The SpatialTranscriptFormer package exposes a clean API for training, inference, and integration with the Scanpy/AnnData ecosystem.
from spatial_transcript_former import (
SpatialTranscriptFormer, # Core model
Trainer, # High-level training orchestrator
Predictor, # Inference wrapper
FeatureExtractor, # Backbone feature extraction
save_pretrained, # Save checkpoint directory
load_pretrained, # Load checkpoint directory
inject_predictions, # AnnData integration
)from spatial_transcript_former import SpatialTranscriptFormer, Predictor, FeatureExtractor
from spatial_transcript_former.predict import inject_predictions
import scanpy as sc
# 1. Load model from checkpoint directory
model = SpatialTranscriptFormer.from_pretrained("./checkpoints/my_run/")
print(model.gene_names[:3]) # ['TP53', 'EGFR', 'MYC']
# 2. Extract features from raw patches
extractor = FeatureExtractor(backbone="phikon", device="cuda")
features = extractor.extract_batch(image_tensor, batch_size=64) # (N, 768)
# 3. Predict gene expression
predictor = Predictor(model, device="cuda")
predictions = predictor.predict_wsi(features, coords) # (1, G)
# 4. Inject into AnnData for Scanpy analysis
adata = sc.AnnData(obs=pd.DataFrame(index=[f"spot_{i}" for i in range(N)]))
inject_predictions(adata, coords, predictions[0], gene_names=model.gene_names)
sc.pl.spatial(adata, color="TP53")from spatial_transcript_former import save_pretrained
# After training, export a self-contained checkpoint
save_pretrained(model, "./release/v1/", gene_names=gene_list)This creates:
release/v1/
├── config.json # Architecture parameters
├── model.pth # Model weights (state_dict)
└── gene_names.json # Ordered gene symbols
The core transformer model. Predicts gene expression from histology patch features and spatial coordinates.
| Parameter | Type | Default | Description |
|---|---|---|---|
num_genes |
int |
required | Number of output genes |
num_pathways |
int |
50 |
Number of pathway bottleneck tokens |
backbone_name |
str |
"resnet50" |
Backbone identifier (resnet50, phikon, ctranspath, etc.) |
pretrained |
bool |
True |
Load pretrained backbone weights |
token_dim |
int |
256 |
Common embedding dimension |
n_heads |
int |
4 |
Number of attention heads |
n_layers |
int |
2 |
Number of transformer layers |
dropout |
float |
0.1 |
Dropout probability |
pathway_init |
Tensor |
None |
(P, G) biological pathway membership matrix |
use_spatial_pe |
bool |
True |
Enable learned spatial positional encodings |
output_mode |
str |
"counts" |
Output head: "counts" (Softplus) or "zinb" (Zero-Inflated NB) |
interactions |
list[str] |
all | Attention interactions: p2p, p2h, h2p, h2h |
Load a model from a checkpoint directory created by save_pretrained.
model = SpatialTranscriptFormer.from_pretrained("./checkpoint/", device="cuda")
model.gene_names # List[str] or None| Parameter | Type | Description |
|---|---|---|
checkpoint_dir |
str |
Path to directory with config.json + model.pth |
device |
str |
Torch device ("cpu", "cuda") |
**kwargs |
Override any config.json value (e.g. dropout=0.0) |
Returns: SpatialTranscriptFormer in eval mode with .gene_names attribute.
Stateful inference wrapper. Manages device placement, eval mode, and optional AMP.
predictor = Predictor(model, device="cuda", use_amp=True)Single-patch inference from a raw image tensor.
result = predictor.predict_patch(image) # image: (1, 3, 224, 224) or (3, 224, 224)
# result: (1, num_genes)Note: When the model uses spatial PE, a zero-coordinate is automatically injected — no need to provide coordinates for single patches.
Whole-slide inference from pre-extracted feature embeddings.
# Global prediction (one vector per slide)
result = predictor.predict_wsi(features, coords) # (1, G)
# Dense prediction (one vector per patch)
result = predictor.predict_wsi(features, coords, return_dense=True) # (1, N, G)| Parameter | Type | Description |
|---|---|---|
features |
Tensor |
(N, D) or (1, N, D) embeddings |
coords |
Tensor |
(N, 2) or (1, N, 2) spatial coordinates |
return_pathways |
bool |
Also return pathway scores |
return_dense |
bool |
Per-patch predictions instead of global |
Validation: Raises
ValueErrorwith a clear message if the feature dimension doesn't match the model's expected backbone dimension.
Unified entry point — auto-dispatches:
- 4D tensor
(B, 3, H, W)→predict_patch - 2D tensor
(N, D)→predict_wsi(requirescoords)
Wraps a backbone model and its normalization transform for one-line feature extraction.
extractor = FeatureExtractor(backbone="phikon", device="cuda")
extractor.feature_dim # 768
extractor.backbone_name # "phikon"| Backbone | feature_dim |
Source |
|---|---|---|
resnet50 |
2048 | torchvision |
ctranspath |
768 | HuggingFace (CTransPath) |
phikon |
768 | Owkin Phikon (HuggingFace) |
vit_b_16 |
768 | torchvision |
gigapath |
1536 | ProvGigaPath (gated) |
hibou-b |
768 | Hibou-B (gated) |
hibou-l |
1024 | Hibou-L (gated) |
features = extractor(images) # (N, D) — all at once
features = extractor.extract_batch(images, batch_size=64) # batched, returns on CPUImages should be float tensors in [0, 1] range, shape (N, 3, H, W).
Save a self-contained checkpoint directory.
save_pretrained(model, "./release/v1/", gene_names=["TP53", "EGFR", ...])| Parameter | Type | Description |
|---|---|---|
model |
SpatialTranscriptFormer |
Trained model instance |
save_dir |
str |
Output directory (created if needed) |
gene_names |
list[str] |
Optional ordered gene symbols (must match num_genes) |
Raises: ValueError if gene_names length doesn't match num_genes.
If you're coming from a pure deep-learning background, AnnData and Scanpy may be unfamiliar. They are the standard data format and analysis toolkit in single-cell and spatial biology — the equivalent of what Pandas DataFrames are for tabular ML.
An AnnData object is a structured container for observations × variables matrices, designed for genomics. Think of it as a spreadsheet with labelled sidecars:
var (genes)
┌──────────────────┐
│ TP53 EGFR MYC │
┌────┼──────────────────┤
obs │ s0 │ 0.3 1.2 0.8 │ ← adata.X (the main data matrix)
(spots/ │ s1 │ 0.1 0.5 1.1 │
cells) │ s2 │ 0.9 0.2 0.4 │
└────┴──────────────────┘
| Slot | What it stores | Our usage |
|---|---|---|
adata.X |
Main matrix (N, G) |
Predicted gene expression |
adata.obs |
Per-observation metadata | Spot/cell barcodes, cluster labels |
adata.var |
Per-variable metadata | Gene symbols as the index |
adata.obsm["spatial"] |
Observation-level embeddings | (N, 2) spatial coordinates |
adata.obsm["spatial_pathways"] |
Additional embeddings | (N, P) pathway scores |
adata.uns |
Unstructured metadata | Pathway names, model config |
Scanpy (sc) is the analysis library that operates on AnnData objects. Once predictions are inside an adata, you instantly get access to:
import scanpy as sc
# Spatial plotting — visualise gene expression on tissue coordinates
sc.pl.spatial(adata, color="TP53")
# Clustering — find groups of spots with similar expression
sc.tl.leiden(adata)
# Differential expression — find marker genes per cluster
sc.tl.rank_genes_groups(adata, groupby="leiden")
# Dimensionality reduction
sc.tl.pca(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color="leiden")By injecting predictions into AnnData, our model's output becomes instantly compatible with the entire Scanpy ecosystem — clustering, differential testing, spatial plotting, trajectory analysis — without any custom code. Biologists can take our predictions and run their standard workflows immediately.
Inject predictions into an AnnData object for Scanpy integration.
inject_predictions(
adata,
coords, # → adata.obsm["spatial"]
predictions, # → adata.X
gene_names=["TP53", "EGFR", ...], # → adata.var_names
pathway_scores=pathway_activations, # → adata.obsm["spatial_pathways"]
pathway_names=["APOPTOSIS", ...], # → adata.uns["pathway_names"]
)| Parameter | Type | Description |
|---|---|---|
adata |
AnnData |
Target AnnData object |
coords |
ndarray |
(N, 2) spatial coordinates |
predictions |
ndarray |
(N, G) gene predictions |
gene_names |
list[str] |
Optional gene symbols |
pathway_scores |
ndarray |
Optional (N, P) pathway scores |
pathway_names |
list[str] |
Optional pathway names |
Lazy loading:
anndatais only imported when this function is called, so it's not required for basic inference.
checkpoint/
├── config.json # Architecture (JSON)
├── model.pth # Weights (PyTorch state_dict)
└── gene_names.json # Gene symbols (JSON array, optional)
config.json example:
{
"num_genes": 460,
"num_pathways": 50,
"backbone_name": "phikon",
"token_dim": 256,
"n_heads": 4,
"n_layers": 2,
"dropout": 0.1,
"use_spatial_pe": true,
"output_mode": "counts",
"interactions": ["h2h", "h2p", "p2h", "p2p"]
}gene_names.json example:
["TP53", "EGFR", "MYC", "BRCA1", ...]The training pipeline lives in the spatial_transcript_former.training subpackage. You can use it via the CLI or programmatically in your own scripts.
Training is launched via the stf-train entry point (or python -m spatial_transcript_former.train):
# Minimal: train on precomputed features with whole-slide mode
stf-train \
--model interaction \
--backbone phikon \
--data-dir /path/to/hest \
--precomputed \
--whole-slide \
--use-spatial-pe \
--pathway-init \
--loss mse_pcc \
--epochs 100 \
--lr 1e-4 \
--warmup-epochs 10
# Resume from checkpoint
stf-train --model interaction --resume --output-dir ./checkpoints| Flag | Default | Description |
|---|---|---|
--data-dir |
from config | Root HEST data directory |
--feature-dir |
auto | Explicit pre-extracted feature directory |
--num-genes |
1000 | Number of output genes |
--precomputed |
off | Use pre-extracted backbone features |
--whole-slide |
off | Dense whole-slide prediction mode |
--organ |
all | Filter samples by organ type |
--max-samples |
all | Limit samples (for debugging) |
| Flag | Default | Description |
|---|---|---|
--model |
he2rna |
Architecture: interaction, he2rna, vit_st, attention_mil, transmil |
--backbone |
resnet50 |
Backbone: resnet50, phikon, ctranspath, vit_b_16, etc. |
--num-pathways |
50 | Pathway bottleneck tokens |
--token-dim |
256 | Embedding dimension |
--n-heads |
4 | Attention heads |
--n-layers |
2 | Transformer layers |
--use-spatial-pe |
off | Learned spatial positional encoding |
--interactions |
all | Attention mask: p2p p2h h2p h2h |
--pathway-init |
off | Initialize gene head from MSigDB Hallmarks |
| Flag | Default | Description |
|---|---|---|
--epochs |
10 | Total training epochs |
--batch-size |
32 | Batch size |
--lr |
1e-4 | Learning rate |
--warmup-epochs |
10 | Linear warmup before cosine annealing |
--weight-decay |
0.0 | AdamW weight decay |
--grad-accum-steps |
1 | Gradient accumulation steps |
--use-amp |
off | Mixed precision (FP16) |
--compile |
off | torch.compile the model |
--resume |
off | Resume from latest checkpoint |
| Flag | Default | Description |
|---|---|---|
--loss |
mse_pcc |
Loss function: mse, pcc, mse_pcc, zinb, poisson, logcosh |
--pcc-weight |
1.0 | PCC term weight in mse_pcc |
--pathway-loss-weight |
0.0 | Auxiliary pathway PCC loss weight (0 = disabled) |
The Trainer class handles LR scheduling, AMP, checkpointing, logging, and early stopping:
from spatial_transcript_former import SpatialTranscriptFormer, Trainer
from spatial_transcript_former.training import CompositeLoss, EarlyStoppingCallback
model = SpatialTranscriptFormer(num_genes=460, backbone_name="phikon", ...)
trainer = Trainer(
model=model,
train_loader=train_dl,
val_loader=val_dl,
criterion=CompositeLoss(alpha=1.0),
epochs=100,
warmup_epochs=10,
device="cuda",
output_dir="./checkpoints",
use_amp=True,
callbacks=[EarlyStoppingCallback(patience=15)],
)
results = trainer.fit() # returns {"best_val_loss", "history", ...}
trainer.save_pretrained("./release/v1/") # inference-ready export| Parameter | Default | Description |
|---|---|---|
model |
required | Any nn.Module |
train_loader |
required | Training DataLoader |
val_loader |
required | Validation DataLoader |
criterion |
required | Loss function |
optimizer |
None |
Custom optimizer (default: AdamW) |
lr |
1e-4 |
Learning rate (if no custom optimizer) |
epochs |
100 |
Total training epochs |
warmup_epochs |
10 |
Linear warmup before cosine annealing |
use_amp |
False |
Mixed precision (FP16) |
grad_accum_steps |
1 |
Gradient accumulation |
whole_slide |
False |
Dense whole-slide mode |
output_dir |
./checkpoints |
Directory for checkpoints/logs |
callbacks |
[] |
List of TrainerCallback instances |
resume |
False |
Resume from checkpoint |
Subclass TrainerCallback to hook into the training loop:
from spatial_transcript_former.training import TrainerCallback
class WandbCallback(TrainerCallback):
def on_epoch_end(self, trainer, epoch, metrics):
wandb.log(metrics, step=epoch)
def should_stop(self, trainer, epoch, metrics):
return False # never stop early| Hook | When |
|---|---|
on_train_begin(trainer) |
Start of fit() |
on_epoch_begin(trainer, epoch) |
Before each epoch |
on_epoch_end(trainer, epoch, metrics) |
After validation |
on_train_end(trainer, results) |
End of fit() |
should_stop(trainer, epoch, metrics) |
Return True to stop |
Built-in: EarlyStoppingCallback(patience=15, min_delta=0.0)
For full control, use the engine functions directly:
from spatial_transcript_former.training import train_one_epoch, validate, CompositeLoss
criterion = CompositeLoss(alpha=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(100):
train_loss = train_one_epoch(
model, train_loader, criterion, optimizer, device,
whole_slide=True, scaler=scaler, grad_accum_steps=4,
)
val_metrics = validate(
model, val_loader, criterion, device,
whole_slide=True, use_amp=True,
)
print(f"Epoch {epoch}: train={train_loss:.4f}, val={val_metrics['val_loss']:.4f}")All losses accept (B, G) patch-level or (B, N, G) dense inputs, with optional mask for padded positions.
| Class | Formula / Description |
|---|---|
MaskedMSELoss |
Standard MSE with optional padding mask |
PCCLoss |
1 - mean(PCC) — gene-wise spatial Pearson correlation |
CompositeLoss |
MSE + α · (1 - PCC) — balances magnitude and spatial pattern |
ZINBLoss |
Zero-Inflated Negative Binomial NLL — for raw count data |
MaskedHuberLoss |
Huber (SmoothL1) — robust to outlier spots |
AuxiliaryPathwayLoss |
Wraps any base loss + PCC on pathway bottleneck scores |
| Function | Description |
|---|---|
train_one_epoch(model, loader, criterion, optimizer, device, ...) |
One epoch of training. Handles gradient accumulation, AMP, and both patch/WSI modes. Returns average loss. |
validate(model, loader, criterion, device, ...) |
Validation pass. Returns dict with val_loss, val_mae, val_pcc, pred_variance, and optional attn_correlation. |
Offline-friendly logger (no W&B dependency). Writes metrics to SQLite and a JSON summary.
logger = ExperimentLogger(output_dir, config_dict)
logger.log_epoch(epoch, {"train_loss": 0.1, "val_loss": 0.2, "val_pcc": 0.65})
logger.finalize(best_val_loss=0.15)| Output File | Contents |
|---|---|
training_logs.sqlite |
Per-epoch metrics table |
results_summary.json |
Config + final metrics + runtime |
During training, checkpoints are managed by training.checkpoint (the internal module — distinct from the public save_pretrained):
| Function | Purpose |
|---|---|
save_checkpoint(model, optimizer, scaler, schedulers, ...) |
Saves full training state for --resume |
load_checkpoint(model, optimizer, scaler, schedulers, ...) |
Restores training state |
After training is complete, use the public save_pretrained to export a clean, inference-ready checkpoint:
from spatial_transcript_former import save_pretrained
# Export for inference (strips optimizer/scheduler state)
save_pretrained(model, "./release/v1/", gene_names=gene_list)All datasets implement the SpatialDataset contract (in data.base). The contract requires __getitem__ to return:
(features, gene_counts, rel_coords)
# features: (S, D) tensor — patch embeddings (S = 1 + neighbours)
# gene_counts: (G,) tensor — expression targets
# rel_coords: (S, 2) tensor — relative spatial coordinatesfrom spatial_transcript_former.data.base import SpatialDataset
import torch
class MyVisiumDataset(SpatialDataset):
def __init__(self, features, gene_matrix, coords):
self._features = torch.as_tensor(features, dtype=torch.float32)
self._genes = torch.as_tensor(gene_matrix, dtype=torch.float32)
self._coords = torch.as_tensor(coords, dtype=torch.float32)
self.num_genes = self._genes.shape[1]
def __len__(self):
return len(self._features)
def __getitem__(self, idx):
feat = self._features[idx].unsqueeze(0) # (1, D)
genes = self._genes[idx] # (G,)
rel_coord = torch.zeros(1, 2) # centre = [0,0]
return feat, genes, rel_coordfrom torch.utils.data import DataLoader, random_split
from spatial_transcript_former import SpatialTranscriptFormer, Trainer
from spatial_transcript_former.training import CompositeLoss, EarlyStoppingCallback
dataset = MyVisiumDataset(features, gene_matrix, coords)
train_ds, val_ds = random_split(dataset, [0.8, 0.2])
model = SpatialTranscriptFormer(num_genes=dataset.num_genes, backbone_name="phikon")
trainer = Trainer(
model=model,
train_loader=DataLoader(train_ds, batch_size=32, shuffle=True),
val_loader=DataLoader(val_ds, batch_size=64),
criterion=CompositeLoss(),
epochs=100,
callbacks=[EarlyStoppingCallback(patience=15)],
)
results = trainer.fit()
trainer.save_pretrained("./my_model/")See recipes/custom/README.md for the full guide.