4444)
4545from turftopic .types import VALID_DISTANCE_METRICS , DistanceMetric
4646from turftopic .utils import safe_binarize
47- from turftopic .vectorizers import PhraseVectorizer
4847from turftopic .vectorizers .default import default_vectorizer
48+ from turftopic .vectorizers .phrases import PhraseVectorizer
4949
5050integer_message = """
5151You tried to pass an integer to ClusteringTopicModel as its first argument.
@@ -719,12 +719,12 @@ def transform(
719719 X = self .vectorizer .transform (raw_documents )
720720 X = normalize (X , axis = 1 , norm = "l1" , copy = False )
721721 X = X * idf_diag
722- doc_topic_matrix = np . exp ( cosine_similarity (X , self .components_ ) )
722+ doc_topic_matrix = cosine_similarity (X , self .components_ )
723723 elif self .feature_importance == "centroid" :
724724 if embeddings is None :
725725 embeddings = self .encode_documents (raw_documents )
726- doc_topic_matrix = np . exp (
727- cosine_similarity ( embeddings , self ._calculate_topic_vectors () )
726+ doc_topic_matrix = cosine_similarity (
727+ embeddings , self ._calculate_topic_vectors ()
728728 )
729729 else :
730730 doc_topic_matrix = safe_binarize (
@@ -909,7 +909,7 @@ def __init__(
909909 reduction_topic_representation : TopicRepresentation = "centroid" ,
910910 window_size : Optional [int ] = 50 ,
911911 step_size : Optional [int ] = 40 ,
912- pooling : Optional [Callable ] = np .mean ,
912+ pooling : Optional [Callable ] = np .nanmean ,
913913 random_state : Optional [int ] = None ,
914914 ):
915915 if dimensionality_reduction is None :
@@ -933,7 +933,10 @@ def __init__(
933933 cluster_selection_method = "eom" ,
934934 )
935935 self .encoder = encoder
936- self .vectorizer = vectorizer
936+ if isinstance (encoder , str ):
937+ encoder = LateSentenceTransformer (encoder )
938+ if vectorizer is None :
939+ vectorizer = PhraseVectorizer ()
937940 self .dimensionality_reduction = dimensionality_reduction
938941 self .clustering = clustering
939942 self .feature_importance = feature_importance
@@ -942,7 +945,7 @@ def __init__(
942945 self .reduction_distance_metric = reduction_distance_metric
943946 self .reduction_topic_representation = reduction_topic_representation
944947 self .random_state = random_state
945- self . model = ClusteringTopicModel (
948+ model = ClusteringTopicModel (
946949 encoder = encoder ,
947950 vectorizer = vectorizer ,
948951 dimensionality_reduction = dimensionality_reduction ,
@@ -955,8 +958,8 @@ def __init__(
955958 reduction_topic_representation = reduction_topic_representation ,
956959 )
957960 super ().__init__ (
958- self . model ,
959- window_size = self . window_size ,
960- step_size = self . step_size ,
961- pooling = self . pooling ,
961+ model ,
962+ window_size = window_size ,
963+ step_size = step_size ,
964+ pooling = pooling ,
962965 )
0 commit comments