diff --git a/README.md b/README.md index cbfaa2b..bc13c8f 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ downloads - Python version + Python version License @@ -97,34 +97,17 @@ Neural networks come from different sources. With `thingsvision`, you can extrac ### :computer: Setting up your environment #### Working locally -First, create a new `conda environment` with Python version 3.8, 3.9, 3.10, or 3.11 e.g. by using `conda`: - +First, create a new `conda environment` with Python version 3.10, 3.11, or 3.12 e.g. by using `conda`: ```bash -$ conda create -n thingsvision python=3.9 +$ conda create -n thingsvision python=3.10 $ conda activate thingsvision ``` - Then, activate the environment and simply install `thingsvision` via running the following `pip` command in your terminal. - ```bash $ pip install --upgrade thingsvision -$ pip install git+https://github.com/openai/CLIP.git -``` - -If you want to extract features for [harmonized models](https://vicco-group.github.io/thingsvision/AvailableModels.html#harmonization) from the [Harmonization repo](https://github.com/serre-lab/harmonization), you have to additionally run the following `pip` command in your `thingsvision` environment (FYI: as of now, this seems to be working smoothly on Ubuntu only but not on macOS), - -```bash -$ pip install git+https://github.com/serre-lab/Harmonization.git -$ pip install keras-cv-attention-models>=1.3.5 -``` - -If you want to extract features for [DreamSim](https://dreamsim-nights.github.io/) from the [DreamSim repo](https://github.com/ssundaram21/dreamsim), you have to additionally run the following `pip` command in your `thingsvision` environment, - -```bash -$ pip install dreamsim==0.1.2 ``` -See the [docs](https://vicco-group.github.io/thingsvision/AvailableModels.html#dreamsim) for which `DreamSim` models are available in `thingsvision`. +The package automatically installs the [Harmonization](https://github.com/serre-lab/harmonization) and [DreamSim](https://github.com/ssundaram21/dreamsim) repositories. See the documentation for available [harmonized models](https://vicco-group.github.io/thingsvision/AvailableModels.html#harmonization) and [DreamSim models](https://vicco-group.github.io/thingsvision/AvailableModels.html#dreamsim) in `thingsvision`. #### Google Colab Alternatively, you can use Google Colab to play around with `thingsvision` by uploading your image data to Google Drive (via directory mounting). @@ -252,6 +235,49 @@ for batch in my_dataloader: ... # whatever post-processing you want to add to the extracted features ``` +#### Multi Module Feature Extraction + +It is possible to jointly extract features for multiple `module_names` of a single model. + +##### PyTorch + +```python + +module_names = ['visual', ...] # add more module_names here + +# your custom dataset and dataloader classes come here (for example, a PyTorch data loader) +my_dataset = ... +my_dataloader = ... + +with extractor.batch_extraction(module_names=module_names, output_type="tensor") as e: + for batch in my_dataloader: + ... # whatever preprocessing you want to add to the batch + feature_batch_dict = e.extract_batch( + batch=batch, + flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer + ) + ... # whatever post-processing you want to add to the extracted features +``` + +##### TensorFlow / Keras + +```python +module_names = ['visual', ...] # add more module_names here + +# your custom dataset and dataloader classes come here (for example, TFRecords files) +my_dataset = ... +my_dataloader = ... + +for batch in my_dataloader: + ... # whatever preprocessing you want to add to the batch + feature_batch = extractor.extract_batch( + batch=batch, + module_names=module_names, + flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer + ) + ... # whatever post-processing you want to add to the extracted features +``` + #### Human alignment *Human alignment*: If you want to align the extracted features with human object similarity according to the approach introduced in *[Improving neural network representations using human similiarty judgments](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9febda1c8344cc5f2d51713964864e93-Abstract-Conference.html)* you can optionally `align` the extracted features using the following method: diff --git a/requirements.txt b/requirements.txt index 6f5c79b..46d2c66 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ numpy<2 open_clip_torch==3.* pandas regex +safetensors<0.6 scikit-image scikit-learn scipy diff --git a/setup.py b/setup.py index 75f64f9..62c94a6 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ "open_clip_torch==3.*", "pandas", "regex", + "safetensors<0.6", "scikit-image", "scikit-learn", "scipy", diff --git a/tests/test_features.py b/tests/test_features.py index 9533462..0dbd266 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -44,6 +44,19 @@ def get_4D_features(self): flatten_acts=False, ) return features + + def get_multi_features(self): + model_name = "vgg16_bn" + extractor, _, batches = helper.create_extractor_and_dataloader( + model_name=model_name, pretrained=False, source="torchvision" + ) + module_names = ["features.23", "classifier.3"] + features = extractor.extract_features( + batches=batches, + module_names=module_names, + flatten_acts=False, + ) + return features def test_postprocessing(self): """Test different postprocessing methods (e.g., centering, normalization, compression).""" @@ -89,6 +102,18 @@ def test_storing_4d(self): ) self.check_file_exists("features", format, False) + + def test_storing_multi(self): + features = self.get_multi_features() + for _, feature in features.items(): + for format in set(helper.FILE_FORMATS) - set(["txt"]): + # tests whether features can be saved in any of the formats except txt + save_features( + features=feature, + out_path=helper.OUT_PATH, + file_format=format, + ) + self.check_file_exists(f"features", format, False) def test_splitting_2d(self): n_splits = 3 @@ -129,3 +154,29 @@ def test_splitting_4d(self): file_format="txt", n_splits=n_splits, ) + + def test_splitting_multi(self): + n_splits = 3 + features = self.get_multi_features() + for format in set(helper.FILE_FORMATS) - set(["txt"]): + for _, feature in features.items(): + if format == "pt": + feature = torch.from_numpy(feature) + split_features( + features=feature, + root=helper.OUT_PATH, + file_format=format, + n_splits=n_splits, + ) + + for i in range(1, n_splits): + self.check_file_exists(f"features_{i:02d}", format, False) + + with self.assertRaises(Exception): + for _, feature in features.items(): + split_features( + features=feature, + root=helper.OUT_PATH, + file_format="txt", + n_splits=n_splits, + ) diff --git a/thingsvision/core/extraction/base.py b/thingsvision/core/extraction/base.py index e1bf3fe..6b5db03 100644 --- a/thingsvision/core/extraction/base.py +++ b/thingsvision/core/extraction/base.py @@ -2,7 +2,8 @@ import os import re import warnings -from typing import Callable, Iterator, List, Optional, Union +from typing import Callable, Dict, Iterator, List, Optional, Union +from collections import defaultdict import numpy as np from torchtyping import TensorType @@ -76,17 +77,32 @@ def load_model(self) -> None: def extract_batch( self, batch: Union[TensorType["b", "c", "h", "w"], Array], - module_name: str, - flatten_acts: bool, - output_type: str, + module_name: Optional[str] = None, + module_names: Optional[List[str]] = None, + flatten_acts: bool = False, + output_type: str = "ndarray", ) -> Union[ + # This is the return type when 'module_names' is used + Dict[ + str, + Union[ + Union[ + TensorType["n", "num_maps", "h_prime", "w_prime"], + TensorType["n", "t", "d"], + TensorType["n", "p"], + TensorType["n", "d"], + ], + Array, + ], + ], + # This is the return type when 'module_name' is used (for backward compatibility) Union[ - TensorType["b", "num_maps", "h_prime", "w_prime"], - TensorType["b", "t", "d"], - TensorType["b", "p"], - TensorType["b", "d"], + TensorType["n", "num_maps", "h_prime", "w_prime"], + TensorType["n", "t", "d"], + TensorType["n", "p"], + TensorType["n", "d"], + Array, ], - Array, ]: """Extract the activations of a selected module for every image in a mini-batch. @@ -95,7 +111,9 @@ def extract_batch( batch : np.ndarray or torch.Tensor mini-batch of three-dimensional image tensors. module_name : str - Name of the module for which features should be extraced. + Name of the neural network layer for which features should be extracted. + module_names : List[str] + Names of the modules for which features should be extracted. flatten_acts : bool Whether the activation of a tensor should be flattened to a vector. output_type : str {"ndarray", "tensor"} @@ -112,16 +130,19 @@ def extract_batch( def _extract_batch( self, batch: Union[TensorType["b", "c", "h", "w"], Array], - module_name: str, - flatten_acts: bool, - ) -> Union[ + module_names: List[str], + flatten_acts: bool = False, + ) -> Dict[ + str, Union[ - TensorType["b", "num_maps", "h_prime", "w_prime"], - TensorType["b", "t", "d"], - TensorType["b", "p"], - TensorType["b", "d"], + Union[ + TensorType["n", "num_maps", "h_prime", "w_prime"], + TensorType["n", "t", "d"], + TensorType["n", "p"], + TensorType["n", "d"], + ], + Array, ], - Array, ]: raise NotImplementedError @@ -129,13 +150,16 @@ def get_output_types(self) -> List[str]: """Return the list of available output types (for the feature matrix).""" return ["ndarray", "tensor"] - def _module_and_output_check(self, module_name: str, output_type: str) -> None: + def _module_and_output_check( + self, module_names: List[str], output_type: str + ) -> None: """Checks whether the provided module name and output type are valid.""" valid_names = self.get_module_names() - if not module_name in valid_names: - raise ValueError( - f"\n{module_name} is not a valid module name. Please choose a name from the following set of modules: {valid_names}\n" - ) + for module_name in module_names: + if module_name not in valid_names: + raise ValueError( + f"\n{module_name} is not a valid module name. Please choose a name from the following set of modules: {valid_names}\n" + ) assert ( output_type in self.get_output_types() ), f"\nData type of output feature matrix must be set to one of the following available data types: {self.get_output_types()}\n" @@ -143,19 +167,36 @@ def _module_and_output_check(self, module_name: str, output_type: str) -> None: def extract_features( self, batches: Iterator[Union[TensorType["b", "c", "h", "w"], Array]], - module_name: str, + module_name: Optional[str] = None, + module_names: Optional[List[str]] = None, flatten_acts: bool = False, output_type: Optional[str] = "ndarray", output_dir: Optional[str] = None, step_size: Optional[int] = None, + file_name_suffix: str = "", + save_in_one_file: bool = False, ) -> Union[ + # This is the return type when 'module_names' is used + Dict[ + str, + Union[ + Union[ + TensorType["n", "num_maps", "h_prime", "w_prime"], + TensorType["n", "t", "d"], + TensorType["n", "p"], + TensorType["n", "d"], + ], + Array, + ], + ], + # This is the return type when 'module_name' is used (for backward compatibility) Union[ TensorType["n", "num_maps", "h_prime", "w_prime"], TensorType["n", "t", "d"], TensorType["n", "p"], TensorType["n", "d"], + Array, ], - Array, ]: """Extract hidden unit activations (at specified layer) for every image in the database. @@ -166,8 +207,11 @@ def extract_features( mini-batches, where each element is a subsample of the full (image) dataset. module_name : str - Layer name. Name of neural network layer for - which features should be extraced. + Layer name. Name of the neural network layer for + which features should be extracted. + module_names : List[str] + Layer names. Names of neural network layers for + which features should be extracted. flatten_acts : bool Whether activation tensor (e.g., activations from an early layer of the neural network model) @@ -189,13 +233,28 @@ def extract_features( are saved to disk. The default uses a heuristic so that extracted features should fit into 8GB of free memory. Only used if output_dir is defined. - + file_name_suffix: str + Suffix to append to the output file names (e.g., "_train", "_val"). + save_in_one_file : bool + If True, all features are saved in one file. If output_dir is defined, + the features are saved in separate files for each module name. They are first + saved in chunks of step_size batches, and then all features are concatenated + and saved in one file. Returns ------- output : np.ndarray or torch.Tensor Returns the feature matrix (e.g., $X \in \mathbb{R}^{n \times d}$ if penultimate or logits layer or flatten_acts = True). """ - self._module_and_output_check(module_name, output_type) + if not bool(module_name) ^ bool(module_names): + raise ValueError( + "\nPlease provide either a single module name or a list of module names, but not both.\n" + ) + if module_name is not None: + single_module_call = True + module_names = [module_name] + else: + single_module_call = False + self._module_and_output_check(module_names, output_type) if output_dir: os.makedirs(output_dir, exist_ok=True) @@ -203,58 +262,97 @@ def extract_features( # if step size is not given, assume that features to every image consume 3MB of memory and that the user has at least 8GB of free RAM step_size = 8000 // (len(next(iter(batches))) * 3) + 1 - features = [] + # create feature dict per module name + features = defaultdict(list) + feature_file_names = defaultdict(list) image_ct, last_image_ct = 0, 0 for i, batch in tqdm( enumerate(batches, start=1), desc="Batch", total=len(batches) ): - features.append( - self._extract_batch( - batch=batch, module_name=module_name, flatten_acts=flatten_acts - ) + modules_features = self._extract_batch( + batch=batch, module_names=module_names, flatten_acts=flatten_acts ) image_ct += len(batch) del batch - if output_dir and (i % step_size == 0 or i == len(batches)): - if self.get_backend() == "pt": - features_subset = torch.cat(features) - if output_type == "ndarray": - features_subset = self._to_numpy(features_subset) + for module_name in module_names: + features[module_name].append(modules_features[module_name]) + + if output_dir and (i % step_size == 0 or i == len(batches)): + if self.get_backend() == "pt": + features_subset = torch.cat(features[module_name]) + if output_type == "ndarray": + features_subset = self._to_numpy(features_subset) + features_subset_file = os.path.join( + output_dir, + f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.npy", + ) + np.save(features_subset_file, features_subset) + else: # output_type = tensor + features_subset_file = os.path.join( + output_dir, + f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.pt", + ) + torch.save(features_subset, features_subset_file) + else: features_subset_file = os.path.join( output_dir, - f"features_{last_image_ct}-{image_ct}.npy", + f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.npy", ) + features_subset = np.vstack(features[module_name]) np.save(features_subset_file, features_subset) - else: # output_type = tensor - features_subset_file = os.path.join( - output_dir, - f"features_{last_image_ct}-{image_ct}.pt", - ) - torch.save(features_subset, features_subset_file) - else: - features_subset_file = os.path.join( - output_dir, f"features_{last_image_ct}-{image_ct}.npy" - ) - features_subset = np.vstack(features) - np.save(features_subset_file, features_subset) - features = [] - last_image_ct = image_ct + features = defaultdict(list) + last_image_ct = image_ct + feature_file_names[module_name].append(features_subset_file) print( f"...Features successfully extracted for all {image_ct} images in the database." ) if output_dir: + if save_in_one_file: + # load features per module name and concatenate them + for module_name in module_names: + # load from files + features = [] + for file in feature_file_names[module_name]: + if self.get_backend() == "pt" and output_type != "ndarray": + if file.endswith(".pt"): + features.append( + torch.load(os.path.join(output_dir, file)) + ) + else: + if file.endswith(".npy"): + features.append( + np.load(os.path.join(output_dir, file)) + ) + features_file = os.path.join( + output_dir, f"{module_name}/features{file_name_suffix}" + ) + if output_type == "ndarray": + np.save(f"{features_file}.npy", np.concatenate(features)) + else: # output_type = tensor + torch.save(torch.cat(features), f"{features_file}.pt") + print( + f"...Features for module '{module_name}' were saved to {features_file}." + ) + # remove temporary files + for file in feature_file_names[module_name]: + os.remove(os.path.join(output_dir, file)) + print(f"...Features were saved to {output_dir}.") return None else: - if self.get_backend() == "pt": - features = torch.cat(features) - if output_type == "ndarray": - features = self._to_numpy(features) - else: - features = np.vstack(features) - print(f"...Features shape: {features.shape}") + for module_name in module_names: + if self.get_backend() == "pt": + features[module_name] = torch.cat(features[module_name]) + if output_type == "ndarray": + features[module_name] = self._to_numpy(features[module_name]) + else: + features[module_name] = np.vstack(features[module_name]) + print(f"...Features shape: {features[module_name].shape}") + if single_module_call: + # for backward compatibility + return features[module_name] return features @staticmethod @@ -264,7 +362,7 @@ def _to_numpy( TensorType["n", "t", "d"], TensorType["n", "p"], TensorType["n", "d"], - ] + ], ) -> Array: """Move activations to CPU and convert torch.Tensor to np.ndarray.""" return features.numpy() diff --git a/thingsvision/core/extraction/tensorflow.py b/thingsvision/core/extraction/tensorflow.py index 1dc49f5..aff3320 100644 --- a/thingsvision/core/extraction/tensorflow.py +++ b/thingsvision/core/extraction/tensorflow.py @@ -38,28 +38,44 @@ def __init__( def _extract_batch( self, batch: Array, - module_name: str, + module_names: Optional[List[str]], flatten_acts: bool, - ) -> Array: - layer_out = [self.model.get_layer(module_name).output] + ) -> Dict[str, Array]: + layer_outputs = [self.model.get_layer(name).output for name in module_names] activation_model = keras.models.Model( inputs=self.model.input, - outputs=layer_out, + outputs=layer_outputs, ) - activations = activation_model.predict(batch) + activations_list = activation_model.predict(batch) + if len(module_names) == 1: + activations_list = [activations_list] + activations_dict = { + name: act for name, act in zip(module_names, activations_list) + } if flatten_acts: - activations = activations.reshape(activations.shape[0], -1) - return activations + for name, act in activations_dict.items(): + activations_dict[name] = act.reshape(act.shape[0], -1) + return activations_dict def extract_batch( self, batch: Array, - module_name: str, - flatten_acts: bool, + module_name: Optional[str] = None, + module_names: Optional[List[str]] = None, + flatten_acts: bool = False, output_type: str = "ndarray", - ) -> Array: - self._module_and_output_check(module_name, output_type) - activations = self._extract_batch(batch, module_name, flatten_acts) + ) -> Union[Array, Dict[str, Array]]: + if not bool(module_name) ^ bool(module_names): + raise ValueError( + "\nPlease provide either a single module name or a list of module names, but not both.\n" + ) + if not module_names: + module_names = [module_name] + self._module_and_output_check(module_names, output_type) + # Extract features from the specified module, tensorflow does not support multiple modules extraction + activations = self._extract_batch(batch, module_names, flatten_acts) + if module_name: + return activations[module_name] return activations def show_model(self) -> str: diff --git a/thingsvision/core/extraction/torch.py b/thingsvision/core/extraction/torch.py index 7c86c1a..c9e0bcd 100644 --- a/thingsvision/core/extraction/torch.py +++ b/thingsvision/core/extraction/torch.py @@ -73,21 +73,35 @@ def hook(model, input, output) -> None: return hook - def _register_hook(self, module_name: str) -> None: - """Register a forward hook to store activations.""" + def _register_hooks(self, module_names: List[str]) -> None: + """Register a forward hook to store activations for multiple modules.""" + self.hook_handles = [] for n, m in self.model.named_modules(): - if n == module_name: - self.hook_handle = m.register_forward_hook(self.get_activation(n)) - break - - def _unregister_hook(self) -> None: - """Remove the forward hook.""" - self.hook_handle.remove() + if n in module_names: + handle = m.register_forward_hook(self.get_activation(n)) + self.hook_handles.append(handle) + + def _unregister_hooks(self) -> None: + """Unregister all forward hooks.""" + if self.hook_handles: + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] + else: + warnings.warn( + "\nNo hooks were registered. Nothing to unregister.\n", + category=UserWarning, + ) - def batch_extraction(self, module_name: str, output_type: str) -> object: + def batch_extraction( + self, + module_name: Optional[str] = None, + module_names: Optional[List[str]] = None, + output_type: str = "ndarray", + ) -> object: """Allows mini-batch extraction for custom data pipeline using a with-statement.""" return BatchExtraction( - extractor=self, module_name=module_name, output_type=output_type + extractor=self, module_name=module_name, module_names=module_names, output_type=output_type ) def extract_batch( @@ -95,88 +109,124 @@ def extract_batch( batch: TensorType["b", "c", "h", "w"], flatten_acts: bool, ) -> Union[ - TensorType["b", "num_maps", "h_prime", "w_prime"], - TensorType["b", "t", "d"], - TensorType["b", "p"], - TensorType["b", "d"], + # This is the return type when 'module_names' is used + Dict[ + str, + Union[ + Union[ + TensorType["n", "num_maps", "h_prime", "w_prime"], + TensorType["n", "t", "d"], + TensorType["n", "p"], + TensorType["n", "d"], + ], + Array, + ], + ], + # This is the return type when 'module_name' is used (for backward compatibility) + Union[ + TensorType["n", "num_maps", "h_prime", "w_prime"], + TensorType["n", "t", "d"], + TensorType["n", "p"], + TensorType["n", "d"], + Array, + ], ]: - act = self._extract_batch( - batch=batch, module_name=self.module_name, flatten_acts=flatten_acts + acts = self._extract_batch( + batch=batch, + module_names=self.module_names, + flatten_acts=flatten_acts, ) if self.output_type == "ndarray": - act = self._to_numpy(act) - return act + for module_name, act in acts.items(): + acts[module_name] = self._to_numpy(act) + if getattr(self, "module_name", None) is not None: + return acts[self.module_name] + return acts @torch.no_grad() def _extract_batch( self, batch: TensorType["b", "c", "h", "w"], - module_name: str, + module_names: List[str], flatten_acts: bool, - ) -> Union[ - TensorType["b", "num_maps", "h_prime", "w_prime"], - TensorType["b", "t", "d"], - TensorType["b", "p"], - TensorType["b", "d"], + ) -> Dict[ + str, + Union[ + TensorType["b", "num_maps", "h_prime", "w_prime"], + TensorType["b", "t", "d"], + TensorType["b", "p"], + TensorType["b", "d"], + ], ]: """Extract representations from a batch of images.""" # move mini-batch to current device batch = batch.to(self.device) batch_size = batch.shape[0] _ = self.forward(batch) - act = self.activations[module_name] - if len(act.shape) > 2: - if hasattr(self, "token_extraction"): - if self.token_extraction == "cls_token": - act = act[:, 0, :].clone() - elif self.token_extraction == "avg_pool": - act = act[:, 1:, :].clone().mean(dim=1) - elif self.token_extraction == "cls_token+avg_pool": - cls_token = act[:, 0, :].clone() - pooled_tokens = act[:, 1:, :].clone().mean(dim=1) - act = torch.cat((cls_token, pooled_tokens), dim=1) - else: - raise ValueError( - f"\n{self.token_extraction} is not a valid value for token extraction. " - "\nChoose one of the following: {TOKEN_EXTRACTIONS}.\n " - ) - elif flatten_acts: - if self.model_name.lower().startswith("clip"): - act = self.flatten_acts(act, batch, module_name) - else: - act = self.flatten_acts(act) - if act.shape[0] != batch_size: - raise ValueError( - f"The number of extracted features ({act.shape=}) does not match the batch size ({batch.shape=}). " - "Please check the model, the module name, and the model parameters ..." - ) - if act.is_cuda or act.get_device() >= 0: - torch.cuda.empty_cache() - act = act.cpu() - return act + acts = {} + for module_name in module_names: + act = self.activations[module_name] + if act.shape[0] != batch_size: + raise ValueError( + f"The number of extracted features ({act.shape=}) does not match the batch size ({batch.shape=}). " + "Please check the model, the module name, and the model parameters ..." + ) + if len(act.shape) > 2: + if hasattr(self, "token_extraction"): + if self.token_extraction == "cls_token": + act = act[:, 0, :].clone() + elif self.token_extraction == "avg_pool": + act = act[:, 1:, :].clone().mean(dim=1) + elif self.token_extraction == "cls_token+avg_pool": + cls_token = act[:, 0, :].clone() + pooled_tokens = act[:, 1:, :].clone().mean(dim=1) + act = torch.cat((cls_token, pooled_tokens), dim=1) + else: + raise ValueError( + f"\n{self.token_extraction} is not a valid value for token extraction. " + "\nChoose one of the following: {TOKEN_EXTRACTIONS}.\n " + ) + elif flatten_acts: + if self.model_name.lower().startswith("clip"): + act = self.flatten_acts(act, batch, module_name) + else: + act = self.flatten_acts(act) + if act.is_cuda or act.get_device() >= 0: + torch.cuda.empty_cache() + act = act.cpu() + acts[module_name] = act + return acts def extract_features( self, batches: Iterator, - module_name: str, + module_name: Optional[str] = None, + module_names: Optional[List[str]] = None, flatten_acts: bool = False, output_type: str = "ndarray", output_dir: Optional[str] = None, step_size: Optional[int] = None, ): + if not bool(module_name) ^ bool(module_names): + raise ValueError( + "\nPlease provide either a single module name or a list of module names, but not both.\n" + ) self.model = self.model.to(self.device) self.activations = {} - self._register_hook(module_name=module_name) + if module_name is not None: + self._register_hooks(module_names=[module_name]) + else: + self._register_hooks(module_names=module_names) features = super().extract_features( batches=batches, module_name=module_name, + module_names=module_names, flatten_acts=flatten_acts, output_type=output_type, output_dir=output_dir, step_size=step_size, ) - if self.hook_handle: - self._unregister_hook() + self._unregister_hooks() return features def forward( @@ -189,7 +239,7 @@ def forward( def flatten_acts( act: Union[ TensorType["b", "num_maps", "h_prime", "w_prime"], TensorType["b", "t", "d"] - ] + ], ) -> TensorType["b", "p"]: """Default flattening of activations.""" return act.view(act.size(0), -1) @@ -276,7 +326,11 @@ def get_backend(self) -> str: class BatchExtraction(object): def __init__( - self, extractor: PyTorchExtractor, module_name: str, output_type: str + self, + extractor: PyTorchExtractor, + module_name: Optional[str] = None, + module_names: Optional[List[str]] = None, + output_type: str = "ndarray", ) -> None: """ Mini-batch extraction object that can be used as a with-statement in a PyTorch extractor. @@ -285,23 +339,33 @@ def __init__( ---------- extractor (object): PyTorchExtractor class. module_name (str): The module of model for which features will be extracted. + module_names (List[str]): List of modules of model for which features will be extracted. output_type (str): Type of the feature matrix returned by the extractor. """ + if not bool(module_name) ^ bool(module_names): + raise ValueError( + "\nPlease provide either a single module name or a list of module names, but not both.\n" + ) + if module_name is not None: + module_names = [module_name] self.extractor = extractor self.module_name = module_name + self.module_names = module_names self.output_type = output_type def __enter__(self) -> PyTorchExtractor: """Registering hooks and setting attributes during opening.""" - self.extractor._module_and_output_check(self.module_name, self.output_type) - self.extractor._register_hook(self.module_name) + self.extractor._module_and_output_check(self.module_names, self.output_type) + self.extractor._register_hooks(self.module_names) setattr(self.extractor, "module_name", self.module_name) + setattr(self.extractor, "module_names", self.module_names) setattr(self.extractor, "output_type", self.output_type) return self.extractor def __exit__(self, *args): """Removing hooks and deleting attributes at closing.""" - self.extractor._unregister_hook() + self.extractor._unregister_hooks() delattr(self.extractor, "module_name") + delattr(self.extractor, "module_names") delattr(self.extractor, "output_type")