Skip to content

Commit 826caad

Browse files
committed
proxy tasks pca return pca_models
1 parent 88cbada commit 826caad

3 files changed

Lines changed: 23 additions & 3 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "eisp"
3-
version = "0.3.7"
3+
version = "0.3.8"
44
authors = [{ name = "Clara Ernesto", email = "clrcera05@gmail.com" }]
55
description = "A framework using an ensemble for inference on sensitive data using proxy tasks."
66
readme = "README.md"

src/eisp/proxy_tasks.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,12 @@ def get_feature(self, name: str) -> np.ndarray | None:
118118
def get_all_features(self) -> dict[str, np.ndarray]:
119119
return self.features
120120

121-
def apply_pca(self) -> Self:
121+
def apply_pca(self) -> tuple[Self, dict[str, PCA]]:
122122
if self.pca_processed:
123123
return self
124124

125125
pca_features = {}
126+
pca_models = {}
126127
for name, feature in self.features.items():
127128
# Test all powers of 2 less than the feature dimension
128129
possible_components_num = []
@@ -138,6 +139,25 @@ def apply_pca(self) -> Self:
138139
break
139140
if name not in pca_features:
140141
pca_features[name] = feature # If no PCA applied, keep original
142+
pca_models[name] = None
143+
else:
144+
pca_models[name] = pca_model
145+
146+
pca_feature_vectors = FeatureVectors(pca_features)
147+
148+
# Avoid reapplying PCA multiple times, if saved to disk,
149+
# this flag will be False when loaded again
150+
pca_feature_vectors.pca_processed = True
151+
return pca_feature_vectors, pca_models
152+
153+
def apply_pca_models(self, pca_models: dict[str, PCA]) -> Self:
154+
pca_features = {}
155+
for name, feature in self.features.items():
156+
pca_model = pca_models.get(name, None)
157+
if pca_model is not None:
158+
pca_features[name] = pca_model.transform(feature)
159+
else:
160+
pca_features[name] = feature # If no PCA model, keep original
141161

142162
pca_feature_vectors = FeatureVectors(pca_features)
143163

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)