diff --git a/README.md b/README.md index bea1845..771290a 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,46 @@ -# SpatialTranscriptFormer +# SpatialTranscriptFormer Framework > [!WARNING] > **Work in Progress**: This project is under active development. Core architectures, CLI flags, and data formats are subject to major changes. -**SpatialTranscriptFormer** bridges histology and biological pathways through a high-performance transformer architecture. By modeling the dense interplay between morphological features and gene expression signatures, it provides an interpretable and spatially-coherent mapping of the tissue microenvironment. + + +> [!TIP] +> **Framework Release**: SpatialTranscriptFormer has been restructured from a research codebase into a robust framework. You can now use the Python API to train on your own spatial transcriptomics data with custom backbones and architectures. + +**SpatialTranscriptFormer** is a modular deep learning framework designed to bridge histology and biological pathways. It leverages transformer architectures to model the interplay between morphological features and gene expression signatures, providing interpretable mapping of the tissue microenvironment. + +## Python API: Quick Start + +The framework is designed to be integrated programmatically into your scanpy/AnnData workflows: + +```python +from spatial_transcript_former import SpatialTranscriptFormer, Predictor, FeatureExtractor +from spatial_transcript_former.predict import inject_predictions + +# 1. Initialize model and backbone +model = SpatialTranscriptFormer.from_pretrained("./checkpoints/stf_small/") +extractor = FeatureExtractor(backbone="phikon", device="cuda") +predictor = Predictor(model, device="cuda") + +# 2. Predict from features +predictions = predictor.predict_wsi(features, coords) # (1, G) + +# 3. Integrate with Scanpy +inject_predictions(adata, coords, predictions[0], gene_names=model.gene_names) +``` + +For more details, see the **[Python API Reference](docs/API.md)**. ## Key Technical Pillars +- **Modular Architecture**: Decoupled backbones, interaction modules, and output heads. - **Quad-Flow Interaction**: Configurable attention between Pathways and Histology patches (`p2p`, `p2h`, `h2p`, `h2h`). - **Pathway Bottleneck**: Interpretable gene expression prediction via 50 MSigDB Hallmark tokens. -- **Spatial Pattern Coherence**: Optimized using a composite **MSE + PCC (Pearson Correlation) loss** to prevent spatial collapse and ensure accurate morphology-expression mapping. +- **Spatial Pattern Coherence**: Optimized using a composite **MSE + PCC (Pearson Correlation) loss**. - **Foundation Model Ready**: Native support for **CTransPath**, **Phikon**, **Hibou**, and **GigaPath**. -- **Biologically Informed Initialization**: Gene reconstruction weights derived from known hallmark memberships. + +--- ## License @@ -28,76 +57,58 @@ This project is protected by a **Proprietary Source Code License**. See the [LIC The core architectural innovations, including the **SpatialTranscriptFormer** interaction logic and spatial masking strategies, are the unique Intellectual Property of the author. For a detailed breakdown, see the [IP Statement](docs/IP_STATEMENT.md). +--- + ## Installation This project requires [Conda](https://docs.conda.io/en/latest/). 1. Clone the repository. 2. Run the automated setup script: -3. On Windows: `.\setup.ps1` + - On Windows: `.\setup.ps1` - On Linux/HPC: `bash setup.sh` -## Usage +## Exemplar Recipe: HEST-1k Benchmark -### Dataset Access +The `SpatialTranscriptFormer` repository includes a complete, out-of-the-box CLI pipeline as an exemplar for reproducing our benchmarks on the [HEST-1k dataset](https://huggingface.co/datasets/MahmoodLab/hest). -The model uses the **HEST1k** dataset. You can download specific subsets (by organ, technology, etc.) or the entire dataset using the `stf-download` utility: +### 1. Dataset Access & Preprocessing ```bash -# List available filtering options -stf-download --list-options - -# Download a specific subset (e.g., Breast Cancer samples from Visium) +# Download a specific subset stf-download --organ Breast --disease Cancer --tech Visium --local_dir hest_data - -# Download all human samples -stf-download --species "Homo sapiens" --local_dir hest_data ``` -> [!NOTE] -> The HEST dataset is gated on Hugging Face. Ensure you have accepted the terms at [MahmoodLab/hest](https://huggingface.co/datasets/MahmoodLab/hest) and are logged in via `huggingface-cli login`. - -### Train Models - -We provide presets for baseline models and scaled versions of the SpatialTranscriptFormer. +### 2. Training with Presets ```bash # Recommended: Run the Interaction model (Small) python scripts/run_preset.py --preset stf_small - -# Run the lightweight Tiny version -python scripts/run_preset.py --preset stf_tiny - -# Run baselines -python scripts/run_preset.py --preset he2rna_baseline ``` -For a complete list of configurations, see the [Training Guide](docs/TRAINING_GUIDE.md). - -### Real-Time Monitoring - -Monitor training progress, loss curves, and **prediction variance (collapse detector)** via the web dashboard: +### 3. Inference & Visualization ```bash -python scripts/monitor.py --run-dir runs/stf_interaction_l4 +stf-predict --data-dir A:\hest_data --sample-id MEND29 --model-path checkpoints/best_model.pth --model-type interaction ``` -### Inference & Visualization +Visualization plots and spatial expression maps will be saved to the `./results` directory. For the full guide, see the **[HEST Recipe Docs](src/spatial_transcript_former/recipes/hest/README.md)**. -Generate spatial maps comparing Ground Truth vs Predictions: +## Documentation -```bash -stf-predict --data-dir A:\hest_data --sample-id MEND29 --model-path checkpoints/best_model.pth --model-type interaction -``` +### Framework APIs & Usage -Visualization plots will be saved to the `./results` directory. +- **[Python API Reference](docs/API.md)**: Full documentation for `Trainer`, `Predictor`, and `SpatialDataset`. +- **[Bring Your Own Data Guide](src/spatial_transcript_former/recipes/custom/README.md)**: Templates and examples for training on your own non-HEST spatial transcriptomics data. +- **[HEST Recipe Docs](src/spatial_transcript_former/recipes/hest/README.md)**: Detailed documentation for the included HEST-1k dataset recipe. +- **[Training Guide](docs/TRAINING_GUIDE.md)**: Complete list of configuration flags and preset configurations for HEST models. -## Documentation +### Theory & Interpretability -- [Models](docs/MODELS.md): Detailed model architectures and scaling parameters. -- [Data Structure](docs/DATA_STRUCTURE.md): Organization of HEST data on disk. -- [Pathway Mapping](docs/PATHWAY_MAPPING.md): Clinical interpretability and pathway integration. -- [Gene Analysis](docs/GENE_ANALYSIS.md): Modeling strategies for high-dimensional gene space. +- **[Models & Architecture](docs/MODELS.md)**: Deep dive into the quad-flow interaction logic and network scaling. +- **[Pathway Mapping](docs/PATHWAY_MAPPING.md)**: Clinical interpretability, pathway bottleneck design, and MSigDB integration. +- **[Gene Analysis](docs/GENE_ANALYSIS.md)**: Modeling strategies for mapping morphology to high-dimensional gene spaces. +- **[Data Structure](docs/DATA_STRUCTURE.md)**: Detailed breakdown of the HEST data structure on disk, metadata conventions, and preprocessing invariants. ## Development diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 0000000..8f5048d --- /dev/null +++ b/docs/API.md @@ -0,0 +1,614 @@ +# Python API Reference + +The SpatialTranscriptFormer package exposes a clean API for training, inference, and integration with the Scanpy/AnnData ecosystem. + +```python +from spatial_transcript_former import ( + SpatialTranscriptFormer, # Core model + Trainer, # High-level training orchestrator + Predictor, # Inference wrapper + FeatureExtractor, # Backbone feature extraction + save_pretrained, # Save checkpoint directory + load_pretrained, # Load checkpoint directory + inject_predictions, # AnnData integration +) +``` + +--- + +## Quick Start + +### End-to-End Inference (New Data) + +```python +from spatial_transcript_former import SpatialTranscriptFormer, Predictor, FeatureExtractor +from spatial_transcript_former.predict import inject_predictions +import scanpy as sc + +# 1. Load model from checkpoint directory +model = SpatialTranscriptFormer.from_pretrained("./checkpoints/my_run/") +print(model.gene_names[:3]) # ['TP53', 'EGFR', 'MYC'] + +# 2. Extract features from raw patches +extractor = FeatureExtractor(backbone="phikon", device="cuda") +features = extractor.extract_batch(image_tensor, batch_size=64) # (N, 768) + +# 3. Predict gene expression +predictor = Predictor(model, device="cuda") +predictions = predictor.predict_wsi(features, coords) # (1, G) + +# 4. Inject into AnnData for Scanpy analysis +adata = sc.AnnData(obs=pd.DataFrame(index=[f"spot_{i}" for i in range(N)])) +inject_predictions(adata, coords, predictions[0], gene_names=model.gene_names) +sc.pl.spatial(adata, color="TP53") +``` + +### Saving a Trained Model + +```python +from spatial_transcript_former import save_pretrained + +# After training, export a self-contained checkpoint +save_pretrained(model, "./release/v1/", gene_names=gene_list) +``` + +This creates: + +``` +release/v1/ +├── config.json # Architecture parameters +├── model.pth # Model weights (state_dict) +└── gene_names.json # Ordered gene symbols +``` + +--- + +## API Reference + +### `SpatialTranscriptFormer` + +The core transformer model. Predicts gene expression from histology patch features and spatial coordinates. + +#### `SpatialTranscriptFormer.__init__(...)` + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `num_genes` | `int` | *required* | Number of output genes | +| `num_pathways` | `int` | `50` | Number of pathway bottleneck tokens | +| `backbone_name` | `str` | `"resnet50"` | Backbone identifier (`resnet50`, `phikon`, `ctranspath`, etc.) | +| `pretrained` | `bool` | `True` | Load pretrained backbone weights | +| `token_dim` | `int` | `256` | Common embedding dimension | +| `n_heads` | `int` | `4` | Number of attention heads | +| `n_layers` | `int` | `2` | Number of transformer layers | +| `dropout` | `float` | `0.1` | Dropout probability | +| `pathway_init` | `Tensor` | `None` | `(P, G)` biological pathway membership matrix | +| `use_spatial_pe` | `bool` | `True` | Enable learned spatial positional encodings | +| `output_mode` | `str` | `"counts"` | Output head: `"counts"` (Softplus) or `"zinb"` (Zero-Inflated NB) | +| `interactions` | `list[str]` | all | Attention interactions: `p2p`, `p2h`, `h2p`, `h2h` | + +#### `SpatialTranscriptFormer.from_pretrained(checkpoint_dir, device="cpu", **kwargs)` + +Load a model from a checkpoint directory created by `save_pretrained`. + +```python +model = SpatialTranscriptFormer.from_pretrained("./checkpoint/", device="cuda") +model.gene_names # List[str] or None +``` + +| Parameter | Type | Description | +|---|---|---| +| `checkpoint_dir` | `str` | Path to directory with `config.json` + `model.pth` | +| `device` | `str` | Torch device (`"cpu"`, `"cuda"`) | +| `**kwargs` | | Override any `config.json` value (e.g. `dropout=0.0`) | + +**Returns:** `SpatialTranscriptFormer` in eval mode with `.gene_names` attribute. + +--- + +### `Predictor` + +Stateful inference wrapper. Manages device placement, eval mode, and optional AMP. + +#### `Predictor.__init__(model, device="cpu", use_amp=False)` + +```python +predictor = Predictor(model, device="cuda", use_amp=True) +``` + +#### `Predictor.predict_patch(image, return_pathways=False)` + +Single-patch inference from a raw image tensor. + +```python +result = predictor.predict_patch(image) # image: (1, 3, 224, 224) or (3, 224, 224) +# result: (1, num_genes) +``` + +> **Note:** When the model uses spatial PE, a zero-coordinate is automatically injected — no need to provide coordinates for single patches. + +#### `Predictor.predict_wsi(features, coords, return_pathways=False, return_dense=False)` + +Whole-slide inference from pre-extracted feature embeddings. + +```python +# Global prediction (one vector per slide) +result = predictor.predict_wsi(features, coords) # (1, G) + +# Dense prediction (one vector per patch) +result = predictor.predict_wsi(features, coords, return_dense=True) # (1, N, G) +``` + +| Parameter | Type | Description | +|---|---|---| +| `features` | `Tensor` | `(N, D)` or `(1, N, D)` embeddings | +| `coords` | `Tensor` | `(N, 2)` or `(1, N, 2)` spatial coordinates | +| `return_pathways` | `bool` | Also return pathway scores | +| `return_dense` | `bool` | Per-patch predictions instead of global | + +> **Validation:** Raises `ValueError` with a clear message if the feature dimension doesn't match the model's expected backbone dimension. + +#### `Predictor.predict(features, coords=None, **kwargs)` + +Unified entry point — auto-dispatches: + +- 4D tensor `(B, 3, H, W)` → `predict_patch` +- 2D tensor `(N, D)` → `predict_wsi` (requires `coords`) + +--- + +### `FeatureExtractor` + +Wraps a backbone model and its normalization transform for one-line feature extraction. + +#### `FeatureExtractor.__init__(backbone="resnet50", device="cpu", pretrained=True, transform=None)` + +```python +extractor = FeatureExtractor(backbone="phikon", device="cuda") +extractor.feature_dim # 768 +extractor.backbone_name # "phikon" +``` + +| Backbone | `feature_dim` | Source | +|---|---|---| +| `resnet50` | 2048 | torchvision | +| `ctranspath` | 768 | HuggingFace (CTransPath) | +| `phikon` | 768 | Owkin Phikon (HuggingFace) | +| `vit_b_16` | 768 | torchvision | +| `gigapath` | 1536 | ProvGigaPath *(gated)* | +| `hibou-b` | 768 | Hibou-B *(gated)* | +| `hibou-l` | 1024 | Hibou-L *(gated)* | + +#### `extractor(images)` / `extractor.extract_batch(images, batch_size=64)` + +```python +features = extractor(images) # (N, D) — all at once +features = extractor.extract_batch(images, batch_size=64) # batched, returns on CPU +``` + +Images should be float tensors in `[0, 1]` range, shape `(N, 3, H, W)`. + +--- + +### `save_pretrained(model, save_dir, gene_names=None)` + +Save a self-contained checkpoint directory. + +```python +save_pretrained(model, "./release/v1/", gene_names=["TP53", "EGFR", ...]) +``` + +| Parameter | Type | Description | +|---|---|---| +| `model` | `SpatialTranscriptFormer` | Trained model instance | +| `save_dir` | `str` | Output directory (created if needed) | +| `gene_names` | `list[str]` | Optional ordered gene symbols (must match `num_genes`) | + +**Raises:** `ValueError` if `gene_names` length doesn't match `num_genes`. + +### AnnData & Scanpy — A Primer + +If you're coming from a pure deep-learning background, AnnData and Scanpy may be unfamiliar. They are the standard data format and analysis toolkit in single-cell and spatial biology — the equivalent of what Pandas DataFrames are for tabular ML. + +#### What is AnnData? + +An `AnnData` object is a structured container for observations × variables matrices, designed for genomics. Think of it as a spreadsheet with labelled sidecars: + +``` + var (genes) + ┌──────────────────┐ + │ TP53 EGFR MYC │ + ┌────┼──────────────────┤ + obs │ s0 │ 0.3 1.2 0.8 │ ← adata.X (the main data matrix) + (spots/ │ s1 │ 0.1 0.5 1.1 │ + cells) │ s2 │ 0.9 0.2 0.4 │ + └────┴──────────────────┘ +``` + +| Slot | What it stores | Our usage | +|---|---|---| +| `adata.X` | Main matrix `(N, G)` | Predicted gene expression | +| `adata.obs` | Per-observation metadata | Spot/cell barcodes, cluster labels | +| `adata.var` | Per-variable metadata | Gene symbols as the index | +| `adata.obsm["spatial"]` | Observation-level embeddings | `(N, 2)` spatial coordinates | +| `adata.obsm["spatial_pathways"]` | Additional embeddings | `(N, P)` pathway scores | +| `adata.uns` | Unstructured metadata | Pathway names, model config | + +#### What is Scanpy? + +[Scanpy](https://scanpy.readthedocs.io/) (`sc`) is the analysis library that operates on AnnData objects. Once predictions are inside an `adata`, you instantly get access to: + +```python +import scanpy as sc + +# Spatial plotting — visualise gene expression on tissue coordinates +sc.pl.spatial(adata, color="TP53") + +# Clustering — find groups of spots with similar expression +sc.tl.leiden(adata) + +# Differential expression — find marker genes per cluster +sc.tl.rank_genes_groups(adata, groupby="leiden") + +# Dimensionality reduction +sc.tl.pca(adata) +sc.tl.umap(adata) +sc.pl.umap(adata, color="leiden") +``` + +#### Why does this matter? + +By injecting predictions into AnnData, our model's output becomes **instantly compatible** with the entire Scanpy ecosystem — clustering, differential testing, spatial plotting, trajectory analysis — without any custom code. Biologists can take our predictions and run their standard workflows immediately. + +--- + +### `inject_predictions(adata, coords, predictions, ...)` + +Inject predictions into an AnnData object for Scanpy integration. + +```python +inject_predictions( + adata, + coords, # → adata.obsm["spatial"] + predictions, # → adata.X + gene_names=["TP53", "EGFR", ...], # → adata.var_names + pathway_scores=pathway_activations, # → adata.obsm["spatial_pathways"] + pathway_names=["APOPTOSIS", ...], # → adata.uns["pathway_names"] +) +``` + +| Parameter | Type | Description | +|---|---|---| +| `adata` | `AnnData` | Target AnnData object | +| `coords` | `ndarray` | `(N, 2)` spatial coordinates | +| `predictions` | `ndarray` | `(N, G)` gene predictions | +| `gene_names` | `list[str]` | Optional gene symbols | +| `pathway_scores` | `ndarray` | Optional `(N, P)` pathway scores | +| `pathway_names` | `list[str]` | Optional pathway names | + +> **Lazy loading:** `anndata` is only imported when this function is called, so it's not required for basic inference. + +--- + +## Checkpoint Directory Format + +``` +checkpoint/ +├── config.json # Architecture (JSON) +├── model.pth # Weights (PyTorch state_dict) +└── gene_names.json # Gene symbols (JSON array, optional) +``` + +**`config.json` example:** + +```json +{ + "num_genes": 460, + "num_pathways": 50, + "backbone_name": "phikon", + "token_dim": 256, + "n_heads": 4, + "n_layers": 2, + "dropout": 0.1, + "use_spatial_pe": true, + "output_mode": "counts", + "interactions": ["h2h", "h2p", "p2h", "p2p"] +} +``` + +**`gene_names.json` example:** + +```json +["TP53", "EGFR", "MYC", "BRCA1", ...] +``` + +--- + +## Training API + +The training pipeline lives in the `spatial_transcript_former.training` subpackage. You can use it via the **CLI** or **programmatically** in your own scripts. + +### CLI Quick Start + +Training is launched via the `stf-train` entry point (or `python -m spatial_transcript_former.train`): + +```bash +# Minimal: train on precomputed features with whole-slide mode +stf-train \ + --model interaction \ + --backbone phikon \ + --data-dir /path/to/hest \ + --precomputed \ + --whole-slide \ + --use-spatial-pe \ + --pathway-init \ + --loss mse_pcc \ + --epochs 100 \ + --lr 1e-4 \ + --warmup-epochs 10 + +# Resume from checkpoint +stf-train --model interaction --resume --output-dir ./checkpoints +``` + +### CLI Arguments + +#### Data + +| Flag | Default | Description | +| --- | --- | --- | +| `--data-dir` | from config | Root HEST data directory | +| `--feature-dir` | auto | Explicit pre-extracted feature directory | +| `--num-genes` | 1000 | Number of output genes | +| `--precomputed` | off | Use pre-extracted backbone features | +| `--whole-slide` | off | Dense whole-slide prediction mode | +| `--organ` | all | Filter samples by organ type | +| `--max-samples` | all | Limit samples (for debugging) | + +#### Model + +| Flag | Default | Description | +| --- | --- | --- | +| `--model` | `he2rna` | Architecture: `interaction`, `he2rna`, `vit_st`, `attention_mil`, `transmil` | +| `--backbone` | `resnet50` | Backbone: `resnet50`, `phikon`, `ctranspath`, `vit_b_16`, etc. | +| `--num-pathways` | 50 | Pathway bottleneck tokens | +| `--token-dim` | 256 | Embedding dimension | +| `--n-heads` | 4 | Attention heads | +| `--n-layers` | 2 | Transformer layers | +| `--use-spatial-pe` | off | Learned spatial positional encoding | +| `--interactions` | all | Attention mask: `p2p p2h h2p h2h` | +| `--pathway-init` | off | Initialize gene head from MSigDB Hallmarks | + +#### Training + +| Flag | Default | Description | +| --- | --- | --- | +| `--epochs` | 10 | Total training epochs | +| `--batch-size` | 32 | Batch size | +| `--lr` | 1e-4 | Learning rate | +| `--warmup-epochs` | 10 | Linear warmup before cosine annealing | +| `--weight-decay` | 0.0 | AdamW weight decay | +| `--grad-accum-steps` | 1 | Gradient accumulation steps | +| `--use-amp` | off | Mixed precision (FP16) | +| `--compile` | off | `torch.compile` the model | +| `--resume` | off | Resume from latest checkpoint | + +#### Loss + +| Flag | Default | Description | +| --- | --- | --- | +| `--loss` | `mse_pcc` | Loss function: `mse`, `pcc`, `mse_pcc`, `zinb`, `poisson`, `logcosh` | +| `--pcc-weight` | 1.0 | PCC term weight in `mse_pcc` | +| `--pathway-loss-weight` | 0.0 | Auxiliary pathway PCC loss weight (0 = disabled) | + +--- + +### Trainer (High-Level) + +The `Trainer` class handles LR scheduling, AMP, checkpointing, logging, and early stopping: + +```python +from spatial_transcript_former import SpatialTranscriptFormer, Trainer +from spatial_transcript_former.training import CompositeLoss, EarlyStoppingCallback + +model = SpatialTranscriptFormer(num_genes=460, backbone_name="phikon", ...) + +trainer = Trainer( + model=model, + train_loader=train_dl, + val_loader=val_dl, + criterion=CompositeLoss(alpha=1.0), + epochs=100, + warmup_epochs=10, + device="cuda", + output_dir="./checkpoints", + use_amp=True, + callbacks=[EarlyStoppingCallback(patience=15)], +) +results = trainer.fit() # returns {"best_val_loss", "history", ...} +trainer.save_pretrained("./release/v1/") # inference-ready export +``` + +#### Trainer Parameters + +| Parameter | Default | Description | +| --- | --- | --- | +| `model` | *required* | Any `nn.Module` | +| `train_loader` | *required* | Training `DataLoader` | +| `val_loader` | *required* | Validation `DataLoader` | +| `criterion` | *required* | Loss function | +| `optimizer` | `None` | Custom optimizer (default: `AdamW`) | +| `lr` | `1e-4` | Learning rate (if no custom optimizer) | +| `epochs` | `100` | Total training epochs | +| `warmup_epochs` | `10` | Linear warmup before cosine annealing | +| `use_amp` | `False` | Mixed precision (FP16) | +| `grad_accum_steps` | `1` | Gradient accumulation | +| `whole_slide` | `False` | Dense whole-slide mode | +| `output_dir` | `./checkpoints` | Directory for checkpoints/logs | +| `callbacks` | `[]` | List of `TrainerCallback` instances | +| `resume` | `False` | Resume from checkpoint | + +#### Callbacks + +Subclass `TrainerCallback` to hook into the training loop: + +```python +from spatial_transcript_former.training import TrainerCallback + +class WandbCallback(TrainerCallback): + def on_epoch_end(self, trainer, epoch, metrics): + wandb.log(metrics, step=epoch) + + def should_stop(self, trainer, epoch, metrics): + return False # never stop early +``` + +| Hook | When | +| --- | --- | +| `on_train_begin(trainer)` | Start of `fit()` | +| `on_epoch_begin(trainer, epoch)` | Before each epoch | +| `on_epoch_end(trainer, epoch, metrics)` | After validation | +| `on_train_end(trainer, results)` | End of `fit()` | +| `should_stop(trainer, epoch, metrics)` | Return `True` to stop | + +Built-in: `EarlyStoppingCallback(patience=15, min_delta=0.0)` + +--- + +### Programmatic Training (Low-Level) + +For full control, use the engine functions directly: + +```python +from spatial_transcript_former.training import train_one_epoch, validate, CompositeLoss + +criterion = CompositeLoss(alpha=1.0) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + +for epoch in range(100): + train_loss = train_one_epoch( + model, train_loader, criterion, optimizer, device, + whole_slide=True, scaler=scaler, grad_accum_steps=4, + ) + val_metrics = validate( + model, val_loader, criterion, device, + whole_slide=True, use_amp=True, + ) + print(f"Epoch {epoch}: train={train_loss:.4f}, val={val_metrics['val_loss']:.4f}") +``` + +--- + +### Loss Functions (`training.losses`) + +All losses accept `(B, G)` patch-level or `(B, N, G)` dense inputs, with optional `mask` for padded positions. + +| Class | Formula / Description | +| --- | --- | +| `MaskedMSELoss` | Standard MSE with optional padding mask | +| `PCCLoss` | `1 - mean(PCC)` — gene-wise spatial Pearson correlation | +| `CompositeLoss` | `MSE + α · (1 - PCC)` — balances magnitude and spatial pattern | +| `ZINBLoss` | Zero-Inflated Negative Binomial NLL — for raw count data | +| `MaskedHuberLoss` | Huber (SmoothL1) — robust to outlier spots | +| `AuxiliaryPathwayLoss` | Wraps any base loss + PCC on pathway bottleneck scores | + +### Training Engine (`training.engine`) + +| Function | Description | +| --- | --- | +| `train_one_epoch(model, loader, criterion, optimizer, device, ...)` | One epoch of training. Handles gradient accumulation, AMP, and both patch/WSI modes. Returns average loss. | +| `validate(model, loader, criterion, device, ...)` | Validation pass. Returns `dict` with `val_loss`, `val_mae`, `val_pcc`, `pred_variance`, and optional `attn_correlation`. | + +### Experiment Logger (`training.experiment_logger`) + +Offline-friendly logger (no W&B dependency). Writes metrics to SQLite and a JSON summary. + +```python +logger = ExperimentLogger(output_dir, config_dict) +logger.log_epoch(epoch, {"train_loss": 0.1, "val_loss": 0.2, "val_pcc": 0.65}) +logger.finalize(best_val_loss=0.15) +``` + +| Output File | Contents | +| --- | --- | +| `training_logs.sqlite` | Per-epoch metrics table | +| `results_summary.json` | Config + final metrics + runtime | + +### Checkpoint Lifecycle + +During training, checkpoints are managed by `training.checkpoint` (the *internal* module — distinct from the public `save_pretrained`): + +| Function | Purpose | +| --- | --- | +| `save_checkpoint(model, optimizer, scaler, schedulers, ...)` | Saves full training state for `--resume` | +| `load_checkpoint(model, optimizer, scaler, schedulers, ...)` | Restores training state | + +After training is complete, use the public `save_pretrained` to export a clean, inference-ready checkpoint: + +```python +from spatial_transcript_former import save_pretrained + +# Export for inference (strips optimizer/scheduler state) +save_pretrained(model, "./release/v1/", gene_names=gene_list) +``` + +--- + +## Bring Your Own Data + +All datasets implement the `SpatialDataset` contract (in `data.base`). The contract requires `__getitem__` to return: + +```python +(features, gene_counts, rel_coords) +# features: (S, D) tensor — patch embeddings (S = 1 + neighbours) +# gene_counts: (G,) tensor — expression targets +# rel_coords: (S, 2) tensor — relative spatial coordinates +``` + +### Minimal Implementation + +```python +from spatial_transcript_former.data.base import SpatialDataset +import torch + +class MyVisiumDataset(SpatialDataset): + def __init__(self, features, gene_matrix, coords): + self._features = torch.as_tensor(features, dtype=torch.float32) + self._genes = torch.as_tensor(gene_matrix, dtype=torch.float32) + self._coords = torch.as_tensor(coords, dtype=torch.float32) + self.num_genes = self._genes.shape[1] + + def __len__(self): + return len(self._features) + + def __getitem__(self, idx): + feat = self._features[idx].unsqueeze(0) # (1, D) + genes = self._genes[idx] # (G,) + rel_coord = torch.zeros(1, 2) # centre = [0,0] + return feat, genes, rel_coord +``` + +### Training Your Custom Dataset + +```python +from torch.utils.data import DataLoader, random_split +from spatial_transcript_former import SpatialTranscriptFormer, Trainer +from spatial_transcript_former.training import CompositeLoss, EarlyStoppingCallback + +dataset = MyVisiumDataset(features, gene_matrix, coords) +train_ds, val_ds = random_split(dataset, [0.8, 0.2]) + +model = SpatialTranscriptFormer(num_genes=dataset.num_genes, backbone_name="phikon") + +trainer = Trainer( + model=model, + train_loader=DataLoader(train_ds, batch_size=32, shuffle=True), + val_loader=DataLoader(val_ds, batch_size=64), + criterion=CompositeLoss(), + epochs=100, + callbacks=[EarlyStoppingCallback(patience=15)], +) +results = trainer.fit() +trainer.save_pretrained("./my_model/") +``` + +See `recipes/custom/README.md` for the full guide. diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md index 29f24d5..abede54 100644 --- a/docs/TRAINING_GUIDE.md +++ b/docs/TRAINING_GUIDE.md @@ -1,6 +1,7 @@ -# Training Guide +# Training Guide (HEST Benchmark Recipe) -This guide provides command-line recipes for training different architectures and configurations using `spatial_transcript_former.train`. +> [!NOTE] +> This guide provides command-line recipes specifically for the **HEST-1k benchmark dataset**. If you are looking to train on your own data using the core API, please see the **[Python API Reference](API.md)**. ## Prerequisites diff --git a/scripts/diagnose_collapse.py b/scripts/diagnose_collapse.py index d764933..87bf279 100644 --- a/scripts/diagnose_collapse.py +++ b/scripts/diagnose_collapse.py @@ -21,7 +21,7 @@ import numpy as np from spatial_transcript_former.models.interaction import SpatialTranscriptFormer -from spatial_transcript_former.data.dataset import ( +from spatial_transcript_former.recipes.hest.dataset import ( HEST_FeatureDataset, load_global_genes, ) diff --git a/scripts/download_hest.py b/scripts/download_hest.py index 87fac44..0f6c293 100644 --- a/scripts/download_hest.py +++ b/scripts/download_hest.py @@ -5,7 +5,7 @@ # Add src to path sys.path.append(os.path.abspath("src")) -from spatial_transcript_former.data.download import ( +from spatial_transcript_former.recipes.hest.download import ( download_hest_subset, download_metadata, ) diff --git a/scripts/inspect_outputs.py b/scripts/inspect_outputs.py index 0279923..a347665 100644 --- a/scripts/inspect_outputs.py +++ b/scripts/inspect_outputs.py @@ -4,7 +4,10 @@ import argparse import numpy as np from spatial_transcript_former.models import SpatialTranscriptFormer -from spatial_transcript_former.data.utils import get_sample_ids, setup_dataloaders +from spatial_transcript_former.recipes.hest.utils import ( + get_sample_ids, + setup_dataloaders, +) class Args: diff --git a/scripts/inspect_sample.py b/scripts/inspect_sample.py index f70a038..9a07a14 100644 --- a/scripts/inspect_sample.py +++ b/scripts/inspect_sample.py @@ -5,7 +5,10 @@ # Add src to path sys.path.append(os.path.abspath("src")) -from spatial_transcript_former.data.io import get_hest_data_dir, load_h5ad_metadata +from spatial_transcript_former.recipes.hest.io import ( + get_hest_data_dir, + load_h5ad_metadata, +) from spatial_transcript_former.config import get_config from spatial_transcript_former.data.pathways import ( download_msigdb_gmt, diff --git a/src/spatial_transcript_former/__init__.py b/src/spatial_transcript_former/__init__.py index e69de29..ada9443 100644 --- a/src/spatial_transcript_former/__init__.py +++ b/src/spatial_transcript_former/__init__.py @@ -0,0 +1,33 @@ +""" +SpatialTranscriptFormer — predict gene expression from histology. + +Core public API:: + + from spatial_transcript_former import ( + SpatialTranscriptFormer, # the model + Predictor, # inference wrapper + FeatureExtractor, # backbone feature extraction + Trainer, # high-level training orchestrator + save_pretrained, # checkpoint serialization + inject_predictions, # AnnData integration + ) +""" + +from spatial_transcript_former.models.interaction import SpatialTranscriptFormer +from spatial_transcript_former.predict import ( + FeatureExtractor, + Predictor, + inject_predictions, +) +from spatial_transcript_former.checkpoint import save_pretrained, load_pretrained +from spatial_transcript_former.training.trainer import Trainer + +__all__ = [ + "SpatialTranscriptFormer", + "Predictor", + "FeatureExtractor", + "Trainer", + "save_pretrained", + "load_pretrained", + "inject_predictions", +] diff --git a/src/spatial_transcript_former/checkpoint.py b/src/spatial_transcript_former/checkpoint.py new file mode 100644 index 0000000..1c99b82 --- /dev/null +++ b/src/spatial_transcript_former/checkpoint.py @@ -0,0 +1,203 @@ +""" +Public-facing checkpoint serialization for SpatialTranscriptFormer. + +Saves and loads a self-contained checkpoint directory containing: + - config.json — architecture hyper-parameters + - model.pth — model weights (state_dict) + - gene_names.json — ordered list of gene symbols (optional) +""" + +import json +import os +from typing import Any, Dict, List, Optional + +import torch + + +# Keys serialized into config.json. These correspond to +# SpatialTranscriptFormer.__init__ arguments (minus runtime-only +# arguments like ``pathway_init`` and ``pretrained``). +_CONFIG_KEYS = [ + "num_genes", + "num_pathways", + "backbone_name", + "token_dim", + "n_heads", + "n_layers", + "dropout", + "use_spatial_pe", + "output_mode", + "interactions", +] + + +def _model_config(model) -> Dict[str, Any]: + """Extract serializable architecture config from a live model.""" + from spatial_transcript_former.models.interaction import ( + SpatialTranscriptFormer, + ) + + if not isinstance(model, SpatialTranscriptFormer): + raise TypeError(f"Expected SpatialTranscriptFormer, got {type(model).__name__}") + + # Reconstruct config from the live model's attributes / constructor args. + num_genes = model.gene_reconstructor.out_features + num_pathways = model.num_pathways + token_dim = model.image_proj.out_features + backbone_name = _infer_backbone_name(model) + + # Transformer encoder introspection + first_layer = model.fusion_engine.layers[0] + n_heads = first_layer.self_attn.num_heads + n_layers = len(model.fusion_engine.layers) + dropout = first_layer.dropout.p if hasattr(first_layer, "dropout") else 0.1 + + use_spatial_pe = model.use_spatial_pe + output_mode = model.output_mode + interactions = sorted(model.interactions) + + return { + "num_genes": num_genes, + "num_pathways": num_pathways, + "backbone_name": backbone_name, + "token_dim": token_dim, + "n_heads": n_heads, + "n_layers": n_layers, + "dropout": dropout, + "use_spatial_pe": use_spatial_pe, + "output_mode": output_mode, + "interactions": interactions, + } + + +def _infer_backbone_name(model) -> str: + """Best-effort inference of backbone name from stored attribute or class.""" + # If we explicitly stored it (set by from_pretrained or user code): + if hasattr(model, "_backbone_name"): + return model._backbone_name + # Fallback: inspect the backbone module's class name + backbone_cls = type(model.backbone).__name__.lower() + if "resnet" in backbone_cls: + return "resnet50" + if "ctrans" in backbone_cls: + return "ctranspath" + if "phikon" in backbone_cls or "dinov2" in backbone_cls: + return "phikon" + return "unknown" + + +# ── Public API ──────────────────────────────────────────────────────── + + +def save_pretrained( + model, + save_dir: str, + gene_names: Optional[List[str]] = None, +) -> None: + """Save a SpatialTranscriptFormer checkpoint directory. + + Creates ``save_dir`` containing: + - ``config.json`` — architecture parameters + - ``model.pth`` — ``state_dict`` + - ``gene_names.json`` — ordered gene symbols (if provided) + + Args: + model: A :class:`SpatialTranscriptFormer` instance. + save_dir: Directory to write files into (created if needed). + gene_names: Optional ordered list of gene symbols matching the + model's ``num_genes`` output dimension. + """ + os.makedirs(save_dir, exist_ok=True) + + # 1. Config + config = _model_config(model) + with open(os.path.join(save_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # 2. Weights + torch.save(model.state_dict(), os.path.join(save_dir, "model.pth")) + + # 3. Gene names (optional) + if gene_names is not None: + if len(gene_names) != config["num_genes"]: + raise ValueError( + f"gene_names length ({len(gene_names)}) does not match " + f"model num_genes ({config['num_genes']})" + ) + with open(os.path.join(save_dir, "gene_names.json"), "w") as f: + json.dump(gene_names, f) + + print(f"Saved pretrained checkpoint to {save_dir}") + + +def load_pretrained( + checkpoint_dir: str, + device: str = "cpu", + **override_kwargs, +): + """Load a SpatialTranscriptFormer from a pretrained checkpoint directory. + + Reads ``config.json`` to reconstruct architectural parameters, then + loads ``model.pth`` weights and (optionally) ``gene_names.json``. + + Args: + checkpoint_dir: Path to directory containing ``config.json`` and + ``model.pth``. + device: Torch device string (e.g. ``"cpu"``, ``"cuda"``). + **override_kwargs: Override any config values (e.g. ``dropout=0.0`` + for deterministic inference). + + Returns: + SpatialTranscriptFormer: The loaded model in eval mode with + ``gene_names`` attribute set (or ``None``). + """ + from spatial_transcript_former.models.interaction import ( + SpatialTranscriptFormer, + ) + + config_path = os.path.join(checkpoint_dir, "config.json") + weights_path = os.path.join(checkpoint_dir, "model.pth") + + if not os.path.isfile(config_path): + raise FileNotFoundError( + f"config.json not found in {checkpoint_dir}. " + "Use save_pretrained() to create a valid checkpoint directory." + ) + if not os.path.isfile(weights_path): + raise FileNotFoundError(f"model.pth not found in {checkpoint_dir}.") + + # 1. Read config + with open(config_path, "r") as f: + config = json.load(f) + + # Apply any user overrides + config.update(override_kwargs) + + # Don't load pretrained backbone weights — we're loading our own + config["pretrained"] = False + + # 2. Instantiate + model = SpatialTranscriptFormer(**config) + + # Store backbone name for future save_pretrained round-trips + model._backbone_name = config.get("backbone_name", "unknown") + + # 3. Load weights + state_dict = torch.load(weights_path, map_location=device, weights_only=True) + model.load_state_dict(state_dict, strict=False) + model.to(device) + model.eval() + + # 4. Gene names (optional) + gene_names_path = os.path.join(checkpoint_dir, "gene_names.json") + if os.path.isfile(gene_names_path): + with open(gene_names_path, "r") as f: + model.gene_names = json.load(f) + else: + model.gene_names = None + + print( + f"Loaded SpatialTranscriptFormer from {checkpoint_dir} " + f"({config['num_genes']} genes, {config['num_pathways']} pathways)" + ) + return model diff --git a/src/spatial_transcript_former/data/__init__.py b/src/spatial_transcript_former/data/__init__.py index be82d84..72f245a 100644 --- a/src/spatial_transcript_former/data/__init__.py +++ b/src/spatial_transcript_former/data/__init__.py @@ -1,4 +1,22 @@ -from .dataset import HEST_Dataset, get_hest_dataloader -from .splitting import split_hest_patients +""" +Data abstractions for SpatialTranscriptFormer. -from .download import download_hest_subset, download_metadata, filter_samples +Core exports: + - :class:`SpatialDataset` — abstract base class implementing the data contract + - :func:`apply_dihedral_augmentation` — D4 coordinate augmentation + - :func:`apply_dihedral_to_tensor` — D4 image augmentation + - :func:`normalize_coordinates` — auto-normalise spatial coordinates + +HEST-specific exports (backward compatibility — prefer ``recipes.hest``): + - :class:`HEST_Dataset`, :func:`get_hest_dataloader` + - :func:`split_hest_patients` + - :func:`download_hest_subset`, :func:`download_metadata`, :func:`filter_samples` +""" + +# Core abstractions (framework) +from .base import ( + SpatialDataset, + apply_dihedral_augmentation, + apply_dihedral_to_tensor, + normalize_coordinates, +) diff --git a/src/spatial_transcript_former/data/base.py b/src/spatial_transcript_former/data/base.py new file mode 100644 index 0000000..e777976 --- /dev/null +++ b/src/spatial_transcript_former/data/base.py @@ -0,0 +1,176 @@ +""" +Abstract data contracts for SpatialTranscriptFormer. + +Defines the :class:`SpatialDataset` ABC that any spatial transcriptomics +dataset must implement. The training engine, ``Trainer``, and ``Predictor`` +all depend only on this contract — never on a specific data source. +""" + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class SpatialDataset(Dataset, ABC): + """Abstract base class for spatial transcriptomics datasets. + + Any dataset used with SpatialTranscriptFormer must subclass this and + implement :meth:`__getitem__` and :meth:`__len__`. + + ``__getitem__`` must return a 3-tuple:: + + (features, gene_counts, rel_coords) + + where: + + * **features** — ``(S, D)`` float tensor of patch embeddings + (``S`` = 1 + K neighbours, ``D`` = backbone feature dim), or + ``(3, H, W)`` / ``(S, 3, H, W)`` image tensor in raw-patch mode. + * **gene_counts** — ``(G,)`` float tensor of gene expression targets. + * **rel_coords** — ``(S, 2)`` float tensor of spatial coordinates + relative to the centre patch (centre is always ``[0, 0]``). + + Subclasses SHOULD also expose :attr:`num_genes` and (optionally) + :attr:`gene_names` as properties. + + Example:: + + class MyVisiumDataset(SpatialDataset): + def __init__(self, slide_path, genes, coords, features): + self._features = features # (N, D) + self._genes = genes # (N, G) + self._coords = coords # (N, 2) + + def __len__(self): + return len(self._features) + + def __getitem__(self, idx): + center = self._coords[idx] + rel = self._coords - center # simplest: no neighbour selection + return self._features[idx:idx+1], self._genes[idx], rel[idx:idx+1] + + @property + def num_genes(self): + return self._genes.shape[1] + + @property + def gene_names(self): + return None # or a list of gene symbols + """ + + @abstractmethod + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return ``(features, gene_counts, rel_coords)`` for index ``idx``.""" + ... + + @abstractmethod + def __len__(self) -> int: + """Return the number of items in the dataset.""" + ... + + # ------------------------------------------------------------------ + # Optional attributes — subclasses can set these as instance + # attributes (self.num_genes = ...) or override as properties. + # ------------------------------------------------------------------ + #: Number of gene expression targets. Set in __init__ or override. + num_genes: int = 0 + #: Ordered list of gene symbols, or None if unavailable. + gene_names: Optional[List[str]] = None + + +# --------------------------------------------------------------------------- +# Generic augmentation helpers (shared across recipes) +# --------------------------------------------------------------------------- + + +def apply_dihedral_augmentation(coords, op=None): + """Apply one of the 8 dihedral (D4) symmetries to 2-D coordinates. + + Args: + coords: ``(N, 2)`` array or tensor of (x, y) coordinates. + op: Integer in ``[0, 7]`` or ``None`` (random). + + Returns: + Tuple of (augmented_coords, op). + """ + is_torch = isinstance(coords, torch.Tensor) + if is_torch: + x, y = coords[..., 0].clone(), coords[..., 1].clone() + else: + x, y = coords[..., 0].copy(), coords[..., 1].copy() + + if op is None: + op = np.random.randint(0, 8) + + if op == 0: + pass + elif op == 1: + x, y = y, -x + elif op == 2: + x, y = -x, -y + elif op == 3: + x, y = -y, x + elif op == 4: + x = -x + elif op == 5: + y = -y + elif op == 6: + x, y = y, x + elif op == 7: + x, y = -y, -x + + if is_torch: + return torch.stack([x, y], dim=-1), op + else: + return np.stack([x, y], axis=-1), op + + +def apply_dihedral_to_tensor(img, op): + """Apply a dihedral operation to a ``(C, H, W)`` image tensor. + + Each operation matches :func:`apply_dihedral_augmentation` so that pixel + content and spatial coordinates stay aligned after augmentation. + """ + if op == 0: + return img + if op == 1: + return torch.rot90(img, k=1, dims=[1, 2]) + if op == 2: + return torch.rot90(img, k=2, dims=[1, 2]) + if op == 3: + return torch.rot90(img, k=3, dims=[1, 2]) + if op == 4: + return torch.flip(img, dims=[2]) + if op == 5: + return torch.flip(img, dims=[1]) + if op == 6: + return img.transpose(1, 2) + if op == 7: + return img.transpose(1, 2).flip(dims=[1, 2]) + return img + + +def normalize_coordinates(coords: np.ndarray) -> np.ndarray: + """Auto-normalize physical coordinates to integer grid indices.""" + if len(coords) == 0: + return coords + + x_vals = np.unique(coords[:, 0]) + y_vals = np.unique(coords[:, 1]) + + dx = x_vals[1:] - x_vals[:-1] + dy = y_vals[1:] - y_vals[:-1] + + steps = np.concatenate([dx, dy]) + valid_steps = steps[steps > 0.5] + + if len(valid_steps) == 0: + return coords + + step_size = valid_steps.min() + if step_size >= 2.0: + return np.round(coords / step_size).astype(coords.dtype) + return coords diff --git a/src/spatial_transcript_former/models/interaction.py b/src/spatial_transcript_former/models/interaction.py index 743e308..3c0dca1 100644 --- a/src/spatial_transcript_former/models/interaction.py +++ b/src/spatial_transcript_former/models/interaction.py @@ -131,6 +131,7 @@ def __init__( self.backbone, self.image_feature_dim = get_backbone( backbone_name, pretrained=pretrained ) + self._backbone_name = backbone_name # 2. Image Projector self.image_proj = nn.Linear(self.image_feature_dim, token_dim) @@ -233,6 +234,26 @@ def get_sparsity_loss(self): """ return torch.norm(self.gene_reconstructor.weight, p=1) + @classmethod + def from_pretrained(cls, checkpoint_dir, device="cpu", **kwargs): + """Load a pretrained SpatialTranscriptFormer from a checkpoint directory. + + The directory should contain ``config.json`` and ``model.pth`` + (created by :func:`~spatial_transcript_former.checkpoint.save_pretrained`). + An optional ``gene_names.json`` will be loaded as ``model.gene_names``. + + Args: + checkpoint_dir (str): Path to checkpoint directory. + device (str): Torch device to load onto. + **kwargs: Override any config values. + + Returns: + SpatialTranscriptFormer: Model in eval mode. + """ + from spatial_transcript_former.checkpoint import load_pretrained + + return load_pretrained(checkpoint_dir, device=device, **kwargs) + def forward( self, x, diff --git a/src/spatial_transcript_former/predict.py b/src/spatial_transcript_former/predict.py index 6df3245..fba4115 100644 --- a/src/spatial_transcript_former/predict.py +++ b/src/spatial_transcript_former/predict.py @@ -1,7 +1,369 @@ +""" +Inference API for SpatialTranscriptFormer. + +Provides three user-facing components: + +* :class:`FeatureExtractor` — wraps a backbone (ResNet, Phikon, …) to turn + raw image patches into feature embeddings. +* :class:`Predictor` — wraps a trained :class:`SpatialTranscriptFormer` model + to predict gene expression from features + spatial coordinates. +* :func:`inject_predictions` — injects predictions into an AnnData object + for seamless Scanpy integration. + +Additionally retains the :func:`plot_training_summary` helper used during +training validation. +""" + import os +from typing import Dict, List, Optional, Union + import numpy as np +import torch +import torch.nn as nn +from torchvision import transforms + + +# ═══════════════════════════════════════════════════════════════════════ +# FeatureExtractor +# ═══════════════════════════════════════════════════════════════════════ + + +class FeatureExtractor: + """Extract feature embeddings from histology image patches. + + Wraps a backbone model (e.g. ResNet-50, Phikon, CTransPath) and its + associated normalization transform so callers don't need to worry + about model-specific preprocessing. + + Example:: + + extractor = FeatureExtractor(backbone="phikon", device="cuda") + # images: Tensor of shape (N, 3, 224, 224), uint8 or float [0, 1] + features = extractor(images) # → (N, D) + features = extractor.extract_batch(images, batch_size=64) + """ + + # Standard ImageNet normalization — used by all current backbones + _IMAGENET_NORM = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) + + def __init__( + self, + backbone: str = "resnet50", + device: str = "cpu", + pretrained: bool = True, + transform: Optional[transforms.Compose] = None, + ): + """ + Args: + backbone: Backbone identifier (see ``models.backbones``). + device: Torch device string. + pretrained: Whether to load pretrained backbone weights. + transform: Optional custom normalization transform. If + ``None``, standard ImageNet normalization is applied. + """ + from spatial_transcript_former.models.backbones import get_backbone + + self.backbone_name = backbone + self.device = torch.device(device) + self.model, self.feature_dim = get_backbone(backbone, pretrained=pretrained) + self.model.to(self.device) + self.model.eval() + self.transform = transform or self._IMAGENET_NORM + + @torch.no_grad() + def __call__(self, images: torch.Tensor) -> torch.Tensor: + """Extract features from a batch of images. + + Args: + images: ``(N, 3, H, W)`` float tensor in ``[0, 1]`` range. + + Returns: + ``(N, D)`` feature tensor on the same device. + """ + images = images.to(self.device) + images = self.transform(images) + return self.model(images) + + @torch.no_grad() + def extract_batch( + self, + images: torch.Tensor, + batch_size: int = 64, + ) -> torch.Tensor: + """Extract features in batches to manage GPU memory. + + Args: + images: ``(N, 3, H, W)`` float tensor in ``[0, 1]``. + batch_size: Number of images per forward pass. + + Returns: + ``(N, D)`` concatenated feature tensor on CPU. + """ + all_features = [] + for i in range(0, len(images), batch_size): + batch = images[i : i + batch_size] + features = self(batch) + all_features.append(features.cpu()) + return torch.cat(all_features, dim=0) + + +# ═══════════════════════════════════════════════════════════════════════ +# Predictor +# ═══════════════════════════════════════════════════════════════════════ + + +class Predictor: + """High-level inference wrapper for SpatialTranscriptFormer. + + Manages model state (eval mode, device, AMP), and provides + convenience methods for single-patch and whole-slide inference. + + Example:: + + model = SpatialTranscriptFormer.from_pretrained("./checkpoint/") + predictor = Predictor(model, device="cuda") + + # Single patch + genes = predictor.predict_patch(image_tensor) + + # Whole slide (pre-extracted features) + genes = predictor.predict_wsi(features, coords) + """ + + def __init__( + self, + model: nn.Module, + device: str = "cpu", + use_amp: bool = False, + ): + """ + Args: + model: A trained :class:`SpatialTranscriptFormer` instance. + device: Torch device string. + use_amp: Enable automatic mixed precision for inference. + """ + self.device = torch.device(device) + self.model = model.to(self.device) + self.model.eval() + self.use_amp = use_amp + + # Expose gene names if the model has them (set by from_pretrained) + self.gene_names: Optional[List[str]] = getattr(model, "gene_names", None) + + @torch.no_grad() + def predict_patch( + self, + image: torch.Tensor, + return_pathways: bool = False, + ) -> Union[torch.Tensor, tuple]: + """Predict gene expression from a single image patch. + + Args: + image: ``(1, 3, H, W)`` or ``(3, H, W)`` image tensor. + return_pathways: Also return pathway activation scores. + + Returns: + Gene expression tensor ``(1, G)`` or tuple with pathway scores. + """ + if image.dim() == 3: + image = image.unsqueeze(0) + + image = image.to(self.device) + + # For single-patch mode, spatial PE needs a dummy coordinate + rel_coords = None + if getattr(self.model, "use_spatial_pe", False): + rel_coords = torch.zeros(image.shape[0], 1, 2, device=self.device) + + with torch.amp.autocast("cuda", enabled=self.use_amp): + result = self.model( + image, + rel_coords=rel_coords, + return_pathways=return_pathways, + ) + return result + + @torch.no_grad() + def predict_wsi( + self, + features: torch.Tensor, + coords: torch.Tensor, + return_pathways: bool = False, + return_dense: bool = False, + ) -> Union[torch.Tensor, tuple]: + """Predict gene expression from pre-extracted whole-slide features. + + Args: + features: ``(N, D)`` or ``(1, N, D)`` feature embeddings. + coords: ``(N, 2)`` or ``(1, N, 2)`` spatial coordinates. + return_pathways: Also return pathway activation scores. + return_dense: If True, return per-patch predictions ``(1, N, G)``. + + Returns: + Gene expression tensor ``(1, G)`` or ``(1, N, G)`` if dense, + optionally with pathway scores as a tuple. + """ + # Ensure batch dimension + if features.dim() == 2: + features = features.unsqueeze(0) + if coords.dim() == 2: + coords = coords.unsqueeze(0) + + # Validate feature dimension matches model expectation + expected_dim = self.model.image_proj.in_features + actual_dim = features.shape[-1] + if actual_dim != expected_dim: + raise ValueError( + f"Feature dimension mismatch: model expects {expected_dim} " + f"(backbone '{getattr(self.model, '_backbone_name', 'unknown')}'), " + f"but got {actual_dim}. " + f"Did you extract features with the correct backbone?" + ) + + features = features.to(self.device) + coords = coords.to(self.device) + + with torch.amp.autocast("cuda", enabled=self.use_amp): + result = self.model( + features, + rel_coords=coords, + return_pathways=return_pathways, + return_dense=return_dense, + ) + return result + + @torch.no_grad() + def predict( + self, + features: torch.Tensor, + coords: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, tuple]: + """Unified prediction entry point. + + Automatically dispatches to :meth:`predict_patch` (if input looks + like a raw image) or :meth:`predict_wsi` (if input looks like + pre-computed features). + + Args: + features: Either ``(B, 3, H, W)`` image or ``(N, D)`` features. + coords: Required for WSI mode, ignored for patch mode. + **kwargs: Forwarded to the underlying predict method. + """ + if features.dim() == 4 and features.shape[1] == 3: + # Looks like an image tensor + return self.predict_patch(features, **kwargs) + else: + if coords is None: + raise ValueError( + "coords are required for feature-based prediction. " + "Pass spatial coordinates alongside pre-extracted features." + ) + return self.predict_wsi(features, coords, **kwargs) + + +# ═══════════════════════════════════════════════════════════════════════ +# Scanpy / AnnData Integration +# ═══════════════════════════════════════════════════════════════════════ + + +def inject_predictions( + adata, + coords: np.ndarray, + predictions: np.ndarray, + gene_names: Optional[List[str]] = None, + pathway_scores: Optional[np.ndarray] = None, + pathway_names: Optional[List[str]] = None, +): + """Inject SpatialTranscriptFormer predictions into an AnnData object. + + Registers spatial coordinates and gene/pathway predictions into the + appropriate AnnData slots so that standard Scanpy spatial plotting + and analysis tools work out of the box. + + Args: + adata: A :class:`anndata.AnnData` instance. Must have the same + number of observations as ``coords`` rows. + coords: ``(N, 2)`` spatial coordinates array. + predictions: ``(N, G)`` predicted gene expression array. + gene_names: Optional list of G gene symbols. If provided, they + are set as ``adata.var_names``. + pathway_scores: Optional ``(N, P)`` pathway activation scores. + pathway_names: Optional list of P pathway names. + + Returns: + The modified ``adata`` (in-place). + + Example:: + + import scanpy as sc + from spatial_transcript_former.predict import inject_predictions + + adata = sc.AnnData(obs=pd.DataFrame(index=[f"spot_{i}" for i in range(N)])) + inject_predictions(adata, coords, preds, gene_names=model.gene_names) + + sc.pl.spatial(adata, color="TP53") + """ + try: + import anndata # noqa: F401 + except ImportError: + raise ImportError( + "anndata is required for inject_predictions. " + "Install it with: pip install anndata" + ) + + n_obs = adata.n_obs + if coords.shape[0] != n_obs: + raise ValueError( + f"coords has {coords.shape[0]} rows but adata has {n_obs} observations" + ) + if predictions.shape[0] != n_obs: + raise ValueError( + f"predictions has {predictions.shape[0]} rows but adata has {n_obs} observations" + ) + + # Convert torch tensors to numpy if needed + if isinstance(coords, torch.Tensor): + coords = coords.cpu().numpy() + if isinstance(predictions, torch.Tensor): + predictions = predictions.cpu().numpy() + + # 1. Spatial coordinates + adata.obsm["spatial"] = coords + + # 2. Gene predictions → adata.X + import pandas as pd + + if gene_names is not None: + if len(gene_names) != predictions.shape[1]: + raise ValueError( + f"gene_names length ({len(gene_names)}) != prediction columns ({predictions.shape[1]})" + ) + adata.var = pd.DataFrame(index=gene_names) + + adata.X = predictions + + # 3. Pathway scores (optional) → adata.obsm + if pathway_scores is not None: + if isinstance(pathway_scores, torch.Tensor): + pathway_scores = pathway_scores.cpu().numpy() + adata.obsm["spatial_pathways"] = pathway_scores + + if pathway_names is not None: + adata.uns["pathway_names"] = pathway_names + + return adata + + +# ═══════════════════════════════════════════════════════════════════════ +# Training Visualization (existing utility) +# ═══════════════════════════════════════════════════════════════════════ +# Kept here for backwards compatibility with training scripts. + import matplotlib.pyplot as plt -from typing import List, Optional def plot_training_summary( diff --git a/src/spatial_transcript_former/recipes/__init__.py b/src/spatial_transcript_former/recipes/__init__.py new file mode 100644 index 0000000..ab95942 --- /dev/null +++ b/src/spatial_transcript_former/recipes/__init__.py @@ -0,0 +1,7 @@ +""" +Dataset recipes for SpatialTranscriptFormer. + +Each sub-package (e.g. ``recipes.hest``) provides dataset-specific +loaders that implement the :class:`~spatial_transcript_former.data.base.SpatialDataset` +contract. +""" diff --git a/src/spatial_transcript_former/recipes/custom/README.md b/src/spatial_transcript_former/recipes/custom/README.md new file mode 100644 index 0000000..94734dd --- /dev/null +++ b/src/spatial_transcript_former/recipes/custom/README.md @@ -0,0 +1,137 @@ +# Bring Your Own Data + +This guide shows how to implement a custom dataset for SpatialTranscriptFormer using the `SpatialDataset` contract. + +## Data Contract + +Every dataset must subclass `SpatialDataset` and return 3-tuples from `__getitem__`: + +```python +(features, gene_counts, rel_coords) +``` + +| Field | Shape | Description | +| --- | --- | --- | +| `features` | `(S, D)` | Patch embeddings (`S` = 1 + neighbours, `D` = backbone dim) | +| `gene_counts` | `(G,)` | Gene expression targets for the centre patch | +| `rel_coords` | `(S, 2)` | Spatial coordinates relative to centre (centre = `[0, 0]`) | + +## Minimal Example + +```python +import torch +import numpy as np +from spatial_transcript_former.data.base import SpatialDataset + +class MyVisiumDataset(SpatialDataset): + """Custom dataset for your own spatial transcriptomics data.""" + + def __init__(self, features, gene_matrix, coords, gene_names=None): + """ + Args: + features: (N, D) pre-extracted backbone features + gene_matrix: (N, G) gene expression matrix + coords: (N, 2) spatial coordinates + gene_names: optional list of G gene symbols + """ + self._features = torch.as_tensor(features, dtype=torch.float32) + self._genes = torch.as_tensor(gene_matrix, dtype=torch.float32) + self._coords = torch.as_tensor(coords, dtype=torch.float32) + self._gene_names = gene_names + + def __len__(self): + return len(self._features) + + def __getitem__(self, idx): + # Centre patch feature (unsqueeze to get shape (1, D)) + feat = self._features[idx].unsqueeze(0) + + # Gene expression target + genes = self._genes[idx] + + # Relative coordinate (centre is always [0, 0]) + rel_coord = torch.zeros(1, 2) + + return feat, genes, rel_coord + + @property + def num_genes(self): + return self._genes.shape[1] + + @property + def gene_names(self): + return self._gene_names +``` + +## Using with the Trainer + +```python +from torch.utils.data import DataLoader, random_split +from spatial_transcript_former import SpatialTranscriptFormer, Predictor +from spatial_transcript_former.training.engine import train_one_epoch, validate +from spatial_transcript_former.training.losses import CompositeLoss + +# 1. Create your dataset +dataset = MyVisiumDataset(features, gene_matrix, coords, gene_names=my_genes) + +# 2. Split +train_ds, val_ds = random_split(dataset, [0.8, 0.2]) +train_loader = DataLoader(train_ds, batch_size=32, shuffle=True) +val_loader = DataLoader(val_ds, batch_size=64) + +# 3. Build model +model = SpatialTranscriptFormer( + num_genes=dataset.num_genes, + backbone_name="phikon", + pretrained=False, # backbone weights not needed for pre-extracted features + use_spatial_pe=True, +).to(device) + +# 4. Train +criterion = CompositeLoss(alpha=1.0) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + +for epoch in range(50): + loss = train_one_epoch(model, train_loader, criterion, optimizer, device, + whole_slide=False) + metrics = validate(model, val_loader, criterion, device) + print(f"Epoch {epoch}: loss={loss:.4f}, val={metrics['val_loss']:.4f}") + +# 5. Save for inference +from spatial_transcript_former import save_pretrained +save_pretrained(model, "./my_model/", gene_names=my_genes) +``` + +## Preparing Your Data + +### From AnnData / Scanpy + +```python +import scanpy as sc +import numpy as np + +adata = sc.read_h5ad("my_experiment.h5ad") + +# Gene expression matrix +gene_matrix = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X +gene_names = list(adata.var_names) + +# Spatial coordinates +coords = adata.obsm["spatial"] + +# Pre-extract features using FeatureExtractor +from spatial_transcript_former import FeatureExtractor +extractor = FeatureExtractor(backbone="phikon", device="cuda") +# ... extract patches from WSI and run through extractor +``` + +### From Raw Patches + +If you have image patches as tensors: + +```python +from spatial_transcript_former import FeatureExtractor + +extractor = FeatureExtractor(backbone="phikon", device="cuda") +features = extractor.extract_batch(patch_tensor, batch_size=64) # (N, 768) +``` diff --git a/src/spatial_transcript_former/recipes/hest/README.md b/src/spatial_transcript_former/recipes/hest/README.md new file mode 100644 index 0000000..a025c89 --- /dev/null +++ b/src/spatial_transcript_former/recipes/hest/README.md @@ -0,0 +1,16 @@ +# HEST-1k Recipe (Exemplar) + +This directory serves as a comprehensive **exemplar** for training `SpatialTranscriptFormer` on the **HEST-1k** benchmark dataset. + +While the core `SpatialTranscriptFormer` framework is dataset-agnostic, this recipe provides a complete, out-of-the-box pipeline for reproducing our benchmarks, including data downloading, preprocessing, and specialized dataloaders. + +## Components + +- **`dataset.py`**: Contains `HEST_Dataset` and `HEST_FeatureDataset`, which subclass `SpatialDataset` to handle the specific `.h5ad` structure and metadata conventions of the HEST dataset. +- **`io.py`**: Utilities for reading spatial graphs, coordinates, and `.h5ad` matrices. +- **`utils.py`**: HEST-specific dataset setup routines, splitting logic, and vocabulary loading. +- **`download.py`**: Logic for fetching subsets of the gated HEST dataset from Hugging Face. + +## Usage + +For complete CLI usage and training preset commands, refer to the main **[README.md](../../../../README.md)** and the **[Training Guide](../../../../docs/TRAINING_GUIDE.md)**. diff --git a/src/spatial_transcript_former/recipes/hest/__init__.py b/src/spatial_transcript_former/recipes/hest/__init__.py new file mode 100644 index 0000000..04767da --- /dev/null +++ b/src/spatial_transcript_former/recipes/hest/__init__.py @@ -0,0 +1,80 @@ +""" +HEST-1k recipe for SpatialTranscriptFormer. + +Provides HEST-specific dataset classes, download utilities, and dataloader +factories. All components implement the generic :class:`SpatialDataset` +contract from ``spatial_transcript_former.data.base``. + +Quick start:: + + from spatial_transcript_former.recipes.hest import ( + HEST_Dataset, + HEST_FeatureDataset, + get_hest_dataloader, + get_hest_feature_dataloader, + get_sample_ids, + setup_dataloaders, + download_hest_subset, + ) +""" + +# Dataset classes and DataLoader factories +from spatial_transcript_former.recipes.hest.dataset import ( + HEST_Dataset, + HEST_FeatureDataset, + get_hest_dataloader, + get_hest_feature_dataloader, + load_gene_expression_matrix, + load_global_genes, +) + +# I/O utilities +from spatial_transcript_former.recipes.hest.io import ( + get_hest_data_dir, + load_h5ad_metadata, + get_image_from_h5ad, + decode_h5_string, +) + +# Download +from spatial_transcript_former.recipes.hest.download import ( + download_hest_subset, + download_metadata, + filter_samples, +) + +# Sample discovery and dataloader setup +from spatial_transcript_former.recipes.hest.utils import ( + get_sample_ids, + setup_dataloaders, +) + +# Splitting +from spatial_transcript_former.recipes.hest.splitting import split_hest_patients + +# Vocab building +from spatial_transcript_former.recipes.hest.build_vocab import scan_h5ad_files + +__all__ = [ + # Datasets + "HEST_Dataset", + "HEST_FeatureDataset", + "get_hest_dataloader", + "get_hest_feature_dataloader", + "load_gene_expression_matrix", + "load_global_genes", + # I/O + "get_hest_data_dir", + "load_h5ad_metadata", + "get_image_from_h5ad", + "decode_h5_string", + # Download + "download_hest_subset", + "download_metadata", + "filter_samples", + # Utils + "get_sample_ids", + "setup_dataloaders", + "split_hest_patients", + "scan_h5ad_files", +] diff --git a/src/spatial_transcript_former/data/build_vocab.py b/src/spatial_transcript_former/recipes/hest/build_vocab.py similarity index 98% rename from src/spatial_transcript_former/data/build_vocab.py rename to src/spatial_transcript_former/recipes/hest/build_vocab.py index 21da9ab..4481d92 100644 --- a/src/spatial_transcript_former/data/build_vocab.py +++ b/src/spatial_transcript_former/recipes/hest/build_vocab.py @@ -11,7 +11,10 @@ # Add src to path sys.path.append(os.path.abspath("src")) -from spatial_transcript_former.data.io import get_hest_data_dir, load_h5ad_metadata +from spatial_transcript_former.recipes.hest.io import ( + get_hest_data_dir, + load_h5ad_metadata, +) from spatial_transcript_former.config import get_config from spatial_transcript_former.data.pathways import ( download_msigdb_gmt, diff --git a/src/spatial_transcript_former/data/dataset.py b/src/spatial_transcript_former/recipes/hest/dataset.py similarity index 87% rename from src/spatial_transcript_former/data/dataset.py rename to src/spatial_transcript_former/recipes/hest/dataset.py index 69d3e00..774f3fc 100644 --- a/src/spatial_transcript_former/data/dataset.py +++ b/src/spatial_transcript_former/recipes/hest/dataset.py @@ -23,116 +23,22 @@ import pandas as pd import numpy as np from .io import decode_h5_string, load_h5ad_metadata -from torch.utils.data import Dataset, DataLoader, ConcatDataset +from spatial_transcript_former.data.base import ( + SpatialDataset, + apply_dihedral_augmentation, + apply_dihedral_to_tensor, + normalize_coordinates, +) +from torch.utils.data import DataLoader, ConcatDataset from scipy.sparse import csr_matrix from typing import List, Optional, Tuple, Union from scipy.spatial import KDTree import torch.nn.functional as F -# --------------------------------------------------------------------------- -# Spatial augmentation helpers -# --------------------------------------------------------------------------- - - -def apply_dihedral_augmentation(coords, op=None): - """Apply one of the 8 dihedral symmetries to a set of 2-D coordinates. - - The dihedral group D4 contains 4 rotations and 4 reflections, which leave - a square grid invariant. Applying the same operation to both pixel tensors - and coordinate tensors keeps spatial relationships consistent after - augmentation. - - Args: - coords (torch.Tensor or np.ndarray): Shape ``(N, 2)`` array of (x, y) - coordinates defined in a *centred* frame (i.e. the origin is the - centre of the slide region, not the top-left corner). - op (int, optional): Integer in ``[0, 7]`` selecting the operation. - If ``None``, one is chosen uniformly at random. - - Returns: - tuple: - - **augmented_coords** – same type and shape as the input. - - **op** (*int*) – the operation that was applied (useful for - applying the same transformation to the corresponding image). - - Operations - ---------- - ===== ============== - Index Description - ===== ============== - 0 Identity - 1 Rotate 90° CCW - 2 Rotate 180° - 3 Rotate 270° CCW - 4 Flip horizontal (negate x) - 5 Flip vertical (negate y) - 6 Transpose (swap x and y) - 7 Anti-transpose (swap and negate both) - ===== ============== - """ - is_torch = isinstance(coords, torch.Tensor) - if is_torch: - x, y = coords[..., 0].clone(), coords[..., 1].clone() - else: - x, y = coords[..., 0].copy(), coords[..., 1].copy() - - if op is None: - op = np.random.randint(0, 8) - - if op == 0: # Identity - pass - elif op == 1: # Rotate 90° CCW - x, y = y, -x - elif op == 2: # Rotate 180° - x, y = -x, -y - elif op == 3: # Rotate 270° CCW - x, y = -y, x - elif op == 4: # Flip horizontal - x = -x - elif op == 5: # Flip vertical - y = -y - elif op == 6: # Transpose - x, y = y, x - elif op == 7: # Anti-transpose - x, y = -y, -x - - if is_torch: - return torch.stack([x, y], dim=-1), op - else: - return np.stack([x, y], axis=-1), op - - -def apply_dihedral_to_tensor(img, op): - """Apply a dihedral operation to a ``(C, H, W)`` image tensor. - - Each operation matches the coordinate transform in - :func:`apply_dihedral_augmentation` so that pixel content and spatial - coordinates stay aligned after augmentation. - - Args: - img (torch.Tensor): Image tensor of shape ``(C, H, W)``. - op (int): Operation index in ``[0, 7]``. - - Returns: - torch.Tensor: Transformed image tensor, same shape as ``img``. - """ - if op == 0: - return img - if op == 1: - return torch.rot90(img, k=1, dims=[1, 2]) # Rotate 90° CCW - if op == 2: - return torch.rot90(img, k=2, dims=[1, 2]) # Rotate 180° - if op == 3: - return torch.rot90(img, k=3, dims=[1, 2]) # Rotate 270° CCW - if op == 4: - return torch.flip(img, dims=[2]) # Flip horizontal (width axis) - if op == 5: - return torch.flip(img, dims=[1]) # Flip vertical (height axis) - if op == 6: - return img.transpose(1, 2) # Transpose - if op == 7: - return img.transpose(1, 2).flip(dims=[1, 2]) # Anti-transpose - return img +# Augmentation helpers and normalize_coordinates are now in data.base +# and imported above. Kept here for backward compatibility: +# from spatial_transcript_former.recipes.hest.dataset import apply_dihedral_augmentation +# still works via the import at the top of this file. # --------------------------------------------------------------------------- @@ -140,7 +46,7 @@ def apply_dihedral_to_tensor(img, op): # --------------------------------------------------------------------------- -class HEST_Dataset(Dataset): +class HEST_Dataset(SpatialDataset): """PyTorch Dataset that loads raw histology patches from a HEST ``.h5`` file. Each item is a tuple ``(patches, gene_counts, rel_coords)`` where: @@ -398,27 +304,7 @@ def load_gene_expression_matrix( return final_subset, valid_patch_mask, selected_names -def normalize_coordinates(coords: np.ndarray) -> np.ndarray: - """Auto-normalizes physical coordinates to integer grid indices.""" - if len(coords) == 0: - return coords - - x_vals = np.unique(coords[:, 0]) - y_vals = np.unique(coords[:, 1]) - - dx = x_vals[1:] - x_vals[:-1] - dy = y_vals[1:] - y_vals[:-1] - - steps = np.concatenate([dx, dy]) - valid_steps = steps[steps > 0.5] - - if len(valid_steps) == 0: - return coords - - step_size = valid_steps.min() - if step_size >= 2.0: - return np.round(coords / step_size).astype(coords.dtype) - return coords +# normalize_coordinates is now in data.base and imported above. def load_global_genes(root_dir: str, num_genes: int = 1000) -> List[str]: @@ -587,7 +473,7 @@ def get_hest_dataloader( # --------------------------------------------------------------------------- -class HEST_FeatureDataset(Dataset): +class HEST_FeatureDataset(SpatialDataset): """Dataset for pre-computed backbone feature vectors. Loads CTransPath (or any other backbone) feature vectors from a ``.pt`` diff --git a/src/spatial_transcript_former/data/download.py b/src/spatial_transcript_former/recipes/hest/download.py similarity index 100% rename from src/spatial_transcript_former/data/download.py rename to src/spatial_transcript_former/recipes/hest/download.py diff --git a/src/spatial_transcript_former/data/extract_features.py b/src/spatial_transcript_former/recipes/hest/extract_features.py similarity index 100% rename from src/spatial_transcript_former/data/extract_features.py rename to src/spatial_transcript_former/recipes/hest/extract_features.py diff --git a/src/spatial_transcript_former/data/io.py b/src/spatial_transcript_former/recipes/hest/io.py similarity index 98% rename from src/spatial_transcript_former/data/io.py rename to src/spatial_transcript_former/recipes/hest/io.py index 9c65390..f3a6b39 100644 --- a/src/spatial_transcript_former/data/io.py +++ b/src/spatial_transcript_former/recipes/hest/io.py @@ -9,7 +9,7 @@ import h5py import numpy as np from typing import List, Dict, Any, Tuple, Optional -from ..config import get_config +from spatial_transcript_former.config import get_config def get_hest_data_dir() -> str: diff --git a/src/spatial_transcript_former/data/splitting.py b/src/spatial_transcript_former/recipes/hest/splitting.py similarity index 100% rename from src/spatial_transcript_former/data/splitting.py rename to src/spatial_transcript_former/recipes/hest/splitting.py diff --git a/src/spatial_transcript_former/data/utils.py b/src/spatial_transcript_former/recipes/hest/utils.py similarity index 100% rename from src/spatial_transcript_former/data/utils.py rename to src/spatial_transcript_former/recipes/hest/utils.py diff --git a/src/spatial_transcript_former/train.py b/src/spatial_transcript_former/train.py index 286be45..530b515 100644 --- a/src/spatial_transcript_former/train.py +++ b/src/spatial_transcript_former/train.py @@ -24,7 +24,10 @@ from spatial_transcript_former.training.engine import train_one_epoch, validate from spatial_transcript_former.training.experiment_logger import ExperimentLogger from spatial_transcript_former.visualization import run_inference_plot -from spatial_transcript_former.data.utils import get_sample_ids, setup_dataloaders +from spatial_transcript_former.recipes.hest.utils import ( + get_sample_ids, + setup_dataloaders, +) from spatial_transcript_former.training.arguments import parse_args from spatial_transcript_former.training.builder import setup_model, setup_criterion diff --git a/src/spatial_transcript_former/training/__init__.py b/src/spatial_transcript_former/training/__init__.py new file mode 100644 index 0000000..854f9bb --- /dev/null +++ b/src/spatial_transcript_former/training/__init__.py @@ -0,0 +1,36 @@ +""" +Training subpackage for SpatialTranscriptFormer. + +Exposes the high-level :class:`Trainer` and the lower-level building blocks. +""" + +from .trainer import Trainer, TrainerCallback, EarlyStoppingCallback +from .engine import train_one_epoch, validate +from .losses import ( + CompositeLoss, + PCCLoss, + MaskedMSELoss, + MaskedHuberLoss, + ZINBLoss, + AuxiliaryPathwayLoss, +) +from .experiment_logger import ExperimentLogger + +__all__ = [ + # High-level + "Trainer", + "TrainerCallback", + "EarlyStoppingCallback", + # Engine + "train_one_epoch", + "validate", + # Losses + "CompositeLoss", + "PCCLoss", + "MaskedMSELoss", + "MaskedHuberLoss", + "ZINBLoss", + "AuxiliaryPathwayLoss", + # Logging + "ExperimentLogger", +] diff --git a/src/spatial_transcript_former/training/trainer.py b/src/spatial_transcript_former/training/trainer.py new file mode 100644 index 0000000..7deff27 --- /dev/null +++ b/src/spatial_transcript_former/training/trainer.py @@ -0,0 +1,392 @@ +""" +High-level Trainer for SpatialTranscriptFormer. + +Wraps the low-level :func:`train_one_epoch` / :func:`validate` engine with +LR scheduling, checkpointing, experiment logging, and early stopping. + +Example:: + + from spatial_transcript_former import SpatialTranscriptFormer + from spatial_transcript_former.training import Trainer + from spatial_transcript_former.training.losses import CompositeLoss + + model = SpatialTranscriptFormer(num_genes=460, backbone_name="phikon", ...) + trainer = Trainer( + model=model, + train_loader=train_dl, + val_loader=val_dl, + criterion=CompositeLoss(), + epochs=100, + ) + results = trainer.fit() + trainer.save_pretrained("./release/v1/", gene_names=my_genes) +""" + +import os +import time +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.optim as optim + +from spatial_transcript_former.training.engine import train_one_epoch, validate +from spatial_transcript_former.training.experiment_logger import ExperimentLogger +from spatial_transcript_former.training.checkpoint import ( + save_checkpoint, + load_checkpoint, +) + + +# --------------------------------------------------------------------------- +# Callback protocol +# --------------------------------------------------------------------------- + + +class TrainerCallback: + """Base class for Trainer callbacks. + + Override any of these hooks. All methods are no-ops by default. + """ + + def on_train_begin(self, trainer: "Trainer") -> None: + """Called at the start of :meth:`Trainer.fit`.""" + + def on_train_end(self, trainer: "Trainer", results: dict) -> None: + """Called at the end of :meth:`Trainer.fit`.""" + + def on_epoch_begin(self, trainer: "Trainer", epoch: int) -> None: + """Called at the beginning of each epoch.""" + + def on_epoch_end(self, trainer: "Trainer", epoch: int, metrics: dict) -> None: + """Called after validation. ``metrics`` has train_loss, val_loss, etc.""" + + def should_stop(self, trainer: "Trainer", epoch: int, metrics: dict) -> bool: + """Return ``True`` to request early stopping.""" + return False + + +class EarlyStoppingCallback(TrainerCallback): + """Stop training when validation loss does not improve for ``patience`` epochs. + + Args: + patience: Number of epochs to wait for improvement. + min_delta: Minimum decrease in val_loss to be considered an improvement. + """ + + def __init__(self, patience: int = 15, min_delta: float = 0.0): + self.patience = patience + self.min_delta = min_delta + self._best_loss = float("inf") + self._wait = 0 + + def on_epoch_end(self, trainer, epoch, metrics): + val_loss = metrics.get("val_loss", float("inf")) + if val_loss < self._best_loss - self.min_delta: + self._best_loss = val_loss + self._wait = 0 + else: + self._wait += 1 + + def should_stop(self, trainer, epoch, metrics): + if self._wait >= self.patience: + print( + f"Early stopping: no improvement for {self.patience} epochs " + f"(best={self._best_loss:.4f})." + ) + return True + return False + + +# --------------------------------------------------------------------------- +# Trainer +# --------------------------------------------------------------------------- + + +class Trainer: + """High-level training orchestrator. + + Manages the full lifecycle: LR scheduling, gradient accumulation, + AMP, checkpointing, logging, and callbacks. + + Args: + model: The model to train (any ``nn.Module``). + train_loader: Training ``DataLoader``. + val_loader: Validation ``DataLoader``. + criterion: Loss function. + optimizer: Optimizer. If ``None``, ``AdamW`` is created with ``lr`` + and ``weight_decay``. + lr: Learning rate (used only when ``optimizer`` is ``None``). + weight_decay: Weight decay (used only when ``optimizer`` is ``None``). + epochs: Total training epochs. + warmup_epochs: Linear warmup epochs before cosine annealing. + device: Device string (``"cuda"``, ``"cpu"``). + output_dir: Directory for checkpoints and logs. + model_name: Name used in checkpoint filenames. + use_amp: Enable automatic mixed precision (FP16). + grad_accum_steps: Gradient accumulation steps. + whole_slide: Whole-slide prediction mode. + callbacks: List of :class:`TrainerCallback` instances. + resume: Attempt to resume from a checkpoint in ``output_dir``. + """ + + def __init__( + self, + model: torch.nn.Module, + train_loader: torch.utils.data.DataLoader, + val_loader: torch.utils.data.DataLoader, + criterion: torch.nn.Module, + *, + optimizer: Optional[torch.optim.Optimizer] = None, + lr: float = 1e-4, + weight_decay: float = 0.0, + epochs: int = 100, + warmup_epochs: int = 10, + device: str = "cuda", + output_dir: str = "./checkpoints", + model_name: str = "model", + use_amp: bool = False, + grad_accum_steps: int = 1, + whole_slide: bool = False, + callbacks: Optional[List[TrainerCallback]] = None, + resume: bool = False, + ): + self.model = model.to(device) + self.train_loader = train_loader + self.val_loader = val_loader + self.criterion = criterion.to(device) + self.epochs = epochs + self.warmup_epochs = warmup_epochs + self.device = device + self.output_dir = output_dir + self.model_name = model_name + self.use_amp = use_amp + self.grad_accum_steps = grad_accum_steps + self.whole_slide = whole_slide + self.callbacks = callbacks or [] + self.resume = resume + + # State + self.current_epoch: int = 0 + self.best_val_loss: float = float("inf") + self.history: List[Dict[str, Any]] = [] + + # Optimizer + if optimizer is not None: + self.optimizer = optimizer + else: + self.optimizer = optim.AdamW( + self.model.parameters(), lr=lr, weight_decay=weight_decay + ) + + # LR Scheduler: warmup → cosine + self._build_scheduler() + + # AMP scaler + self.scaler = torch.amp.GradScaler("cuda") if use_amp else None + + # Logger + os.makedirs(output_dir, exist_ok=True) + self.logger = ExperimentLogger( + output_dir, + { + "epochs": epochs, + "lr": lr, + "weight_decay": weight_decay, + "warmup_epochs": warmup_epochs, + "use_amp": use_amp, + "grad_accum_steps": grad_accum_steps, + "whole_slide": whole_slide, + "model_name": model_name, + }, + ) + + # Resume + if resume: + self._resume_from_checkpoint() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _build_scheduler(self): + warmup = optim.lr_scheduler.LinearLR( + self.optimizer, + start_factor=0.01, + total_iters=max(1, self.warmup_epochs), + ) + cosine = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, + T_max=max(1, self.epochs - self.warmup_epochs), + eta_min=1e-6, + ) + + if self.warmup_epochs > 0: + self.scheduler = optim.lr_scheduler.SequentialLR( + self.optimizer, + schedulers=[warmup, cosine], + milestones=[self.warmup_epochs], + ) + else: + self.scheduler = cosine + + def _resume_from_checkpoint(self): + schedulers = {"main": self.scheduler} + start_epoch, best_val_loss, loaded_schedulers = load_checkpoint( + self.model, + self.optimizer, + self.scaler, + schedulers, + self.output_dir, + self.model_name, + self.device, + ) + self.current_epoch = start_epoch + self.best_val_loss = best_val_loss + + # Catch up scheduler for old checkpoints + if start_epoch > 0 and self.scheduler.last_epoch < start_epoch: + for _ in range(start_epoch): + self.scheduler.step() + + # ------------------------------------------------------------------ + # Core training loop + # ------------------------------------------------------------------ + + def fit(self) -> Dict[str, Any]: + """Run the full training loop. + + Returns: + dict: Final training results including ``best_val_loss`` and + ``history`` (list of per-epoch metrics). + """ + for cb in self.callbacks: + cb.on_train_begin(self) + + for epoch in range(self.current_epoch, self.epochs): + self.current_epoch = epoch + + for cb in self.callbacks: + cb.on_epoch_begin(self, epoch) + + print(f"\nEpoch {epoch + 1}/{self.epochs}") + + # --- Train --- + train_loss = train_one_epoch( + self.model, + self.train_loader, + self.criterion, + self.optimizer, + self.device, + whole_slide=self.whole_slide, + scaler=self.scaler, + grad_accum_steps=self.grad_accum_steps, + ) + + # --- Validate --- + val_metrics = validate( + self.model, + self.val_loader, + self.criterion, + self.device, + whole_slide=self.whole_slide, + use_amp=self.use_amp, + ) + + val_loss = val_metrics["val_loss"] + lr = self.optimizer.param_groups[0]["lr"] + + print( + f"Train Loss: {train_loss:.4f}, " + f"Val Loss: {val_loss:.4f}, " + f"LR: {lr:.2e}" + ) + + # Step scheduler + self.scheduler.step() + + # --- Metrics --- + epoch_metrics = { + "train_loss": train_loss, + "val_loss": val_loss, + "lr": lr, + } + for key in ("val_mae", "val_pcc", "pred_variance", "attn_correlation"): + if val_metrics.get(key) is not None: + epoch_metrics[key] = val_metrics[key] + + # Hardware metrics (optional) + try: + import psutil + + epoch_metrics["sys_cpu_percent"] = psutil.cpu_percent() + epoch_metrics["sys_ram_percent"] = psutil.virtual_memory().percent + except ImportError: + pass + + if torch.cuda.is_available(): + epoch_metrics["sys_gpu_mem_mb"] = round( + torch.cuda.memory_allocated() / (1024**2), 2 + ) + + self.history.append(epoch_metrics) + self.logger.log_epoch(epoch + 1, epoch_metrics) + + # --- Best model --- + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + best_path = os.path.join( + self.output_dir, f"best_model_{self.model_name}.pth" + ) + torch.save(self.model.state_dict(), best_path) + print(f"Saved best model -> {best_path}") + + # --- Checkpoint --- + save_checkpoint( + self.model, + self.optimizer, + self.scaler, + {"main": self.scheduler}, + epoch, + self.best_val_loss, + self.output_dir, + self.model_name, + ) + + # --- Callbacks --- + for cb in self.callbacks: + cb.on_epoch_end(self, epoch, epoch_metrics) + + if any(cb.should_stop(self, epoch, epoch_metrics) for cb in self.callbacks): + print(f"Training stopped at epoch {epoch + 1}.") + break + + # --- Finalize --- + results = { + "best_val_loss": self.best_val_loss, + "epochs_completed": self.current_epoch + 1, + "history": self.history, + } + + self.logger.finalize(self.best_val_loss) + + for cb in self.callbacks: + cb.on_train_end(self, results) + + return results + + # ------------------------------------------------------------------ + # Convenience + # ------------------------------------------------------------------ + + def save_pretrained( + self, path: str, gene_names: Optional[List[str]] = None + ) -> None: + """Export an inference-ready checkpoint (strips optimizer state). + + Delegates to :func:`spatial_transcript_former.checkpoint.save_pretrained`. + """ + from spatial_transcript_former.checkpoint import ( + save_pretrained as _save_pretrained, + ) + + _save_pretrained(self.model, path, gene_names=gene_names) diff --git a/src/spatial_transcript_former/visualization.py b/src/spatial_transcript_former/visualization.py index 98a670a..0e85b96 100644 --- a/src/spatial_transcript_former/visualization.py +++ b/src/spatial_transcript_former/visualization.py @@ -3,7 +3,7 @@ import numpy as np import h5py import matplotlib.pyplot as plt -from spatial_transcript_former.data.utils import setup_dataloaders +from spatial_transcript_former.recipes.hest.utils import setup_dataloaders def _load_histology(h5ad_path): @@ -177,7 +177,7 @@ def run_inference_plot(model, args, sample_id, epoch, device): return # 4. Compute Pathway Truth - from spatial_transcript_former.data.dataset import load_global_genes + from spatial_transcript_former.recipes.hest.dataset import load_global_genes gene_names = load_global_genes(args.data_dir, args.num_genes) diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..5a15324 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,332 @@ +""" +Tests for the public Python API surface. + +Covers: + - Package-level imports (__init__.py) + - Config serialization (save_pretrained / load_pretrained round-trip) + - from_pretrained class method + - Predictor (patch and WSI mode) + - FeatureExtractor + - inject_predictions (AnnData integration) +""" + +import json +import os +import tempfile + +import numpy as np +import pytest +import torch + +from spatial_transcript_former import ( + SpatialTranscriptFormer, + Predictor, + FeatureExtractor, + save_pretrained, + load_pretrained, + inject_predictions, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def small_model(): + """A minimal SpatialTranscriptFormer for fast tests.""" + return SpatialTranscriptFormer( + num_genes=50, + num_pathways=10, + backbone_name="resnet50", + pretrained=False, + token_dim=64, + n_heads=4, + n_layers=2, + use_spatial_pe=True, + ) + + +@pytest.fixture +def checkpoint_dir(small_model): + """Save a small model to a temp directory and return the path.""" + with tempfile.TemporaryDirectory() as tmpdir: + gene_names = [f"GENE_{i}" for i in range(50)] + save_pretrained(small_model, tmpdir, gene_names=gene_names) + yield tmpdir + + +# --------------------------------------------------------------------------- +# Package imports +# --------------------------------------------------------------------------- + + +class TestPackageImports: + def test_model_importable(self): + from spatial_transcript_former import SpatialTranscriptFormer + + assert SpatialTranscriptFormer is not None + + def test_predictor_importable(self): + from spatial_transcript_former import Predictor + + assert Predictor is not None + + def test_feature_extractor_importable(self): + from spatial_transcript_former import FeatureExtractor + + assert FeatureExtractor is not None + + def test_checkpoint_functions_importable(self): + from spatial_transcript_former import save_pretrained, load_pretrained + + assert callable(save_pretrained) + assert callable(load_pretrained) + + def test_inject_predictions_importable(self): + from spatial_transcript_former import inject_predictions + + assert callable(inject_predictions) + + +# --------------------------------------------------------------------------- +# Config serialization round-trip +# --------------------------------------------------------------------------- + + +class TestCheckpointSerialization: + def test_save_creates_files(self, small_model): + """save_pretrained should create config.json and model.pth.""" + with tempfile.TemporaryDirectory() as tmpdir: + save_pretrained(small_model, tmpdir) + assert os.path.isfile(os.path.join(tmpdir, "config.json")) + assert os.path.isfile(os.path.join(tmpdir, "model.pth")) + + def test_save_with_gene_names(self, small_model): + """save_pretrained should create gene_names.json when provided.""" + with tempfile.TemporaryDirectory() as tmpdir: + names = [f"G{i}" for i in range(50)] + save_pretrained(small_model, tmpdir, gene_names=names) + path = os.path.join(tmpdir, "gene_names.json") + assert os.path.isfile(path) + with open(path) as f: + loaded = json.load(f) + assert loaded == names + + def test_save_gene_names_length_mismatch(self, small_model): + """Mismatched gene_names length should raise ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError, match="gene_names length"): + save_pretrained(small_model, tmpdir, gene_names=["A", "B"]) + + def test_config_json_contents(self, small_model): + """config.json should contain all expected architecture keys.""" + with tempfile.TemporaryDirectory() as tmpdir: + save_pretrained(small_model, tmpdir) + with open(os.path.join(tmpdir, "config.json")) as f: + config = json.load(f) + assert config["num_genes"] == 50 + assert config["num_pathways"] == 10 + assert config["token_dim"] == 64 + assert config["n_heads"] == 4 + assert config["n_layers"] == 2 + assert config["use_spatial_pe"] is True + + def test_round_trip_weights(self, small_model, checkpoint_dir): + """Weights should be identical after save → load.""" + loaded = load_pretrained(checkpoint_dir, device="cpu") + for (n1, p1), (n2, p2) in zip( + small_model.named_parameters(), loaded.named_parameters() + ): + assert n1 == n2, f"Parameter name mismatch: {n1} vs {n2}" + assert torch.allclose(p1, p2), f"Weight mismatch in {n1}" + + def test_round_trip_gene_names(self, checkpoint_dir): + """gene_names should survive the round trip.""" + model = load_pretrained(checkpoint_dir) + assert model.gene_names is not None + assert len(model.gene_names) == 50 + assert model.gene_names[0] == "GENE_0" + + def test_load_missing_config_raises(self): + """Loading from empty directory should raise FileNotFoundError.""" + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(FileNotFoundError, match="config.json"): + load_pretrained(tmpdir) + + +# --------------------------------------------------------------------------- +# from_pretrained +# --------------------------------------------------------------------------- + + +class TestFromPretrained: + def test_from_pretrained_returns_model(self, checkpoint_dir): + """from_pretrained should return a SpatialTranscriptFormer in eval mode.""" + model = SpatialTranscriptFormer.from_pretrained(checkpoint_dir) + assert isinstance(model, SpatialTranscriptFormer) + assert not model.training # should be in eval mode + + def test_from_pretrained_with_overrides(self, checkpoint_dir): + """Overriding dropout should be reflected in loaded model.""" + model = SpatialTranscriptFormer.from_pretrained(checkpoint_dir, dropout=0.0) + # The first transformer layer should use the override + layer = model.fusion_engine.layers[0] + # dropout is an attribute of the layer + assert layer.dropout.p == 0.0 + + +# --------------------------------------------------------------------------- +# Predictor +# --------------------------------------------------------------------------- + + +class TestPredictor: + def test_predict_patch(self, small_model): + """predict_patch should return (1, G) tensor.""" + predictor = Predictor(small_model, device="cpu") + image = torch.randn(1, 3, 224, 224) + result = predictor.predict_patch(image) + assert result.shape == (1, 50) + + def test_predict_patch_no_batch_dim(self, small_model): + """predict_patch should accept (3, H, W) without batch dim.""" + predictor = Predictor(small_model, device="cpu") + image = torch.randn(3, 224, 224) + result = predictor.predict_patch(image) + assert result.shape == (1, 50) + + def test_predict_wsi(self, small_model): + """predict_wsi should return (1, G) tensor for global mode.""" + predictor = Predictor(small_model, device="cpu") + features = torch.randn(20, small_model.image_proj.in_features) + coords = torch.randn(20, 2) + result = predictor.predict_wsi(features, coords) + assert result.shape == (1, 50) + + def test_predict_wsi_dense(self, small_model): + """predict_wsi with return_dense should return (1, N, G).""" + predictor = Predictor(small_model, device="cpu") + n_patches = 15 + features = torch.randn(n_patches, small_model.image_proj.in_features) + coords = torch.randn(n_patches, 2) + result = predictor.predict_wsi(features, coords, return_dense=True) + assert result.shape == (1, n_patches, 50) + + def test_predict_wsi_feature_dim_mismatch(self, small_model): + """Wrong feature dim should raise ValueError with helpful message.""" + predictor = Predictor(small_model, device="cpu") + features = torch.randn(10, 999) # wrong dim + coords = torch.randn(10, 2) + with pytest.raises(ValueError, match="Feature dimension mismatch"): + predictor.predict_wsi(features, coords) + + def test_predict_unified_dispatch_image(self, small_model): + """predict() should dispatch to patch mode for 4D image input.""" + predictor = Predictor(small_model, device="cpu") + image = torch.randn(1, 3, 224, 224) + result = predictor.predict(image) + assert result.shape == (1, 50) + + def test_predict_unified_dispatch_features(self, small_model): + """predict() should dispatch to WSI mode for 2D features.""" + predictor = Predictor(small_model, device="cpu") + features = torch.randn(10, small_model.image_proj.in_features) + coords = torch.randn(10, 2) + result = predictor.predict(features, coords) + assert result.shape == (1, 50) + + def test_predict_features_without_coords_raises(self, small_model): + """predict() on features without coords should raise.""" + predictor = Predictor(small_model, device="cpu") + features = torch.randn(10, small_model.image_proj.in_features) + with pytest.raises(ValueError, match="coords are required"): + predictor.predict(features) + + def test_gene_names_exposed(self, checkpoint_dir): + """Predictor should expose gene_names from the model.""" + model = SpatialTranscriptFormer.from_pretrained(checkpoint_dir) + predictor = Predictor(model) + assert predictor.gene_names is not None + assert len(predictor.gene_names) == 50 + + +# --------------------------------------------------------------------------- +# inject_predictions (AnnData) +# --------------------------------------------------------------------------- + + +class TestInjectPredictions: + def test_basic_injection(self): + """Should set adata.X and adata.obsm['spatial'].""" + anndata = pytest.importorskip("anndata") + import pandas as pd + + n, g = 100, 50 + adata = anndata.AnnData(obs=pd.DataFrame(index=[f"spot_{i}" for i in range(n)])) + coords = np.random.rand(n, 2) + predictions = np.random.rand(n, g).astype(np.float32) + + inject_predictions(adata, coords, predictions) + + assert adata.X is not None + assert adata.X.shape == (n, g) + np.testing.assert_array_equal(adata.obsm["spatial"], coords) + + def test_with_gene_names(self): + """Gene names should populate adata.var_names.""" + anndata = pytest.importorskip("anndata") + import pandas as pd + + n, g = 50, 20 + adata = anndata.AnnData(obs=pd.DataFrame(index=[f"s{i}" for i in range(n)])) + gene_names = [f"GENE_{i}" for i in range(g)] + inject_predictions( + adata, + np.zeros((n, 2)), + np.zeros((n, g)), + gene_names=gene_names, + ) + assert list(adata.var_names) == gene_names + + def test_with_pathway_scores(self): + """Pathway scores should go into adata.obsm['spatial_pathways'].""" + anndata = pytest.importorskip("anndata") + import pandas as pd + + n, g, p = 30, 10, 5 + adata = anndata.AnnData(obs=pd.DataFrame(index=[f"s{i}" for i in range(n)])) + pathway_scores = np.random.rand(n, p).astype(np.float32) + pathway_names = [f"PW_{i}" for i in range(p)] + + inject_predictions( + adata, + np.zeros((n, 2)), + np.zeros((n, g)), + pathway_scores=pathway_scores, + pathway_names=pathway_names, + ) + assert adata.obsm["spatial_pathways"].shape == (n, p) + assert adata.uns["pathway_names"] == pathway_names + + def test_shape_mismatch_raises(self): + """Mismatched row counts should raise ValueError.""" + anndata = pytest.importorskip("anndata") + import pandas as pd + + adata = anndata.AnnData(obs=pd.DataFrame(index=[f"s{i}" for i in range(10)])) + with pytest.raises(ValueError, match="coords has"): + inject_predictions(adata, np.zeros((5, 2)), np.zeros((10, 20))) + + def test_torch_tensor_input(self): + """Should accept torch tensors and convert them.""" + anndata = pytest.importorskip("anndata") + import pandas as pd + + n, g = 20, 10 + adata = anndata.AnnData(obs=pd.DataFrame(index=[f"s{i}" for i in range(n)])) + coords = torch.rand(n, 2) + preds = torch.rand(n, g) + inject_predictions(adata, coords, preds) + assert adata.X.shape == (n, g) diff --git a/tests/test_augmentation_sync.py b/tests/test_augmentation_sync.py index 6d35ca5..e5ed4fd 100644 --- a/tests/test_augmentation_sync.py +++ b/tests/test_augmentation_sync.py @@ -1,7 +1,7 @@ import torch import numpy as np import pytest -from spatial_transcript_former.data.dataset import ( +from spatial_transcript_former.recipes.hest.dataset import ( apply_dihedral_augmentation, apply_dihedral_to_tensor, ) diff --git a/tests/test_build_vocab.py b/tests/test_build_vocab.py index 4b204cc..3835a50 100644 --- a/tests/test_build_vocab.py +++ b/tests/test_build_vocab.py @@ -1,7 +1,7 @@ import os import pytest from unittest.mock import patch -from spatial_transcript_former.data.build_vocab import scan_h5ad_files +from spatial_transcript_former.recipes.hest.build_vocab import scan_h5ad_files def test_scan_h5ad_files_success(tmp_path): diff --git a/tests/test_data_integrity.py b/tests/test_data_integrity.py index ca63416..925514c 100644 --- a/tests/test_data_integrity.py +++ b/tests/test_data_integrity.py @@ -3,8 +3,11 @@ import torch import numpy as np import h5py -from spatial_transcript_former.data.io import get_hest_data_dir, load_h5ad_metadata -from spatial_transcript_former.data.dataset import load_global_genes +from spatial_transcript_former.recipes.hest.io import ( + get_hest_data_dir, + load_h5ad_metadata, +) +from spatial_transcript_former.recipes.hest.dataset import load_global_genes from spatial_transcript_former.data.pathways import ( download_msigdb_gmt, parse_gmt, diff --git a/tests/test_dataloader_h5ad.py b/tests/test_dataloader_h5ad.py index f7ad618..42b5a25 100644 --- a/tests/test_dataloader_h5ad.py +++ b/tests/test_dataloader_h5ad.py @@ -1,5 +1,5 @@ import os -from spatial_transcript_former.data.dataset import get_hest_dataloader +from spatial_transcript_former.recipes.hest.dataset import get_hest_dataloader import torch data_dir = r"A:\hest_data" diff --git a/tests/test_dataset_logic.py b/tests/test_dataset_logic.py index 92ae4d3..d516e78 100644 --- a/tests/test_dataset_logic.py +++ b/tests/test_dataset_logic.py @@ -1,7 +1,7 @@ import torch import numpy as np import pytest -from spatial_transcript_former.data.dataset import ( +from spatial_transcript_former.recipes.hest.dataset import ( apply_dihedral_augmentation, apply_dihedral_to_tensor, normalize_coordinates, diff --git a/tests/test_dataset_mocks.py b/tests/test_dataset_mocks.py index 93a7c5a..025b8ff 100644 --- a/tests/test_dataset_mocks.py +++ b/tests/test_dataset_mocks.py @@ -2,7 +2,10 @@ import numpy as np import pytest from unittest.mock import MagicMock, patch -from spatial_transcript_former.data.dataset import HEST_Dataset, HEST_FeatureDataset +from spatial_transcript_former.recipes.hest.dataset import ( + HEST_Dataset, + HEST_FeatureDataset, +) @pytest.fixture @@ -38,10 +41,10 @@ def test_hest_dataset_augmentation_consistency(mock_h5_file): # are called with the same 'op'. with ( patch( - "spatial_transcript_former.data.dataset.apply_dihedral_to_tensor" + "spatial_transcript_former.recipes.hest.dataset.apply_dihedral_to_tensor" ) as mock_tensor_aug, patch( - "spatial_transcript_former.data.dataset.apply_dihedral_augmentation" + "spatial_transcript_former.recipes.hest.dataset.apply_dihedral_augmentation" ) as mock_coord_aug, ): @@ -75,7 +78,7 @@ def test_hest_feature_dataset_neighborhood_dropout(): with ( patch("torch.load") as mock_load, patch( - "spatial_transcript_former.data.dataset.load_gene_expression_matrix" + "spatial_transcript_former.recipes.hest.dataset.load_gene_expression_matrix" ) as mock_gene_load, ): diff --git a/tests/test_download.py b/tests/test_download.py index 7d7c6bd..669f04b 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -10,7 +10,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) # Correct import based on the project structure -from spatial_transcript_former.data.download import ( +from spatial_transcript_former.recipes.hest.download import ( download_metadata, filter_samples, download_hest_subset, @@ -28,7 +28,7 @@ def tearDown(self): # Clean up the temporary directory shutil.rmtree(self.test_dir) - @patch("spatial_transcript_former.data.download.hf_hub_download") + @patch("spatial_transcript_former.recipes.hest.download.hf_hub_download") @patch("os.path.exists") def test_download_metadata_exists(self, mock_exists, mock_download): # Test case where metadata already exists @@ -39,7 +39,7 @@ def test_download_metadata_exists(self, mock_exists, mock_download): self.assertEqual(result, self.metadata_path) mock_download.assert_not_called() - @patch("spatial_transcript_former.data.download.hf_hub_download") + @patch("spatial_transcript_former.recipes.hest.download.hf_hub_download") @patch("os.path.exists") def test_download_metadata_missing(self, mock_exists, mock_download): # Test case where metadata is missing and needs download @@ -80,7 +80,7 @@ def test_filter_samples(self): samples = filter_samples(self.metadata_path, organ="Brain") self.assertEqual(samples, []) - @patch("spatial_transcript_former.data.download.snapshot_download") + @patch("spatial_transcript_former.recipes.hest.download.snapshot_download") def test_download_hest_subset_calls(self, mock_snapshot): # Test that snapshot_download is called with correct patterns sample_ids = ["S1", "S2"] @@ -104,7 +104,7 @@ def test_download_hest_subset_calls(self, mock_snapshot): # Check additional patterns self.assertIn("extra_file.txt", patterns) - @patch("spatial_transcript_former.data.download.snapshot_download") + @patch("spatial_transcript_former.recipes.hest.download.snapshot_download") @patch("zipfile.ZipFile") @patch("os.listdir") @patch("os.path.exists") diff --git a/tests/test_io.py b/tests/test_io.py index ec43a79..134818a 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -3,7 +3,7 @@ import numpy as np import pytest from unittest.mock import patch, MagicMock -from spatial_transcript_former.data.io import ( +from spatial_transcript_former.recipes.hest.io import ( get_hest_data_dir, decode_h5_string, load_h5ad_metadata, @@ -17,7 +17,7 @@ def test_decode_h5_string(): assert decode_h5_string(123) == "123" -@patch("spatial_transcript_former.data.io.get_config") +@patch("spatial_transcript_former.recipes.hest.io.get_config") @patch("os.path.exists") def test_get_hest_data_dir_from_config(mock_exists, mock_get_config): # Mock config to return a specific path @@ -32,7 +32,7 @@ def test_get_hest_data_dir_from_config(mock_exists, mock_get_config): mock_get_config.assert_called_with("data_dirs", []) -@patch("spatial_transcript_former.data.io.get_config") +@patch("spatial_transcript_former.recipes.hest.io.get_config") @patch("os.path.exists") def test_get_hest_data_dir_fallbacks(mock_exists, mock_get_config): mock_get_config.return_value = [] @@ -46,7 +46,7 @@ def test_get_hest_data_dir_fallbacks(mock_exists, mock_get_config): assert get_hest_data_dir() == fallback_path -@patch("spatial_transcript_former.data.io.get_config") +@patch("spatial_transcript_former.recipes.hest.io.get_config") @patch("os.path.exists") def test_get_hest_data_dir_not_found(mock_exists, mock_get_config): mock_get_config.return_value = [] diff --git a/tests/test_spatial_augment.py b/tests/test_spatial_augment.py index d1b0ea9..38ad5e0 100644 --- a/tests/test_spatial_augment.py +++ b/tests/test_spatial_augment.py @@ -1,7 +1,7 @@ import torch import numpy as np import pytest -from spatial_transcript_former.data.dataset import apply_dihedral_augmentation +from spatial_transcript_former.recipes.hest.dataset import apply_dihedral_augmentation def test_apply_dihedral_augmentation_torch(): diff --git a/tests/test_splitting_logic.py b/tests/test_splitting_logic.py index b6f3411..dc26573 100644 --- a/tests/test_splitting_logic.py +++ b/tests/test_splitting_logic.py @@ -1,7 +1,7 @@ import pandas as pd import os import pytest -from spatial_transcript_former.data import split_hest_patients +from spatial_transcript_former.recipes.hest.splitting import split_hest_patients def test_split_hest_patients(): diff --git a/tests/test_splitting_robust.py b/tests/test_splitting_robust.py index 09a4873..c5b116f 100644 --- a/tests/test_splitting_robust.py +++ b/tests/test_splitting_robust.py @@ -2,7 +2,7 @@ import pytest import os import tempfile -from spatial_transcript_former.data.splitting import split_hest_patients, main +from spatial_transcript_former.recipes.hest.splitting import split_hest_patients, main import sys from unittest.mock import patch diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 0000000..c6108a6 --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,322 @@ +""" +Tests for the Trainer class and callback system. + +Uses a tiny synthetic dataset and model to verify the training lifecycle +without requiring real data or GPU. +""" + +import os +import pytest +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from spatial_transcript_former.training.trainer import ( + Trainer, + TrainerCallback, + EarlyStoppingCallback, +) +from spatial_transcript_former.data.base import SpatialDataset + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +class TinySpatialDataset(SpatialDataset): + """Minimal SpatialDataset for testing.""" + + def __init__(self, n=32, feature_dim=64, num_genes=10): + self._features = torch.randn(n, 1, feature_dim) + self._genes = torch.randn(n, num_genes).abs() + self._coords = torch.zeros(n, 1, 2) + self.num_genes = num_genes + + def __len__(self): + return len(self._features) + + def __getitem__(self, idx): + return self._features[idx], self._genes[idx], self._coords[idx] + + +class TinyModel(nn.Module): + """Simple linear model for testing (mimics patch-level prediction).""" + + def __init__(self, in_dim=64, num_genes=10): + super().__init__() + self.fc = nn.Linear(in_dim, num_genes) + + def forward(self, x, **kwargs): + # x shape: (B, 1, D) -> squeeze -> (B, D) + if x.dim() == 3: + x = x.squeeze(1) + return self.fc(x) + + +@pytest.fixture +def tiny_setup(tmp_path): + """Create a minimal training setup with a tmp_path for output.""" + ds = TinySpatialDataset(n=32, feature_dim=64, num_genes=10) + train_loader = DataLoader(ds, batch_size=8, shuffle=True) + val_loader = DataLoader(ds, batch_size=8) + + model = TinyModel(in_dim=64, num_genes=10) + criterion = nn.MSELoss() + output_dir = str(tmp_path / "output") + + return model, criterion, train_loader, val_loader, output_dir + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestTrainerImports: + """Verify the Trainer is importable from all expected paths.""" + + def test_top_level_import(self): + from spatial_transcript_former import Trainer + + assert Trainer is not None + + def test_training_subpackage_import(self): + from spatial_transcript_former.training import Trainer + + assert Trainer is not None + + def test_direct_import(self): + from spatial_transcript_former.training.trainer import Trainer + + assert Trainer is not None + + def test_callback_imports(self): + from spatial_transcript_former.training import ( + TrainerCallback, + EarlyStoppingCallback, + ) + + assert TrainerCallback is not None + assert EarlyStoppingCallback is not None + + +class TestTrainerBasicFit: + """Test the core fit() lifecycle.""" + + def test_fit_runs_to_completion(self, tiny_setup): + model, criterion, train_loader, val_loader, output_dir = tiny_setup + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + epochs=3, + warmup_epochs=1, + device="cpu", + output_dir=output_dir, + use_amp=False, + ) + + results = trainer.fit() + + assert "best_val_loss" in results + assert results["epochs_completed"] == 3 + assert len(results["history"]) == 3 + + def test_fit_records_metrics(self, tiny_setup): + model, criterion, train_loader, val_loader, output_dir = tiny_setup + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + epochs=2, + warmup_epochs=0, + device="cpu", + output_dir=output_dir, + ) + results = trainer.fit() + + for row in results["history"]: + assert "train_loss" in row + assert "val_loss" in row + assert "lr" in row + + def test_saves_best_model_and_checkpoint(self, tiny_setup): + model, criterion, train_loader, val_loader, output_dir = tiny_setup + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + epochs=2, + warmup_epochs=0, + device="cpu", + output_dir=output_dir, + model_name="test", + ) + trainer.fit() + + assert os.path.exists(os.path.join(output_dir, "best_model_test.pth")) + assert os.path.exists(os.path.join(output_dir, "latest_model_test.pth")) + + def test_saves_logger_outputs(self, tiny_setup): + model, criterion, train_loader, val_loader, output_dir = tiny_setup + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + epochs=2, + warmup_epochs=0, + device="cpu", + output_dir=output_dir, + ) + trainer.fit() + + assert os.path.exists(os.path.join(output_dir, "training_logs.sqlite")) + assert os.path.exists(os.path.join(output_dir, "results_summary.json")) + + +class TestTrainerCustomOptimizer: + """Test passing a custom optimizer.""" + + def test_custom_optimizer(self, tiny_setup): + model, criterion, train_loader, val_loader, output_dir = tiny_setup + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=optimizer, + epochs=2, + warmup_epochs=0, + device="cpu", + output_dir=output_dir, + ) + assert trainer.optimizer is optimizer + results = trainer.fit() + assert results["epochs_completed"] == 2 + + +class TestCallbacks: + """Test the callback system.""" + + def test_callbacks_are_called(self, tiny_setup): + model, criterion, train_loader, val_loader, output_dir = tiny_setup + + call_log = [] + + class LogCallback(TrainerCallback): + def on_train_begin(self, trainer): + call_log.append("train_begin") + + def on_epoch_begin(self, trainer, epoch): + call_log.append(f"epoch_begin_{epoch}") + + def on_epoch_end(self, trainer, epoch, metrics): + call_log.append(f"epoch_end_{epoch}") + + def on_train_end(self, trainer, results): + call_log.append("train_end") + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + epochs=2, + warmup_epochs=0, + device="cpu", + output_dir=output_dir, + callbacks=[LogCallback()], + ) + trainer.fit() + + assert call_log == [ + "train_begin", + "epoch_begin_0", + "epoch_end_0", + "epoch_begin_1", + "epoch_end_1", + "train_end", + ] + + def test_early_stopping(self, tiny_setup): + model, criterion, train_loader, val_loader, output_dir = tiny_setup + + # Wrap so the model always outputs a near-constant → loss stays flat → early stop + class ConstantModel(nn.Module): + def __init__(self, base): + super().__init__() + self.base = base # keep parameters so optimizer doesn't crash + + def forward(self, x, **kwargs): + out = self.base(x, **kwargs) + # Multiply by tiny eps to keep grad flow, but output is ~1.0 + return out * 1e-10 + 1.0 + + const_model = ConstantModel(model) + + trainer = Trainer( + model=const_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + epochs=100, + warmup_epochs=0, + device="cpu", + output_dir=output_dir, + callbacks=[EarlyStoppingCallback(patience=2)], + ) + results = trainer.fit() + + # With constant output, loss never changes, so early stopping should + # fire after patience + 1 epochs + assert results["epochs_completed"] <= 5 + + +class TestTrainerResume: + """Test checkpoint resumption.""" + + def test_resume_continues_from_checkpoint(self, tiny_setup): + model, criterion, train_loader, val_loader, output_dir = tiny_setup + + # Train for 3 epochs + trainer1 = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + epochs=3, + warmup_epochs=0, + device="cpu", + output_dir=output_dir, + model_name="resume_test", + ) + trainer1.fit() + + # Resume — should start from epoch 3 and run to 5 + trainer2 = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + epochs=5, + warmup_epochs=0, + device="cpu", + output_dir=output_dir, + model_name="resume_test", + resume=True, + ) + results2 = trainer2.fit() + + # Should have completed 5 total epochs (3 from first run + 2 more) + assert results2["epochs_completed"] == 5