Skip to content

Commit ba45dc1

Browse files
Made tests faster and added late interaction testing
1 parent badeeec commit ba45dc1

1 file changed

Lines changed: 24 additions & 22 deletions

File tree

tests/test_integration.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77
import pandas as pd
88
import pytest
9-
from sentence_transformers import SentenceTransformer
109
from sklearn.cluster import KMeans
1110
from sklearn.datasets import fetch_20newsgroups
1211
from sklearn.decomposition import PCA
@@ -15,13 +14,17 @@
1514
GMM,
1615
AutoEncodingTopicModel,
1716
ClusteringTopicModel,
17+
CTop2Vec,
1818
FASTopic,
1919
KeyNMF,
2020
SemanticSignalSeparation,
2121
SensTopic,
2222
Topeax,
2323
load_model,
2424
)
25+
from turftopic.late import LateSentenceTransformer
26+
27+
ENCODER = "sentence-transformers/static-retrieval-mrl-en-v1"
2528

2629

2730
def batched(iterable, n: int):
@@ -56,44 +59,28 @@ def generate_dates(
5659
remove=("headers", "footers", "quotes"),
5760
)
5861
texts = newsgroups.data
59-
trf = SentenceTransformer("paraphrase-MiniLM-L3-v2")
62+
trf = LateSentenceTransformer(ENCODER)
6063
embeddings = np.asarray(trf.encode(texts))
6164
timestamps = generate_dates(n_dates=len(texts))
6265

6366
models = [
64-
GMM(3, encoder=trf),
6567
SemanticSignalSeparation(3, encoder=trf),
66-
KeyNMF(3, encoder=trf),
6768
KeyNMF(3, encoder=trf, cross_lingual=True),
68-
ClusteringTopicModel(
69-
dimensionality_reduction=PCA(10),
70-
clustering=KMeans(3),
71-
feature_importance="c-tf-idf",
72-
encoder=trf,
73-
reduction_method="average",
74-
),
7569
ClusteringTopicModel(
7670
dimensionality_reduction=PCA(10),
7771
clustering=KMeans(3),
7872
feature_importance="centroid",
7973
encoder=trf,
8074
reduction_method="smallest",
8175
),
82-
AutoEncodingTopicModel(3, combined=True),
83-
FASTopic(3, batch_size=None),
84-
SensTopic(),
85-
Topeax(),
76+
AutoEncodingTopicModel(3, combined=False, encoder=trf),
77+
FASTopic(3, batch_size=None, encoder=trf),
78+
SensTopic(encoder=trf),
79+
Topeax(encoder=trf),
8680
]
8781

8882
dynamic_models = [
8983
GMM(3, encoder=trf),
90-
ClusteringTopicModel(
91-
dimensionality_reduction=PCA(10),
92-
clustering=KMeans(3),
93-
feature_importance="centroid",
94-
encoder=trf,
95-
reduction_method="smallest",
96-
),
9784
ClusteringTopicModel(
9885
dimensionality_reduction=PCA(10),
9986
clustering=KMeans(3),
@@ -106,6 +93,8 @@ def generate_dates(
10693

10794
online_models = [KeyNMF(3, encoder=trf)]
10895

96+
late_models = [CTop2Vec(encoder=trf)]
97+
10998

11099
@pytest.mark.parametrize("model", dynamic_models)
111100
def test_fit_dynamic(model):
@@ -122,6 +111,19 @@ def test_fit_dynamic(model):
122111
df = pd.read_csv(out_path)
123112

124113

114+
@pytest.mark.parametrize("model", late_models)
115+
def test_late(model):
116+
doc_topic_matrix = model.fit_transform(
117+
texts,
118+
)
119+
table = model.export_topics(format="csv")
120+
with tempfile.TemporaryDirectory() as tmpdirname:
121+
out_path = Path(tmpdirname).joinpath("topics.csv")
122+
with out_path.open("w") as out_file:
123+
out_file.write(table)
124+
df = pd.read_csv(out_path)
125+
126+
125127
@pytest.mark.parametrize("model", online_models)
126128
def test_fit_online(model):
127129
for epoch in range(5):

0 commit comments

Comments
 (0)