diff --git a/README.md b/README.md
index cbfaa2b..bc13c8f 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
-
+
@@ -97,34 +97,17 @@ Neural networks come from different sources. With `thingsvision`, you can extrac
### :computer: Setting up your environment
#### Working locally
-First, create a new `conda environment` with Python version 3.8, 3.9, 3.10, or 3.11 e.g. by using `conda`:
-
+First, create a new `conda environment` with Python version 3.10, 3.11, or 3.12 e.g. by using `conda`:
```bash
-$ conda create -n thingsvision python=3.9
+$ conda create -n thingsvision python=3.10
$ conda activate thingsvision
```
-
Then, activate the environment and simply install `thingsvision` via running the following `pip` command in your terminal.
-
```bash
$ pip install --upgrade thingsvision
-$ pip install git+https://github.com/openai/CLIP.git
-```
-
-If you want to extract features for [harmonized models](https://vicco-group.github.io/thingsvision/AvailableModels.html#harmonization) from the [Harmonization repo](https://github.com/serre-lab/harmonization), you have to additionally run the following `pip` command in your `thingsvision` environment (FYI: as of now, this seems to be working smoothly on Ubuntu only but not on macOS),
-
-```bash
-$ pip install git+https://github.com/serre-lab/Harmonization.git
-$ pip install keras-cv-attention-models>=1.3.5
-```
-
-If you want to extract features for [DreamSim](https://dreamsim-nights.github.io/) from the [DreamSim repo](https://github.com/ssundaram21/dreamsim), you have to additionally run the following `pip` command in your `thingsvision` environment,
-
-```bash
-$ pip install dreamsim==0.1.2
```
-See the [docs](https://vicco-group.github.io/thingsvision/AvailableModels.html#dreamsim) for which `DreamSim` models are available in `thingsvision`.
+The package automatically installs the [Harmonization](https://github.com/serre-lab/harmonization) and [DreamSim](https://github.com/ssundaram21/dreamsim) repositories. See the documentation for available [harmonized models](https://vicco-group.github.io/thingsvision/AvailableModels.html#harmonization) and [DreamSim models](https://vicco-group.github.io/thingsvision/AvailableModels.html#dreamsim) in `thingsvision`.
#### Google Colab
Alternatively, you can use Google Colab to play around with `thingsvision` by uploading your image data to Google Drive (via directory mounting).
@@ -252,6 +235,49 @@ for batch in my_dataloader:
... # whatever post-processing you want to add to the extracted features
```
+#### Multi Module Feature Extraction
+
+It is possible to jointly extract features for multiple `module_names` of a single model.
+
+##### PyTorch
+
+```python
+
+module_names = ['visual', ...] # add more module_names here
+
+# your custom dataset and dataloader classes come here (for example, a PyTorch data loader)
+my_dataset = ...
+my_dataloader = ...
+
+with extractor.batch_extraction(module_names=module_names, output_type="tensor") as e:
+ for batch in my_dataloader:
+ ... # whatever preprocessing you want to add to the batch
+ feature_batch_dict = e.extract_batch(
+ batch=batch,
+ flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer
+ )
+ ... # whatever post-processing you want to add to the extracted features
+```
+
+##### TensorFlow / Keras
+
+```python
+module_names = ['visual', ...] # add more module_names here
+
+# your custom dataset and dataloader classes come here (for example, TFRecords files)
+my_dataset = ...
+my_dataloader = ...
+
+for batch in my_dataloader:
+ ... # whatever preprocessing you want to add to the batch
+ feature_batch = extractor.extract_batch(
+ batch=batch,
+ module_names=module_names,
+ flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer
+ )
+ ... # whatever post-processing you want to add to the extracted features
+```
+
#### Human alignment
*Human alignment*: If you want to align the extracted features with human object similarity according to the approach introduced in *[Improving neural network representations using human similiarty judgments](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9febda1c8344cc5f2d51713964864e93-Abstract-Conference.html)* you can optionally `align` the extracted features using the following method:
diff --git a/requirements.txt b/requirements.txt
index 6f5c79b..46d2c66 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,6 +7,7 @@ numpy<2
open_clip_torch==3.*
pandas
regex
+safetensors<0.6
scikit-image
scikit-learn
scipy
diff --git a/setup.py b/setup.py
index 75f64f9..62c94a6 100644
--- a/setup.py
+++ b/setup.py
@@ -15,6 +15,7 @@
"open_clip_torch==3.*",
"pandas",
"regex",
+ "safetensors<0.6",
"scikit-image",
"scikit-learn",
"scipy",
diff --git a/tests/test_features.py b/tests/test_features.py
index 9533462..0dbd266 100644
--- a/tests/test_features.py
+++ b/tests/test_features.py
@@ -44,6 +44,19 @@ def get_4D_features(self):
flatten_acts=False,
)
return features
+
+ def get_multi_features(self):
+ model_name = "vgg16_bn"
+ extractor, _, batches = helper.create_extractor_and_dataloader(
+ model_name=model_name, pretrained=False, source="torchvision"
+ )
+ module_names = ["features.23", "classifier.3"]
+ features = extractor.extract_features(
+ batches=batches,
+ module_names=module_names,
+ flatten_acts=False,
+ )
+ return features
def test_postprocessing(self):
"""Test different postprocessing methods (e.g., centering, normalization, compression)."""
@@ -89,6 +102,18 @@ def test_storing_4d(self):
)
self.check_file_exists("features", format, False)
+
+ def test_storing_multi(self):
+ features = self.get_multi_features()
+ for _, feature in features.items():
+ for format in set(helper.FILE_FORMATS) - set(["txt"]):
+ # tests whether features can be saved in any of the formats except txt
+ save_features(
+ features=feature,
+ out_path=helper.OUT_PATH,
+ file_format=format,
+ )
+ self.check_file_exists(f"features", format, False)
def test_splitting_2d(self):
n_splits = 3
@@ -129,3 +154,29 @@ def test_splitting_4d(self):
file_format="txt",
n_splits=n_splits,
)
+
+ def test_splitting_multi(self):
+ n_splits = 3
+ features = self.get_multi_features()
+ for format in set(helper.FILE_FORMATS) - set(["txt"]):
+ for _, feature in features.items():
+ if format == "pt":
+ feature = torch.from_numpy(feature)
+ split_features(
+ features=feature,
+ root=helper.OUT_PATH,
+ file_format=format,
+ n_splits=n_splits,
+ )
+
+ for i in range(1, n_splits):
+ self.check_file_exists(f"features_{i:02d}", format, False)
+
+ with self.assertRaises(Exception):
+ for _, feature in features.items():
+ split_features(
+ features=feature,
+ root=helper.OUT_PATH,
+ file_format="txt",
+ n_splits=n_splits,
+ )
diff --git a/thingsvision/core/extraction/base.py b/thingsvision/core/extraction/base.py
index e1bf3fe..6b5db03 100644
--- a/thingsvision/core/extraction/base.py
+++ b/thingsvision/core/extraction/base.py
@@ -2,7 +2,8 @@
import os
import re
import warnings
-from typing import Callable, Iterator, List, Optional, Union
+from typing import Callable, Dict, Iterator, List, Optional, Union
+from collections import defaultdict
import numpy as np
from torchtyping import TensorType
@@ -76,17 +77,32 @@ def load_model(self) -> None:
def extract_batch(
self,
batch: Union[TensorType["b", "c", "h", "w"], Array],
- module_name: str,
- flatten_acts: bool,
- output_type: str,
+ module_name: Optional[str] = None,
+ module_names: Optional[List[str]] = None,
+ flatten_acts: bool = False,
+ output_type: str = "ndarray",
) -> Union[
+ # This is the return type when 'module_names' is used
+ Dict[
+ str,
+ Union[
+ Union[
+ TensorType["n", "num_maps", "h_prime", "w_prime"],
+ TensorType["n", "t", "d"],
+ TensorType["n", "p"],
+ TensorType["n", "d"],
+ ],
+ Array,
+ ],
+ ],
+ # This is the return type when 'module_name' is used (for backward compatibility)
Union[
- TensorType["b", "num_maps", "h_prime", "w_prime"],
- TensorType["b", "t", "d"],
- TensorType["b", "p"],
- TensorType["b", "d"],
+ TensorType["n", "num_maps", "h_prime", "w_prime"],
+ TensorType["n", "t", "d"],
+ TensorType["n", "p"],
+ TensorType["n", "d"],
+ Array,
],
- Array,
]:
"""Extract the activations of a selected module for every image in a mini-batch.
@@ -95,7 +111,9 @@ def extract_batch(
batch : np.ndarray or torch.Tensor
mini-batch of three-dimensional image tensors.
module_name : str
- Name of the module for which features should be extraced.
+ Name of the neural network layer for which features should be extracted.
+ module_names : List[str]
+ Names of the modules for which features should be extracted.
flatten_acts : bool
Whether the activation of a tensor should be flattened to a vector.
output_type : str {"ndarray", "tensor"}
@@ -112,16 +130,19 @@ def extract_batch(
def _extract_batch(
self,
batch: Union[TensorType["b", "c", "h", "w"], Array],
- module_name: str,
- flatten_acts: bool,
- ) -> Union[
+ module_names: List[str],
+ flatten_acts: bool = False,
+ ) -> Dict[
+ str,
Union[
- TensorType["b", "num_maps", "h_prime", "w_prime"],
- TensorType["b", "t", "d"],
- TensorType["b", "p"],
- TensorType["b", "d"],
+ Union[
+ TensorType["n", "num_maps", "h_prime", "w_prime"],
+ TensorType["n", "t", "d"],
+ TensorType["n", "p"],
+ TensorType["n", "d"],
+ ],
+ Array,
],
- Array,
]:
raise NotImplementedError
@@ -129,13 +150,16 @@ def get_output_types(self) -> List[str]:
"""Return the list of available output types (for the feature matrix)."""
return ["ndarray", "tensor"]
- def _module_and_output_check(self, module_name: str, output_type: str) -> None:
+ def _module_and_output_check(
+ self, module_names: List[str], output_type: str
+ ) -> None:
"""Checks whether the provided module name and output type are valid."""
valid_names = self.get_module_names()
- if not module_name in valid_names:
- raise ValueError(
- f"\n{module_name} is not a valid module name. Please choose a name from the following set of modules: {valid_names}\n"
- )
+ for module_name in module_names:
+ if module_name not in valid_names:
+ raise ValueError(
+ f"\n{module_name} is not a valid module name. Please choose a name from the following set of modules: {valid_names}\n"
+ )
assert (
output_type in self.get_output_types()
), f"\nData type of output feature matrix must be set to one of the following available data types: {self.get_output_types()}\n"
@@ -143,19 +167,36 @@ def _module_and_output_check(self, module_name: str, output_type: str) -> None:
def extract_features(
self,
batches: Iterator[Union[TensorType["b", "c", "h", "w"], Array]],
- module_name: str,
+ module_name: Optional[str] = None,
+ module_names: Optional[List[str]] = None,
flatten_acts: bool = False,
output_type: Optional[str] = "ndarray",
output_dir: Optional[str] = None,
step_size: Optional[int] = None,
+ file_name_suffix: str = "",
+ save_in_one_file: bool = False,
) -> Union[
+ # This is the return type when 'module_names' is used
+ Dict[
+ str,
+ Union[
+ Union[
+ TensorType["n", "num_maps", "h_prime", "w_prime"],
+ TensorType["n", "t", "d"],
+ TensorType["n", "p"],
+ TensorType["n", "d"],
+ ],
+ Array,
+ ],
+ ],
+ # This is the return type when 'module_name' is used (for backward compatibility)
Union[
TensorType["n", "num_maps", "h_prime", "w_prime"],
TensorType["n", "t", "d"],
TensorType["n", "p"],
TensorType["n", "d"],
+ Array,
],
- Array,
]:
"""Extract hidden unit activations (at specified layer) for every image in the database.
@@ -166,8 +207,11 @@ def extract_features(
mini-batches, where each element is a
subsample of the full (image) dataset.
module_name : str
- Layer name. Name of neural network layer for
- which features should be extraced.
+ Layer name. Name of the neural network layer for
+ which features should be extracted.
+ module_names : List[str]
+ Layer names. Names of neural network layers for
+ which features should be extracted.
flatten_acts : bool
Whether activation tensor (e.g., activations
from an early layer of the neural network model)
@@ -189,13 +233,28 @@ def extract_features(
are saved to disk. The default uses a heuristic so that
extracted features should fit into 8GB of free memory.
Only used if output_dir is defined.
-
+ file_name_suffix: str
+ Suffix to append to the output file names (e.g., "_train", "_val").
+ save_in_one_file : bool
+ If True, all features are saved in one file. If output_dir is defined,
+ the features are saved in separate files for each module name. They are first
+ saved in chunks of step_size batches, and then all features are concatenated
+ and saved in one file.
Returns
-------
output : np.ndarray or torch.Tensor
Returns the feature matrix (e.g., $X \in \mathbb{R}^{n \times d}$ if penultimate or logits layer or flatten_acts = True).
"""
- self._module_and_output_check(module_name, output_type)
+ if not bool(module_name) ^ bool(module_names):
+ raise ValueError(
+ "\nPlease provide either a single module name or a list of module names, but not both.\n"
+ )
+ if module_name is not None:
+ single_module_call = True
+ module_names = [module_name]
+ else:
+ single_module_call = False
+ self._module_and_output_check(module_names, output_type)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
@@ -203,58 +262,97 @@ def extract_features(
# if step size is not given, assume that features to every image consume 3MB of memory and that the user has at least 8GB of free RAM
step_size = 8000 // (len(next(iter(batches))) * 3) + 1
- features = []
+ # create feature dict per module name
+ features = defaultdict(list)
+ feature_file_names = defaultdict(list)
image_ct, last_image_ct = 0, 0
for i, batch in tqdm(
enumerate(batches, start=1), desc="Batch", total=len(batches)
):
- features.append(
- self._extract_batch(
- batch=batch, module_name=module_name, flatten_acts=flatten_acts
- )
+ modules_features = self._extract_batch(
+ batch=batch, module_names=module_names, flatten_acts=flatten_acts
)
image_ct += len(batch)
del batch
- if output_dir and (i % step_size == 0 or i == len(batches)):
- if self.get_backend() == "pt":
- features_subset = torch.cat(features)
- if output_type == "ndarray":
- features_subset = self._to_numpy(features_subset)
+ for module_name in module_names:
+ features[module_name].append(modules_features[module_name])
+
+ if output_dir and (i % step_size == 0 or i == len(batches)):
+ if self.get_backend() == "pt":
+ features_subset = torch.cat(features[module_name])
+ if output_type == "ndarray":
+ features_subset = self._to_numpy(features_subset)
+ features_subset_file = os.path.join(
+ output_dir,
+ f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.npy",
+ )
+ np.save(features_subset_file, features_subset)
+ else: # output_type = tensor
+ features_subset_file = os.path.join(
+ output_dir,
+ f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.pt",
+ )
+ torch.save(features_subset, features_subset_file)
+ else:
features_subset_file = os.path.join(
output_dir,
- f"features_{last_image_ct}-{image_ct}.npy",
+ f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.npy",
)
+ features_subset = np.vstack(features[module_name])
np.save(features_subset_file, features_subset)
- else: # output_type = tensor
- features_subset_file = os.path.join(
- output_dir,
- f"features_{last_image_ct}-{image_ct}.pt",
- )
- torch.save(features_subset, features_subset_file)
- else:
- features_subset_file = os.path.join(
- output_dir, f"features_{last_image_ct}-{image_ct}.npy"
- )
- features_subset = np.vstack(features)
- np.save(features_subset_file, features_subset)
- features = []
- last_image_ct = image_ct
+ features = defaultdict(list)
+ last_image_ct = image_ct
+ feature_file_names[module_name].append(features_subset_file)
print(
f"...Features successfully extracted for all {image_ct} images in the database."
)
if output_dir:
+ if save_in_one_file:
+ # load features per module name and concatenate them
+ for module_name in module_names:
+ # load from files
+ features = []
+ for file in feature_file_names[module_name]:
+ if self.get_backend() == "pt" and output_type != "ndarray":
+ if file.endswith(".pt"):
+ features.append(
+ torch.load(os.path.join(output_dir, file))
+ )
+ else:
+ if file.endswith(".npy"):
+ features.append(
+ np.load(os.path.join(output_dir, file))
+ )
+ features_file = os.path.join(
+ output_dir, f"{module_name}/features{file_name_suffix}"
+ )
+ if output_type == "ndarray":
+ np.save(f"{features_file}.npy", np.concatenate(features))
+ else: # output_type = tensor
+ torch.save(torch.cat(features), f"{features_file}.pt")
+ print(
+ f"...Features for module '{module_name}' were saved to {features_file}."
+ )
+ # remove temporary files
+ for file in feature_file_names[module_name]:
+ os.remove(os.path.join(output_dir, file))
+
print(f"...Features were saved to {output_dir}.")
return None
else:
- if self.get_backend() == "pt":
- features = torch.cat(features)
- if output_type == "ndarray":
- features = self._to_numpy(features)
- else:
- features = np.vstack(features)
- print(f"...Features shape: {features.shape}")
+ for module_name in module_names:
+ if self.get_backend() == "pt":
+ features[module_name] = torch.cat(features[module_name])
+ if output_type == "ndarray":
+ features[module_name] = self._to_numpy(features[module_name])
+ else:
+ features[module_name] = np.vstack(features[module_name])
+ print(f"...Features shape: {features[module_name].shape}")
+ if single_module_call:
+ # for backward compatibility
+ return features[module_name]
return features
@staticmethod
@@ -264,7 +362,7 @@ def _to_numpy(
TensorType["n", "t", "d"],
TensorType["n", "p"],
TensorType["n", "d"],
- ]
+ ],
) -> Array:
"""Move activations to CPU and convert torch.Tensor to np.ndarray."""
return features.numpy()
diff --git a/thingsvision/core/extraction/tensorflow.py b/thingsvision/core/extraction/tensorflow.py
index 1dc49f5..aff3320 100644
--- a/thingsvision/core/extraction/tensorflow.py
+++ b/thingsvision/core/extraction/tensorflow.py
@@ -38,28 +38,44 @@ def __init__(
def _extract_batch(
self,
batch: Array,
- module_name: str,
+ module_names: Optional[List[str]],
flatten_acts: bool,
- ) -> Array:
- layer_out = [self.model.get_layer(module_name).output]
+ ) -> Dict[str, Array]:
+ layer_outputs = [self.model.get_layer(name).output for name in module_names]
activation_model = keras.models.Model(
inputs=self.model.input,
- outputs=layer_out,
+ outputs=layer_outputs,
)
- activations = activation_model.predict(batch)
+ activations_list = activation_model.predict(batch)
+ if len(module_names) == 1:
+ activations_list = [activations_list]
+ activations_dict = {
+ name: act for name, act in zip(module_names, activations_list)
+ }
if flatten_acts:
- activations = activations.reshape(activations.shape[0], -1)
- return activations
+ for name, act in activations_dict.items():
+ activations_dict[name] = act.reshape(act.shape[0], -1)
+ return activations_dict
def extract_batch(
self,
batch: Array,
- module_name: str,
- flatten_acts: bool,
+ module_name: Optional[str] = None,
+ module_names: Optional[List[str]] = None,
+ flatten_acts: bool = False,
output_type: str = "ndarray",
- ) -> Array:
- self._module_and_output_check(module_name, output_type)
- activations = self._extract_batch(batch, module_name, flatten_acts)
+ ) -> Union[Array, Dict[str, Array]]:
+ if not bool(module_name) ^ bool(module_names):
+ raise ValueError(
+ "\nPlease provide either a single module name or a list of module names, but not both.\n"
+ )
+ if not module_names:
+ module_names = [module_name]
+ self._module_and_output_check(module_names, output_type)
+ # Extract features from the specified module, tensorflow does not support multiple modules extraction
+ activations = self._extract_batch(batch, module_names, flatten_acts)
+ if module_name:
+ return activations[module_name]
return activations
def show_model(self) -> str:
diff --git a/thingsvision/core/extraction/torch.py b/thingsvision/core/extraction/torch.py
index 7c86c1a..c9e0bcd 100644
--- a/thingsvision/core/extraction/torch.py
+++ b/thingsvision/core/extraction/torch.py
@@ -73,21 +73,35 @@ def hook(model, input, output) -> None:
return hook
- def _register_hook(self, module_name: str) -> None:
- """Register a forward hook to store activations."""
+ def _register_hooks(self, module_names: List[str]) -> None:
+ """Register a forward hook to store activations for multiple modules."""
+ self.hook_handles = []
for n, m in self.model.named_modules():
- if n == module_name:
- self.hook_handle = m.register_forward_hook(self.get_activation(n))
- break
-
- def _unregister_hook(self) -> None:
- """Remove the forward hook."""
- self.hook_handle.remove()
+ if n in module_names:
+ handle = m.register_forward_hook(self.get_activation(n))
+ self.hook_handles.append(handle)
+
+ def _unregister_hooks(self) -> None:
+ """Unregister all forward hooks."""
+ if self.hook_handles:
+ for handle in self.hook_handles:
+ handle.remove()
+ self.hook_handles = []
+ else:
+ warnings.warn(
+ "\nNo hooks were registered. Nothing to unregister.\n",
+ category=UserWarning,
+ )
- def batch_extraction(self, module_name: str, output_type: str) -> object:
+ def batch_extraction(
+ self,
+ module_name: Optional[str] = None,
+ module_names: Optional[List[str]] = None,
+ output_type: str = "ndarray",
+ ) -> object:
"""Allows mini-batch extraction for custom data pipeline using a with-statement."""
return BatchExtraction(
- extractor=self, module_name=module_name, output_type=output_type
+ extractor=self, module_name=module_name, module_names=module_names, output_type=output_type
)
def extract_batch(
@@ -95,88 +109,124 @@ def extract_batch(
batch: TensorType["b", "c", "h", "w"],
flatten_acts: bool,
) -> Union[
- TensorType["b", "num_maps", "h_prime", "w_prime"],
- TensorType["b", "t", "d"],
- TensorType["b", "p"],
- TensorType["b", "d"],
+ # This is the return type when 'module_names' is used
+ Dict[
+ str,
+ Union[
+ Union[
+ TensorType["n", "num_maps", "h_prime", "w_prime"],
+ TensorType["n", "t", "d"],
+ TensorType["n", "p"],
+ TensorType["n", "d"],
+ ],
+ Array,
+ ],
+ ],
+ # This is the return type when 'module_name' is used (for backward compatibility)
+ Union[
+ TensorType["n", "num_maps", "h_prime", "w_prime"],
+ TensorType["n", "t", "d"],
+ TensorType["n", "p"],
+ TensorType["n", "d"],
+ Array,
+ ],
]:
- act = self._extract_batch(
- batch=batch, module_name=self.module_name, flatten_acts=flatten_acts
+ acts = self._extract_batch(
+ batch=batch,
+ module_names=self.module_names,
+ flatten_acts=flatten_acts,
)
if self.output_type == "ndarray":
- act = self._to_numpy(act)
- return act
+ for module_name, act in acts.items():
+ acts[module_name] = self._to_numpy(act)
+ if getattr(self, "module_name", None) is not None:
+ return acts[self.module_name]
+ return acts
@torch.no_grad()
def _extract_batch(
self,
batch: TensorType["b", "c", "h", "w"],
- module_name: str,
+ module_names: List[str],
flatten_acts: bool,
- ) -> Union[
- TensorType["b", "num_maps", "h_prime", "w_prime"],
- TensorType["b", "t", "d"],
- TensorType["b", "p"],
- TensorType["b", "d"],
+ ) -> Dict[
+ str,
+ Union[
+ TensorType["b", "num_maps", "h_prime", "w_prime"],
+ TensorType["b", "t", "d"],
+ TensorType["b", "p"],
+ TensorType["b", "d"],
+ ],
]:
"""Extract representations from a batch of images."""
# move mini-batch to current device
batch = batch.to(self.device)
batch_size = batch.shape[0]
_ = self.forward(batch)
- act = self.activations[module_name]
- if len(act.shape) > 2:
- if hasattr(self, "token_extraction"):
- if self.token_extraction == "cls_token":
- act = act[:, 0, :].clone()
- elif self.token_extraction == "avg_pool":
- act = act[:, 1:, :].clone().mean(dim=1)
- elif self.token_extraction == "cls_token+avg_pool":
- cls_token = act[:, 0, :].clone()
- pooled_tokens = act[:, 1:, :].clone().mean(dim=1)
- act = torch.cat((cls_token, pooled_tokens), dim=1)
- else:
- raise ValueError(
- f"\n{self.token_extraction} is not a valid value for token extraction. "
- "\nChoose one of the following: {TOKEN_EXTRACTIONS}.\n "
- )
- elif flatten_acts:
- if self.model_name.lower().startswith("clip"):
- act = self.flatten_acts(act, batch, module_name)
- else:
- act = self.flatten_acts(act)
- if act.shape[0] != batch_size:
- raise ValueError(
- f"The number of extracted features ({act.shape=}) does not match the batch size ({batch.shape=}). "
- "Please check the model, the module name, and the model parameters ..."
- )
- if act.is_cuda or act.get_device() >= 0:
- torch.cuda.empty_cache()
- act = act.cpu()
- return act
+ acts = {}
+ for module_name in module_names:
+ act = self.activations[module_name]
+ if act.shape[0] != batch_size:
+ raise ValueError(
+ f"The number of extracted features ({act.shape=}) does not match the batch size ({batch.shape=}). "
+ "Please check the model, the module name, and the model parameters ..."
+ )
+ if len(act.shape) > 2:
+ if hasattr(self, "token_extraction"):
+ if self.token_extraction == "cls_token":
+ act = act[:, 0, :].clone()
+ elif self.token_extraction == "avg_pool":
+ act = act[:, 1:, :].clone().mean(dim=1)
+ elif self.token_extraction == "cls_token+avg_pool":
+ cls_token = act[:, 0, :].clone()
+ pooled_tokens = act[:, 1:, :].clone().mean(dim=1)
+ act = torch.cat((cls_token, pooled_tokens), dim=1)
+ else:
+ raise ValueError(
+ f"\n{self.token_extraction} is not a valid value for token extraction. "
+ "\nChoose one of the following: {TOKEN_EXTRACTIONS}.\n "
+ )
+ elif flatten_acts:
+ if self.model_name.lower().startswith("clip"):
+ act = self.flatten_acts(act, batch, module_name)
+ else:
+ act = self.flatten_acts(act)
+ if act.is_cuda or act.get_device() >= 0:
+ torch.cuda.empty_cache()
+ act = act.cpu()
+ acts[module_name] = act
+ return acts
def extract_features(
self,
batches: Iterator,
- module_name: str,
+ module_name: Optional[str] = None,
+ module_names: Optional[List[str]] = None,
flatten_acts: bool = False,
output_type: str = "ndarray",
output_dir: Optional[str] = None,
step_size: Optional[int] = None,
):
+ if not bool(module_name) ^ bool(module_names):
+ raise ValueError(
+ "\nPlease provide either a single module name or a list of module names, but not both.\n"
+ )
self.model = self.model.to(self.device)
self.activations = {}
- self._register_hook(module_name=module_name)
+ if module_name is not None:
+ self._register_hooks(module_names=[module_name])
+ else:
+ self._register_hooks(module_names=module_names)
features = super().extract_features(
batches=batches,
module_name=module_name,
+ module_names=module_names,
flatten_acts=flatten_acts,
output_type=output_type,
output_dir=output_dir,
step_size=step_size,
)
- if self.hook_handle:
- self._unregister_hook()
+ self._unregister_hooks()
return features
def forward(
@@ -189,7 +239,7 @@ def forward(
def flatten_acts(
act: Union[
TensorType["b", "num_maps", "h_prime", "w_prime"], TensorType["b", "t", "d"]
- ]
+ ],
) -> TensorType["b", "p"]:
"""Default flattening of activations."""
return act.view(act.size(0), -1)
@@ -276,7 +326,11 @@ def get_backend(self) -> str:
class BatchExtraction(object):
def __init__(
- self, extractor: PyTorchExtractor, module_name: str, output_type: str
+ self,
+ extractor: PyTorchExtractor,
+ module_name: Optional[str] = None,
+ module_names: Optional[List[str]] = None,
+ output_type: str = "ndarray",
) -> None:
"""
Mini-batch extraction object that can be used as a with-statement in a PyTorch extractor.
@@ -285,23 +339,33 @@ def __init__(
----------
extractor (object): PyTorchExtractor class.
module_name (str): The module of model for which features will be extracted.
+ module_names (List[str]): List of modules of model for which features will be extracted.
output_type (str): Type of the feature matrix returned by the extractor.
"""
+ if not bool(module_name) ^ bool(module_names):
+ raise ValueError(
+ "\nPlease provide either a single module name or a list of module names, but not both.\n"
+ )
+ if module_name is not None:
+ module_names = [module_name]
self.extractor = extractor
self.module_name = module_name
+ self.module_names = module_names
self.output_type = output_type
def __enter__(self) -> PyTorchExtractor:
"""Registering hooks and setting attributes during opening."""
- self.extractor._module_and_output_check(self.module_name, self.output_type)
- self.extractor._register_hook(self.module_name)
+ self.extractor._module_and_output_check(self.module_names, self.output_type)
+ self.extractor._register_hooks(self.module_names)
setattr(self.extractor, "module_name", self.module_name)
+ setattr(self.extractor, "module_names", self.module_names)
setattr(self.extractor, "output_type", self.output_type)
return self.extractor
def __exit__(self, *args):
"""Removing hooks and deleting attributes at closing."""
- self.extractor._unregister_hook()
+ self.extractor._unregister_hooks()
delattr(self.extractor, "module_name")
+ delattr(self.extractor, "module_names")
delattr(self.extractor, "output_type")