Skip to content

Commit ae68c5f

Browse files
Update and Major fixes to the Model and Pathway Loss (#5)
* feat: add prototype visualization and organ filtering * fix: correct spatial-pe argument handling * fix: sync num_genes with global_genes.json * fix: global sync of num_genes across model and data * - Updated the monitoring dashboard with a more modern look and refactored for better maintainbility. * refactor: modularize training pipeline and stabilize model initialization - Refactor train.py by extracting logic into training submodules: - arguments.py: CLI parameter definitions - builder.py: Model and criterion setup - checkpoint.py: Robust saving and loading logic - Fix learning rate plateau by replacing disjoint schedulers with SequentialLR to properly chain linear warmup and cosine decay phases. - Simplify SpatialTranscriptFormer architecture: - Remove redundant log_temperature parameter to reduce gradient variance. - Implement L1-normalization for MSigDB pathway weight initialization to prevent exponential prediction explosion at startup. - Enhance load_checkpoint with robust error handling for EOFError (corrupted files) and ValueError (architecture/optimizer mismatches) to ensure graceful fallbacks. * test: updated unit tests with architectural and API changes - Fix test_checkpoint.py by adding required schedulers argument to save_checkpoint and load_checkpoint calls. - Fix test_checkpoints.py by updating pathway initialization assembly assertion to expect L1-normalized weights instead of raw binary values. - Fix test_spatial_interaction.py by removing obsolete test_temperature_scaling (log_temperature removed from model). - Fix test_pathways.py by providing mock args to _compute_pathway_truth to satisfy new visualization parameters. - Add test_pathway_stability.py to verify numerical stability and gradient flow of the final stabilized pathway-informed architecture. * feat(loss): Implement Z-score normalization for AuxiliaryPathwayLoss Implemented spatial Z-score normalization and mean-aggregation for biological pathway ground-truth calculation. This ensures that every member gene in a pathway (even lowly-expressed transcription factors) contributes equally to the spatial activation signature, preventing high-count housekeeping genes from dominating the pathway patterns. Changes: - Updated AuxiliaryPathwayLoss to spatially standardize genes before projecting onto the pathway matrix. - Handled normalization across batch (patch-level) and spatial (whole-slide) dimensions with proper masking. - Switched from raw summation to mean-aggregation (averaging by pathway member counts). - Synchronized visualization.py ground-truth logic with the new objective. - Fixed mock tests in test_losses.py to match the normalized targets. Variance analysis on HEST data indicated raw gene variance ratios exceeding 300,000x, necessitating this standardization for biologically relevant pathway supervision. * feat(interpretation): pivot to biological priors and z-score normalization Deprecates the experimental data-driven pathway discovery in favor of strictly biologically-prior-driven interpretability. Updates the auxiliary pathway loss to use spatial Z-score normalization, ensuring lowly-expressed transcription factors contribute equally to the spatial objective. - Remove `--sparsity-lambda` and associated L1 regularization logic. - Implement spatial Z-score normalization in `AuxiliaryPathwayLoss`. - Synchronize visualization ground truth calculation with the new math. - Add `--plot-pathways-list` for dynamic user control over heatmaps. - Update plot labels to reflect Z-scored spatial patterns. - Cleanup: Delete LATENT_DISCOVERY.md and scrub legacy doc references. Ref: docs/PATHWAY_MAPPING.md, src/spatial_transcript_former/training/losses.py * - Updated visulisation test as this was failing with the new pathway selection logic..
1 parent eeac398 commit ae68c5f

34 files changed

Lines changed: 2594 additions & 1837 deletions

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ stf-download --species "Homo sapiens" --local_dir hest_data
6262
We provide presets for baseline models and scaled versions of the SpatialTranscriptFormer.
6363

6464
```bash
65-
# Recommended: Run the Interaction model with 4 transformer layers
66-
python scripts/run_preset.py --preset stf_interaction_l4
65+
# Recommended: Run the Interaction model (Small)
66+
python scripts/run_preset.py --preset stf_small
6767

68-
# Run the lightweight 2-layer version
69-
python scripts/run_preset.py --preset stf_interaction_l2
68+
# Run the lightweight Tiny version
69+
python scripts/run_preset.py --preset stf_tiny
7070

7171
# Run baselines
7272
python scripts/run_preset.py --preset he2rna_baseline

docs/LATENT_DISCOVERY.md

Lines changed: 0 additions & 46 deletions
This file was deleted.

docs/MODELS.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,8 @@ The SpatialTranscriptFormer models the **interaction between biological pathways
3434
By default, the model operates in **Full Interaction** mode where all four information flows are active. Users can selectively disable any combination using the `--interactions` flag to explore architectural variants:
3535

3636
```bash
37-
# Default: Full Interaction (all quadrants enabled)
38-
--interactions p2p p2h h2p h2h
39-
40-
# Pathway Bottleneck: block H↔H to force all inter-patch
41-
# communication through the pathway bottleneck
42-
--interactions p2p p2h h2p
37+
# Default: Small Interaction (CTransPath, 4 layers)
38+
python scripts/run_preset.py --preset stf_small
4339
```
4440

4541
> [!TIP]
@@ -53,7 +49,7 @@ Three additional design principles support these interactions:
5349

5450
- **Biological Initialisation** — The gene reconstruction weights are initialised from MSigDB Hallmark gene sets, providing a biologically-grounded starting point that the model refines during training.
5551

56-
### 2.2 Spatial Learning
52+
## 2.2 Spatial Learning
5753

5854
The spatial relationships of gene expression are central to this model. It is not sufficient to predict correct expression magnitudes at each spot independently — the model must capture **where** on the tissue pathways are active and how that spatial pattern varies across the slide. Two mechanisms enforce this:
5955

@@ -218,14 +214,19 @@ The model outputs these parameters, and the loss computes the negative log-likel
218214

219215
To prevent bottleneck collapse and provide a direct gradient signal to the pathway tokens, we use the `AuxiliaryPathwayLoss`. This loss compares the model's internal pathway scores against "ground truth" pathway activations computed from the gene expression targets via MSigDB membership.
220216

217+
To prevent highly-expressed housekeeping genes from dominating the pathway's spatial pattern, the ground-truth targets are computed using **Z-score spatial normalization**:
218+
219+
1. Every gene's spatial expression pattern is standardized (mean=0, variance=1) across the tissue slide.
220+
2. The normalized genes are projected onto the binary MSigDB pathway matrix.
221+
3. The resulting pathway scores are **mean-aggregated** (divided by the number of known member genes in each pathway) rather than raw-summed.
222+
223+
This ensures every gene—including critical but lowly-expressed transcription factors—gets an equal vote in determining where a pathway is active.
224+
221225
The total objective becomes:
222226
$$\mathcal{L} = \mathcal{L}_{gene} + \lambda_{aux} (1 - \text{PCC}(\text{pathway\_scores}, \text{target\_pathways}))$$
223227

224228
The `--log-transform` flag applies `log1p` to targets, mitigating the heavy-tailed gene expression distribution where housekeeping genes dominate MSE.
225229

226-
The full training objective with pathway sparsity regularisation:
227-
$$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$
228-
229230
---
230231

231232
## 5. CLI Flags (Model Configuration)
@@ -239,7 +240,6 @@ $$\mathcal{L} = \mathcal{L}_{task} + \lambda \|W_{recon}\|_1$$
239240
| `--n-layers` | 2 | Transformer layers (minimum 2) |
240241
| `--num-pathways` | 50 | Number of pathway bottleneck tokens |
241242
| `--pathway-init` | off | Initialize gene_reconstructor from MSigDB |
242-
| `--sparsity-lambda` | 0.0 | L1 regularisation on reconstruction weights |
243243
| `--loss mse_pcc` | `mse` | Loss function (`mse`, `pcc`, `mse_pcc`, `zinb`) |
244244
| `--pcc-weight` | 1.0 | Weight for PCC term in composite loss |
245245
| `--pathway-loss-weight` | 0.0 | Weight for auxiliary pathway loss ($\lambda_{aux}$) |

docs/PATHWAY_MAPPING.md

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,11 @@ In this mode, the network receives direct supervision on its pathway tokens, gui
3535
1. **Interaction**: Learnable pathway tokens $P$ interact with Histology patch features $H$ via self-attention (e.g., $p2h$, $h2p$).
3636
2. **Activation**: Pathway scores $S \in \mathbb{R}^P$ are computed using a learnable temperature-scaled cosine similarity between the pathway tokens and image patch tokens.
3737
3. **Gene Reconstruction**: $\hat{y} = S \cdot \mathbf{W}_{recon} + b$, where $\mathbf{W}_{recon}$ is initialized using the binary pathway membership matrix $M$.
38-
- **MTL Auxiliary Loss**: To prevent standard bottleneck collapse, an explicit auxiliary loss bridges the spatial representations directly to biological data. The pathway scores $S$ are supervised against a pathway ground truth ($Y_{genes} \cdot M^T$) using a Pearson Correlation Coefficient (PCC) loss.
39-
$$L_{total} = L_{gene} + \lambda_{pathway} (1 - PCC(S, Y_{genes} \cdot M^T))$$
40-
- **Benefit**: The model is forced to explicitly align its internal interaction tokens with concrete biological pathways, granting direct interpretability.
41-
42-
#### 2. Data-Driven Discovery (Latent Projection)
43-
44-
In the absence of a biological prior, the model can learn its own "latent pathways".
45-
46-
- **Implementation**: $\mathbf{W}_{recon}$ is randomly initialized and the auxiliary pathway loss is disabled.
47-
- **Sparsity Constraint**: We apply an L1 penalty to force the model to identify "canonical" sparse gene sets: $L_{total} = L_{gene} + \lambda_{sparsity} \|\mathbf{W}_{recon}\|_1$.
48-
- **Benefit**: Can discover novel spatial-transcriptomic relationships that aren't yet captured in curated databases.
38+
- **MTL Auxiliary Loss**: To prevent standard bottleneck collapse, an explicit auxiliary loss bridges the spatial representations directly to biological data. The pathway scores $S$ are supervised against a pathway ground truth using a Pearson Correlation Coefficient (PCC) loss.
39+
- To prevent highly expressed housekeeping genes dominating the signal, the raw spatial gene counts ($Y_{genes}$) are first **spatially Z-score normalized** ($Z_{genes}$).
40+
- These are then projected onto the pathway matrix and mean-aggregated by member count ($C$):
41+
$$L_{total} = L_{gene} + \lambda_{pathway} (1 - PCC(S, \frac{Z_{genes} \cdot M^T}{C}))$$
42+
- **Benefit**: The model is forced to explicitly align its internal interaction tokens with concrete biological pathways, granting direct interpretability where every gene gets an equal vote.
4943

5044
## 3. Generalizing to HEST1k Tissues
5145

@@ -73,18 +67,12 @@ By supplying these functional groupings via `--custom-gmt`, the model's MTL proc
7367
- GMT file cached in `.cache/` after first download.
7468
- **Custom Pathway Definitions** (`--custom-gmt` flag): Users can override the default Hallmarks by providing a URL or local path to a `.gmt` file, enabling custom database integrations (e.g., KEGG, Reactome, or highly specific tissue masks).
7569

76-
- **Sparsity Regularization** (`--sparsity-lambda` flag): L1 penalty on `gene_reconstructor` weights to encourage pathway-like groupings when using data-driven (random) initialization.
77-
7870
### Usage
7971

8072
```bash
8173
# With biological initialization (50 MSigDB Hallmarks)
8274
python -m spatial_transcript_former.train \
8375
--model interaction --pathway-init ...
84-
85-
# With data-driven pathways + sparsity
86-
python -m spatial_transcript_former.train \
87-
--model interaction --num-pathways 50 --sparsity-lambda 0.01 ...
8876
```
8977

9078
- **Spatial Pathway Maps**: Visualize pathway activations as spatial heatmaps overlaid on histology using `stf-predict`. See the [README](../README.md) for inference instructions.

docs/TRAINING_GUIDE.md

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -126,44 +126,22 @@ python -m spatial_transcript_former.train \
126126

127127
> **Note**: `--pathway-init` overrides `--num-pathways` to 50 (the number of Hallmark gene sets). The GMT file is cached in `.cache/` after first download.
128128
129-
### Data-Driven Discovery (Latent Pathways)
129+
### Recommended: Using Presets
130130

131-
To allow the model to discover its own spatial-transcriptomic relationships without biological priors, omit `--pathway-init` and apply sparsity regularization (`--sparsity-lambda`). This aims to force the model to identify "canonical" sparse gene sets.
131+
For most cases, it is recommended to use the provided presets:
132132

133133
```bash
134-
python -m spatial_transcript_former.train \
135-
--data-dir A:\hest_data \
136-
--model interaction \
137-
--backbone ctranspath \
138-
--use-nystrom \
139-
--num-pathways 50 \
140-
--sparsity-lambda 0.01 \
141-
--precomputed \
142-
--whole-slide \
143-
--use-amp \
144-
--log-transform \
145-
--epochs 100
146-
```
147-
148-
> **Note**: Without `--pathway-init`, the model disables the `AuxiliaryPathwayLoss` and relies entirely on the main reconstruction objectives and the L1 sparsity penalty. (I am yet to obtain results with this method)...
134+
# Tiny (2 layers, 256 dim)
135+
python scripts/run_preset.py --preset stf_tiny
149136

150-
### Robust Counting: ZINB + Auxiliary Loss
137+
# Small (4 layers, 384 dim) - Recommended
138+
python scripts/run_preset.py --preset stf_small
151139

152-
For raw count data with high sparsity, using the ZINB distribution and auxiliary pathway supervision is recommended.
140+
# Medium (6 layers, 512 dim)
141+
python scripts/run_preset.py --preset stf_medium
153142

154-
```bash
155-
python -m spatial_transcript_former.train \
156-
--data-dir A:\hest_data \
157-
--model interaction \
158-
--backbone ctranspath \
159-
--pathway-init \
160-
--loss zinb \
161-
--pathway-loss-weight 0.5 \
162-
--lr 5e-5 \
163-
--batch-size 4 \
164-
--whole-slide \
165-
--precomputed \
166-
--epochs 200
143+
# Large (12 layers, 768 dim)
144+
python scripts/run_preset.py --preset stf_large
167145
```
168146

169147
### Choosing Interaction Modes
@@ -201,7 +179,7 @@ Submit with:
201179
sbatch hpc/array_train.slurm
202180
```
203181

204-
### Collecting Results
182+
### Collecting Results (Currently broken!)
205183

206184
After experiments complete, aggregate all `results_summary.json` files into a comparison table:
207185

@@ -243,8 +221,8 @@ python -m spatial_transcript_former.train --resume --output-dir runs/my_experime
243221
| `--feature-dir` | Explicit path to precomputed features directory. | Overrides auto-detection. |
244222
| `--loss` | Loss function: `mse`, `pcc`, `mse_pcc`, `zinb`. | `mse_pcc` or `zinb` recommended. |
245223
| `--pathway-loss-weight` | Weight ($\lambda$) for auxiliary pathway supervision. | Set `0.5` or `1.0` with `interaction` model. |
246-
| `--sparsity-lambda` | L1 regularization weight for discovering latent pathways. | Use `0.01` when `--pathway-init` is NOT used. |
247224
| `--interactions` | Enabled attention quadrants: `p2p`, `p2h`, `h2p`, `h2h`. | Default: `all` (Full Interaction). |
225+
| `--plot-pathways-list` | Names of explicitly requested pathways to visualize as heatmaps during periodic validation. | Use with `--plot-pathways`. e.g. `HYPOXIA ANGIOGENESIS` |
248226
| `--log-transform` | Apply log1p to gene expression targets. | Recommended for raw count data. |
249227
| `--num-genes` | Number of HVGs to predict (default: 1000). | Match your `global_genes.json`. |
250228
| `--mask-radius` | Euclidean distance for spatial attention gating. | Usually between 200 and 800. |
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
import argparse
3+
import numpy as np
4+
import h5py
5+
import matplotlib.pyplot as plt
6+
import pandas as pd
7+
import json
8+
9+
10+
def analyze_sample(h5ad_path):
11+
print(f"Analyzing {h5ad_path}...")
12+
13+
with h5py.File(h5ad_path, "r") as f:
14+
# Check standard AnnData structure
15+
if "X" in f:
16+
if isinstance(f["X"], h5py.Group):
17+
# Sparse format (CSR/CSC)
18+
data_group = f["X"]["data"][:]
19+
n_cells = (
20+
f["obs"]["_index"].shape[0]
21+
if "_index" in f["obs"]
22+
else len(f["obs"])
23+
)
24+
n_genes = (
25+
f["var"]["_index"].shape[0]
26+
if "_index" in f["var"]
27+
else len(f["var"])
28+
)
29+
30+
print(f"Data is sparse, shape: ({n_cells}, {n_genes})")
31+
print(f"Non-zero elements: {len(data_group)}")
32+
33+
# Analyze non-zero elements
34+
mean_val = np.mean(data_group)
35+
max_val = np.max(data_group)
36+
min_val = np.min(data_group)
37+
38+
print(f"Non-zero Mean: {mean_val:.4f}")
39+
print(f"Max Expression: {max_val:.4f}")
40+
print(f"Min Expression: {min_val:.4f}")
41+
42+
else:
43+
# Dense array
44+
X = f["X"][:]
45+
print(f"Data is dense, shape: {X.shape}")
46+
47+
# Basic stats
48+
mean_exp = np.mean(X, axis=0) # per gene mean
49+
var_exp = np.var(X, axis=0) # per gene variance
50+
max_exp = np.max(X, axis=0)
51+
52+
sparsity = np.sum(X == 0) / X.size
53+
print(f"Overall Sparsity (zeros): {sparsity:.2%}")
54+
55+
print(
56+
f"Gene Mean Range: {np.min(mean_exp):.4f} to {np.max(mean_exp):.4f}"
57+
)
58+
print(f"Gene Var Range: {np.min(var_exp):.4f} to {np.max(var_exp):.4f}")
59+
print(f"Overall Max Expression: {np.max(max_exp):.4f}")
60+
61+
# Check for extreme differences in variance
62+
var_ratio = np.max(var_exp) / (np.min(var_exp) + 1e-8)
63+
print(f"Ratio of max/min gene variance: {var_ratio:.4e}")
64+
65+
return {
66+
"sparsity": sparsity,
67+
"var_ratio": var_ratio,
68+
"max_exp": np.max(max_exp),
69+
}
70+
71+
72+
if __name__ == "__main__":
73+
parser = argparse.ArgumentParser()
74+
parser.add_argument(
75+
"--data-dir",
76+
type=str,
77+
default="A:\\hest_data",
78+
help="Path to HEST data directory",
79+
)
80+
args = parser.parse_args()
81+
82+
st_dir = os.path.join(args.data_dir, "st")
83+
if not os.path.exists(st_dir):
84+
print(f"Error: Directory not found: {st_dir}")
85+
exit(1)
86+
87+
# Get a few random samples
88+
samples = [f for f in os.listdir(st_dir) if f.endswith(".h5ad")]
89+
if not samples:
90+
print(f"No .h5ad files found in {st_dir}")
91+
92+
# Analyze the first couple of samples
93+
for sample in samples[:3]:
94+
analyze_sample(os.path.join(st_dir, sample))
95+
print("-" * 50)

0 commit comments

Comments
 (0)