diff --git a/.gitignore b/.gitignore index 3f1d8382..d4e807ff 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ build dist # files from script downloads +data openfold/resources/ tests/test_data/ cutlass/ diff --git a/archive/environment.yml b/archive/environment.yml new file mode 100644 index 00000000..427c5559 --- /dev/null +++ b/archive/environment.yml @@ -0,0 +1,64 @@ +name: vip-zarr-vizfold +channels: + - pytorch + - nvidia + - conda-forge + - bioconda + - defaults + +dependencies: + - python=3.10 + - setuptools=59.5.0 + - pip + + # Core scientific stack + - numpy + - pandas + - scipy + - matplotlib + - pillow + - tqdm + - psutil + - PyYAML + - typing-extensions + - requests + - biopython + + # Notebook / interactive workflow + - jupyterlab + - ipykernel + + # PyTorch with CUDA 12.4 + # CPU-only alternative: comment out the three lines below and uncomment these: + # - pytorch + # - torchvision + # - cpuonly + - pytorch::pytorch=2.5 + - pytorch::torchvision + - pytorch::pytorch-cuda=12.4 + + # OpenFold / VizFold runtime dependencies + # Note: cuda package provides nvcc + headers for compiling attn_core_inplace_cuda via setup.py + # Note: gcc=12.4 is Linux-only; on Windows use MSVC (Visual Studio Build Tools with C++ workload) + - cuda + - openmm + - pdbfixer + - pytorch-lightning + - ml-collections + - mkl + + # Zarr archive stack + - zarr=2.16 + - numcodecs + - fsspec + + - pip: + # OpenFold extras + - dm-tree==0.1.6 + - modelcif==0.7 + # NOTE: do NOT put "-e .." here. + # setup.py imports torch at the top level, so pip's isolated build + # subprocess fails with ModuleNotFoundError before torch is installed. + # After activating this env, install openfold manually from the repo root: + # cd .. + # pip install -e . \ No newline at end of file diff --git a/archive/openfold_sweep_to_zarr.py b/archive/openfold_sweep_to_zarr.py new file mode 100644 index 00000000..40b984bb --- /dev/null +++ b/archive/openfold_sweep_to_zarr.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +import argparse +from itertools import product +import json +import pickle +from pathlib import Path +import subprocess +from typing import Any, Dict, Iterable, List, Tuple + +from standardizedarchive.openfold_zarr_archive import ( + OpenFoldRunResult, + OpenFoldZarrArchive, + score_from_output_dict, + select_best_entries, +) + + +def _flag_value(name: str, value: Any) -> str: + if isinstance(value, bool): + return f"--{name}" if value else "" + return f"--{name} {value}" + + +def _normalize_for_id(value: Any) -> str: + safe = str(value).replace(" ", "_").replace("/", "_") + return safe[:80] + + +def expand_grid(param_grid: Dict[str, List[Any]]) -> Iterable[Dict[str, Any]]: + keys = sorted(param_grid.keys()) + value_lists = [param_grid[key] for key in keys] + for values in product(*value_lists): + yield dict(zip(keys, values)) + + +def _find_output_pickle(run_output_dir: Path) -> Path: + candidates = sorted(run_output_dir.rglob("*_output_dict.pkl")) + if not candidates: + raise FileNotFoundError( + f"No *_output_dict.pkl found under {run_output_dir}. " + "Ensure the command enables --save_outputs." + ) + return candidates[-1] + + +def _run_single_openfold_command( + base_command: str, + params: Dict[str, Any], + run_output_dir: Path, +) -> Tuple[subprocess.CompletedProcess[str], str]: + attn_map_dir = run_output_dir / "attention_maps" + attn_map_dir.mkdir(parents=True, exist_ok=True) + + flags = [_flag_value(name, value) for name, value in sorted(params.items())] + flags = [flag for flag in flags if flag] + + command = " ".join( + [ + base_command, + "--save_outputs", + f"--output_dir {run_output_dir}", + f"--attn_map_dir {attn_map_dir}", + *flags, + ] + ) + + completed = subprocess.run(command, shell=True, capture_output=True, text=True) + return completed, command + + +def _collect_result( + run_id: str, + params: Dict[str, Any], + run_output_dir: Path, + command: str, + score_key: str, + changed_param: str | None = None, + from_value: Any | None = None, + to_value: Any | None = None, + score_delta: float | None = None, + step_index: int | None = None, +) -> OpenFoldRunResult: + output_pickle = _find_output_pickle(run_output_dir) + with output_pickle.open("rb") as handle: + output_dict = pickle.load(handle) + + score = score_from_output_dict(output_dict, score_key=score_key) + return OpenFoldRunResult( + run_id=run_id, + score=score, + params=dict(params), + output_dir=str(run_output_dir), + command=command, + model_output_path=str(output_pickle), + changed_param=changed_param, + from_value=from_value, + to_value=to_value, + score_delta=score_delta, + step_index=step_index, + ) + + +def run_sweep( + base_command: str, + param_grid: Dict[str, List[Any]], + runs_root: Path, + score_key: str, +) -> List[OpenFoldRunResult]: + results: List[OpenFoldRunResult] = [] + failures: List[Tuple[str, int, str, str]] = [] + + for idx, params in enumerate(expand_grid(param_grid), start=1): + run_id_parts = [f"{k}-{_normalize_for_id(v)}" for k, v in sorted(params.items())] + run_id = f"run-{idx:03d}__" + "__".join(run_id_parts) + + run_output_dir = runs_root / run_id + run_output_dir.mkdir(parents=True, exist_ok=True) + + completed, command = _run_single_openfold_command(base_command, params, run_output_dir) + if completed.returncode != 0: + failures.append( + ( + run_id, + completed.returncode, + completed.stderr.strip(), + completed.stdout.strip(), + ) + ) + continue + + results.append( + _collect_result( + run_id=run_id, + params=params, + run_output_dir=run_output_dir, + command=command, + score_key=score_key, + ) + ) + + if failures: + print(f"[openfold-sweep] failed runs: {len(failures)}") + for run_id, returncode, stderr, stdout in failures: + print(f"[openfold-sweep] run={run_id} returncode={returncode}") + if stderr: + print("[openfold-sweep] stderr:") + print(stderr[-2000:]) + elif stdout: + print("[openfold-sweep] stdout:") + print(stdout[-2000:]) + + if not results: + raise RuntimeError( + "All OpenFold sweep runs failed. Check the per-run stderr summaries above." + ) + + return results + + +def run_incremental_sweep( + base_command: str, + param_grid: Dict[str, List[Any]], + runs_root: Path, + score_key: str, +) -> Tuple[List[OpenFoldRunResult], List[OpenFoldRunResult], OpenFoldRunResult]: + if not param_grid: + raise ValueError("param_grid must contain at least one parameter") + + ordered_keys = sorted(param_grid.keys()) + for key in ordered_keys: + values = param_grid[key] + if not isinstance(values, list) or len(values) == 0: + raise ValueError(f"Parameter '{key}' must map to a non-empty list") + + baseline_params = {key: param_grid[key][0] for key in ordered_keys} + all_results: List[OpenFoldRunResult] = [] + best_increment_entries: List[OpenFoldRunResult] = [] + failures: List[Tuple[str, int, str, str]] = [] + + baseline_run_id = "run-000__baseline" + baseline_output_dir = runs_root / baseline_run_id + baseline_output_dir.mkdir(parents=True, exist_ok=True) + baseline_completed, baseline_command = _run_single_openfold_command( + base_command, + baseline_params, + baseline_output_dir, + ) + if baseline_completed.returncode != 0: + raise RuntimeError( + "Baseline incremental run failed. stderr:\n" + + baseline_completed.stderr[-2000:] + ) + + current_best = _collect_result( + run_id=baseline_run_id, + params=baseline_params, + run_output_dir=baseline_output_dir, + command=baseline_command, + score_key=score_key, + step_index=0, + ) + all_results.append(current_best) + + run_counter = 1 + for step_index, key in enumerate(ordered_keys, start=1): + current_value = current_best.params[key] + candidates = [value for value in param_grid[key] if value != current_value] + + best_trial: OpenFoldRunResult | None = None + for candidate in candidates: + trial_params = dict(current_best.params) + trial_params[key] = candidate + + run_id = f"run-{run_counter:03d}__step-{step_index:02d}__{key}-{_normalize_for_id(candidate)}" + run_counter += 1 + + run_output_dir = runs_root / run_id + run_output_dir.mkdir(parents=True, exist_ok=True) + + completed, command = _run_single_openfold_command(base_command, trial_params, run_output_dir) + if completed.returncode != 0: + failures.append( + ( + run_id, + completed.returncode, + completed.stderr.strip(), + completed.stdout.strip(), + ) + ) + continue + + trial_result = _collect_result( + run_id=run_id, + params=trial_params, + run_output_dir=run_output_dir, + command=command, + score_key=score_key, + changed_param=key, + from_value=current_value, + to_value=candidate, + step_index=step_index, + ) + all_results.append(trial_result) + + if best_trial is None or trial_result.score > best_trial.score: + best_trial = trial_result + + if best_trial is None: + continue + + score_delta = best_trial.score - current_best.score + if score_delta > 0: + improved = OpenFoldRunResult( + run_id=best_trial.run_id, + score=best_trial.score, + params=best_trial.params, + output_dir=best_trial.output_dir, + command=best_trial.command, + model_output_path=best_trial.model_output_path, + changed_param=best_trial.changed_param, + from_value=best_trial.from_value, + to_value=best_trial.to_value, + score_delta=score_delta, + step_index=best_trial.step_index, + ) + best_increment_entries.append(improved) + current_best = improved + + if failures: + print(f"[openfold-sweep] failed runs: {len(failures)}") + for run_id, returncode, stderr, stdout in failures: + print(f"[openfold-sweep] run={run_id} returncode={returncode}") + if stderr: + print("[openfold-sweep] stderr:") + print(stderr[-2000:]) + elif stdout: + print("[openfold-sweep] stdout:") + print(stdout[-2000:]) + + if len(all_results) == 1 and not best_increment_entries: + print("[openfold-sweep] no successful incremental candidate runs; baseline only") + + return all_results, best_increment_entries, current_best + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run OpenFold parameter sweep and archive best entries in Zarr") + parser.add_argument("--base_command", required=True, help="Base command used to launch OpenFold runs") + parser.add_argument("--grid_json", required=True, help="JSON file mapping parameter names to candidate values") + parser.add_argument("--runs_root", default="outputs/sweep_runs", help="Directory for per-run outputs") + parser.add_argument("--archive_path", default="standardizedarchive/openfold_best_runs.zarr", help="Zarr archive output path") + parser.add_argument("--best_log_path", default="standardizedarchive/best_entries.jsonl", help="Best entries log path") + parser.add_argument("--top_k", type=int, default=1, help="Number of best entries to keep") + parser.add_argument("--score_key", default="plddt", help="Output dict key used for scoring") + parser.add_argument( + "--sweep_strategy", + choices=["incremental", "grid"], + default="incremental", + help="Sweep strategy: incremental one-parameter-at-a-time, or full grid search", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + with open(args.grid_json, "r", encoding="utf-8") as handle: + param_grid = json.load(handle) + + if not isinstance(param_grid, dict): + raise ValueError("grid_json must contain an object mapping parameter names to value lists") + + runs_root = Path(args.runs_root) + runs_root.mkdir(parents=True, exist_ok=True) + + if args.sweep_strategy == "grid": + results = run_sweep( + base_command=args.base_command, + param_grid=param_grid, + runs_root=runs_root, + score_key=args.score_key, + ) + best_entries = select_best_entries(results, top_k=args.top_k) + else: + results, best_increment_entries, final_best = run_incremental_sweep( + base_command=args.base_command, + param_grid=param_grid, + runs_root=runs_root, + score_key=args.score_key, + ) + if best_increment_entries: + best_entries = best_increment_entries + else: + best_entries = [final_best] + + archive = OpenFoldZarrArchive(args.archive_path) + archive.root.attrs["sweep_strategy"] = args.sweep_strategy + archive.root.attrs["score_key"] = args.score_key + archive.append_best_entries(best_entries) + archive.write_best_log(best_entries, args.best_log_path) + + +if __name__ == "__main__": + main() diff --git a/archive/openfold_zarr_archive.py b/archive/openfold_zarr_archive.py new file mode 100644 index 00000000..dde879c3 --- /dev/null +++ b/archive/openfold_zarr_archive.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +import json +from pathlib import Path +import pickle +import re +import shlex +from typing import Any, Dict, Iterable, List + +import numpy as np + +try: + import zarr +except ImportError as exc: # pragma: no cover - exercised in environments missing zarr + raise ImportError( + "zarr is required for standardizedarchive. Install with `pip install zarr`." + ) from exc + + +@dataclass(frozen=True) +class OpenFoldRunResult: + run_id: str + score: float + params: Dict[str, Any] + output_dir: str + command: str + model_output_path: str + changed_param: str | None = None + from_value: Any | None = None + to_value: Any | None = None + score_delta: float | None = None + step_index: int | None = None + + +def score_from_output_dict(output_dict: Dict[str, Any], score_key: str = "plddt") -> float: + """Compute a scalar quality score from an OpenFold output dictionary.""" + if score_key not in output_dict: + keys = ", ".join(sorted(output_dict.keys())) + raise KeyError(f"Score key '{score_key}' not present in output dict. Keys: {keys}") + + values = np.asarray(output_dict[score_key], dtype=np.float64) + if values.size == 0: + raise ValueError(f"Score key '{score_key}' contained no values") + + return float(values.mean()) + + +def select_best_entries(entries: Iterable[OpenFoldRunResult], top_k: int = 1) -> List[OpenFoldRunResult]: + if top_k < 1: + raise ValueError("top_k must be >= 1") + + ranked = sorted(entries, key=lambda entry: entry.score, reverse=True) + return ranked[:top_k] + + +class OpenFoldZarrArchive: + """Persist best OpenFold sweep results in a Zarr hierarchy.""" + + def __init__(self, archive_path: str | Path): + self.archive_path = Path(archive_path) + self.root = zarr.open_group(str(self.archive_path), mode="a") + self.root.attrs.setdefault("archive_type", "openfold_parameter_sweep") + self.root.attrs.setdefault("created_at_utc", datetime.now(timezone.utc).isoformat()) + + def append_best_entries(self, entries: Iterable[OpenFoldRunResult]) -> None: + best_group = self.root.require_group("best_entries") + + for entry in entries: + run_group = best_group.require_group(entry.run_id) + run_group.attrs["score"] = float(entry.score) + run_group.attrs["params_json"] = json.dumps(entry.params, sort_keys=True) + run_group.attrs["output_dir"] = entry.output_dir + run_group.attrs["command"] = entry.command + run_group.attrs["model_output_path"] = entry.model_output_path + if entry.changed_param is not None: + run_group.attrs["changed_param"] = entry.changed_param + if entry.from_value is not None: + run_group.attrs["from_value_json"] = json.dumps(entry.from_value) + if entry.to_value is not None: + run_group.attrs["to_value_json"] = json.dumps(entry.to_value) + if entry.score_delta is not None: + run_group.attrs["score_delta"] = float(entry.score_delta) + if entry.step_index is not None: + run_group.attrs["step_index"] = int(entry.step_index) + run_group.attrs["saved_at_utc"] = datetime.now(timezone.utc).isoformat() + self._archive_run_artifacts(run_group, entry) + + def write_best_log(self, entries: Iterable[OpenFoldRunResult], log_path: str | Path) -> None: + log_file = Path(log_path) + log_file.parent.mkdir(parents=True, exist_ok=True) + with log_file.open("w", encoding="utf-8") as handle: + for entry in entries: + record = { + "run_id": entry.run_id, + "score": float(entry.score), + "params": entry.params, + "output_dir": entry.output_dir, + "command": entry.command, + "model_output_path": entry.model_output_path, + } + if entry.changed_param is not None: + record["changed_param"] = entry.changed_param + if entry.from_value is not None: + record["from_value"] = entry.from_value + if entry.to_value is not None: + record["to_value"] = entry.to_value + if entry.score_delta is not None: + record["score_delta"] = float(entry.score_delta) + if entry.step_index is not None: + record["step_index"] = int(entry.step_index) + handle.write(json.dumps(record, sort_keys=True) + "\n") + + @staticmethod + def _sanitize_component(name: str) -> str: + safe = re.sub(r"[^0-9A-Za-z._-]+", "_", name) + safe = safe.strip("._") + return safe or "item" + + @staticmethod + def _extract_flag_value(command: str, flag: str) -> str | None: + try: + tokens = shlex.split(command) + except ValueError: + return None + + value: str | None = None + for idx, token in enumerate(tokens): + if token == flag and idx + 1 < len(tokens): + value = tokens[idx + 1] + elif token.startswith(f"{flag}="): + value = token.split("=", 1)[1] + + return value + + @staticmethod + def _as_numpy_array(value: Any) -> np.ndarray | None: + if isinstance(value, np.ndarray): + return value + + if isinstance(value, (int, float, bool, np.number)): + return np.asarray(value) + + if hasattr(value, "detach") and hasattr(value, "cpu") and hasattr(value, "numpy"): + return np.asarray(value.detach().cpu().numpy()) + + if isinstance(value, (list, tuple)): + try: + arr = np.asarray(value) + except Exception: + return None + if arr.dtype == object: + return None + return arr + + return None + + def _write_array_dataset(self, parent_group: Any, name: str, array: np.ndarray) -> None: + key = self._sanitize_component(name) + if key in parent_group: + del parent_group[key] + parent_group.create_dataset( + key, + data=array, + shape=array.shape, + dtype=array.dtype, + overwrite=True, + ) + + def _archive_output_dict_arrays(self, group: Any, obj: Any, depth: int = 0) -> None: + if depth > 8: + return + + if isinstance(obj, dict): + for key, value in sorted(obj.items(), key=lambda item: str(item[0])): + child_name = self._sanitize_component(str(key)) + child_group = group.require_group(child_name) + self._archive_output_dict_arrays(child_group, value, depth + 1) + return + + if isinstance(obj, (list, tuple)) and obj and any(isinstance(x, (dict, list, tuple)) for x in obj): + for idx, item in enumerate(obj): + child_group = group.require_group(f"idx_{idx}") + self._archive_output_dict_arrays(child_group, item, depth + 1) + return + + array = self._as_numpy_array(obj) + if array is None or array.dtype == object: + return + + self._write_array_dataset(group, "values", np.asarray(array)) + + def _archive_file_bytes(self, files_group: Any, file_path: Path, relative_path: Path) -> None: + # Preserve directory structure to avoid collisions for files sharing a basename. + safe_parts = [self._sanitize_component(part) for part in relative_path.parts] + file_group = files_group + for part in safe_parts: + file_group = file_group.require_group(part) + if "bytes" in file_group: + del file_group["bytes"] + + payload = np.frombuffer(file_path.read_bytes(), dtype=np.uint8) + file_group.create_dataset( + "bytes", + data=payload, + shape=payload.shape, + dtype=payload.dtype, + overwrite=True, + ) + file_group.attrs["source_path"] = str(file_path) + file_group.attrs["relative_path"] = str(relative_path) + file_group.attrs["size_bytes"] = int(file_path.stat().st_size) + + def _archive_run_artifacts(self, run_group: Any, entry: OpenFoldRunResult) -> None: + artifacts_group = run_group.require_group("artifacts") + + output_dict_path = Path(entry.model_output_path) + if output_dict_path.exists(): + try: + with output_dict_path.open("rb") as handle: + output_dict = pickle.load(handle) + + activations_group = artifacts_group.require_group("layer_wise_activations") + self._archive_output_dict_arrays(activations_group, output_dict) + except Exception as exc: + artifacts_group.attrs["layer_wise_activations_error"] = str(exc) + + attention_dir = self._extract_flag_value(entry.command, "--attn_map_dir") + attention_group = artifacts_group.require_group("attention_maps") + if attention_dir: + attn_path = Path(attention_dir) + attention_group.attrs["attention_dir"] = str(attn_path) + if attn_path.exists() and attn_path.is_dir(): + files_group = attention_group.require_group("files") + for file_path in sorted(attn_path.rglob("*")): + if file_path.is_file(): + self._archive_file_bytes( + files_group, + file_path, + file_path.relative_to(attn_path), + ) + + structure_group = artifacts_group.require_group("structural_outputs") + run_output_dir = Path(entry.output_dir) + structure_group.attrs["run_output_dir"] = str(run_output_dir) + if run_output_dir.exists() and run_output_dir.is_dir(): + files_group = structure_group.require_group("files") + for file_path in sorted(run_output_dir.rglob("*")): + if file_path.is_file() and file_path.suffix.lower() in {".pdb", ".cif", ".mmcif"}: + self._archive_file_bytes( + files_group, + file_path, + file_path.relative_to(run_output_dir), + ) + + metadata_group = artifacts_group.require_group("metadata") + metadata_group.attrs["saved_at_utc"] = datetime.now(timezone.utc).isoformat() + metadata_group.attrs["run_id"] = entry.run_id + metadata_group.attrs["score"] = float(entry.score) + metadata_group.attrs["params_json"] = json.dumps(entry.params, sort_keys=True) + metadata_group.attrs["command"] = entry.command + + model_version = self._extract_flag_value(entry.command, "--config_preset") + checkpoint_path = self._extract_flag_value(entry.command, "--openfold_checkpoint_path") + if model_version is not None: + metadata_group.attrs["model_version"] = model_version + if checkpoint_path is not None: + metadata_group.attrs["checkpoint_path"] = checkpoint_path diff --git a/archive/run.vizfold.zarr/.zgroup b/archive/run.vizfold.zarr/.zgroup new file mode 100644 index 00000000..3b7daf22 --- /dev/null +++ b/archive/run.vizfold.zarr/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/archive/run.vizfold.zarr/attention/.zgroup b/archive/run.vizfold.zarr/attention/.zgroup new file mode 100644 index 00000000..3b7daf22 --- /dev/null +++ b/archive/run.vizfold.zarr/attention/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/archive/run.vizfold.zarr/attention/triangle_start/.zgroup b/archive/run.vizfold.zarr/attention/triangle_start/.zgroup new file mode 100644 index 00000000..3b7daf22 --- /dev/null +++ b/archive/run.vizfold.zarr/attention/triangle_start/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/archive/run.vizfold.zarr/metadata/.zgroup b/archive/run.vizfold.zarr/metadata/.zgroup new file mode 100644 index 00000000..3b7daf22 --- /dev/null +++ b/archive/run.vizfold.zarr/metadata/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/archive/run.vizfold.zarr/metadata/config_version/.zarray b/archive/run.vizfold.zarr/metadata/config_version/.zarray new file mode 100644 index 00000000..f9df8496 --- /dev/null +++ b/archive/run.vizfold.zarr/metadata/config_version/.zarray @@ -0,0 +1,10 @@ +{ + "chunks": [], + "compressor": null, + "dtype": " str: + """Return the zero-padded layer key string, e.g. 2 -> 'layer_02'.""" + return f"layer_{layer_index:02d}" + + +# ============================================================ +# METHOD 1 +# ============================================================ + +def tensor_to_numpy(tensor): + """ + Convert a VizFold tensor into a NumPy array. + + VizFold outputs may come from: + - PyTorch tensors (detach → cpu → numpy) + - NumPy arrays (returned as-is) + + Parameters + ---------- + tensor : torch.Tensor | numpy.ndarray + + Returns + ------- + numpy.ndarray + A CPU-based NumPy representation of the input. + """ + if isinstance(tensor, np.ndarray): + return tensor + if hasattr(tensor, "detach") and hasattr(tensor, "cpu"): + return tensor.detach().cpu().numpy() + return np.asarray(tensor) + + +# ============================================================ +# METHOD 2 +# ============================================================ + +def open_archive(archive_path: str, overwrite: bool = False): + """ + Open (or create) a VizFold Zarr archive and initialise the required + top-level group structure. + + Creates the following empty groups if they do not already exist: + metadata/ + representations/single/ + representations/pair/ + attention/triangle_start/ + structure/ + + Parameters + ---------- + archive_path : str + File-system path for the Zarr DirectoryStore + (e.g. ``'run.vizfold.zarr'``). + + overwrite : bool + If True, delete and recreate the store. + If False, open for appending; raise if the path exists as a + non-Zarr directory. + + Returns + ------- + zarr.Group + The root group of the opened archive, ready for writing. + """ + if overwrite and os.path.exists(archive_path): + shutil.rmtree(archive_path) + + root = zarr.open_group(archive_path, mode="a") + + root.require_group("metadata") + representations = root.require_group("representations") + representations.require_group("single") + representations.require_group("pair") + attention = root.require_group("attention") + attention.require_group("triangle_start") + root.require_group("structure") + + return root + + +# ============================================================ +# METHOD 3 +# ============================================================ + +def store_metadata( + archive_path: str, + model_version: str, + config_version: str, + sequence: str, + num_residues: int, + num_recycles: int, + recycle_info, + residue_index, + representation_names, +): + """ + Write all run-level metadata into the ``metadata/`` group. + + Archive layout written by this function: + + metadata/model_version <- scalar string + metadata/config_version <- scalar string + metadata/sequence <- scalar string + metadata/num_residues <- scalar int + metadata/num_recycles <- scalar int + metadata/recycle_info <- 1-D float array, shape (num_recycles,) + metadata/residue_index <- 1-D int array, shape (num_residues,) + metadata/representation_names <- 1-D object array of strings + + Scalar strings and ints are stored as length-1 Zarr arrays so that the + archive remains self-describing without external config files. + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + model_version : str + Human-readable model version tag (e.g. ``'openfold-v2.2.0'``). + + config_version : str + Identifier for the model config used (e.g. ``'model_1_ptm'``). + + sequence : str + Amino-acid sequence in one-letter code (length == num_residues). + + num_residues : int + Number of residues in the input sequence. + + num_recycles : int + Number of recycling iterations performed. + + recycle_info : array-like, shape (num_recycles,) + Per-recycle diagnostic values (e.g. pLDDT or RMSD between recycles). + + residue_index : array-like, shape (num_residues,) + Original residue numbering from the input (accommodates gaps). + + representation_names : array-like of str + Labels for each representation layer stored under + ``representations/``. + + Returns + ------- + None + """ + recycle_info = tensor_to_numpy(recycle_info) + residue_index = tensor_to_numpy(residue_index) + + if len(sequence) != num_residues: + raise ValueError( + f"sequence length ({len(sequence)}) != num_residues ({num_residues})" + ) + if recycle_info.shape != (num_recycles,): + raise ValueError( + f"recycle_info shape {recycle_info.shape} != ({num_recycles},)" + ) + if residue_index.shape != (num_residues,): + raise ValueError( + f"residue_index shape {residue_index.shape} != ({num_residues},)" + ) + + root = zarr.open_group(archive_path, mode="a") + meta = root.require_group("metadata") + + scalars = { + "model_version": np.array(model_version), + "config_version": np.array(config_version), + "sequence": np.array(sequence), + "num_residues": np.array(num_residues, dtype=np.int64), + "num_recycles": np.array(num_recycles, dtype=np.int64), + } + for key, value in scalars.items(): + if key in meta: + del meta[key] + meta[key] = value + + arrays = { + "recycle_info": np.asarray(recycle_info), + "residue_index": np.asarray(residue_index, dtype=np.int64), + "representation_names": np.asarray(representation_names), + } + for key, value in arrays.items(): + if key in meta: + del meta[key] + meta[key] = value + + +# ============================================================ +# METHOD 4 +# ============================================================ + +def store_single_representation( + archive_path: str, + layer_index: int, + single_array, + chunks=None, + overwrite: bool = True, +): + """ + Store a per-residue (single) representation for one Evoformer layer. + + Archive location: + representations/single/layer_ + + Expected array shape: + (num_residues, single_dim) + + Each layer's output is stored as a separate Zarr array under the + ``representations/single/`` group. The layer key is zero-padded to + two digits (``layer_00``, ``layer_01``, …) so lexicographic and + numeric ordering agree. + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + layer_index : int + Zero-based Evoformer layer index. + + single_array : numpy.ndarray | torch.Tensor + Per-residue representation, shape (num_residues, single_dim). + + chunks : tuple, optional + Zarr chunk shape. Defaults to one residue per chunk: + ``(1, single_dim)``. + + overwrite : bool + Replace existing data for this layer if True. + + Returns + ------- + None + """ + single_array = tensor_to_numpy(single_array) + + # Validate shape + if single_array.ndim != 2: + raise ValueError( + f"Expected (num_residues, single_dim), got {single_array.shape}" + ) + + num_residues, single_dim = single_array.shape + + root = zarr.open_group(archive_path, mode="a") + group = root["representations"]["single"] + + layer_key = _layer_key(layer_index) + + if layer_key in group: + if not overwrite: + raise FileExistsError(f"{layer_key} already exists") + del group[layer_key] + + # Default chunking: per-residue + if chunks is None: + chunks = (1, single_dim) + + group.create_dataset( + layer_key, + data=single_array, + chunks=chunks, + ) + + +# ============================================================ +# METHOD 5 +# ============================================================ + +def store_pair_representation( + archive_path: str, + layer_index: int, + pair_array, + chunks=None, + overwrite: bool = True, +): + """ + Store a pairwise representation for one Evoformer layer. + + Archive location: + representations/pair/layer_ + + Expected array shape: + (num_residues, num_residues, pair_dim) + + The first two dimensions must be equal (square residue × residue matrix). + Each layer is stored as a separate array; no cross-layer data is mixed. + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + layer_index : int + Zero-based Evoformer layer index. + + pair_array : numpy.ndarray | torch.Tensor + Pairwise representation, shape (num_residues, num_residues, pair_dim). + + chunks : tuple, optional + Zarr chunk shape. Defaults to one residue-row per chunk: + ``(1, num_residues, pair_dim)``. + + overwrite : bool + Replace existing data for this layer if True. + + Returns + ------- + None + """ + pair_array = tensor_to_numpy(pair_array) + + # Validate shape + if pair_array.ndim != 3: + raise ValueError( + f"Expected (num_residues, num_residues, pair_dim), got {pair_array.shape}" + ) + + n_i, n_j, pair_dim = pair_array.shape + + if n_i != n_j: + raise ValueError( + f"Pair representation must be square, got {pair_array.shape}" + ) + + root = zarr.open_group(archive_path, mode="a") + group = root["representations"]["pair"] + + layer_key = _layer_key(layer_index) + + if layer_key in group: + if not overwrite: + raise FileExistsError(f"{layer_key} already exists") + del group[layer_key] + + # Default chunking: one row of pair matrix + if chunks is None: + chunks = (1, n_i, pair_dim) + + group.create_dataset( + layer_key, + data=pair_array, + chunks=chunks, + ) + +# ============================================================ +# METHOD 6 +# ============================================================ + +def store_triangle_attention( + archive_path: str, + layer_index: int, + attention_array, + chunks=None, + overwrite: bool = True, +): + """ + Store triangle-start attention weights for one layer. + + Archive location: + attention/triangle_start/layer_ + + Expected array shape (as captured from the forward hook): + (num_residues, num_residues, num_heads) + + The triangle attention mechanism in VizFold/OpenFold operates on pair + representations. Each element [i, j, h] is the attention weight that + residue-pair (i, j) received from head h during the triangular update + starting at node i. + + Recommended chunking: + (num_residues, num_residues, 1) + so that a single attention head can be read without loading all heads. + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + layer_index : int + Zero-based layer index for the triangle attention block. + + attention_array : numpy.ndarray | torch.Tensor + Attention weights, shape (num_residues, num_residues, num_heads). + + chunks : tuple, optional + Zarr chunk shape. Defaults to ``(num_residues, num_residues, 1)``. + + overwrite : bool + Replace existing data for this layer if True. + + Returns + ------- + None + """ + + + attention_array = tensor_to_numpy(attention_array) + + if attention_array.ndim != 3: + raise ValueError(f"Expected (num_residues, num_residues, num_heads), got {attention_array.shape}") + n_i, n_j, num_heads = attention_array.shape + + if n_i != n_j: + raise ValueError(f"Triangle attention must be square in first two dims, got {attention_array.shape}") + + path_info = archive_path + root = zarr.open_group(path_info, mode="a") + group = root["attention"]["triangle_start"] + layer_key = _layer_key(layer_index) + + if layer_key in group: + if not overwrite: + raise FileExistsError("{} already exists".format(layer_key)) + del group[layer_key] + + if chunks is None: + chunks = (n_i, n_j, 1) + + group.create_dataset(layer_key, data=attention_array, chunks=chunks) + + +# ============================================================ +# METHOD 7 +# ============================================================ + +def store_structure( + archive_path: str, + atom_positions, + atom_mask, + ptm: float, + overwrite: bool = True, +): + """ + Store predicted structure outputs into the ``structure/`` group. + + Archive layout written by this function: + + structure/atom_positions <- 2-D float array, shape (num_atoms, 3) + structure/atom_mask <- 1-D float array, shape (num_atoms,) + structure/ptm <- scalar float + + In OpenFold/VizFold ``num_atoms = num_residues × 37`` because the model + predicts coordinates for all 37 heavy-atom positions per residue + (backbone + up to 14 side-chain atoms, padded to 37). + ``atom_mask`` is 1.0 where an atom is present and 0.0 where padded. + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + atom_positions : numpy.ndarray | torch.Tensor + Predicted 3-D coordinates, shape (num_atoms, 3). + Typically obtained from ``output['final_atom_positions']`` + after flattening the residue dimension: + ``positions.reshape(-1, 3)``. + + atom_mask : numpy.ndarray | torch.Tensor + Binary mask, shape (num_atoms,), indicating valid atom positions. + Obtained from ``output['final_atom_mask'].reshape(-1)``. + + ptm : float + Predicted TM-score (pTM) for the run, scalar in [0, 1]. + + overwrite : bool + Replace existing structure data if True. + + Returns + ------- + None + """ + + atom_positions = tensor_to_numpy(atom_positions) + atom_mask = tensor_to_numpy(atom_mask) + + if atom_positions.ndim != 2: + raise ValueError(f"Expected atom_positions shape (num_atoms, 3), got {atom_positions.shape}") + + if atom_positions.shape[1] != 3: + raise ValueError("Expected atom_positions shape (num_atoms, 3), got {})".format(atom_positions.shape)) + + num_atoms = atom_positions.shape[0] + + if atom_mask.ndim != 1: + raise ValueError(f"Expected atom_mask shape ({num_atoms},), got {atom_mask.shape}") + if atom_mask.shape[0] != num_atoms: + raise ValueError(f"Expected atom_mask shape ({num_atoms},), got {atom_mask.shape}") + + root = zarr.open_group(archive_path, mode="a") + group = root.require_group("structure") + + dataArray = ["atom_positions", "atom_mask", "ptm"] + + for i in dataArray: + if i in group and overwrite: + del group[i] + + group.create_dataset("atom_positions", data=atom_positions) + group.create_dataset("atom_mask", data=atom_mask) + group["ptm"] = np.array(float(ptm)) + +# ============================================================ +# METHOD 8 +# ============================================================ + +def load_single_representation(archive_path: str, layer_index: int) -> np.ndarray: + """ + Load the per-residue (single) representation for one Evoformer layer. + + Archive location read: + representations/single/layer_ + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + layer_index : int + Zero-based Evoformer layer index to retrieve. + + Returns + ------- + numpy.ndarray + Per-residue representation, shape (num_residues, single_dim). + + Raises + ------ + KeyError + If the requested layer does not exist in the archive. + """ + root = zarr.open_group(archive_path, mode="r") + single = root["representations"]["single"] + layer_key = f"layer_{layer_index:02d}" + if layer_key not in single: + raise KeyError(f"Layer not found: representations/single/{layer_key}") + return np.asarray(single[layer_key]) + + +# ============================================================ +# METHOD 9 +# ============================================================ + +def load_pair_representation(archive_path: str, layer_index: int) -> np.ndarray: + """ + Load the pairwise representation for one Evoformer layer. + + Archive location read: + representations/pair/layer_ + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + layer_index : int + Zero-based Evoformer layer index to retrieve. + + Returns + ------- + numpy.ndarray + Pairwise representation, shape (num_residues, num_residues, pair_dim). + + Raises + ------ + KeyError + If the requested layer does not exist in the archive. + """ + root = zarr.open_group(archive_path, mode="r") + pair = root["representations"]["pair"] + layer_key = f"layer_{layer_index:02d}" + if layer_key not in pair: + raise KeyError(f"Layer not found: representations/pair/{layer_key}") + return np.asarray(pair[layer_key]) + + +# ============================================================ +# METHOD 10 +# ============================================================ + +def load_triangle_attention(archive_path: str, layer_index: int, head_index: int = None): + """ + Load triangle-start attention weights for one layer. + + If ``head_index`` is given, return only that head's attention matrix. + Otherwise return the full tensor for all heads. + + Archive location read: + attention/triangle_start/layer_ + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + layer_index : int + Zero-based layer index for the triangle attention block. + + head_index : int, optional + If provided, return only the slice ``[:, :, head_index]``, + shape (num_residues, num_residues). + If None, return the full array, shape + (num_residues, num_residues, num_heads). + + Returns + ------- + numpy.ndarray + Shape (num_residues, num_residues) when ``head_index`` is given, + or (num_residues, num_residues, num_heads) otherwise. + + Raises + ------ + KeyError + If the requested layer does not exist in the archive. + IndexError + If ``head_index`` is out of range. + """ + pass + + +# ============================================================ +# METHOD 11 +# ============================================================ + +def validate_archive(archive_path: str) -> bool: + """ + Validate the integrity of a VizFold Zarr archive. + + Checks that all required top-level groups exist and that every stored + array has the expected number of dimensions and internally consistent + shapes. + + Validation rules + ---------------- + metadata/ + All eight keys must be present. + ``sequence`` length must equal ``num_residues``. + ``recycle_info`` length must equal ``num_recycles``. + ``residue_index`` length must equal ``num_residues``. + + representations/single/ + Each array must be 2-D: (num_residues, single_dim). + ``num_residues`` must be consistent across all layers. + + representations/pair/ + Each array must be 3-D: (num_residues, num_residues, pair_dim). + First two dimensions must be equal and match ``num_residues``. + + attention/triangle_start/ + Each array must be 3-D: (num_residues, num_residues, num_heads). + First two dimensions must be equal and match ``num_residues``. + + structure/ + ``atom_positions`` must be 2-D with last dim == 3. + ``atom_mask`` must be 1-D with length == atom_positions.shape[0]. + ``ptm`` must be a scalar in [0, 1]. + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + Returns + ------- + bool + True if all checks pass; False (or raises) if any check fails. + """ + pass + + +# ============================================================ +# METHOD 12 — top-level orchestrator +# ============================================================ + +def archive_vizfold_run( + archive_path: str, + vizfold_output: dict, + config, + sequence: str, + num_recycles: int, + recycle_info, + overwrite: bool = False, +): + """ + Convert a complete VizFold inference run into a Zarr archive. + + This is the single entry-point that calls all other store_* functions + in the correct order. It expects the outputs captured during a VizFold + forward pass (typically collected via PyTorch forward hooks or returned + directly by the model) and writes them into ``archive_path``. + + Expected keys inside ``vizfold_output`` + ---------------------------------------- + ``single_representations`` : list of tensors, one per Evoformer layer + Each tensor has shape (num_residues, single_dim). + + ``pair_representations`` : list of tensors, one per Evoformer layer + Each tensor has shape (num_residues, num_residues, pair_dim). + + ``triangle_attention_start`` : list of tensors, one per attention layer + Each tensor has shape (num_residues, num_residues, num_heads). + + ``final_atom_positions`` : tensor, shape (num_residues, 37, 3) + All-atom predicted coordinates. Will be reshaped to + (num_residues * 37, 3) before storing. + + ``final_atom_mask`` : tensor, shape (num_residues, 37) + Atom presence mask. Will be reshaped to (num_residues * 37,). + + ``ptm`` : float + Predicted TM-score for this run. + + Processing steps + ---------------- + 1. ``open_archive`` — initialise the store and group tree + 2. ``store_metadata`` — write all run-level metadata + 3. ``store_single_representation`` (loop) — one call per layer + 4. ``store_pair_representation`` (loop) — one call per layer + 5. ``store_triangle_attention`` (loop) — one call per layer + 6. ``store_structure`` — atom positions, mask, pTM + + Parameters + ---------- + archive_path : str + Destination path for the Zarr archive + (e.g. ``'outputs/run.vizfold.zarr'``). + + vizfold_output : dict + Dictionary of tensors/arrays collected during inference + (see expected keys above). + + config : openfold.config.model_config (or equivalent) + Model configuration object; used to extract ``model_version`` + and ``config_version`` strings for the metadata group. + + sequence : str + Input amino-acid sequence in one-letter code. + + num_recycles : int + Number of recycling iterations that were performed. + + recycle_info : array-like, shape (num_recycles,) + Per-recycle diagnostic scalars logged during inference. + + overwrite : bool + If True, overwrite any existing archive at ``archive_path``. + + Returns + ------- + zarr.Group + The root group of the completed archive. + + Raises + ------ + KeyError + If a required key is missing from ``vizfold_output``. + FileExistsError + If ``archive_path`` already exists and ``overwrite=False``. + """ + pass diff --git a/candidate_format_comparison/archiveformat.py b/candidate_format_comparison/archiveformat.py new file mode 100644 index 00000000..0f7852d5 --- /dev/null +++ b/candidate_format_comparison/archiveformat.py @@ -0,0 +1,546 @@ +import os +import time +import json +import shutil +import random +import numpy as np +import concurrent.futures +import zarr +import h5py +import pandas as pd +import matplotlib.pyplot as plt +from numcodecs import Blosc +from tqdm import tqdm + + +# Global styling +plt.rcParams.update({ + "figure.figsize": (6, 4), + "axes.grid": True, + "font.size": 11 +}) + +# ----------------------------- +# CONFIGURATION +# ----------------------------- +CONFIG = { + "N_res": 256, + "num_layers": 12, + "num_heads": 8, + "hidden_dim": 384, + "pair_dim": 128, + "dtype": "float32", + "runs": 3, + "output_dir": "benchmark_data" +} + +np.random.seed(42) +random.seed(42) + +# ----------------------------- +# DATA GENERATION +# ----------------------------- +def generate_data(cfg): + N = cfg["N_res"] + L = cfg["num_layers"] + H = cfg["num_heads"] + + data = { + "metadata": { + "num_layers": L, + "num_heads": H, + "N_res": N + }, + "activations": [], + "attention": [], + "outputs": {} + } + + for l in range(L): + data["activations"].append({ + "single": np.random.randn(N, cfg["hidden_dim"]).astype(cfg["dtype"]), + "pair": np.random.randn(N, N, cfg["pair_dim"]).astype(cfg["dtype"]) + }) + + data["attention"].append( + np.random.randn(H, N, N).astype(cfg["dtype"]) + ) + + data["outputs"]["coordinates"] = np.random.randn(N, 3).astype(cfg["dtype"]) + data["outputs"]["confidence"] = np.random.randn(N).astype(cfg["dtype"]) + + return data + +import concurrent.futures +import random +import time +import zarr +import h5py + +# ----------------------------- +# PARALLEL ACCESS TEST +# ----------------------------- +def zarr_worker(root, cfg): + l = random.randint(0, cfg["num_layers"] - 1) + h = random.randint(0, cfg["num_heads"] - 1) + i = random.randint(0, cfg["N_res"] - 32) + return root[f"attention/layer_{l}"][h, i:i+32, i:i+32] + + +def hdf5_worker(path, cfg): + # each worker opens its own handle (realistic multi-process behavior) + with h5py.File(path, "r") as f: + l = random.randint(0, cfg["num_layers"] - 1) + h = random.randint(0, cfg["num_heads"] - 1) + i = random.randint(0, cfg["N_res"] - 32) + return f[f"attention/layer_{l}"][h, i:i+32, i:i+32] + + +def benchmark_parallel_access(cfg, num_workers=8, num_tasks=100): + results = {} + + zarr_path = os.path.join(cfg["output_dir"], "data.zarr") + h5_path = os.path.join(cfg["output_dir"], "data.h5") + + root = zarr.open(zarr_path, mode="r") + + # ----------------------------- + # ZARR PARALLEL + # ----------------------------- + start = time.time() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(zarr_worker, root, cfg) for _ in range(num_tasks)] + _ = [f.result() for f in futures] + results["zarr_parallel"] = time.time() - start + + # ----------------------------- + # HDF5 PARALLEL + # ----------------------------- + start = time.time() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(hdf5_worker, h5_path, cfg) for _ in range(num_tasks)] + _ = [f.result() for f in futures] + results["hdf5_parallel"] = time.time() - start + + return results + +# ----------------------------- +# NPZ +# ----------------------------- +def write_npz(data, path): + flat = {} + for i, layer in enumerate(data["activations"]): + flat[f"act_{i}_single"] = layer["single"] + flat[f"act_{i}_pair"] = layer["pair"] + flat[f"attn_{i}"] = data["attention"][i] + + flat["coords"] = data["outputs"]["coordinates"] + flat["conf"] = data["outputs"]["confidence"] + + np.savez_compressed(path, **flat) + + +def read_npz_full(path): + return np.load(path) + + +# ----------------------------- +# HDF5 +# ----------------------------- +def write_hdf5(data, path): + with h5py.File(path, "w") as f: + for i, layer in enumerate(data["activations"]): + grp = f.create_group(f"activations/layer_{i}") + grp.create_dataset("single", data=layer["single"]) + grp.create_dataset("pair", data=layer["pair"]) + f.create_dataset(f"attention/layer_{i}", data=data["attention"][i]) + + f.create_dataset("outputs/coords", data=data["outputs"]["coordinates"]) + f.create_dataset("outputs/conf", data=data["outputs"]["confidence"]) + + +def read_hdf5_full(path): + with h5py.File(path, "r") as f: + _ = f["outputs/coords"][:] + + +# ----------------------------- +# ZARR +# ----------------------------- +def write_zarr(data, path, chunk_size=32): + root = zarr.open(path, mode="w") + compressor = Blosc(cname="zstd", clevel=3) + + for i, layer in enumerate(data["activations"]): + grp = root.create_group(f"activations/layer_{i}") + + # single representation + grp.create_dataset( + "single", + data=layer["single"], + chunks=(chunk_size, layer["single"].shape[1]), + compressor=compressor + ) + + # pair representation + grp.create_dataset( + "pair", + data=layer["pair"], + chunks=(chunk_size, chunk_size, layer["pair"].shape[2]), + compressor=compressor + ) + + # attention + root.create_dataset( + f"attention/layer_{i}", + data=data["attention"][i], + chunks=(1, chunk_size, chunk_size), + compressor=compressor + ) + + # outputs + root.create_dataset( + "outputs/coords", + data=data["outputs"]["coordinates"], + compressor=compressor + ) + root.create_dataset( + "outputs/conf", + data=data["outputs"]["confidence"], + compressor=compressor + ) + +def read_zarr_full(path): + root = zarr.open(path, mode="r") + _ = root["outputs/coords"][:] + + +# ----------------------------- +# BENCHMARK HELPERS +# ----------------------------- +def time_fn(fn, *args): + start = time.time() + fn(*args) + return time.time() - start + +def get_size(path): + if os.path.isfile(path): + return os.path.getsize(path) + else: + total = 0 + for root, _, files in os.walk(path): + for f in files: + total += os.path.getsize(os.path.join(root, f)) + return total + +def benchmark_write(data, cfg): + results = {} + + os.makedirs(cfg["output_dir"], exist_ok=True) + + # NPZ + npz_path = os.path.join(cfg["output_dir"], "data.npz") + results["npz_write"] = time_fn(write_npz, data, npz_path) + results["npz_size"] = get_size(npz_path) + + # HDF5 + h5_path = os.path.join(cfg["output_dir"], "data.h5") + results["hdf5_write"] = time_fn(write_hdf5, data, h5_path) + results["hdf5_size"] = get_size(h5_path) + + # ZARR + zarr_path = os.path.join(cfg["output_dir"], "data.zarr") + if os.path.exists(zarr_path): + shutil.rmtree(zarr_path) + results["zarr_write"] = time_fn(write_zarr, data, zarr_path, 32) + results["zarr_size"] = get_size(zarr_path) + + return results + + +def benchmark_reads(cfg): + results = {} + + npz_path = os.path.join(cfg["output_dir"], "data.npz") + h5_path = os.path.join(cfg["output_dir"], "data.h5") + zarr_path = os.path.join(cfg["output_dir"], "data.zarr") + + # FULL READ + results["npz_full_read"] = time_fn(read_npz_full, npz_path) + results["hdf5_full_read"] = time_fn(read_hdf5_full, h5_path) + results["zarr_full_read"] = time_fn(read_zarr_full, zarr_path) + + # PARTIAL READ (layer) + layer = random.randint(0, CONFIG["num_layers"] - 1) + + # NPZ (loads everything) + start = time.time() + d = np.load(npz_path) + _ = d[f"attn_{layer}"] + results["npz_layer_read"] = time.time() - start + + # HDF5 + with h5py.File(h5_path, "r") as f: + start = time.time() + _ = f[f"attention/layer_{layer}"][:] + results["hdf5_layer_read"] = time.time() - start + + # ZARR + root = zarr.open(zarr_path, mode="r") + start = time.time() + _ = root[f"attention/layer_{layer}"][:] + results["zarr_layer_read"] = time.time() - start + + # HEAD READ + head = random.randint(0, CONFIG["num_heads"] - 1) + + with h5py.File(h5_path, "r") as f: + start = time.time() + _ = f[f"attention/layer_{layer}"][head] + results["hdf5_head_read"] = time.time() - start + + start = time.time() + _ = root[f"attention/layer_{layer}"][head] + results["zarr_head_read"] = time.time() - start + + # RANDOM ACCESS LOOP + def random_access_zarr(): + for _ in range(500): + l = random.randint(0, CONFIG["num_layers"] - 1) + h = random.randint(0, CONFIG["num_heads"] - 1) + i = random.randint(0, CONFIG["N_res"] - 32) + _ = root[f"attention/layer_{l}"][h, i:i+32, i:i+32] + + def random_access_hdf5(): + with h5py.File(h5_path, "r") as f: + for _ in range(500): + l = random.randint(0, CONFIG["num_layers"] - 1) + h = random.randint(0, CONFIG["num_heads"] - 1) + i = random.randint(0, CONFIG["N_res"] - 32) + _ = f[f"attention/layer_{l}"][h, i:i+32, i:i+32] + + results["zarr_random"] = time_fn(random_access_zarr) + results["hdf5_random"] = time_fn(random_access_hdf5) + + return results + + +def viz_simulation_zarr(root, cfg): + L = cfg["num_layers"] + H = cfg["num_heads"] + N = cfg["N_res"] + + start = time.time() + + for l in range(L): + for h in range(H): + i = random.randint(0, N - 32) + _ = root[f"attention/layer_{l}"][h, i:i+32, i:i+32] + + return time.time() - start + +def viz_simulation_hdf5(path, cfg): + L = cfg["num_layers"] + H = cfg["num_heads"] + N = cfg["N_res"] + + start = time.time() + + with h5py.File(path, "r") as f: + for l in range(L): + for h in range(H): + i = random.randint(0, N - 32) + _ = f[f"attention/layer_{l}"][h, i:i+32, i:i+32] + + return time.time() - start + +# ----------------------------- +# MAIN +# ----------------------------- +def run_benchmark(): + base_cfg = CONFIG + + N_values = [256, 512, 1024, 2048] + + all_results = [] + + for N in N_values: + print(f"\n=== Running for N_res = {N} ===") + + cfg = base_cfg.copy() + cfg["N_res"] = N + + data = generate_data(cfg) + + write_res = benchmark_write(data, cfg) + + npz_path = os.path.join(cfg["output_dir"], "data.npz") + h5_path = os.path.join(cfg["output_dir"], "data.h5") + zarr_path = os.path.join(cfg["output_dir"], "data.zarr") + + root = zarr.open(zarr_path, mode="r") + read_res = benchmark_reads(cfg) + + read_res["viz_hdf5"] = viz_simulation_hdf5(h5_path, cfg) + read_res["viz_zarr"] = viz_simulation_zarr(root, cfg) + + parallel_res = benchmark_parallel_access(cfg, num_workers=8, num_tasks=200) + + combined = {"N_res": N, **write_res, **read_res, **parallel_res} + all_results.append(combined) + + df = pd.DataFrame(all_results) + + df["npz_write_MBps"] = df["npz_size"] / (df["npz_write"] * 1e6) + df["hdf5_write_MBps"] = df["hdf5_size"] / (df["hdf5_write"] * 1e6) + df["zarr_write_MBps"] = df["zarr_size"] / (df["zarr_write"] * 1e6) + + def useful_bytes(cfg): + return 32 * 32 * 4 # float32 + + # Partial Read Efficiency + df["useful_bytes"] = useful_bytes(CONFIG) + df["hdf5_efficiency"] = df["useful_bytes"] / df["hdf5_head_read"] + df["zarr_efficiency"] = df["useful_bytes"] / df["zarr_head_read"] + + print("\n=== RESULTS ===") + print(df) + df.to_csv("benchmark_results.csv", index=False) + + + # # ========================================================= + # # Write Performance (System Cost Baseline) + # # ========================================================= + # plt.figure() + # plt.plot(df["N_res"], df["hdf5_write"], marker="o", label="HDF5") + # plt.plot(df["N_res"], df["zarr_write"], marker="o", label="Zarr") + # plt.title("Write Performance vs Sequence Length") + # plt.xlabel("Sequence Length (N_res)") + # plt.ylabel("Time (seconds)") + # plt.legend() + # plt.tight_layout() + # plt.show() + + + # ========================================================= + # Storage Footprint (Efficiency Baseline) + # ========================================================= + plt.figure() + plt.plot(df["N_res"], df["hdf5_size"], marker="o", label="HDF5") + plt.plot(df["N_res"], df["zarr_size"], marker="o", label="Zarr") + plt.plot(df["N_res"], df["npz_size"], marker="o", label="NPZ") + plt.title("Storage Size Scaling") + plt.xlabel("Sequence Length (N_res)") + plt.ylabel("Bytes") + plt.legend() + plt.tight_layout() + plt.show() + + + # ========================================================= + # Full-Archive Read (Not Primary Workload) + # ========================================================= + plt.figure() + plt.plot(df["N_res"], df["hdf5_full_read"], marker="o", label="HDF5") + plt.plot(df["N_res"], df["zarr_full_read"], marker="o", label="Zarr") + plt.plot(df["N_res"], df["npz_full_read"], marker="o", label="NPZ") + plt.title("Full Archive Read Performance") + plt.xlabel("Sequence Length (N_res)") + plt.ylabel("Time (seconds)") + plt.legend() + plt.tight_layout() + plt.show() + + + # ========================================================= + # Partial Access Latency (Interpretability Core) + # ========================================================= + plt.figure() + plt.plot(df["N_res"], df["hdf5_layer_read"], marker="o", label="HDF5") + plt.plot(df["N_res"], df["zarr_layer_read"], marker="o", label="Zarr") + plt.title("Layer-Level Partial Read Latency") + plt.xlabel("Sequence Length (N_res)") + plt.ylabel("Time (seconds)") + plt.legend() + plt.tight_layout() + plt.show() + + + # # ========================================================= + # # Fine-Grained Access (Head-Level Querying) + # # ========================================================= + # plt.figure() + # plt.plot(df["N_res"], df["hdf5_head_read"], marker="o", label="HDF5") + # plt.plot(df["N_res"], df["zarr_head_read"], marker="o", label="Zarr") + # plt.title("Attention Head Access Latency") + # plt.xlabel("Sequence Length (N_res)") + # plt.ylabel("Time (seconds)") + # plt.legend() + # plt.tight_layout() + # plt.show() + + + # ========================================================= + # Random Access Scaling (Key Interpretability Signal) + # ========================================================= + plt.figure() + plt.plot(df["N_res"], df["hdf5_random"], marker="o", label="HDF5") + plt.plot(df["N_res"], df["zarr_random"], marker="o", label="Zarr") + plt.title("Random Access Scaling (Interpretability Simulation)") + plt.xlabel("Sequence Length (N_res)") + plt.ylabel("Time (seconds)") + plt.legend() + plt.tight_layout() + plt.show() + + + # # ========================================================= + # # End-to-End Interpretability Workload (VizFold) + # # ========================================================= + # plt.figure() + # plt.plot(df["N_res"], df["viz_hdf5"], marker="o", label="HDF5") + # plt.plot(df["N_res"], df["viz_zarr"], marker="o", label="Zarr") + # plt.title("End-to-End Interpretability Workload (VizFold Simulation)") + # plt.xlabel("Sequence Length (N_res)") + # plt.ylabel("Time (seconds)") + # plt.legend() + # plt.tight_layout() + # plt.show() + + + # # ========================================================= + # # Access Pattern Sensitivity + # # ========================================================= + # plt.figure() + # plt.plot(df["N_res"], df["hdf5_random"], linestyle="--", label="HDF5 (Random)") + # plt.plot(df["N_res"], df["hdf5_layer_read"], linestyle="-", label="HDF5 (Layer)") + # plt.plot(df["N_res"], df["zarr_random"], linestyle="--", label="Zarr (Random)") + # plt.plot(df["N_res"], df["zarr_layer_read"], linestyle="-", label="Zarr (Layer)") + + # plt.title("Access Pattern Sensitivity Comparison") + # plt.xlabel("Sequence Length (N_res)") + # plt.ylabel("Time (seconds)") + # plt.legend() + # plt.tight_layout() + # plt.show() + + + # ========================================================= + # Parallel Interpretability Workload + # ========================================================= + plt.figure() + plt.plot(df["N_res"], df["hdf5_parallel"], marker="o", label="HDF5") + plt.plot(df["N_res"], df["zarr_parallel"], marker="o", label="Zarr") + + plt.title("Parallel Interpretability Workload") + plt.xlabel("Sequence Length (N_res)") + plt.ylabel("Time (seconds)") + plt.legend() + plt.tight_layout() + plt.show() + + return df + + +if __name__ == "__main__": + run_benchmark() \ No newline at end of file diff --git a/candidate_format_comparison/benchmark_results.csv b/candidate_format_comparison/benchmark_results.csv new file mode 100644 index 00000000..798103cd --- /dev/null +++ b/candidate_format_comparison/benchmark_results.csv @@ -0,0 +1,5 @@ +N_res,npz_write,npz_size,hdf5_write,hdf5_size,zarr_write,zarr_size,npz_full_read,hdf5_full_read,zarr_full_read,npz_layer_read,hdf5_layer_read,zarr_layer_read,hdf5_head_read,zarr_head_read,zarr_random,hdf5_random,viz_hdf5,viz_zarr,zarr_parallel,hdf5_parallel,npz_write_MBps,hdf5_write_MBps,zarr_write_MBps,useful_bytes,hdf5_efficiency,zarr_efficiency +256,20.959038972854614,400619395,0.4188868999481201,432569736,3.2118980884552,387861993,0.0011882781982421875,0.0011639595031738281,0.0006070137023925781,0.008630990982055664,0.0006988048553466797,0.04167675971984863,0.00026297569274902344,0.004118919372558594,0.14974498748779297,0.029961109161376953,0.0058100223541259766,0.029222965240478516,0.09261107444763184,0.037297964096069336,19.114397159090533,1032.6647504459427,120.7578765945674,4096,15575584.029011786,994435.5860152813 +512,83.98201513290405,1593689281,0.6405029296875,1720749448,10.872176885604858,1542865571,0.0007750988006591797,0.0013039112091064453,0.0003631114959716797,0.03376269340515137,0.004105806350708008,0.16460800170898438,0.0004391670227050781,0.012581825256347656,0.24428105354309082,0.0844881534576416,0.015540122985839844,0.05322766304016113,0.08516788482666016,0.0678720474243164,18.976554426301142,2686.559839530398,141.90953543469365,4096,9326747.65689468,325548.949897673 +1024,348.0411250591278,6357243010,4.268572807312012,6864022920,44.16462802886963,6154440006,0.004289865493774414,0.0027647018432617188,0.00041222572326660156,0.14620518684387207,0.025741100311279297,1.0825958251953125,0.002031087875366211,0.0656440258026123,0.38530993461608887,0.13773298263549805,0.031294822692871094,0.09022998809814453,0.11237192153930664,0.08473801612854004,18.26578111687228,1608.036978599033,139.35224365474,4096,2016653.2672848925,62397.14810173936 +2048,1411.5884323120117,25393961662,281.561222076416,27418226056,351.5809578895569,24585016703,0.011762857437133789,0.016555070877075195,0.0022859573364257812,0.556689977645874,0.09723091125488281,3.3311328887939453,0.0060617923736572266,0.21351218223571777,0.42122387886047363,0.2177579402923584,0.05145597457885742,0.10213494300842285,0.0932619571685791,0.1259291172027588,17.98963570451463,97.37926925377054,69.92704283695296,4096,675707.7358505408,19183.917081967764 diff --git a/candidate_format_comparison/fullarchivereadperformance.png b/candidate_format_comparison/fullarchivereadperformance.png new file mode 100644 index 00000000..ef42499f Binary files /dev/null and b/candidate_format_comparison/fullarchivereadperformance.png differ diff --git a/candidate_format_comparison/layerlevelpartialreadlatency.png b/candidate_format_comparison/layerlevelpartialreadlatency.png new file mode 100644 index 00000000..2974d8b3 Binary files /dev/null and b/candidate_format_comparison/layerlevelpartialreadlatency.png differ diff --git a/candidate_format_comparison/parallelinterpretabilityworkload.png b/candidate_format_comparison/parallelinterpretabilityworkload.png new file mode 100644 index 00000000..9f3491fa Binary files /dev/null and b/candidate_format_comparison/parallelinterpretabilityworkload.png differ diff --git a/candidate_format_comparison/randomaccessscaling.png b/candidate_format_comparison/randomaccessscaling.png new file mode 100644 index 00000000..616d1c0c Binary files /dev/null and b/candidate_format_comparison/randomaccessscaling.png differ diff --git a/candidate_format_comparison/requirements.txt b/candidate_format_comparison/requirements.txt new file mode 100644 index 00000000..ea654a16 --- /dev/null +++ b/candidate_format_comparison/requirements.txt @@ -0,0 +1,13 @@ +numpy>=1.24 + +pandas>=2.0 + +matplotlib>=3.7 + +h5py>=3.10 + +zarr<3 + +numcodecs>=0.11 + +tqdm>=4.65 \ No newline at end of file diff --git a/candidate_format_comparison/storagesizescaling.png b/candidate_format_comparison/storagesizescaling.png new file mode 100644 index 00000000..ab4f487d Binary files /dev/null and b/candidate_format_comparison/storagesizescaling.png differ diff --git a/standardized_archive/MNIST/.gitignore b/standardized_archive/MNIST/.gitignore new file mode 100644 index 00000000..c67da8ee --- /dev/null +++ b/standardized_archive/MNIST/.gitignore @@ -0,0 +1,17 @@ +# datasets +data/ + +# generated archives +*.zarr + +# virtual environments +zarr_env/ +venv/ +.env/ + +# python cache +__pycache__/ +*.pyc + +# OS files +.DS_Store \ No newline at end of file diff --git a/standardized_archive/MNIST/mnist_trace_project/archive/schema.py b/standardized_archive/MNIST/mnist_trace_project/archive/schema.py new file mode 100644 index 00000000..ebcadc22 --- /dev/null +++ b/standardized_archive/MNIST/mnist_trace_project/archive/schema.py @@ -0,0 +1,45 @@ +import zarr +from numcodecs import Blosc + +def create_archive(path, dataset_size): + compressor = Blosc(cname='zstd', clevel=5, shuffle=Blosc.SHUFFLE) + + root = zarr.open(path, mode="w") + + root.create_dataset( + "inputs/images", + shape=(dataset_size, 1, 28, 28), + chunks=(64, 1, 28, 28), + dtype="float32", + compressor=compressor, + ) + + root.create_dataset( + "outputs/logits", + shape=(dataset_size, 10), + chunks=(64, 10), + dtype="float32", + compressor=compressor, + ) + + root.create_dataset( + "outputs/predictions", + shape=(dataset_size,), + chunks=(64,), + dtype="int64", + compressor=compressor, + ) + + # Activations (example shapes — update after first forward pass if needed) + root.create_dataset( + "activations/conv1", + shape=(dataset_size, 16, 26, 26), + chunks=(64, 16, 26, 26), + dtype="float32", + compressor=compressor, + ) + + root.attrs["dataset"] = "MNIST" + root.attrs["archive_version"] = "v1" + + return root \ No newline at end of file diff --git a/standardized_archive/MNIST/mnist_trace_project/archive/writer.py b/standardized_archive/MNIST/mnist_trace_project/archive/writer.py new file mode 100644 index 00000000..284893c6 --- /dev/null +++ b/standardized_archive/MNIST/mnist_trace_project/archive/writer.py @@ -0,0 +1,10 @@ +def write_batch(root, start_idx, images, logits, preds, activations): + batch_size = images.shape[0] + end_idx = start_idx + batch_size + + root["inputs/images"][start_idx:end_idx] = images + root["outputs/logits"][start_idx:end_idx] = logits + root["outputs/predictions"][start_idx:end_idx] = preds + root["activations/conv1"][start_idx:end_idx] = activations["conv1"] + + return end_idx \ No newline at end of file diff --git a/standardized_archive/MNIST/mnist_trace_project/benchmarks/benchmark_zarr.py b/standardized_archive/MNIST/mnist_trace_project/benchmarks/benchmark_zarr.py new file mode 100644 index 00000000..0c99ff74 --- /dev/null +++ b/standardized_archive/MNIST/mnist_trace_project/benchmarks/benchmark_zarr.py @@ -0,0 +1,9 @@ +import time +import zarr + +def benchmark(path): + root = zarr.open(path, mode="r") + + start = time.time() + _ = root["activations/conv1"][500] + print("Random read time:", time.time() - start) \ No newline at end of file diff --git a/standardized_archive/MNIST/mnist_trace_project/config.py b/standardized_archive/MNIST/mnist_trace_project/config.py new file mode 100644 index 00000000..e69de29b diff --git a/standardized_archive/MNIST/mnist_trace_project/inference/run_inference.py b/standardized_archive/MNIST/mnist_trace_project/inference/run_inference.py new file mode 100644 index 00000000..7ef85097 --- /dev/null +++ b/standardized_archive/MNIST/mnist_trace_project/inference/run_inference.py @@ -0,0 +1,37 @@ +import torch +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + +from models.cnn import SimpleCNN +from utils.hooks import register_hooks +from archive.schema import create_archive +from archive.writer import write_batch + +def run(output_path): + + transform = transforms.ToTensor() + dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform) + dataloader = DataLoader(dataset, batch_size=64) + + model = SimpleCNN() + model.eval() + + activations = register_hooks(model) + + root = create_archive(output_path, len(dataset)) + + start_idx = 0 + + with torch.no_grad(): + for images, labels in dataloader: + logits = model(images) + preds = logits.argmax(dim=1) + + start_idx = write_batch( + root, + start_idx, + images.numpy(), + logits.numpy(), + preds.numpy(), + activations + ) \ No newline at end of file diff --git a/standardized_archive/MNIST/mnist_trace_project/main.py b/standardized_archive/MNIST/mnist_trace_project/main.py new file mode 100644 index 00000000..fbcd6d7f --- /dev/null +++ b/standardized_archive/MNIST/mnist_trace_project/main.py @@ -0,0 +1,4 @@ +from inference.run_inference import run + +if __name__ == "__main__": + run("mnist_trace.zarr") \ No newline at end of file diff --git a/standardized_archive/MNIST/mnist_trace_project/models/cnn.py b/standardized_archive/MNIST/mnist_trace_project/models/cnn.py new file mode 100644 index 00000000..ba204239 --- /dev/null +++ b/standardized_archive/MNIST/mnist_trace_project/models/cnn.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class SimpleCNN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 16, 3) + self.conv2 = nn.Conv2d(16, 32, 3) + self.fc1 = nn.Linear(32 * 24 * 24, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x \ No newline at end of file diff --git a/standardized_archive/MNIST/mnist_trace_project/utils/hooks.py b/standardized_archive/MNIST/mnist_trace_project/utils/hooks.py new file mode 100644 index 00000000..b9f51298 --- /dev/null +++ b/standardized_archive/MNIST/mnist_trace_project/utils/hooks.py @@ -0,0 +1,14 @@ +def register_hooks(model): + activations = {} + + def save_activation(name): + def hook(model, input, output): + activations[name] = output.detach().cpu().numpy() + return hook + + model.conv1.register_forward_hook(save_activation("conv1")) + model.conv2.register_forward_hook(save_activation("conv2")) + model.fc1.register_forward_hook(save_activation("fc1")) + model.fc2.register_forward_hook(save_activation("fc2")) + + return activations \ No newline at end of file diff --git a/standardized_archive/MNIST/tests/test1.py b/standardized_archive/MNIST/tests/test1.py new file mode 100644 index 00000000..d414a832 --- /dev/null +++ b/standardized_archive/MNIST/tests/test1.py @@ -0,0 +1,105 @@ +import zarr +import matplotlib.pyplot as plt + +# root = zarr.open("mnist_trace.zarr", mode="r") + +# image = root["inputs/images"][0][0] # shape: (28, 28) + +# plt.imshow(image, cmap="gray") +# plt.title("MNIST Image") +# plt.axis("off") +# plt.show() +# first shows the black and white MNIST image + +########################################### + +# activation = root["activations/conv1"][0] # shape (16, 26, 26) + +# Pick one filter (e.g., filter 0) +# feature_map = activation[0] + +# plt.imshow(feature_map, cmap="viridis") +# plt.title("Conv1 Filter 0 Activation") +# plt.colorbar() +# plt.axis("off") +# plt.show() +#second shows the conv1 activation heatmap from filter 0 + +########################################### + +import numpy as np + +# activation = root["activations/conv1"][0] +# num_filters = activation.shape[0] + +# fig, axes = plt.subplots(4, 4, figsize=(8, 8)) + +# for i, ax in enumerate(axes.flat): +# ax.imshow(activation[i], cmap="viridis") +# ax.axis("off") +# ax.set_title(f"F{i}") + +# plt.tight_layout() +# plt.show() +#third shows the 16 heatmaps, one for each filter + +########################################### + +# plt.imshow(image, cmap="gray") +# plt.imshow(feature_map, cmap="jet", alpha=0.5) +# plt.axis("off") +# plt.show() +#fourth shows the original digit and where the filter activtes on top of it + +########################################### + +# Normalize activation for better visualization +# feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min()) + +# fig, ax = plt.subplots(figsize=(5,5)) + +# # Show original image +# ax.imshow(image, cmap="gray") + +# # Overlay activation +# overlay = ax.imshow(feature_map, cmap="jet", alpha=0.5) + +# ax.set_title("Original Digit + Conv1 Filter 0 Activation") +# ax.axis("off") + +# # Add colorbar as key +# cbar = plt.colorbar(overlay, ax=ax) +# cbar.set_label("Activation Intensity") + +# plt.show() + +########################################### + +root = zarr.open("mnist_trace.zarr", mode="r") + +image = root["inputs/images"][0][0] + +activation = root["activations/conv1"][0] + +global_min = activation.min() +global_max = activation.max() + +fig, axes = plt.subplots(4, 4, figsize=(8,8)) + +for i, ax in enumerate(axes.flat): + fmap = activation[i] + ax.imshow(image, cmap="gray") + + im = ax.imshow( + fmap, + cmap="jet", + alpha=0.5, + vmin=global_min, + vmax=global_max + ) + ax.set_title(f"Filter {i}") + ax.axis("off") + +fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, label="Activation Intensity") + +plt.show() diff --git a/standardized_archive/ViT/.gitignore b/standardized_archive/ViT/.gitignore new file mode 100644 index 00000000..750bf36a --- /dev/null +++ b/standardized_archive/ViT/.gitignore @@ -0,0 +1,16 @@ +# Python +__pycache__/ +*.pyc + +# Virtual environments +venv/ +.env/ + +# Zarr archives +*.zarr/ + +# OS files +.DS_Store + +# Jupyter +.ipynb_checkpoints/ \ No newline at end of file diff --git a/standardized_archive/ViT/archive_schema.py b/standardized_archive/ViT/archive_schema.py new file mode 100644 index 00000000..74bd650b --- /dev/null +++ b/standardized_archive/ViT/archive_schema.py @@ -0,0 +1,58 @@ +import zarr +import numpy as np + +def create_vit_archive(archive_path, model_config, num_layers=12, num_heads=12, hidden_dim=768, num_tokens=197): + """ + Create a standardized ViT archive with formal metadata, layer shapes, and head counts. + + Parameters: + archive_path (str): Path to the Zarr archive. + model_config (transformers.PretrainedConfig): ViT model config. + num_layers (int): Number of transformer layers. + num_heads (int): Number of attention heads. + hidden_dim (int): Hidden dimension size. + num_tokens (int): Number of tokens (patches + CLS). + + Returns: + zarr.hierarchy.Group: Root Zarr archive. + """ + + # Root archive + archive = zarr.open(archive_path, mode="w") + + # Metadata group + meta = archive.create_group("metadata") + meta.attrs["model_name"] = getattr(model_config, "model_type", "vit-base") + meta.attrs["num_layers"] = num_layers + meta.attrs["num_heads"] = num_heads + meta.attrs["hidden_dim"] = hidden_dim + meta.attrs["num_tokens"] = num_tokens + meta.attrs["input_shape"] = (3, 224, 224) + + # Placeholder for layer-wise metadata + layers_meta = meta.require_group("layers") + for i in range(num_layers): + layer_group = layers_meta.require_group(f"layer_{i}") + layer_group.attrs["hidden_shape"] = (num_tokens, hidden_dim) + layer_group.attrs["num_heads"] = num_heads + layer_group.attrs["attention_shape"] = (num_heads, num_tokens, num_tokens) + + # Inputs + archive.create_group("inputs") + # Processed images will be stored later as datasets + # shape example: (batch, 3, 224, 224) + + # Activations + archive.create_group("activations") + # Each layer will have a "hidden_states" dataset + # shape example: (batch, num_tokens, hidden_dim) + + # Attention + archive.create_group("attention") + # Each layer will have a dataset of shape (batch, num_heads, tokens, tokens) + + # Outputs + archive.create_group("outputs") + # Logits and predicted class will be stored here + + return archive \ No newline at end of file diff --git a/standardized_archive/ViT/run_trace.py b/standardized_archive/ViT/run_trace.py new file mode 100644 index 00000000..e607360c --- /dev/null +++ b/standardized_archive/ViT/run_trace.py @@ -0,0 +1,21 @@ +from utils import load_image +from trace_vit import trace_vit + + +IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg" + + +def main(): + + image = load_image(IMAGE_URL) + + archive = trace_vit( + image=image, + archive_path="vit_trace.zarr" + ) + + print("Trace archive created at vit_trace.zarr") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/standardized_archive/ViT/trace_vit.py b/standardized_archive/ViT/trace_vit.py new file mode 100644 index 00000000..9766b51e --- /dev/null +++ b/standardized_archive/ViT/trace_vit.py @@ -0,0 +1,64 @@ +import torch +import numpy as np +from transformers import ViTImageProcessor, ViTForImageClassification +from archive_schema import create_vit_archive + +def trace_vit(image, archive_path): + """ + Trace a ViT image through the model and store inputs, activations, attention, + outputs, and metadata into a standardized Zarr archive. + """ + + # Load model & processor + processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") + model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="eager") + model.eval() + + # Preprocess image + inputs = processor(images=image, return_tensors="pt") + + # Create archive with metadata + archive = create_vit_archive(archive_path, model.config) + + # Store inputs + archive["inputs"].create_dataset( + "processed_image", + data=inputs["pixel_values"].numpy(), + chunks=True + ) + + # Forward pass with gradients off + with torch.no_grad(): + outputs = model( + **inputs, + output_attentions=True, + output_hidden_states=True + ) + + # Store activations + for i, hidden in enumerate(outputs.hidden_states): + layer_group = archive["activations"].require_group(f"layer_{i}") + layer_group.create_dataset( + "hidden_states", + data=hidden.detach().cpu().numpy(), + chunks=True + ) + + # Store CLS token output from the last layer + cls_output = outputs.hidden_states[-1][:, 0, :].detach().cpu().numpy() + archive["activations"].create_dataset("cls_output", data=cls_output, chunks=True) + + # Store attention + for i, attn in enumerate(outputs.attentions): + archive["attention"].create_dataset( + f"layer_{i}", + data=attn.detach().cpu().numpy(), + chunks=True + ) + + # Store outputs + logits = outputs.logits.detach().cpu().numpy() + archive["outputs"].create_dataset("logits", data=logits) + archive["outputs"].attrs["predicted_class"] = int(logits.argmax(-1)) + + return archive \ No newline at end of file diff --git a/standardized_archive/ViT/utils.py b/standardized_archive/ViT/utils.py new file mode 100644 index 00000000..97dcc347 --- /dev/null +++ b/standardized_archive/ViT/utils.py @@ -0,0 +1,7 @@ +from PIL import Image +import requests + + +def load_image(url): + image = Image.open(requests.get(url, stream=True).raw) + return image \ No newline at end of file diff --git a/standardized_archive/ViT/visualize_attention.py b/standardized_archive/ViT/visualize_attention.py new file mode 100644 index 00000000..36c8586e --- /dev/null +++ b/standardized_archive/ViT/visualize_attention.py @@ -0,0 +1,85 @@ +import zarr +import matplotlib.pyplot as plt +import numpy as np + +# Load archive + +z = zarr.open("vit_trace.zarr", mode="r") +print(z.tree()) + +# Load image + +img = z["inputs"]["processed_image"][:][0] # (3,224,224) + +img = np.transpose(img, (1,2,0)) + +img = (img * 0.5) + 0.5 +img = np.clip(img, 0, 1) + +plt.imshow(img) +plt.title("Input Image") +plt.axis("off") +plt.show() + +# Show ViT patches + +patch_size = 16 + +fig, axes = plt.subplots(14,14, figsize=(8,8)) + +for i in range(14): + for j in range(14): + + patch = img[ + i*patch_size:(i+1)*patch_size, + j*patch_size:(j+1)*patch_size + ] + + axes[i,j].imshow(patch) + axes[i,j].axis("off") + +plt.suptitle("ViT Image Patches") +plt.show() + +# Load attention + +print(list(z["attention"].keys())) + +layers = list(z["attention"].keys()) + +fig, axes = plt.subplots(3, 4, figsize=(12,9)) + +for i, layer in enumerate(layers): + + row = i // 4 + col = i % 4 + + attn = z["attention"][layer][:][0] # remove batch + attn = attn.mean(axis=0) # average heads + + cls_attn = attn[0,1:].reshape(14,14) + cls_attn = cls_attn / cls_attn.max() + + heatmap = np.kron(cls_attn, np.ones((16,16))) + + axes[row,col].imshow(img) + axes[row,col].imshow(heatmap, cmap="jet", alpha=0.4) + axes[row,col].set_title(layer) + axes[row,col].axis("off") + +plt.suptitle("ViT Attention Across Layers", fontsize=16) +plt.tight_layout() +plt.show() + +# Predicted class + +pred = z["outputs"].attrs["predicted_class"] +print("Predicted class index:", pred) + +from transformers import ViTForImageClassification + +model = ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224" +) + +print("Predicted label:", model.config.id2label[pred]) \ No newline at end of file