@@ -118,11 +118,12 @@ def get_feature(self, name: str) -> np.ndarray | None:
118118 def get_all_features (self ) -> dict [str , np .ndarray ]:
119119 return self .features
120120
121- def apply_pca (self ) -> Self :
121+ def apply_pca (self ) -> tuple [ Self , dict [ str , PCA ]] :
122122 if self .pca_processed :
123123 return self
124124
125125 pca_features = {}
126+ pca_models = {}
126127 for name , feature in self .features .items ():
127128 # Test all powers of 2 less than the feature dimension
128129 possible_components_num = []
@@ -138,6 +139,25 @@ def apply_pca(self) -> Self:
138139 break
139140 if name not in pca_features :
140141 pca_features [name ] = feature # If no PCA applied, keep original
142+ pca_models [name ] = None
143+ else :
144+ pca_models [name ] = pca_model
145+
146+ pca_feature_vectors = FeatureVectors (pca_features )
147+
148+ # Avoid reapplying PCA multiple times, if saved to disk,
149+ # this flag will be False when loaded again
150+ pca_feature_vectors .pca_processed = True
151+ return pca_feature_vectors , pca_models
152+
153+ def apply_pca_models (self , pca_models : dict [str , PCA ]) -> Self :
154+ pca_features = {}
155+ for name , feature in self .features .items ():
156+ pca_model = pca_models .get (name , None )
157+ if pca_model is not None :
158+ pca_features [name ] = pca_model .transform (feature )
159+ else :
160+ pca_features [name ] = feature # If no PCA model, keep original
141161
142162 pca_feature_vectors = FeatureVectors (pca_features )
143163
0 commit comments