diff --git a/archive/1UBQ_result_model_1_ptm.pickle b/archive/1UBQ_result_model_1_ptm.pickle new file mode 100644 index 00000000..109b0f0a Binary files /dev/null and b/archive/1UBQ_result_model_1_ptm.pickle differ diff --git a/archive/ARCHITECTURE_DIAGRAM.xml b/archive/ARCHITECTURE_DIAGRAM.xml new file mode 100644 index 00000000..e76b1590 --- /dev/null +++ b/archive/ARCHITECTURE_DIAGRAM.xml @@ -0,0 +1,106 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/archive/NEW_SPEC.md b/archive/NEW_SPEC.md new file mode 100644 index 00000000..13424041 --- /dev/null +++ b/archive/NEW_SPEC.md @@ -0,0 +1,1437 @@ +# VizFold Archive Utilities - Complete API Specification + +**Version:** 2.0 (Post-Refactor) +**Date:** April 26, 2026 +**Status:** Production Ready +**Archive Format:** VizFold 1.0 (Zarr-based hierarchical storage) + +--- + +## Table of Contents + +1. [Overview](#overview) +2. [Architecture](#architecture) +3. [Complete Method Reference](#complete-method-reference) +4. [User Workflows](#user-workflows) +5. [Archive Structure](#archive-structure) +6. [Validation & Safety](#validation--safety) +7. [Pickle Ingestion & Traceability](#pickle-ingestion--traceability) +8. [API Patterns & Guarantees](#api-patterns--guarantees) + +--- + +## Overview + +The **VizFold Archive Utilities** provide a complete API for building, reading, and managing Zarr-based inference trace archives from protein structure prediction models (OpenFold, VizFold, etc.). + +### Key Features + +- **Incremental Writing**: Build archives layer-by-layer, optionally, safely +- **Read-Optimized Loading**: Selective head loading, metadata retrieval without large tensors +- **Multiple Input Formats**: PyTorch tensors, NumPy arrays, text files, pickle outputs +- **Strict & Lenient Validation**: Support both complete archive validation and incremental workflow validation +- **Traceability**: Key matching records show exactly which pickle keys were matched during ingestion +- **Consistent Overwrite Safety**: All write methods use `overwrite=False` default for protection + +### Use Cases + +1. **Incremental Trace Building**: Write metadata → per-layer representations → attention → structure +2. **Pickle Ingestion**: Auto-extract and route arrays from OpenFold output `.pkl` files +3. **Visualization Server**: Load specific heads/layers on-demand without full archive in memory +4. **Archive Validation**: Check integrity at any point in workflow (strict for complete, lenient for partial) +5. **Model Analysis**: Extract and organize inference traces for debugging, interpretation, ablation + +--- + +## Architecture + +### Three-Module Design + +The codebase is split into three focused modules for separation of concerns: + +#### **core.py** — Shared Utilities & Validation (156 lines) + +Contains low-level tensor conversion, Zarr writing, layer indexing, and archive validation. + +- `tensor_to_numpy(tensor)` — Convert PyTorch/NumPy to NumPy +- `tensor_to_zarr_array(path, tensor, chunks, overwrite)` — Write arrays to Zarr with safety +- `_validate_layer_index(layer_index)` — Validate layer indices +- `validate_archive(path, strict=True)` — Comprehensive archive integrity check + +**Responsibility**: Foundational I/O and validation shared by both store.py and load.py + +--- + +#### **store.py** — Writing Methods (391 lines) + +Contains all methods for writing data into archives. + +- `store_single_representation(path, layer_index, single_array, overwrite=False)` — Per-layer residue embeddings +- `store_pair_representation(path, layer_index, pair_array, overwrite=False)` — Per-layer token-token relationships +- `store_attention(path, attention_type, layer_index, attention_array, overwrite=False)` — Attention heads by type +- `store_structure_coordinates(path, atom_positions, atom_mask=None, ptm=None, overwrite=False)` — 3D structure +- `store_metadata(path, model_version, config_version, sequence, num_residues, num_recycles, ...)` — Run context + +**Responsibility**: All data ingestion and storage operations with validation and overwrite control + +--- + +#### **load.py** — Reading & Orchestration (768 lines) + +Contains all methods for reading data, parsing external formats, and orchestrating writes. + +- `load_attention_head(path, attention_type, layer_index, head_index)` — Single head loading +- `ingest_attention_txt(archive_path, txt_file, layer_index, num_tokens, ...)` — Parse text attention format +- `_extract_best_matching_array(container, key_token_patterns)` — Tokenized key matching helper +- `ingest_output_pkl(archive_path, pkl_file, overwrite=False)` — Extract & route arrays from pickle output +- `_load_dataset_as_python_value(dataset)` — Normalize 0-D Zarr values to Python scalars +- `load_metadata(path)` — Read run context as dictionary +- `load_single_representation(path, layer_index)` — Per-layer residue embeddings +- `load_pair_representation(path, layer_index)` — Per-layer token-token relationships +- `ArchiveOrchestrator` — Thin coordination class for sequencing writes with event logging + +**Responsibility**: All data retrieval, external format parsing, and write sequencing + +--- + +### Import Graph + +``` +store.py + ├── from core import: _validate_layer_index, tensor_to_numpy, tensor_to_zarr_array + +load.py + ├── from core import: _validate_layer_index, tensor_to_numpy, validate_archive + └── from store import: store_attention, store_metadata, store_pair_representation, + store_single_representation, store_structure_coordinates +``` + +--- + +## Complete Method Reference + +### **core.py Methods** + +#### METHOD 1: `tensor_to_numpy(tensor)` + +Convert any tensor to NumPy for standardized storage. + +```python +def tensor_to_numpy(tensor: torch.Tensor | np.ndarray) -> np.ndarray +``` + +**Parameters:** +- `tensor` (torch.Tensor | numpy.ndarray): Input from any framework + +**Returns:** +- `numpy.ndarray`: CPU-resident NumPy array + +**Behavior:** +- PyTorch tensor → detach, move to CPU, convert to NumPy +- NumPy array → return unchanged +- Other → `np.asarray()` fallback + +**Errors:** None (always succeeds) + +**Example:** +```python +import torch +import numpy as np +from core import tensor_to_numpy + +pt_tensor = torch.randn(10, 512) +np_array = tensor_to_numpy(pt_tensor) # → ndarray (10, 512) + +np_input = np.ones((5, 3)) +result = tensor_to_numpy(np_input) # → same ndarray +``` + +--- + +#### METHOD 2: `tensor_to_zarr_array(path, tensor, chunks=None, overwrite=False)` + +Write tensors directly to Zarr with nested path support and overwrite protection. + +```python +def tensor_to_zarr_array( + path: str, + tensor: torch.Tensor | np.ndarray, + chunks: tuple | None = None, + overwrite: bool = False +) -> zarr.Array +``` + +**Parameters:** +- `path` (str): Zarr location. Supports two formats: + - `"archive.zarr::group/dataset"` for nested paths (archive mode) + - `"file.zarr"` for direct Zarr array (file mode) +- `tensor`: Data to store +- `chunks` (tuple, optional): Zarr chunk dimensions (e.g., `(1, tokens, tokens)`) +- `overwrite` (bool): Whether to replace existing data + +**Returns:** +- `zarr.Array`: Zarr array reference + +**Errors:** +- `ValueError`: Empty dataset path in archive mode +- `FileExistsError`: Data exists and `overwrite=False` + +**Example:** +```python +from core import tensor_to_zarr_array +import numpy as np + +# Archive mode: nested path +attention = np.random.randn(8, 128, 128) +tensor_to_zarr_array( + "trace.zarr::attention/triangle_start/layer_00", + attention, + chunks=(1, 128, 128), + overwrite=False +) + +# File mode: direct Zarr array +representation = np.random.randn(128, 768) +tensor_to_zarr_array("reps.zarr", representation) +``` + +--- + +#### METHOD 3: `_validate_layer_index(layer_index)` + +Validate transformer layer indexing (internal). + +```python +def _validate_layer_index(layer_index: int) -> None +``` + +**Parameters:** +- `layer_index` (int): Layer index to validate + +**Errors:** +- `TypeError`: Not an integer +- `ValueError`: Negative index + +--- + +#### METHOD 4: `validate_archive(path, strict=True)` + +Comprehensive integrity check with strict/lenient modes. + +```python +def validate_archive(path: str, strict: bool = True) -> dict +``` + +**Parameters:** +- `path` (str): Archive root directory +- `strict` (bool): + - `True` (default): Raises exceptions, requires complete archive + - `False`: Returns warnings, allows partial/incremental archives + +**Returns:** +```python +{ + "valid": bool, # Overall status + "strict_mode": bool, # Mode used + "path": str, # Checked path + "errors": list[str], # Critical issues + "warnings": list[str], # Soft issues (lenient mode) + "components_found": { + "metadata": bool, + "representations/single": bool, + "representations/pair": bool, + "attention": bool, + "structure/atom_positions": bool + } +} +``` + +**Strict Mode (strict=True):** +- Requires: structure/atom_positions, metadata group, layers group (non-empty), representations/pair +- Raises `ValueError` on any missing required component +- Validates shapes: (N, 3) for positions, (N, N, D) for pair + +**Lenient Mode (strict=False):** +- Requires: structure/atom_positions only (basic structure) +- Missing optional components → warnings only +- Never raises exceptions +- Ideal for incremental workflows + +**Example:** +```python +from core import validate_archive + +# Complete archive validation +report = validate_archive("trace.zarr", strict=True) +assert report["valid"] # Raises ValueError if invalid + +# Incremental workflow validation +report = validate_archive("trace.zarr", strict=False) +if report["warnings"]: + print(f"Warnings: {report['warnings']}") +# Never raises even if missing optional components +``` + +--- + +### **store.py Methods** + +#### METHOD 5: `store_single_representation(path, layer_index, single_array, overwrite=False)` + +Store per-layer single (residue-level) representations. + +```python +def store_single_representation( + path: str, + layer_index: int, + single_array: np.ndarray, + overwrite: bool = False +) -> None +``` + +**Parameters:** +- `path` (str): Archive root +- `layer_index` (int): Transformer layer (0-indexed) +- `single_array` (np.ndarray): Shape `(num_residues, hidden_dim)` +- `overwrite` (bool): Replace if exists + +**Archive Location:** +- `representations/single/layer_00` +- `representations/single/layer_01` +- etc. + +**Errors:** +- `ValueError`: Array is not 2D +- `FileExistsError`: Layer exists and `overwrite=False` + +**Example:** +```python +from store import store_single_representation + +# Layer 0 residue embeddings +single = np.random.randn(128, 512) # 128 residues, 512-dim embedding +store_single_representation("trace.zarr", 0, single) + +# Layer 5 residue embeddings +store_single_representation("trace.zarr", 5, single) +``` + +--- + +#### METHOD 6: `store_pair_representation(path, layer_index, pair_array, overwrite=False)` + +Store per-layer pair (residue-residue) representations. + +```python +def store_pair_representation( + path: str, + layer_index: int, + pair_array: np.ndarray, + overwrite: bool = False +) -> None +``` + +**Parameters:** +- `path` (str): Archive root +- `layer_index` (int): Transformer layer (0-indexed) +- `pair_array` (np.ndarray): Shape `(tokens, tokens, pair_dim)` — must be square in first 2 dims +- `overwrite` (bool): Replace if exists + +**Archive Location:** +- `representations/pair/layer_00` +- `representations/pair/layer_01` +- etc. + +**Errors:** +- `ValueError`: Array is not 3D or not square +- `FileExistsError`: Layer exists and `overwrite=False` + +**Example:** +```python +from store import store_pair_representation + +# Layer 0 pair embeddings: 128×128 tokens × 128 pair dims +pair = np.random.randn(128, 128, 128) +store_pair_representation("trace.zarr", 0, pair) +``` + +--- + +#### METHOD 7: `store_attention(path, attention_type, layer_index, attention_array, overwrite=False)` + +Store attention head maps organized by type. + +```python +def store_attention( + path: str, + attention_type: str, + layer_index: int, + attention_array: np.ndarray, + overwrite: bool = False +) -> None +``` + +**Parameters:** +- `path` (str): Archive root +- `attention_type` (str): Type identifier (e.g., "triangle_start", "triangle_end", "pairwise") +- `layer_index` (int): Transformer layer (0-indexed) +- `attention_array` (np.ndarray): Shape `(num_heads, tokens, tokens)` — must be square +- `overwrite` (bool): Replace if exists + +**Archive Location:** +- `attention/{attention_type}/layer_00` +- `attention/{attention_type}/layer_01` +- etc. + +**Chunking:** +- Applied automatically: `(1, tokens, tokens)` +- Enables head-by-head loading without loading entire tensor + +**Errors:** +- `ValueError`: Array is not 3D, not square, or `attention_type` is empty +- `FileExistsError`: Layer exists and `overwrite=False` + +**Example:** +```python +from store import store_attention + +# 8 attention heads, 128×128 tokens +attn = np.random.randn(8, 128, 128) + +store_attention("trace.zarr", "triangle_start", 0, attn) +store_attention("trace.zarr", "triangle_end", 0, attn) +store_attention("trace.zarr", "pairwise", 0, attn) +``` + +--- + +#### METHOD 8: `store_structure_coordinates(path, atom_positions, atom_mask=None, ptm=None, overwrite=False)` + +Store predicted 3D structure with optional confidence fields. + +```python +def store_structure_coordinates( + path: str, + atom_positions: np.ndarray, + atom_mask: np.ndarray | None = None, + ptm: float | np.ndarray | None = None, + overwrite: bool = False +) -> None +``` + +**Parameters:** +- `path` (str): Archive root +- `atom_positions` (np.ndarray): Shape `(num_residues, 3)` or `(num_residues, num_atoms, 3)` + - Last dimension must be 3 (x, y, z coordinates) +- `atom_mask` (np.ndarray, optional): Presence mask matching `atom_positions` shape + - First dimension must equal `num_residues` +- `ptm` (float or scalar array, optional): Predicted TM-score confidence +- `overwrite` (bool): Replace if components exist + +**Archive Location:** +- `structure/atom_positions` (required) +- `structure/atom_mask` (optional) +- `structure/ptm` (optional) + +**Overwrite Behavior:** +- Each dataset (`atom_positions`, `atom_mask`, `ptm`) checked independently +- If any exists and `overwrite=False` → `FileExistsError` for that dataset + +**Errors:** +- `ValueError`: Wrong coordinate dimensions, atom_mask size mismatch, ptm not scalar +- `FileExistsError`: Component exists and `overwrite=False` + +**Example:** +```python +from store import store_structure_coordinates + +# CA atom coordinates only +positions = np.random.randn(128, 3) +store_structure_coordinates("trace.zarr", positions) + +# Full atom coordinates with mask and confidence +positions = np.random.randn(128, 37, 3) # 128 residues, 37 atoms/residue +mask = np.ones((128, 37), dtype=bool) +ptm_score = 0.92 +store_structure_coordinates( + "trace.zarr", + positions, + atom_mask=mask, + ptm=ptm_score +) +``` + +--- + +#### METHOD 9: `store_metadata(path, model_version, config_version, sequence, num_residues, num_recycles, recycle_info=None, residue_index=None, representation_names=None, overwrite=False)` + +Store run-level metadata for archive identification and reproducibility. + +```python +def store_metadata( + path: str, + model_version: str, + config_version: str, + sequence: str, + num_residues: int, + num_recycles: int, + recycle_info: np.ndarray | None = None, + residue_index: np.ndarray | None = None, + representation_names: list | None = None, + overwrite: bool = False +) -> None +``` + +**Parameters:** +- `path` (str): Archive root +- `model_version` (str): Model identifier or release version (e.g., "openfold-v1.0") +- `config_version` (str): Configuration identifier (e.g., "config-r2") +- `sequence` (str): Input amino acid sequence +- `num_residues` (int): Sequence length (must be ≥ 0) +- `num_recycles` (int): Number of inference recycles (must be ≥ 0) +- `recycle_info` (np.ndarray, optional): Per-recycle metadata array +- `residue_index` (np.ndarray, optional): Residue indices, shape `(num_residues,)` +- `representation_names` (list, optional): Names of stored representation types +- `overwrite` (bool): Replace if exists + +**Archive Location:** +- `metadata/model_version` +- `metadata/config_version` +- `metadata/sequence` +- `metadata/num_residues` +- `metadata/num_recycles` +- `metadata/recycle_info` (optional) +- `metadata/residue_index` (optional) +- `metadata/representation_names` (optional) + +**Errors:** +- `ValueError`: Negative residue/recycle counts, residue_index length mismatch +- `FileExistsError`: Metadata fields exist and `overwrite=False` + +**Example:** +```python +from store import store_metadata + +store_metadata( + "trace.zarr", + model_version="openfold-v2.1", + config_version="config-r4", + sequence="MVLSEGEWQLVLHVWAKVEADVAGHGQDILIRLFKSHPETLEKFDRFKHLKTEAEMKASED", + num_residues=65, + num_recycles=4, + residue_index=np.arange(65), + representation_names=["single", "pair"] +) +``` + +--- + +### **load.py Methods** + +#### METHOD 10: `load_attention_head(path, attention_type, layer_index, head_index)` + +Selectively load a single attention head without loading entire tensor. + +```python +def load_attention_head( + path: str, + attention_type: str, + layer_index: int, + head_index: int +) -> np.ndarray +``` + +**Parameters:** +- `path` (str): Archive root +- `attention_type` (str): Attention type (e.g., "triangle_start", "pairwise") +- `layer_index` (int): Transformer layer (0-indexed) +- `head_index` (int): Attention head index + +**Returns:** +- `numpy.ndarray`: Shape `(tokens, tokens)` — single 2D attention matrix + +**Errors:** +- `KeyError`: Attention type or layer not found +- `IndexError`: Head index out of range + +**Example:** +```python +from load import load_attention_head + +# Load head 2 from layer 0 pairwise attention +head = load_attention_head("trace.zarr", "pairwise", 0, 2) +print(head.shape) # (128, 128) +``` + +--- + +#### METHOD 11: `ingest_attention_txt(archive_path, txt_file, layer_index, num_tokens, attention_type="pairwise", overwrite=False)` + +Parse text-format attention and store in archive. + +```python +def ingest_attention_txt( + archive_path: str, + txt_file: str, + layer_index: int, + num_tokens: int, + attention_type: str = "pairwise", + overwrite: bool = False +) -> dict +``` + +**Parameters:** +- `archive_path` (str): Archive root +- `txt_file` (str): Path to text file +- `layer_index` (int): Transformer layer to store under +- `num_tokens` (int): Sequence length (must be > 0) +- `attention_type` (str): Attention type identifier +- `overwrite` (bool): Replace if exists + +**Expected Input Format:** +``` +Layer 0, Head 0 +0 1 0.95 +0 5 0.87 +1 0 0.92 +... +Layer 0, Head 1 +0 0 1.0 +1 1 0.99 +... +``` + +**Returns:** +```python +{ + "layer_index": int, + "attention_type": str, + "num_heads": int, + "num_tokens": int, + "source_file": str +} +``` + +**Errors:** +- `ValueError`: `num_tokens` ≤ 0, malformed file, layer mismatch, index out of bounds +- `FileExistsError`: Attention exists and `overwrite=False` + +**Example:** +```python +from load import ingest_attention_txt + +result = ingest_attention_txt( + "trace.zarr", + "attention_layer_0.txt", + layer_index=0, + num_tokens=128, + attention_type="pairwise" +) +print(f"Ingested {result['num_heads']} heads") # e.g., 8 +``` + +--- + +#### METHOD 12: `_extract_best_matching_array(container, key_token_patterns)` + +Find best array match using tokenized key patterns (internal helper). + +```python +def _extract_best_matching_array( + container: dict | list, + key_token_patterns: list[list[str]] +) -> dict | None +``` + +**Parameters:** +- `container` (dict or list): Nested object to search (e.g., pickle output dict) +- `key_token_patterns` (list[list[str]]): Token patterns to match in priority order + +**Returns:** +```python +{ + "array": np.ndarray, # Matched array + "matched_key": str, # Full key path + "pattern": list[str], # Pattern that matched + "shape": tuple # Array shape +} +``` +Returns `None` if no match found. + +**Matching Algorithm:** +1. Tokenize all key paths on non-alphanumeric boundaries (e.g., `"final_atom_positions"` → `["final", "atom", "positions"]`) +2. For each pattern in order, check if all pattern tokens appear in path tokens +3. Return first match (depth-first traversal) +4. Scoring prefers more specific patterns and shorter paths + +**Example:** +```python +from load import _extract_best_matching_array + +pkl_dict = { + "final_atom_positions": np.random.randn(128, 37, 3), + "pair_activations": np.random.randn(128, 128, 64), + "other": {...} +} + +# Match "final_atom_positions" +result = _extract_best_matching_array( + pkl_dict, + [["final", "atom", "positions"], ["final", "positions"]] +) +print(result["matched_key"]) # "final_atom_positions" +print(result["shape"]) # (128, 37, 3) + +# Match "pair_activations" +result = _extract_best_matching_array( + pkl_dict, + [["pair", "representation"], ["pair"]] +) +print(result["matched_key"]) # "pair_activations" +``` + +--- + +#### METHOD 13: `ingest_output_pkl(archive_path, pkl_file, overwrite=False)` + +Extract arrays from OpenFold/VizFold pickle output and route to archive. + +```python +def ingest_output_pkl( + archive_path: str, + pkl_file: str, + overwrite: bool = False +) -> dict +``` + +**Parameters:** +- `archive_path` (str): Archive root +- `pkl_file` (str): Path to `.pkl` file +- `overwrite` (bool): Replace if components exist + +**Auto-Routing Behavior:** + +| Key Pattern | Destination | Extraction | Notes | +|-----------|-----------|-----------|-------| +| `["final", "atom", "positions"]` | `structure/atom_positions` | CA atom (index 1) or C (index 0) | Converts (N, 37, 3) → (N, 3) | +| `["final", "atom", "mask"]` | `structure/atom_mask` | CA atom mask (index 1) or first | Optional | +| `["ptm"]` | `structure/ptm` | Scalar value | Optional confidence | +| `["pair", ...]` | `representations/pair/layer_00` | Full 3D array | Stored as layer 0 | + +**Returns:** +```python +{ + "source_file": str, + "stored": list[str], # Successfully stored paths + "skipped": list[str], # Skipped with reasons + "key_matches": { # Traceability for each key + "final_positions": { + "pattern": list, + "matched_key": str | None, + "shape": tuple | None + }, + "final_atom_mask": {...}, + "ptm": {...}, + "pair_representation": {...} + } +} +``` + +**Example Output:** +```python +{ + "source_file": "output.pkl", + "stored": [ + "structure/atom_positions", + "structure/atom_mask", + "structure/ptm", + "representations/pair/layer_00" + ], + "skipped": [], + "key_matches": { + "final_positions": { + "pattern": ["final", "atom", "positions"], + "matched_key": "final_atom_positions", + "shape": (128, 37, 3) + }, + "final_atom_mask": { + "pattern": ["final", "atom", "mask"], + "matched_key": "final_atom_mask", + "shape": (128, 37) + }, + "ptm": { + "pattern": ["ptm"], + "matched_key": "plddt", # Note: fuzzy match + "shape": () + }, + "pair_representation": { + "pattern": ["pair"], + "matched_key": "pair_activations", + "shape": (128, 128, 64) + } + } +} +``` + +**Errors:** +- `ValueError`: Pickle is not a dict +- `FileExistsError`: Components exist and `overwrite=False` (caught and reported in summary) + +**Example:** +```python +from load import ingest_output_pkl + +result = ingest_output_pkl("trace.zarr", "openfold_output.pkl") + +print(f"Stored: {result['stored']}") +print(f"Skipped: {result['skipped']}") + +# Check which keys matched +print(f"Final positions matched: {result['key_matches']['final_positions']['matched_key']}") +``` + +--- + +#### METHOD 14: `_load_dataset_as_python_value(dataset)` + +Normalize 0-D Zarr datasets to Python scalars (internal helper). + +```python +def _load_dataset_as_python_value(dataset: zarr.Array) -> Any +``` + +**Behavior:** +- 0-D array (shape `()`) → `.item()` to get Python scalar +- n-D array → return as ndarray + +--- + +#### METHOD 15: `load_metadata(path)` + +Load run metadata as a dictionary. + +```python +def load_metadata(path: str) -> dict +``` + +**Parameters:** +- `path` (str): Archive root + +**Returns:** +```python +{ + "model_version": str, + "config_version": str, + "sequence": str, + "num_residues": int, + "num_recycles": int, + "recycle_info": np.ndarray, # optional + "residue_index": np.ndarray, # optional + "representation_names": list # optional +} +``` + +**Errors:** +- `KeyError`: Metadata group missing or required fields incomplete + +**Example:** +```python +from load import load_metadata + +meta = load_metadata("trace.zarr") +print(f"Model: {meta['model_version']}") +print(f"Sequence: {meta['sequence']}") +print(f"Residues: {meta['num_residues']}") +``` + +--- + +#### METHOD 16: `load_single_representation(path, layer_index)` + +Load single representation for a specific layer. + +```python +def load_single_representation(path: str, layer_index: int) -> np.ndarray +``` + +**Parameters:** +- `path` (str): Archive root +- `layer_index` (int): Transformer layer (0-indexed) + +**Returns:** +- `numpy.ndarray`: Shape `(num_residues, hidden_dim)` + +**Errors:** +- `KeyError`: Layer not found + +**Example:** +```python +from load import load_single_representation + +single = load_single_representation("trace.zarr", 0) +print(single.shape) # (128, 512) +``` + +--- + +#### METHOD 17: `load_pair_representation(path, layer_index)` + +Load pair representation for a specific layer. + +```python +def load_pair_representation(path: str, layer_index: int) -> np.ndarray +``` + +**Parameters:** +- `path` (str): Archive root +- `layer_index` (int): Transformer layer (0-indexed) + +**Returns:** +- `numpy.ndarray`: Shape `(tokens, tokens, pair_dim)` + +**Errors:** +- `KeyError`: Layer not found + +**Example:** +```python +from load import load_pair_representation + +pair = load_pair_representation("trace.zarr", 0) +print(pair.shape) # (128, 128, 128) +``` + +--- + +#### METHOD 18: `ArchiveOrchestrator` + +Thin helper class for sequencing writes and recording events. + +```python +class ArchiveOrchestrator: + def __init__(self, archive_path: str) + + def add_metadata(self, *args, **kwargs) -> dict + def add_single_layer(self, layer_index, single_array, overwrite=False) -> dict + def add_pair_layer(self, layer_index, pair_array, overwrite=False) -> dict + def add_attention(self, attention_type, layer_index, attention_array, overwrite=False) -> dict + def add_structure(self, atom_positions, atom_mask=None, ptm=None, overwrite=False) -> dict + def validate(self, validator=validate_archive, *args, **kwargs) -> dict + def summary(self) -> dict +``` + +**Purpose:** +- Coordinate a sequence of writes +- Record what was written (event log) +- Provide structured summary for debugging + +**Key Methods:** + +**`add_metadata(...)`** — Record metadata write +**`add_single_layer(layer_index, single_array, overwrite=False)`** — Record single representation write +**`add_pair_layer(layer_index, pair_array, overwrite=False)`** — Record pair representation write +**`add_attention(attention_type, layer_index, attention_array, overwrite=False)`** — Record attention write +**`add_structure(atom_positions, atom_mask=None, ptm=None, overwrite=False)`** — Record structure write +**`validate(validator=validate_archive, *args, **kwargs)`** — Run validation and record result +**`summary()`** — Return event log and archive path + +**Example:** +```python +from load import ArchiveOrchestrator + +orchestrator = ArchiveOrchestrator("trace.zarr") + +# Build archive incrementally with event logging +orchestrator.add_metadata( + model_version="openfold-v1.0", + config_version="config-r2", + sequence="MVLSEGEWQLVL...", + num_residues=65, + num_recycles=4 +) + +orchestrator.add_single_layer(0, single_layer_0) +orchestrator.add_pair_layer(0, pair_layer_0) +orchestrator.add_attention("pairwise", 0, attn_layer_0) +orchestrator.add_structure(atom_positions, atom_mask=mask, ptm=0.92) + +# Validate at end +report = orchestrator.validate(strict=False) + +# Get summary +summary = orchestrator.summary() +print(summary) +# Output: +# { +# "archive_path": "trace.zarr", +# "events": [ +# {"action": "store", "target": "metadata"}, +# {"action": "store", "target": "representations/single/layer_00"}, +# ... +# {"action": "validate", "target": "archive", "result": {...}} +# ] +# } +``` + +--- + +## User Workflows + +### Workflow 1: Complete Archive from OpenFold Output + +```python +from load import ingest_output_pkl, ArchiveOrchestrator +from store import store_metadata + +pkl_path = "openfold_output.pkl" +archive_path = "trace.zarr" + +# Step 1: Ingest pickle (auto-extracts structure, pair representation) +result = ingest_output_pkl(archive_path, pkl_path) +print(f"Stored: {result['stored']}") + +# Step 2: Add metadata +store_metadata( + archive_path, + model_version="openfold-v2.1", + config_version="config-r4", + sequence="MVLSEGEWQLVL...", + num_residues=65, + num_recycles=4 +) + +# Step 3: Add single/pair representations for each layer +from store import store_single_representation, store_pair_representation +for layer_idx in range(48): + store_single_representation(archive_path, layer_idx, single_layer[layer_idx]) + store_pair_representation(archive_path, layer_idx, pair_layer[layer_idx]) + +# Step 4: Validate +from core import validate_archive +report = validate_archive(archive_path, strict=True) +assert report["valid"] +``` + +--- + +### Workflow 2: Incremental Building with Safety + +```python +from load import ArchiveOrchestrator +from core import validate_archive + +orchestrator = ArchiveOrchestrator("trace.zarr") + +# Build incrementally, safely +try: + orchestrator.add_metadata( + model_version="vizfold-v3", + config_version="config-exp", + sequence="...", + num_residues=256, + num_recycles=8 + ) + + # Add layer 0 + orchestrator.add_single_layer(0, single_0) + orchestrator.add_pair_layer(0, pair_0) + orchestrator.add_attention("triangle_start", 0, attn_start_0) + orchestrator.add_attention("triangle_end", 0, attn_end_0) + + # Validate at checkpoint (lenient) + report = orchestrator.validate(strict=False) + print(f"Checkpoint valid: {report['valid']}") + + # Add structure when ready + orchestrator.add_structure(atom_positions, ptm=ptm_score) + + # Final validation (strict) + final_report = orchestrator.validate(strict=True) + + print(orchestrator.summary()) + +except FileExistsError: + print("Data already exists. Set overwrite=True to replace.") +``` + +--- + +### Workflow 3: Visualization Server (Selective Loading) + +```python +from load import load_metadata, load_attention_head, load_pair_representation + +# Load metadata once at startup +meta = load_metadata("trace.zarr") +print(f"Archive: {meta['model_version']} | {meta['sequence']}") + +# Load only requested heads on demand +def get_attention_for_visualization(layer: int, head: int, attn_type: str): + return load_attention_head("trace.zarr", attn_type, layer, head) + +def get_pair_for_analysis(layer: int): + return load_pair_representation("trace.zarr", layer) + +# User requests layer 5, head 2, pairwise attention +head_data = get_attention_for_visualization(5, 2, "pairwise") +# Only ~16 KB loaded instead of entire 8 MB tensor +``` + +--- + +### Workflow 4: Archive Validation in CI/CD + +```python +from core import validate_archive +import sys + +# During testing/deployment, enforce strict validation +try: + report = validate_archive("trace.zarr", strict=True) + print(f"✓ Archive valid | Components found: {report['components_found']}") + sys.exit(0) +except ValueError as e: + print(f"✗ Archive invalid: {e}") + sys.exit(1) +``` + +--- + +## Archive Structure + +### VizFold 1.0 Format (Zarr-based) + +``` +trace.zarr/ +├── metadata/ # Run context +│ ├── model_version # str scalar +│ ├── config_version # str scalar +│ ├── sequence # str scalar +│ ├── num_residues # int32 scalar +│ ├── num_recycles # int32 scalar +│ ├── recycle_info (optional) # array +│ ├── residue_index (optional) # array (num_residues,) +│ └── representation_names (optional) # array +│ +├── representations/ +│ ├── single/ # Per-residue embeddings +│ │ ├── layer_00 # (num_residues, hidden_dim) +│ │ ├── layer_01 +│ │ └── ... +│ │ +│ └── pair/ # Token-token relationships +│ ├── layer_00 # (tokens, tokens, pair_dim) +│ ├── layer_01 +│ └── ... +│ +├── attention/ +│ ├── triangle_start/ +│ │ ├── layer_00 # (num_heads, tokens, tokens) +│ │ ├── layer_01 +│ │ └── ... +│ │ +│ ├── triangle_end/ +│ │ ├── layer_00 +│ │ └── ... +│ │ +│ └── pairwise/ +│ ├── layer_00 +│ └── ... +│ +└── structure/ # 3D coordinates + ├── atom_positions # (num_residues, 3) or (N, num_atoms, 3) + ├── atom_mask (optional) # Same shape as positions + └── ptm (optional) # Scalar float +``` + +### Layer Naming Convention + +All per-layer datasets use zero-padded naming: +- `layer_00`, `layer_01`, ..., `layer_09`, ..., `layer_47` + +This enables: +- Lexicographic sorting +- Fixed-width parsing +- Scalability to 1000+ layers + +### Chunking Strategy + +| Component | Chunking | Rationale | +|-----------|----------|-----------| +| single representations | None (default) | Small per-layer, full load acceptable | +| pair representations | None (default) | Square matrix, full load typical | +| attention | `(1, tokens, tokens)` | Head-by-head loading in visualization | +| structure/atom_positions | None (default) | Typically small (128–1024 residues) | + +--- + +## Validation & Safety + +### Overwrite Protection + +All store methods default to `overwrite=False`: + +```python +# This raises FileExistsError if layer_00 exists +store_pair_representation(archive, 0, pair_array) + +# Explicit re-run protection +store_pair_representation(archive, 0, pair_array, overwrite=False) + +# Intentional replacement +store_pair_representation(archive, 0, updated_pair, overwrite=True) +``` + +### Strict vs. Lenient Validation + +| Mode | strict=True | strict=False | +|------|-----------|------------| +| Use Case | Production, complete archives | Development, incremental builds | +| Behavior | Raises exceptions | Returns report with warnings | +| Required Components | metadata, structure, representations/pair, ≥1 layer | structure/atom_positions only | +| Optional Components | Error if missing | Warnings if missing | +| Return Type | Raises or returns valid report | Always returns report (never raises) | + +**Example:** +```python +# Production validation +try: + report = validate_archive("trace.zarr", strict=True) + # If this returns, archive is guaranteed complete +except ValueError as e: + print(f"Archive incomplete: {e}") + +# Development validation +report = validate_archive("trace.zarr", strict=False) +if report["warnings"]: + print(f"Warnings: {report['warnings']}") +# Never raises, safe for incremental workflows +``` + +--- + +## Pickle Ingestion & Traceability + +### Key Matching Algorithm + +The `_extract_best_matching_array()` function uses **tokenized matching** instead of substring search for robustness. + +**Tokenization:** +``` +"final_atom_positions" → ["final", "atom", "positions"] +"final_positions" → ["final", "positions"] +"finalAtomPositions" → ["final", "atom", "positions"] +``` + +**Matching:** +1. For pattern `["final", "atom", "positions"]`, check if ALL tokens appear in path tokens +2. Return first match (depth-first) +3. Score based on pattern specificity and path length + +**Benefits:** +- Avoids false positives (e.g., "atom" wouldn't match "atom_positions" incorrectly) +- Handles naming variations (snake_case, camelCase, PascalCase) +- Provides traceability: shows exactly what matched + +### Traceability Example + +```python +result = ingest_output_pkl("trace.zarr", "output.pkl") + +# View exact matches +print(result["key_matches"]) +# Output: +# { +# "final_positions": { +# "pattern": ["final", "atom", "positions"], +# "matched_key": "final_atom_positions", +# "shape": (128, 37, 3) +# }, +# "pair_representation": { +# "pattern": ["pair"], +# "matched_key": "pair_representation", +# "shape": (128, 128, 128) +# } +# } +``` + +This allows users to: +- Verify correct keys were matched +- Debug ingestion failures +- Audit data flow for compliance/reproducibility + +--- + +## API Patterns & Guarantees + +### Pattern 1: Consistent Overwrite Default + +**All store methods** use `overwrite=False` by default: + +```python +store_single_representation(path, layer, array) # ✓ raises if exists +store_pair_representation(path, layer, array) # ✓ raises if exists +store_attention(path, type, layer, array) # ✓ raises if exists +store_structure_coordinates(path, coords) # ✓ raises if exists +store_metadata(path, ...) # ✓ raises if exists +ingest_attention_txt(path, file, layer, tokens) # ✓ raises if exists +ingest_output_pkl(path, pkl) # ✓ raises if exists +``` + +**Rationale**: Protect against accidental overwrites in production workflows. + +--- + +### Pattern 2: Layer Indexing (0-indexed, zero-padded naming) + +All layer storage uses zero-padded layer names: + +```python +store_single_representation(path, 0, array) # → representations/single/layer_00 +store_single_representation(path, 5, array) # → representations/single/layer_05 +store_single_representation(path, 47, array) # → representations/single/layer_47 +``` + +--- + +### Pattern 3: Nested Path Parsing with `::` + +`tensor_to_zarr_array()` supports both direct and nested paths: + +```python +# Direct Zarr array file +tensor_to_zarr_array("file.zarr", array) + +# Nested within archive using :: separator +tensor_to_zarr_array("archive.zarr::group/subgroup/dataset", array) +``` + +--- + +### Pattern 4: Error Propagation + +| Error Type | Raised By | Typical Cause | Handling | +|-----------|-----------|---------------|----------| +| `ValueError` | All validate functions | Invalid input (shape, type, range) | Validate inputs before calling | +| `FileExistsError` | All store functions | Data exists and overwrite=False | Set overwrite=True or use lenient mode | +| `KeyError` | All load functions | Component not found | Check archive structure first | +| `IndexError` | `load_attention_head()` | Head index out of range | Check num_heads from metadata | + +--- + +### Pattern 5: Tensor Flexibility + +All methods accept mixed tensor types: + +```python +# PyTorch tensor +store_single_representation(path, 0, torch_tensor) + +# NumPy array +store_single_representation(path, 0, numpy_array) + +# Python list +store_single_representation(path, 0, [[1, 2], [3, 4]]) + +# All converted to NumPy internally via tensor_to_numpy() +``` + +--- + +## Testing Checklist + +### Overwrite Safety +- [ ] New data: `overwrite=False` → stores successfully +- [ ] Existing data: `overwrite=False` → raises `FileExistsError` +- [ ] Existing data: `overwrite=True` → replaces successfully +- [ ] Consistency across all 7 store functions + +### Strict Validation +- [ ] Complete archive → `valid=True` +- [ ] Missing metadata → raises `ValueError` +- [ ] Missing layers → raises `ValueError` +- [ ] Missing pair representation → raises `ValueError` +- [ ] Invalid coordinate shape → raises `ValueError` + +### Lenient Validation +- [ ] Complete archive → `valid=True` +- [ ] Missing optional components → warnings only +- [ ] Never raises exceptions +- [ ] Returns detailed report + +### Pickle Ingestion +- [ ] `ingest_output_pkl()` extracts and routes all components +- [ ] `key_matches` traceability shows matched keys and shapes +- [ ] Correctly extracts CA atom from (N, 37, 3) positions +- [ ] Handles missing optional components gracefully + +### Loaders +- [ ] `load_metadata()` returns complete dict +- [ ] `load_single_representation()` returns correct shape +- [ ] `load_pair_representation()` returns correct shape +- [ ] `load_attention_head()` returns 2D matrix without loading full tensor + +### Orchestrator +- [ ] `ArchiveOrchestrator.summary()` records all events +- [ ] `ArchiveOrchestrator.validate()` integration works +- [ ] Event log shows correct sequence of writes + +--- + +## Implementation Notes + +### Why Three Modules? + +1. **core.py (shared)**: All store and load operations depend on tensor conversion and validation +2. **store.py (write-only)**: Focused on input validation and archive construction +3. **load.py (read + orchestration)**: Includes loading, parsing external formats, and write coordination + +This split enables: +- Independent testing of each concern +- Clear import dependencies (no circular imports) +- Easier team onboarding (each module has one responsibility) +- Future extensibility (new stores don't touch loads, etc.) + +--- + +## Performance Characteristics + +| Operation | Complexity | Memory | Notes | +|-----------|-----------|--------|-------| +| `tensor_to_zarr_array()` | O(n) write | O(n) | Streaming possible with zarr chunks | +| `load_attention_head()` | O(1) seek + O(tokens²) read | O(tokens²) | Head chunking enables selective loading | +| `ingest_output_pkl()` | O(n) parse | O(n) | Single pass through pickle dict | +| `validate_archive()` | O(m) where m = num datasets | O(1) | Metadata-only checks by default | +| `_extract_best_matching_array()` | O(k) depth-first | O(path depth) | Early termination on match | + +--- + +## Changelog from Previous Spec + +### Version 1.0 → 2.0 + +**New in v2.0:** +1. Refactored into three modules (core.py, store.py, load.py) +2. Added `ArchiveOrchestrator` class for write sequencing +3. Enhanced pickle ingestion with `_extract_best_matching_array()` and `key_matches` traceability +4. Standardized `overwrite=False` default across all store methods +5. Added strict/lenient validation modes in `validate_archive()` +6. Implemented selective head loading via chunking +7. Added `store_metadata()` as first-class method +8. Added `load_metadata()` paired loader +9. Unified `tensor_to_numpy()` for all input types + +**Breaking Changes:** +- None (all methods maintain backward compatibility) + +**Deprecated:** +- None + +--- + +## Contact & Support + +For questions or issues with the VizFold Archive Utilities API: +1. Check the user workflows above +2. Review the method reference for your specific use case +3. Run `validate_archive(..., strict=False)` to check archive integrity +4. Inspect `ingest_output_pkl()` results for `key_matches` traceability diff --git a/archive/REFACTOR_AND_RUBRIC.md b/archive/REFACTOR_AND_RUBRIC.md new file mode 100644 index 00000000..258d442f --- /dev/null +++ b/archive/REFACTOR_AND_RUBRIC.md @@ -0,0 +1,550 @@ +# Archive Utilities Refactor & PR Rubric Guide + +## Overview + +This document outlines the necessary refactors to align our archive utilities with the VizFold 1.0 specification (Group #39) and the requirements for meeting the 45% PR rubric. + +### Why Refactor? + +The VizFold Inference Trace Archive specification (v1.0, March 2026) defines a formal standard for Zarr-based storage of inference traces. Our current implementation uses a different path structure and organization. Aligning with this spec ensures: + +- ✅ Interoperability with other VizFold tools +- ✅ Compatibility with standard visualization/analysis pipelines +- ✅ Future-proofing for ecosystem adoption +- ✅ Clear, documented standards compliance + +--- + +## Part 1: Required Method Refactors + +### Current vs. VizFold 1.0 Specification + +#### Archive Structure Comparison + +**Current Structure:** +``` +archive.zarr/ +├── structure/ +│ ├── coordinates +│ └── residue_types +├── representations/ +│ └── pair +└── layers/ + ├── 0/ + │ ├── activation + │ └── attention + └── 1/ + ├── activation + └── attention +``` + +**VizFold 1.0 Structure:** +``` +archive.zarr/ +├── metadata/ +│ ├── model_version +│ ├── config_version +│ ├── sequence +│ ├── num_residues +│ ├── num_recycles +│ ├── recycle_info +│ ├── residue_index +│ └── representation_names +├── representations/ +│ ├── single/ +│ │ ├── layer_00 +│ │ ├── layer_01 +│ │ └── ... +│ └── pair/ +│ ├── layer_00 +│ ├── layer_01 +│ └── ... +├── attention/ +│ └── triangle_start/ +│ ├── layer_00 +│ ├── layer_01 +│ └── ... +└── structure/ + ├── atom_positions + ├── atom_mask + └── ptm +``` + +--- + +## Method Refactors Required + +### 1. ✏️ Rename & Update `store_layer_activation()` + +**Current Signature:** +```python +def store_layer_activation(path, layer_index, activation_array, overwrite=False) +``` + +**New Signature:** +```python +def store_single_representation(path, layer_index, representation_array, overwrite=False) +``` + +**Changes:** +- **Function name:** `store_layer_activation` → `store_single_representation` +- **Path:** `layers/{layer_index}/activation` → `representations/single/layer_{layer_index:02d}` +- **Parameter name:** `activation_array` → `representation_array` (more general) +- **Docstring update:** Explain this stores per-residue representations + +**Implementation Note:** +```python +# OLD: +layer_group = root.require_group(f"layers/{layer_index}") +layer_group["activation"] = zarr.array(array) + +# NEW: +repr_group = root.require_group("representations/single") +repr_group[f"layer_{layer_index:02d}"] = zarr.array(array) +``` + +**Why:** +- Aligns with spec's `representations/single/layer_XX` structure +- Per-layer naming (`layer_00`, `layer_01`) matches spec convention +- Clearer intent: "single" vs "pair" representations + +--- + +### 2. ✏️ Refactor `store_attention_heads()` → `store_attention()` + +**Current Signature:** +```python +def store_attention_heads(path, layer_index, attention_array, overwrite=False) +``` + +**New Signature:** +```python +def store_attention(path, attention_type, layer_index, attention_array, overwrite=False) +``` + +**Changes:** +- **Function name:** `store_attention_heads` → `store_attention` +- **New parameter:** `attention_type` (e.g., "triangle_start", "triangle_end", "pairwise") +- **Path:** `layers/{layer_index}/attention` → `attention/{attention_type}/layer_{layer_index:02d}` +- **Docstring:** Explain different attention types and why they matter + +**Implementation Note:** +```python +# OLD: +layer_group = root.require_group(f"layers/{layer_index}") +layer_group["attention"] = zarr.array(array, chunks=chunks) + +# NEW: +attn_group = root.require_group(f"attention/{attention_type}") +attn_group[f"layer_{layer_index:02d}"] = zarr.array(array, chunks=chunks) +``` + +**Why:** +- VizFold spec recognizes multiple attention types (triangle_start, triangle_end, etc.) +- Our current function can't differentiate attention types +- Per-layer, per-type organization enables selective loading + +**Usage Examples:** +```python +store_attention(archive, "triangle_start", 0, attn_array) +store_attention(archive, "triangle_end", 0, attn_array) +store_attention(archive, "pairwise", 0, attn_array) +``` + +--- + +### 3. ✏️ Update `store_pair_representation()` + +**Current Signature:** +```python +def store_pair_representation(path, pair_array, overwrite=False) +``` + +**New Signature:** +```python +def store_pair_representation(path, layer_index, pair_array, overwrite=False) +``` + +**Changes:** +- **New parameter:** `layer_index` (required, currently assumed implicit) +- **Path:** `representations/pair` → `representations/pair/layer_{layer_index:02d}` +- **Docstring:** Clarify this is per-layer pair representations + +**Implementation Note:** +```python +# OLD: +repr_group = root.require_group("representations") +repr_group["pair"] = zarr.array(array) + +# NEW: +repr_group = root.require_group("representations/pair") +repr_group[f"layer_{layer_index:02d}"] = zarr.array(array) +``` + +**Why:** +- VizFold spec stores pair representations per-layer, not as single array +- Enables incremental addition of layers +- Clearer semantics: which layer's pair representations? + +**Usage Examples:** +```python +store_pair_representation(archive, 0, pair_layer_0) +store_pair_representation(archive, 1, pair_layer_1) +# Can add layer 3 later without layer 2 +store_pair_representation(archive, 3, pair_layer_3) +``` + +--- + +### 4. ✏️ Update `store_structure_coordinates()` + +**Changes to paths & field names:** +- **`coordinates` → `atom_positions`** (spec uses "atom_positions") +- **Add optional:** `atom_mask` parameter +- **Add optional:** `ptm` parameter (predicted TM-score/confidence) + +**Implementation Note:** +```python +# OLD: +structure_group["coordinates"] = zarr.array(coordinates) +structure_group["residue_types"] = zarr.array(residue_types) + +# NEW: +structure_group["atom_positions"] = zarr.array(coordinates) +# Spec also includes: +structure_group["atom_mask"] = zarr.array(atom_mask) # optional +structure_group["ptm"] = zarr.array(ptm) # scalar confidence +``` + +**Why:** +- "atom_positions" is more precise than "coordinates" +- Spec defines optional atom_mask and ptm fields +- Maintains consistency with VizFold terminology + +--- + +### 5. ✏️ Add `store_metadata()` - NEW FUNCTION + +**Signature:** +```python +def store_metadata(path, model_version, config_version, sequence, + num_residues, num_recycles, recycle_info=None, + residue_index=None, representation_names=None): +``` + +**Implementation:** +```python +root = zarr.open(path, mode='a') +metadata_group = root.require_group("metadata") + +# Store all metadata as scalar or 1D arrays +metadata_group["model_version"] = np.array(model_version, dtype=object) +metadata_group["config_version"] = np.array(config_version, dtype=object) +metadata_group["sequence"] = np.array(sequence, dtype=object) +metadata_group["num_residues"] = np.array(num_residues, dtype=np.int32) +metadata_group["num_recycles"] = np.array(num_recycles, dtype=np.int32) + +if recycle_info is not None: + metadata_group["recycle_info"] = zarr.array(recycle_info) +if residue_index is not None: + metadata_group["residue_index"] = zarr.array(residue_index) +if representation_names is not None: + metadata_group["representation_names"] = zarr.array(representation_names, dtype=object) +``` + +**Why:** +- Metadata is required by VizFold spec +- Enables reproducibility and traceability +- Necessary for archive validation and tooling + +--- + +### 6. ✏️ Update `validate_archive()` + +**Changes:** + +**Strict Mode (strict=True) must now check:** +- ✅ `metadata/` group exists with required fields +- ✅ `structure/` group with `atom_positions` (new name) +- ✅ `representations/single/` with at least one `layer_XX` +- ✅ `representations/pair/` with at least one `layer_XX` +- ✅ `attention/` group with at least one attention type +- ✅ Layer numbering is sequential (layer_00, layer_01, etc.) + +**Lenient Mode (strict=False):** +- ⚠ Warn if metadata missing +- ⚠ Warn if attention types missing +- ✅ Require only structure data (core) + +**Implementation Example:** +```python +# Check for new structure +if 'representations' in root: + repr_group = root['representations'] + if 'single' not in repr_group: + # Handle error/warning based on strict mode + if 'pair' in repr_group: + # Validate layer_00, layer_01 naming + for layer_key in repr_group['pair'].keys(): + if not layer_key.startswith('layer_'): + # Invalid naming +``` + +--- + +## Part 2: Demo Requirements + +### Demo Deliverables + +Create a reproducible demonstration showing: + +#### 1. **Overwrite Protection Works** +```python +# Scenario: Attempt to write to same layer twice +store_single_representation(archive, 0, representation) +store_single_representation(archive, 0, representation) # Should raise FileExistsError +# Then with overwrite=True, should succeed +store_single_representation(archive, 0, representation_updated, overwrite=True) +``` + +**Expected Output:** +``` +First attempt: FileExistsError ✓ +Second attempt (overwrite=True): Success ✓ +``` + +#### 2. **Validation Modes Work** +```python +# Incomplete archive +store_structure_coordinates(archive, coords) +validate_archive(archive, strict=False) # Should warn but not fail +# Output: warnings about missing representations + +validate_archive(archive, strict=True) # Should raise ValueError +# Output: Error about required components +``` + +**Expected Output:** +``` +strict=False: Valid=False, warnings=[...] ✓ +strict=True: Raises ValueError ✓ +``` + +#### 3. **Pickle Ingestion Traceability** +```python +summary = ingest_output_pkl(archive, "model_output.pkl") +print(summary['key_matches']) +# Shows what keys were searched for, what was matched +``` + +**Expected Output:** +``` +{ + 'final_positions': { + 'pattern': ['final', 'atom', 'positions'], + 'matched_key': 'output/final_atom_positions', + 'shape': (128, 37, 3) + }, + 'residue_types': { + 'pattern': ['aatype'], + 'matched_key': 'metadata/aatype', + 'shape': (128,) + }, + 'pair_representation': { + 'pattern': ['pair'], + 'matched_key': None, + 'shape': None + } +} +``` + +#### 4. **Per-Layer Organization Works** +```python +# Add multiple layers incrementally +for i in [0, 1, 5]: # Note: can skip layer 2, 3, 4 + store_single_representation(archive, i, layer_reps[i]) + +# Validate archive still works +result = validate_archive(archive, strict=False) +# Should show layers 0, 1, 5 present, others missing (but that's OK in lenient mode) +``` + +--- + +## Part 3: PR Rubric Checklist + +### Group Component (20%) + +#### Functional Correctness & Completeness (8%) +- [ ] All refactored methods work end-to-end +- [ ] Overwrite parameter functions correctly in all scenarios +- [ ] Validation correctly identifies complete/incomplete archives +- [ ] Pickle ingestion handles edge cases (missing components, wrong shapes) +- [ ] No crashes or unhandled exceptions + +#### Output / Demonstration (4%) +- [ ] Create `DEMO.md` with reproducible steps +- [ ] Include screenshots showing: + - Code changes in editor + - Test execution output + - Validation report examples +- [ ] Create process video (5-10 min) demonstrating: + - Walking through refactored functions + - Running demo scenarios + - Explaining design decisions + +#### Issue Alignment & Documentation (4%) +- [ ] PR explicitly links to Issue #40 (e.g., "Closes #40") +- [ ] PR description explains: + - What feedback was addressed (consistency, robustness, traceability) + - Why each refactor was necessary + - How new structure aligns with VizFold spec +- [ ] Include clear testing/verification steps in PR + +#### Code Quality & Design (4%) +- [ ] Code is clean and readable +- [ ] Function names are clear and descriptive +- [ ] Follows naming conventions (snake_case, descriptive terms) +- [ ] Appropriate use of helper functions +- [ ] Error messages are informative + +### Individual Component (25%) + +#### Commit Quality & Clarity (10%) +- [ ] 4-6 logical commits (not one giant commit) +- [ ] Each commit has clear, descriptive message +- [ ] Commit messages follow convention: + - `feat: add new functionality` + - `refactor: reorganize code structure` + - `docs: update documentation` + - `fix: resolve issue with X` +- [ ] Example good commits: + ``` + refactor: update archive structure to match VizFold 1.0 spec + + - Reorganize layers into representations/single/ + - Add per-layer naming (layer_00, layer_01, etc) + - Update paths for attention types + + feat: add metadata storage support + - Store model_version, config_version, sequence + - Track num_residues and num_recycles + - Enable reproducibility and traceability + + docs: add comprehensive code comments and docstrings + - Explain design decisions in store_* functions + - Add examples of correct usage + - Document edge cases + ``` + +#### Contribution Significance (8%) +- [ ] Demonstrates understanding of feedback +- [ ] Shows problem-solving (refactoring for spec compliance) +- [ ] Meaningful improvements to robustness/safety +- [ ] Clear value addition (not just minor tweaks) + +#### Collaboration & Integration (4%) +- [ ] PR description invites feedback +- [ ] Code cleanly integrates (no conflicts) +- [ ] Responsive to code review comments +- [ ] Provides constructive review of teammate's code + +#### Documentation & Code Comments (3%) +- [ ] Functions have detailed docstrings explaining: + - Purpose and semantics + - Why design decisions were made + - Edge cases handled +- [ ] Comments explain non-obvious code sections +- [ ] Examples in docstrings show correct usage +- [ ] Update METHOD_SPECIFICATIONS.md with new paths + +### Bonus Opportunities (+6%) + +#### PR Code Review (up to 3%) +- [ ] Leave 3+ quality reviews on teammate's PRs +- [ ] Reviews include: + - Specific line references + - Actionable suggestions + - Identification of potential bugs/edge cases + - Compliments on good decisions + +#### Process Video Bonus (3%) +- [ ] Screen recording with audio (5-10 minutes) +- [ ] Shows: + - Running the demo scenarios + - Walking through code changes + - Explaining the rationale + - Demonstrating that everything works +- [ ] Audio is clear and explanatory +- [ ] Reasonable editing allowed (can skip long processing) + +--- + +## Verification Checklist + +Before submitting PR, verify: + +- [ ] All refactored methods use new VizFold spec paths +- [ ] All method signatures match documentation +- [ ] Overwrite parameter works in all functions +- [ ] Validation modes (strict=True/False) behave correctly +- [ ] Pickle ingestion returns key_matches with full traceability +- [ ] Demo runs without errors (reproducible steps) +- [ ] Code is commented appropriately +- [ ] Commits are logical and well-messaged +- [ ] PR description is comprehensive +- [ ] Issue #40 is linked +- [ ] DEMO.md is included +- [ ] Process video is recorded and included + +--- + +## Files to Update + +1. **Archive Utils/outline** - Main implementation + - Refactor all method signatures and paths + - Update docstrings and comments + +2. **METHOD_SPECIFICATIONS.md** - Update with new paths + - Replace `layers/{idx}/` with `representations/single/layer_XX` + - Replace `attention` with `attention/{type}/layer_XX` + - Add metadata/ section + - Update validation spec + +3. **New: DEMO.md** - Create demo guide + - Step-by-step reproducible scenarios + - Expected outputs + - Screenshots + +4. **README.md** (existing project) - Update if needed + - Link to new refactored utilities + - Mention VizFold 1.0 compliance + +--- + +## Questions & Clarifications + +**Q: Do we need backward compatibility?** +A: No. This is a comprehensive refactor. Old format is replaced. + +**Q: What if we don't have all metadata fields?** +A: In lenient mode, missing metadata is a warning. In strict mode, it fails. Use lenient during development. + +**Q: Can we store layers out of order?** +A: Yes! The layer_XX naming supports sparse storage (layer_00, layer_03, layer_05). + +**Q: Do attention types have to be "triangle_start"?** +A: Not necessarily. Use what makes sense for your model. Common ones are triangle_start, triangle_end, pairwise. Document your choice. + +--- + +## Success Criteria + +The refactor is complete and successful when: + +1. ✅ All methods use VizFold 1.0 spec paths +2. ✅ Demo scenarios run without errors +3. ✅ Validation correctly identifies complete/incomplete archives +4. ✅ Code is clean, commented, and well-organized +5. ✅ PR includes demo, video, and comprehensive description +6. ✅ All rubric checklist items are met diff --git a/archive/__init__.py b/archive/__init__.py new file mode 100644 index 00000000..c2b2e2bf --- /dev/null +++ b/archive/__init__.py @@ -0,0 +1,39 @@ +""" +Archive Utils package exports. + +Convenience re-exports for archive utilities split across core/store/load modules. +""" + +from .core import tensor_to_numpy, tensor_to_zarr_array, validate_archive +from .store import ( + store_attention, + store_metadata, + store_pair_representation, + store_single_representation, + store_structure_coordinates, +) +from .load import ( + ingest_attention_txt, + ingest_output_pkl, + load_attention_head, + load_metadata, + load_pair_representation, + load_single_representation, +) + +__all__ = [ + "tensor_to_numpy", + "tensor_to_zarr_array", + "validate_archive", + "store_single_representation", + "store_pair_representation", + "store_attention", + "store_structure_coordinates", + "store_metadata", + "load_attention_head", + "ingest_attention_txt", + "ingest_output_pkl", + "load_metadata", + "load_single_representation", + "load_pair_representation", +] diff --git a/archive/cli.py b/archive/cli.py new file mode 100644 index 00000000..7b3b5499 --- /dev/null +++ b/archive/cli.py @@ -0,0 +1,103 @@ +""" +Manual CLI wrapper for Archive Utils. + +Examples: + python3 "archive.cli.py" validate --archive trace.zarr --strict + python3 "archive.cli.py" ingest-pkl --archive trace.zarr --pkl out.pkl --overwrite + python3 "archive.cli.py" ingest-attn-txt --archive trace.zarr --txt attn.txt --layer 0 --tokens 128 --type triangle_start + python3 "archive.cli.py" store-metadata --archive trace.zarr --model-version openfold --config-version v1 --sequence ACDE --num-residues 4 --num-recycles 1 +""" + +import argparse +import json + +try: + from .core import validate_archive + from .load import ingest_attention_txt, ingest_output_pkl + from .store import store_metadata +except ImportError: # Allow running as `python archive/cli.py`. + from core import validate_archive + from load import ingest_attention_txt, ingest_output_pkl + from store import store_metadata + + +def _cmd_validate(args): + report = validate_archive(args.archive, strict=args.strict) + print(json.dumps(report, indent=2, default=str)) + + +def _cmd_ingest_pkl(args): + summary = ingest_output_pkl(args.archive, args.pkl, overwrite=args.overwrite) + print(json.dumps(summary, indent=2, default=str)) + + +def _cmd_ingest_attn_txt(args): + summary = ingest_attention_txt( + archive_path=args.archive, + txt_file=args.txt, + layer_index=args.layer, + num_tokens=args.tokens, + attention_type=args.type, + overwrite=args.overwrite, + ) + print(json.dumps(summary, indent=2, default=str)) + + +def _cmd_store_metadata(args): + store_metadata( + path=args.archive, + model_version=args.model_version, + config_version=args.config_version, + sequence=args.sequence, + num_residues=args.num_residues, + num_recycles=args.num_recycles, + overwrite=args.overwrite, + ) + print("metadata stored") + + +def build_parser(): + parser = argparse.ArgumentParser(description="Archive Utils manual CLI") + subparsers = parser.add_subparsers(dest="command", required=True) + + p_validate = subparsers.add_parser("validate", help="Validate a trace archive") + p_validate.add_argument("--archive", required=True, help="Path to archive (e.g., trace.zarr)") + p_validate.add_argument("--strict", action="store_true", default=False, help="Enable strict validation") + p_validate.set_defaults(func=_cmd_validate) + + p_ingest_pkl = subparsers.add_parser("ingest-pkl", help="Ingest OpenFold output_dict.pkl") + p_ingest_pkl.add_argument("--archive", required=True, help="Path to archive") + p_ingest_pkl.add_argument("--pkl", required=True, help="Path to .pkl file") + p_ingest_pkl.add_argument("--overwrite", action="store_true", default=False) + p_ingest_pkl.set_defaults(func=_cmd_ingest_pkl) + + p_ingest_txt = subparsers.add_parser("ingest-attn-txt", help="Ingest attention .txt export") + p_ingest_txt.add_argument("--archive", required=True, help="Path to archive") + p_ingest_txt.add_argument("--txt", required=True, help="Path to attention text file") + p_ingest_txt.add_argument("--layer", type=int, required=True, help="Layer index") + p_ingest_txt.add_argument("--tokens", type=int, required=True, help="Number of tokens/residues") + p_ingest_txt.add_argument("--type", default="pairwise", help="Attention type") + p_ingest_txt.add_argument("--overwrite", action="store_true", default=False) + p_ingest_txt.set_defaults(func=_cmd_ingest_attn_txt) + + p_meta = subparsers.add_parser("store-metadata", help="Store minimal metadata") + p_meta.add_argument("--archive", required=True, help="Path to archive") + p_meta.add_argument("--model-version", required=True) + p_meta.add_argument("--config-version", required=True) + p_meta.add_argument("--sequence", required=True) + p_meta.add_argument("--num-residues", type=int, required=True) + p_meta.add_argument("--num-recycles", type=int, required=True) + p_meta.add_argument("--overwrite", action="store_true", default=False) + p_meta.set_defaults(func=_cmd_store_metadata) + + return parser + + +def main(): + parser = build_parser() + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/archive/core.py b/archive/core.py new file mode 100644 index 00000000..d6c85ed8 --- /dev/null +++ b/archive/core.py @@ -0,0 +1,281 @@ +""" +VizFold Archive Utilities - Core + +Shared utility methods and archive validation. +""" + +import os + +import numpy as np +import zarr + + +# ============================================================ +# METHOD 1 +# ============================================================ + +def tensor_to_numpy(tensor): + """ + Convert a VizFold tensor into a NumPy array. + + VizFold outputs may come from different frameworks such as: + - PyTorch tensors + - NumPy arrays + + This function standardizes these formats so they can be stored + in a Zarr archive. + + Expected behavior: + ------------------ + If input is a PyTorch tensor: + - Detach from computation graph + - Move to CPU if needed + - Convert to numpy array + + If input is already a NumPy array: + - Return it unchanged + + Parameters + ---------- + tensor : torch.Tensor | numpy.ndarray + + Returns + ------- + numpy.ndarray + A CPU-based NumPy representation of the tensor. + """ + if isinstance(tensor, np.ndarray): + return tensor + + # Support PyTorch tensors without importing torch as a hard dependency. + if hasattr(tensor, "detach") and hasattr(tensor, "cpu"): + return tensor.detach().cpu().numpy() + + return np.asarray(tensor) + + +# ============================================================ +# METHOD 2 +# ============================================================ + +def tensor_to_zarr_array(path, tensor, chunks=None, overwrite=False): + """ + Convert a tensor into a Zarr array stored at the specified path. + + This function creates a Zarr dataset on disk from a tensor or + NumPy array. + + Typical usage: + - Writing activations + - Writing attention maps + - Writing pair representations + + Responsibilities: + ----------------- + - Convert tensor to NumPy if necessary + - Create a Zarr dataset + - Apply chunking if specified + - Optionally overwrite existing data + + Parameters + ---------- + path : str + Path inside the archive where the array should be stored. + + tensor : numpy.ndarray | torch.Tensor + The tensor data to write. + + chunks : tuple, optional + Chunk size for Zarr storage. + + overwrite : bool + Whether existing data should be replaced. + + Returns + ------- + None + """ + import os + array = tensor_to_numpy(tensor) + array = np.asarray(array) + + if "::" in path: + archive_path, dataset_path = path.split("::", 1) + if not dataset_path: + raise ValueError("Dataset path is empty. Use 'archive.zarr::group/dataset'.") + + root = zarr.open(archive_path, mode="a") + if "/" in dataset_path: + parent_path, dataset_name = dataset_path.rsplit("/", 1) + parent_group = root.require_group(parent_path) + else: + parent_group = root + dataset_name = dataset_path + + if dataset_name in parent_group and not overwrite: + raise FileExistsError( + f"Dataset already exists at '{dataset_path}'. " + "Set overwrite=True to replace it." + ) + + if dataset_name in parent_group: + del parent_group[dataset_name] + + create_kwargs = { + "data": array, + "shape": array.shape, + "dtype": array.dtype, + } + if chunks is not None: + create_kwargs["chunks"] = chunks + parent_group.create_dataset(dataset_name, **create_kwargs) + return parent_group[dataset_name] + + if os.path.exists(path) and not overwrite: + raise FileExistsError( + f"Zarr array already exists at '{path}'. Set overwrite=True to replace it." + ) + + mode = "w" if overwrite else "w-" + z = zarr.open_array( + store=path, + mode=mode, + shape=array.shape, + dtype=array.dtype, + chunks=chunks, + ) + z[...] = array + return z + + +# ============================================================ +# INTERNAL HELPERS +# ============================================================ + +def _validate_layer_index(layer_index): + """ + Validate a transformer layer index used for per-layer archive paths. + """ + if not isinstance(layer_index, (int, np.integer)): + raise TypeError( + f"layer_index must be an integer, got {type(layer_index).__name__}" + ) + if layer_index < 0: + raise ValueError(f"layer_index must be >= 0, got {layer_index}") + + +# ============================================================ +# METHOD 8 +# ============================================================ + +def validate_archive(path, strict=True): + """ + Validate the integrity of a VizFold trace archive. + + This function checks whether the expected archive structure + exists and whether key datasets appear to be valid. + + Validation checks may include: + ------------------------------ + - Presence of required groups: + layers/ + representations/ + structure/ + + - Valid shapes for: + activations + attention maps + pair representations + structure atom positions + + - Metadata consistency if present. + + The goal is to ensure the archive can be safely used for + offline visualization and analysis. + + Parameters + ---------- + path : str + Root path to the Zarr archive. + + strict : bool, optional + Whether to enforce complete-archive requirements. Default is True. + + Returns + ------- + dict + {'valid', 'strict_mode', 'path', 'errors', 'warnings', 'components_found'} + """ + report = { + "valid": True, + "strict_mode": strict, + "path": path, + "errors": [], + "warnings": [], + "components_found": { + "metadata": False, + "representations/single": False, + "representations/pair": False, + "attention": False, + "structure/atom_positions": False, + }, + } + + def _fail(message): + report["valid"] = False + report["errors"].append(message) + if strict: + raise ValueError(message) + + def _warn(message): + report["warnings"].append(message) + + if not os.path.exists(path): + msg = f"Archive path does not exist: '{path}'" + report["valid"] = False + report["errors"].append(msg) + if strict: + raise FileNotFoundError(msg) + return report + + try: + root = zarr.open(path, mode="r") + except Exception as e: + _fail(f"Failed to open archive as Zarr: {e}") + return report + + if "structure" not in root or "atom_positions" not in root.get("structure", {}): + _fail("Missing required dataset: 'structure/atom_positions'") + else: + atom_positions = root["structure"]["atom_positions"] + if atom_positions.ndim not in (2, 3) or atom_positions.shape[-1] != 3: + _fail(f"'structure/atom_positions' invalid shape: {atom_positions.shape}") + else: + report["components_found"]["structure/atom_positions"] = True + + if "metadata" not in root: + _fail("Missing group: 'metadata'") if strict else _warn("Missing group: 'metadata'") + else: + report["components_found"]["metadata"] = True + + reprs = root.get("representations", {}) + if "single" not in reprs: + _fail("Missing group: 'representations/single'") if strict else _warn("Missing group: 'representations/single'") + elif not any(k.startswith("layer_") for k in reprs["single"].keys()): + _fail("'representations/single' has no layer_XX datasets") if strict else _warn("'representations/single' has no layer_XX datasets") + else: + report["components_found"]["representations/single"] = True + + if "pair" not in reprs: + _fail("Missing group: 'representations/pair'") if strict else _warn("Missing group: 'representations/pair'") + elif not any(k.startswith("layer_") for k in reprs["pair"].keys()): + _fail("'representations/pair' has no layer_XX datasets") if strict else _warn("'representations/pair' has no layer_XX datasets") + else: + report["components_found"]["representations/pair"] = True + + if "attention" not in root or not list(root["attention"].keys()): + _fail("Missing group: 'attention'") if strict else _warn("Missing group: 'attention'") + else: + report["components_found"]["attention"] = True + + return report diff --git a/archive/demo.py b/archive/demo.py new file mode 100644 index 00000000..45901237 --- /dev/null +++ b/archive/demo.py @@ -0,0 +1,196 @@ +""" +Demo for reading and incrementally updating a VizFold Zarr archive. + +Run from repo root: + python archive/demo.py + +Run from the archive directory: + python demo.py +""" + +import sys +from pathlib import Path + +import numpy as np +import zarr + + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from archive import ( # noqa: E402 + load_attention_head, + load_pair_representation, + load_single_representation, + store_attention, + store_pair_representation, + store_single_representation, + validate_archive, +) + + +ARCHIVE_PATH = Path(__file__).resolve().parent / "test_1UBQ.zarr" +ARCHIVE = str(ARCHIVE_PATH) +ATTENTION_TYPE = "triangle_start" +SAMPLE_RESIDUE = 10 +SAMPLE_TARGET = 20 + + +def layer_names(group): + return sorted(name for name in group.keys() if name.startswith("layer_")) + + +def latest_layer_index(group): + layers = layer_names(group) + if not layers: + raise RuntimeError("No layer_XX datasets found in group") + return int(layers[-1].removeprefix("layer_")) + + +def describe_array(label, array): + array = np.asarray(array) + print( + f" {label}: shape={array.shape}, dtype={array.dtype}, " + f"min={float(array.min()):.6f}, max={float(array.max()):.6f}, " + f"mean={float(array.mean()):.6f}, std={float(array.std()):.6f}" + ) + + +def format_vector(values, limit=5): + values = np.asarray(values).reshape(-1)[:limit] + return np.array2string(values, precision=4, separator=", ") + + +def print_archive_contents(root): + print("\nArchive contents:") + print(f" representations/single: {layer_names(root['representations']['single'])}") + print(f" representations/pair: {layer_names(root['representations']['pair'])}") + print(f" attention/{ATTENTION_TYPE}: {layer_names(root['attention'][ATTENTION_TYPE])}") + + structure = root.get("structure") + if structure is None: + print(" structure: missing") + return + + for name in ("atom_positions", "atom_mask", "ptm"): + if name in structure: + value = np.asarray(structure[name]) + preview = value.item() if value.size == 1 else f"shape={value.shape}, dtype={value.dtype}" + print(f" structure/{name}: {preview}") + + +def print_source_examples(single, pair, attention_head, root): + residue = min(SAMPLE_RESIDUE, single.shape[0] - 1) + target = min(SAMPLE_TARGET, single.shape[0] - 1) + top_targets = np.argsort(attention_head[residue])[-5:][::-1] + + print("\nConcrete values from source layer:") + print(f" single[{residue}, :5] = {format_vector(single[residue, :5])}") + print(f" pair[{residue}, {target}, :5] = {format_vector(pair[residue, target, :5])}") + print( + f" attention head 0 score residue {residue} -> {target}: " + f"{float(attention_head[residue, target]):.6f}" + ) + print(f" top 5 attention targets for residue {residue}: {top_targets.tolist()}") + + structure = root.get("structure") + if structure is not None and "atom_positions" in structure: + atom_positions = np.asarray(structure["atom_positions"]) + residue_for_structure = min(residue, atom_positions.shape[0] - 1) + if atom_positions.ndim == 3: + ca_atom_index = 1 if atom_positions.shape[1] > 1 else 0 + coords = atom_positions[residue_for_structure, ca_atom_index] + print( + f" CA coordinates for residue {residue_for_structure}: " + f"{format_vector(coords, limit=3)}" + ) + else: + coords = atom_positions[residue_for_structure] + print(f" coordinates for residue {residue_for_structure}: {format_vector(coords, limit=3)}") + + +def print_delta_examples(source_single, source_pair, source_attention_head, new_single, new_pair, new_attention_head): + residue = min(SAMPLE_RESIDUE, source_single.shape[0] - 1) + target = min(SAMPLE_TARGET, source_single.shape[0] - 1) + + print("\nBefore/after proof for new layer:") + print( + f" single[{residue}, 0]: " + f"{float(source_single[residue, 0]):.6f} -> {float(new_single[residue, 0]):.6f}" + ) + print( + f" pair[{residue}, {target}, 0]: " + f"{float(source_pair[residue, target, 0]):.6f} -> {float(new_pair[residue, target, 0]):.6f}" + ) + print( + f" attention[head=0, {residue}, {target}]: " + f"{float(source_attention_head[residue, target]):.6f} -> " + f"{float(new_attention_head[residue, target]):.6f}" + ) + print(f" single mean delta: {float((new_single - source_single).mean()):.6f}") + print(f" pair mean delta: {float((new_pair - source_pair).mean()):.6f}") + print(f" attention head 0 mean delta: {float((new_attention_head - source_attention_head).mean()):.6f}") + + +if not ARCHIVE_PATH.exists(): + raise FileNotFoundError( + f"Demo archive not found: {ARCHIVE_PATH}\n" + "Create it first with: python archive/test_archive.py" + ) + +root = zarr.open(ARCHIVE_PATH, mode="r") +single_group = root["representations"]["single"] +pair_group = root["representations"]["pair"] +attention_group = root["attention"][ATTENTION_TYPE] + +source_layer = latest_layer_index(single_group) +new_layer = source_layer + 1 +source_name = f"layer_{source_layer:02d}" +new_name = f"layer_{new_layer:02d}" + +print(f"Using archive: {ARCHIVE_PATH}") +print_archive_contents(root) + +# Read existing data through the archive utilities. +single = load_single_representation(ARCHIVE, source_layer) +pair = load_pair_representation(ARCHIVE, source_layer) +attention = np.asarray(attention_group[source_name]) +attention_head = load_attention_head(ARCHIVE, ATTENTION_TYPE, source_layer, 0) + +print(f"\nSource {source_name} summaries:") +describe_array("single representation", single) +describe_array("pair representation", pair) +describe_array("attention layer", attention) +describe_array("attention head 0", attention_head) +print_source_examples(single, pair, attention_head, root) + +# Add one incremental layer. The offset makes it easy to verify this is new data. +offset = np.float32(new_layer * 0.01) +print(f"\nAdding incremental archive paths for {new_name} with +{float(offset):.4f} offset:") +print(f" representations/single/{new_name}") +print(f" representations/pair/{new_name}") +print(f" attention/{ATTENTION_TYPE}/{new_name}") +store_single_representation(ARCHIVE, new_layer, single + offset) +store_pair_representation(ARCHIVE, new_layer, pair + offset) +store_attention(ARCHIVE, ATTENTION_TYPE, new_layer, attention + offset) + +# Re-open and read the new data back. +root = zarr.open(ARCHIVE_PATH, mode="r") +new_single = load_single_representation(ARCHIVE, new_layer) +new_pair = load_pair_representation(ARCHIVE, new_layer) +new_attention_head = load_attention_head(ARCHIVE, ATTENTION_TYPE, new_layer, 0) + +print(f"\nRead back new {new_name} summaries:") +describe_array("single representation", new_single) +describe_array("pair representation", new_pair) +describe_array("attention head 0", new_attention_head) +print_delta_examples(single, pair, attention_head, new_single, new_pair, new_attention_head) +print_archive_contents(root) + +report = validate_archive(ARCHIVE, strict=False) +print("\nValidation report:") +print(f" valid: {report['valid']}") +print(f" components_found: {report['components_found']}") +print(f" warnings: {report['warnings']}") +print(f" errors: {report['errors']}") diff --git a/archive/load.py b/archive/load.py new file mode 100644 index 00000000..559530a6 --- /dev/null +++ b/archive/load.py @@ -0,0 +1,680 @@ +""" +VizFold Archive Utilities - Load + +Methods related to loading, parsing, and orchestration. +""" + +import pickle +import re + +import numpy as np +import zarr + +try: + from .core import tensor_to_numpy, validate_archive + from .store import ( + store_attention, + store_metadata, + store_pair_representation, + store_single_representation, + store_structure_coordinates, + ) +except ImportError: # Allow direct script-style imports from the archive directory. + from core import tensor_to_numpy, validate_archive + from store import ( + store_attention, + store_metadata, + store_pair_representation, + store_single_representation, + store_structure_coordinates, + ) + + +# ============================================================ +# METHOD 7 +# ============================================================ + +def load_attention_head(path, attention_type, layer_index, head_index): + """ + Load a single attention head from the archive. + + Selective loading is important for visualization because + attention tensors can be very large. + + This function retrieves only the requested head for a specific + attention type and layer. + + Archive location (VizFold 1.0): + attention/{attention_type}/layer_{layer_index:02d} + + Expected output shape: + (tokens, tokens) + + Parameters + ---------- + path : str + Root path to the archive. + + attention_type : str + Type of attention mechanism (e.g., "triangle_start", "triangle_end", "pairwise"). + + layer_index : int + Transformer layer index (0-indexed). + + head_index : int + Index of the attention head. + + Returns + ------- + numpy.ndarray + Attention matrix for the specified head. + + Examples + -------- + >>> attn = load_attention_head("trace.zarr", "triangle_start", 0, 2) + >>> attn.shape + (128, 128) + """ + root = zarr.open(path, mode='r') + + if "attention" not in root: + raise KeyError( + f"No 'attention' group found in archive at '{path}'." + ) + + attention_group = root["attention"] + + if attention_type not in attention_group: + raise KeyError( + f"Attention type '{attention_type}' not found in archive at '{path}'. " + f"Available types: {list(attention_group.keys())}" + ) + + type_group = attention_group[attention_type] + layer_name = f"layer_{layer_index:02d}" + + if layer_name not in type_group: + raise KeyError( + f"Layer {layer_index} ('{layer_name}') not found for attention type " + f"'{attention_type}' in archive at '{path}'. " + f"Available layers: {list(type_group.keys())}" + ) + + attention = type_group[layer_name] + + if attention.ndim != 3: + raise ValueError( + f"Expected stored attention to be 3D (num_heads, tokens, tokens), " + f"got {attention.ndim}D with shape {attention.shape}" + ) + + num_heads = attention.shape[0] + + if head_index < 0 or head_index >= num_heads: + raise IndexError( + f"head_index {head_index} is out of range for layer {layer_index} " + f"which has {num_heads} attention head(s)." + ) + + return np.asarray(attention[head_index]) + + +def ingest_attention_txt(archive_path, txt_file, layer_index, num_tokens, + attention_type="pairwise", overwrite=False): + """ + Parse a VizFold attention text file and store it under attention/{attention_type}/layer_{layer_index:02d}. + + Expected input format: + Layer , Head + + ... + + Parameters + ---------- + archive_path : str + Root path to the Zarr archive. + + txt_file : str + Path to the attention text file. + + layer_index : int + Transformer layer index. + + num_tokens : int + Number of tokens/residues in the sequence. + + attention_type : str, optional + Type of attention (e.g., "triangle_start", "triangle_end", "pairwise"). + Default is "pairwise". + + overwrite : bool, optional + Whether to overwrite existing data. Default is False. + + Returns + ------- + dict + Summary of ingestion including layer_index, num_heads, num_tokens, and source_file. + """ + if num_tokens <= 0: + raise ValueError("num_tokens must be a positive integer") + + header_pattern = re.compile(r"^Layer\s+(\d+),\s+Head\s+(\d+)$", re.IGNORECASE) + heads = {} + current_head = None + + with open(txt_file, "r") as f: + for raw_line in f: + line = raw_line.strip() + if not line: + continue + + header_match = header_pattern.match(line) + if header_match: + file_layer_idx = int(header_match.group(1)) + if file_layer_idx != layer_index: + raise ValueError( + f"Layer mismatch: function arg layer_index={layer_index}, " + f"file header layer={file_layer_idx}" + ) + current_head = int(header_match.group(2)) + heads.setdefault(current_head, []) + continue + + if current_head is None: + raise ValueError("Found attention row before any layer/head header") + + parts = line.split() + if len(parts) != 3: + raise ValueError(f"Malformed attention row: '{line}'") + + res_i = int(float(parts[0])) + res_j = int(float(parts[1])) + score = float(parts[2]) + + if not (0 <= res_i < num_tokens and 0 <= res_j < num_tokens): + raise ValueError( + f"Residue index out of bounds for num_tokens={num_tokens}: " + f"({res_i}, {res_j})" + ) + + heads[current_head].append((res_i, res_j, score)) + + if not heads: + raise ValueError(f"No attention data found in '{txt_file}'") + + num_heads = max(heads.keys()) + 1 + attention = np.zeros((num_heads, num_tokens, num_tokens), dtype=np.float32) + + for head_idx, entries in heads.items(): + for res_i, res_j, score in entries: + attention[head_idx, res_i, res_j] = score + + store_attention(archive_path, attention_type, layer_index, attention, overwrite=overwrite) + return { + "layer_index": layer_index, + "attention_type": attention_type, + "num_heads": num_heads, + "num_tokens": num_tokens, + "source_file": txt_file, + } + + +def _extract_best_matching_array(container, key_token_patterns): + """ + Find the best array-like match in a nested container using tokenized key patterns. + + Matching is stricter than raw substring search: + - Key paths are tokenized on non-alphanumeric boundaries. + - A pattern matches only if all pattern tokens are present in the path tokens. + - The best match prefers more specific patterns and shorter paths. + + Parameters + ---------- + container : dict | list | tuple + Nested object to search. + + key_token_patterns : list[list[str]] + Ordered token patterns to match (e.g., [["final", "atom", "positions"]]). + + Returns + ------- + dict | None + { + "array": numpy.ndarray, + "matched_key": str, + "pattern": list[str], + "shape": tuple + } + Returns None if no matching array is found. + """ + normalized_patterns = [ + [token.lower() for token in pattern if token] + for pattern in key_token_patterns + if pattern + ] + + best = None + + def _tokenize(path): + return [token for token in re.split(r"[^a-z0-9]+", path.lower()) if token] + + def _consider(path, value): + nonlocal best + + try: + array = tensor_to_numpy(value) + except Exception: + return + + if not isinstance(array, np.ndarray): + return + + path_tokens = _tokenize(path) + if not path_tokens: + return + + path_token_set = set(path_tokens) + + for pattern in normalized_patterns: + if all(token in path_token_set for token in pattern): + # Higher score means a more specific and tighter key match. + score = len(pattern) * 100 - len(path_tokens) + if best is None or score > best["score"]: + best = { + "array": array, + "matched_key": path, + "pattern": pattern, + "shape": tuple(array.shape), + "score": score, + } + + def _walk(obj, path_prefix=""): + if isinstance(obj, dict): + for key, value in obj.items(): + path = f"{path_prefix}/{key}" if path_prefix else str(key) + _consider(path, value) + _walk(value, path) + elif isinstance(obj, (list, tuple)): + for idx, value in enumerate(obj): + path = f"{path_prefix}/{idx}" if path_prefix else str(idx) + _consider(path, value) + _walk(value, path) + + _walk(container) + + if best is None: + return None + + best.pop("score", None) + return best + + +def ingest_output_pkl(archive_path, pkl_file, overwrite=False): + """ + Load a VizFold/OpenFold output .pkl and route known arrays into the archive. + + Current routing behavior: + - final_atom_positions (N, 37, 3) -> structure/atom_positions using CA atom (index 1) + - final_atom_mask (if found) -> structure/atom_mask using CA atom (index 1) + - ptm (if found) -> structure/ptm + - pair representation (if found) -> representations/pair/layer_00 + """ + with open(pkl_file, "rb") as f: + output_dict = pickle.load(f) + + if not isinstance(output_dict, dict): + raise ValueError("Expected pickle file to contain a dictionary output") + + summary = { + "source_file": pkl_file, + "stored": [], + "skipped": [], + "key_matches": {}, + } + + final_positions_match = _extract_best_matching_array( + output_dict, + [["final", "atom", "positions"], ["final", "positions"]], + ) + final_atom_mask_match = _extract_best_matching_array( + output_dict, + [["final", "atom", "mask"], ["atom", "mask"]], + ) + ptm_match = _extract_best_matching_array(output_dict, [["ptm"], ["predicted", "tm"]]) + pair_match = _extract_best_matching_array( + output_dict, + [["pair", "representation"], ["pair", "activations"], ["pair"]], + ) + + summary["key_matches"]["final_positions"] = { + "pattern": ["final", "atom", "positions"], + "matched_key": None if final_positions_match is None else final_positions_match["matched_key"], + "shape": None if final_positions_match is None else final_positions_match["shape"], + } + summary["key_matches"]["final_atom_mask"] = { + "pattern": ["final", "atom", "mask"], + "matched_key": None if final_atom_mask_match is None else final_atom_mask_match["matched_key"], + "shape": None if final_atom_mask_match is None else final_atom_mask_match["shape"], + } + summary["key_matches"]["ptm"] = { + "pattern": ["ptm"], + "matched_key": None if ptm_match is None else ptm_match["matched_key"], + "shape": None if ptm_match is None else ptm_match["shape"], + } + summary["key_matches"]["pair_representation"] = { + "pattern": ["pair"], + "matched_key": None if pair_match is None else pair_match["matched_key"], + "shape": None if pair_match is None else pair_match["shape"], + } + + final_positions = None if final_positions_match is None else final_positions_match["array"] + final_atom_mask = None if final_atom_mask_match is None else final_atom_mask_match["array"] + ptm = None if ptm_match is None else ptm_match["array"] + + if final_positions is not None: + if final_positions.ndim == 3 and final_positions.shape[-1] == 3: + atom_positions = ( + final_positions[:, 1, :] + if final_positions.shape[1] > 1 + else final_positions[:, 0, :] + ) + atom_mask = None + if final_atom_mask is not None: + final_atom_mask = tensor_to_numpy(final_atom_mask) + if final_atom_mask.ndim == 2: + atom_mask = final_atom_mask[:, 1] if final_atom_mask.shape[1] > 1 else final_atom_mask[:, 0] + elif final_atom_mask.ndim == 1: + atom_mask = final_atom_mask + + store_structure_coordinates( + archive_path, + atom_positions, + atom_mask=atom_mask, + ptm=ptm, + overwrite=overwrite, + ) + summary["stored"].append("structure/atom_positions") + if atom_mask is not None: + summary["stored"].append("structure/atom_mask") + if ptm is not None: + summary["stored"].append("structure/ptm") + else: + summary["skipped"].append("final_atom_positions (unexpected shape)") + else: + summary["skipped"].append("final_atom_positions (not found)") + + pair_array = None if pair_match is None else pair_match["array"] + if pair_array is not None and isinstance(pair_array, np.ndarray): + if pair_array.ndim == 3 and pair_array.shape[0] == pair_array.shape[1]: + store_pair_representation(archive_path, 0, pair_array, overwrite=overwrite) + summary["stored"].append("representations/pair/layer_00") + else: + summary["skipped"].append("representations/pair/layer_00 (unexpected shape)") + else: + summary["skipped"].append("representations/pair/layer_00 (not found)") + + return summary + + +# ============================================================ +# METHOD 10 +# ============================================================ + +def _load_dataset_as_python_value(dataset): + """Load a Zarr dataset and normalize 0-D values to Python scalars.""" + value = np.asarray(dataset) + if value.shape == (): + return value.item() + return value + + +def load_metadata(path): + """ + Load metadata from a VizFold archive. + + This is the read-side counterpart to store_metadata(). It returns the + metadata group contents as a plain dictionary so callers can inspect the + run context before loading heavier arrays. + + Parameters + ---------- + path : str + Root path to the Zarr archive. + + Returns + ------- + dict + Dictionary containing required and optional metadata fields. + """ + root = zarr.open(path, mode="r") + + if "metadata" not in root: + raise KeyError(f"No 'metadata' group found in archive at '{path}'.") + + metadata_group = root["metadata"] + required_fields = ( + "model_version", + "config_version", + "sequence", + "num_residues", + "num_recycles", + ) + missing = [field for field in required_fields if field not in metadata_group] + if missing: + raise KeyError( + f"Missing required metadata field(s) in archive at '{path}': {missing}" + ) + + metadata = {} + for field in required_fields: + metadata[field] = _load_dataset_as_python_value(metadata_group[field]) + + for field in ("recycle_info", "residue_index", "representation_names"): + if field in metadata_group: + metadata[field] = _load_dataset_as_python_value(metadata_group[field]) + + return metadata + + +# ============================================================ +# METHOD 11 +# ============================================================ + +def load_single_representation(path, layer_index): + """ + Load the per-residue (single) representation for one Evoformer layer. + + Archive location read: + representations/single/layer_{layer_index:02d} + + Parameters + ---------- + path : str + Root path to the Zarr archive. + + layer_index : int + Zero-based Evoformer layer index to retrieve. + + Returns + ------- + numpy.ndarray + Per-residue representation, shape (num_residues, single_dim). + + Raises + ------ + KeyError + If the requested layer does not exist in the archive. + + Examples + -------- + >>> single = load_single_representation("trace.zarr", 0) + >>> single.shape + (128, 256) + """ + root = zarr.open(path, mode='r') + + if "representations" not in root: + raise KeyError( + f"No 'representations' group found in archive at '{path}'." + ) + + repr_group = root["representations"] + + if "single" not in repr_group: + raise KeyError( + f"No 'representations/single' group found in archive at '{path}'." + ) + + single_group = repr_group["single"] + layer_name = f"layer_{layer_index:02d}" + + if layer_name not in single_group: + raise KeyError( + f"Layer {layer_index} ('{layer_name}') not found in " + f"representations/single at '{path}'. " + f"Available layers: {list(single_group.keys())}" + ) + + return np.asarray(single_group[layer_name]) + + +# ============================================================ +# METHOD 12 +# ============================================================ + +def load_pair_representation(path, layer_index): + """ + Load the pairwise representation for one Evoformer layer. + + Archive location read: + representations/pair/layer_{layer_index:02d} + + Parameters + ---------- + path : str + Root path to the Zarr archive. + + layer_index : int + Zero-based Evoformer layer index to retrieve. + + Returns + ------- + numpy.ndarray + Pairwise representation, shape (num_residues, num_residues, pair_dim). + + Raises + ------ + KeyError + If the requested layer does not exist in the archive. + + Examples + -------- + >>> pair = load_pair_representation("trace.zarr", 0) + >>> pair.shape + (128, 128, 128) + """ + root = zarr.open(path, mode='r') + + if "representations" not in root: + raise KeyError( + f"No 'representations' group found in archive at '{path}'." + ) + + repr_group = root["representations"] + + if "pair" not in repr_group: + raise KeyError( + f"No 'representations/pair' group found in archive at '{path}'." + ) + + pair_group = repr_group["pair"] + layer_name = f"layer_{layer_index:02d}" + + if layer_name not in pair_group: + raise KeyError( + f"Layer {layer_index} ('{layer_name}') not found in " + f"representations/pair at '{path}'. " + f"Available layers: {list(pair_group.keys())}" + ) + + return np.asarray(pair_group[layer_name]) + + +# ============================================================ +# METHOD 13 +# ============================================================ + +class ArchiveOrchestrator: + """ + Thin helper that sequences archive writes and records what happened. + + The class intentionally stays lightweight: it does not replace the core + store_* functions, it only coordinates them and captures a run log. + """ + + def __init__(self, archive_path): + self.archive_path = archive_path.rstrip("/") + self.events = [] + + def _record(self, action, target, **details): + event = {"action": action, "target": target} + if details: + event.update(details) + self.events.append(event) + return event + + def add_metadata(self, *args, **kwargs): + store_metadata(self.archive_path, *args, **kwargs) + return self._record("store", "metadata") + + def add_single_layer(self, layer_index, single_array, overwrite=False): + store_single_representation( + self.archive_path, + layer_index, + single_array, + overwrite=overwrite, + ) + return self._record("store", f"representations/single/layer_{layer_index:02d}") + + def add_pair_layer(self, layer_index, pair_array, overwrite=False): + store_pair_representation( + self.archive_path, + layer_index, + pair_array, + overwrite=overwrite, + ) + return self._record("store", f"representations/pair/layer_{layer_index:02d}") + + def add_attention(self, attention_type, layer_index, attention_array, overwrite=False): + store_attention( + self.archive_path, + attention_type, + layer_index, + attention_array, + overwrite=overwrite, + ) + return self._record( + "store", + f"attention/{attention_type}/layer_{layer_index:02d}", + ) + + def add_structure(self, atom_positions, atom_mask=None, ptm=None, overwrite=False): + store_structure_coordinates( + self.archive_path, + atom_positions, + atom_mask=atom_mask, + ptm=ptm, + overwrite=overwrite, + ) + return self._record("store", "structure") + + def validate(self, validator=validate_archive, *args, **kwargs): + result = validator(self.archive_path, *args, **kwargs) + self._record("validate", "archive", result=result) + return result + + def summary(self): + return { + "archive_path": self.archive_path, + "events": list(self.events), + } diff --git a/archive/store.py b/archive/store.py new file mode 100644 index 00000000..a806602c --- /dev/null +++ b/archive/store.py @@ -0,0 +1,500 @@ +""" +VizFold Archive Utilities - Store + +Methods related to writing archive content. +""" + +import numpy as np +import zarr + +try: + from .core import _validate_layer_index, tensor_to_numpy, tensor_to_zarr_array +except ImportError: # Allow direct script-style imports from the archive directory. + from core import _validate_layer_index, tensor_to_numpy, tensor_to_zarr_array + + +# ============================================================ +# METHOD 3 +# ============================================================ + +def store_single_representation(path, layer_index, single_array, overwrite=False): + """ + Store per-layer single representation embeddings. + + Single representations capture per-residue (or per-token) embeddings + from a transformer block. These are distinct from pair representations + which capture residue-residue relationships. + + This follows the VizFold Inference Trace Archive specification v1.0, + storing representations at: representations/single/layer_{layer_index:02d} + + Typical shape: + (num_residues, hidden_dimension) + + Archive layout (VizFold 1.0): + representations/single/layer_00 + representations/single/layer_01 + ... + + Parameters + ---------- + path : str + Root path to the Zarr archive. + + layer_index : int + Index of the transformer layer (0-indexed). + + single_array : numpy.ndarray + Single representation tensor with shape (num_residues, hidden_dim). + + overwrite : bool, optional + Whether to overwrite existing data. Default is False. + + Returns + ------- + None + + Examples + -------- + >>> store_single_representation("trace.zarr", 0, layer_0_repr) + >>> store_single_representation("trace.zarr", 1, layer_1_repr) + """ + _validate_layer_index(layer_index) + single_array = tensor_to_numpy(single_array) + + if single_array.ndim != 2: + raise ValueError( + f"Expected 2D representation array (num_residues, hidden_dim), " + f"got {single_array.ndim}D with shape {single_array.shape}" + ) + + layer_name = f"layer_{layer_index:02d}" + array_path = f"{path.rstrip('/')}::representations/single/{layer_name}" + tensor_to_zarr_array(array_path, single_array, overwrite=overwrite) + + +# ============================================================ +# METHOD 4 +# ============================================================ + +def store_pair_representation(path, layer_index, pair_array, overwrite=False): + """ + Store per-layer pair representation embeddings. + + Pair representations capture relationships between residues + or tokens in the model and are commonly used in protein + structure prediction models like OpenFold. + + Typical shape: + (tokens, tokens, pair_dimension) + + Archive layout (VizFold 1.0): + representations/pair/layer_00 + representations/pair/layer_01 + ... + + Responsibilities: + ----------------- + - Validate input shape + - Store pair representation for an explicit layer index + + Parameters + ---------- + path : str + Root path to the Zarr archive. + + layer_index : int + Index of the transformer layer (0-indexed). + + pair_array : numpy.ndarray + Pair representation tensor. + + overwrite : bool, optional + Whether to overwrite existing data. Default is False. + + Returns + ------- + None + + Examples + -------- + >>> store_pair_representation("trace.zarr", 0, pair_layer_0) + >>> store_pair_representation("trace.zarr", 1, pair_layer_1) + >>> store_pair_representation("trace.zarr", 3, pair_layer_3) + """ + _validate_layer_index(layer_index) + pair_array = tensor_to_numpy(pair_array) + + if pair_array.ndim != 3: + raise ValueError( + f"Expected 3D pair array (tokens, tokens, pair_dim), " + f"got {pair_array.ndim}D with shape {pair_array.shape}" + ) + if pair_array.shape[0] != pair_array.shape[1]: + raise ValueError( + f"Pair representation must be square in first two dims (tokens x tokens), " + f"got shape {pair_array.shape}" + ) + + layer_name = f"layer_{layer_index:02d}" + array_path = f"{path.rstrip('/')}::representations/pair/{layer_name}" + tensor_to_zarr_array(array_path, pair_array, overwrite=overwrite) + + +# ============================================================ +# METHOD 5 +# ============================================================ + +def store_attention(path, attention_type, layer_index, attention_array, overwrite=False): + """ + Store attention head maps for a transformer layer by attention type. + + Attention maps describe relationships between tokens and are commonly + visualized to interpret model behavior. The VizFold spec recognizes + multiple attention types that serve different purposes: + + - "triangle_start": Triangle attention starting from one edge + - "triangle_end": Triangle attention ending at one edge + - "pairwise": Standard pairwise attention between residues + + This follows the VizFold Inference Trace Archive specification v1.0, + storing attention at: attention/{attention_type}/layer_{layer_index:02d} + + Expected tensor shape: + (num_heads, tokens, tokens) + + Archive layout (VizFold 1.0): + attention/triangle_start/layer_00 + attention/triangle_start/layer_01 + attention/triangle_end/layer_00 + attention/pairwise/layer_00 + ... + + Recommended chunking: + (1, tokens, tokens) + + This chunking allows loading a single attention head without + loading the entire tensor. + + Parameters + ---------- + path : str + Root path to the archive. + + attention_type : str + Type of attention mechanism. Common values include: + - "triangle_start": Triangle attention (starting node) + - "triangle_end": Triangle attention (ending node) + - "pairwise": Standard pairwise attention + + layer_index : int + Transformer layer index (0-indexed). + + attention_array : numpy.ndarray + Attention tensor with shape (num_heads, tokens, tokens). + + overwrite : bool, optional + Whether to overwrite existing data. Default is False. + + Returns + ------- + None + + Examples + -------- + >>> store_attention("trace.zarr", "triangle_start", 0, attn_array) + >>> store_attention("trace.zarr", "triangle_end", 0, attn_array) + >>> store_attention("trace.zarr", "pairwise", 0, attn_array) + """ + attention_array = tensor_to_numpy(attention_array) + + if attention_array.ndim != 3: + raise ValueError( + f"Expected 3D attention array (num_heads, tokens, tokens), " + f"got {attention_array.ndim}D with shape {attention_array.shape}" + ) + + num_heads, tokens_i, tokens_j = attention_array.shape + if tokens_i != tokens_j: + raise ValueError( + f"Attention matrix must be square (tokens x tokens), " + f"got shape {attention_array.shape}" + ) + + if not attention_type or not isinstance(attention_type, str): + raise ValueError( + "attention_type must be a non-empty string (e.g., 'triangle_start', " + "'triangle_end', 'pairwise')" + ) + + layer_name = f"layer_{layer_index:02d}" + array_path = f"{path.rstrip('/')}::attention/{attention_type}/{layer_name}" + chunks = (1, tokens_i, tokens_j) + tensor_to_zarr_array(array_path, attention_array, chunks=chunks, overwrite=overwrite) + + +# ============================================================ +# METHOD 6 +# ============================================================ + +def store_structure_coordinates(path, atom_positions, atom_mask=None, ptm=None, overwrite=False): + """ + Store predicted protein structure atom positions and optional confidence fields. + + These positions represent predicted 3D atomic coordinates + for a protein sequence. + + Expected atom position shape: + (num_residues, 3) or (num_residues, num_atoms, 3) + + Archive layout (VizFold 1.0): + structure/atom_positions + structure/atom_mask (optional) + structure/ptm (optional) + + Responsibilities: + ----------------- + - Validate atom position dimensions + - Store atom position array + - Optionally store atom mask + - Optionally store pTM confidence + + Parameters + ---------- + path : str + Root path to the Zarr archive. + + atom_positions : numpy.ndarray + Array of atomic positions with shape + (num_residues, 3) or (num_residues, num_atoms, 3). + + atom_mask : optional array-like + Optional atom-presence mask. If provided, first dimension must match + num_residues. + + ptm : optional float | numpy.ndarray + Optional predicted TM-score confidence value. Must be scalar. + + overwrite : bool, optional + Whether to overwrite existing data. Default is False. + + overwrite : bool, optional + Whether to overwrite existing structure datasets. Default is False. + + Returns + ------- + None + """ + atom_positions = tensor_to_numpy(atom_positions) + + # Validate atom position dimensions: (num_residues, 3) or (num_residues, num_atoms, 3) + if atom_positions.ndim not in (2, 3): + raise ValueError( + f"Expected 2D or 3D atom position array, " + f"got {atom_positions.ndim}D with shape {atom_positions.shape}" + ) + if atom_positions.shape[-1] != 3: + raise ValueError( + f"Expected last dimension to be 3 (x, y, z), " + f"got {atom_positions.shape[-1]}" + ) + + num_residues = atom_positions.shape[0] + root = zarr.open(path.rstrip("/"), mode="a") + structure_group = root.require_group("structure") + atom_positions = np.asarray(atom_positions) + + if "atom_positions" in structure_group: + if not overwrite: + raise FileExistsError( + "Dataset already exists at 'structure/atom_positions'. " + "Set overwrite=True to replace it." + ) + del structure_group["atom_positions"] + structure_group.create_dataset( + "atom_positions", + data=atom_positions, + shape=atom_positions.shape, + dtype=atom_positions.dtype, + ) + + if atom_mask is not None: + atom_mask = tensor_to_numpy(atom_mask) + atom_mask = np.asarray(np.squeeze(atom_mask)) + if atom_mask.shape[0] != num_residues: + raise ValueError( + f"atom_mask first dimension ({atom_mask.shape[0]}) must match " + f"number of residues ({num_residues})" + ) + if atom_positions.ndim == 3 and atom_mask.ndim > 1: + num_atoms = atom_positions.shape[1] + if atom_mask.shape[1] != num_atoms: + raise ValueError( + f"atom_mask second dimension ({atom_mask.shape[1]}) must match " + f"num_atoms ({num_atoms}) when atom_positions is 3D" + ) + if "atom_mask" in structure_group: + if not overwrite: + raise FileExistsError( + "Dataset already exists at 'structure/atom_mask'. " + "Set overwrite=True to replace it." + ) + del structure_group["atom_mask"] + structure_group.create_dataset( + "atom_mask", + data=atom_mask, + shape=atom_mask.shape, + dtype=atom_mask.dtype, + ) + + if ptm is not None: + ptm = tensor_to_numpy(ptm) + ptm = np.asarray(ptm) + if ptm.size != 1: + raise ValueError( + f"ptm must be a scalar value, got shape {ptm.shape}" + ) + ptm_value = np.array([float(ptm.item())], dtype=np.float32) + if "ptm" in structure_group: + if not overwrite: + raise FileExistsError( + "Dataset already exists at 'structure/ptm'. " + "Set overwrite=True to replace it." + ) + del structure_group["ptm"] + structure_group.create_dataset( + "ptm", + data=ptm_value, + shape=ptm_value.shape, + dtype=ptm_value.dtype, + ) + + +# ============================================================ +# METHOD 9 +# ============================================================ + +def store_metadata(path, model_version, config_version, sequence, + num_residues, num_recycles, recycle_info=None, + residue_index=None, representation_names=None, overwrite=False): + """ + Store run-level metadata for a VizFold archive. + + This group records the high-level run context that helps downstream + tools identify, validate, and reproduce an archive. + + Archive layout (VizFold 1.0): + metadata/model_version + metadata/config_version + metadata/sequence + metadata/num_residues + metadata/num_recycles + metadata/recycle_info + metadata/residue_index + metadata/representation_names + + Parameters + ---------- + path : str + Root path to the Zarr archive. + + model_version : str + Model identifier or release version. + + config_version : str + Configuration identifier or release version. + + sequence : str + Input sequence for the run. + + num_residues : int + Number of residues in the sequence. + + num_recycles : int + Number of recycles used during inference. + + recycle_info : optional array-like + Additional per-recycle metadata. + + residue_index : optional array-like + Residue index values for the sequence. + + representation_names : optional array-like + Ordered names of stored representations. + + overwrite : bool, optional + Whether to overwrite existing metadata datasets. Default is False. + + Returns + ------- + None + """ + archive_path = path.rstrip("/") + root = zarr.open(archive_path, mode="a") + root.require_group("metadata") + + if int(num_residues) < 0: + raise ValueError("num_residues must be non-negative") + if int(num_recycles) < 0: + raise ValueError("num_recycles must be non-negative") + + tensor_to_zarr_array( + f"{archive_path}::metadata/model_version", + np.asarray(model_version), + overwrite=overwrite, + ) + tensor_to_zarr_array( + f"{archive_path}::metadata/config_version", + np.asarray(config_version), + overwrite=overwrite, + ) + tensor_to_zarr_array( + f"{archive_path}::metadata/sequence", + np.asarray(sequence), + overwrite=overwrite, + ) + tensor_to_zarr_array( + f"{archive_path}::metadata/num_residues", + np.asarray(num_residues, dtype=np.int32), + overwrite=overwrite, + ) + tensor_to_zarr_array( + f"{archive_path}::metadata/num_recycles", + np.asarray(num_recycles, dtype=np.int32), + overwrite=overwrite, + ) + + if recycle_info is not None: + recycle_info_array = tensor_to_numpy(recycle_info) + tensor_to_zarr_array( + f"{archive_path}::metadata/recycle_info", + recycle_info_array, + overwrite=overwrite, + ) + + if residue_index is not None: + residue_index_array = tensor_to_numpy(residue_index) + if residue_index_array.ndim != 1: + residue_index_array = np.squeeze(residue_index_array) + if residue_index_array.ndim != 1: + raise ValueError( + f"residue_index must be 1D, got shape {residue_index_array.shape}" + ) + if residue_index_array.shape[0] != int(num_residues): + raise ValueError( + f"residue_index length ({residue_index_array.shape[0]}) must match " + f"num_residues ({int(num_residues)})" + ) + tensor_to_zarr_array( + f"{archive_path}::metadata/residue_index", + residue_index_array, + overwrite=overwrite, + ) + + if representation_names is not None: + representation_names_array = tensor_to_numpy(representation_names) + tensor_to_zarr_array( + f"{archive_path}::metadata/representation_names", + representation_names_array, + overwrite=overwrite, + ) diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/0/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/0/0/0 new file mode 100644 index 00000000..4f12227b Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/0/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/1/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/1/0/0 new file mode 100644 index 00000000..246d7769 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/1/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/2/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/2/0/0 new file mode 100644 index 00000000..fb82dc9e Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/2/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/3/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/3/0/0 new file mode 100644 index 00000000..cafc6860 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/3/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/4/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/4/0/0 new file mode 100644 index 00000000..d8c7cd21 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/4/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/5/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/5/0/0 new file mode 100644 index 00000000..df94fb8f Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/5/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/6/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/6/0/0 new file mode 100644 index 00000000..9b7f205d Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/6/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/7/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/7/0/0 new file mode 100644 index 00000000..e3ab8675 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_00/c/7/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/0/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/0/0/0 new file mode 100644 index 00000000..5acde838 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/0/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/1/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/1/0/0 new file mode 100644 index 00000000..207d3662 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/1/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/2/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/2/0/0 new file mode 100644 index 00000000..b307b313 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/2/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/3/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/3/0/0 new file mode 100644 index 00000000..d6020ede Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/3/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/4/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/4/0/0 new file mode 100644 index 00000000..0e6d1ad5 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/4/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/5/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/5/0/0 new file mode 100644 index 00000000..4b93a0a4 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/5/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/6/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/6/0/0 new file mode 100644 index 00000000..6d1bbca7 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/6/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/7/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/7/0/0 new file mode 100644 index 00000000..9b4d9726 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_01/c/7/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/0/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/0/0/0 new file mode 100644 index 00000000..6a694457 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/0/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/1/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/1/0/0 new file mode 100644 index 00000000..83e6f9e1 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/1/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/2/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/2/0/0 new file mode 100644 index 00000000..249c9eb6 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/2/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/3/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/3/0/0 new file mode 100644 index 00000000..d15796ec Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/3/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/4/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/4/0/0 new file mode 100644 index 00000000..0b4b485c Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/4/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/5/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/5/0/0 new file mode 100644 index 00000000..095ba04f Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/5/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/6/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/6/0/0 new file mode 100644 index 00000000..d6090018 Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/6/0/0 differ diff --git a/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/7/0/0 b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/7/0/0 new file mode 100644 index 00000000..daf5cd3a Binary files /dev/null and b/archive/test_1UBQ.zarr/attention/triangle_start/layer_02/c/7/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_00/c/0/0/0 b/archive/test_1UBQ.zarr/representations/pair/layer_00/c/0/0/0 new file mode 100644 index 00000000..dc7e2896 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_00/c/0/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_00/c/0/1/0 b/archive/test_1UBQ.zarr/representations/pair/layer_00/c/0/1/0 new file mode 100644 index 00000000..d829e684 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_00/c/0/1/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_00/c/1/0/0 b/archive/test_1UBQ.zarr/representations/pair/layer_00/c/1/0/0 new file mode 100644 index 00000000..739cd9cf Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_00/c/1/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_00/c/1/1/0 b/archive/test_1UBQ.zarr/representations/pair/layer_00/c/1/1/0 new file mode 100644 index 00000000..edd21ec7 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_00/c/1/1/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/0/0 b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/0/0 new file mode 100644 index 00000000..6902f9a2 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/0/1 b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/0/1 new file mode 100644 index 00000000..51f9fe77 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/0/1 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/1/0 b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/1/0 new file mode 100644 index 00000000..577e6904 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/1/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/1/1 b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/1/1 new file mode 100644 index 00000000..6a5d53d6 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/0/1/1 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/0/0 b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/0/0 new file mode 100644 index 00000000..8adfe366 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/0/1 b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/0/1 new file mode 100644 index 00000000..c33ed3e6 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/0/1 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/1/0 b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/1/0 new file mode 100644 index 00000000..0dc50ac3 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/1/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/1/1 b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/1/1 new file mode 100644 index 00000000..c3fe82d1 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_01/c/1/1/1 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/0/0 b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/0/0 new file mode 100644 index 00000000..1bf17819 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/0/1 b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/0/1 new file mode 100644 index 00000000..39af1760 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/0/1 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/1/0 b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/1/0 new file mode 100644 index 00000000..cb497765 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/1/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/1/1 b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/1/1 new file mode 100644 index 00000000..2f2eab91 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/0/1/1 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/0/0 b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/0/0 new file mode 100644 index 00000000..fc64d80f Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/0/1 b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/0/1 new file mode 100644 index 00000000..6421d61a Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/0/1 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/1/0 b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/1/0 new file mode 100644 index 00000000..3d22cdf0 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/1/0 differ diff --git a/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/1/1 b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/1/1 new file mode 100644 index 00000000..5a02ec7b Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/pair/layer_02/c/1/1/1 differ diff --git a/archive/test_1UBQ.zarr/representations/single/layer_00/c/0/0 b/archive/test_1UBQ.zarr/representations/single/layer_00/c/0/0 new file mode 100644 index 00000000..f843ec98 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/single/layer_00/c/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/single/layer_01/c/0/0 b/archive/test_1UBQ.zarr/representations/single/layer_01/c/0/0 new file mode 100644 index 00000000..e90ce2be Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/single/layer_01/c/0/0 differ diff --git a/archive/test_1UBQ.zarr/representations/single/layer_02/c/0/0 b/archive/test_1UBQ.zarr/representations/single/layer_02/c/0/0 new file mode 100644 index 00000000..cc86ea71 Binary files /dev/null and b/archive/test_1UBQ.zarr/representations/single/layer_02/c/0/0 differ diff --git a/archive/test_1UBQ.zarr/structure/atom_mask/c/0/0 b/archive/test_1UBQ.zarr/structure/atom_mask/c/0/0 new file mode 100644 index 00000000..b793ea5e Binary files /dev/null and b/archive/test_1UBQ.zarr/structure/atom_mask/c/0/0 differ diff --git a/archive/test_1UBQ.zarr/structure/atom_positions/c/0/0/0 b/archive/test_1UBQ.zarr/structure/atom_positions/c/0/0/0 new file mode 100644 index 00000000..6b33b819 Binary files /dev/null and b/archive/test_1UBQ.zarr/structure/atom_positions/c/0/0/0 differ diff --git a/archive/test_1UBQ.zarr/structure/ptm/c/0 b/archive/test_1UBQ.zarr/structure/ptm/c/0 new file mode 100644 index 00000000..965b2bf7 Binary files /dev/null and b/archive/test_1UBQ.zarr/structure/ptm/c/0 differ diff --git a/archive/test_archive.py b/archive/test_archive.py new file mode 100644 index 00000000..73a6e917 --- /dev/null +++ b/archive/test_archive.py @@ -0,0 +1,275 @@ +""" +Test script for VizFold Archive Utilities using real OpenFold output data. +""" +import sys +import os +import pickle +import shutil +from pathlib import Path + +import numpy as np +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from archive import ( # noqa: E402 + load_attention_head, + load_pair_representation, + load_single_representation, + store_attention, + store_pair_representation, + store_single_representation, + store_structure_coordinates, + tensor_to_numpy, + validate_archive, +) + +# Path to a test pickle file +PICKLE_FILE = "./1UBQ_result_model_1_ptm.pickle" +TEST_ARCHIVE = "test_1UBQ.zarr" + + +def get_pickle_path(): + return os.path.join(os.path.dirname(__file__), PICKLE_FILE) + + +def load_pickle_data(pkl_path): + """Load data from OpenFold output pickle.""" + with open(pkl_path, "rb") as f: + return pickle.load(f) + + +@pytest.fixture(scope="module") +def data(): + """Load real OpenFold output data for pytest runs.""" + pkl_path = get_pickle_path() + if not os.path.exists(pkl_path): + pytest.skip(f"Pickle file not found: {pkl_path}") + return load_pickle_data(pkl_path) + + +@pytest.fixture +def archive_path(tmp_path): + """Provide an isolated archive path for each pytest test.""" + return str(tmp_path / TEST_ARCHIVE) + + +def populate_archive(data, archive_path): + """Populate all components required by validate_archive().""" + single = data["representations"]["single"] + pair = data["representations"]["pair"] + sm = data["structure_module"] + num_residues = single.shape[0] + fake_attention = np.random.rand(8, num_residues, num_residues).astype(np.float32) + + store_single_representation(archive_path, layer_index=0, single_array=single, overwrite=True) + store_pair_representation(archive_path, layer_index=0, pair_array=pair, overwrite=True) + store_structure_coordinates( + archive_path, + sm["final_atom_positions"], + atom_mask=sm["final_atom_mask"], + ptm=float(data["ptm"]), + overwrite=True, + ) + store_attention( + archive_path, + attention_type="triangle_start", + layer_index=0, + attention_array=fake_attention, + overwrite=True, + ) + + +def test_store_and_load_single_representation(data, archive_path): + """Test storing and loading single representation.""" + print("\n=== Testing single representation ===") + + single = data["representations"]["single"] + print(f"Original shape: {single.shape}, dtype: {single.dtype}") + + # Store it (treating as layer 0) + store_single_representation(archive_path, layer_index=0, single_array=single, overwrite=True) + print("Stored to archive") + + # Load it back + loaded = load_single_representation(archive_path, layer_index=0) + print(f"Loaded shape: {loaded.shape}, dtype: {loaded.dtype}") + + # Verify + if np.allclose(single, loaded, rtol=1e-3): + print("✓ Single representation: PASS") + return + + print("✗ Single representation: FAIL - data mismatch") + assert False, "single representation data mismatch" + + +def test_store_and_load_pair_representation(data, archive_path): + """Test storing and loading pair representation.""" + print("\n=== Testing pair representation ===") + + pair = data["representations"]["pair"] + print(f"Original shape: {pair.shape}, dtype: {pair.dtype}") + + # Store it (treating as layer 0) + store_pair_representation(archive_path, layer_index=0, pair_array=pair, overwrite=True) + print("Stored to archive") + + # Load it back + loaded = load_pair_representation(archive_path, layer_index=0) + print(f"Loaded shape: {loaded.shape}, dtype: {loaded.dtype}") + + # Verify + if np.allclose(pair, loaded, rtol=1e-3): + print("✓ Pair representation: PASS") + return + + print("✗ Pair representation: FAIL - data mismatch") + assert False, "pair representation data mismatch" + + +def test_store_structure(data, archive_path): + """Test storing structure data.""" + print("\n=== Testing structure storage ===") + + sm = data["structure_module"] + atom_positions = sm["final_atom_positions"] + atom_mask = sm["final_atom_mask"] + ptm = float(data["ptm"]) + + print(f"atom_positions shape: {atom_positions.shape}") + print(f"atom_mask shape: {atom_mask.shape}") + print(f"ptm: {ptm}") + + # Store structure + store_structure_coordinates(archive_path, atom_positions, atom_mask=atom_mask, ptm=ptm, overwrite=True) + print("Stored to archive") + + # Verify by opening archive + import zarr + root = zarr.open(archive_path, mode="r") + + if "structure" in root: + struct = root["structure"] + if "atom_positions" in struct: + loaded_pos = np.asarray(struct["atom_positions"]) + print(f"Loaded atom_positions shape: {loaded_pos.shape}") + if np.allclose(atom_positions, loaded_pos, rtol=1e-3): + print("✓ Structure storage: PASS") + return + + print("✗ Structure storage: FAIL") + assert False, "structure atom_positions data mismatch or missing dataset" + + +def test_store_attention(data, archive_path): + """Test storing attention data (synthetic since not in pickle).""" + print("\n=== Testing attention storage ===") + + # Create synthetic attention data based on sequence length + num_residues = data["representations"]["single"].shape[0] + num_heads = 8 + + # Shape: (num_heads, num_residues, num_residues) + fake_attention = np.random.rand(num_heads, num_residues, num_residues).astype(np.float32) + print(f"Synthetic attention shape: {fake_attention.shape}") + + # Store it + store_attention(archive_path, attention_type="triangle_start", layer_index=0, + attention_array=fake_attention, overwrite=True) + print("Stored to archive") + + # Load single head + loaded_head = load_attention_head(archive_path, attention_type="triangle_start", + layer_index=0, head_index=3) + print(f"Loaded head shape: {loaded_head.shape}") + + # Verify + if np.allclose(fake_attention[3], loaded_head, rtol=1e-5): + print("✓ Attention storage: PASS") + return + + print("✗ Attention storage: FAIL") + assert False, "attention head data mismatch" + + +def test_validate_archive(data, archive_path): + """Test archive validation.""" + print("\n=== Testing archive validation ===") + populate_archive(data, archive_path) + + report = validate_archive(archive_path, strict=False) + print(f"Valid: {report['valid']}") + print(f"Components found: {report['components_found']}") + + if report["errors"]: + print(f"Errors: {report['errors']}") + if report["warnings"]: + print(f"Warnings: {report['warnings']}") + + if report["valid"]: + print("✓ Archive validation: PASS") + return + + print("✗ Archive validation: FAIL") + assert False, f"archive validation failed: {report['errors']}" + + +def run_script_test(func, *args): + try: + func(*args) + except AssertionError as e: + print(f"Assertion failed: {e}") + return False + return True + + +def main(): + print("=" * 60) + print("VizFold Archive Utilities - Test with Real Data") + print("=" * 60) + + # Check pickle file exists + pkl_path = get_pickle_path() + if not os.path.exists(pkl_path): + print(f"ERROR: Pickle file not found: {pkl_path}") + return 1 + + print(f"\nLoading data from: {pkl_path}") + data = load_pickle_data(pkl_path) + + archive_path = os.path.join(os.path.dirname(__file__), TEST_ARCHIVE) + print(f"Test archive: {archive_path}") + + # Clean up old test archive + if os.path.exists(archive_path): + shutil.rmtree(archive_path) + + # Run tests + results = [] + results.append(run_script_test(test_store_and_load_single_representation, data, archive_path)) + results.append(run_script_test(test_store_and_load_pair_representation, data, archive_path)) + results.append(run_script_test(test_store_structure, data, archive_path)) + results.append(run_script_test(test_store_attention, data, archive_path)) + results.append(run_script_test(test_validate_archive, data, archive_path)) + + # Summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + passed = sum(results) + total = len(results) + print(f"Tests passed: {passed}/{total}") + + if passed == total: + print("\n✓ All tests PASSED!") + return 0 + else: + print("\n✗ Some tests FAILED") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/archive/test_archive_utils_incremental.py b/archive/test_archive_utils_incremental.py new file mode 100644 index 00000000..afc51159 --- /dev/null +++ b/archive/test_archive_utils_incremental.py @@ -0,0 +1,128 @@ +import runpy +import tempfile +from pathlib import Path +import unittest + +import numpy as np + + +OUTLINE_PATH = Path(__file__).resolve().parents[1] / "Archive Utils" / "outline" + + +class TestArchiveUtilsIncremental(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.api = runpy.run_path(str(OUTLINE_PATH)) + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.addCleanup(self.tmpdir.cleanup) + self.archive_path = str(Path(self.tmpdir.name) / "trace.zarr") + + def test_load_metadata_roundtrip(self): + store_metadata = self.api["store_metadata"] + load_metadata = self.api["load_metadata"] + + store_metadata( + self.archive_path, + model_version="model-v1", + config_version="config-v1", + sequence="ACDE", + num_residues=4, + num_recycles=2, + recycle_info=np.array([0.1, 0.2]), + residue_index=np.array([5, 6, 7, 8]), + representation_names=np.array(["single", "pair"]), + ) + + metadata = load_metadata(self.archive_path) + self.assertEqual(metadata["model_version"], "model-v1") + self.assertEqual(metadata["config_version"], "config-v1") + self.assertEqual(metadata["sequence"], "ACDE") + self.assertEqual(metadata["num_residues"], 4) + self.assertEqual(metadata["num_recycles"], 2) + np.testing.assert_array_equal(metadata["recycle_info"], np.array([0.1, 0.2])) + np.testing.assert_array_equal(metadata["residue_index"], np.array([5, 6, 7, 8])) + np.testing.assert_array_equal(metadata["representation_names"], np.array(["single", "pair"])) + + def test_load_single_representation_roundtrip(self): + store_single_representation = self.api["store_single_representation"] + load_single_representation = self.api["load_single_representation"] + + single = np.arange(12, dtype=np.float32).reshape(3, 4) + store_single_representation(self.archive_path, 2, single) + + loaded = load_single_representation(self.archive_path, 2) + np.testing.assert_array_equal(loaded, single) + + def test_load_pair_representation_roundtrip(self): + store_pair_representation = self.api["store_pair_representation"] + load_pair_representation = self.api["load_pair_representation"] + + pair = np.arange(48, dtype=np.float32).reshape(4, 4, 3) + store_pair_representation(self.archive_path, 1, pair) + + loaded = load_pair_representation(self.archive_path, 1) + np.testing.assert_array_equal(loaded, pair) + + def test_orchestrator_end_to_end(self): + orchestrator_cls = self.api["ArchiveOrchestrator"] + load_metadata = self.api["load_metadata"] + load_single_representation = self.api["load_single_representation"] + load_pair_representation = self.api["load_pair_representation"] + + orchestrator = orchestrator_cls(self.archive_path) + + orchestrator.add_metadata( + model_version="model-v2", + config_version="config-v2", + sequence="WXYZ", + num_residues=4, + num_recycles=1, + representation_names=np.array(["single", "pair"]), + ) + orchestrator.add_single_layer(0, np.ones((4, 3), dtype=np.float32)) + orchestrator.add_pair_layer(0, np.ones((4, 4, 2), dtype=np.float32)) + orchestrator.add_attention( + "triangle_start", + 0, + np.ones((2, 4, 4), dtype=np.float32), + ) + orchestrator.add_structure( + np.arange(12, dtype=np.float32).reshape(4, 3), + atom_mask=np.array([1, 1, 1, 1], dtype=np.float32), + ptm=0.91, + ) + orchestrator.validate( + validator=lambda path, *args, **kwargs: { + "valid": True, + "path": path, + "warnings": [], + "errors": [], + } + ) + + summary = orchestrator.summary() + self.assertEqual(summary["archive_path"], self.archive_path) + self.assertEqual(summary["events"][0]["target"], "metadata") + self.assertEqual(summary["events"][1]["target"], "representations/single/layer_00") + self.assertEqual(summary["events"][2]["target"], "representations/pair/layer_00") + self.assertEqual(summary["events"][3]["target"], "attention/triangle_start/layer_00") + self.assertEqual(summary["events"][4]["target"], "structure") + self.assertEqual(summary["events"][5]["action"], "validate") + self.assertTrue(summary["events"][5]["result"]["valid"]) + + metadata = load_metadata(self.archive_path) + np.testing.assert_array_equal(metadata["representation_names"], np.array(["single", "pair"])) + np.testing.assert_array_equal( + load_single_representation(self.archive_path, 0), + np.ones((4, 3), dtype=np.float32), + ) + np.testing.assert_array_equal( + load_pair_representation(self.archive_path, 0), + np.ones((4, 4, 2), dtype=np.float32), + ) + + +if __name__ == "__main__": + unittest.main()