diff --git a/thingsvision/core/extraction/base.py b/thingsvision/core/extraction/base.py index 6b5db03..f0d5a32 100644 --- a/thingsvision/core/extraction/base.py +++ b/thingsvision/core/extraction/base.py @@ -277,6 +277,8 @@ def extract_features( del batch for module_name in module_names: + if output_dir: + os.makedirs(os.path.join(output_dir, module_name), exist_ok=True) features[module_name].append(modules_features[module_name]) if output_dir and (i % step_size == 0 or i == len(batches)): diff --git a/thingsvision/core/extraction/torch.py b/thingsvision/core/extraction/torch.py index c9e0bcd..ba52f57 100644 --- a/thingsvision/core/extraction/torch.py +++ b/thingsvision/core/extraction/torch.py @@ -206,6 +206,8 @@ def extract_features( output_type: str = "ndarray", output_dir: Optional[str] = None, step_size: Optional[int] = None, + file_name_suffix: str = "", + save_in_one_file: bool = False, ): if not bool(module_name) ^ bool(module_names): raise ValueError( @@ -225,6 +227,8 @@ def extract_features( output_type=output_type, output_dir=output_dir, step_size=step_size, + file_name_suffix=file_name_suffix, + save_in_one_file=save_in_one_file, ) self._unregister_hooks() return features