66import numpy as np
77import pandas as pd
88import pytest
9- from sentence_transformers import SentenceTransformer
109from sklearn .cluster import KMeans
1110from sklearn .datasets import fetch_20newsgroups
1211from sklearn .decomposition import PCA
1514 GMM ,
1615 AutoEncodingTopicModel ,
1716 ClusteringTopicModel ,
17+ CTop2Vec ,
1818 FASTopic ,
1919 KeyNMF ,
2020 SemanticSignalSeparation ,
2121 SensTopic ,
2222 Topeax ,
2323 load_model ,
2424)
25+ from turftopic .late import LateSentenceTransformer
26+
27+ ENCODER = "sentence-transformers/static-retrieval-mrl-en-v1"
2528
2629
2730def batched (iterable , n : int ):
@@ -56,44 +59,28 @@ def generate_dates(
5659 remove = ("headers" , "footers" , "quotes" ),
5760)
5861texts = newsgroups .data
59- trf = SentenceTransformer ( "paraphrase-MiniLM-L3-v2" )
62+ trf = LateSentenceTransformer ( ENCODER )
6063embeddings = np .asarray (trf .encode (texts ))
6164timestamps = generate_dates (n_dates = len (texts ))
6265
6366models = [
64- GMM (3 , encoder = trf ),
6567 SemanticSignalSeparation (3 , encoder = trf ),
66- KeyNMF (3 , encoder = trf ),
6768 KeyNMF (3 , encoder = trf , cross_lingual = True ),
68- ClusteringTopicModel (
69- dimensionality_reduction = PCA (10 ),
70- clustering = KMeans (3 ),
71- feature_importance = "c-tf-idf" ,
72- encoder = trf ,
73- reduction_method = "average" ,
74- ),
7569 ClusteringTopicModel (
7670 dimensionality_reduction = PCA (10 ),
7771 clustering = KMeans (3 ),
7872 feature_importance = "centroid" ,
7973 encoder = trf ,
8074 reduction_method = "smallest" ,
8175 ),
82- AutoEncodingTopicModel (3 , combined = True ),
83- FASTopic (3 , batch_size = None ),
84- SensTopic (),
85- Topeax (),
76+ AutoEncodingTopicModel (3 , combined = False , encoder = trf ),
77+ FASTopic (3 , batch_size = None , encoder = trf ),
78+ SensTopic (encoder = trf ),
79+ Topeax (encoder = trf ),
8680]
8781
8882dynamic_models = [
8983 GMM (3 , encoder = trf ),
90- ClusteringTopicModel (
91- dimensionality_reduction = PCA (10 ),
92- clustering = KMeans (3 ),
93- feature_importance = "centroid" ,
94- encoder = trf ,
95- reduction_method = "smallest" ,
96- ),
9784 ClusteringTopicModel (
9885 dimensionality_reduction = PCA (10 ),
9986 clustering = KMeans (3 ),
@@ -106,6 +93,8 @@ def generate_dates(
10693
10794online_models = [KeyNMF (3 , encoder = trf )]
10895
96+ late_models = [CTop2Vec (encoder = trf )]
97+
10998
11099@pytest .mark .parametrize ("model" , dynamic_models )
111100def test_fit_dynamic (model ):
@@ -122,6 +111,19 @@ def test_fit_dynamic(model):
122111 df = pd .read_csv (out_path )
123112
124113
114+ @pytest .mark .parametrize ("model" , late_models )
115+ def test_late (model ):
116+ doc_topic_matrix = model .fit_transform (
117+ texts ,
118+ )
119+ table = model .export_topics (format = "csv" )
120+ with tempfile .TemporaryDirectory () as tmpdirname :
121+ out_path = Path (tmpdirname ).joinpath ("topics.csv" )
122+ with out_path .open ("w" ) as out_file :
123+ out_file .write (table )
124+ df = pd .read_csv (out_path )
125+
126+
125127@pytest .mark .parametrize ("model" , online_models )
126128def test_fit_online (model ):
127129 for epoch in range (5 ):
0 commit comments