From 00488065c6be2b16c3826eabe702f09a9666580a Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 10 Jun 2026 12:52:21 -0700 Subject: [PATCH 01/11] Restructure combination/: drop deprecated config path, group supplemental + add-on modules The argparse subpackage (`python -m ...combination.pca_optimization`) is now the canonical entry point, so the config/baseline.yml path is removed and the remaining modules are grouped by role. Remove deprecated config/baseline.yml path (no non-test/non-scratch importers): - cli.py, config_handler.py, file_validator.py - combiners.py (the duplicate PcaOptimizationCombiner + deprecated ComprehensiveCombiner) - classifier_combiner.py / classifier_aggregator.py (dormant; never wired into the CLI) Group downstream-only analysis tools into analysis/: - embedding_overlays, compare_map_scores, compare_modalities, pca_component_to_feature, marker_norm_sweep_runner Group optional flag-gated stages into pipeline_add_ons/: - op_signal, chromosome, guide_chrom_arm_correction Update all importers (pca_optimization __init__/handlers/phase2/embeddings and models/attention/embedding/regen_umap_html) to the new paths. pca_sweep_op_signal is still re-exported through pca_optimization's namespace; test_pca_optimization_refactor passes (41/41). Add README.md (how to run the subpackage) and SCRIPT_MAP.md (core vs supplemental inventory). cell_filters.py is retained pending a port-vs-drop decision (no subpackage equivalent yet). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../attention/embedding/regen_umap_html.py | 2 +- .../post_process/combination/README.md | 200 +++ .../post_process/combination/SCRIPT_MAP.md | 170 +++ .../combination/analysis/__init__.py | 0 .../{ => analysis}/compare_map_scores.py | 0 .../{ => analysis}/compare_modalities.py | 0 .../{ => analysis}/embedding_overlays.py | 2 +- .../marker_norm_sweep_runner.py | 0 .../pca_component_to_feature.py | 0 .../combination/classifier_aggregator.py | 1206 ---------------- .../combination/classifier_combiner.py | 219 --- src/ops_model/post_process/combination/cli.py | 279 ---- .../post_process/combination/combiners.py | 1220 ----------------- .../combination/config_handler.py | 242 ---- .../combination/file_validator.py | 155 --- .../combination/pca_optimization/__init__.py | 6 +- .../pca_optimization/embeddings.py | 2 +- .../combination/pca_optimization/handlers.py | 16 +- .../combination/pca_optimization/phase2.py | 4 +- .../combination/pipeline_add_ons/__init__.py | 0 .../chromosome.py | 0 .../guide_chrom_arm_correction.py | 0 .../op_signal.py | 0 23 files changed, 386 insertions(+), 3337 deletions(-) create mode 100644 src/ops_model/post_process/combination/README.md create mode 100644 src/ops_model/post_process/combination/SCRIPT_MAP.md create mode 100644 src/ops_model/post_process/combination/analysis/__init__.py rename src/ops_model/post_process/combination/{ => analysis}/compare_map_scores.py (100%) rename src/ops_model/post_process/combination/{ => analysis}/compare_modalities.py (100%) rename src/ops_model/post_process/combination/{ => analysis}/embedding_overlays.py (99%) rename src/ops_model/post_process/combination/{ => analysis}/marker_norm_sweep_runner.py (100%) rename src/ops_model/post_process/combination/{ => analysis}/pca_component_to_feature.py (100%) delete mode 100644 src/ops_model/post_process/combination/classifier_aggregator.py delete mode 100644 src/ops_model/post_process/combination/classifier_combiner.py delete mode 100644 src/ops_model/post_process/combination/cli.py delete mode 100644 src/ops_model/post_process/combination/combiners.py delete mode 100644 src/ops_model/post_process/combination/config_handler.py delete mode 100644 src/ops_model/post_process/combination/file_validator.py create mode 100644 src/ops_model/post_process/combination/pipeline_add_ons/__init__.py rename src/ops_model/post_process/combination/{pca_optimization => pipeline_add_ons}/chromosome.py (100%) rename src/ops_model/post_process/combination/{ => pipeline_add_ons}/guide_chrom_arm_correction.py (100%) rename src/ops_model/post_process/combination/{pca_optimization => pipeline_add_ons}/op_signal.py (100%) diff --git a/src/ops_model/models/attention/embedding/regen_umap_html.py b/src/ops_model/models/attention/embedding/regen_umap_html.py index cb463d8..3cfe97f 100644 --- a/src/ops_model/models/attention/embedding/regen_umap_html.py +++ b/src/ops_model/models/attention/embedding/regen_umap_html.py @@ -99,7 +99,7 @@ def process_group(group_dir: Path, from ops_model.post_process.combination.pca_optimization.aggregation import ( _annotate_genes_from_panel, ) - from ops_model.post_process.combination.embedding_overlays import ( + from ops_model.post_process.combination.analysis.embedding_overlays import ( load_overlay_maps, save_interactive_html, ) diff --git a/src/ops_model/post_process/combination/README.md b/src/ops_model/post_process/combination/README.md new file mode 100644 index 0000000..6f5272e --- /dev/null +++ b/src/ops_model/post_process/combination/README.md @@ -0,0 +1,200 @@ +# `combination/` — combining experiments into multi-experiment profiles + +This package combines per-experiment, cell-level feature embeddings (DINO / cell-DINO / +CellProfiler / …) from many OPS experiments into **combined guide- and gene-level AnnData +objects**, scored on phenotypic metrics. The supported pipeline is the **`pca_optimization` +subpackage**. + +- **How it works (internals / data flow):** [`pca_optimization_dataflow.md`](pca_optimization_dataflow.md) +- **What every file in this dir is (core vs supplemental):** [`SCRIPT_MAP.md`](SCRIPT_MAP.md) + +> The older config-driven `cli.py` / `baseline.yml` path (`run_combination`, +> `PcaOptimizationCombiner`, `ComprehensiveCombiner`, …) has been **removed**. Use the +> argparse subpackage below. (History: `SCRIPT_MAP.md` §3.) + +--- + +## What it produces + +A two-phase pipeline: + +- **Phase 1** — one job per biological signal group: pool cells across experiments that share + that signal → (optional z-score / downsample) → fit PCA → pick `n_pcs` (variance sweep or + fixed cutoff) → aggregate to guide/gene → save a per-signal h5ad. +- **Phase 2** — one aggregation job: load the per-signal h5ads, horizontally concatenate, NTC- + normalize, aggregate to gene level, score metrics (activity / distinctiveness / CORUM / CHAD / + EBI), compute UMAP+PHATE, and write the canonical outputs. + +Canonical Phase 2 outputs (under the resolved output dir, see [Output layout](#output-layout)): +- `guide_pca_optimized.h5ad` — combined guide-level profiles +- `gene_embedding_pca_optimized.h5ad` — combined gene-level profiles + UMAP/PHATE +- `pca_report.csv`, `metrics/`, `plots/`, and (default) `second_pca_consensus/` + +--- + +## Entry point + +```bash +python -m ops_model.post_process.combination.pca_optimization [FLAGS] +``` + +Full flag reference (60+ options): + +```bash +python -m ops_model.post_process.combination.pca_optimization --help +``` + +**Exactly one feature-mode flag is required** (no implicit default): +`--cell-dino` · `--dino` · `--cell-profiler` · `--dynaclr` · `--subcell` · `--organelle-profiler`. + +--- + +## Inputs + +For the discovery-based feature modes, the pipeline scans the standard storage roots for each +experiment's cell-level h5ads at: + +``` +//3-assembly//anndata_objects/features_processed_.h5ad +``` + +`` per mode: `cell_dino_features` (`--cell-dino`), `dino_features` (`--dino`), +`cell-profiler` (`--cell-profiler`), `dynaclr_features`, `subcell_features`. Channels are mapped +to biological-signal groups via the channel maps +(`/hpc/projects/icd.fast.ops/configs/ops_channel_maps.yaml`). +`--organelle-profiler` instead reads consolidated `all_cells_*.h5ad` files from `--op-root`. + +Restrict the experiment set with `--experiments ops0100,ops0105,…`, or `--paper-v1` +(the curated `good_experiment_list_v1.yml`). **Always start with `--dry-run`** to print the +discovered signal-group manifest without processing. + +--- + +## Quick start + +**0. Dry run** — see what would be processed (no compute): +```bash +python -m ops_model.post_process.combination.pca_optimization \ + --cell-dino --phase-only \ + --experiments ops0100,ops0105,ops0117,ops0119,ops0120 \ + --dry-run +``` + +**1. Local run** (no SLURM) — small/quick combine, fixed PCA threshold: +```bash +python -m ops_model.post_process.combination.pca_optimization \ + --cell-dino --phase-only \ + --experiments ops0100,ops0105,ops0117,ops0119,ops0120 \ + --fixed-threshold 0.80 \ + --output-dir /hpc/projects/icd.fast.ops/experiments//combine_test \ + -y +``` +Omitting `--slurm` runs Phase 1 + Phase 2 in-process. + +**2. SLURM run** — production combine (one Phase-1 job per signal + one Phase-2 job): +```bash +python -m ops_model.post_process.combination.pca_optimization \ + --cell-dino --zscore-per-experiment \ + --paper-v1 \ + --slurm +``` + +**3. Validation cohort** (4 experiments, Phase-only, custom CHAD file) — from the module docstring: +```bash +python -m ops_model.post_process.combination.pca_optimization \ + --output-dir /hpc/projects/icd.fast.ops/organelle_attribution/pca_optimized_v0.3 \ + --cell-dino --zscore-per-experiment \ + --run-tag paper_v1/validation_4exp_phase_only \ + --experiments ops0146,ops0147,ops0150,ops0151 \ + --phase-only \ + --chad-annotation /hpc/projects/icd.fast.ops/configs/gene_clusters/val_library_chad_positive_controls_v1.yml \ + --slurm +``` + +--- + +## Key flags + +**Feature mode (pick one, required):** `--cell-dino` `--dino` `--cell-profiler` `--dynaclr` +`--subcell` `--organelle-profiler` (`--op-root `). + +**Channel subset:** `--phase-only` (brightfield only) · `--no-phase` (fluorescent only) · +default = all. Sibling layouts: `--with-cp` / `--with-4i` / `--only-cp` / `--only-4i` / +`--include-cellpainting`. + +**PCA threshold:** `--fixed-threshold 0.80` (default; single cutoff) · `--fixed-threshold 0` +(run the full **consensus variance sweep** instead). CP features default to a lower sweep range. + +**Downsampling:** `--downsampled` (equalize cells across signal groups, floor 750k) · +`--target-cells N` (force exact count) · `--downsample-per-guide --cells-per-guide 250`. + +**Normalization:** `--norm-method ntc|global` (default `ntc`) · +`--zscore-per-experiment` / `--no-zscore-per-experiment` (default on). + +**SLURM:** `--slurm` to dispatch; tune with `--slurm-memory` `--slurm-time` `--slurm-cpus` +`--slurm-partition` `--phase-memory` `--slurm-agg-memory` `--slurm-agg-time`. Omit `--slurm` +to run locally. + +**Embeddings / reproducibility:** `--seed` (default 1 for `--umap-type max`, 42 for `gav`) · +`--umap-type max|gav` · `--distance cosine|euclidean`. + +**Second-pass PCA** (on by default): `--no-second-pca` to disable · `--second-pca-threshold` +(0 = consensus sweep) · `--second-pca-consensus-metrics activity,distinctiveness,ebi`. + +**Chromosome-arm correction** (optional): `--chrom-arm-correct` (+ `--chrom-arm-method`, +`--chrom-arm-knn`, `--chrom-arm-qval`, `--chrom-arm-map-csv`). + +**Experiment selection:** `--experiments ops0100,ops0105` · `--paper-v1` · `--signals "Phase,ER_SEC61B"` +(retry specific signal shards) · `--run-tag ` (organizational subfolder). + +**Misc:** `-y/--yes` (skip confirmation) · `--direct` (use `--output-dir` verbatim, skip auto-nesting) · +`--clean` (wipe prior Phase-1 outputs first) · `--exclude-dud-guides` (default on). + +--- + +## Re-running phases cheaply (skip the expensive PCA sweep) + +After a full run you can regenerate downstream artifacts from the on-disk per-signal/combined +h5ads without redoing Phase 1: + +| Flag | Reuses | Recomputes | +|---|---|---| +| `--aggregate-only` | per-signal h5ads | Phase 2: concat → normalize → score → embed | +| `--second-pca-only` | `guide_pca_optimized.h5ad` | second-pass PCA only | +| `--umap-only` | combined h5ads | UMAP/PHATE + embedding plots | +| `--overlays-only` | combined h5ads | interactive HTML overlays (refits UMAP only if `--seed` differs) | +| `--chad-umap-only` | `gene_embedding_pca_optimized.h5ad` | CHAD-colored UMAP | +| `--sweep-seed` | gene h5ad | a grid PNG of UMAP layouts across seeds | + +Pass the **same flags** that define the output path (feature mode, channel subset, threshold, +distance, …) so the tool resolves to the same directory — see below. + +--- + +## Output layout + +The output path is auto-nested from the flags (use `--direct` to bypass): + +``` +//[zscore_per_exp/][paper_v1/][/]///[agg_/] +``` + +- ``: `cell_dino` · `dino` · `cellprofiler` · `dynaclr` · `subcell` · `organelle_profiler` +- ``: `all_livecell` (default) · `phase_only` · `no_phase` · `*_downsampled` · `with_cp` … +- ``: `fixed_80%` (fixed threshold) · `consensus_sweep` · `batch` (`--preserve-batch`) · `no_pca` +- ``: `cosine` (default) · `euclidean` + +Inside the resolved dir: +- `per_channel/` (standard) or `per_signal/` (downsampled) — Phase 1 per-signal h5ads + sweep CSVs +- `guide_pca_optimized.h5ad`, `gene_embedding_pca_optimized.h5ad` — Phase 2 combined outputs +- `pca_report.csv`, `metrics/`, `plots/`, `second_pca_consensus/` + +--- + +## Supplemental tools (run separately, not part of a combine) + +Post-hoc analysis/visualization that consume the combined outputs, in the `analysis/` +subpackage (run via `python -m ops_model.post_process.combination.analysis.`) — see +`SCRIPT_MAP.md`: `analysis/compare_map_scores.py`, `analysis/compare_modalities.py`, +`analysis/pca_component_to_feature.py`, `analysis/embedding_overlays.py`, plus the +`titration/` and `hand_annotations/` analyses. diff --git a/src/ops_model/post_process/combination/SCRIPT_MAP.md b/src/ops_model/post_process/combination/SCRIPT_MAP.md new file mode 100644 index 0000000..6b3fa38 --- /dev/null +++ b/src/ops_model/post_process/combination/SCRIPT_MAP.md @@ -0,0 +1,170 @@ +# `combination/` — core pipeline vs supplemental scripts + +Inventory of every `.py` in `ops_model/src/ops_model/post_process/combination/` (read in +full), classified as **core combination pipeline** (produces the combined multi-experiment +guide/gene h5ads), **deprecated**, or **supplemental** (optional stages, downstream +analysis/plotting, one-off tooling). 34 modules, ~27k LOC. + +> **Decision (canonical going forward):** the **argparse subpackage** +> `python -m ops_model.post_process.combination.pca_optimization` is the single supported +> entry point. **Everything tied only to the config/`baseline.yml` path is deprecated** — +> see §3. + +> Line numbers/sizes are a snapshot; anchor on names. + +--- + +## 0. The two implementations — one kept, one deprecated + +There were **two independent implementations** of the same two-phase pca_optimized pipeline. +We are standardizing on the subpackage and retiring the config-driven path. + +| Path | Entry point | Modules | Status | +|---|---|---|---| +| **Argparse subpackage** | `python -m …combination.pca_optimization …` | `pca_optimization/` (`__init__`, `handlers`, `phase1`, `phase2`, `sweep_core`, `aggregation`, `slurm`, `parser`, …) | **CANONICAL** | +| **Config-driven CLI** | `python -m …combination.cli --config baseline.yml` → `run_combination` | `cli.py`, `config_handler.py`, `combiners.py`, `file_validator.py`, `cell_filters.py` | **DEPRECATED** | + +Verified: the subpackage is **fully independent** of the config-path modules (no imports of +`config_handler` / `combiners` / `cell_filters` / `file_validator`). Both paths ultimately +call the shared aggregation primitives in `ops_model.features.anndata_utils` +(`aggregate_to_level`, `hconcat_by_perturbation`, `normalize_guide_adata`). + +--- + +## 1. CORE — the canonical pipeline (`pca_optimization/`) + +Two-phase flow: pool cells per biological signal → fit PCA → sweep n_pcs → aggregate to +guide/gene → NTC-normalize → save guide/gene h5ads → (optional) score + embed. + +| File | LOC | Entry | Role | +|---|---|---|---| +| `pca_optimization/__init__.py` | 501 | `main` | Orchestration hub; parses args, discovers experiments, dispatches handlers. | +| `pca_optimization/__main__.py` | 13 | `python -m …` shim | Re-exports `main`. | +| `pca_optimization/parser.py` | 641 | lib | `_build_parser` — the 60+ CLI flags (the real config surface now). | +| `pca_optimization/slurm.py` | 240 | lib | Phase 1/2 submitit job submission + chaining. | +| `pca_optimization/handlers.py` | 1961 | lib | One handler per CLI mode (standard/downsampled, aggregate-only, second-pca, op, umap-only…). Decides which stages run. | +| `pca_optimization/phase1.py` | 571 | lib (SLURM worker) | `pca_sweep_pooled_signal` — Phase 1: pool→PCA→threshold sweep per signal. | +| `pca_optimization/phase2.py` | 1101 | lib (SLURM worker) | `aggregate_channels` — Phase 2: concat per-signal → normalize → score → embed → save; `apply_second_pass_pca` (optional Phase 3). | +| `pca_optimization/sweep_core.py` | 755 | lib | Threshold-sweep scoring + per-signal h5ad/CSV writing (incl. `--no-pca` raw save). | +| `pca_optimization/aggregation.py` | 580 | lib | Phase 2 primitives: hconcat, NTC-normalize, gene re-agg, **`_atomic_write_h5ad` (obs-sanitizing save)**, panel annotation, canonical h5ad write. | + +Shared dependency (outside this dir): `ops_model.features.anndata_utils` aggregation helpers. + +--- + +## 2. CORE-OPTIONAL — subpackage stages that run only under specific flags +| File | LOC | Role | When it runs | +|---|---|---|---| +| `pca_optimization/embeddings.py` | 478 | UMAP/PHATE + metric overlays + distinctiveness/consistency scoring. | Standard Phase 2; also standalone via `--umap-only`. | +Optional add-on stages, moved into `pipeline_add_ons/` on `alexhillsley/refactor` +(`pca_sweep_op_signal` is still re-exported through `pca_optimization`'s namespace): +| File | LOC | Role | When it runs | +|---|---|---|---| +| `pipeline_add_ons/op_signal.py` | 395 | `pca_sweep_op_signal` — Phase 1 variant reading OrganelleProfiler `all_cells_*.h5ad`. | Only `--organelle-profiler` mode. | +| `pipeline_add_ons/chromosome.py` | 308 | Chromosome-arm overlay plots (PNG/SVG/HTML) on gene embeddings. | Only with `--chromosome-csv` or active chrom-arm correction. | +| `pipeline_add_ons/guide_chrom_arm_correction.py` | 740 | Removes chromosome-arm clustering artifacts from guide PCA embeddings (3 strategies). | Optional post-processing; called from subpackage `handlers`/`embedding_overlays` when configured. | + +--- + +## 3. DEPRECATED — the config/`baseline.yml` path + +**REMOVED** on branch `alexhillsley/refactor` (verified: no non-test/non-scratch code imported +them). Deleted: `cli.py`, `combiners.py`, `config_handler.py`, `file_validator.py`, +`classifier_combiner.py`, `classifier_aggregator.py` (+ tests `test_classifier_combiner.py`, +`test_classifier_aggregator.py`, `test_combination_e2e.py`). +**Held:** `cell_filters.py` (pending the port-vs-drop decision — see §Migration). + +The table below is retained as a record of what was removed and why. + +| File | LOC | Role | Notes | +|---|---|---|---| +| `cli.py` | 286 | `run_combination` dispatch + `validate_and_save`. | Whole module deprecated. (The recent obs-sanitization fix here is moot — `aggregation._atomic_write_h5ad` already does it on the canonical path.) | +| `config_handler.py` | 242 | `CombinationConfig` + `load_config(yaml)`. | The `baseline.yml`/`CombinationConfig` schema retires with the CLI. Verify external `load_config` users first (§Migration). | +| `combiners.py` | 1220 | `PcaOptimizationCombiner` (duplicate Phase 1/2), the deprecated `ComprehensiveCombiner`, and `_process_signal_group`/`_sweep_pca_thresholds`/`_prepare_cells_for_scoring`. | The whole file is config-path-only. Removing it deletes the **duplicate** pca_optimized implementation and the last in-repo caller of `anndata_utils.concatenate_experiments_comprehensive`. | +| `file_validator.py` | 155 | Input-file validation for `comprehensive`/`vertical`. | Skipped on the pca path; only the deprecated methods use it. | +| `cell_filters.py` | 272 | `build_cell_filter` + `DudGuide`/`TopPhenotype`/`Composed` filters. | Config-path-only — **but the subpackage has no cell-filtering equivalent.** See §Migration: port or consciously drop. | + +Also retiring with this path (alternative methods that were never the canonical flow): +| File | LOC | Role | Notes | +|---|---|---|---| +| `combiners.ComprehensiveCombiner` (in `combiners.py`) | — | `comprehensive` method → `anndata_utils.concatenate_experiments_comprehensive`. | Deprecated; see features `REFACTOR_PLAN_1`. | +| `classifier_combiner.py` | 219 | `ClassifierCombiner` — MLP-classifier aggregation. | **Never wired into `cli`**; depends on `CombinationConfig`. Dormant → retire. | +| `classifier_aggregator.py` | 1206 | MLP/CosineClassifier training machinery for the above. | Only via `ClassifierCombiner`. Dormant → retire. | + +--- + +## 4. SUPPLEMENTAL — downstream analysis / plotting (`analysis/` subpackage; consume combined outputs, not in the core path) + +Moved into `analysis/` on branch `alexhillsley/refactor`. Run as +`python -m ops_model.post_process.combination.analysis.`. + +| File | LOC | Entry | Role | +|---|---|---|---| +| `analysis/embedding_overlays.py` | 3080 | lib (`save_extra_overlays`) | Static/interactive UMAP overlays, Leiden + GO enrichment, super-category/CHAD/CORUM/EBI annotation. Invoked (lazily) by subpackage Phase 2 for *extra* plots; pure visualization. | +| `analysis/compare_map_scores.py` | 781 | `main` | Compare mAP metric CSVs across conditions (phase vs no-phase, DINO vs CP). | +| `analysis/compare_modalities.py` | 868 | `main` | Cross-modality distinctiveness at fixed cell budget (cp/4i/livecell set comparisons). | +| `analysis/pca_component_to_feature.py` | 680 | `main` | Interpret PCA loadings → CellProfiler feature categories. | + +--- + +## 5. SUPPLEMENTAL — `titration/` (cell-count titration analyses; all standalone, none in the core path) +| File | LOC | Entry | Role | +|---|---|---|---| +| `titration/titration.py` | 2262 | `main` | Per-reporter cell-count titration; score 4 metrics vs cell budget. | +| `titration/combined_titration.py` | 2316 | `main` | Multi-marker group titration (hconcat panels at each budget). | +| `titration/titration_reporter_pair.py` | 1040 | `main` | Two-reporter (e.g. Phase+fluor) titration with optional dual-PCA. | +| `titration/titration_phase_paired_fluor.py` | 572 | `main` | Phase + each fluor marker across budgets. | +| `titration/titration_phase_paired_dual_fluor.py` | 571 | `main` | Phase + two channel-disjoint fluor markers. | +| `titration/titration_paired_plots.py` | 559 | lib | Shared plot helpers for the two phase-paired drivers above. | + +(Note: several titration modules import `anndata_utils` aggregation helpers and the +subpackage's sweep/aggregation — they're independent of the deprecated config path.) + +--- + +## 6. SUPPLEMENTAL — `hand_annotations/` (manual curation / presentation; standalone) +| File | LOC | Entry | Role | +|---|---|---|---| +| `hand_annotations/embedding_param_sweep.py` | 754 | `main` | UMAP/PHATE parameter grid sweep on an existing run. | +| `hand_annotations/hand_annotated_umap_animation.py` | 1362 | `main` | GIF walking through hand-annotated gene clusters on the gene UMAP. | +| `hand_annotations/resolve_cluster_picks.py` | 329 | `main` | Resolve representative gene/channel per hand-annotated cluster from attention CSVs. | + +--- + +## 7. SUPPLEMENTAL — misc tooling (in `analysis/`) +| File | LOC | Entry | Role | +|---|---|---|---| +| `analysis/marker_norm_sweep_runner.py` | 37 | lib | Submitit wrapper subprocessing an external marker-normalization sweep script. Independent of the pipeline. Moved into `analysis/` on `alexhillsley/refactor`. | + +--- + +## Migration impact of deprecating the config/baseline path + +Status after removal on `alexhillsley/refactor`: + +- **Tests** `test_classifier_combiner.py` / `test_classifier_aggregator.py` / `test_combination_e2e.py` + — **deleted** with the modules they covered. + - ⚠️ **Follow-up:** `test_combination_e2e.py` was our only end-to-end combination test. It + needs **re-adding against the canonical subpackage** (drive Phase 1 + `pca_sweep_pooled_signal` + Phase 2 `aggregate_channels`, or `main` with argparse flags). + The subpackage's `_atomic_write_h5ad` already sanitizes obs, so no `cli`-side fix is needed. +- **External `load_config` users** — false alarms, no action: `project_shared_umap.py` defines + its **own** local `load_config`; the organelle stage uses its **own** `_load_config()`; the + energy_distance reference is a comment. None imported the deleted `config_handler`. +- **`experiments/scratch/20260414_debugging/pca_optimized_with_filter.py`** — scratch; imports + the now-deleted `combiners` (`PcaOptimizationCombiner`, `_sweep_pca_thresholds`) and + `config_handler`. **Now broken** — left as-is (scratch); update or delete when convenient. + +**Open decision (blocking `cell_filters.py` removal):** +- **Cell filtering** — `cell_filters.py` (dud-guide / top-phenotype filters) has no subpackage + equivalent. **Held, not deleted.** Port it into `pca_optimization` Phase 1, or accept + dropping the feature? It currently has zero importers (its only caller, `combiners.py`, is + gone), so it's dead-but-harmless until decided. + +## Takeaways +- **The canonical core is just `pca_optimization/`** (§1) + the `anndata_utils` aggregation + primitives. §2 are flag-gated stages of it. +- Deprecating the config path (§3) removes ~2.4k LOC (`cli`+`config_handler`+`combiners`+ + `file_validator`+`cell_filters`) plus the dormant classifier (~1.4k LOC), and eliminates the + **duplicate pca_optimized implementation** — the single biggest hazard in this dir. +- §4–§7 are legitimate standalone analysis tools — supplemental, low-coupling, keep as-is. diff --git a/src/ops_model/post_process/combination/analysis/__init__.py b/src/ops_model/post_process/combination/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ops_model/post_process/combination/compare_map_scores.py b/src/ops_model/post_process/combination/analysis/compare_map_scores.py similarity index 100% rename from src/ops_model/post_process/combination/compare_map_scores.py rename to src/ops_model/post_process/combination/analysis/compare_map_scores.py diff --git a/src/ops_model/post_process/combination/compare_modalities.py b/src/ops_model/post_process/combination/analysis/compare_modalities.py similarity index 100% rename from src/ops_model/post_process/combination/compare_modalities.py rename to src/ops_model/post_process/combination/analysis/compare_modalities.py diff --git a/src/ops_model/post_process/combination/embedding_overlays.py b/src/ops_model/post_process/combination/analysis/embedding_overlays.py similarity index 99% rename from src/ops_model/post_process/combination/embedding_overlays.py rename to src/ops_model/post_process/combination/analysis/embedding_overlays.py index 6049b07..d41ab9d 100644 --- a/src/ops_model/post_process/combination/embedding_overlays.py +++ b/src/ops_model/post_process/combination/analysis/embedding_overlays.py @@ -377,7 +377,7 @@ def _annotate_chrom_arm(adata, _logger=logger) -> None: if "chrom_arm" in adata.obs.columns: return try: - from ops_model.post_process.combination.guide_chrom_arm_correction import ( + from ops_model.post_process.combination.pipeline_add_ons.guide_chrom_arm_correction import ( SHARED_MAP_CSV_PATH, _load_symbol_to_arm_from_csv, ) diff --git a/src/ops_model/post_process/combination/marker_norm_sweep_runner.py b/src/ops_model/post_process/combination/analysis/marker_norm_sweep_runner.py similarity index 100% rename from src/ops_model/post_process/combination/marker_norm_sweep_runner.py rename to src/ops_model/post_process/combination/analysis/marker_norm_sweep_runner.py diff --git a/src/ops_model/post_process/combination/pca_component_to_feature.py b/src/ops_model/post_process/combination/analysis/pca_component_to_feature.py similarity index 100% rename from src/ops_model/post_process/combination/pca_component_to_feature.py rename to src/ops_model/post_process/combination/analysis/pca_component_to_feature.py diff --git a/src/ops_model/post_process/combination/classifier_aggregator.py b/src/ops_model/post_process/combination/classifier_aggregator.py deleted file mode 100644 index 4233379..0000000 --- a/src/ops_model/post_process/combination/classifier_aggregator.py +++ /dev/null @@ -1,1206 +0,0 @@ -"""Classifier-based aggregator for OPS embeddings. - -Replaces mean-pooling aggregation with a learned MLP that predicts perturbation -identity from multi-reporter embeddings. The penultimate-layer representations -become the new gene-level aggregated embeddings. - -See classifier_aggregator_plan.md (in ops_model/eval/) for full design details. -""" - -from __future__ import annotations - -import math -import time -import warnings -from collections import Counter, defaultdict -from pathlib import Path - -import anndata as ad -import h5py -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -from numpy.random import SeedSequence, default_rng -from torch.utils.data import DataLoader, Dataset - -from ops_model.features.anndata_utils import DEFAULT_GUIDE_COL, _guide_col -from ops_model.post_process.anndata_processing.anndata_validator import AnndataValidator - - -# --------------------------------------------------------------------------- -# Model -# --------------------------------------------------------------------------- - - -class CosineClassifier(nn.Module): - """L2-normalised linear head with learnable temperature scale.""" - - def __init__( - self, - in_dim: int, - num_classes: int, - init_scale: float = 20.0, - learn_scale: bool = True, - ): - super().__init__() - self.weight = nn.Parameter(torch.randn(num_classes, in_dim)) - nn.init.normal_(self.weight, std=0.01) - if learn_scale: - self.log_scale = nn.Parameter(torch.tensor(math.log(init_scale))) - else: - self.register_buffer("log_scale", torch.tensor(math.log(init_scale))) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = F.normalize(x, dim=1) - w = F.normalize(self.weight, dim=1) - return torch.exp(self.log_scale) * (x @ w.t()) - - -class MLP(nn.Module): - """MLP classifier with a split backbone / head for penultimate-layer extraction. - - Parameters - ---------- - input_dim: - Dimensionality of the concatenated multi-reporter input. - num_classes: - Number of perturbation classes. - hidden_dims: - Width of each hidden layer. - dropout: - Dropout rate applied after each hidden layer. - batch_norm: - If ``True`` (default), include ``BatchNorm1d`` after each linear layer. - cosine_classifier: - If ``True`` (default), use a cosine-similarity head instead of a plain - linear layer. - - Attributes - ---------- - backbone : nn.Sequential - All layers up to and including the last hidden block (linear + BN + ReLU - + dropout). Running ``backbone(x)`` returns the penultimate representation. - head : nn.Module - The final classifier layer (``CosineClassifier`` or ``nn.Linear``). - """ - - def __init__( - self, - input_dim: int, - num_classes: int, - hidden_dims: tuple[int, ...] = (512, 512, 512), - dropout: float = 0.4, - batch_norm: bool = True, - cosine_classifier: bool = True, - ): - super().__init__() - - backbone_layers: list[nn.Module] = [] - prev_dim = input_dim - for h in hidden_dims: - backbone_layers.append(nn.Linear(prev_dim, h)) - if batch_norm: - backbone_layers.append(nn.BatchNorm1d(h)) - backbone_layers.append(nn.ReLU()) - backbone_layers.append(nn.Dropout(dropout)) - prev_dim = h - - self.backbone = nn.Sequential(*backbone_layers) - self.head: nn.Module = ( - CosineClassifier(prev_dim, num_classes) - if cosine_classifier - else nn.Linear(prev_dim, num_classes) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.head(self.backbone(x)) - - -# --------------------------------------------------------------------------- -# Training helpers -# --------------------------------------------------------------------------- - - -def _topk_accuracy(logits: torch.Tensor, labels: torch.Tensor, k: int = 5) -> int: - """Number of samples whose true label is in the top-k predictions.""" - k = min(k, logits.size(1)) - topk_preds = logits.topk(k, dim=1).indices - return topk_preds.eq(labels.unsqueeze(1)).any(dim=1).sum().item() # type: ignore[return-value] - - -@torch.no_grad() -def _evaluate( - model: MLP, - loader: DataLoader, - criterion: nn.Module, - device: torch.device, -) -> tuple[float, float, float]: - """Return ``(avg_loss, top1_acc, top5_acc)`` on the given loader.""" - model.eval() - total_loss = 0.0 - correct_top1 = 0 - correct_top5 = 0 - total = 0 - for embeddings, labels in loader: - embeddings, labels = embeddings.to(device), labels.to(device) - logits = model(embeddings) - loss = criterion(logits, labels) - total_loss += loss.item() * len(labels) - correct_top1 += (logits.argmax(dim=1) == labels).sum().item() - correct_top5 += _topk_accuracy(logits, labels, k=5) - total += len(labels) - return total_loss / total, correct_top1 / total, correct_top5 / total - - -def _train_loop( - model: MLP, - train_loader: DataLoader, - val_loader: DataLoader, - num_epochs: int, - learning_rate: float, - weight_decay: float, - device: torch.device, - wandb_run=None, -) -> None: - """Standard AdamW train/eval loop with per-epoch console logging. - - Parameters - ---------- - wandb_run: - An active ``wandb.Run`` object (returned by ``wandb.init()``). When - provided, per-epoch metrics are logged to W&B in addition to stdout. - Pass ``None`` (default) to disable W&B logging. - """ - optimizer = torch.optim.AdamW( - model.parameters(), lr=learning_rate, weight_decay=weight_decay - ) - criterion = nn.CrossEntropyLoss() - - header = ( - f"{'Epoch':>5} | {'Train Loss':>10} | {'Train Top1':>10} " - f"| {'Train Top5':>10} | {'Val Loss':>10} | {'Val Top1':>10} " - f"| {'Val Top5':>10} | {'Time':>6}" - ) - print(header) - print("-" * len(header)) - - for epoch in range(1, num_epochs + 1): - t_start = time.time() - model.train() - running_loss = 0.0 - running_top1 = 0 - running_top5 = 0 - running_total = 0 - - for embeddings, labels in train_loader: - embeddings, labels = embeddings.to(device), labels.to(device) - optimizer.zero_grad() - logits = model(embeddings) - loss = criterion(logits, labels) - loss.backward() - optimizer.step() - - running_loss += loss.item() * len(labels) - running_top1 += (logits.argmax(dim=1) == labels).sum().item() - running_top5 += _topk_accuracy(logits, labels, k=5) - running_total += len(labels) - - train_loss = running_loss / running_total - train_acc = running_top1 / running_total - train_acc5 = running_top5 / running_total - - val_loss, val_acc, val_acc5 = _evaluate(model, val_loader, criterion, device) - - elapsed = time.time() - t_start - print( - f"{epoch:5d} | {train_loss:10.4f} | {train_acc:10.4%} " - f"| {train_acc5:10.4%} | {val_loss:10.4f} | {val_acc:10.4%} " - f"| {val_acc5:10.4%} | {elapsed:5.1f}s" - ) - - if wandb_run is not None: - wandb_run.log( - { - "train/loss": train_loss, - "train/top1_acc": train_acc, - "train/top5_acc": train_acc5, - "val/loss": val_loss, - "val/top1_acc": val_acc, - "val/top5_acc": val_acc5, - }, - step=epoch, - ) - - print(" Training complete.") - - -# --------------------------------------------------------------------------- -# View pre-computation -# --------------------------------------------------------------------------- - - -def _peek_embedding_dim(path: Path) -> int: - """Return the embedding dimensionality of an h5ad file without loading it. - - Tries a direct HDF5 read first (fast); falls back to AnnData backed mode for - sparse matrices. - """ - with h5py.File(path, "r") as f: - x_node = f.get("X") - if x_node is None: - raise ValueError(f"No 'X' dataset found in {path}") - if isinstance(x_node, h5py.Dataset): - return int(x_node.shape[1]) - # Sparse encoding: shape is stored as an attribute on the group - shape = x_node.attrs.get("shape") - if shape is not None: - return int(shape[1]) - # Fallback: open with AnnData backed mode (reads only metadata) - adata = ad.read_h5ad(path, backed="r") - dim = adata.n_vars - adata.file.close() - return dim - - -def precompute_views( - h5ad_paths: list[Path], - reporter_for_path: list[str], - reporters: list[str], - perturbations: list[str], - n_cells: int, - n_views: int, - seed: int = 42, -) -> tuple[np.ndarray, np.ndarray]: - """Pre-compute averaged cell views for every (perturbation, reporter) pair. - - Reads each h5ad file exactly once and discards raw cell embeddings immediately - after processing, keeping peak memory bounded to roughly one file at a time - plus the growing views array. - - For each reporter, ``n_views`` is split evenly across the files that contribute - to it. Each file computes its budget of views by sampling ``n_cells`` cells - uniformly at random and averaging them, independently for each view. - - Parameters - ---------- - h5ad_paths: - One path per (experiment, channel) pair, in processing order. - reporter_for_path: - Parallel list mapping each path to its biological signal / reporter name. - reporters: - Ordered list of unique reporter names. Defines axis-1 of the output. - perturbations: - Ordered list of unique perturbation names. Defines axis-0 of the output. - n_cells: - Number of cells averaged per view. - n_views: - Views to pre-compute per reporter per perturbation. Each reporter - gets this many views independently; if multiple files contribute to - one reporter their budgets are summed to reach this total. - seed: - Base random seed for reproducibility. - - Returns - ------- - views : np.ndarray, shape (n_perturbations, n_reporters, n_views, embedding_dim) - Pre-computed averaged embeddings. Unwritten slots (perturbation absent - from a file) are zero and excluded by ``n_valid``. - n_valid : np.ndarray, shape (n_perturbations, n_reporters), dtype int32 - Number of valid (written) views per (perturbation, reporter) slot. - Always in ``[0, n_views]``. - """ - n_p = len(perturbations) - n_r = len(reporters) - pert_to_idx = {p: i for i, p in enumerate(perturbations)} - rep_to_idx = {r: i for i, r in enumerate(reporters)} - - # --- Per-reporter file lists and view budgets ---------------------------- - # reporter_files[r_idx] = list of (file_idx, budget) in processing order - reporter_file_lists: list[list[int]] = [[] for _ in range(n_r)] - for file_idx, reporter in enumerate(reporter_for_path): - r_idx = rep_to_idx[reporter] - reporter_file_lists[r_idx].append(file_idx) - - # Split n_views evenly across files per reporter; earlier files get +1 if - # n_views is not divisible by n_files. - file_view_budget: list[int] = [0] * len(h5ad_paths) - for r_idx, file_indices in enumerate(reporter_file_lists): - if not file_indices: - continue - splits = np.array_split(range(n_views), len(file_indices)) - for file_idx, split in zip(file_indices, splits): - file_view_budget[file_idx] = len(split) - - # --- Discover embedding dim from first file ------------------------------ - embedding_dim = _peek_embedding_dim(h5ad_paths[0]) - print( - f"\n precompute_views: {n_p} perturbations × {n_r} reporters × " - f"{n_views} views, embedding_dim={embedding_dim}" - ) - print( - f" Allocating views array: " - f"{n_p * n_r * n_views * embedding_dim * 4 / 1e9:.2f} GB" - ) - - views = np.zeros((n_p, n_r, n_views, embedding_dim), dtype=np.float32) - n_valid = np.zeros((n_p, n_r), dtype=np.int32) - - # --- Independent RNG per file via SeedSequence -------------------------- - ss = SeedSequence(seed) - file_rngs = [default_rng(child) for child in ss.spawn(len(h5ad_paths))] - - # --- Process files ------------------------------------------------------- - for file_idx, path in enumerate(h5ad_paths): - reporter = reporter_for_path[file_idx] - r_idx = rep_to_idx[reporter] - budget = file_view_budget[file_idx] - rng = file_rngs[file_idx] - - print( - f" [{file_idx + 1}/{len(h5ad_paths)}] {path.name}" - f" reporter={reporter!r} budget={budget} views" - ) - - adata = ad.read_h5ad(path) - X = adata.X - if hasattr(X, "toarray"): - X = X.toarray() - X = np.asarray(X, dtype=np.float32) - obs_perturbations: np.ndarray = adata.obs["perturbation"].to_numpy() - del adata # free raw cells immediately - - # Group row indices by perturbation - cells_by_pert: dict[int, np.ndarray] = {} - for row_idx, pert in enumerate(obs_perturbations): - p_idx = pert_to_idx.get(pert) - if p_idx is None: - continue # perturbation not in our master list — skip - if p_idx not in cells_by_pert: - cells_by_pert[p_idx] = [] - cells_by_pert[p_idx].append(row_idx) - - n_written_this_file = 0 - for p_idx, row_list in cells_by_pert.items(): - already = int(n_valid[p_idx, r_idx]) - n_to_write = min(budget, n_views - already) - if n_to_write <= 0: - continue # view budget for this (perturbation, reporter) is full - - cell_indices = np.array(row_list, dtype=np.int32) - n_available = len(cell_indices) - replace = n_available < n_cells - if replace: - warnings.warn( - f"precompute_views: perturbation {perturbations[p_idx]!r}, " - f"reporter {reporter!r} has only {n_available} cells " - f"(need {n_cells}). Sampling with replacement.", - stacklevel=2, - ) - - for k in range(n_to_write): - sampled = rng.choice(cell_indices, size=n_cells, replace=replace) - views[p_idx, r_idx, already + k] = X[sampled].mean(axis=0) - - n_valid[p_idx, r_idx] += n_to_write - n_written_this_file += n_to_write - - n_perts_this_file = len(cells_by_pert) - del X, obs_perturbations, cells_by_pert - print( - f" wrote {n_written_this_file:,} views across {n_perts_this_file} perturbations" - ) - - # --- Summary ------------------------------------------------------------- - mean_valid = float(n_valid.mean()) - min_valid = int(n_valid.min()) - missing = int((n_valid == 0).sum()) - print( - f"\n precompute_views complete." - f" mean valid views/slot={mean_valid:.1f}" - f" min={min_valid}" - f" empty slots={(missing)}" - ) - - return views, n_valid - - -def precompute_inference_views( - h5ad_paths: list[Path], - reporter_for_path: list[str], - reporters: list[str], - perturbations: list[str], - n_cells: int, - seed: int = 0, - obs_col: str = "perturbation", -) -> tuple[list[np.ndarray], np.ndarray]: - """Pre-compute non-overlapping averaged cell views for inference. - - Unlike :func:`precompute_views`, this function uses **all available cells** - with no per-slot view-count cap. For each reporter, all contributing files - are pooled into a single cell matrix before views are created, so cells from - different files for the same ``(perturbation, reporter)`` slot are combined - into a shared pool. The pool is shuffled once per slot and sliced into - consecutive non-overlapping chunks of ``n_cells``. - - If a slot has fewer than ``n_cells`` cells across all files, all available - cells are averaged into one view. - - Each reporter receives its own views array sized to its own maximum view - count, avoiding the large wasted padding that would result from a global - maximum when reporters have very different cell counts. - - Parameters - ---------- - h5ad_paths, reporter_for_path, reporters, perturbations: - Same semantics as :func:`precompute_views`. - n_cells: - Number of cells averaged per view. - seed: - Base random seed for cell-shuffle reproducibility. - obs_col: - The obs column used to identify entities (perturbations or guides). - Defaults to ``"perturbation"`` with a ``"label_str"`` fallback for - backward compatibility. Pass ``"sgRNA"`` for guide-level views. - - Returns - ------- - views : list[np.ndarray] - One array per reporter, each shape - ``(n_perturbations, max_views_r, embedding_dim)`` where - ``max_views_r`` is the maximum view count for that reporter. - Unwritten slots are zero; use ``n_valid`` to determine the valid prefix. - n_valid : np.ndarray, shape (n_perturbations, n_reporters), dtype int32 - Number of valid views per slot. - cell_counts : np.ndarray, shape (n_perturbations, n_reporters), dtype int64 - Total cell count per (perturbation, reporter) slot, from Phase 1. - """ - n_p = len(perturbations) - n_r = len(reporters) - pert_to_idx = {p: i for i, p in enumerate(perturbations)} - rep_to_idx = {r: i for i, r in enumerate(reporters)} - - # --- Phase 1: backed-mode scan for global cell counts per slot ----------- - print( - "\n precompute_inference_views — Phase 1: counting cells per " - "(perturbation, reporter)..." - ) - cell_counts = np.zeros((n_p, n_r), dtype=np.int64) - for file_idx, path in enumerate(h5ad_paths): - reporter = reporter_for_path[file_idx] - r_idx = rep_to_idx[reporter] - adata = ad.read_h5ad(path, backed="r") - try: - col = ( - ("perturbation" if "perturbation" in adata.obs.columns else "label_str") - if obs_col == "perturbation" - else obs_col - ) - for pert, count in adata.obs[col].value_counts().items(): - p_idx = pert_to_idx.get(pert) - if p_idx is not None: - cell_counts[p_idx, r_idx] += count - finally: - adata.file.close() - - # --- Compute view count per slot from global cell counts ---------------- - # Slots with 0 cells → 0 views - # Slots with 1..n_cells-1 cells → 1 view (all cells; edge case) - # Slots with ≥ n_cells cells → floor(count / n_cells) views - n_views_per_slot = np.zeros((n_p, n_r), dtype=np.int32) - mask_some = cell_counts > 0 - mask_full = cell_counts >= n_cells - n_views_per_slot[mask_some & ~mask_full] = 1 - n_views_per_slot[mask_full] = (cell_counts[mask_full] // n_cells).astype(np.int32) - - # --- Report entities whose global cell count is below n_cells ---------- - for r_idx, reporter in enumerate(reporters): - n_total = int(mask_some[:, r_idx].sum()) - n_few = int((mask_some[:, r_idx] & ~mask_full[:, r_idx]).sum()) - if n_few > 0: - print( - f" {reporter}: {n_few} / {n_total} guides had < {n_cells} cells," - f" using all available as 1 view" - ) - - # Per-reporter max view count — avoids padding all reporters to a global max - max_views_per_reporter = [ - ( - int(n_views_per_slot[:, r_idx].max()) - if n_views_per_slot[:, r_idx].max() > 0 - else 1 - ) - for r_idx in range(n_r) - ] - - # --- Discover embedding dim and allocate per-reporter views arrays ------ - embedding_dim = _peek_embedding_dim(h5ad_paths[0]) - total_gb = sum(n_p * mv * embedding_dim * 4 for mv in max_views_per_reporter) / 1e9 - print( - f" precompute_inference_views: {n_p} perturbations × {n_r} reporters, " - f"embedding_dim={embedding_dim}, total allocation={total_gb:.2f} GB" - ) - for r_idx, reporter in enumerate(reporters): - mv = max_views_per_reporter[r_idx] - gb = n_p * mv * embedding_dim * 4 / 1e9 - print(f" {reporter}: up to {mv} views ({gb:.2f} GB)") - - views: list[np.ndarray] = [ - np.zeros((n_p, max_views_per_reporter[r_idx], embedding_dim), dtype=np.float32) - for r_idx in range(n_r) - ] - n_written = np.zeros((n_p, n_r), dtype=np.int32) - - # --- Build reporter → file paths mapping -------------------------------- - reporter_to_paths: dict[str, list[Path]] = {r: [] for r in reporters} - for path, reporter in zip(h5ad_paths, reporter_for_path): - reporter_to_paths[reporter].append(path) - - # --- Independent RNG per (perturbation, reporter) slot ------------------ - ss = SeedSequence(seed) - slot_rngs = default_rng(ss).integers(0, 2**31, size=(n_p, n_r), dtype=np.int64) - - # --- Phase 2: pool all files per reporter, then create views ------------ - print(" Phase 2: building non-overlapping inference views...") - for r_idx, reporter in enumerate(reporters): - paths = reporter_to_paths[reporter] - if not paths: - continue - - print(f" reporter={reporter!r} ({len(paths)} file(s))") - - # Pool cells from all files for this reporter - adata_list = [ad.read_h5ad(p) for p in paths] - adata_pooled = ad.concat(adata_list, join="inner", merge="same") - del adata_list - - X = adata_pooled.X - if hasattr(X, "toarray"): - X = X.toarray() - X = np.asarray(X, dtype=np.float32) - col = ( - ( - "perturbation" - if "perturbation" in adata_pooled.obs.columns - else "label_str" - ) - if obs_col == "perturbation" - else obs_col - ) - obs_labels = adata_pooled.obs[col].to_numpy() - del adata_pooled - - # Group row indices by entity (perturbation or guide) - cells_by_pert: dict[int, list[int]] = {} - for row_idx, label in enumerate(obs_labels): - p_idx = pert_to_idx.get(label) - if p_idx is None: - continue - if p_idx not in cells_by_pert: - cells_by_pert[p_idx] = [] - cells_by_pert[p_idx].append(row_idx) - - n_written_this_reporter = 0 - for p_idx, row_list in cells_by_pert.items(): - n_views = int(n_views_per_slot[p_idx, r_idx]) - if n_views == 0: - continue - - cell_indices = np.array(row_list, dtype=np.int32) - rng = default_rng(int(slot_rngs[p_idx, r_idx])) - rng.shuffle(cell_indices) - - if len(cell_indices) < n_cells: - # Fewer cells than n_cells globally — use all as 1 view - views[r_idx][p_idx, 0] = X[cell_indices].mean(axis=0) - n_written[p_idx, r_idx] = 1 - n_written_this_reporter += 1 - else: - for k in range(n_views): - chunk = cell_indices[k * n_cells : (k + 1) * n_cells] - views[r_idx][p_idx, k] = X[chunk].mean(axis=0) - n_written[p_idx, r_idx] = n_views - n_written_this_reporter += n_views - - del X, obs_labels, cells_by_pert - print(f" wrote {n_written_this_reporter:,} views") - - n_valid = n_written - mean_valid = float(n_valid.mean()) - min_valid = int(n_valid.min()) - missing = int((n_valid == 0).sum()) - print( - f"\n precompute_inference_views complete." - f" mean valid views/slot={mean_valid:.1f}" - f" min={min_valid}" - f" empty slots={missing}" - ) - return views, n_valid, cell_counts - - -# --------------------------------------------------------------------------- -# Dataset -# --------------------------------------------------------------------------- - - -class ClassifierAggregatorDataset(Dataset): - """PyTorch Dataset over pre-computed per-(perturbation, reporter) views. - - Each item is a concatenation of one embedding vector per reporter, - paired with an integer perturbation label. - - Parameters - ---------- - views: - Shape ``(n_perturbations, n_reporters, n_views, embedding_dim)``. - Output of :func:`precompute_views`. - n_valid: - Shape ``(n_perturbations, n_reporters)``, dtype int32. - Number of valid (written) views per slot; sampling is restricted to - ``views[i, r, :n_valid[i, r], :]``. - labels: - Integer-encoded perturbation labels, shape ``(n_perturbations,)``. - inference: - If ``False`` (default, training mode), each ``__getitem__`` call - independently samples one view index per reporter from the valid - views, giving combinatorial augmentation across reporters. - If ``True`` (inference mode), the valid views for each reporter are - averaged deterministically before concatenation. - """ - - def __init__( - self, - views: np.ndarray, - n_valid: np.ndarray, - labels: np.ndarray, - inference: bool = False, - ): - n_perturbations, n_reporters, n_views, embedding_dim = views.shape - assert n_valid.shape == (n_perturbations, n_reporters) - assert labels.shape == (n_perturbations,) - - self.views = views - self.n_valid = n_valid - self.labels = torch.tensor(labels, dtype=torch.long) - self.inference = inference - self.n_reporters = n_reporters - self.embedding_dim = embedding_dim - - empty = int((n_valid == 0).sum()) - if empty: - warnings.warn( - f"ClassifierAggregatorDataset: {empty} (perturbation, reporter) " - f"slots have no valid views and will produce zero embeddings.", - stacklevel=2, - ) - - def __len__(self) -> int: - return len(self.labels) - - def __getitem__(self, i: int) -> tuple[torch.Tensor, torch.Tensor]: - parts: list[np.ndarray] = [] - - for r in range(self.n_reporters): - valid = int(self.n_valid[i, r]) - if valid == 0: - parts.append(np.zeros(self.embedding_dim, dtype=np.float32)) - elif self.inference: - parts.append(self.views[i, r, :valid].mean(axis=0)) - else: - j = np.random.randint(0, valid) - parts.append(self.views[i, r, j]) - - embedding = torch.tensor(np.concatenate(parts), dtype=torch.float32) - return embedding, self.labels[i] - - -# --------------------------------------------------------------------------- -# Aggregator -# --------------------------------------------------------------------------- - - -class ClassifierAggregator: - """Train an MLP classifier on pre-computed multi-reporter views and extract - penultimate-layer representations as gene-level aggregated embeddings. - - Parameters - ---------- - hidden_dims: - Width of each MLP hidden layer. - dropout: - Dropout rate applied after each hidden layer. - cosine_classifier: - Use a cosine-similarity head instead of a plain linear layer. - batch_size: - DataLoader batch size for both training and inference. - num_epochs: - Number of training epochs. - learning_rate: - AdamW learning rate. - weight_decay: - AdamW weight decay. - val_fraction: - Fraction of perturbations held out for validation (group split). - seed: - Random seed for train/val split and DataLoader worker init. - - Notes - ----- - The MLP is instantiated during :meth:`fit` once the embedding - dimensionality is known from the pre-computed views. - """ - - def __init__( - self, - hidden_dims: tuple[int, ...] = (512, 512, 512), - dropout: float = 0.4, - cosine_classifier: bool = True, - batch_size: int = 256, - num_epochs: int = 50, - learning_rate: float = 1e-3, - weight_decay: float = 1e-4, - val_fraction: float = 0.2, - seed: int = 42, - ): - self.hidden_dims = tuple(hidden_dims) - self.dropout = dropout - self.cosine_classifier = cosine_classifier - self.batch_size = batch_size - self.num_epochs = num_epochs - self.learning_rate = learning_rate - self.weight_decay = weight_decay - self.val_fraction = val_fraction - self.seed = seed - - # Set after fit() - self.model: MLP | None = None - self.views: np.ndarray | None = None - self.n_valid: np.ndarray | None = None - self.perturbations: list[str] | None = None - self.reporters: list[str] | None = None - self._labels: np.ndarray | None = None - self._cell_type: str = "cell" - self._embedding_type: str = "" - self._guide_col: str = DEFAULT_GUIDE_COL - self._h5ad_paths: list[Path] | None = None - self._reporter_for_path: list[str] | None = None - self._n_cells_per_view: int | None = None - self._n_experiments: int = 0 - - # ------------------------------------------------------------------ - # Fit - # ------------------------------------------------------------------ - - def fit( - self, - h5ad_paths: list[Path], - reporter_for_path: list[str], - reporters: list[str], - perturbations: list[str], - n_cells: int, - n_views: int, - device: torch.device, - wandb_project: str | None = None, - wandb_run_name: str | None = None, - ) -> None: - """Pre-compute views, train the MLP, and store state for transform. - - Parameters - ---------- - h5ad_paths: - One path per (experiment, channel) pair. - reporter_for_path: - Parallel list mapping each path to its reporter name. - reporters: - Ordered list of unique reporter names. - perturbations: - Ordered list of unique perturbation names (classification targets). - n_cells: - Cells averaged per view during pre-computation. - n_views: - Total views pre-computed per (perturbation, reporter). - device: - Torch device for MLP training. - wandb_project: - W&B project name. When provided, a run is initialised with - ``wandb.init()`` before training and finished afterwards. Pass - ``None`` (default) to disable W&B logging. - wandb_run_name: - Optional display name for the W&B run. Ignored when - ``wandb_project`` is ``None``. - """ - # --- Pre-compute views --- - views, n_valid = precompute_views( - h5ad_paths=h5ad_paths, - reporter_for_path=reporter_for_path, - reporters=reporters, - perturbations=perturbations, - n_cells=n_cells, - n_views=n_views, - seed=self.seed, - ) - - n_perturbations, n_reporters, _, embedding_dim = views.shape - input_dim = embedding_dim * n_reporters - num_classes = n_perturbations - labels = np.arange(n_perturbations, dtype=np.int64) - - # --- Train / val split (by view index, not perturbation) --- - # All perturbations appear in both sets; val uses held-out views so the - # model cannot be evaluated on classes it was never trained on. - n_train_views = max(1, n_views - max(1, int(n_views * self.val_fraction))) - n_val_views = n_views - n_train_views - - train_views = views[:, :, :n_train_views, :] - val_views = views[:, :, n_train_views:, :] - train_n_valid = np.minimum(n_valid, n_train_views) - val_n_valid = np.maximum(n_valid - n_train_views, 0).astype(np.int32) - - print( - f"\n ClassifierAggregator.fit: {n_perturbations} perturbations, " - f"{n_reporters} reporters, input_dim={input_dim}, " - f"num_classes={num_classes}" - ) - print( - f" Split: {n_train_views} train / {n_val_views} val views per (perturbation, reporter)" - ) - - train_ds = ClassifierAggregatorDataset( - train_views, train_n_valid, labels, inference=False - ) - val_ds = ClassifierAggregatorDataset( - val_views, val_n_valid, labels, inference=True - ) - - train_loader = DataLoader( - train_ds, batch_size=self.batch_size, shuffle=True, num_workers=0 - ) - val_loader = DataLoader( - val_ds, batch_size=self.batch_size, shuffle=False, num_workers=0 - ) - - # --- Instantiate and train MLP --- - model = MLP( - input_dim=input_dim, - num_classes=num_classes, - hidden_dims=self.hidden_dims, - dropout=self.dropout, - cosine_classifier=self.cosine_classifier, - ).to(device) - - total_params = sum(p.numel() for p in model.parameters()) - print(f" MLP: {total_params:,} params hidden_dims={self.hidden_dims}") - print(model) - - # --- W&B setup --- - wandb_run = None - if wandb_project is not None: - import wandb - - wandb_run = wandb.init( - project=wandb_project, - name=wandb_run_name, - config={ - "hidden_dims": self.hidden_dims, - "dropout": self.dropout, - "cosine_classifier": self.cosine_classifier, - "num_epochs": self.num_epochs, - "learning_rate": self.learning_rate, - "weight_decay": self.weight_decay, - "batch_size": self.batch_size, - "val_fraction": self.val_fraction, - "seed": self.seed, - "n_cells_per_view": n_cells, - "n_views": n_views, - "n_perturbations": n_perturbations, - "n_reporters": n_reporters, - "input_dim": input_dim, - "num_classes": num_classes, - "reporters": reporters, - }, - ) - - try: - _train_loop( - model=model, - train_loader=train_loader, - val_loader=val_loader, - num_epochs=self.num_epochs, - learning_rate=self.learning_rate, - weight_decay=self.weight_decay, - device=device, - wandb_run=wandb_run, - ) - finally: - if wandb_run is not None: - wandb_run.finish() - - # --- Infer uns metadata from first h5ad --- - _first = ad.read_h5ad(h5ad_paths[0], backed="r") - self._cell_type: str = _first.uns.get("cell_type", "cell") - self._embedding_type: str = _first.uns.get("embedding_type", "") - self._guide_col: str = _guide_col(_first) - _first.file.close() - - # --- Store state --- - self.model = model.cpu() - self.views = views - self.n_valid = n_valid - self.perturbations = list(perturbations) - self.reporters = list(reporters) - self._labels = labels - self._h5ad_paths = list(h5ad_paths) - self._reporter_for_path = list(reporter_for_path) - self._n_cells_per_view = n_cells - self._n_experiments = max(Counter(reporter_for_path).values()) - - # ------------------------------------------------------------------ - # Transform - # ------------------------------------------------------------------ - - def _tta_loop( - self, - inference_views: list[np.ndarray], - inference_n_valid: np.ndarray, - n_passes: int, - device: torch.device, - ) -> np.ndarray: - """Run TTA inference and return averaged backbone embeddings. - - Parameters - ---------- - inference_views: - List of ``n_reporters`` arrays, each shape - ``(n_entities, max_views_r, embedding_dim)``. Output of - :func:`precompute_inference_views`. - inference_n_valid: - Shape ``(n_entities, n_reporters)``, dtype int32. - n_passes: - Number of random-pairing passes to accumulate. - device: - Torch device for forward passes. - - Returns - ------- - np.ndarray, shape ``(n_entities, last_hidden_dim)``, dtype float32 - """ - assert self.model is not None - n_e, n_r = inference_n_valid.shape - embedding_dim = inference_views[0].shape[-1] - backbone_out_dim = self.hidden_dims[-1] - - rep_sum = np.zeros((n_e, backbone_out_dim), dtype=np.float64) - rng = np.random.default_rng(self.seed) - - for _ in range(n_passes): - inputs = np.zeros((n_e, n_r * embedding_dim), dtype=np.float32) - for r in range(n_r): - valid_counts = inference_n_valid[:, r] - has_valid = valid_counts > 0 - j_r = rng.integers(0, np.maximum(valid_counts, 1), size=n_e) - gathered = inference_views[r][np.arange(n_e), j_r, :].copy() - gathered[~has_valid] = 0.0 - inputs[:, r * embedding_dim : (r + 1) * embedding_dim] = gathered - - for start in range(0, n_e, self.batch_size): - batch = torch.tensor(inputs[start : start + self.batch_size]).to(device) - rep_sum[start : start + self.batch_size] += ( - self.model.backbone(batch).cpu().numpy() - ) - - return (rep_sum / n_passes).astype(np.float32) - - def _discover_guides(self) -> tuple[list[str], dict[str, str]]: - """Backed-mode scan to find all guide IDs and their gene mappings. - - Reads ``obs[self._guide_col]`` (the per-construct identifier column — - ``"sgRNA"`` for CRISPR, ``"minibinder_perturbation"`` for minibinder, - resolved from the first h5ad's ``uns["guide_col"]``) and - ``obs["perturbation"]`` (or ``"label_str"``) from every h5ad file. Each - guide always maps to the same gene, so the relationship is collected - into a ``guide → gene`` dict. - - Returns - ------- - guides : list[str] - Sorted list of unique guide IDs. - guide_to_gene : dict[str, str] - Maps each guide ID to its gene / perturbation label. - """ - assert self._h5ad_paths is not None - guide_to_gene: dict[str, str] = {} - for path in self._h5ad_paths: - adata = ad.read_h5ad(path, backed="r") - try: - gene_col = ( - "perturbation" - if "perturbation" in adata.obs.columns - else "label_str" - ) - mapping = adata.obs.groupby(self._guide_col, observed=True)[ - gene_col - ].first() - guide_to_gene.update(mapping.to_dict()) - finally: - adata.file.close() - guides = sorted(guide_to_gene.keys()) - return guides, guide_to_gene - - def _make_adata( - self, - X: np.ndarray, - entities: list[str], - ) -> ad.AnnData: - """Build a gene- or guide-level AnnData with standard uns fields.""" - assert self.reporters is not None - obs = pd.DataFrame({"perturbation": entities}, index=entities) - adata = ad.AnnData(X=X, obs=obs) - adata.uns["aggregation_method"] = "classifier" - adata.uns["reporters"] = self.reporters - adata.uns["cell_type"] = self._cell_type - adata.uns["embedding_type"] = self._embedding_type - adata.uns["guide_col"] = self._guide_col - return adata - - @torch.no_grad() - def transform( - self, - device: torch.device, - n_passes: int = 100, - ) -> tuple[ad.AnnData, ad.AnnData]: - """Extract penultimate-layer representations using TTA-style random pairing. - - For each of ``n_passes`` forward passes, each reporter independently - samples one view at random from all available inference views; backbone - outputs are averaged across passes. Inference views are computed from - **all available cells** using non-overlapping chunks of - ``n_cells_per_view`` — no training view-count cap is applied. - - Requires a per-construct identifier column (named via - ``adata.uns["guide_col"]``, default ``"sgRNA"``) in the h5ad obs. - Inference is run at the **guide level** (one embedding per guide) and - gene-level embeddings are derived by averaging guide embeddings with - equal weight per guide. - - Parameters - ---------- - device: - Torch device for the forward pass. - n_passes: - Number of random-pairing passes to average. - - Returns - ------- - guide_adata : ad.AnnData - Guide-level AnnData ``(n_guides, last_hidden_dim)``. - gene_adata : ad.AnnData - Gene-level AnnData ``(n_perturbations, last_hidden_dim)``. - - Raises - ------ - ValueError - If the configured guide column is not present in the h5ad obs. - """ - if self.model is None: - raise RuntimeError("Call fit() before transform().") - assert self.perturbations is not None - assert self.reporters is not None - assert self._h5ad_paths is not None - assert self._reporter_for_path is not None - assert self._n_cells_per_view is not None - - self.model.eval() - self.model.to(device) - - # --- Require guide-level obs column --- - _first = ad.read_h5ad(self._h5ad_paths[0], backed="r") - has_guide_col = self._guide_col in _first.obs.columns - _first.file.close() - - if not has_guide_col: - raise ValueError( - f"No {self._guide_col!r} column found in obs. Guide-level data " - "is required for transform(). Ensure h5ad files have an " - f"{self._guide_col!r} obs column." - ) - - # --- Guide-level inference --- - guides, guide_to_gene = self._discover_guides() - n_guides = len(guides) - print( - f"\n transform (guide-level): {n_passes} passes, " - f"{n_guides} guides, {len(self.reporters)} reporters" - ) - - g_views, g_n_valid, g_cell_counts = precompute_inference_views( - h5ad_paths=self._h5ad_paths, - reporter_for_path=self._reporter_for_path, - reporters=self.reporters, - perturbations=guides, - n_cells=self._n_cells_per_view, - seed=self.seed, - obs_col=self._guide_col, - ) - - X_guides = self._tta_loop(g_views, g_n_valid, n_passes, device) - - guide_adata = self._make_adata(X_guides, guides) - guide_adata.obs["perturbation"] = [guide_to_gene[g] for g in guides] - guide_adata.obs[self._guide_col] = guides - _counts = g_cell_counts.astype(np.float64) - _counts[_counts == 0] = np.nan - guide_adata.obs["n_cells"] = np.nanmin(_counts, axis=1).astype(np.int64) - AnndataValidator().validate(guide_adata, level="guide", strict=True) - - # --- Gene-level: equal-weight mean of guide embeddings per gene --- - gene_labels_s = pd.Series(guide_to_gene)[guides] - guide_df = pd.DataFrame(X_guides, index=guides) - gene_df = guide_df.groupby(gene_labels_s.values).mean() - gene_names = gene_df.index.tolist() - gene_adata = self._make_adata(gene_df.values.astype(np.float32), gene_names) - - gene_to_guides_map: dict[str, list[str]] = defaultdict(list) - for guide, gene in guide_to_gene.items(): - gene_to_guides_map[gene].append(guide) - - guide_n_cells = pd.Series( - np.nanmin(_counts, axis=1).astype(np.int64), index=guides - ) - gene_n_cells = guide_n_cells.groupby(gene_labels_s.values).min() - gene_adata.obs["n_cells"] = gene_n_cells[gene_names].values.astype(np.int64) - gene_adata.obs["guides"] = [ - "|".join(sorted(gene_to_guides_map[g])) for g in gene_names - ] - gene_adata.obs["n_experiments"] = self._n_experiments - AnndataValidator().validate(gene_adata, level="gene", strict=True) - - self.model.cpu() - return guide_adata, gene_adata - - # ------------------------------------------------------------------ - # Weight persistence - # ------------------------------------------------------------------ - - def save_weights(self, path: Path | str) -> None: - """Save MLP state dict to disk.""" - if self.model is None: - raise RuntimeError("No model to save — call fit() first.") - path = Path(path) - path.parent.mkdir(parents=True, exist_ok=True) - torch.save(self.model.state_dict(), path) - print(f" Saved classifier weights to {path}") - - def load_weights( - self, path: Path | str, device: torch.device | None = None - ) -> None: - """Load MLP state dict from disk. - - The model must already be instantiated (i.e. ``fit()`` must have - been called first to set ``self.model``). - """ - if self.model is None: - raise RuntimeError("Model not instantiated — call fit() first.") - state = torch.load(path, map_location=device or "cpu", weights_only=True) - self.model.load_state_dict(state) - print(f" Loaded classifier weights from {path}") diff --git a/src/ops_model/post_process/combination/classifier_combiner.py b/src/ops_model/post_process/combination/classifier_combiner.py deleted file mode 100644 index a5f3020..0000000 --- a/src/ops_model/post_process/combination/classifier_combiner.py +++ /dev/null @@ -1,219 +0,0 @@ -"""ClassifierCombiner: train an MLP on multi-reporter cell views and return -gene-level penultimate-layer embeddings. - -See classifier_aggregator_plan.md (in ops_model/eval/) for full design details. -""" - -from __future__ import annotations - -import logging -from pathlib import Path -from typing import Dict, List, Tuple - -import anndata as ad -import torch - -from .classifier_aggregator import ClassifierAggregator -from .config_handler import CombinationConfig - -logger = logging.getLogger(__name__) - - -class ClassifierCombiner: - """Config-driven combiner that trains a classifier and returns penultimate-layer - embeddings as gene-level AnnData. - - Replaces mean-pooling aggregation: an MLP is trained to predict perturbation - identity from multi-reporter concatenated views and the penultimate-layer - representations become the gene-level embeddings. - - Only produces gene-level output (no guide-level AnnData). - """ - - def __init__(self, config: CombinationConfig) -> None: - self.config = config - - # ------------------------------------------------------------------ - # Public interface - # ------------------------------------------------------------------ - - def combine(self) -> tuple[ad.AnnData, ad.AnnData]: - """Run full pipeline: path resolution → view pre-computation → MLP training → transform. - - Returns - ------- - guide_adata : ad.AnnData or None - Guide-level AnnData ``(n_guides × last_hidden_dim)``, or ``None`` - if no per-construct identifier column (see ``uns["guide_col"]``, - default ``"sgRNA"``) is found in the h5ad files. - gene_adata : ad.AnnData - Gene-level AnnData ``(n_perturbations × last_hidden_dim)``. - """ - from ops_utils.data.feature_discovery import ( - build_signal_groups, - find_cell_h5ad_path, - get_channel_maps_path, - ) - from ops_utils.data.feature_metadata import FeatureMetadata - - agg_cfg = self.config.classifier_aggregation - - # 1. Flatten experiments_channels → (exp, ch) pairs - pairs: List[Tuple[str, str]] = [ - (exp, ch) - for exp, channels in (self.config.experiments_channels or {}).items() - for ch in channels - ] - if not pairs: - raise ValueError( - "No experiment/channel pairs found in config.experiments_channels." - ) - - # 2. Group by biological signal (reporter) - maps_path = get_channel_maps_path() - fm = FeatureMetadata(metadata_path=maps_path) - signal_groups: Dict[str, List[Tuple[str, str]]] = build_signal_groups(pairs, fm) - - if not signal_groups: - raise ValueError( - "No experiment/channel pairs could be resolved to a biological signal." - ) - - reporters = list(signal_groups.keys()) - logger.info(f"Resolved {len(reporters)} reporters: {reporters}") - - # 3. Resolve h5ad paths using find_cell_h5ad_path (handles 3-assembly/ - # subdirectory, reporter→channel fallback, and missing files) - storage_roots = [Path(self.config.base_dir)] - feature_dir = self.config.feature_dir - - h5ad_paths: List[Path] = [] - reporter_for_path: List[str] = [] - - for signal, signal_pairs in signal_groups.items(): - for exp, ch in signal_pairs: - path = find_cell_h5ad_path( - exp, ch, storage_roots, feature_dir, maps_path - ) - if path is None: - logger.warning( - f"h5ad not found for {exp}/{ch} (reporter={signal!r}), skipping." - ) - continue - h5ad_paths.append(path) - reporter_for_path.append(signal) - logger.info(f" {signal:<30} {exp}/{ch} → {path.name}") - - if not h5ad_paths: - raise ValueError( - "No h5ad files could be resolved. Check base_dir and feature_dir." - ) - - # 4. Discover perturbations by scanning obs only (no .X loaded) - perturbations = self._discover_perturbations(h5ad_paths) - logger.info(f"Discovered {len(perturbations)} perturbations across all files.") - self._log_cells_per_perturbation(h5ad_paths, reporter_for_path, reporters) - - # 5. Instantiate ClassifierAggregator from config params - aggregator = ClassifierAggregator( - hidden_dims=tuple(agg_cfg.get("hidden_dims", [512, 512, 512])), - dropout=agg_cfg.get("dropout", 0.4), - cosine_classifier=agg_cfg.get("cosine_classifier", True), - batch_size=agg_cfg.get("batch_size", 256), - num_epochs=agg_cfg.get("num_epochs", 50), - learning_rate=agg_cfg.get("learning_rate", 1e-3), - weight_decay=agg_cfg.get("weight_decay", 1e-4), - val_fraction=agg_cfg.get("val_fraction", 0.2), - seed=agg_cfg.get("seed", 42), - ) - - device = torch.device(agg_cfg.get("device", "cpu")) - - # 6. Fit - aggregator.fit( - h5ad_paths=h5ad_paths, - reporter_for_path=reporter_for_path, - reporters=reporters, - perturbations=perturbations, - n_cells=agg_cfg["n_cells_per_view"], - n_views=agg_cfg["n_views"], - device=device, - wandb_project=agg_cfg.get("wandb_project"), - wandb_run_name=agg_cfg.get("wandb_run_name"), - ) - - # 7. Save weights - weights_path = agg_cfg.get("weights_path") - if weights_path: - aggregator.save_weights(weights_path) - - # 8. Extract and return guide- and gene-level embeddings - n_passes = agg_cfg.get("inference_n_passes", 100) - return aggregator.transform(device, n_passes=n_passes) - - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - def _log_cells_per_perturbation( - self, - h5ad_paths: List[Path], - reporter_for_path: List[str], - reporters: List[str], - ) -> None: - """Log mean cells per perturbation for each reporter (obs-only scan).""" - from collections import defaultdict - - reporter_paths: Dict[str, List[Path]] = defaultdict(list) - for path, reporter in zip(h5ad_paths, reporter_for_path): - reporter_paths[reporter].append(path) - - logger.info("Mean cells per perturbation per reporter:") - for reporter in reporters: - paths = reporter_paths.get(reporter, []) - if not paths: - continue - - pert_counts: Dict[str, int] = defaultdict(int) - for path in paths: - adata = ad.read_h5ad(path, backed="r") - try: - col = ( - "perturbation" - if "perturbation" in adata.obs.columns - else "label_str" - ) - for pert, count in adata.obs[col].value_counts().items(): - pert_counts[pert] += count - finally: - adata.file.close() - - if pert_counts: - mean_cells = sum(pert_counts.values()) / len(pert_counts) - logger.info( - f" {reporter:<45} {mean_cells:>8.0f} cells/perturbation" - f" ({len(paths)} file(s), {len(pert_counts)} perturbations)" - ) - - def _discover_perturbations(self, h5ad_paths: List[Path]) -> List[str]: - """Collect the union of perturbation labels across all h5ad files. - - Uses AnnData backed mode so only ``.obs`` is read — ``.X`` is never - loaded into memory. - - Falls back to ``label_str`` if ``perturbation`` is absent (backwards - compatibility with older h5ad files). - """ - all_perts: set[str] = set() - for path in h5ad_paths: - adata = ad.read_h5ad(path, backed="r") - try: - col = ( - "perturbation" - if "perturbation" in adata.obs.columns - else "label_str" - ) - all_perts.update(adata.obs[col].unique()) - finally: - adata.file.close() - return sorted(all_perts) diff --git a/src/ops_model/post_process/combination/cli.py b/src/ops_model/post_process/combination/cli.py deleted file mode 100644 index 91d7093..0000000 --- a/src/ops_model/post_process/combination/cli.py +++ /dev/null @@ -1,279 +0,0 @@ -""" -CLI and programmatic entry point for the AnnData combination pipeline. - -This module provides the main orchestration logic for the combination process, -tying together configuration, file validation, and the core combination logic. -""" - -import argparse -import logging -import sys -from pathlib import Path -from typing import Dict, List, Optional - -import anndata as ad - -from .config_handler import CombinationConfig, load_config -from .file_validator import FileValidator -from .combiners import ComprehensiveCombiner, PcaOptimizationCombiner -from ..anndata_processing.anndata_validator import AnndataValidator, IssueLevel - -# Initialize logger -logger = logging.getLogger(__name__) - - -def setup_logging(level: int = logging.INFO): - """ - Configures the root logger for the application. - - Args: - level: The logging level to set (e.g., logging.INFO, logging.DEBUG). - """ - logging.basicConfig( - level=level, - format="%(asctime)s [%(levelname)s] - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - stream=sys.stdout, - ) - - -def load_anndata_objects(file_paths: List[Path]) -> Dict[str, ad.AnnData]: - """ - Loads a list of AnnData files from disk into a dictionary. - - Args: - file_paths: A list of Path objects pointing to .h5ad files. - - Returns: - A dictionary mapping stringified file paths to loaded AnnData objects. - """ - adata_objects = {} - logger.info(f"Loading {len(file_paths)} AnnData object(s)...") - for path in file_paths: - try: - adata_objects[str(path)] = ad.read_h5ad(path) - logger.debug(f"Successfully loaded {path}") - except Exception as e: - logger.error(f"Failed to load AnnData file: {path}. Error: {e}") - # The "warn and continue" policy applies to finding files. - # If a found file fails to load, it's a critical error. - raise - return adata_objects - - -def validate_and_save(adata: ad.AnnData, path: Path, level: str): - """ - Validate an AnnData object against a schema and save it to a .h5ad file. - - This function enforces hard validation constraints. If any ERROR-level issues - are found, it logs a detailed report and raises a ValueError, preventing - the invalid object from being saved. - - Args: - adata: The AnnData object to validate and save. - path: The Path object where the file will be saved. - level: The schema level to validate against (e.g., "multi_experiment"). - - Raises: - ValueError: If the AnnData object fails validation. - """ - logger.info(f"Validating final AnnData object against '{level}' schema...") - validator = AnndataValidator() - issues = validator.validate(adata, level=level, strict=False) - errors = [issue for issue in issues.errors if issue.level == IssueLevel.ERROR] - - if errors: - error_summary = ( - f"{len(errors)} validation error(s) found in {level}-level AnnData" - ) - logger.error(f"✗ Validation FAILED: {error_summary}") - for issue in errors[:10]: # Log up to 10 specific errors - logger.error( - f" - {issue.component}{'['+issue.field+']' if issue.field else ''}: {issue.message}" - ) - raise ValueError(error_summary) - - logger.info(f"✓ Validation passed. Saving file to {path}...") - try: - path.parent.mkdir(parents=True, exist_ok=True) - adata.write_h5ad(path) - logger.info(f"✓ File saved successfully.") - except Exception as e: - logger.error(f"Failed to save final AnnData object: {e}") - raise - - -def run_combination(config: CombinationConfig): - """ - Programmatic entry point for running the full combination pipeline. - """ - logger.info("--- Starting Combination Pipeline ---") - - # pca_optimized handles its own file discovery — skip the file validator entirely. - if config.concatenation_method == "pca_optimized": - combiner = PcaOptimizationCombiner(config) - adata_guide, adata_gene = combiner.combine() - - if adata_guide is not None and config.output_path: - output_dir = Path(config.output_path) - stem = Path(config.output_filename or "combined").stem - validate_and_save( - adata_guide, output_dir / f"{stem}_guide.h5ad", level="multi_experiment" - ) - validate_and_save( - adata_gene, output_dir / f"{stem}_gene.h5ad", level="multi_experiment" - ) - elif adata_guide is None: - logger.info("Phase 2 was skipped — no combined output to save.") - else: - logger.warning("output_path is not set. Skipping save.") - - logger.info("--- Combination Pipeline Finished ---") - return - - # All other methods need the file validator. - file_validator = FileValidator(config) - valid_files = file_validator.validate_and_collect_files() - - if config.concatenation_method == "comprehensive": - combiner = ComprehensiveCombiner(config) - adata_guide, adata_gene = combiner.combine() - elif config.concatenation_method == "vertical": - # Vertical concatenation - pool same biological signal across experiments - from ops_model.features.anndata_utils import _process_vertical_group - - if not config.experiments or not config.channel: - logger.error( - "Vertical concatenation requires 'experiments' and 'channel' in config." - ) - return - - # Convert experiments to (experiment, channel) pairs - exp_channel_pairs = [(exp, config.channel) for exp in config.experiments] - - # Determine feature directory - if config.feature_dir: - feature_dir = config.feature_dir - elif config.feature_type == "cellprofiler": - feature_dir = "cell-profiler" - else: - feature_dir = f"{config.feature_type}_features" - - logger.info( - f"Running vertical concatenation for {len(config.experiments)} experiments" - ) - logger.info(f"Channel: {config.channel}") - logger.info(f"Per-experiment mode: {config.aggregation_per_experiment}") - logger.info(f"Per-well mode: {config.aggregation_per_well}") - - # Process to guide level - adata_guide = _process_vertical_group( - exp_channel_pairs=exp_channel_pairs, - feature_type=config.feature_type, - base_dir=Path(config.base_dir), - feature_dir=feature_dir, - target_level="guide", - join="inner", - verbose=True, - subsample_controls=False, - control_gene=config.control_subsampling.get("control_gene", "NTC"), - control_group_size=config.control_subsampling.get("group_size", 4), - random_seed=config.control_subsampling.get("random_seed"), - normalize_on_pooling=config.normalization.get("normalize_on_pooling", True), - normalize_on_controls=config.normalization.get( - "normalize_on_controls", False - ), - per_experiment=config.aggregation_per_experiment, - keep_shared_only=not config.aggregation_per_well, # shared-only doesn't apply in per-well mode - per_well=config.aggregation_per_well, - ) - - # Aggregate to gene level. - # When per_well is active, group by [perturbation, well, experiment] so that - # each (gene, well, experiment) triple becomes one observation. - from ops_model.features.anndata_utils import aggregate_to_level - - if config.aggregation_per_well: - gene_batch_cols = ["well", "experiment"] - elif config.aggregation_per_experiment: - gene_batch_cols = ["experiment"] - else: - gene_batch_cols = None - - logger.info("Aggregating to gene level...") - adata_gene = aggregate_to_level( - adata_guide, - level="gene", - method="mean", - preserve_batch_info=False, - batch_cols=gene_batch_cols, - subsample_controls=config.control_subsampling.get("enabled", False), - control_gene=config.control_subsampling.get("control_gene", "NTC"), - control_group_size=config.control_subsampling.get("group_size", 4), - random_seed=config.control_subsampling.get("random_seed"), - ) - - logger.info(f"Guide-level: {adata_guide.shape}") - logger.info(f"Gene-level: {adata_gene.shape}") - else: - # For other methods, load all objects first - if not valid_files: - logger.error("No valid input files found. Aborting pipeline.") - return - - anndata_objects = load_anndata_objects(valid_files) - logger.error( - f"Method '{config.concatenation_method}' not yet implemented in this CLI." - ) - return - - # 4. Final validation and saving - if config.output_path: - output_path = Path(config.output_path) - guide_path = output_path.parent / f"{output_path.stem}_guide.h5ad" - gene_path = output_path.parent / f"{output_path.stem}_gene.h5ad" - - validate_and_save(adata_guide, guide_path, level="multi_experiment") - validate_and_save(adata_gene, gene_path, level="multi_experiment") - else: - logger.warning("Output path is not set. Skipping save.") - - logger.info("--- Combination Pipeline Finished ---") - - -def main(): - """CLI entry point for the AnnData combination script.""" - parser = argparse.ArgumentParser( - description="Combine AnnData objects from multiple experiments." - ) - parser.add_argument( - "--config", - type=str, - required=True, - help="Path to the YAML configuration file.", - ) - parser.add_argument( - "--output-path", - type=str, - help="Optional: Path to save the final output file. Overrides the path in the config file.", - ) - parser.add_argument( - "--verbose", - action="store_true", - help="Enable debug-level logging for more detailed output.", - ) - args = parser.parse_args() - - # Configure logging - log_level = logging.DEBUG if args.verbose else logging.INFO - setup_logging(log_level) - - # Load configuration - config = load_config(args.config, args.output_path) - - # Run the full combination pipeline - run_combination(config) - - -if __name__ == "__main__": - main() diff --git a/src/ops_model/post_process/combination/combiners.py b/src/ops_model/post_process/combination/combiners.py deleted file mode 100644 index a8ad3bb..0000000 --- a/src/ops_model/post_process/combination/combiners.py +++ /dev/null @@ -1,1220 +0,0 @@ -import logging -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple -import anndata as ad - -from .config_handler import CombinationConfig -from ops_model.features.anndata_utils import ( - concatenate_experiments_comprehensive, - _guide_col, -) - -# Initialize logger -logger = logging.getLogger(__name__) - - -@contextmanager -def temp_log_level(level, module_names): - """Temporarily set the logging level for a list of modules.""" - original_levels = {} - for name in module_names: - module_logger = logging.getLogger(name) - original_levels[name] = module_logger.level - module_logger.setLevel(level) - try: - yield - finally: - for name, original_level in original_levels.items(): - logging.getLogger(name).setLevel(original_level) - - -class ComprehensiveCombiner: - """ - Orchestrates the comprehensive combination of experiments by wrapping - the core logic in `anndata_utils.concatenate_experiments_comprehensive`. - """ - - def __init__(self, config: CombinationConfig): - """ - Initializes the combiner with a validated configuration. - """ - self.config = config - - def combine(self) -> Tuple[ad.AnnData, ad.AnnData]: - """ - Execute the full combination pipeline by calling the utility function. - """ - logger.info("Starting comprehensive combination process...") - - # Prepare arguments for the utility function from the config - # Note: Currently, the same embedding config is used for both guide and gene levels - # in concatenate_experiments_comprehensive. Using gene_level config since that's - # typically more important for downstream analysis. - embedding_config = self.config.embeddings.get("gene_level") - if embedding_config is None: - # Fallback to guide_level if gene_level missing - embedding_config = self.config.embeddings.get("guide_level") - if embedding_config is None: - # Final fallback if both missing (shouldn't happen) - from .config_handler import EmbeddingConfig - - embedding_config = EmbeddingConfig() - - # Convert experiments_channels from Dict[str, List[str]] to List[Tuple[str, str]] - experiments_channels_list = [ - (exp, ch) - for exp, channels in (self.config.experiments_channels or {}).items() - for ch in channels - ] - - if not experiments_channels_list: - raise ValueError("No experiment/channel pairs were found to combine.") - - # Temporarily silence verbose logs from underlying libraries - with temp_log_level(logging.WARNING, ["scanpy", "umap"]): - adata_guide, adata_gene = concatenate_experiments_comprehensive( - experiments_channels=experiments_channels_list, - feature_type=self.config.feature_type, - base_dir=self.config.base_dir, - feature_dir=self.config.feature_dir, - recompute_embeddings=embedding_config.compute_embeddings, - n_pca_components=embedding_config.n_pca_components, - n_umap_neighbors=embedding_config.n_neighbors, - compute_pca=embedding_config.pca, - compute_umap=embedding_config.umap, - compute_phate=embedding_config.phate, - normalize_on_pooling=self.config.normalization.get( - "normalize_on_pooling", True - ), - normalize_on_controls=self.config.normalization.get( - "normalize_on_controls", False - ), - subsample_controls=self.config.control_subsampling.get( - "enabled", False - ), - control_gene=self.config.control_subsampling.get("control_gene", "NTC"), - control_group_size=self.config.control_subsampling.get("group_size", 4), - random_seed=self.config.control_subsampling.get("random_seed"), - fit_on_aggregated_controls=self.config.fitted_embeddings.get( - "enabled", False - ), - use_pca_for_umap=self.config.fitted_embeddings.get( - "use_pca_for_umap", True - ), - leiden_resolutions=( - self.config.leiden_clustering.get("resolutions") - if self.config.leiden_clustering.get("enabled", False) - else None - ), - ) - - logger.info("Comprehensive combination process complete.") - return adata_guide, adata_gene - - -# ============================================================================= -# PCA-optimized combiner — constants -# ============================================================================= - -_SWEEP_THRESHOLDS_DINO = [ - 0.60, - 0.70, - 0.74, - 0.76, - 0.78, - 0.80, - 0.82, - 0.84, - 0.88, - 0.90, - 0.95, -] -_SWEEP_THRESHOLDS_CP = [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70] -_MIN_PCS = 10 # skip thresholds that yield fewer PCs than this -_MIN_CELLS_FLOOR = 750_000 # floor for auto target_n_cells -_PCA_FIT_CAP = 5_000_000 # cells used to fit PCA axes; larger datasets use passthrough (fit subsample, transform all) - - -# ============================================================================= -# PCA-optimized combiner — module-level helpers (must be picklable for SLURM) -# ============================================================================= - - -def _prepare_cells_for_scoring(adata: "ad.AnnData") -> "ad.AnnData": - """Strip obs to copairs-required columns and cast X to float64.""" - import numpy as np - - if "n_cells" not in adata.obs.columns: - adata.obs["n_cells"] = 1 - keep = [ - c - for c in [_guide_col(adata), "perturbation", "n_cells"] - if c in adata.obs.columns - ] - adata.obs = adata.obs[keep].copy() - for col in adata.obs.columns: - if adata.obs[col].dtype.name == "category": - adata.obs[col] = adata.obs[col].astype(str) - adata.X = adata.X.astype("float64") - return adata - - -def _sweep_pca_thresholds( - X_pcs: "np.ndarray", - cumvar: "np.ndarray", - obs_df: "pd.DataFrame", - thresholds: List[float], - norm_method: str, - _logger, -) -> Optional[Tuple[List[Dict], float, int]]: - """Sweep variance thresholds, score AUC at each, return best result. - - Returns (sweep_rows, best_threshold, best_n_pcs) or None if no valid threshold found. - The normalization here is temporary (for scoring only) and discarded after selection. - """ - import numpy as np - import pandas as pd - import anndata as ad - from ops_utils.analysis.pca import n_pcs_for_threshold - from ops_utils.analysis.map_scores import ( - compute_auc_score, - phenotypic_activity_assesment, - ) - from ops_model.features.anndata_utils import ( - aggregate_to_level, - normalize_guide_adata, - ) - - best_auc_t, best_auc_r, best_auc_a, best_auc_n = None, -1.0, -1.0, 0 - sweep_rows = [] - - for threshold in thresholds: - n_pcs = n_pcs_for_threshold(cumvar, threshold) - X_slice = X_pcs[:, :n_pcs].astype(np.float32) - pc_names = [f"PC{j}" for j in range(n_pcs)] - - adata_tmp = ad.AnnData( - X=X_slice, obs=obs_df.copy(), var=pd.DataFrame(index=pc_names) - ) - guide_tmp = aggregate_to_level( - adata_tmp, level="guide", method="mean", preserve_batch_info=False - ) - del adata_tmp - guide_tmp.X = guide_tmp.X.astype(np.float32) - - guide_norm = normalize_guide_adata(guide_tmp.copy(), norm_method) - guide_norm.X = guide_norm.X.astype(np.float32) - guide_norm = _prepare_cells_for_scoring(guide_norm) - - try: - activity_map, active_ratio = ( - phenotypic_activity_assesment( # distance default="cosine" - guide_norm, - plot_results=False, - null_size=100_000, - ) - ) - auc = compute_auc_score(activity_map) - except Exception as e: - _logger.warning(f" Scoring failed at {threshold:.0%}: {e}") - active_ratio, auc = 0.0, 0.0 - del guide_tmp, guide_norm - - row = { - "threshold": threshold, - "n_pcs": n_pcs, - "activity": active_ratio, - "auc": auc, - } - sweep_rows.append(row) - - if n_pcs < _MIN_PCS: - _logger.info( - f" {threshold:.0%}: {n_pcs} PCs (< {_MIN_PCS}) — {active_ratio:.1%}, AUC={auc:.4f} [skipped]" - ) - continue - - _logger.info( - f" {threshold:.0%}: {n_pcs} PCs — {active_ratio:.1%}, AUC={auc:.4f}" - ) - if auc > best_auc_a or (auc == best_auc_a and active_ratio > best_auc_r): - best_auc_t, best_auc_r, best_auc_a, best_auc_n = ( - threshold, - active_ratio, - auc, - n_pcs, - ) - - if best_auc_t is None: - return None - - _logger.info( - f" Best (AUC): {best_auc_t:.0%} ({best_auc_n} PCs) → {best_auc_r:.1%}, AUC={best_auc_a:.4f}" - ) - return sweep_rows, best_auc_t, best_auc_n - - -def _process_signal_group( - signal: str, - exp_channel_pairs: List[Tuple[str, str]], - output_dir: str, - base_dir: str, - feature_dir: str, - pca_config: Dict[str, Any], - downsampling_config: Dict[str, Any], - norm_method: str, - random_seed: int = 42, - preserve_batch: bool = False, - no_pca: bool = False, - cell_filter: Optional[Any] = None, - save_cell_level: bool = False, - cell_path_map: Optional[Dict[Tuple[str, str], str]] = None, -) -> str: - """Phase 1: pool cells for one biological signal, fit PCA, select n_pcs, save h5ads. - - Top-level function (not a method) so submitit can pickle it for SLURM dispatch. - - Saves to output_dir/per_channel/: - - {signal_prefix}_guide.h5ad (aggregated at selected n_pcs, un-normalized) - - {signal_prefix}_gene.h5ad - - {signal_prefix}_sweep.csv (one row per threshold, or one row for fixed mode) - """ - import logging - import time - import warnings - import numpy as np - import pandas as pd - import anndata as ad - from pathlib import Path - - warnings.filterwarnings("ignore", category=FutureWarning) - logging.basicConfig( - level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s" - ) - logging.getLogger("copairs").setLevel(logging.WARNING) - _logger = logging.getLogger(__name__) - t_start = time.time() - - from ops_utils.data.feature_discovery import ( - find_cell_h5ad_path, - load_cell_h5ad, - get_channel_maps_path, - sanitize_signal_filename, - ) - from ops_utils.analysis.pca import fit_pca, n_pcs_for_threshold - from ops_utils.analysis.normalization import zscore_normalize - from ops_model.features.anndata_utils import aggregate_to_level - - output_dir = Path(output_dir) - per_channel_dir = output_dir / "per_channel" - per_channel_dir.mkdir(parents=True, exist_ok=True) - - maps_path = get_channel_maps_path() - storage_roots = [Path(base_dir)] - - def _resolve_cell_file(exp, ch): - """Explicit-path override (bypasses discovery) if provided, else standard lookup.""" - if cell_path_map and (exp, ch) in cell_path_map: - p = Path(cell_path_map[(exp, ch)]) - return p if p.exists() else None - return find_cell_h5ad_path(exp, ch, storage_roots, feature_dir, maps_path) - - _logger.info(f"Processing signal: {signal} ({len(exp_channel_pairs)} experiments)") - - # --- Pre-scan cell counts (lightweight h5py read) --- - import h5py - - exp_cell_counts = {} - for exp, ch in exp_channel_pairs: - cell_file = _resolve_cell_file(exp, ch) - if cell_file is not None: - try: - with h5py.File(cell_file, "r") as f: - exp_cell_counts[(exp, ch)] = f["X"].shape[0] - except Exception: - pass - - n_cells_pooled = sum(exp_cell_counts.values()) - if n_cells_pooled == 0: - return f"FAILED: {signal} — no cell data found for any experiment" - - # --- Resolve downsampling target --- - downsampling_enabled = downsampling_config.get("enabled", False) - if downsampling_enabled: - raw_target = downsampling_config.get("target_n_cells", "auto") - actual_target = ( - int(raw_target) - if raw_target != "auto" - else max(min(exp_cell_counts.values()), _MIN_CELLS_FLOOR) - ) - actual_target = min(actual_target, n_cells_pooled) - else: - actual_target = n_cells_pooled - - _logger.info(f" Total cells: {n_cells_pooled:,}, target: {actual_target:,}") - - # --- Load cells, optionally downsample, optionally z-score per experiment --- - rng = np.random.RandomState(random_seed) - all_blocks = [] - n_vars_expected = None - normalize_before_pca = pca_config.get("normalize_before_pca", False) - inferred_cell_type = None # read from first successfully loaded h5ad - inferred_guide_col = None # read from first successfully loaded h5ad - - for exp, ch in exp_channel_pairs: - if exp_cell_counts.get((exp, ch), 0) == 0: - continue - - cell_file = _resolve_cell_file(exp, ch) - adata = ad.read_h5ad(cell_file) if cell_file is not None else None - if adata is None: - continue - - if cell_filter is not None: - n_before = adata.n_obs - adata = cell_filter(adata) - if adata.n_obs == 0: - _logger.warning(f" {exp}/{ch}: all cells removed by filter, skipping") - continue - _logger.info(f" {exp}/{ch}: {n_before} → {adata.n_obs} cells after filtering") - - if inferred_cell_type is None: - inferred_cell_type = adata.uns.get("cell_type", "cell") - if inferred_guide_col is None: - inferred_guide_col = _guide_col(adata) - - if n_vars_expected is None: - n_vars_expected = adata.n_vars - elif adata.n_vars != n_vars_expected: - _logger.info( - f" {exp}/{ch}: {adata.n_vars} features (vs {n_vars_expected}), will use shared features on concat" - ) - - # Ensure label_str exists - if "label_str" not in adata.obs.columns and "perturbation" in adata.obs.columns: - adata.obs["label_str"] = adata.obs["perturbation"] - - # Proportional subsample - if downsampling_enabled and actual_target < n_cells_pooled: - fraction = exp_cell_counts[(exp, ch)] / n_cells_pooled - n_take = max(1, int(round(fraction * actual_target))) - n_take = min(n_take, adata.n_obs) - if n_take < adata.n_obs: - idx = rng.choice(adata.n_obs, n_take, replace=False) - idx.sort() - adata = adata[idx].copy() - - keep_cols = [ - c - for c in [_guide_col(adata), "perturbation", "label_str"] - if c in adata.obs.columns - ] - obs = adata.obs[keep_cols].copy() - obs["experiment"] = exp.split("_")[0] - - X_block = np.asarray(adata.X, dtype=np.float32) - feature_cols = list(adata.var_names) - - # Per-experiment z-score before pooling (uses ops_utils backend) - if normalize_before_pca: - df_block = pd.DataFrame(X_block, columns=feature_cols) - df_norm = zscore_normalize( - df_block, feature_cols=feature_cols, method="global" - ) - X_block = df_norm[feature_cols].values.astype(np.float32) - - all_blocks.append(ad.AnnData(X=X_block, obs=obs, var=adata.var.copy())) - _logger.info( - f" {exp.split('_')[0]}/{ch}: {exp_cell_counts[(exp, ch)]:,} → {len(obs):,} cells" - ) - del adata, X_block - - if not all_blocks: - return f"FAILED: {signal} — failed to load cell data for all experiments" - - # Concatenate blocks; inner join keeps only shared features across experiments. - # index_unique ensures obs_names are unique across experiments (avoids AnnData warning). - adata_cells = ad.concat(all_blocks, join="inner", index_unique="-") - del all_blocks - # Per-experiment all_blocks were built without .uns; restore guide_col so - # downstream consumers (and any aggregated outputs) see the right column. - if inferred_guide_col is not None: - adata_cells.uns["guide_col"] = inferred_guide_col - if np.isnan(adata_cells.X).any(): - adata_cells.X = np.nan_to_num(adata_cells.X, nan=0.0) - - n_cells = adata_cells.n_obs - n_feats = adata_cells.n_vars - feature_names = list(adata_cells.var_names) - _logger.info(f" Pooled: {n_cells:,} cells, {n_feats} features") - - # Separate obs for scoring (no experiment column — copairs doesn't handle extra string cols) - score_cols = [ - c - for c in [_guide_col(adata_cells), "perturbation", "label_str"] - if c in adata_cells.obs.columns - ] - obs_for_scoring = adata_cells.obs[score_cols].copy() - obs_full = adata_cells.obs[[c for c in adata_cells.obs.columns]].copy() - - X_raw = np.asarray(adata_cells.X, dtype=np.float32) - del adata_cells - - if no_pca: - _logger.info( - f" no_pca=True: skipping PCA, using {n_feats} raw features directly" - ) - X_reduced = X_raw - del X_raw - pc_names = feature_names - n_pcs = n_feats - pca_components = None - peak_t = None - sweep_rows = [] - else: - # --- Fit PCA on subsample, transform all cells in chunks --- - t_pca = time.time() - n_total = X_raw.shape[0] - - if n_total > _PCA_FIT_CAP: - fit_idx = rng.choice(n_total, _PCA_FIT_CAP, replace=False) - fit_idx.sort() - _logger.info( - f" Fitting PCA on {_PCA_FIT_CAP:,}/{n_total:,} subsampled cells..." - ) - _, cumvar, pca_model = fit_pca(X_raw[fit_idx]) - del fit_idx - _logger.info(f" Transforming all {n_total:,} cells in chunks...") - chunk_size = 2_000_000 - X_pcs_chunks = [] - for i in range(0, n_total, chunk_size): - chunk = np.asarray(X_raw[i : i + chunk_size], dtype=np.float64) - chunk = np.nan_to_num(chunk, nan=0.0, posinf=0.0, neginf=0.0) - X_pcs_chunks.append(pca_model.transform(chunk).astype(np.float32)) - _logger.info( - f" Transformed chunk {i:,}-{min(i + chunk_size, n_total):,}" - ) - X_pcs = np.vstack(X_pcs_chunks) - del X_pcs_chunks - else: - _logger.info(f" Fitting PCA on {n_total:,} x {X_raw.shape[1]} matrix...") - X_pcs, cumvar, pca_model = fit_pca(X_raw) - - _logger.info( - f" PCA done in {time.time() - t_pca:.0f}s — {X_pcs.shape[1]} components" - ) - pca_components = pca_model.components_.copy() - del X_raw, pca_model - - # --- Select n_pcs: sweep or fixed --- - selection = pca_config.get("selection", "sweep") - if preserve_batch: - selection = "fixed" # skip sweep when preserving batch info - - if selection == "fixed": - cutoff = float(pca_config.get("variance_cutoff", 0.80)) - n_pcs = n_pcs_for_threshold(cumvar, cutoff) - peak_t = cutoff - sweep_rows = [ - {"threshold": cutoff, "n_pcs": n_pcs, "activity": None, "auc": None} - ] - _logger.info(f" Fixed cutoff {cutoff:.0%}: {n_pcs} PCs") - else: - thresholds = pca_config.get("_sweep_thresholds", _SWEEP_THRESHOLDS_DINO) - result = _sweep_pca_thresholds( - X_pcs, cumvar, obs_for_scoring, thresholds, norm_method, _logger - ) - if result is None: - return f"FAILED: {signal} — no valid threshold found (all thresholds yield < {_MIN_PCS} PCs)" - sweep_rows, peak_t, n_pcs = result - - # --- Build AnnData at selected n_pcs and aggregate --- - X_reduced = X_pcs[:, :n_pcs].astype(np.float32) - pc_names = [f"{signal}_PC{j}" for j in range(n_pcs)] - del X_pcs - - # Compute n_experiments per guide from obs_full before dropping the experiment column. - # Injected directly into g.obs after guide aggregation (aggregate_to_level at guide level - # does not carry arbitrary cell-obs columns through, so cell-level injection is lost). - guide_col_name = inferred_guide_col or _guide_col(adata_cells) - guide_to_n_exp = obs_full.groupby(guide_col_name)["experiment"].nunique() - - # Drop experiment column before aggregation unless preserving batch info - # (copairs is incompatible with extra string cols, but preserve_batch skips copairs scoring) - if preserve_batch: - obs_for_agg = obs_full - else: - obs_for_agg = obs_full[[c for c in obs_full.columns if c != "experiment"]] - adata_reduced = ad.AnnData( - X=X_reduced, - obs=obs_for_agg, - var=pd.DataFrame(index=pc_names), - ) - adata_reduced.uns["guide_col"] = guide_col_name - del X_reduced - - if save_cell_level: - adata_cell = ad.AnnData( - X=adata_reduced.X, - obs=obs_full.copy(), - var=adata_reduced.var.copy(), - ) - adata_cell.uns["guide_col"] = guide_col_name - - g = aggregate_to_level( - adata_reduced, level="guide", method="mean", preserve_batch_info=preserve_batch - ) - e = aggregate_to_level( - adata_reduced, level="gene", method="mean", preserve_batch_info=preserve_batch - ) - del adata_reduced - - g.X = g.X.astype(np.float32) - e.X = e.X.astype(np.float32) - - # Inject n_experiments into guide obs so Phase 2's guide→gene aggregation picks it up - # via aggregate_to_level's max() path (lines 804-808 of anndata_utils.py). - g.obs["n_experiments"] = ( - g.obs[guide_col_name].map(guide_to_n_exp).fillna(1).astype(int) - ) - g.uns["aggregation_method"] = "mean" - e.uns["aggregation_method"] = "mean" - - uns = { - "signal": signal, - "pca_applied": not no_pca, - "n_cells": int(n_cells), - "n_cells_pooled": int(n_cells_pooled), - "n_features_raw": int(n_feats), - "pca_feature_names": feature_names, - "experiments": ",".join(exp.split("_")[0] for exp, _ in exp_channel_pairs), - "exp_cell_counts": { - exp.split("_")[0]: int(cnt) for (exp, ch), cnt in exp_cell_counts.items() - }, - "channels": list({ch for _, ch in exp_channel_pairs}), - "cell_type": inferred_cell_type or "cell", - "embedding_type": feature_dir, - } - if no_pca: - uns["n_features"] = int(n_feats) - else: - uns.update( - { - "pca_threshold": float(peak_t), - "n_pcs": int(n_pcs), - "explained_variance": ( - float(cumvar[n_pcs - 1]) if n_pcs <= len(cumvar) else 1.0 - ), - "pca_components": pca_components[:n_pcs].tolist(), - } - ) - for adata in [g, e]: - adata.uns.update(uns) - - if save_cell_level: - adata_cell.uns.update(uns) - - from ops_model.post_process.anndata_processing.anndata_validator import ( - AnndataValidator, - ) - - _validator = AnndataValidator() - for _adata, _level in [(g, "guide"), (e, "gene")]: - _report = _validator.validate(_adata, level=_level, strict=False) - if not _report.is_valid: - _logger.warning( - f" {signal} {_level}-level AnnData failed validation:\n{_report.summary()}" - ) - else: - _logger.info( - f" {signal} {_level}-level AnnData passed validation ({_report.get_warning_count()} warnings)" - ) - - file_prefix = sanitize_signal_filename(signal) - output_suffix = ("_nopca" if no_pca else "") + ("_batch" if preserve_batch else "") - g.write_h5ad(per_channel_dir / f"{file_prefix}{output_suffix}_guide.h5ad") - e.write_h5ad(per_channel_dir / f"{file_prefix}{output_suffix}_gene.h5ad") - if save_cell_level: - adata_cell.write_h5ad(per_channel_dir / f"{file_prefix}{output_suffix}_cell.h5ad") - _logger.info( - f" {signal}: saved cell-level AnnData ({adata_cell.n_obs:,} cells × {adata_cell.n_vars} features)" - ) - if sweep_rows: - pd.DataFrame(sweep_rows).to_csv( - per_channel_dir / f"{file_prefix}{output_suffix}_sweep.csv", index=False - ) - - elapsed = time.time() - t_start - if no_pca: - _logger.info( - f" Done: {signal} in {elapsed:.0f}s — {n_feats} raw features (no PCA)" - ) - return f"SUCCESS: {signal} — {n_feats} raw features, no PCA ({n_cells:,}/{n_cells_pooled:,} cells)" - else: - _logger.info(f" Done: {signal} in {elapsed:.0f}s — {n_pcs} PCs @ {peak_t:.0%}") - return f"SUCCESS: {signal} — {n_pcs} PCs @ {peak_t:.0%} ({n_cells:,}/{n_cells_pooled:,} cells)" - - -# ============================================================================= -# PcaOptimizationCombiner -# ============================================================================= - - -class PcaOptimizationCombiner: - """Config-driven combiner that follows the pca_optimization two-phase pipeline. - - Phase 1 — per biological signal group: - Pool cells from all experiments → optionally downsample → optionally z-score → - fit PCA → select n_pcs (sweep or fixed cutoff) → save per-signal guide/gene h5ads. - - Phase 2 — aggregation: - Load per-signal guide h5ads → hconcat → NTC normalize → aggregate to gene → - optionally compute UMAP/PHATE → return (adata_guide, adata_gene). - """ - - def __init__(self, config: CombinationConfig) -> None: - self.config = config - self._cell_path_map = self._build_cell_path_map() - - # ------------------------------------------------------------------ - # Experiment resolution - # ------------------------------------------------------------------ - - def _build_cell_path_map(self) -> Optional[Dict[Tuple[str, str], str]]: - """Flatten config.cell_paths ({exp: {channel: path}}) to {(exp, channel): path}. - - Returns None when cell_paths is not set (standard discovery is used). - """ - if not self.config.cell_paths: - return None - out: Dict[Tuple[str, str], str] = {} - for exp, chan_map in self.config.cell_paths.items(): - if not isinstance(chan_map, dict): - raise ValueError( - f"cell_paths[{exp!r}] must be a mapping of channel -> path, " - f"got {type(chan_map).__name__}." - ) - for ch, path in chan_map.items(): - out[(exp, ch)] = path - return out - - def _resolve_experiments(self) -> List[Tuple[str, str]]: - """Return final (experiment, channel) list from config. - - If cell_paths is set: use its keys as the (experiment, channel) pairs - (discovery is bypassed; the explicit paths are used to load cells). - Else if auto_discover=True: scan base_dir via feature_discovery functions. - Else: flatten experiments_channels from config. - Applies reporters filter to the result in all cases. - """ - from ops_utils.data.feature_discovery import ( - discover_dino_experiments, - discover_cellprofiler_experiments, - ) - - if self._cell_path_map is not None: - pairs = list(self._cell_path_map.keys()) - logger.info( - f"Using {len(pairs)} explicit cell_paths entries " - "(path discovery bypassed)." - ) - if self.config.reporters: - pairs = [(exp, ch) for exp, ch in pairs if ch in self.config.reporters] - if not pairs: - raise ValueError("No experiment-channel pairs remain after filtering.") - return pairs - - if self.config.auto_discover: - if self.config.experiments_channels: - logger.warning( - "auto_discover=True and experiments_channels are both set; " - "ignoring experiments_channels and using auto_discover." - ) - storage_roots = [Path(self.config.base_dir)] - if self.config.feature_type == "cellprofiler": - pairs = discover_cellprofiler_experiments(storage_roots) - else: - pairs = discover_dino_experiments( - storage_roots, self.config.feature_dir - ) - logger.info( - f"Auto-discovered {len(pairs)} experiment-channel pairs from {self.config.base_dir}" - ) - else: - if not self.config.experiments_channels: - raise ValueError( - "experiments_channels must be set when auto_discover=False" - ) - pairs = [ - (exp, ch) - for exp, channels in self.config.experiments_channels.items() - for ch in channels - ] - - if self.config.reporters: - pairs = [(exp, ch) for exp, ch in pairs if ch in self.config.reporters] - - if not pairs: - raise ValueError("No experiment-channel pairs remain after filtering.") - - return pairs - - def _group_by_signal( - self, pairs: List[Tuple[str, str]] - ) -> Dict[str, List[Tuple[str, str]]]: - """Group (experiment, channel) pairs by biological signal label.""" - from ops_utils.data.feature_discovery import ( - get_channel_maps_path, - build_signal_groups, - ) - from ops_utils.data.feature_metadata import FeatureMetadata - - maps_path = get_channel_maps_path() - fm = FeatureMetadata(metadata_path=maps_path) - return build_signal_groups(pairs, fm) - - # ------------------------------------------------------------------ - # Auto target_n_cells computation - # ------------------------------------------------------------------ - - def _compute_auto_target( - self, signal_groups: Dict[str, List[Tuple[str, str]]] - ) -> int: - """Scan cell counts per signal group and return max(min_count, floor).""" - import h5py - from ops_utils.data.feature_discovery import ( - find_cell_h5ad_path, - get_channel_maps_path, - ) - - maps_path = get_channel_maps_path() - storage_roots = [Path(self.config.base_dir)] - feature_dir = self.config.feature_dir - - logger.info("Pre-scanning cell counts to compute auto target_n_cells...") - min_count = float("inf") - for signal, pairs in signal_groups.items(): - group_total = 0 - for exp, ch in pairs: - if self._cell_path_map and (exp, ch) in self._cell_path_map: - _p = Path(self._cell_path_map[(exp, ch)]) - cell_file = _p if _p.exists() else None - else: - cell_file = find_cell_h5ad_path( - exp, ch, storage_roots, feature_dir, maps_path - ) - if cell_file is not None: - try: - with h5py.File(cell_file, "r") as f: - group_total += f["X"].shape[0] - except Exception: - pass - if group_total > 0: - min_count = min(min_count, group_total) - logger.info(f" {signal}: {group_total:,} cells") - - if min_count == float("inf"): - min_count = _MIN_CELLS_FLOOR - - target = max(int(min_count), _MIN_CELLS_FLOOR) - logger.info(f" Auto target_n_cells = {target:,}") - return target - - # ------------------------------------------------------------------ - # PCA config builder - # ------------------------------------------------------------------ - - def _build_pca_config(self) -> Dict[str, Any]: - """Return pca config dict with default sweep thresholds injected.""" - pca_cfg = dict(self.config.pca) - if pca_cfg.get("selection", "sweep") == "sweep": - if self.config.feature_type == "cellprofiler": - pca_cfg["_sweep_thresholds"] = _SWEEP_THRESHOLDS_CP - else: - pca_cfg["_sweep_thresholds"] = _SWEEP_THRESHOLDS_DINO - return pca_cfg - - # ------------------------------------------------------------------ - # Phase 1 - # ------------------------------------------------------------------ - - def _run_phase1_local( - self, - signal_groups: Dict[str, List[Tuple[str, str]]], - output_dir: Path, - downsampling_config: Dict[str, Any], - cell_filter: Optional[Any] = None, - save_cell_level: bool = False, - ) -> None: - """Run Phase 1 sequentially in the calling process.""" - pca_cfg = self._build_pca_config() - norm_method = self.config.normalization.get("method", "ntc") - - for signal, pairs in signal_groups.items(): - result = _process_signal_group( - signal=signal, - exp_channel_pairs=pairs, - output_dir=str(output_dir), - base_dir=self.config.base_dir, - feature_dir=self.config.feature_dir, - pca_config=pca_cfg, - downsampling_config=downsampling_config, - norm_method=norm_method, - preserve_batch=self.config.preserve_batch, - no_pca=self.config.no_pca, - cell_filter=cell_filter, - save_cell_level=save_cell_level, - cell_path_map=self._cell_path_map, - ) - logger.info(f" {result}") - - def _run_phase1_slurm( - self, - signal_groups: Dict[str, List[Tuple[str, str]]], - output_dir: Path, - downsampling_config: Dict[str, Any], - cell_filter: Optional[Any] = None, - save_cell_level: bool = False, - ) -> None: - """Submit Phase 1 as parallel SLURM jobs and wait for completion.""" - from ops_utils.hpc.slurm_batch_utils import submit_parallel_jobs - from ops_utils.data.feature_discovery import sanitize_signal_filename - - pca_cfg = self._build_pca_config() - norm_method = self.config.normalization.get("method", "ntc") - slurm = self.config.slurm - - slurm_params = { - "timeout_min": slurm.get("time_minutes", 10), - "mem": slurm.get("memory", "100GB"), - "cpus_per_task": slurm.get("cpus", 16), - "slurm_partition": slurm.get("partition", "cpu,gpu"), - } - - jobs = [] - for signal, pairs in signal_groups.items(): - sig_safe = sanitize_signal_filename(signal)[:40] - jobs.append( - { - "name": f"pca_opt_{sig_safe}", - "func": _process_signal_group, - "kwargs": { - "signal": signal, - "exp_channel_pairs": pairs, - "output_dir": str(output_dir), - "base_dir": self.config.base_dir, - "feature_dir": self.config.feature_dir, - "pca_config": pca_cfg, - "downsampling_config": downsampling_config, - "norm_method": norm_method, - "preserve_batch": self.config.preserve_batch, - "no_pca": self.config.no_pca, - "cell_filter": cell_filter, - "save_cell_level": save_cell_level, - "cell_path_map": self._cell_path_map, - }, - } - ) - - logger.info(f"Submitting {len(jobs)} SLURM Phase 1 jobs...") - result = submit_parallel_jobs( - jobs_to_submit=jobs, - experiment="pca_optimization", - slurm_params=slurm_params, - log_dir="pca_optimization", - manifest_prefix="pca_opt", - wait_for_completion=True, - ) - - failed = result.get("failed", []) - if failed: - logger.warning(f"{len(failed)} Phase 1 job(s) failed: {failed}") - - # ------------------------------------------------------------------ - # Phase 2 - # ------------------------------------------------------------------ - - def _join_gene_panel_metadata( - self, - adata: "ad.AnnData", - panel_csv: Path = Path( - "/hpc/projects/icd.fast.ops/configs/annotated_gene_panel_July2025.csv" - ), - join_col: str = "perturbation", - ) -> None: - """Annotate ``adata.obs`` in-place with per-gene metadata from the - annotated_gene_panel CSV, joined on ``obs[join_col]`` (gene name) == - CSV ``Gene.name``. R-style column names (e.g. ``Priority..smaller.is.higher.``) - are normalized to underscores. Rows whose perturbation isn't in the - panel (e.g. NTCs, off-panel genes) get NaN.""" - import pandas as pd - - if not panel_csv.exists(): - logger.warning(f" Gene panel CSV not found: {panel_csv} — skipping") - return - if join_col not in adata.obs.columns: - logger.warning(f" adata.obs missing '{join_col}' — skipping panel metadata join") - return - panel = pd.read_csv(panel_csv) - if "Gene.name" not in panel.columns: - logger.warning(" Gene panel CSV missing 'Gene.name' — skipping") - return - # Drop the unnamed row-index column the CSV carries. - panel = panel.drop(columns=[c for c in panel.columns if c.startswith("Unnamed:")]) - panel = panel.drop_duplicates(subset=["Gene.name"]) - panel.columns = [ - c.replace("..", "_").replace(".", "_").rstrip("_") - for c in panel.columns - ] - panel = panel.set_index("Gene_name") - keys = adata.obs[join_col].astype(str).values - aligned = panel.reindex(keys) - aligned.index = adata.obs.index - n_matched = int(aligned.notna().any(axis=1).sum()) - logger.info( - f" Joined gene panel ({len(panel)} genes × {len(panel.columns)} cols): " - f"{n_matched}/{len(adata.obs)} obs rows matched on '{join_col}'" - ) - for col in aligned.columns: - if col in adata.obs.columns: - logger.debug(f" panel column '{col}' already in obs — overwriting") - adata.obs[col] = aligned[col].values - - def _compute_embeddings(self, adata: "ad.AnnData", embedding_config) -> None: - """Compute UMAP and/or PHATE directly (not via scanpy) and store in obsm.""" - import numpy as np - - X = adata.X.astype(np.float32) - n_obs = adata.n_obs - - if embedding_config.umap: - try: - from umap import UMAP - - nn = min(embedding_config.n_neighbors, n_obs - 1) - if nn >= 2: - logger.info(f"Computing UMAP ({n_obs} obs, n_neighbors={nn})...") - coords = UMAP( - n_components=2, n_neighbors=nn, random_state=42 - ).fit_transform(X) - adata.obsm["X_umap"] = coords.astype(np.float32) - logger.info(" UMAP complete.") - else: - logger.warning(f"Skipping UMAP: too few observations ({n_obs})") - except ImportError: - logger.warning("UMAP skipped: install umap-learn") - except Exception as e: - logger.warning(f"UMAP failed: {e}") - - if embedding_config.phate: - try: - import phate - - knn = min(15 if n_obs > 2000 else 10, n_obs - 1) - if knn >= 2: - logger.info(f"Computing PHATE ({n_obs} obs, knn={knn})...") - coords = phate.PHATE( - n_components=2, - knn=knn, - decay=15, - t="auto", - n_jobs=-1, - random_state=42, - verbose=0, - ).fit_transform(X) - adata.obsm["X_phate"] = coords.astype(np.float32) - logger.info(" PHATE complete.") - else: - logger.warning(f"Skipping PHATE: too few observations ({n_obs})") - except ImportError: - logger.warning("PHATE skipped: install phate") - except Exception as e: - logger.warning(f"PHATE failed: {e}") - - def _run_phase2(self, output_dir: Path) -> Tuple["ad.AnnData", "ad.AnnData"]: - """Load per-signal guide h5ads, hconcat, NTC normalize, aggregate, embed.""" - import numpy as np - from ops_model.features.anndata_utils import ( - hconcat_by_perturbation, - normalize_guide_adata, - aggregate_to_level, - ) - - per_channel_dir = output_dir / "per_channel" - guide_files = sorted(per_channel_dir.glob("*_guide.h5ad")) - - if not guide_files: - raise FileNotFoundError( - f"No per-signal guide h5ads found in {per_channel_dir}. " - "Ensure Phase 1 completed successfully before running Phase 2." - ) - - logger.info(f"Phase 2: loading {len(guide_files)} per-signal guide files...") - guide_blocks = [] - for gf in guide_files: - g = ad.read_h5ad(gf) - sig = g.uns.get("signal", gf.stem.replace("_guide", "")) - if sig == "unknown" or sig.startswith("(unmapped:"): - logger.warning(f" Skipping {gf.name}: unmapped signal ({sig!r})") - continue - guide_blocks.append(g) - logger.info(f" {sig}: {g.n_obs} guides × {g.n_vars} PCs") - - if not guide_blocks: - raise ValueError("No valid per-signal guide blocks loaded for Phase 2.") - - logger.info("Concatenating per-signal blocks horizontally...") - cell_type = guide_blocks[0].uns.get("cell_type", "cell") - embedding_type = guide_blocks[0].uns.get( - "embedding_type", self.config.feature_type - ) - - # Build pca_optimized_metadata before consuming guide_blocks - biological_groups = {} - feature_slices = {} - offset = 0 - for g in guide_blocks: - sig = g.uns["signal"] - n_feat = g.n_vars - biological_groups[sig] = { - "biological_signal": sig, - "aggregation_type": "pooled_pca", - "experiments": g.uns.get("experiments", "").split(","), - "channels": g.uns.get("channels", []), - "n_cells_per_experiment": g.uns.get("exp_cell_counts", {}), - "n_cells_total": g.uns.get("n_cells_pooled", 0), - "n_cells_used": g.uns.get("n_cells", 0), - "n_features_raw": g.uns.get("n_features_raw", 0), - "n_pcs": g.uns.get("n_pcs", n_feat), - "pca_threshold": g.uns.get("pca_threshold", None), - "explained_variance": g.uns.get("explained_variance", None), - "feature_range": [offset, offset + n_feat], - "n_features": n_feat, - } - feature_slices[sig] = { - "start": offset, - "end": offset + n_feat, - "n_features": n_feat, - } - offset += n_feat - - pca_optimized_metadata = { - "strategy": "pca_optimized", - "feature_type": self.config.feature_type, - "aggregation_level": "guide", - "n_biological_signals": len(biological_groups), - "biological_groups": biological_groups, - "feature_slices": feature_slices, - } - - adata_guide = hconcat_by_perturbation(guide_blocks, "guide") - del guide_blocks - - norm_method = self.config.normalization.get("method", "ntc") - logger.info(f"NTC normalizing at guide level (method={norm_method!r})...") - adata_guide = normalize_guide_adata(adata_guide, norm_method) - adata_guide.X = adata_guide.X.astype(np.float32) - - logger.info("Aggregating guide → gene...") - adata_gene = aggregate_to_level( - adata_guide, - "gene", - preserve_batch_info=False, - subsample_controls=False, - ) - logger.info(f" Guide: {adata_guide.n_obs} obs × {adata_guide.n_vars} features") - logger.info(f" Gene: {adata_gene.n_obs} obs × {adata_gene.n_vars} features") - - # Stamp required metadata fields (inferred from per-signal intermediates) - gene_metadata = {**pca_optimized_metadata, "aggregation_level": "gene"} - for adata, meta in [ - (adata_guide, pca_optimized_metadata), - (adata_gene, gene_metadata), - ]: - adata.uns["cell_type"] = cell_type - adata.uns["embedding_type"] = embedding_type - adata.uns["comprehensive_metadata"] = meta - adata.uns["aggregation_method"] = "mean" - - # Validate before returning - from ops_model.post_process.anndata_processing.anndata_validator import ( - AnndataValidator, - ) - - _validator = AnndataValidator() - for _adata, _level in [(adata_guide, "guide"), (adata_gene, "gene")]: - _report = _validator.validate(_adata, level=_level, strict=False) - if not _report.is_valid: - logger.warning( - f"Phase 2 {_level}-level AnnData failed validation:\n{_report.summary()}" - ) - else: - logger.info( - f"Phase 2 {_level}-level AnnData passed validation ({_report.get_warning_count()} warnings)" - ) - - # Embeddings on gene level - embedding_config = self.config.embeddings.get( - "gene_level" - ) or self.config.embeddings.get("guide_level") - if embedding_config is not None and embedding_config.compute_embeddings: - self._compute_embeddings(adata_gene, embedding_config) - - # Annotate obs with per-gene panel metadata (Funk/Ramezani/Replogle map - # coords, CORUM/REACT/GO membership, Gene_Category, NCBI_ID, …) on both - # levels. Guide-level joins via the guide's target gene in obs["perturbation"]. - logger.info("Joining annotated_gene_panel metadata into obs...") - for _adata, _level in [(adata_guide, "guide"), (adata_gene, "gene")]: - logger.info(f" ({_level} level)") - self._join_gene_panel_metadata(_adata) - - return adata_guide, adata_gene - - # ------------------------------------------------------------------ - # Orchestrator - # ------------------------------------------------------------------ - - def combine(self) -> Tuple["ad.AnnData", "ad.AnnData"]: - """Run the full two-phase pipeline and return (adata_guide, adata_gene).""" - output_dir = Path(self.config.output_path) - output_dir.mkdir(parents=True, exist_ok=True) - - # 1. Resolve experiments and group by biological signal - pairs = self._resolve_experiments() - signal_groups = self._group_by_signal(pairs) - - # 2. Resolve downsampling target (auto pre-scan happens here if needed) - downsampling_config = dict(self.config.downsampling) - if ( - downsampling_config.get("enabled", False) - and downsampling_config.get("target_n_cells", "auto") == "auto" - ): - downsampling_config["target_n_cells"] = self._compute_auto_target( - signal_groups - ) - - # 3. Phase 1: PCA sweep per signal group - from .cell_filters import build_cell_filter - - cell_filter = build_cell_filter(self.config.cell_filters) - save_cell_level = self.config.save_cell_level - - slurm_enabled = self.config.slurm.get("enabled", False) - if slurm_enabled: - self._run_phase1_slurm(signal_groups, output_dir, downsampling_config, cell_filter, save_cell_level) - else: - self._run_phase1_local(signal_groups, output_dir, downsampling_config, cell_filter, save_cell_level) - - # 4. Phase 2: aggregate, normalize, embed - if self.config.preserve_batch or self.config.no_pca: - logger.info( - f"Skipping Phase 2 aggregation (preserve_batch={self.config.preserve_batch}, no_pca={self.config.no_pca})." - ) - return None, None - - logger.info("Starting Phase 2 aggregation...") - adata_guide, adata_gene = self._run_phase2(output_dir) - - logger.info("PCA-optimized combination complete.") - return adata_guide, adata_gene diff --git a/src/ops_model/post_process/combination/config_handler.py b/src/ops_model/post_process/combination/config_handler.py deleted file mode 100644 index c64f27c..0000000 --- a/src/ops_model/post_process/combination/config_handler.py +++ /dev/null @@ -1,242 +0,0 @@ -import logging -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Union -import yaml -from pathlib import Path - -logger = logging.getLogger(__name__) - - -@dataclass -class EmbeddingConfig: - """Configuration for embeddings at a specific level (cell/guide/gene).""" - - compute_embeddings: bool = True - n_pca_components: int = 128 - n_neighbors: int = 15 - pca: bool = True - umap: bool = True - phate: bool = True - - def __post_init__(self): - """Validate embedding configuration.""" - # Note: UMAP can be computed on PCA features (pca=True) or raw features (pca=False) - # Both are valid - no forced validation needed - pass - - def get_embeddings_list(self) -> List[str]: - """Return list of enabled embedding types.""" - return [ - name - for name, enabled in [ - ("PCA", self.pca), - ("UMAP", self.umap), - ("PHATE", self.phate), - ] - if enabled and self.compute_embeddings - ] - - -@dataclass -class CombinationConfig: - """Main configuration object for combining experiments.""" - - # Core Fields - concatenation_method: str - feature_type: str - base_dir: str - feature_dir: str - output_path: Optional[str] = None - - # Validator Fields - these will be populated from input anndata objects - cell_type: Optional[str] = None - embedding_type: Optional[str] = None - - # Experiment/Channel Fields - experiments: Optional[List[str]] = None - experiments_channels: Optional[Dict[str, List[str]]] = None - channel: Optional[str] = None - - # Explicit cell-h5ad path override for pca_optimized: {experiment: {channel: path}}. - # When set, bypasses find_cell_h5ad_path discovery (the standard - # {base}/{exp}/3-assembly/{feature_dir}/anndata_objects layout) and uses these - # exact cell-level files. Its keys also define the (experiment, channel) pairs, - # so everything downstream (signal grouping, downsampling, PCA) is unchanged. - cell_paths: Optional[Dict[str, Dict[str, str]]] = None - - # Aggregation & Normalization - aggregation_level: Optional[str] = "cell" - aggregation_per_experiment: bool = False - aggregation_per_well: bool = False - normalization: Dict[str, Any] = field(default_factory=dict) - control_subsampling: Dict[str, Any] = field(default_factory=dict) - fitted_embeddings: Dict[str, Any] = field(default_factory=dict) - leiden_clustering: Dict[str, Any] = field(default_factory=dict) - - # Cell-level PCA optimization (pca_optimized method) - auto_discover: bool = False - reporters: Optional[List[str]] = None - pca: Dict[str, Any] = field(default_factory=dict) - downsampling: Dict[str, Any] = field(default_factory=dict) - slurm: Dict[str, Any] = field(default_factory=dict) - output_filename: Optional[str] = None - preserve_batch: bool = False - no_pca: bool = False - save_cell_level: bool = False - apply_iss_sidecar: bool = False # use ops_model.data.iss_drift_fix sidecars when loading cell h5ads - - # Cell-level filters applied after loading, before PCA/aggregation - cell_filters: List[Dict[str, Any]] = field(default_factory=list) - - # Embedding Configs - embeddings: Dict[str, EmbeddingConfig] = field(default_factory=dict) - - # Raw config for reference - raw_config: Dict[str, Any] = field(default_factory=dict, repr=False) - - def __post_init__(self): - # Initialize embedding configs from raw dict if present - if "embeddings" in self.raw_config and isinstance( - self.raw_config["embeddings"], dict - ): - self.embeddings = { - level: EmbeddingConfig(**params) - for level, params in self.raw_config["embeddings"].items() - } - else: - self.embeddings = { - "cell_level": EmbeddingConfig(), - "guide_level": EmbeddingConfig(), - "gene_level": EmbeddingConfig(), - } - - -def load_config( - config_path: Union[str, Path], output_path_override: Optional[str] = None -) -> CombinationConfig: - """ - Loads, validates, and processes the configuration file. - - Args: - config_path: Path to the YAML configuration file. - output_path_override: Optional path to override the output path in the config. - - Returns: - A validated CombinationConfig object. - """ - with open(config_path, "r") as f: - raw_config = yaml.safe_load(f) - - # Store raw config for embedding initialization - config_dict = raw_config.copy() - - # Auto-detection mode - if "feature_extraction_configs" in config_dict: - _auto_detect_experiments(config_dict) - - # Backwards compatibility for embedding settings - if "embedding_settings" in config_dict: - config_dict["embeddings"] = config_dict.pop("embedding_settings") - for level in ["cell_level", "guide_level", "gene_level"]: - if level not in config_dict["embeddings"]: - config_dict["embeddings"][level] = {} - - # Auto-generate output path if not specified - if not config_dict.get("output_path"): - config_dict["output_path"] = _generate_output_path(config_dict) - - if output_path_override: - config_dict["output_path"] = output_path_override - - # Create dataclass instance - # Pass raw_config separately to handle nested dataclasses - final_config_dict = { - k: v for k, v in config_dict.items() if k in CombinationConfig.__annotations__ - } - final_config_dict["raw_config"] = config_dict - - return CombinationConfig(**final_config_dict) - - -def _generate_output_path(config: Dict[str, Any]) -> str: - """Auto-generates an output path if not specified.""" - base_dir = config.get("base_dir", ".") - method = config.get("concatenation_method", "combined") - feature_type = config.get("feature_type", "features") - - experiment_part = "multi-experiment" - if config.get("experiments"): - if len(config["experiments"]) == 1: - experiment_part = config["experiments"][0] - else: - experiment_part = f"{config['experiments'][0]}_etc" - - filename = f"{method}_{feature_type}_{experiment_part}.h5ad" - return str(Path(base_dir) / "combined_anndata" / filename) - - -def normalize_feature_type(feature_type: str) -> str: - """Normalize feature type names to canonical form.""" - aliases = { - "dino": "dinov3", - "dinov3": "dinov3", - "cell_dino": "cell_dino", - "cellprofiler": "cellprofiler", - "cell-profiler": "cellprofiler", - } - canonical = aliases.get(feature_type.lower().strip()) - if canonical is None: - logger.warning(f"Unknown feature_type '{feature_type}' - using as-is") - return feature_type - return canonical - - -def _parse_extraction_config(config_path: Path) -> Optional[Tuple[str, str, List[str]]]: - """Parse a feature extraction config to extract metadata.""" - if not config_path.exists(): - raise FileNotFoundError(f"Config file not found: {config_path}") - - with open(config_path) as f: - config = yaml.safe_load(f) - - dm_config = config.get("data_manager", {}) - experiments = dm_config.get("experiments", {}) - if not experiments: - raise ValueError( - f"Config {config_path} has empty or missing 'data_manager.experiments'" - ) - - experiment_name = list(experiments.keys())[0] - feature_type = normalize_feature_type(config.get("model_type", "")) - out_channels = dm_config.get("out_channels", []) - - if isinstance(out_channels, str) and out_channels.lower() in ["random", "all"]: - logger.warning( - f"Skipping {config_path.name}: out_channels='{out_channels}' is not supported for auto-detection." - ) - return None - - channels = [out_channels] if isinstance(out_channels, str) else out_channels - return experiment_name, feature_type, channels - - -def _auto_detect_experiments(config_dict: Dict[str, Any]): - """Auto-detects experiments and channels from feature extraction configs.""" - logger.info("AUTO-DETECTION MODE: Using 'feature_extraction_configs'") - extraction_configs = config_dict["feature_extraction_configs"] - - detected_experiments = {} - for config_path_str in extraction_configs: - try: - result = _parse_extraction_config(Path(config_path_str)) - if result: - exp, _, channels = result - if exp not in detected_experiments: - detected_experiments[exp] = [] - detected_experiments[exp].extend(channels) - except (FileNotFoundError, ValueError) as e: - logger.error(f"Error parsing {config_path_str}: {e}") - raise - - config_dict["experiments_channels"] = detected_experiments - logger.info(f"Auto-detected {len(detected_experiments)} experiments.") diff --git a/src/ops_model/post_process/combination/file_validator.py b/src/ops_model/post_process/combination/file_validator.py deleted file mode 100644 index 9d53dc6..0000000 --- a/src/ops_model/post_process/combination/file_validator.py +++ /dev/null @@ -1,155 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Optional, Tuple - -from .config_handler import CombinationConfig -from ops_utils.data.feature_metadata import FeatureMetadata - -# Initialize logger -logger = logging.getLogger(__name__) - - -class FilePathBuilder: - """Constructs file paths for AnnData objects based on configuration.""" - - def __init__(self, base_dir: str, feature_dir: str): - self.base_dir = Path(base_dir) - self.feature_dir = feature_dir - self.meta = FeatureMetadata() - - def get_anndata_path( - self, - experiment: str, - feature_type: str, - aggregation_level: str, - channel: Optional[str] = None, - reporter: Optional[str] = None, - ) -> Path: - """ - Constructs the path to a specific AnnData file. - """ - exp_short = experiment.split("_")[0] - try: - exp_dir = self._find_experiment_dir(exp_short) - except FileNotFoundError: - raise - - anndata_dir = exp_dir / "3-assembly" / self.feature_dir / "anndata_objects" - - # If reporter is not provided, derive it from the channel - if reporter is None and channel: - reporter = self.meta.get_biological_signal(exp_short, channel) - - # Determine filename based on aggregation and feature type - if aggregation_level == "cell": - filename = ( - f"features_processed_{reporter}.h5ad" - if reporter - else "features_processed.h5ad" - ) - elif aggregation_level in ["guide", "gene"]: - filename = ( - f"{aggregation_level}_bulked_{reporter}.h5ad" - if reporter - else f"{aggregation_level}_bulked.h5ad" - ) - else: - raise ValueError(f"Unknown aggregation level: {aggregation_level}") - - return anndata_dir / filename - - def _find_experiment_dir(self, exp_short_name: str) -> Path: - """ - Finds the full experiment directory, which may have a date suffix. - - Args: - exp_short_name: The short name of the experiment (e.g., "ops0089"). - - Returns: - The full Path to the experiment directory. - - Raises: - FileNotFoundError: If no matching directory is found. - """ - exp_dirs = list(self.base_dir.glob(f"{exp_short_name}*")) - if not exp_dirs: - raise FileNotFoundError( - f"Experiment directory not found for pattern: {exp_short_name}* in {self.base_dir}" - ) - # Assuming the first match is the correct one - return exp_dirs[0] - - -class FileValidator: - """Validates the existence of input AnnData files based on the configuration.""" - - def __init__(self, config: CombinationConfig): - self.config = config - self.builder = FilePathBuilder(config.base_dir, config.feature_dir) - - def validate_and_collect_files(self) -> List[Path]: - """ - Validates that required input files exist and returns a list of valid paths. - - Implements the "warn and continue" strategy. If a file is not found, a - warning is logged, and the file is omitted from the returned list. - - Returns: - A list of Path objects for all found AnnData files. - """ - logger.info("Starting input file validation...") - valid_paths = [] - method = self.config.concatenation_method - - if method == "vertical": - paths_to_check = self._get_paths_for_vertical() - elif method in ["horizontal", "comprehensive"]: - paths_to_check = self._get_paths_for_horizontal_or_comprehensive() - else: - raise ValueError(f"Unknown concatenation method for validation: {method}") - - for path in paths_to_check: - if path.exists(): - valid_paths.append(path) - logger.info(f"✓ Found file: {path}") - else: - logger.warning(f"✗ WARNING: File not found, skipping: {path}") - - if not valid_paths: - logger.error("No valid input files were found. Cannot proceed.") - - logger.info(f"Validation complete. Found {len(valid_paths)} valid input files.") - return valid_paths - - def _get_paths_for_vertical(self) -> List[Path]: - """Generate expected file paths for vertical combination.""" - paths = [] - for exp in self.config.experiments or []: - try: - path = self.builder.get_anndata_path( - experiment=exp, - feature_type=self.config.feature_type, - aggregation_level=self.config.aggregation_level or "cell", - channel=self.config.channel, - ) - paths.append(path) - except FileNotFoundError as e: - logger.warning(e) - return paths - - def _get_paths_for_horizontal_or_comprehensive(self) -> List[Path]: - """Generate expected file paths for horizontal/comprehensive combination.""" - paths = [] - for exp, channels in (self.config.experiments_channels or {}).items(): - for channel in channels: - try: - path = self.builder.get_anndata_path( - experiment=exp, - feature_type=self.config.feature_type, - aggregation_level="cell", # Horizontal is always cell level first - channel=channel, - ) - paths.append(path) - except FileNotFoundError as e: - logger.warning(e) - return paths diff --git a/src/ops_model/post_process/combination/pca_optimization/__init__.py b/src/ops_model/post_process/combination/pca_optimization/__init__.py index b6eb363..9187de4 100644 --- a/src/ops_model/post_process/combination/pca_optimization/__init__.py +++ b/src/ops_model/post_process/combination/pca_optimization/__init__.py @@ -17,7 +17,7 @@ 12 variants — feature type × channel subset ------------------------------------------- Each variant produces an independent output subtree and can be compared via -compare_map_scores.py. Replace --slurm with --aggregate-only --slurm to re-run +analysis/compare_map_scores.py. Replace --slurm with --aggregate-only --slurm to re-run Phase 2 only (e.g. after code changes) without redoing the PCA sweeps. Variant Flags Output subdir @@ -132,12 +132,12 @@ plot_positive_controls_grid, ) -from ops_model.post_process.combination.pca_optimization.chromosome import ( +from ops_model.post_process.combination.pipeline_add_ons.chromosome import ( _load_chromosome_map, _plot_chromosome_overlay, _plot_chromosome_overlay_html, ) -from ops_model.post_process.combination.pca_optimization.op_signal import ( +from ops_model.post_process.combination.pipeline_add_ons.op_signal import ( _discover_op_files, pca_sweep_op_signal, ) diff --git a/src/ops_model/post_process/combination/pca_optimization/embeddings.py b/src/ops_model/post_process/combination/pca_optimization/embeddings.py index c1f8234..b070d1a 100644 --- a/src/ops_model/post_process/combination/pca_optimization/embeddings.py +++ b/src/ops_model/post_process/combination/pca_optimization/embeddings.py @@ -30,7 +30,7 @@ from ops_model.post_process.combination.pca_optimization.aggregation import ( _annotate_genes_from_panel, ) -from ops_model.post_process.combination.pca_optimization.chromosome import ( +from ops_model.post_process.combination.pipeline_add_ons.chromosome import ( _load_chromosome_map, _plot_chromosome_overlay, _plot_chromosome_overlay_html, diff --git a/src/ops_model/post_process/combination/pca_optimization/handlers.py b/src/ops_model/post_process/combination/pca_optimization/handlers.py index 896ed5a..caafe20 100644 --- a/src/ops_model/post_process/combination/pca_optimization/handlers.py +++ b/src/ops_model/post_process/combination/pca_optimization/handlers.py @@ -46,7 +46,7 @@ _atomic_write_h5ad, _plot_chad_umap, ) -from ops_model.post_process.combination.pca_optimization.chromosome import ( +from ops_model.post_process.combination.pipeline_add_ons.chromosome import ( _load_chromosome_map, _plot_chromosome_overlay, _plot_chromosome_overlay_html, @@ -54,7 +54,7 @@ from ops_model.post_process.combination.pca_optimization.embeddings import ( _compute_and_plot_embeddings, ) -from ops_model.post_process.combination.pca_optimization.op_signal import ( +from ops_model.post_process.combination.pipeline_add_ons.op_signal import ( _discover_op_files, pca_sweep_op_signal, ) @@ -531,7 +531,7 @@ def _load_csv(name): corum_map = _load_csv("phenotypic_consistency_corum.csv") chad_map = _load_csv("phenotypic_consistency_manual.csv") - from ops_model.post_process.combination.embedding_overlays import save_extra_overlays + from ops_model.post_process.combination.analysis.embedding_overlays import save_extra_overlays save_extra_overlays( adata_guide=adata_guide, @@ -813,7 +813,7 @@ def _handle_overlays_only(args, output_dir): subdir = f"second_pca_{int(round(threshold * 100))}" # Honor the chrom-arm-correction suffix used by run_chrom_arm_then_second_pca if getattr(args, "chrom_arm_correct", False): - from ops_model.post_process.combination.guide_chrom_arm_correction import ( + from ops_model.post_process.combination.pipeline_add_ons.guide_chrom_arm_correction import ( METHOD_SUFFIX, ) method = getattr(args, "chrom_arm_method", "cohesion") @@ -824,7 +824,7 @@ def _handle_overlays_only(args, output_dir): # When --chrom-arm-correct is on but no --chromosome-csv was given, # auto-pipe the shared symbol→arm cache so the chr-arm overlay plot fires. if getattr(args, "chrom_arm_correct", False) and not getattr(args, "chromosome_csv", None): - from ops_model.post_process.combination.guide_chrom_arm_correction import ( + from ops_model.post_process.combination.pipeline_add_ons.guide_chrom_arm_correction import ( SHARED_MAP_CSV_PATH, ) if SHARED_MAP_CSV_PATH is not None and SHARED_MAP_CSV_PATH.is_file(): @@ -987,7 +987,7 @@ def run_second_pca_then_chrom_arm( clobbers the standard ``second_pca_consensus_chrom_arm_corr*/`` dirs produced by the original-order wrapper. """ - from ops_model.post_process.combination.guide_chrom_arm_correction import ( + from ops_model.post_process.combination.pipeline_add_ons.guide_chrom_arm_correction import ( run_chrom_arm_correction, METHOD_SUFFIX, ) @@ -1060,7 +1060,7 @@ def run_chrom_arm_then_second_pca( Keeping this as a top-level function (not a closure inside ``_handle_second_pca``) so submitit can pickle it for SLURM submission. """ - from ops_model.post_process.combination.guide_chrom_arm_correction import ( + from ops_model.post_process.combination.pipeline_add_ons.guide_chrom_arm_correction import ( run_chrom_arm_correction, METHOD_SUFFIX, ) @@ -1104,7 +1104,7 @@ def _handle_second_pca(args, output_dir): not chromosome_csv and getattr(args, "chrom_arm_correct", False) ): - from ops_model.post_process.combination.guide_chrom_arm_correction import ( + from ops_model.post_process.combination.pipeline_add_ons.guide_chrom_arm_correction import ( SHARED_MAP_CSV_PATH, ) if SHARED_MAP_CSV_PATH is not None and SHARED_MAP_CSV_PATH.is_file(): diff --git a/src/ops_model/post_process/combination/pca_optimization/phase2.py b/src/ops_model/post_process/combination/pca_optimization/phase2.py index 9ba8f70..257b570 100644 --- a/src/ops_model/post_process/combination/pca_optimization/phase2.py +++ b/src/ops_model/post_process/combination/pca_optimization/phase2.py @@ -339,7 +339,7 @@ def aggregate_channels( # Extra overlays (super-category, Leiden, interactive HTML) try: - from ops_model.post_process.combination.embedding_overlays import ( + from ops_model.post_process.combination.analysis.embedding_overlays import ( save_extra_overlays, ) @@ -1048,7 +1048,7 @@ def apply_second_pass_pca( # Extra overlays (super-category, Leiden, interactive HTML) try: - from ops_model.post_process.combination.embedding_overlays import ( + from ops_model.post_process.combination.analysis.embedding_overlays import ( save_extra_overlays, ) diff --git a/src/ops_model/post_process/combination/pipeline_add_ons/__init__.py b/src/ops_model/post_process/combination/pipeline_add_ons/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ops_model/post_process/combination/pca_optimization/chromosome.py b/src/ops_model/post_process/combination/pipeline_add_ons/chromosome.py similarity index 100% rename from src/ops_model/post_process/combination/pca_optimization/chromosome.py rename to src/ops_model/post_process/combination/pipeline_add_ons/chromosome.py diff --git a/src/ops_model/post_process/combination/guide_chrom_arm_correction.py b/src/ops_model/post_process/combination/pipeline_add_ons/guide_chrom_arm_correction.py similarity index 100% rename from src/ops_model/post_process/combination/guide_chrom_arm_correction.py rename to src/ops_model/post_process/combination/pipeline_add_ons/guide_chrom_arm_correction.py diff --git a/src/ops_model/post_process/combination/pca_optimization/op_signal.py b/src/ops_model/post_process/combination/pipeline_add_ons/op_signal.py similarity index 100% rename from src/ops_model/post_process/combination/pca_optimization/op_signal.py rename to src/ops_model/post_process/combination/pipeline_add_ons/op_signal.py From e8d1026adc3786ace9700e9011f1265d9f8c5a30 Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 10 Jun 2026 14:06:31 -0700 Subject: [PATCH 02/11] Add config entry point + external signal_paths ingest to pca_optimization Two additive features for the pca_optimization combination pipeline, both reusing the existing argparse Namespace + Phase 1/2 with no parallel schema. 1. --config : run the pipeline from a config file whose keys are the CLI argument names (snake_case dest names). Config values populate argparse defaults via set_defaults; any explicit CLI flag still overrides. Adds run_from_config() (programmatic entry) and _load_and_validate_config() (rejects unknown keys + the phase_only/no_phase conflict). main() is split into main() (parse + config merge) and run(args) (the unchanged dispatch). Example at pca_optimization/example_config.yml. 2. signal_paths (a config key): combine cell-level embedding h5ads that live OUTSIDE the standard experiment layout. Maps a signal-group name -> one h5ad path or a list of paths (pooled); each h5ad uses the same schema as the discovery features_processed_*.h5ad. phase1.pca_sweep_pooled_signal gains an optional cell_paths override (explicit path instead of find_cell_h5ad_path); new handlers._handle_external builds signal groups from the manifest and reuses the pooled worker + Phase 2. Experiment discovery is skipped; output lands under /external/. Verified: 41/41 structural tests pass; external ingest validated end-to-end (two synthetic h5ads pooled into one signal -> per_signal guide/gene outputs). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../post_process/combination/README.md | 55 ++++++ .../combination/pca_optimization/__init__.py | 73 +++++++- .../pca_optimization/example_config.yml | 56 ++++++ .../combination/pca_optimization/handlers.py | 170 ++++++++++++++++++ .../combination/pca_optimization/parser.py | 22 +++ .../combination/pca_optimization/phase1.py | 19 +- 6 files changed, 388 insertions(+), 7 deletions(-) create mode 100644 src/ops_model/post_process/combination/pca_optimization/example_config.yml diff --git a/src/ops_model/post_process/combination/README.md b/src/ops_model/post_process/combination/README.md index 6f5272e..24be61f 100644 --- a/src/ops_model/post_process/combination/README.md +++ b/src/ops_model/post_process/combination/README.md @@ -47,6 +47,61 @@ python -m ops_model.post_process.combination.pca_optimization --help **Exactly one feature-mode flag is required** (no implicit default): `--cell-dino` · `--dino` · `--cell-profiler` · `--dynaclr` · `--subcell` · `--organelle-profiler`. +### Run from a config file + +Instead of (or alongside) CLI flags, pass a YAML config with `--config`: + +```bash +python -m ops_model.post_process.combination.pca_optimization --config my_run.yml +``` + +The config keys are the **same arguments**, written as their snake_case names +(`--cell-dino` → `cell_dino`, `--phase-only` → `phase_only`, `--output-dir` → `output_dir`). +Config values populate the defaults; **any flag passed explicitly on the command line still +overrides the config** (e.g. `--config my_run.yml --no-slurm`). To turn off a default-on flag +in the config, set it `false` (e.g. `second_pca: false`). Unknown keys are rejected. + +A worked example (the validation-cohort run) is at +[`pca_optimization/example_config.yml`](pca_optimization/example_config.yml): + +```yaml +cell_dino: true +output_dir: /hpc/projects/icd.fast.ops/organelle_attribution/pca_optimized_v0.3 +experiments: ops0146,ops0147,ops0150,ops0151 +phase_only: true +fixed_threshold: 0.80 +slurm: true +``` + +Programmatic equivalent: + +```python +from ops_model.post_process.combination.pca_optimization import run_from_config +run_from_config("my_run.yml") +``` + +### Combine embeddings outside the experiment structure + +If your cell-level embeddings don't live in the standard +`…/3-assembly//anndata_objects/` layout, add a **`signal_paths`** mapping to the +config (no separate flag). It maps each signal-group name to one h5ad path, or a list of paths +that get **pooled**: + +```yaml +cell_dino: true # still pick a feature mode (for metadata/labels) +output_dir: /path/to/out +signal_paths: + Phase: /data/runA/phase.h5ad # one file = one signal + MAP4: # multiple files pooled into one signal + - /data/runA/map4.h5ad + - /data/runB/map4.h5ad +``` + +Each h5ad must have the **same schema as the discovery `features_processed_*.h5ad`** (obs with +`sgRNA` / `perturbation` / `experiment`; `X` = the embedding matrix). When `signal_paths` is set, +experiment discovery is skipped, every other option (PCA threshold, normalization, downsampling, +SLURM, second-pass PCA, …) applies as usual, and output lands under `/external/`. + --- ## Inputs diff --git a/src/ops_model/post_process/combination/pca_optimization/__init__.py b/src/ops_model/post_process/combination/pca_optimization/__init__.py index 9187de4..0a5131b 100644 --- a/src/ops_model/post_process/combination/pca_optimization/__init__.py +++ b/src/ops_model/post_process/combination/pca_optimization/__init__.py @@ -192,6 +192,7 @@ _handle_aggregate_only, _handle_chad_umap_only, _handle_downsampled, + _handle_external, _handle_op, _handle_overlays_only, _handle_second_pca, @@ -266,6 +267,54 @@ +def _load_and_validate_config(config_path: str) -> dict: + """Load a YAML config and validate its keys against the CLI argument set. + + Keys must be argparse ``dest`` names (snake_case), so a config is just the + CLI args expressed as YAML (``--cell-dino`` → ``cell_dino``). Returns the + parsed dict; the caller feeds it to ``parser.set_defaults(**cfg)``. + """ + import yaml + + with open(config_path) as f: + cfg = yaml.safe_load(f) or {} + if not isinstance(cfg, dict): + raise ValueError( + f"Config {config_path} must be a YAML mapping of arg→value, " + f"got {type(cfg).__name__}." + ) + valid_dests = { + a.dest for a in _build_parser()._actions if a.dest not in ("help", "config") + } + unknown = sorted(set(cfg) - valid_dests) + if unknown: + raise ValueError( + f"Unknown config key(s): {unknown}. Keys must match CLI argument names " + f"as snake_case dest names (e.g. cell_dino, phase_only, output_dir, " + f"fixed_threshold). Run the module with --help for the full list." + ) + # set_defaults bypasses argparse's mutually-exclusive-group check, so guard + # the one pair a config can realistically set together. (The "exactly one + # feature-mode flag" rule is still enforced in run() below.) + if cfg.get("phase_only") and cfg.get("no_phase"): + raise ValueError( + "Config sets both phase_only and no_phase (mutually exclusive)." + ) + return cfg + + +def run_from_config(config_path: str): + """Programmatic entry point: run the pipeline from a YAML config (no CLI). + + Equivalent to ``--config `` on the command line. See + ``pca_optimization/example_config.yml`` for the key set. + """ + cfg = _load_and_validate_config(config_path) + parser = _build_parser() + parser.set_defaults(**cfg) + run(parser.parse_args([])) + + def main(): # Force line-buffered stdout so progress prints appear in real time when # launched under `uv run`, `nohup`, or any other wrapper that pipes @@ -277,8 +326,19 @@ def main(): except (AttributeError, ValueError): pass - global CHAD_ANNOTATION_PATH, EBI_ANNOTATION_PATH args = _build_parser().parse_args() + if getattr(args, "config", None): + # Config file populates argparse defaults; any explicit CLI flag still + # overrides it (re-parse the same argv against the config-seeded parser). + cfg = _load_and_validate_config(args.config) + parser = _build_parser() + parser.set_defaults(**cfg) + args = parser.parse_args() + run(args) + + +def run(args): + global CHAD_ANNOTATION_PATH, EBI_ANNOTATION_PATH CHAD_ANNOTATION_PATH = args.chad_annotation EBI_ANNOTATION_PATH = args.ebi_annotation # --seed default depends on --umap-type: max → 1 (Max's recipe), gav → 42 (legacy). @@ -287,6 +347,17 @@ def main(): print(f"--seed unset, resolved to {args.seed} (umap_type={args.umap_type})") output_dir = Path(args.output_dir) + # External mode: combine explicit per-signal h5ads given in the config's + # `signal_paths` mapping (embeddings outside the experiment structure). + # Bypasses the feature-mode requirement and experiment discovery. + if getattr(args, "signal_paths", None): + args.phase_filter = None + args.all_cells = not getattr(args, "downsampled", False) + out = output_dir if args.direct else output_dir / "external" + print(f"External mode (signal_paths): output → {out}") + _handle_external(args, out) + return + # --only-4i / --only-cp imply the corresponding --with-* and turn off the # standard scan. Apply once here so both --direct and standard paths see it. only_4i = getattr(args, "only_4i", False) diff --git a/src/ops_model/post_process/combination/pca_optimization/example_config.yml b/src/ops_model/post_process/combination/pca_optimization/example_config.yml new file mode 100644 index 0000000..2082031 --- /dev/null +++ b/src/ops_model/post_process/combination/pca_optimization/example_config.yml @@ -0,0 +1,56 @@ +# Example config for the pca_optimization combination pipeline. +# +# Run with: +# python -m ops_model.post_process.combination.pca_optimization --config +# +# Keys are the CLI argument names as snake_case (argparse "dest" names): +# --cell-dino -> cell_dino +# --phase-only -> phase_only +# --output-dir -> output_dir +# --fixed-threshold -> fixed_threshold +# Anything omitted here falls back to the CLI default. Any flag passed on the +# command line overrides the value here (e.g. `--config this.yml --no-slurm`). +# Run the module with --help for the full key list. +# +# This example reproduces the validation-cohort run from the module docstring. + +# --- Feature mode (exactly ONE of: cell_dino / dino / cell_profiler / dynaclr / subcell / organelle_profiler) --- +cell_dino: true + +# --- Output --- +output_dir: /hpc/projects/icd.fast.ops/organelle_attribution/pca_optimized_v0.3 +run_tag: paper_v1/validation_4exp_phase_only # organizational output subfolder + +# --- Experiment / channel selection --- +experiments: ops0146,ops0147,ops0150,ops0151 # comma-separated short names; omit for full discovery +phase_only: true # Phase (brightfield) channels only; or set no_phase: true + +# --- PCA --- +fixed_threshold: 0.80 # single variance cutoff; set 0 to run the full consensus sweep + +# --- Normalization --- +norm_method: ntc # ntc | global +zscore_per_experiment: true + +# --- Consistency-score annotation --- +chad_annotation: /hpc/projects/icd.fast.ops/configs/gene_clusters/val_library_chad_positive_controls_v1.yml + +# --- Execution --- +slurm: true # false → run Phase 1 + Phase 2 in-process + +# To turn OFF a default-on flag, set it false, e.g.: +# second_pca: false + +# --- External embeddings (optional) --- +# To combine embeddings that live OUTSIDE the standard experiment layout, set +# `signal_paths` here instead of using experiment discovery. It maps a +# signal-group name -> one h5ad path or a list of paths (pooled). Each h5ad must +# have the same schema as the discovery features_processed_*.h5ad (obs: sgRNA / +# perturbation / experiment; X = embedding). When set, `experiments` / feature +# discovery are ignored and output lands under /external/. +# +# signal_paths: +# Phase: /data/runA/phase_embeddings.h5ad # one file = one signal +# MAP4: # multiple files pooled +# - /data/runA/map4_embeddings.h5ad +# - /data/runB/map4_embeddings.h5ad diff --git a/src/ops_model/post_process/combination/pca_optimization/handlers.py b/src/ops_model/post_process/combination/pca_optimization/handlers.py index caafe20..18cad89 100644 --- a/src/ops_model/post_process/combination/pca_optimization/handlers.py +++ b/src/ops_model/post_process/combination/pca_optimization/handlers.py @@ -1412,6 +1412,176 @@ def _op_job_kwargs(signal: str, path: Path) -> Dict: ) +def _handle_external(args, output_dir): + """External mode: combine explicit per-signal h5ads from config ``signal_paths``. + + ``args.signal_paths`` maps a signal-group name -> one h5ad path or a list of + paths. Each h5ad must have the same schema as the discovery + ``features_processed_*.h5ad`` (obs with sgRNA / perturbation / experiment; + ``X`` = embedding). Multiple paths under one signal are pooled. Experiment + discovery is bypassed; the standard ``pca_sweep_pooled_signal`` worker (with + an explicit-path override) + Phase 2 aggregation are reused unchanged. + """ + ds_output_dir = output_dir + ds_output_dir.mkdir(parents=True, exist_ok=True) + + if getattr(args, "clean", False): + import shutil + + per_signal_dir = ds_output_dir / "per_signal" + if per_signal_dir.exists(): + print(f"--clean: removing {per_signal_dir}") + shutil.rmtree(per_signal_dir) + + spec = args.signal_paths + if not isinstance(spec, dict) or not spec: + print("ERROR: signal_paths must be a non-empty mapping of signal -> path(s).") + return + + # Build signal groups + an explicit (exp_label, channel) -> path override. + # Each file becomes its own synthetic "experiment" batch (channel == signal), + # so per-experiment z-scoring treats each file independently. + signal_groups: Dict[str, list] = {} + cell_path_map: Dict[tuple, str] = {} + missing: list = [] + for signal, paths in spec.items(): + if isinstance(paths, (str, Path)): + paths = [paths] + pairs = [] + for p in paths: + p = Path(p) + if not p.exists(): + missing.append(str(p)) + continue + exp_label = p.stem + pairs.append((exp_label, signal)) + cell_path_map[(exp_label, signal)] = str(p) + if pairs: + signal_groups[signal] = pairs + + if missing: + print("ERROR: signal_paths references missing file(s):") + for m in missing: + print(f" - {m}") + return + if not signal_groups: + print("ERROR: no usable signal_paths entries.") + return + + print(f"External mode: {len(signal_groups)} signal group(s) from explicit paths:") + for sig, pairs in signal_groups.items(): + print(f" {sig}: {len(pairs)} file(s)") + + if getattr(args, "dry_run", False): + print("\n--dry-run: not processing.") + return + + # External files are user-provided; default to keeping all cells unless the + # user opts into downsampling (--target-cells / --downsampled). + target_n_cells = int(getattr(args, "target_cells", 0) or 0) or 10_000_000 + skip_phase2 = getattr(args, "preserve_batch", False) or getattr(args, "no_pca", False) + + def _job_kwargs(signal, pairs): + kwargs = dict( + signal=signal, + exp_channel_pairs=pairs, + output_dir=str(ds_output_dir), + target_n_cells=target_n_cells, + norm_method=args.norm_method, + random_seed=getattr(args, "seed", 42), + distance=args.distance, + zscore_per_experiment=getattr(args, "zscore_per_experiment", False), + exclude_dud_guides=getattr(args, "exclude_dud_guides", True), + cell_paths=cell_path_map, + ) + if args.fixed_threshold is not None and args.fixed_threshold > 0: + kwargs["fixed_threshold"] = args.fixed_threshold + if getattr(args, "preserve_batch", False): + kwargs["preserve_batch"] = True + if getattr(args, "no_pca", False): + kwargs["no_pca"] = True + if getattr(args, "agg_method", "mean") != "mean": + kwargs["agg_method"] = args.agg_method + if getattr(args, "downsample_per_guide", False): + kwargs["downsample_per_guide"] = True + kwargs["cells_per_guide"] = getattr(args, "cells_per_guide", 250) + return kwargs + + if not args.slurm: + print("\nRunning locally (sequential)...") + for signal, pairs in signal_groups.items(): + print(f" {pca_sweep_pooled_signal(**_job_kwargs(signal, pairs))}") + if not skip_phase2: + print( + aggregate_channels( + output_dir=str(ds_output_dir), + norm_method=args.norm_method, + per_unit_subdir="per_signal", + distance=args.distance, + random_seed=getattr(args, "seed", 42), + agg_method=getattr(args, "agg_method", "mean"), + chromosome_csv=getattr(args, "chromosome_csv", None), + umap_type=getattr(args, "umap_type", "max"), + ) + ) + if getattr(args, "second_pca", False): + print("\nChaining 2nd-pass PCA on aggregate output...") + _handle_second_pca(args, ds_output_dir) + return + + # SLURM: one job per signal group, wait, then a chained aggregation job. + from ops_utils.hpc.slurm_batch_utils import submit_parallel_jobs + + jobs = [ + { + "name": f"pca_ext_{sanitize_signal_filename(sig)[:40]}", + "func": pca_sweep_pooled_signal, + "kwargs": _job_kwargs(sig, pairs), + "metadata": {"signal": sig, "n_files": len(pairs)}, + } + for sig, pairs in signal_groups.items() + ] + slurm_params = _make_slurm_params(args) + print( + f"\nSubmitting {len(jobs)} external signal job(s) " + f"({slurm_params['mem']} each, {slurm_params['timeout_min']}min)..." + ) + res = submit_parallel_jobs( + jobs_to_submit=jobs, + experiment="pca_ext", + slurm_params=slurm_params, + log_dir="pca_optimization", + manifest_prefix="pca_ext", + wait_for_completion=True, + ) + if res.get("failed"): + print(f"\n{len(res['failed'])} external signal job(s) failed. Check logs.") + return + print(f"\nAll {len(jobs)} external signal job(s) complete") + + if skip_phase2: + print(" Phase 2 aggregation skipped (--preserve-batch or --no-pca mode)") + return + + sp_kwargs = _build_second_pca_kwargs(args) + print("\nSubmitting aggregation SLURM job...") + _submit_aggregation_slurm( + str(ds_output_dir), + args.norm_method, + "per_signal", + _make_agg_slurm_params(args), + "pca_ext_aggregation", + "pca_ext_agg", + distance=args.distance, + second_pca_kwargs=sp_kwargs, + random_seed=getattr(args, "seed", 42), + agg_method=getattr(args, "agg_method", "mean"), + chromosome_csv=getattr(args, "chromosome_csv", None), + umap_type=getattr(args, "umap_type", "max"), + consensus_metrics=getattr(args, "second_pca_consensus_metrics", None), + ) + + def _handle_downsampled(args, output_dir, cp_override): """Pool cells by signal group, downsample, PCA sweep (local or SLURM).""" from ops_model.post_process.combination.pca_optimization import ( diff --git a/src/ops_model/post_process/combination/pca_optimization/parser.py b/src/ops_model/post_process/combination/pca_optimization/parser.py index ad66825..68bd9ff 100644 --- a/src/ops_model/post_process/combination/pca_optimization/parser.py +++ b/src/ops_model/post_process/combination/pca_optimization/parser.py @@ -25,6 +25,28 @@ def _build_parser(): parser = argparse.ArgumentParser( description="Per-signal pooled PCA optimization for organelle attribution" ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a YAML config whose keys are the CLI argument names as " + "snake_case dest names (e.g. cell_dino, phase_only, output_dir, " + "experiments, fixed_threshold, slurm). Config values populate the " + "defaults; any flag passed explicitly on the command line overrides the " + "config. See pca_optimization/example_config.yml.", + ) + parser.add_argument( + "--signal-paths", + dest="signal_paths", + default=None, + help="(Config-only) Combine embeddings that live OUTSIDE the standard " + "experiment layout. Set in the --config YAML under `signal_paths:` as a " + "mapping of signal-group name -> h5ad path (or list of paths to pool). " + "Each h5ad must have the same schema as the discovery " + "features_processed_*.h5ad (obs: sgRNA / perturbation / experiment; " + "X = embedding). When set, experiment discovery is skipped. Output → " + "/external/.", + ) parser.add_argument( "-o", "--output-dir", diff --git a/src/ops_model/post_process/combination/pca_optimization/phase1.py b/src/ops_model/post_process/combination/pca_optimization/phase1.py index 8c3b55d..c609ee7 100644 --- a/src/ops_model/post_process/combination/pca_optimization/phase1.py +++ b/src/ops_model/post_process/combination/pca_optimization/phase1.py @@ -53,6 +53,7 @@ def pca_sweep_pooled_signal( cells_per_guide: int = 250, agg_method: str = "mean", apply_iss_sidecar: bool = False, + cell_paths: Optional[Dict[Tuple[str, str], str]] = None, ) -> str: """PCA variance sweep on pooled & downsampled cells for a biological signal. @@ -110,12 +111,19 @@ def pca_sweep_pooled_signal( f"Processing signal group: {signal} ({n_exps} experiments, features: {feature_dir})" ) + def _resolve(exp_: str, ch_: str): + """Explicit-path override (external mode) else standard discovery.""" + if cell_paths and (exp_, ch_) in cell_paths: + _pp = Path(cell_paths[(exp_, ch_)]) + return _pp if _pp.exists() else None + return find_cell_h5ad_path(exp_, ch_, storage_roots, feature_dir, maps_path) + # --- Pass 1: Lightweight pre-scan for cell counts (no full matrix load) --- import h5py exp_cell_counts = {} # (exp, ch) -> n_cells for exp, ch in exp_channel_pairs: - cell_file = find_cell_h5ad_path(exp, ch, storage_roots, feature_dir, maps_path) + cell_file = _resolve(exp, ch) if cell_file is not None: try: with h5py.File(cell_file, "r") as f: @@ -140,7 +148,7 @@ def pca_sweep_pooled_signal( for exp, ch in exp_channel_pairs: if (exp, ch) not in exp_cell_counts or exp_cell_counts[(exp, ch)] == 0: continue - path = find_cell_h5ad_path(exp, ch, storage_roots, feature_dir, maps_path) + path = _resolve(exp, ch) if path is None: continue try: @@ -240,15 +248,14 @@ def pca_sweep_pooled_signal( # carry stale labels into PCA + aggregation. See # ops_model.data.iss_drift_fix for how the sidecars are built. from ops_model.features.anndata_utils import load_features_corrected - cell_path = find_cell_h5ad_path( - exp, ch, storage_roots, feature_dir, maps_path - ) + cell_path = _resolve(exp, ch) adata = ( load_features_corrected(cell_path, drop_orphans=True) if cell_path is not None else None ) else: - adata = load_cell_h5ad(exp, ch, storage_roots, feature_dir, maps_path) + _cp = _resolve(exp, ch) + adata = ad.read_h5ad(_cp) if _cp is not None else None if adata is None: continue From 01f0d692d86d4489f242d6abbd806ae9c3599943 Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 10 Jun 2026 14:35:59 -0700 Subject: [PATCH 03/11] Deprecate eval subdir and remove its tests Move src/ops_model/eval into src/ops_model/deprecated/eval (following the existing deprecated/ convention) and delete tests/eval. Also drop the now-dead run_eval console-script entry point from pyproject.toml. Co-Authored-By: Claude Opus 4.8 (1M context) --- pyproject.toml | 3 - .../{ => deprecated}/eval/__init__.py | 0 .../{ => deprecated}/eval/eval_io.py | 0 .../{ => deprecated}/eval/evaluate_gene.py | 0 .../{ => deprecated}/eval/evaluate_guide.py | 0 .../{ => deprecated}/eval/metrics.py | 0 .../{ => deprecated}/eval/run_eval.py | 0 tests/eval/__init__.py | 0 tests/eval/conftest.py | 172 --------------- tests/eval/test_evaluate_gene.py | 201 ------------------ tests/eval/test_evaluate_guide.py | 179 ---------------- tests/eval/test_metrics.py | 83 -------- tests/eval/test_run_eval.py | 138 ------------ 13 files changed, 776 deletions(-) rename src/ops_model/{ => deprecated}/eval/__init__.py (100%) rename src/ops_model/{ => deprecated}/eval/eval_io.py (100%) rename src/ops_model/{ => deprecated}/eval/evaluate_gene.py (100%) rename src/ops_model/{ => deprecated}/eval/evaluate_guide.py (100%) rename src/ops_model/{ => deprecated}/eval/metrics.py (100%) rename src/ops_model/{ => deprecated}/eval/run_eval.py (100%) delete mode 100644 tests/eval/__init__.py delete mode 100644 tests/eval/conftest.py delete mode 100644 tests/eval/test_evaluate_gene.py delete mode 100644 tests/eval/test_evaluate_guide.py delete mode 100644 tests/eval/test_metrics.py delete mode 100644 tests/eval/test_run_eval.py diff --git a/pyproject.toml b/pyproject.toml index b84a223..ed3724a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,9 +68,6 @@ dev = [ homepage = "https://github.com/ahillsley/ops_model" repository = "https://github.com/ahillsley/ops_model" -[project.scripts] -run_eval = "ops_model.eval.run_eval:main" - # Entry points # https://peps.python.org/pep-0621/#entry-points # [project.entry-points."spam.magical"] diff --git a/src/ops_model/eval/__init__.py b/src/ops_model/deprecated/eval/__init__.py similarity index 100% rename from src/ops_model/eval/__init__.py rename to src/ops_model/deprecated/eval/__init__.py diff --git a/src/ops_model/eval/eval_io.py b/src/ops_model/deprecated/eval/eval_io.py similarity index 100% rename from src/ops_model/eval/eval_io.py rename to src/ops_model/deprecated/eval/eval_io.py diff --git a/src/ops_model/eval/evaluate_gene.py b/src/ops_model/deprecated/eval/evaluate_gene.py similarity index 100% rename from src/ops_model/eval/evaluate_gene.py rename to src/ops_model/deprecated/eval/evaluate_gene.py diff --git a/src/ops_model/eval/evaluate_guide.py b/src/ops_model/deprecated/eval/evaluate_guide.py similarity index 100% rename from src/ops_model/eval/evaluate_guide.py rename to src/ops_model/deprecated/eval/evaluate_guide.py diff --git a/src/ops_model/eval/metrics.py b/src/ops_model/deprecated/eval/metrics.py similarity index 100% rename from src/ops_model/eval/metrics.py rename to src/ops_model/deprecated/eval/metrics.py diff --git a/src/ops_model/eval/run_eval.py b/src/ops_model/deprecated/eval/run_eval.py similarity index 100% rename from src/ops_model/eval/run_eval.py rename to src/ops_model/deprecated/eval/run_eval.py diff --git a/tests/eval/__init__.py b/tests/eval/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/eval/conftest.py b/tests/eval/conftest.py deleted file mode 100644 index 74a923d..0000000 --- a/tests/eval/conftest.py +++ /dev/null @@ -1,172 +0,0 @@ -"""Shared fixtures for eval tests.""" - -from __future__ import annotations - -import numpy as np -import pandas as pd -import pytest -import anndata as ad - - -# --------------------------------------------------------------------------- -# Factory functions (used directly in tests that need custom data) -# --------------------------------------------------------------------------- - -def make_guide_adata(perturbations: list[str], embeddings: np.ndarray) -> ad.AnnData: - """Build a minimal valid guide-level AnnData. - - Parameters - ---------- - perturbations : list of str - Perturbation label for each row. - embeddings : ndarray of shape (n_guides, n_features) - """ - n = len(perturbations) - assert embeddings.shape[0] == n - obs = pd.DataFrame( - { - "perturbation": perturbations, - "sgRNA": [f"sg_{i}" for i in range(n)], - "n_cells": [10] * n, - } - ) - obs.index = obs["sgRNA"].values - obs.index.name = None - adata = ad.AnnData(X=embeddings.astype(np.float32), obs=obs) - adata.uns["aggregation_method"] = "mean" - adata.uns["cell_type"] = "HeLa" - adata.uns["embedding_type"] = "test" - return adata - - -def make_gene_adata(perturbations: list[str], embeddings: np.ndarray) -> ad.AnnData: - """Build a minimal valid gene-level AnnData. - - Parameters - ---------- - perturbations : list of str - Perturbation label for each row (one per gene). - embeddings : ndarray of shape (n_genes, n_features) - """ - n = len(perturbations) - assert embeddings.shape[0] == n - obs = pd.DataFrame( - { - "perturbation": perturbations, - "n_cells": [20] * n, - "guides": [["sg_0", "sg_1"]] * n, - "n_experiments": [1] * n, - } - ) - obs.index = perturbations - adata = ad.AnnData(X=embeddings.astype(np.float32), obs=obs) - adata.uns["aggregation_method"] = "mean" - adata.uns["cell_type"] = "HeLa" - adata.uns["embedding_type"] = "test" - return adata - - -def make_clusters(perturbations: list[str]) -> dict: - """Build a clusters dict grouping consecutive pairs of perturbations. - - Parameters - ---------- - perturbations : list of str - Genes to group. Must have even length. - """ - assert len(perturbations) % 2 == 0, "perturbations must have even length" - clusters = {} - for i in range(0, len(perturbations), 2): - cluster_name = f"complex_{i // 2}" - clusters[cluster_name] = {"genes": [perturbations[i], perturbations[i + 1]]} - return clusters - - -def make_activity_map( - perturbations: list[str], - all_active: bool = True, - map_value: float = 1.0, -) -> pd.DataFrame: - """Build a minimal activity_map DataFrame as returned by phenotypic_activity_assesment.""" - return pd.DataFrame( - { - "perturbation": perturbations, - "below_corrected_p": [all_active] * len(perturbations), - "mean_average_precision": [map_value] * len(perturbations), - "corrected_p_value": [0.01 if all_active else 0.5] * len(perturbations), - } - ) - - -def make_consistency_map( - n_complexes: int, - all_significant: bool = True, - map_value: float = 1.0, -) -> pd.DataFrame: - """Build a minimal consistency_map DataFrame as returned by phenotypic_consistency_*.""" - return pd.DataFrame( - { - "complex_id": [f"c{i}" for i in range(n_complexes)], - "below_corrected_p": [all_significant] * n_complexes, - "mean_average_precision": [map_value] * n_complexes, - "corrected_p_value": [0.01 if all_significant else 0.5] * n_complexes, - } - ) - - -# --------------------------------------------------------------------------- -# pytest fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def guide_adata_identical(): - """2 perturbations × 2 guides each; identical embeddings within perturbation.""" - embeddings = np.array( - [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] - ) - return make_guide_adata(["A", "A", "B", "B"], embeddings) - - -@pytest.fixture -def guide_adata_orthogonal(): - """2 perturbations × 2 guides each; orthogonal embeddings within perturbation.""" - embeddings = np.array( - [ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - return make_guide_adata(["A", "A", "B", "B"], embeddings) - - -@pytest.fixture -def gene_adata_identical(): - """4 genes in 2 complexes; identical embeddings within complex.""" - perturbations = ["GENE_A1", "GENE_A2", "GENE_B1", "GENE_B2"] - embeddings = np.array( - [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] - ) - return make_gene_adata(perturbations, embeddings) - - -@pytest.fixture -def gene_adata_orthogonal(): - """4 genes in 2 complexes; orthogonal embeddings within complex.""" - perturbations = ["GENE_A1", "GENE_A2", "GENE_B1", "GENE_B2"] - embeddings = np.array( - [ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - return make_gene_adata(perturbations, embeddings) - - -@pytest.fixture -def gene_clusters_fixture(): - """Matching clusters dict for gene_adata_* fixtures.""" - return make_clusters(["GENE_A1", "GENE_A2", "GENE_B1", "GENE_B2"]) diff --git a/tests/eval/test_evaluate_gene.py b/tests/eval/test_evaluate_gene.py deleted file mode 100644 index 4dcaaf0..0000000 --- a/tests/eval/test_evaluate_gene.py +++ /dev/null @@ -1,201 +0,0 @@ -"""Tests for ops_model.eval.evaluate_gene.""" - -from __future__ import annotations - -from unittest.mock import patch - -import anndata as ad -import numpy as np -import pandas as pd -import pytest -import warnings - -from ops_model.eval.evaluate_gene import evaluate_gene_level -from tests.eval.conftest import ( - make_gene_adata, - make_activity_map, - make_consistency_map, - make_clusters, -) - -EXPECTED_KEYS = { - "pct_complexes_significant_manual", - "mean_map_complexes_manual", - "pct_complexes_significant_corum", - "mean_map_complexes_corum", - "mean_cosine_sim_within_complex", - "silhouette_within_complex", -} - -PERTURBATIONS = ["GENE_A1", "GENE_A2", "GENE_B1", "GENE_B2"] -CLUSTERS = make_clusters(PERTURBATIONS) - -_EMBEDDINGS_PERFECT = np.array( - [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] -) -_EMBEDDINGS_RANDOM = np.random.default_rng(1).random((4, 3)).astype(np.float32) - - -ACTIVITY_MAP = make_activity_map(PERTURBATIONS, all_active=True) - - -def _mock_all( - mock_manual, - mock_corum, - mock_load, - all_significant=True, - map_value=1.0, -): - consistency_df = make_consistency_map(2, all_significant=all_significant, map_value=map_value) - mock_manual.return_value = (consistency_df, float(all_significant)) - mock_corum.return_value = (consistency_df, float(all_significant)) - mock_load.return_value = CLUSTERS - - -# --------------------------------------------------------------------------- -# Coverage tests -# --------------------------------------------------------------------------- - -def test_raises_on_missing_n_cells(): - adata = ad.AnnData( - X=np.eye(3), - obs=pd.DataFrame( - { - "perturbation": ["A", "B", "C"], - "guides": [["sg_0"]] * 3, - "n_experiments": [1] * 3, - }, - index=["A", "B", "C"], - ), - ) - adata.uns["aggregation_method"] = "mean" - adata.uns["cell_type"] = "HeLa" - adata.uns["embedding_type"] = "test" - with pytest.raises(Exception): - evaluate_gene_level(adata) - - -def test_raises_on_missing_perturbation(): - adata = ad.AnnData( - X=np.eye(3), - obs=pd.DataFrame( - {"n_cells": [10, 10, 10], "guides": [["sg_0"]] * 3, "n_experiments": [1] * 3}, - index=["A", "B", "C"], - ), - ) - adata.uns["aggregation_method"] = "mean" - with pytest.raises(Exception): - evaluate_gene_level(adata) - - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_returns_expected_keys(mock_manual, mock_corum, mock_load): - _mock_all(mock_manual, mock_corum, mock_load) - adata = make_gene_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - result = evaluate_gene_level(adata, activity_map=ACTIVITY_MAP) - assert set(result.keys()) == EXPECTED_KEYS - - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_runs_without_error(mock_manual, mock_corum, mock_load): - _mock_all(mock_manual, mock_corum, mock_load) - adata = make_gene_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - evaluate_gene_level(adata, activity_map=ACTIVITY_MAP) # should not raise - - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_no_activity_map_warns_and_runs(mock_manual, mock_corum, mock_load): - """No activity_map → UserWarning issued, function still runs.""" - _mock_all(mock_manual, mock_corum, mock_load) - adata = make_gene_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - evaluate_gene_level(adata) # no activity_map - assert any(issubclass(w.category, UserWarning) for w in caught) - - -# --------------------------------------------------------------------------- -# Metric correctness tests -# --------------------------------------------------------------------------- - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_perfect_complex_structure_map_metrics(mock_manual, mock_corum, mock_load): - """Mocked perfect structure → all mAP scalars are 1.0, all pct are 1.0.""" - _mock_all(mock_manual, mock_corum, mock_load, all_significant=True, map_value=1.0) - adata = make_gene_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - result = evaluate_gene_level(adata, activity_map=ACTIVITY_MAP) - assert np.isclose(result["pct_complexes_significant_manual"], 1.0) - assert np.isclose(result["mean_map_complexes_manual"], 1.0) - assert np.isclose(result["pct_complexes_significant_corum"], 1.0) - assert np.isclose(result["mean_map_complexes_corum"], 1.0) - - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_no_structure_runs_without_error(mock_manual, mock_corum, mock_load): - _mock_all(mock_manual, mock_corum, mock_load, all_significant=False, map_value=0.5) - adata = make_gene_adata(PERTURBATIONS, _EMBEDDINGS_RANDOM) - evaluate_gene_level(adata, activity_map=ACTIVITY_MAP) # should not raise - - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_identical_complex_members_cosine_sim_one(mock_manual, mock_corum, mock_load): - """Identical embeddings within each complex → cosine sim = 1.0.""" - _mock_all(mock_manual, mock_corum, mock_load) - adata = make_gene_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - result = evaluate_gene_level(adata, activity_map=ACTIVITY_MAP) - assert np.isclose(result["mean_cosine_sim_within_complex"], 1.0) - - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_orthogonal_complex_members_cosine_sim_zero(mock_manual, mock_corum, mock_load): - """Orthogonal embeddings within each complex → cosine sim = 0.0.""" - _mock_all(mock_manual, mock_corum, mock_load) - embeddings = np.array( - [ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - adata = make_gene_adata(PERTURBATIONS, embeddings) - result = evaluate_gene_level(adata, activity_map=ACTIVITY_MAP) - assert np.isclose(result["mean_cosine_sim_within_complex"], 0.0) - - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_genes_absent_from_complexes_do_not_affect_cosine(mock_manual, mock_corum, mock_load): - """Genes not in any complex should be ignored in the cosine similarity score.""" - _mock_all(mock_manual, mock_corum, mock_load) - perturbations = PERTURBATIONS + ["EXTRA_GENE"] - embeddings = np.vstack([_EMBEDDINGS_PERFECT, [[0.5, 0.5, 0.0]]]) - adata = make_gene_adata(perturbations, embeddings) - result = evaluate_gene_level(adata, activity_map=ACTIVITY_MAP) - assert np.isclose(result["mean_cosine_sim_within_complex"], 1.0) - - -@patch("ops_model.eval.evaluate_gene._load_gene_clusters") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_corum") -@patch("ops_model.eval.evaluate_gene.phenotypic_consistency_manual_annotation") -def test_perfect_separation_silhouette_one(mock_manual, mock_corum, mock_load): - """Perfect complex structure → silhouette = 1.0.""" - _mock_all(mock_manual, mock_corum, mock_load) - adata = make_gene_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - result = evaluate_gene_level(adata, activity_map=ACTIVITY_MAP) - assert np.isclose(result["silhouette_within_complex"], 1.0) diff --git a/tests/eval/test_evaluate_guide.py b/tests/eval/test_evaluate_guide.py deleted file mode 100644 index 1ac1c53..0000000 --- a/tests/eval/test_evaluate_guide.py +++ /dev/null @@ -1,179 +0,0 @@ -"""Tests for ops_model.eval.evaluate_guide.""" - -from __future__ import annotations - -from unittest.mock import patch - -import anndata as ad -import numpy as np -import pandas as pd -import pytest -import warnings - -from ops_model.eval.evaluate_guide import evaluate_guide_level -from tests.eval.conftest import ( - make_guide_adata, - make_activity_map, -) - -EXPECTED_KEYS = { - "pct_perturbations_active", - "mean_map_active", - "pct_pos_controls_active", - "mean_map_pos_controls", - "pct_perturbations_distinct", - "mean_map_distinct", - "mean_cosine_sim_within_gene", - "silhouette_within_gene", -} - -PERTURBATIONS = ["A", "A", "B", "B"] - -# Embeddings: perfect separation — A guides identical, B guides identical, A⊥B -_EMBEDDINGS_PERFECT = np.array( - [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] -) -_EMBEDDINGS_RANDOM = np.random.default_rng(0).random((4, 3)).astype(np.float32) - -POS_CONTROLS = {"complex1": {"genes": ["A", "B"]}} - - -def _mock_all(mock_activity, mock_distinct, mock_load, all_active=True, map_value=1.0): - activity_df = make_activity_map(["A", "B"], all_active=all_active, map_value=map_value) - mock_activity.return_value = (activity_df, float(all_active)) - distinct_df = make_activity_map(["A", "B"], all_active=all_active, map_value=map_value) - mock_distinct.return_value = (distinct_df, float(all_active)) - mock_load.return_value = POS_CONTROLS - - -# --------------------------------------------------------------------------- -# Coverage tests -# --------------------------------------------------------------------------- - -def test_raises_on_missing_sgRNA(): - adata = ad.AnnData( - X=np.eye(3), - obs=pd.DataFrame({"perturbation": ["A", "B", "C"]}, index=["A", "B", "C"]), - ) - adata.uns["aggregation_method"] = "mean" - adata.uns["cell_type"] = "HeLa" - adata.uns["embedding_type"] = "test" - with pytest.raises(Exception): - evaluate_guide_level(adata) - - -def test_raises_on_missing_perturbation(): - adata = ad.AnnData( - X=np.eye(3), - obs=pd.DataFrame({"sgRNA": ["sg_0", "sg_1", "sg_2"]}, index=["sg_0", "sg_1", "sg_2"]), - ) - adata.uns["aggregation_method"] = "mean" - with pytest.raises(Exception): - evaluate_guide_level(adata) - - -@patch("ops_model.eval.evaluate_guide._load_pos_controls") -@patch("ops_model.eval.evaluate_guide.phenotypic_distinctivness") -@patch("ops_model.eval.evaluate_guide.phenotypic_activity_assesment") -def test_returns_expected_keys(mock_activity, mock_distinct, mock_load): - _mock_all(mock_activity, mock_distinct, mock_load) - adata = make_guide_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - result, _ = evaluate_guide_level(adata) - assert set(result.keys()) == EXPECTED_KEYS - - -@patch("ops_model.eval.evaluate_guide._load_pos_controls") -@patch("ops_model.eval.evaluate_guide.phenotypic_distinctivness") -@patch("ops_model.eval.evaluate_guide.phenotypic_activity_assesment") -def test_returns_activity_map(mock_activity, mock_distinct, mock_load): - _mock_all(mock_activity, mock_distinct, mock_load) - adata = make_guide_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - _, activity_map = evaluate_guide_level(adata) - assert isinstance(activity_map, pd.DataFrame) - assert "perturbation" in activity_map.columns - assert "below_corrected_p" in activity_map.columns - - -@patch("ops_model.eval.evaluate_guide._load_pos_controls") -@patch("ops_model.eval.evaluate_guide.phenotypic_distinctivness") -@patch("ops_model.eval.evaluate_guide.phenotypic_activity_assesment") -def test_runs_without_error(mock_activity, mock_distinct, mock_load): - _mock_all(mock_activity, mock_distinct, mock_load) - adata = make_guide_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - evaluate_guide_level(adata) # should not raise - - -# --------------------------------------------------------------------------- -# Metric correctness tests -# --------------------------------------------------------------------------- - -@patch("ops_model.eval.evaluate_guide._load_pos_controls") -@patch("ops_model.eval.evaluate_guide.phenotypic_distinctivness") -@patch("ops_model.eval.evaluate_guide.phenotypic_activity_assesment") -def test_perfect_separation_map_metrics(mock_activity, mock_distinct, mock_load): - """Mocked perfect separation → all mAP scalars are 1.0, all pct are 1.0.""" - _mock_all(mock_activity, mock_distinct, mock_load, all_active=True, map_value=1.0) - adata = make_guide_adata(PERTURBATIONS, _EMBEDDINGS_PERFECT) - result, _ = evaluate_guide_level(adata) - assert np.isclose(result["pct_perturbations_active"], 1.0) - assert np.isclose(result["mean_map_active"], 1.0) - assert np.isclose(result["pct_pos_controls_active"], 1.0) - assert np.isclose(result["mean_map_pos_controls"], 1.0) - assert np.isclose(result["pct_perturbations_distinct"], 1.0) - assert np.isclose(result["mean_map_distinct"], 1.0) - - -@patch("ops_model.eval.evaluate_guide._load_pos_controls") -@patch("ops_model.eval.evaluate_guide.phenotypic_distinctivness") -@patch("ops_model.eval.evaluate_guide.phenotypic_activity_assesment") -def test_no_separation_runs_without_error(mock_activity, mock_distinct, mock_load): - _mock_all(mock_activity, mock_distinct, mock_load, all_active=False, map_value=0.5) - adata = make_guide_adata(PERTURBATIONS, _EMBEDDINGS_RANDOM) - evaluate_guide_level(adata) # should not raise (returns tuple, but we don't need it) - - -@patch("ops_model.eval.evaluate_guide._load_pos_controls") -@patch("ops_model.eval.evaluate_guide.phenotypic_distinctivness") -@patch("ops_model.eval.evaluate_guide.phenotypic_activity_assesment") -def test_identical_guides_cosine_sim_one(mock_activity, mock_distinct, mock_load): - """Identical guides within each perturbation → cosine sim = 1.0.""" - _mock_all(mock_activity, mock_distinct, mock_load) - embeddings = np.array( - [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] - ) - adata = make_guide_adata(PERTURBATIONS, embeddings) - result, _ = evaluate_guide_level(adata) - assert np.isclose(result["mean_cosine_sim_within_gene"], 1.0) - - -@patch("ops_model.eval.evaluate_guide._load_pos_controls") -@patch("ops_model.eval.evaluate_guide.phenotypic_distinctivness") -@patch("ops_model.eval.evaluate_guide.phenotypic_activity_assesment") -def test_orthogonal_guides_cosine_sim_zero(mock_activity, mock_distinct, mock_load): - """Orthogonal guides within each perturbation → cosine sim = 0.0.""" - _mock_all(mock_activity, mock_distinct, mock_load) - embeddings = np.array( - [ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - adata = make_guide_adata(PERTURBATIONS, embeddings) - result, _ = evaluate_guide_level(adata) - assert np.isclose(result["mean_cosine_sim_within_gene"], 0.0) - - -@patch("ops_model.eval.evaluate_guide._load_pos_controls") -@patch("ops_model.eval.evaluate_guide.phenotypic_distinctivness") -@patch("ops_model.eval.evaluate_guide.phenotypic_activity_assesment") -def test_perfect_separation_silhouette_one(mock_activity, mock_distinct, mock_load): - """Perfect cluster separation → silhouette = 1.0.""" - _mock_all(mock_activity, mock_distinct, mock_load) - embeddings = np.array( - [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] - ) - adata = make_guide_adata(PERTURBATIONS, embeddings) - result, _ = evaluate_guide_level(adata) - assert np.isclose(result["silhouette_within_gene"], 1.0) diff --git a/tests/eval/test_metrics.py b/tests/eval/test_metrics.py deleted file mode 100644 index 9961851..0000000 --- a/tests/eval/test_metrics.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Tests for ops_model.eval.metrics.""" - -from __future__ import annotations - -import numpy as np -import pytest - -from ops_model.eval.metrics import mean_cosine_sim_within_groups -from tests.eval.conftest import make_guide_adata - - -# --------------------------------------------------------------------------- -# Coverage tests -# --------------------------------------------------------------------------- - -def test_returns_float(): - adata = make_guide_adata(["A", "A"], np.array([[1.0, 0.0], [1.0, 0.0]])) - result = mean_cosine_sim_within_groups(adata, [[0, 1]]) - assert isinstance(result, float) - - -def test_result_in_valid_range(): - adata = make_guide_adata(["A", "A"], np.array([[1.0, 0.0], [0.5, 0.5]])) - result = mean_cosine_sim_within_groups(adata, [[0, 1]]) - assert -1.0 <= result <= 1.0 - - -def test_skips_groups_smaller_than_two(): - # Both groups have 1 member — all skipped → NaN - adata = make_guide_adata(["A", "B"], np.eye(2)) - result = mean_cosine_sim_within_groups(adata, [[0], [1]]) - assert np.isnan(result) - - -def test_empty_groups_list_returns_nan(): - adata = make_guide_adata(["A"], np.array([[1.0, 0.0]])) - result = mean_cosine_sim_within_groups(adata, []) - assert np.isnan(result) - - -def test_single_group(): - adata = make_guide_adata(["A", "A"], np.array([[1.0, 0.0], [1.0, 0.0]])) - result = mean_cosine_sim_within_groups(adata, [[0, 1]]) - assert np.isclose(result, 1.0) - - -def test_multiple_groups_unequal_size(): - embeddings = np.array( - [[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]] - ) - adata = make_guide_adata(["A", "A", "A", "B", "B"], embeddings) - result = mean_cosine_sim_within_groups(adata, [[0, 1, 2], [3, 4]]) - assert np.isclose(result, 1.0) - - -# --------------------------------------------------------------------------- -# Metric correctness tests (synthetic embeddings) -# --------------------------------------------------------------------------- - -def test_identical_vectors_returns_one(): - adata = make_guide_adata(["A", "A"], np.array([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]])) - result = mean_cosine_sim_within_groups(adata, [[0, 1]]) - assert np.isclose(result, 1.0) - - -def test_orthogonal_vectors_returns_zero(): - adata = make_guide_adata(["A", "A"], np.array([[1.0, 0.0], [0.0, 1.0]])) - result = mean_cosine_sim_within_groups(adata, [[0, 1]]) - assert np.isclose(result, 0.0) - - -def test_anti_parallel_vectors_returns_minus_one(): - adata = make_guide_adata(["A", "A"], np.array([[1.0, 0.0], [-1.0, 0.0]])) - result = mean_cosine_sim_within_groups(adata, [[0, 1]]) - assert np.isclose(result, -1.0) - - -def test_known_analytic_value(): - # angle between [1,0] and [1,1]/sqrt(2) is 45° → cos = sqrt(2)/2 - embeddings = np.array([[1.0, 0.0], [1.0, 1.0]]) - adata = make_guide_adata(["A", "A"], embeddings) - result = mean_cosine_sim_within_groups(adata, [[0, 1]]) - assert np.isclose(result, np.sqrt(2) / 2, atol=1e-5) diff --git a/tests/eval/test_run_eval.py b/tests/eval/test_run_eval.py deleted file mode 100644 index c2e7e3d..0000000 --- a/tests/eval/test_run_eval.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Tests for ops_model.eval.run_eval CLI.""" - -from __future__ import annotations - -import csv -from pathlib import Path -from unittest.mock import patch, MagicMock - -import pytest - -from ops_model.eval.run_eval import main, _default_output_path - - -# --------------------------------------------------------------------------- -# Coverage tests -# --------------------------------------------------------------------------- - -def test_default_output_path_in_same_directory(tmp_path): - embedding = str(tmp_path / "embeddings.h5ad") - output = _default_output_path(embedding) - assert Path(output).parent == tmp_path - assert output.endswith("_eval.csv") - - -def test_raises_when_no_embedding_provided(monkeypatch): - monkeypatch.setattr("sys.argv", ["run_eval"]) - with pytest.raises(SystemExit): - main() - - -def test_guide_only_produces_guide_columns(tmp_path, monkeypatch): - output_csv = str(tmp_path / "out.csv") - guide_h5ad = str(tmp_path / "guide.h5ad") - guide_results = { - "pct_perturbations_active": 0.8, - "mean_map_active": 0.7, - "pct_pos_controls_active": 1.0, - "mean_map_pos_controls": 0.9, - "pct_perturbations_distinct": 0.6, - "mean_map_distinct": 0.65, - "mean_cosine_sim_within_gene": 0.9, - "silhouette_within_gene": 0.85, - } - - monkeypatch.setattr("sys.argv", ["run_eval", "--guide_embedding", guide_h5ad, "--output", output_csv]) - - with patch("ops_model.eval.run_eval.ad.read_h5ad", return_value=MagicMock()): - with patch("ops_model.eval.run_eval.evaluate_guide_level", return_value=(guide_results, MagicMock())): - main() - - with open(output_csv) as f: - reader = csv.DictReader(f) - rows = list(reader) - - assert len(rows) == 1 - assert "guide_embedding_path" in rows[0] - assert "gene_embedding_path" not in rows[0] - assert "pct_perturbations_active" in rows[0] - assert "pct_complexes_significant_manual" not in rows[0] - - -def test_gene_only_produces_gene_columns(tmp_path, monkeypatch): - output_csv = str(tmp_path / "out.csv") - gene_h5ad = str(tmp_path / "gene.h5ad") - gene_results = { - "pct_complexes_significant_manual": 0.9, - "mean_map_complexes_manual": 0.85, - "pct_complexes_significant_corum": 0.7, - "mean_map_complexes_corum": 0.75, - "mean_cosine_sim_within_complex": 0.8, - "silhouette_within_complex": 0.7, - } - - monkeypatch.setattr("sys.argv", ["run_eval", "--gene_embedding", gene_h5ad, "--output", output_csv]) - - with patch("ops_model.eval.run_eval.ad.read_h5ad", return_value=MagicMock()): - with patch("ops_model.eval.run_eval.evaluate_gene_level", return_value=gene_results): - main() - - with open(output_csv) as f: - reader = csv.DictReader(f) - rows = list(reader) - - assert len(rows) == 1 - assert "gene_embedding_path" in rows[0] - assert "guide_embedding_path" not in rows[0] - assert "pct_complexes_significant_manual" in rows[0] - assert "pct_perturbations_active" not in rows[0] - - -def test_both_embeddings_merged_into_one_row(tmp_path, monkeypatch): - output_csv = str(tmp_path / "out.csv") - guide_h5ad = str(tmp_path / "guide.h5ad") - gene_h5ad = str(tmp_path / "gene.h5ad") - guide_results = {"pct_perturbations_active": 0.8, "mean_cosine_sim_within_gene": 0.9, - "silhouette_within_gene": 0.85, "mean_map_active": 0.7, - "pct_pos_controls_active": 1.0, "mean_map_pos_controls": 0.9, - "pct_perturbations_distinct": 0.6, "mean_map_distinct": 0.65} - gene_results = {"pct_complexes_significant_manual": 0.9, "mean_cosine_sim_within_complex": 0.8, - "silhouette_within_complex": 0.7, "mean_map_complexes_manual": 0.85, - "pct_complexes_significant_corum": 0.7, "mean_map_complexes_corum": 0.75} - - monkeypatch.setattr( - "sys.argv", - ["run_eval", "--guide_embedding", guide_h5ad, "--gene_embedding", gene_h5ad, "--output", output_csv], - ) - - with patch("ops_model.eval.run_eval.ad.read_h5ad", return_value=MagicMock()): - with patch("ops_model.eval.run_eval.evaluate_guide_level", return_value=(guide_results, MagicMock())): - with patch("ops_model.eval.run_eval.evaluate_gene_level", return_value=gene_results): - main() - - with open(output_csv) as f: - reader = csv.DictReader(f) - rows = list(reader) - - assert len(rows) == 1 - assert "guide_embedding_path" in rows[0] - assert "gene_embedding_path" in rows[0] - assert "pct_perturbations_active" in rows[0] - assert "pct_complexes_significant_manual" in rows[0] - - -def test_default_output_path_used_when_not_specified(tmp_path, monkeypatch): - guide_h5ad = str(tmp_path / "guide.h5ad") - guide_results = {"pct_perturbations_active": 0.8, "mean_cosine_sim_within_gene": 0.9, - "silhouette_within_gene": 0.85, "mean_map_active": 0.7, - "pct_pos_controls_active": 1.0, "mean_map_pos_controls": 0.9, - "pct_perturbations_distinct": 0.6, "mean_map_distinct": 0.65} - - monkeypatch.setattr("sys.argv", ["run_eval", "--guide_embedding", guide_h5ad]) - - with patch("ops_model.eval.run_eval.ad.read_h5ad", return_value=MagicMock()): - with patch("ops_model.eval.run_eval.evaluate_guide_level", return_value=(guide_results, MagicMock())): - main() - - output_files = list(tmp_path.glob("*_eval.csv")) - assert len(output_files) == 1 From 94d974b2a31e1c8c4bc9bfcb7ab4f24fd96a2f70 Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 10 Jun 2026 16:00:24 -0700 Subject: [PATCH 04/11] Unify CSV->AnnData processing into process_features_csv Replace the two near-duplicate pipelines (evaluate_cp.process and evaluate_embeddings.process_embedding_csv) with a single processing_common.process_features_csv that branches on feature type: CellProfiler builds the cell AnnData and splits by reporter; embedding models (dinov3/cell_dino/subcell) build one per-channel AnnData. Shared embedding-config parsing, guide/gene aggregation, and validate_and_save now live in processing_common; cp_features and batch_process_embeddings call the single entry point and the old ones are removed. Also folds in the features/ cleanup: dead functions + unused imports removed, the good_rows_mask NaN-row bug fixed, and the broken test_evaluate_dinov3.py deleted (its module was generalized into evaluate_embeddings). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../features/batch_process_embeddings.py | 8 +- src/ops_model/features/cp_features.py | 4 +- src/ops_model/features/evaluate_cp.py | 488 +---------- src/ops_model/features/evaluate_embeddings.py | 413 +--------- src/ops_model/features/processing_common.py | 298 +++++++ tests/features/test_evaluate_dinov3.py | 764 ------------------ tests/models/test_extractor_e2e.py | 6 +- 7 files changed, 307 insertions(+), 1674 deletions(-) create mode 100644 src/ops_model/features/processing_common.py delete mode 100644 tests/features/test_evaluate_dinov3.py diff --git a/src/ops_model/features/batch_process_embeddings.py b/src/ops_model/features/batch_process_embeddings.py index c957a59..683001d 100644 --- a/src/ops_model/features/batch_process_embeddings.py +++ b/src/ops_model/features/batch_process_embeddings.py @@ -28,8 +28,7 @@ import yaml -from ops_model.features.evaluate_embeddings import process_embedding_csv -from ops_model.features.evaluate_cp import process +from ops_model.features.processing_common import process_features_csv # Base directory for OPS experiments @@ -144,10 +143,7 @@ def process_experiment( # Process print(f"\nProcessing {feature_type} features...") try: - if feature_type == "cellprofiler": - adata = process(str(csv_path), config_path=config_path) - else: - adata = process_embedding_csv(str(csv_path), config_path=config_path) + adata = process_features_csv(str(csv_path), config_path=config_path) print(f"✓ Processing complete") except Exception as e: print(f"✗ Processing failed: {e}") diff --git a/src/ops_model/features/cp_features.py b/src/ops_model/features/cp_features.py index 0322b23..c184570 100644 --- a/src/ops_model/features/cp_features.py +++ b/src/ops_model/features/cp_features.py @@ -279,13 +279,13 @@ def anndata_conversion_worker(csv_path: str, config_path: str = None): ... ) '/path/to/anndata_objects' """ - from ops_model.features.evaluate_cp import process + from ops_model.features.processing_common import process_features_csv print(f"Starting AnnData conversion for {csv_path}") print(f"Using config: {config_path if config_path else 'default settings'}") # Process the CSV (creates 3 .h5ad files) - process(save_path=csv_path, config_path=config_path) + process_features_csv(save_path=csv_path, config_path=config_path) # Return the output directory output_dir = Path(csv_path).parent / "anndata_objects" diff --git a/src/ops_model/features/evaluate_cp.py b/src/ops_model/features/evaluate_cp.py index bc0d924..13a0460 100644 --- a/src/ops_model/features/evaluate_cp.py +++ b/src/ops_model/features/evaluate_cp.py @@ -1,5 +1,4 @@ # %% -from tqdm import tqdm from pathlib import Path import re import time @@ -7,14 +6,7 @@ import numpy as np import pandas as pd -import scanpy as sc -import scanpy.external as sce import anndata as ad -import matplotlib.pyplot as plt -from sklearn.decomposition import PCA - -from ops_model.data.paths import OpsPaths -from ops_model.features.anndata_utils import create_aggregated_embeddings, pca_embed DEFAULT_GUIDE_COL = "sgRNA" @@ -44,63 +36,6 @@ def timer(name: str): print(f"[TIMING] {name}: {elapsed:.2f} seconds") -def pca_embed( - adata: ad.AnnData, - n_components: int = 128, - variance_plot=False, -) -> ad.AnnData: - - sc.tl.pca(adata, n_comps=n_components) - if variance_plot: - sc.pl.pca_variance_ratio(adata, n_pcs=100, log=False, save=False) - plt.figure() - - return adata - - -def pca_fit_manual(df: pd.DataFrame, n_components: int = 128): - pca = PCA(n_components=n_components) - pca.fit_transform(df) - return pca - - -def cell_size(save_path: str): - features = pd.read_csv(save_path) - - df_sorted = ( - features.groupby("label_str", observed=False) - .mean() - .sort_values(by="cell_mask_sizeshape_Area", ascending=False) - ) - - return df_sorted["cell_mask_sizeshape_Area"] - - -def _convert_array_strings_to_float(value): - """ - Convert string representations of arrays to float values. - Examples: '[0.2875]' -> 0.2875, '[1.0, 2.0]' -> 1.5 (mean) - """ - if isinstance(value, str): - # Check if it looks like an array string - if value.startswith("[") and value.endswith("]"): - try: - # Remove brackets and split by comma - inner = value[1:-1].strip() - if inner: - values = [float(x.strip()) for x in inner.split(",")] - # If single value, return it; otherwise raise an error - if len(values) == 1: - return values[0] - else: - raise ValueError(f"Too many values in array string: {value}") - else: - return 0.0 - except (ValueError, AttributeError): - return np.nan - return value - - def create_adata_object( save_path: str, config: dict = None, @@ -224,27 +159,6 @@ def create_adata_object( cols_to_drop = [col for col in nonfeature_cols if col in features.columns] features = features.drop(columns=cols_to_drop) - with timer("Converting array strings to floats"): - pass - # Convert any string representations of arrays to float values - # for col in features.columns: - # if col == "label_str": - # continue - # # Check if column contains string array representations - # if features[col].dtype == "object": - # features[col] = features[col].apply(_convert_array_strings_to_float) - # # Convert to numeric, coercing errors to NaN - # features[col] = pd.to_numeric(features[col], errors="coerce") - - with timer("Converting numeric columns to float32"): - pass - # Convert numeric columns to float32 for memory efficiency - # This can halve memory usage compared to float64 - # numeric_cols = features.select_dtypes(include=[np.number]).columns - # numeric_cols = [col for col in numeric_cols if col != "label_str"] - # features[numeric_cols] = features[numeric_cols].astype("float32") - # print(f"Converted {len(numeric_cols)} numeric columns to float32") - with timer("Dropping constant columns and nans"): if config is not None and config["processing"].get("cell-profiler", False): features = features.dropna(subset=["cell_Area"]) @@ -266,7 +180,7 @@ def create_adata_object( # Filter rows with too many NaNs and track which rows are kept num_nan_features_per_row = features.isna().sum(axis=1) - good_rows_mask = num_nan_features_per_row <= 0 if config else 0 + good_rows_mask = num_nan_features_per_row <= 0 features = features[good_rows_mask] # Update metadata arrays to match filtered rows @@ -458,403 +372,3 @@ def split_adata_by_reporter(adata: ad.AnnData, verbose: bool = True) -> dict: ) return reporter_adatas - - -def process(save_path: str, config_path: str = None): - """ - Process CellProfiler features through the full pipeline - - Reporter names are always derived from FeatureMetadata. If multiple channels are present: - 1. Create combined AnnData object - 2. Split by reporter/biological signal - 3. Save separate files for each reporter (keyed by reporter name for feature consistency): - - features_processed_{reporter}.h5ad (cell-level) - - guide_bulked_{reporter}.h5ad - - gene_bulked_{reporter}.h5ad - - Args: - save_path: Path to features CSV - config_path: Path to configuration YAML file (must include 'cell_type' field) - """ - print("\n" + "=" * 60) - print("Starting feature processing pipeline") - print("=" * 60 + "\n") - - total_start = time.time() - - # Load config if provided - config = {} - if config_path is not None: - import yaml - - with open(config_path, "r") as f: - config = yaml.safe_load(f) - print(f"Loaded configuration from {config_path}") - else: - print("No configuration file provided, using default settings.") - - save_path = Path(save_path) - save_dir = save_path.parent / "anndata_objects" - save_dir.mkdir(parents=True, exist_ok=True) - - # Define single checkpoint path - checkpoint_path = save_dir / "features_processed.h5ad" - - # Extract cell_type and embedding_type from config - cell_type = config.get("cell_type", None) if config else None - embedding_type = ( - config.get("embedding_type", "cellprofiler") if config else "cellprofiler" - ) - - if not cell_type: - raise ValueError( - "cell_type must be specified in config. " - "Add to config YAML:\n" - " cell_type: 'A549' # or your cell line (e.g., 'HeLa', 'RPE1')" - ) - - # Create anndata object with all features combined - with timer("TOTAL: Create AnnData object"): - features_adata = create_adata_object( - save_path, - config=config, - cell_type=cell_type, - embedding_type=embedding_type, - ) - - # Always split by reporter when channel_mapping is present (even single channel) - has_multiple_channels = ( - "channel_mapping" in features_adata.uns - and len(features_adata.uns["channel_mapping"]) >= 1 - ) - - # Read aggregation configuration (same as DinoV3) - agg_config = config.get("aggregation", {}) - guide_config = agg_config.get("guide_level", {}) - gene_config = agg_config.get("gene_level", {}) - - if has_multiple_channels: - # Split by reporter and save separate files for each reporter - print("\n" + "=" * 60) - print("SPLITTING BY REPORTER (SAVING BY REPORTER NAME)") - print("=" * 60) - - with timer("TOTAL: Split AnnData by reporter"): - reporter_adatas = split_adata_by_reporter(features_adata, verbose=True) - - # Save cell-level, guide-level, and gene-level for each reporter - for reporter, adata_cell in reporter_adatas.items(): - channel_name = adata_cell.uns["channel"] - print( - f"\n--- Processing reporter: {reporter} (channel: {channel_name}) ---" - ) - - # Save cell-level with REPORTER NAME - cell_path = save_dir / f"features_processed_{reporter}.h5ad" - with timer(f"Save cell-level for {reporter}"): - adata_cell.write_h5ad(cell_path) - print(f" Saved: {cell_path}") - - # Guide-level aggregation (configurable) - if guide_config.get( - "enabled", True - ): # Default True for backwards compatibility - with timer(f"Guide-level aggregation for {reporter}"): - # Get embedding settings from config - guide_embeddings = guide_config.get("embeddings", {}) - compute_embeddings = guide_config.get("compute_embeddings", True) - - # Get embedding parameters - n_pca = guide_embeddings.get("n_pca_components", 128) - n_neighbors = guide_embeddings.get("n_neighbors", 15) - compute_pca = ( - guide_embeddings.get("pca", True) - if compute_embeddings - else False - ) - compute_umap = ( - guide_embeddings.get("umap", True) - if compute_embeddings - else False - ) - compute_phate = ( - guide_embeddings.get("phate", True) - if compute_embeddings - else False - ) - - # Validate: UMAP requires PCA - if compute_umap and not compute_pca: - print( - f" WARNING: UMAP requires PCA. Enabling PCA for {reporter}." - ) - compute_pca = True - - adata_guide = create_aggregated_embeddings( - adata_cell, - level="guide", - n_pca_components=n_pca, - n_neighbors=n_neighbors, - compute_pca=compute_pca, - compute_umap=compute_umap, - compute_phate=compute_phate, - ) - - # Save output if enabled - if guide_config.get("save_output", True): - guide_path = save_dir / f"guide_bulked_{reporter}.h5ad" - adata_guide.write_h5ad(guide_path) - print(f" Saved: {guide_path}") - else: - print(f" Skipped saving guide-level (save_output=False)") - else: - print(f" Skipped guide-level aggregation (enabled=False)") - - # Gene-level aggregation (configurable) - if gene_config.get( - "enabled", True - ): # Default True for backwards compatibility - with timer(f"Gene-level aggregation for {reporter}"): - # Get embedding settings from config - gene_embeddings = gene_config.get("embeddings", {}) - compute_embeddings = gene_config.get("compute_embeddings", True) - - # Get embedding parameters - n_pca = gene_embeddings.get("n_pca_components", 128) - n_neighbors = gene_embeddings.get("n_neighbors", 15) - compute_pca = ( - gene_embeddings.get("pca", True) - if compute_embeddings - else False - ) - compute_umap = ( - gene_embeddings.get("umap", True) - if compute_embeddings - else False - ) - compute_phate = ( - gene_embeddings.get("phate", True) - if compute_embeddings - else False - ) - - # Validate: UMAP requires PCA - if compute_umap and not compute_pca: - print( - f" WARNING: UMAP requires PCA. Enabling PCA for {reporter}." - ) - compute_pca = True - - adata_gene = create_aggregated_embeddings( - adata_cell, - level="gene", - n_pca_components=n_pca, - n_neighbors=n_neighbors, - compute_pca=compute_pca, - compute_umap=compute_umap, - compute_phate=compute_phate, - ) - - # Save output if enabled - if gene_config.get("save_output", True): - gene_path = save_dir / f"gene_bulked_{reporter}.h5ad" - adata_gene.write_h5ad(gene_path) - print(f" Saved: {gene_path}") - else: - print(f" Skipped saving gene-level (save_output=False)") - else: - print(f" Skipped gene-level aggregation (enabled=False)") - - total_time = time.time() - total_start - print("\n" + "=" * 60) - print("PIPELINE COMPLETE (SPLIT BY REPORTER, SAVED BY REPORTER)") - print("=" * 60) - print(f"Total time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)") - print(f"Reporters processed: {len(reporter_adatas)}") - for reporter, adata in reporter_adatas.items(): - channel_name = adata.uns["channel"] - print(f" - {reporter} (channel: {channel_name})") - print(f" Cell: {save_dir}/features_processed_{reporter}.h5ad") - print(f" Guide: {save_dir}/guide_bulked_{reporter}.h5ad") - print(f" Gene: {save_dir}/gene_bulked_{reporter}.h5ad") - print("=" * 60 + "\n") - - else: - # Original behavior: save combined file - print("\n(No channel_mapping found - saving combined file)") - - # Save cell-level - features_adata.write_h5ad(checkpoint_path) - print(f"Saved initial AnnData object to {checkpoint_path}") - - # Guide-level averaged analysis (configurable) - if guide_config.get( - "enabled", True - ): # Default True for backwards compatibility - with timer("TOTAL: Guide-level processing"): - # Get embedding settings from config - guide_embeddings = guide_config.get("embeddings", {}) - compute_embeddings = guide_config.get("compute_embeddings", True) - - # Get embedding parameters - n_pca = guide_embeddings.get("n_pca_components", 128) - n_neighbors = guide_embeddings.get("n_neighbors", 15) - compute_pca = ( - guide_embeddings.get("pca", True) if compute_embeddings else False - ) - compute_umap = ( - guide_embeddings.get("umap", True) if compute_embeddings else False - ) - compute_phate = ( - guide_embeddings.get("phate", True) if compute_embeddings else False - ) - - # Validate: UMAP requires PCA - if compute_umap and not compute_pca: - print("WARNING: UMAP requires PCA. Enabling PCA.") - compute_pca = True - - embeddings_guide_bulk_ad = create_aggregated_embeddings( - features_adata, - level="guide", - n_pca_components=n_pca, - n_neighbors=n_neighbors, - compute_pca=compute_pca, - compute_umap=compute_umap, - compute_phate=compute_phate, - ) - - # Save output if enabled - if guide_config.get("save_output", True): - guide_avg_path = save_dir / "guide_bulked.h5ad" - embeddings_guide_bulk_ad.write_h5ad(guide_avg_path) - - # Build embedding info for display - embeddings_computed = [] - if compute_pca: - embeddings_computed.append("PCA") - if compute_umap: - embeddings_computed.append("UMAP") - if compute_phate: - embeddings_computed.append("PHATE") - embeddings_str = ( - "+".join(embeddings_computed) if embeddings_computed else "none" - ) - - print( - f"Saved guide-bulked analysis to {guide_avg_path} (embeddings: {embeddings_str})" - ) - else: - print("Skipped saving guide-level (save_output=False)") - else: - print("Skipped guide-level aggregation (enabled=False)") - - # Gene-level averaged analysis (configurable) - if gene_config.get("enabled", True): # Default True for backwards compatibility - with timer("TOTAL: Gene-level processing"): - # Get embedding settings from config - gene_embeddings = gene_config.get("embeddings", {}) - compute_embeddings = gene_config.get("compute_embeddings", True) - - # Get embedding parameters - n_pca = gene_embeddings.get("n_pca_components", 128) - n_neighbors = gene_embeddings.get("n_neighbors", 15) - compute_pca = ( - gene_embeddings.get("pca", True) if compute_embeddings else False - ) - compute_umap = ( - gene_embeddings.get("umap", True) if compute_embeddings else False - ) - compute_phate = ( - gene_embeddings.get("phate", True) if compute_embeddings else False - ) - - # Validate: UMAP requires PCA - if compute_umap and not compute_pca: - print("WARNING: UMAP requires PCA. Enabling PCA.") - compute_pca = True - - embeddings_gene_avg_ad = create_aggregated_embeddings( - features_adata, - level="gene", - n_pca_components=n_pca, - n_neighbors=n_neighbors, - compute_pca=compute_pca, - compute_umap=compute_umap, - compute_phate=compute_phate, - ) - - # Save output if enabled - if gene_config.get("save_output", True): - gene_avg_path = save_dir / "gene_bulked.h5ad" - embeddings_gene_avg_ad.write_h5ad(gene_avg_path) - - # Build embedding info for display - embeddings_computed = [] - if compute_pca: - embeddings_computed.append("PCA") - if compute_umap: - embeddings_computed.append("UMAP") - if compute_phate: - embeddings_computed.append("PHATE") - embeddings_str = ( - "+".join(embeddings_computed) if embeddings_computed else "none" - ) - - print( - f"Saved gene-bulked analysis to {gene_avg_path} (embeddings: {embeddings_str})" - ) - else: - print("Skipped saving gene-level (save_output=False)") - else: - print("Skipped gene-level aggregation (enabled=False)") - - total_time = time.time() - total_start - print("\n" + "=" * 60) - print( - f"Pipeline completed in {total_time:.2f} seconds ({total_time/60:.2f} minutes)" - ) - print(f"Cell-level output: {checkpoint_path} (contains raw features)") - - # Show guide/gene outputs only if they were created - if guide_config.get("enabled", True) and guide_config.get("save_output", True): - print(f"Guide-bulked output: {save_dir / 'guide_bulked.h5ad'}") - if gene_config.get("enabled", True) and gene_config.get("save_output", True): - print(f"Gene-bulked output: {save_dir / 'gene_bulked.h5ad'}") - - print("=" * 60 + "\n") - - return features_adata - - -def _build_arg_parser(): - import argparse - - parser = argparse.ArgumentParser(description="Process features.") - parser.add_argument( - "--save_path", - type=str, - required=True, - help="Path to the CSV file containing CellProfiler features.", - ) - parser.add_argument( - "--config_path", - type=str, - default=None, - help="Path to configuration YAML file.", - ) - return parser - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - process( - args.save_path, - config_path=args.config_path, - ) - - -if __name__ == "__main__": - main() diff --git a/src/ops_model/features/evaluate_embeddings.py b/src/ops_model/features/evaluate_embeddings.py index 7b13189..4ccc84e 100644 --- a/src/ops_model/features/evaluate_embeddings.py +++ b/src/ops_model/features/evaluate_embeddings.py @@ -15,96 +15,13 @@ python evaluate_embeddings.py --save_path /path/to/features_Phase2D.csv """ -from tqdm import tqdm from pathlib import Path -import time -from contextlib import contextmanager import numpy as np import pandas as pd -import scanpy as sc -import scanpy.external as sce import anndata as ad -# Import shared utilities from CellProfiler evaluation -from ops_model.features.evaluate_cp import ( - # center_scale_fast, - timer, - NONFEATURE_COLUMNS, - DEFAULT_GUIDE_COL, -) -from ops_model.features.anndata_utils import create_aggregated_embeddings, pca_embed -from ops_model.post_process.anndata_processing.anndata_validator import ( - AnndataValidator, - IssueLevel, -) - - -def validate_and_save( - adata: ad.AnnData, - path: Path, - level: str, -) -> None: - """ - Validate AnnData object and save to h5ad file. - - Validation is always enforced with hard constraints - errors will raise exceptions. - - Args: - adata: AnnData object to validate and save - path: Path to save h5ad file - level: Schema level ("cell", "guide", "gene") - - Raises: - ValueError: If validation fails - """ - print(f"\nValidating {level}-level AnnData before saving...") - - # Initialize validator - validator = AnndataValidator() - - # Run validation (returns ValidationReport object) - report = validator.validate(adata, level=level) - - # Access errors and warnings from report - errors = report.errors - warnings = report.warnings - - # Report results - if report.is_valid: - print(f"✓ Validation passed: {level}-level AnnData is compliant") - else: - print(f"Validation found {len(errors)} errors, {len(warnings)} warnings") - - # Show errors - if errors: - print("\nERROR-level issues:") - for issue in errors[:10]: # Show first 10 - print(f" - {issue.field}: {issue.message}") - if len(errors) > 10: - print(f" ... and {len(errors) - 10} more errors") - - # Show warnings - if warnings: - print("\nWARNING-level issues:") - for issue in warnings[:5]: # Show first 5 - print(f" - {issue.field}: {issue.message}") - if len(warnings) > 5: - print(f" ... and {len(warnings) - 5} more warnings") - - # Fail fast if errors found (always enforced) - if errors: - error_summary = ( - f"{len(errors)} validation error(s) found in {level}-level AnnData" - ) - print(f"\n✗ Validation FAILED: {error_summary}") - print(f"Fix these issues before saving. First error: {errors[0].message}") - raise ValueError(error_summary) - - # Save file - print(f"Saving {level}-level AnnData to {path}") - adata.write_h5ad(path) - print(f"✓ Saved successfully: {path}") +from ops_model.features.evaluate_cp import timer, DEFAULT_GUIDE_COL def create_adata_object_embedding( @@ -236,16 +153,6 @@ def create_adata_object_embedding( features = features.drop(columns=constant_cols) print(f"Dropped constant columns, remaining features: {features.shape[1]}") - # with timer("Center-scaling features"): - # # Normalize embeddings (mean=0, std=1) - # features['label_str'] = gene_strs - # features_norm = center_scale_fast( - # features, - # on_controls=False, # Normalize on all data by default - # control_column="label_str", - # control_gene="NTC" - # ) - with timer("Creating AnnData object"): # Use features directly (no normalization applied) features_norm = features.copy() @@ -320,321 +227,3 @@ def create_adata_object_embedding( # Note: reporter is in .obs, no need to duplicate in .uns return adata - - -def process_embedding_csv( - save_path: str, - config_path: str = None, -): - """ - Process neural-network embedding features through the full pipeline. - - Unlike CellProfiler pipeline, this does NOT compute cell-level PCA/UMAP. - PCA and UMAP are only computed at guide/gene aggregation level. - - If reporter names are enabled (use_reporter_names=True in config), files will be - saved with reporter suffixes (e.g., features_processed_EEA1.h5ad) instead of - channel names (e.g., features_processed_GFP.h5ad). - - Args: - save_path: Path to embeddings CSV (e.g. dinov3_features_Phase2D.csv or - cell_dino_features_Phase2D.csv) - config_path: Path to configuration YAML file - - Returns: - Cell-level AnnData object (without PCA/UMAP) - """ - - config = {} - if config_path is not None: - import yaml - - with open(config_path, "r") as f: - config = yaml.safe_load(f) - print(f"Loaded configuration from {config_path}") - - # Extract required config parameters for validator compliance - cell_type = config.get("cell_type", None) - embedding_type = config.get("embedding_type", "dinov3") - - if not cell_type: - raise ValueError( - "cell_type must be specified in config for validator compliance.\n" - "Add to your config file:\n" - " cell_type: 'A549' # or your cell type" - ) - - print("\n" + "=" * 60) - print(f"Starting {embedding_type} feature processing pipeline") - print("=" * 60 + "\n") - - total_start = time.time() - - save_path = Path(save_path) - save_dir = save_path.parent / "anndata_objects" - save_dir.mkdir(parents=True, exist_ok=True) - - # Extract channel from CSV filename. - # Convention: {model_type}_features_{channel}.csv - # e.g. dinov3_features_Phase2D.csv, cell_dino_features_Phase2D.csv, subcell_features_GFP.csv - csv_stem = save_path.stem - marker = "_features_" - if marker in csv_stem: - channel = csv_stem[csv_stem.index(marker) + len(marker) :] - else: - channel = "unknown" - - print(f"Detected channel: {channel}") - - # Read experiment from CSV to enable reporter name lookup - with timer("Reading experiment from CSV"): - df_sample = pd.read_csv(save_path, nrows=1) - if "experiment" in df_sample.columns: - experiment = df_sample["experiment"].iloc[0] - print(f"Detected experiment: {experiment}") - else: - experiment = None - print("Warning: experiment column not found in CSV") - - # Use reporter name as filename suffix - from ops_utils.data.feature_metadata import FeatureMetadata - - meta = FeatureMetadata() - filename_suffix = meta.get_biological_signal(experiment, channel) - print(f"Using reporter name for files: {filename_suffix}") - - # Define checkpoint paths with appropriate suffix - checkpoint_path = save_dir / f"features_processed_{filename_suffix}.h5ad" - guide_avg_path = save_dir / f"guide_bulked_{filename_suffix}.h5ad" - gene_avg_path = save_dir / f"gene_bulked_{filename_suffix}.h5ad" - - # Create AnnData object from embeddings - with timer("TOTAL: Create AnnData object"): - features_adata = create_adata_object_embedding( - str(save_path), - config=config, - channel=channel, - experiment=experiment, - cell_type=cell_type, - embedding_type=embedding_type, - ) - - print( - f"Created AnnData with {features_adata.shape[0]} cells and {features_adata.shape[1]} features" - ) - - # Validate and save cell-level data - validate_and_save( - features_adata, - checkpoint_path, - level="cell", - ) - - # Read aggregation configuration - agg_config = config.get("aggregation", {}) - guide_config = agg_config.get("guide_level", {}) - gene_config = agg_config.get("gene_level", {}) - - # Guide-level aggregation and analysis (configurable) - if guide_config.get("enabled", True): # Default True for backwards compatibility - print("\n" + "=" * 60) - print("Guide-level aggregation") - print("=" * 60) - - # Get embedding settings - guide_embeddings = guide_config.get("embeddings", {}) - compute_embeddings = guide_config.get("compute_embeddings", True) - - # Get embedding parameters - n_pca = guide_embeddings.get("n_pca_components", 128) - n_neighbors = guide_embeddings.get("n_neighbors", 15) - compute_pca = guide_embeddings.get("pca", True) if compute_embeddings else False - compute_umap = ( - guide_embeddings.get("umap", True) if compute_embeddings else False - ) - compute_phate = ( - guide_embeddings.get("phate", True) if compute_embeddings else False - ) - - # Validate: UMAP requires PCA - if compute_umap and not compute_pca: - print("WARNING: UMAP requires PCA. Enabling PCA.") - compute_pca = True - - with timer("TOTAL: Guide-level processing"): - embeddings_guide_bulk_ad = create_aggregated_embeddings( - features_adata, - level="guide", - n_pca_components=n_pca, - n_neighbors=n_neighbors, - compute_pca=compute_pca, - compute_umap=compute_umap, - compute_phate=compute_phate, - ) - - # Validate and save guide-level output - if guide_config.get("save_output", True): - validate_and_save( - embeddings_guide_bulk_ad, - guide_avg_path, - level="guide", - ) - - # Optional plotting - if ( - guide_config.get("plot_umap", False) - and "X_umap" in embeddings_guide_bulk_ad.obsm.keys() - ): - plot_path = save_dir / "guide_umap.png" - guide_col_for_plot = embeddings_guide_bulk_ad.uns.get( - "guide_col", DEFAULT_GUIDE_COL - ) - sc.pl.umap( - embeddings_guide_bulk_ad, - color=guide_col_for_plot, - save=str(plot_path), - ) - print(f"Saved guide UMAP plot to {plot_path}") - else: - print("\nSkipping guide-level aggregation (disabled in config)") - embeddings_guide_bulk_ad = None - - # Gene-level aggregation and analysis (configurable) - if gene_config.get("enabled", True): # Default True for backwards compatibility - print("\n" + "=" * 60) - print("Gene-level aggregation") - print("=" * 60) - - # Get embedding settings - gene_embeddings = gene_config.get("embeddings", {}) - compute_embeddings = gene_config.get("compute_embeddings", True) - - # Get embedding parameters - n_pca = gene_embeddings.get("n_pca_components", 128) - n_neighbors = gene_embeddings.get("n_neighbors", 15) - compute_pca = gene_embeddings.get("pca", True) if compute_embeddings else False - compute_umap = ( - gene_embeddings.get("umap", True) if compute_embeddings else False - ) - compute_phate = ( - gene_embeddings.get("phate", True) if compute_embeddings else False - ) - - # Validate: UMAP requires PCA - if compute_umap and not compute_pca: - print("WARNING: UMAP requires PCA. Enabling PCA.") - compute_pca = True - - with timer("TOTAL: Gene-level processing"): - embeddings_gene_avg_ad = create_aggregated_embeddings( - features_adata, - level="gene", - n_pca_components=n_pca, - n_neighbors=n_neighbors, - compute_pca=compute_pca, - compute_umap=compute_umap, - compute_phate=compute_phate, - ) - - # Validate and save gene-level output - if gene_config.get("save_output", True): - validate_and_save( - embeddings_gene_avg_ad, - gene_avg_path, - level="gene", - ) - - # Optional plotting - if ( - gene_config.get("plot_umap", False) - and "X_umap" in embeddings_gene_avg_ad.obsm.keys() - ): - plot_path = save_dir / "gene_umap.png" - sc.pl.umap( - embeddings_gene_avg_ad, color="perturbation", save=str(plot_path) - ) - print(f"Saved gene UMAP plot to {plot_path}") - else: - print("\nSkipping gene-level aggregation (disabled in config)") - embeddings_gene_avg_ad = None - - total_time = time.time() - total_start - print("\n" + "=" * 60) - print( - f"Pipeline completed in {total_time:.2f} seconds ({total_time/60:.2f} minutes)" - ) - print(f"Cell-level output: {checkpoint_path} (raw embeddings, no PCA/UMAP)") - - # Report guide-level output - if guide_config.get("enabled", True) and guide_config.get("save_output", True): - guide_embeddings = guide_config.get("embeddings", {}) - compute_guide_embeddings = guide_config.get("compute_embeddings", True) - if compute_guide_embeddings: - embeddings_list = [] - if guide_embeddings.get("pca", True): - embeddings_list.append("PCA") - if guide_embeddings.get("umap", True): - embeddings_list.append("UMAP") - if guide_embeddings.get("phate", True): - embeddings_list.append("PHATE") - embeddings_str = ", ".join(embeddings_list) if embeddings_list else "none" - print(f"Guide-bulked output: {guide_avg_path} (with {embeddings_str})") - else: - print(f"Guide-bulked output: {guide_avg_path} (aggregated, no embeddings)") - - # Report gene-level output - if gene_config.get("enabled", True) and gene_config.get("save_output", True): - gene_embeddings = gene_config.get("embeddings", {}) - compute_gene_embeddings = gene_config.get("compute_embeddings", True) - if compute_gene_embeddings: - embeddings_list = [] - if gene_embeddings.get("pca", True): - embeddings_list.append("PCA") - if gene_embeddings.get("umap", True): - embeddings_list.append("UMAP") - if gene_embeddings.get("phate", True): - embeddings_list.append("PHATE") - embeddings_str = ", ".join(embeddings_list) if embeddings_list else "none" - print(f"Gene-bulked output: {gene_avg_path} (with {embeddings_str})") - else: - print(f"Gene-bulked output: {gene_avg_path} (aggregated, no embeddings)") - - print(f"Files saved with reporter suffix: {filename_suffix}") - print("=" * 60 + "\n") - - return features_adata - - -def _build_arg_parser(): - import argparse - - parser = argparse.ArgumentParser( - description="Process neural-network embedding features into AnnData objects." - ) - parser.add_argument( - "--save_path", - type=str, - required=True, - help="Path to the CSV file containing embedding features.", - ) - parser.add_argument( - "--config_path", - type=str, - default=None, - help="Path to configuration YAML file.", - ) - return parser - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - process_embedding_csv( - args.save_path, - config_path=args.config_path, - ) - - -if __name__ == "__main__": - main() diff --git a/src/ops_model/features/processing_common.py b/src/ops_model/features/processing_common.py new file mode 100644 index 0000000..7541962 --- /dev/null +++ b/src/ops_model/features/processing_common.py @@ -0,0 +1,298 @@ +"""Unified CSV → AnnData processing pipeline for CellProfiler and embedding features. + +A single entry point, ``process_features_csv``, branches once on the config's +feature type: + +- ``cellprofiler``: build the cell-level AnnData with ``create_adata_object`` and + split it into per-reporter subsets (``split_adata_by_reporter``), saving one set + of outputs per reporter. +- any embedding type (``dinov3``/``cell_dino``/``subcell``): build with + ``create_adata_object_embedding`` (one channel per CSV) and save a single set of + outputs, the reporter name coming from ``FeatureMetadata``. + +Both paths share validation, guide/gene aggregation, and saving. +""" + +import time +from dataclasses import dataclass +from pathlib import Path + +import pandas as pd + +from ops_model.features.anndata_utils import create_aggregated_embeddings +from ops_model.features.evaluate_cp import ( + create_adata_object, + split_adata_by_reporter, +) +from ops_model.features.evaluate_embeddings import create_adata_object_embedding +from ops_model.post_process.anndata_processing.anndata_validator import AnndataValidator + + +# --------------------------------------------------------------------------- +# Embedding-config helpers (shared by both feature types) +# --------------------------------------------------------------------------- + + +@dataclass +class EmbeddingConfig: + """Resolved embedding settings for one aggregation level.""" + + compute_embeddings: bool + n_pca_components: int + n_neighbors: int + compute_pca: bool + compute_umap: bool + compute_phate: bool + + +def extract_embedding_config(level_cfg: dict) -> EmbeddingConfig: + """Resolve a ``guide_level``/``gene_level`` config block into embedding settings. + + Precedence: + - ``compute_embeddings`` (default True) is the master switch; when False, PCA, + UMAP and PHATE are all forced off regardless of their individual flags. + - Otherwise each of ``pca``/``umap``/``phate`` defaults to True. + - UMAP requires PCA: if UMAP is on but PCA is off, PCA is enabled. + """ + emb = level_cfg.get("embeddings", {}) + compute_embeddings = level_cfg.get("compute_embeddings", True) + + compute_pca = emb.get("pca", True) if compute_embeddings else False + compute_umap = emb.get("umap", True) if compute_embeddings else False + compute_phate = emb.get("phate", True) if compute_embeddings else False + if compute_umap and not compute_pca: + compute_pca = True + + return EmbeddingConfig( + compute_embeddings=compute_embeddings, + n_pca_components=emb.get("n_pca_components", 128), + n_neighbors=emb.get("n_neighbors", 15), + compute_pca=compute_pca, + compute_umap=compute_umap, + compute_phate=compute_phate, + ) + + +def aggregate_level(cell_adata, level: str, level_cfg: dict): + """Aggregate a cell-level AnnData to ``level`` ("guide"/"gene") with embeddings. + + Returns the aggregated AnnData, or ``None`` if the level is disabled + (``enabled: false``). + """ + if not level_cfg.get("enabled", True): + return None + cfg = extract_embedding_config(level_cfg) + return create_aggregated_embeddings( + cell_adata, + level=level, + n_pca_components=cfg.n_pca_components, + n_neighbors=cfg.n_neighbors, + compute_pca=cfg.compute_pca, + compute_umap=cfg.compute_umap, + compute_phate=cfg.compute_phate, + ) + + +def format_embeddings_list(level_cfg: dict, sep: str = ", ") -> str: + """Human-readable list of which embeddings will run for a level (e.g. "PCA, UMAP").""" + cfg = extract_embedding_config(level_cfg) + if not cfg.compute_embeddings: + return "none" + names = [] + if cfg.compute_pca: + names.append("PCA") + if cfg.compute_umap: + names.append("UMAP") + if cfg.compute_phate: + names.append("PHATE") + return sep.join(names) if names else "none" + + +# --------------------------------------------------------------------------- +# Validation + saving +# --------------------------------------------------------------------------- + + +def validate_and_save(adata, path: Path, level: str) -> None: + """Validate an AnnData object against the schema and save it to ``path``. + + Validation is enforced with hard constraints — any errors raise. + + Args: + adata: AnnData object to validate and save. + path: Destination .h5ad path. + level: Schema level ("cell", "guide", "gene"). + """ + print(f"\nValidating {level}-level AnnData before saving...") + report = AnndataValidator().validate(adata, level=level) + errors = report.errors + warnings = report.warnings + + if report.is_valid: + print(f"✓ Validation passed: {level}-level AnnData is compliant") + else: + print(f"Validation found {len(errors)} errors, {len(warnings)} warnings") + if errors: + print("\nERROR-level issues:") + for issue in errors[:10]: + print(f" - {issue.field}: {issue.message}") + if len(errors) > 10: + print(f" ... and {len(errors) - 10} more errors") + if warnings: + print("\nWARNING-level issues:") + for issue in warnings[:5]: + print(f" - {issue.field}: {issue.message}") + if len(warnings) > 5: + print(f" ... and {len(warnings) - 5} more warnings") + + if errors: + error_summary = f"{len(errors)} validation error(s) found in {level}-level AnnData" + print(f"\n✗ Validation FAILED: {error_summary}") + print(f"Fix these issues before saving. First error: {errors[0].message}") + raise ValueError(error_summary) + + print(f"Saving {level}-level AnnData to {path}") + adata.write_h5ad(path) + print(f"✓ Saved successfully: {path}") + + +def _save_level_outputs(cell_adata, suffix, save_dir: Path, guide_cfg: dict, gene_cfg: dict): + """Save cell-level, then aggregate to guide/gene level and save each. + + ``suffix`` is the reporter name (appended as ``_{suffix}``); pass ``None`` for + unsuffixed combined outputs. Guide/gene levels are skipped when disabled + (``aggregate_level`` returns None) or when ``save_output`` is False. + """ + tag = f"_{suffix}" if suffix else "" + + validate_and_save(cell_adata, save_dir / f"features_processed{tag}.h5ad", "cell") + + adata_guide = aggregate_level(cell_adata, "guide", guide_cfg) + if adata_guide is not None and guide_cfg.get("save_output", True): + validate_and_save(adata_guide, save_dir / f"guide_bulked{tag}.h5ad", "guide") + print(f" guide-bulked embeddings: {format_embeddings_list(guide_cfg)}") + + adata_gene = aggregate_level(cell_adata, "gene", gene_cfg) + if adata_gene is not None and gene_cfg.get("save_output", True): + validate_and_save(adata_gene, save_dir / f"gene_bulked{tag}.h5ad", "gene") + print(f" gene-bulked embeddings: {format_embeddings_list(gene_cfg)}") + + +# --------------------------------------------------------------------------- +# Unified entry point +# --------------------------------------------------------------------------- + + +def _detect_channel_and_experiment(save_path: Path): + """For embedding CSVs: channel from the filename ``{model}_features_{channel}.csv`` + and experiment from the CSV's first row.""" + stem = save_path.stem + marker = "_features_" + channel = stem[stem.index(marker) + len(marker) :] if marker in stem else "unknown" + + df_sample = pd.read_csv(save_path, nrows=1) + experiment = ( + df_sample["experiment"].iloc[0] if "experiment" in df_sample.columns else None + ) + return channel, experiment + + +def process_features_csv(save_path: str, config_path: str = None): + """Process a feature/embedding CSV into cell/guide/gene AnnData objects. + + Branches on ``config['model_type']`` (falling back to ``embedding_type``, then + ``cellprofiler``): the ``cellprofiler`` path splits by reporter; embedding paths + save a single reporter-named set of outputs. Returns the cell-level AnnData. + + Args: + save_path: Path to the features/embeddings CSV. + config_path: Path to the YAML config (must set ``cell_type``). + """ + config = {} + if config_path is not None: + import yaml + + with open(config_path, "r") as f: + config = yaml.safe_load(f) + print(f"Loaded configuration from {config_path}") + + cell_type = config.get("cell_type") + if not cell_type: + raise ValueError( + "cell_type must be specified in config " + "(e.g. cell_type: 'A549') for validator compliance." + ) + + feature_type = ( + config.get("model_type") or config.get("embedding_type") or "cellprofiler" + ) + + save_path = Path(save_path) + save_dir = save_path.parent / "anndata_objects" + save_dir.mkdir(parents=True, exist_ok=True) + + agg_config = config.get("aggregation", {}) + guide_cfg = agg_config.get("guide_level", {}) + gene_cfg = agg_config.get("gene_level", {}) + + print("\n" + "=" * 60) + print(f"Processing {feature_type} features: {save_path.name}") + print("=" * 60) + t_start = time.time() + + if feature_type == "cellprofiler": + cell_adata = create_adata_object( + str(save_path), + config=config, + cell_type=cell_type, + embedding_type=feature_type, + ) + # CellProfiler CSVs interleave multiple channels' features; split per reporter. + if ( + "channel_mapping" in cell_adata.uns + and len(cell_adata.uns["channel_mapping"]) >= 1 + ): + reporter_adatas = split_adata_by_reporter(cell_adata, verbose=True) + for reporter, adata_cell in reporter_adatas.items(): + print(f"\n--- reporter: {reporter} (channel: {adata_cell.uns['channel']}) ---") + _save_level_outputs(adata_cell, reporter, save_dir, guide_cfg, gene_cfg) + else: + print("\n(No channel_mapping found - saving combined file)") + _save_level_outputs(cell_adata, None, save_dir, guide_cfg, gene_cfg) + else: + channel, experiment = _detect_channel_and_experiment(save_path) + from ops_utils.data.feature_metadata import FeatureMetadata + + reporter = FeatureMetadata().get_biological_signal(experiment, channel) + print(f"channel={channel} experiment={experiment} reporter={reporter}") + cell_adata = create_adata_object_embedding( + str(save_path), + config=config, + channel=channel, + experiment=experiment, + cell_type=cell_type, + embedding_type=feature_type, + ) + _save_level_outputs(cell_adata, reporter, save_dir, guide_cfg, gene_cfg) + + elapsed = time.time() - t_start + print("\n" + "=" * 60) + print(f"Pipeline completed in {elapsed:.2f}s ({elapsed/60:.2f} min)") + print("=" * 60 + "\n") + return cell_adata + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Process a feature/embedding CSV into AnnData objects." + ) + parser.add_argument("--save_path", type=str, required=True) + parser.add_argument("--config_path", type=str, default=None) + args = parser.parse_args() + process_features_csv(args.save_path, config_path=args.config_path) + + +if __name__ == "__main__": + main() diff --git a/tests/features/test_evaluate_dinov3.py b/tests/features/test_evaluate_dinov3.py deleted file mode 100644 index f9d8628..0000000 --- a/tests/features/test_evaluate_dinov3.py +++ /dev/null @@ -1,764 +0,0 @@ -""" -Tests for DinoV3 feature evaluation pipeline. - -This test suite validates the processing of DinoV3 embeddings through -the AnnData pipeline, including normalization, aggregation, and output generation. -""" - -import warnings - -# Filter anndata and zarr warnings -warnings.filterwarnings("ignore", message=".*zarr v2.*", category=DeprecationWarning) -warnings.filterwarnings("ignore", category=DeprecationWarning) - -import pytest -import numpy as np -import pandas as pd -import anndata as ad -from pathlib import Path -import tempfile -import shutil - -from ops_model.features.evaluate_cp import ( - center_scale_fast, - pca_embed, -) -from ops_model.features.evaluate_dinov3 import ( - create_adata_object_dinov3, - process_dinov3, -) - - -# ============================================================================ -# Fixtures -# ============================================================================ - - -@pytest.fixture(scope="module") -def mock_dinov3_csv(): - """ - Create a synthetic DinoV3 features DataFrame for testing. - - Returns: - pd.DataFrame with 1024 feature columns + metadata - """ - np.random.seed(42) - - # Use 150 cells to allow PCA with 128 components (need > 128 samples) - n_cells = 150 - n_features = 1024 - - # Create feature matrix with realistic embedding values - # DinoV3 embeddings typically have values roughly in [-2, 2] range - features = np.random.randn(n_cells, n_features).astype(np.float32) - - # Create DataFrame with numbered columns (0-1023) - feature_df = pd.DataFrame(features, columns=[str(i) for i in range(n_features)]) - - # Create metadata - genes = ["NTC", "GENE_A", "GENE_B", "GENE_C"] - gene_labels = np.random.choice(genes, size=n_cells) - - # Create multiple guides per gene - guides = [] - for gene in gene_labels: - guide_idx = np.random.randint(1, 4) # 3 guides per gene - guides.append(f"{gene}_sg{guide_idx}") - - # Create label_int mapping - gene_to_int = {gene: idx for idx, gene in enumerate(genes)} - label_ints = [gene_to_int[gene] for gene in gene_labels] - - # Add metadata columns - feature_df["label_int"] = label_ints - feature_df["label_str"] = gene_labels - feature_df["sgRNA"] = guides - feature_df["experiment"] = "ops0089_20251119" - feature_df["x_position"] = np.random.uniform(0, 1000, n_cells) - feature_df["y_position"] = np.random.uniform(0, 1000, n_cells) - feature_df["well"] = np.random.choice( - ["A1_ops0089_20251119", "A2_ops0089_20251119"], n_cells - ) - - return feature_df - - -@pytest.fixture(scope="module") -def mock_dinov3_csv_path(mock_dinov3_csv): - """ - Write mock DinoV3 CSV to temporary file. - - Returns: - Path to temporary CSV file - """ - # Create temporary directory - temp_dir = tempfile.mkdtemp() - csv_path = Path(temp_dir) / "dinov3_features_Phase2D.csv" - - # Write CSV - mock_dinov3_csv.to_csv(csv_path, index=False) - - yield csv_path - - # Cleanup - shutil.rmtree(temp_dir) - - -@pytest.fixture(scope="module") -def processed_adata(mock_dinov3_csv_path): - """ - Run full processing pipeline on mock data. - - Returns: - Processed AnnData object (cell-level) - """ - config = { - "normalize_PCA_embeddings": False, - } - - # Note: This will create guide/gene aggregations but we're returning cell-level - adata = create_adata_object_dinov3(str(mock_dinov3_csv_path), config=config) - - return adata - - -# ============================================================================ -# Unit Tests -# ============================================================================ - - -def test_load_dinov3_csv(mock_dinov3_csv): - """Test 1: Validate CSV loading and basic structure.""" - assert isinstance(mock_dinov3_csv, pd.DataFrame) - assert len(mock_dinov3_csv) == 150, "Should have 150 rows" - - # 1024 features + 7 metadata columns - assert ( - len(mock_dinov3_csv.columns) == 1031 - ), "Should have 1024 features + 7 metadata" - - # Check that feature columns are numeric - feature_cols = [str(i) for i in range(1024)] - for col in feature_cols: - assert col in mock_dinov3_csv.columns - assert pd.api.types.is_numeric_dtype(mock_dinov3_csv[col]) - - -def test_metadata_columns_present(mock_dinov3_csv): - """Test 2: Validate required metadata columns.""" - required_columns = [ - "label_int", - "label_str", - "sgRNA", - "experiment", - "x_position", - "y_position", - "well", - ] - - for col in required_columns: - assert col in mock_dinov3_csv.columns, f"Missing required column: {col}" - - # Check no missing values in critical columns - assert ( - not mock_dinov3_csv["label_str"].isna().any() - ), "label_str should not have NaN" - assert not mock_dinov3_csv["sgRNA"].isna().any(), "sgRNA should not have NaN" - - # Check data types - assert pd.api.types.is_integer_dtype(mock_dinov3_csv["label_int"]) - assert pd.api.types.is_numeric_dtype(mock_dinov3_csv["x_position"]) - assert pd.api.types.is_numeric_dtype(mock_dinov3_csv["y_position"]) - - -def test_feature_dimensions(mock_dinov3_csv): - """Test 3: Validate feature dimensions and values.""" - feature_cols = [str(i) for i in range(1024)] - features = mock_dinov3_csv[feature_cols] - - # Check dimensions - assert features.shape[1] == 1024, "Should have exactly 1024 features" - - # Check all numeric - assert features.dtypes.apply(lambda x: pd.api.types.is_numeric_dtype(x)).all() - - # Check no NaN values - assert not features.isna().any().any(), "Features should not contain NaN" - - # Check reasonable value ranges (embeddings typically in [-5, 5]) - assert features.min().min() > -10, "Feature values too negative" - assert features.max().max() < 10, "Feature values too large" - - # Check not all zeros - assert features.std().mean() > 0.1, "Features should have variance" - - -def test_normalization(mock_dinov3_csv): - """Test 4: Validate center-scaling normalization.""" - feature_cols = [str(i) for i in range(1024)] - features = mock_dinov3_csv[feature_cols + ["label_str"]].copy() - - # Apply normalization - features_norm = center_scale_fast(features, on_controls=False) - - # Check shape preserved - assert features_norm.shape == features.shape - - # Check mean ≈ 0 and std ≈ 1 (excluding label_str column) - feature_cols_only = [col for col in features_norm.columns if col != "label_str"] - means = features_norm[feature_cols_only].mean() - stds = features_norm[feature_cols_only].std() - - assert np.allclose(means, 0, atol=1e-6), "Mean should be close to 0" - assert np.allclose(stds, 1, atol=1e-6), "Std should be close to 1" - - # Test control-based normalization - features_norm_ctrl = center_scale_fast( - features, on_controls=True, control_column="label_str", control_gene="NTC" - ) - assert features_norm_ctrl.shape == features.shape - - # Check that label_str column preserved - assert "label_str" in features_norm.columns - pd.testing.assert_series_equal( - features["label_str"], features_norm["label_str"], check_names=False - ) - - -def test_no_constant_columns(mock_dinov3_csv): - """Test 5: Validate that QC filtering removes constant columns.""" - # Create a copy with one constant column - test_df = mock_dinov3_csv.copy() - test_df["constant_col"] = 1.0 - - # Check that we can identify constant columns - constant_cols = test_df.columns[test_df.nunique(dropna=False) == 1] - assert "constant_col" in constant_cols.tolist() - - # Verify DinoV3 features have no constant columns - feature_cols = [str(i) for i in range(1024)] - features = mock_dinov3_csv[feature_cols] - constant_features = features.columns[features.nunique(dropna=False) == 1] - assert ( - len(constant_features) == 0 - ), "DinoV3 embeddings should not have constant columns" - - -# ============================================================================ -# Integration Tests -# ============================================================================ - - -def test_create_adata_object_structure(processed_adata): - """Test 6: Validate AnnData object structure at cell level.""" - assert isinstance(processed_adata, ad.AnnData) - - # Check .X shape and dtype - assert processed_adata.X.shape[0] == 150, "Should have 150 cells" - assert processed_adata.X.shape[1] == 1024, "Should have 1024 features (embeddings)" - # Note: center_scale_fast returns float64 for precision, this is acceptable - assert processed_adata.X.dtype in [ - np.float32, - np.float64, - ], "Should use float32 or float64" - - # Check .obs contains metadata - required_obs_cols = ["label_str", "label_int", "sgRNA", "well"] - for col in required_obs_cols: - assert col in processed_adata.obs.columns, f"Missing obs column: {col}" - - # Check .var_names - assert len(processed_adata.var_names) == 1024 - - # CRITICAL: DinoV3 cell-level should NOT have PCA/UMAP (only at aggregation level) - assert ( - "X_pca" not in processed_adata.obsm.keys() - ), "DinoV3 cell-level AnnData should not have PCA (only at aggregation level)" - assert ( - "X_umap" not in processed_adata.obsm.keys() - ), "DinoV3 cell-level AnnData should not have UMAP (only at aggregation level)" - - -def test_adata_metadata_integrity(processed_adata, mock_dinov3_csv): - """Test 7: Validate metadata preservation through pipeline.""" - # Check that metadata matches original CSV - assert len(processed_adata.obs) == len(mock_dinov3_csv) - - # Check label_str preservation - original_genes = sorted(mock_dinov3_csv["label_str"].unique()) - adata_genes = sorted(processed_adata.obs["label_str"].unique()) - assert original_genes == adata_genes - - # Check sgRNA preservation - original_guides = sorted(mock_dinov3_csv["sgRNA"].unique()) - adata_guides = sorted(processed_adata.obs["sgRNA"].unique()) - assert original_guides == adata_guides - - # Check gene counts - original_gene_counts = mock_dinov3_csv["label_str"].value_counts() - adata_gene_counts = processed_adata.obs["label_str"].value_counts() - pd.testing.assert_series_equal( - original_gene_counts.sort_index(), adata_gene_counts.sort_index() - ) - - # Check guide counts - original_guide_counts = mock_dinov3_csv["sgRNA"].value_counts() - adata_guide_counts = processed_adata.obs["sgRNA"].value_counts() - pd.testing.assert_series_equal( - original_guide_counts.sort_index(), adata_guide_counts.sort_index() - ) - - -# ============================================================================ -# Aggregation Tests -# ============================================================================ - - -@pytest.fixture(scope="module") -def guide_bulked_adata(mock_dinov3_csv_path): - """ - Create guide-level aggregated AnnData for testing. - This mimics the guide-bulking step in the pipeline. - """ - # Load and process - config = {"normalize_PCA_embeddings": False} - adata = create_adata_object_dinov3(str(mock_dinov3_csv_path), config=config) - - # Aggregate by guide - embeddings_df = pd.DataFrame(adata.X) - embeddings_df["sgRNA"] = adata.obs["sgRNA"].values - embeddings_guide_bulk = embeddings_df.groupby("sgRNA").mean() - - # Create bulked AnnData - adata_guide = ad.AnnData(embeddings_guide_bulk) - adata_guide.obs["sgRNA"] = adata_guide.obs_names - - # Compute PCA and UMAP - adata_guide = pca_embed( - adata_guide, n_components=min(128, adata_guide.shape[0] - 1) - ) - - # Import scanpy for UMAP and PHATE - import scanpy as sc - import scanpy.external as sce - - n_pcs = min(50, adata_guide.obsm["X_pca"].shape[1]) - n_neighbors = min(15, adata_guide.shape[0] - 1) - - sc.pp.neighbors(adata_guide, n_pcs=n_pcs, n_neighbors=n_neighbors, metric="cosine") - sc.tl.umap(adata_guide, min_dist=0.1) - - # Compute PHATE - import warnings - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning, module="phate") - warnings.filterwarnings("ignore", category=UserWarning, module="graphtools") - warnings.filterwarnings("ignore", category=RuntimeWarning, module="phate") - sce.tl.phate( - adata_guide, - n_components=2, - k=n_neighbors, - n_pca=n_pcs, - knn_dist="cosine", - t="auto", - ) - - return adata_guide - - -@pytest.fixture(scope="module") -def gene_bulked_adata(mock_dinov3_csv_path): - """ - Create gene-level aggregated AnnData for testing. - """ - # Load and process - config = {"normalize_PCA_embeddings": False} - adata = create_adata_object_dinov3(str(mock_dinov3_csv_path), config=config) - - # Aggregate by gene - embeddings_df = pd.DataFrame(adata.X) - embeddings_df["label_str"] = adata.obs["label_str"].values - embeddings_gene_avg = embeddings_df.groupby("label_str").mean() - - # Create bulked AnnData - adata_gene = ad.AnnData(embeddings_gene_avg) - adata_gene.obs["label_str"] = adata_gene.obs_names - - # Compute PCA and UMAP - adata_gene = pca_embed(adata_gene, n_components=min(128, adata_gene.shape[0] - 1)) - - # Import scanpy for UMAP and PHATE - import scanpy as sc - import scanpy.external as sce - - n_pcs = min(3, adata_gene.obsm["X_pca"].shape[1]) - n_neighbors = min(3, adata_gene.shape[0] - 1) - - sc.pp.neighbors(adata_gene, n_pcs=n_pcs, n_neighbors=n_neighbors, metric="cosine") - sc.tl.umap(adata_gene, min_dist=0.1) - - # Compute PHATE (skip for very small datasets where PHATE may have numerical issues) - if adata_gene.shape[0] >= 5: - import warnings - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning, module="phate") - warnings.filterwarnings("ignore", category=UserWarning, module="graphtools") - warnings.filterwarnings("ignore", category=RuntimeWarning, module="phate") - sce.tl.phate( - adata_gene, - n_components=2, - k=n_neighbors, - n_pca=n_pcs, - knn_dist="cosine", - t="auto", - ) - - return adata_gene - - -def test_guide_level_aggregation_structure(guide_bulked_adata, mock_dinov3_csv): - """Test 8: Validate guide-level aggregation structure.""" - n_unique_guides = mock_dinov3_csv["sgRNA"].nunique() - - assert isinstance(guide_bulked_adata, ad.AnnData) - assert ( - guide_bulked_adata.shape[0] == n_unique_guides - ), f"Should have {n_unique_guides} guides" - assert ( - guide_bulked_adata.shape[1] == 1024 - ), "Should maintain 1024 feature dimensions" - - # Check each guide represented once - assert len(guide_bulked_adata.obs["sgRNA"].unique()) == n_unique_guides - - # Check metadata preserved - assert "sgRNA" in guide_bulked_adata.obs.columns - - -def test_guide_level_pca_umap(guide_bulked_adata): - """Test 9: Validate PCA and UMAP on guide-level data.""" - # Check PCA exists - assert "X_pca" in guide_bulked_adata.obsm.keys(), "Guide-level data should have PCA" - - n_guides = guide_bulked_adata.shape[0] - pca_shape = guide_bulked_adata.obsm["X_pca"].shape - assert pca_shape[0] == n_guides - assert pca_shape[1] <= min( - 128, n_guides - 1 - ), "PCA components should not exceed n_samples - 1" - - # Check explained variance exists - assert "pca" in guide_bulked_adata.uns.keys() - assert "variance" in guide_bulked_adata.uns["pca"].keys() - - # Check UMAP exists - assert ( - "X_umap" in guide_bulked_adata.obsm.keys() - ), "Guide-level data should have UMAP" - - umap_shape = guide_bulked_adata.obsm["X_umap"].shape - assert umap_shape == (n_guides, 2), "UMAP should be 2D" - - # Check UMAP coordinates are finite - assert np.isfinite( - guide_bulked_adata.obsm["X_umap"] - ).all(), "UMAP coordinates should be finite" - - -def test_guide_level_phate(guide_bulked_adata): - """Test 9b: Validate PHATE on guide-level data.""" - # Check PHATE exists - assert ( - "X_phate" in guide_bulked_adata.obsm.keys() - ), "Guide-level data should have PHATE" - - n_guides = guide_bulked_adata.shape[0] - phate_shape = guide_bulked_adata.obsm["X_phate"].shape - assert phate_shape == (n_guides, 2), "PHATE should be 2D" - - # Check PHATE coordinates are finite - assert np.isfinite( - guide_bulked_adata.obsm["X_phate"] - ).all(), "PHATE coordinates should be finite" - - # Check PHATE is different from UMAP (they should produce different embeddings) - if "X_umap" in guide_bulked_adata.obsm.keys(): - assert not np.allclose( - guide_bulked_adata.obsm["X_phate"], guide_bulked_adata.obsm["X_umap"] - ), "PHATE and UMAP should produce different embeddings" - - -def test_gene_level_aggregation_structure(gene_bulked_adata, mock_dinov3_csv): - """Test 10: Validate gene-level aggregation structure.""" - n_unique_genes = mock_dinov3_csv["label_str"].nunique() - - assert isinstance(gene_bulked_adata, ad.AnnData) - assert ( - gene_bulked_adata.shape[0] == n_unique_genes - ), f"Should have {n_unique_genes} genes" - assert gene_bulked_adata.shape[1] == 1024, "Should maintain 1024 feature dimensions" - - # Check each gene represented once - assert len(gene_bulked_adata.obs["label_str"].unique()) == n_unique_genes - - # Check metadata preserved - assert "label_str" in gene_bulked_adata.obs.columns - - # Check expected genes present - expected_genes = ["NTC", "GENE_A", "GENE_B", "GENE_C"] - for gene in expected_genes: - assert gene in gene_bulked_adata.obs["label_str"].values - - -def test_gene_level_pca_umap(gene_bulked_adata): - """Test 11: Validate PCA and UMAP on gene-level data.""" - # Check PCA exists - assert "X_pca" in gene_bulked_adata.obsm.keys(), "Gene-level data should have PCA" - - n_genes = gene_bulked_adata.shape[0] - pca_shape = gene_bulked_adata.obsm["X_pca"].shape - assert pca_shape[0] == n_genes - assert pca_shape[1] <= min( - 128, n_genes - 1 - ), "PCA components should not exceed n_samples - 1" - - # Check UMAP exists - assert "X_umap" in gene_bulked_adata.obsm.keys(), "Gene-level data should have UMAP" - - umap_shape = gene_bulked_adata.obsm["X_umap"].shape - assert umap_shape == (n_genes, 2), "UMAP should be 2D" - - # Check UMAP coordinates are finite - assert np.isfinite( - gene_bulked_adata.obsm["X_umap"] - ).all(), "UMAP coordinates should be finite" - - -def test_gene_level_phate(gene_bulked_adata): - """Test 11b: Validate PHATE on gene-level data.""" - # Check PHATE exists - assert ( - "X_phate" in gene_bulked_adata.obsm.keys() - ), "Gene-level data should have PHATE" - - n_genes = gene_bulked_adata.shape[0] - phate_shape = gene_bulked_adata.obsm["X_phate"].shape - assert phate_shape == (n_genes, 2), "PHATE should be 2D" - - # Check PHATE coordinates are finite - assert np.isfinite( - gene_bulked_adata.obsm["X_phate"] - ).all(), "PHATE coordinates should be finite" - - # Check PHATE is different from UMAP (they should produce different embeddings) - if "X_umap" in gene_bulked_adata.obsm.keys(): - assert not np.allclose( - gene_bulked_adata.obsm["X_phate"], gene_bulked_adata.obsm["X_umap"] - ), "PHATE and UMAP should produce different embeddings" - - -# ============================================================================ -# Output Tests -# ============================================================================ - - -@pytest.mark.slow -def test_output_files_created(mock_dinov3_csv_path): - """Test 12: Validate that all output files are created.""" - config = { - "normalize_PCA_embeddings": False, - } - - # Run full pipeline - adata = process_dinov3(str(mock_dinov3_csv_path), config=config) - - # Check output directory created - output_dir = mock_dinov3_csv_path.parent / "anndata_objects" - assert output_dir.exists(), "anndata_objects directory should be created" - - # Check main output file - main_file = output_dir / "features_processed.h5ad" - assert main_file.exists(), "features_processed.h5ad should be created" - - # Check guide-bulk file - guide_file = output_dir / "guide_bulked_umap.h5ad" - assert guide_file.exists(), "guide_bulked_umap.h5ad should be created" - - # Check gene-bulk file - gene_file = output_dir / "gene_bulked_umap.h5ad" - assert gene_file.exists(), "gene_bulked_umap.h5ad should be created" - - # Verify files can be read - adata_main = ad.read_h5ad(main_file) - assert isinstance(adata_main, ad.AnnData) - - adata_guide = ad.read_h5ad(guide_file) - assert isinstance(adata_guide, ad.AnnData) - - adata_gene = ad.read_h5ad(gene_file) - assert isinstance(adata_gene, ad.AnnData) - - -@pytest.mark.slow -def test_saved_adata_completeness(mock_dinov3_csv_path): - """Test 13: Validate round-trip save/load preserves data.""" - config = { - "normalize_PCA_embeddings": False, - } - - # Run pipeline - adata_original = process_dinov3(str(mock_dinov3_csv_path), config=config) - - output_dir = mock_dinov3_csv_path.parent / "anndata_objects" - - # Load cell-level file - adata_loaded = ad.read_h5ad(output_dir / "features_processed.h5ad") - - # Check .X preserved - assert adata_loaded.X.shape == adata_original.X.shape - assert np.allclose(adata_loaded.X, adata_original.X) - - # Check .obs preserved - assert list(adata_loaded.obs.columns) == list(adata_original.obs.columns) - - # DinoV3 cell-level should NOT have PCA/UMAP - assert ( - len(adata_loaded.obsm.keys()) == 0 - ), "DinoV3 cell-level file should not contain obsm (PCA/UMAP only at aggregation level)" - assert ( - len(adata_original.obsm.keys()) == 0 - ), "DinoV3 cell-level should not have obsm" - - # Load guide-bulk and check UMAP and PHATE present - # NOTE: Guide-bulk uses cell-level PCA embeddings directly (n_pcs=0), - # so it won't have its own X_pca, just X_umap and X_phate - adata_guide = ad.read_h5ad(output_dir / "guide_bulked_umap.h5ad") - assert "X_umap" in adata_guide.obsm.keys() - assert ( - "X_phate" in adata_guide.obsm.keys() - ), "Guide-level file should contain PHATE embeddings" - - # Load gene-bulk and check UMAP and PHATE present - adata_gene = ad.read_h5ad(output_dir / "gene_bulked_umap.h5ad") - assert "X_umap" in adata_gene.obsm.keys() - assert ( - "X_phate" in adata_gene.obsm.keys() - ), "Gene-level file should contain PHATE embeddings" - - -# ============================================================================ -# Edge Cases and Error Handling -# ============================================================================ - - -def test_missing_metadata_columns(): - """Test 14: Validate error handling for missing metadata columns.""" - # Create CSV missing a required column - df = pd.DataFrame( - { - "0": [1.0, 2.0], - "1": [3.0, 4.0], - # Missing label_str, sgRNA, etc. - } - ) - - with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: - df.to_csv(f.name, index=False) - temp_path = f.name - - try: - config = {"normalize_PCA_embeddings": False} - # This should raise an error or handle gracefully - with pytest.raises(KeyError): - create_adata_object_dinov3(temp_path, config=config) - finally: - Path(temp_path).unlink() - - -def test_wrong_feature_dimensions(): - """Test 15: Validate handling of incorrect feature dimensions.""" - # Create CSV with wrong number of features (not 1024) - np.random.seed(42) - wrong_features = np.random.randn(10, 512) # Only 512 features - - df = pd.DataFrame(wrong_features, columns=[str(i) for i in range(512)]) - df["label_int"] = [0] * 10 - df["label_str"] = ["GENE_A"] * 10 - df["sgRNA"] = ["GENE_A_sg1"] * 10 - df["experiment"] = ["ops0089_20251119"] * 10 - df["x_position"] = [100.0] * 10 - df["y_position"] = [200.0] * 10 - df["well"] = ["A1_ops0089_20251119"] * 10 - - with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: - df.to_csv(f.name, index=False) - temp_path = f.name - - try: - config = {"normalize_PCA_embeddings": False} - # Should handle gracefully - may process with 512 features - adata = create_adata_object_dinov3(temp_path, config=config) - # Just check it doesn't crash - actual dimension is 512 - assert adata.shape[1] == 512 - finally: - Path(temp_path).unlink() - - -def test_empty_dataframe(): - """Test 16: Validate handling of empty input.""" - # Create empty CSV with correct columns - df = pd.DataFrame( - columns=[str(i) for i in range(1024)] - + [ - "label_int", - "label_str", - "sgRNA", - "experiment", - "x_position", - "y_position", - "well", - ] - ) - - with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: - df.to_csv(f.name, index=False) - temp_path = f.name - - try: - config = {"normalize_PCA_embeddings": False} - # Should handle empty data gracefully - adata = create_adata_object_dinov3(temp_path, config=config) - assert adata.shape[0] == 0, "Empty input should produce empty AnnData" - except Exception as e: - # Also acceptable to raise an error for empty data - assert "empty" in str(e).lower() or "shape" in str(e).lower() - finally: - Path(temp_path).unlink() - - -# ============================================================================ -# Comparison Tests -# ============================================================================ - - -def test_pipeline_output_structure_consistency(processed_adata): - """Test 17: Ensure DinoV3 output structure matches expected format.""" - # Check that required metadata columns present - required_obs = ["label_str", "label_int", "sgRNA", "well"] - for col in required_obs: - assert ( - col in processed_adata.obs.columns - ), f"Required obs column {col} missing - inconsistent with CellProfiler format" - - # Check that .X is the feature matrix - assert hasattr(processed_adata, "X") - assert isinstance(processed_adata.X, np.ndarray) - - # Check that cell-level has no PCA/UMAP (different from CellProfiler which may have it) - assert "X_pca" not in processed_adata.obsm.keys() - assert "X_umap" not in processed_adata.obsm.keys() - - # Note: This is expected difference between DinoV3 and CellProfiler - # Both will have PCA/UMAP at aggregation level (guide/gene bulked) diff --git a/tests/models/test_extractor_e2e.py b/tests/models/test_extractor_e2e.py index dd10597..c932abb 100644 --- a/tests/models/test_extractor_e2e.py +++ b/tests/models/test_extractor_e2e.py @@ -7,7 +7,7 @@ per gene knockout, 2. run ``extract_*_features`` (crops are pulled from the real ``phenotyping_v3.zarr`` at the subsampled bboxes), - 3. convert the resulting feature CSV to AnnData via ``process_embedding_csv``. + 3. convert the resulting feature CSV to AnnData via ``process_features_csv``. Outputs at each step are verified, and everything is written under pytest's ``tmp_path`` — only the link CSV / zarr / checkpoints are read. @@ -35,7 +35,7 @@ from ops_model.models.cell_dino import extract_cell_dino_features from ops_model.models.dinov3 import extract_dinov3_features from ops_model.models.subcell import extract_subcell_features -from ops_model.features.evaluate_embeddings import process_embedding_csv +from ops_model.features.processing_common import process_features_csv # Integration test: heavy + needs a GPU, so keep it out of the default run and # don't let incidental torch/monai/anndata warnings fail it (the suite treats @@ -228,7 +228,7 @@ def test_extractor_end_to_end(spec, subsampled_link_dir, tmp_path, _require_gpu) assert meta_col in feats.columns, f"missing metadata column {meta_col}" # --- Step 3: feature CSV -> AnnData -------------------------------------- - adata = process_embedding_csv(str(feature_csv), config_path=str(config_path)) + adata = process_features_csv(str(feature_csv), config_path=str(config_path)) anndata_dir = feature_csv.parent / "anndata_objects" produced = list(anndata_dir.glob("features_processed_*.h5ad")) From 6571e39035f378f5bb74fe54e32c5d9fd3438f21 Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 10 Jun 2026 16:01:01 -0700 Subject: [PATCH 05/11] Add end-to-end test scripts for the feature pipelines MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tests/e2e_tests/ — self-contained scripts (run directly with uv run python, not pytest) covering each core ops_model feature: the cell_dino, dinov3, subcell and cell_profiler extractors, and the pca_optimization combination pipeline. Each subsets real inputs to a minimal example in a tmp dir, points an inline config at it, runs the feature normally, and verifies the outputs at each step. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/e2e_tests/cell_dino_e2e.py | 168 +++++++++++++++++++++ tests/e2e_tests/cell_profiler_e2e.py | 210 +++++++++++++++++++++++++++ tests/e2e_tests/combination_e2e.py | 168 +++++++++++++++++++++ tests/e2e_tests/dinov3_e2e.py | 168 +++++++++++++++++++++ tests/e2e_tests/subcell_e2e.py | 179 +++++++++++++++++++++++ tests/e2e_tests/testing_plan | 13 ++ 6 files changed, 906 insertions(+) create mode 100644 tests/e2e_tests/cell_dino_e2e.py create mode 100644 tests/e2e_tests/cell_profiler_e2e.py create mode 100644 tests/e2e_tests/combination_e2e.py create mode 100644 tests/e2e_tests/dinov3_e2e.py create mode 100644 tests/e2e_tests/subcell_e2e.py create mode 100644 tests/e2e_tests/testing_plan diff --git a/tests/e2e_tests/cell_dino_e2e.py b/tests/e2e_tests/cell_dino_e2e.py new file mode 100644 index 0000000..be0df71 --- /dev/null +++ b/tests/e2e_tests/cell_dino_e2e.py @@ -0,0 +1,168 @@ +"""End-to-end test for the Cell-DINO feature-extraction pipeline. + +Self-contained script (run directly, not via pytest). It exercises the full +production path for one minimal example: + + 1. subset a real per-well link CSV to one cell per gene KO, saved to a tmp dir + 2. point an inline config at that tmp dir (link_csv_dir) + tmp output_dir + 3. run extraction normally (extract_cell_dino_features) -> feature CSV + 4. convert to AnnData (process_features_csv) -> cell/guide/gene .h5ad + 5. verify the outputs at each step + +Crops are read from the real phenotyping_v3.zarr at the subsampled bboxes, so +only the link CSV is subset; nothing else is mutated. All outputs go to a fresh +tmp dir (printed at the end for inspection). + +Requires a GPU node with the Cell-DINO checkpoint on disk. + +Run with: + uv run python tests/e2e_tests/cell_dino_e2e.py +""" + +import tempfile +from pathlib import Path + +import pandas as pd +import yaml + +from ops_model.models.cell_dino import extract_cell_dino_features +from ops_model.features.processing_common import process_features_csv + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +EXPERIMENT = "ops0031_20250424" +WELL = "A/1/0" +WELL_PREFIX = WELL[0] + WELL[2] # "A/1/0" -> "A1" +CHANNEL = "Phase2D" +EXPECTED_DIM = 1024 # Cell-DINO ViT-L/16 embedding width +MAX_GENES = 32 # cap (one cell per gene) to keep the run fast + +ASSEMBLY_DIR = Path(f"/hpc/projects/icd.fast.ops/{EXPERIMENT}/3-assembly") +SOURCE_LINK_CSV = ASSEMBLY_DIR / f"{WELL_PREFIX}_linked_pheno_iss.csv" +PHENOTYPING_ZARR = ASSEMBLY_DIR / "phenotyping_v3.zarr" + + +def build_config(link_dir: Path, output_dir: Path) -> dict: + """Inline config; mirrors experiments/embedding/configs/cell_dino/ops0031_cell_dino.yml.""" + return { + "model_type": "cell_dino", + "embedding_type": "cell_dino", + "dataset_type": "basic", + "data_manager": { + "experiments": {EXPERIMENT: [WELL]}, + "batch_size": 8, + "data_split": [0, 0, 1], # everything in the test loader + "out_channels": [CHANNEL], + "initial_yx_patch_size": [256, 256], + "final_yx_patch_size": [128, 128], + "num_workers": 0, + "link_csv_dir": str(link_dir), + }, + "output_dir": str(output_dir), + "cell_type": "A549", + # Skip PCA/UMAP: they'd fail on a tiny one-cell-per-gene dataset. + "aggregation": { + "guide_level": {"compute_embeddings": False}, + "gene_level": {"compute_embeddings": False}, + }, + } + + +# --------------------------------------------------------------------------- +# Steps +# --------------------------------------------------------------------------- + + +def subsample_link_csv(link_dir: Path) -> int: + """Copy the real link CSV into link_dir, keeping one cell per gene KO. + + Returns the number of genes (== number of cells kept). + """ + assert SOURCE_LINK_CSV.exists(), f"source link CSV not found: {SOURCE_LINK_CSV}" + assert PHENOTYPING_ZARR.exists(), f"phenotyping zarr not found: {PHENOTYPING_ZARR}" + + df = pd.read_csv(SOURCE_LINK_CSV) + gene_col = "gene_name" if "gene_name" in df.columns else "Gene name" + assert gene_col in df.columns, f"no gene column in {SOURCE_LINK_CSV}" + + # Mirror the loader's basic QC so kept cells survive into a batch. + df = df.dropna(subset=["segmentation_id", gene_col]) + + one_per_gene = df.groupby(gene_col, sort=True).head(1) + genes = one_per_gene[gene_col].unique()[:MAX_GENES] + subsampled = one_per_gene[one_per_gene[gene_col].isin(genes)].reset_index(drop=True) + + n_genes = subsampled[gene_col].nunique() + assert n_genes >= 2, "need at least a couple of genes for the test" + assert len(subsampled) == n_genes, "expected exactly one cell per gene" + + out_csv = link_dir / f"{WELL_PREFIX}_linked_pheno_iss.csv" + subsampled.to_csv(out_csv, index=False) + assert out_csv.exists() + print(f"[1] Subsampled link CSV: {n_genes} cells (1/gene) -> {out_csv}") + return n_genes + + +def verify_feature_csv(feature_csv: Path, n_genes: int) -> int: + """Verify the extracted feature CSV. Returns the cell count.""" + assert feature_csv.exists(), f"feature CSV not produced: {feature_csv}" + feats = pd.read_csv(feature_csv) + n_cells = len(feats) + assert 0 < n_cells <= n_genes, f"unexpected cell count {n_cells} (genes={n_genes})" + + feature_cols = [c for c in feats.columns if str(c).isdigit()] + assert ( + len(feature_cols) == EXPECTED_DIM + ), f"expected {EXPECTED_DIM} feature dims, got {len(feature_cols)}" + for col in ("label_int", "label_str", "sgRNA", "experiment", "well"): + assert col in feats.columns, f"missing metadata column {col}" + + print(f"[3] Feature CSV OK: {n_cells} cells x {len(feature_cols)} dims -> {feature_csv}") + return n_cells + + +def verify_anndata(feature_csv: Path, n_cells: int) -> None: + """Verify the AnnData outputs from process_features_csv.""" + import anndata as ad + + anndata_dir = feature_csv.parent / "anndata_objects" + produced = list(anndata_dir.glob("features_processed_*.h5ad")) + assert produced, f"no features_processed_*.h5ad written in {anndata_dir}" + + reloaded = ad.read_h5ad(produced[0]) + assert reloaded.n_obs == n_cells, "AnnData cell count != feature CSV rows" + assert reloaded.n_vars == EXPECTED_DIM, "AnnData feature width mismatch" + print(f"[4] AnnData OK: {reloaded.n_obs} x {reloaded.n_vars} -> {produced[0]}") + + +def main() -> None: + tmp = Path(tempfile.mkdtemp(prefix="cell_dino_e2e_")) + link_dir = tmp / "link_csvs" + link_dir.mkdir(parents=True, exist_ok=True) + output_dir = tmp / "features" + + print(f"Working dir: {tmp}\n") + + # 1. subset + 2. config -> tmp + n_genes = subsample_link_csv(link_dir) + config = build_config(link_dir, output_dir) + config_path = tmp / "config.yml" + config_path.write_text(yaml.safe_dump(config)) + print(f"[2] Wrote config -> {config_path}") + + # 3. run extraction + extract_cell_dino_features(config=config) + feature_csv = output_dir / f"cell_dino_features_{CHANNEL}.csv" + n_cells = verify_feature_csv(feature_csv, n_genes) + + # 4. CSV -> AnnData + verify + process_features_csv(str(feature_csv), config_path=str(config_path)) + verify_anndata(feature_csv, n_cells) + + print(f"\n✓ Cell-DINO e2e PASSED. Outputs under: {tmp}") + + +if __name__ == "__main__": + main() diff --git a/tests/e2e_tests/cell_profiler_e2e.py b/tests/e2e_tests/cell_profiler_e2e.py new file mode 100644 index 0000000..6508337 --- /dev/null +++ b/tests/e2e_tests/cell_profiler_e2e.py @@ -0,0 +1,210 @@ +"""End-to-end test for the CellProfiler feature-extraction pipeline. + +Self-contained script (run directly, not via pytest). It exercises the full +production path for one minimal example: + + 1. subset a real per-well link CSV to one cell per gene KO, saved to a tmp dir + 2. point an inline config at that tmp dir (link_csv_dir) + tmp output_dir + 3. run extraction normally (extract_cp_features_parallel, run locally over the + whole minimal subset rather than via SLURM array jobs) -> cp_features.csv + 4. convert to AnnData (process_features_csv, CellProfiler branch: split by + reporter) -> per-reporter cell/guide/gene .h5ad + 5. verify the outputs at each step + +Unlike the embedding extractors, CellProfiler extraction is index-based and +reads its labels from a DataFrame, so the subsetted link CSV is loaded through +OpsDataManager (link_csv_dir) to build that DataFrame. Crops/measurements come +from the real phenotyping_v3.zarr; only the link CSV is subset. All outputs go +to a fresh tmp dir (printed at the end for inspection). + +Requires a GPU node (granularity is computed on GPU workers). + +Run with: + uv run python tests/e2e_tests/cell_profiler_e2e.py +""" + +import tempfile +from pathlib import Path + +import pandas as pd +import yaml + +from ops_model.data import data_loader +from ops_model.features.cp_extraction import extract_cp_features_parallel +from ops_model.features.processing_common import process_features_csv + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +EXPERIMENT = "ops0031_20250424" +WELL = "A/1/0" +WELL_PREFIX = WELL[0] + WELL[2] # "A/1/0" -> "A1" +# Multi-channel so the CellProfiler reporter-split branch is exercised. +OUT_CHANNELS = ["Phase2D"] +GUIDE_COL = "sgRNA" +MAX_GENES = 32 # cap (one cell per gene) to keep the run fast +NUM_WORKERS = 2 # CPU extraction workers (tiny dataset; keep small) + +ASSEMBLY_DIR = Path(f"/hpc/projects/icd.fast.ops/{EXPERIMENT}/3-assembly") +SOURCE_LINK_CSV = ASSEMBLY_DIR / f"{WELL_PREFIX}_linked_pheno_iss.csv" +PHENOTYPING_ZARR = ASSEMBLY_DIR / "phenotyping_v3.zarr" + + +def build_config(link_dir: Path, output_dir: Path) -> dict: + """Inline config; mirrors experiments/embedding/configs/cell-profiler/*.yml.""" + return { + "model_type": "cellprofiler", + "dataset_type": "basic", + "data_manager": { + "experiments": {EXPERIMENT: [WELL]}, + "batch_size": 1, + "data_split": [1, 0, 0], # CP processes the train split + "out_channels": OUT_CHANNELS, + "initial_yx_patch_size": [256, 256], + "final_yx_patch_size": [128, 128], + "num_workers": 0, + "link_csv_dir": str(link_dir), + }, + "output_dir": str(output_dir), + "cell_type": "A549", + "processing": {}, + # Skip PCA/UMAP: they'd fail on a tiny one-cell-per-gene dataset. + "aggregation": { + "guide_level": {"compute_embeddings": False}, + "gene_level": {"compute_embeddings": False}, + }, + } + + +# --------------------------------------------------------------------------- +# Steps +# --------------------------------------------------------------------------- + + +def subsample_link_csv(link_dir: Path) -> int: + """Copy the real link CSV into link_dir, keeping one cell per gene KO. + + Returns the number of genes (== number of cells kept). + """ + assert SOURCE_LINK_CSV.exists(), f"source link CSV not found: {SOURCE_LINK_CSV}" + assert PHENOTYPING_ZARR.exists(), f"phenotyping zarr not found: {PHENOTYPING_ZARR}" + + df = pd.read_csv(SOURCE_LINK_CSV) + gene_col = "gene_name" if "gene_name" in df.columns else "Gene name" + assert gene_col in df.columns, f"no gene column in {SOURCE_LINK_CSV}" + + # Mirror the loader's basic QC so kept cells survive into a batch. + df = df.dropna(subset=["segmentation_id", gene_col]) + + one_per_gene = df.groupby(gene_col, sort=True).head(1) + genes = one_per_gene[gene_col].unique()[:MAX_GENES] + subsampled = one_per_gene[one_per_gene[gene_col].isin(genes)].reset_index(drop=True) + + n_genes = subsampled[gene_col].nunique() + assert n_genes >= 2, "need at least a couple of genes for the test" + assert len(subsampled) == n_genes, "expected exactly one cell per gene" + + out_csv = link_dir / f"{WELL_PREFIX}_linked_pheno_iss.csv" + subsampled.to_csv(out_csv, index=False) + assert out_csv.exists() + print(f"[1] Subsampled link CSV: {n_genes} cells (1/gene) -> {out_csv}") + return n_genes + + +def run_extraction(link_dir: Path, output_dir: Path) -> Path: + """Run CellProfiler extraction locally over the whole minimal subset.""" + # Build the labels DataFrame from the subsetted link CSV (via link_csv_dir). + dm = data_loader.OpsDataManager( + experiments={EXPERIMENT: [WELL]}, + batch_size=1, + data_split=(1, 0, 0), + out_channels=OUT_CHANNELS, + initial_yx_patch_size=(256, 256), + link_csv_dir=str(link_dir), + verbose=False, + guide_col=GUIDE_COL, + ) + dm.construct_dataloaders(num_workers=0, dataset_type="cell_profile") + labels_df = dm.train_loader.dataset.labels_df + indices = list(range(len(labels_df))) + del dm # workers open their own zarr handles + + results_df = extract_cp_features_parallel( + experiment_dict={EXPERIMENT: [WELL]}, + indices=indices, + out_channels=OUT_CHANNELS, + num_workers=NUM_WORKERS, + labels_df=labels_df, + guide_col=GUIDE_COL, + ) + + output_dir.mkdir(parents=True, exist_ok=True) + cp_csv = output_dir / "cp_features.csv" + results_df.to_csv(cp_csv, index=False) + print(f"[3] Extracted CP features -> {cp_csv}") + return cp_csv + + +def verify_feature_csv(cp_csv: Path, n_genes: int) -> int: + """Verify the extracted CP feature CSV. Returns the cell count.""" + assert cp_csv.exists(), f"feature CSV not produced: {cp_csv}" + feats = pd.read_csv(cp_csv) + n_cells = len(feats) + assert 0 < n_cells <= n_genes, f"unexpected cell count {n_cells} (genes={n_genes})" + + meta_cols = {"label_int", "label_str", GUIDE_COL, "well", "experiment"} + for col in meta_cols: + assert col in feats.columns, f"missing metadata column {col}" + # CellProfiler features are variable-width; just confirm there are some. + feature_cols = [c for c in feats.columns if c not in meta_cols and "position" not in c] + assert len(feature_cols) > 0, "no CellProfiler feature columns found" + + print(f"[3] Feature CSV OK: {n_cells} cells x {len(feature_cols)} CP features -> {cp_csv}") + return n_cells + + +def verify_anndata(cp_csv: Path) -> None: + """Verify the per-reporter AnnData outputs from process_features_csv.""" + import anndata as ad + + anndata_dir = cp_csv.parent / "anndata_objects" + produced = sorted(anndata_dir.glob("features_processed_*.h5ad")) + assert produced, f"no features_processed_*.h5ad written in {anndata_dir}" + + print(f"[4] AnnData OK: {len(produced)} reporter file(s)") + for path in produced: + adata = ad.read_h5ad(path) + assert adata.n_obs > 0, f"empty cell-level AnnData: {path}" + assert adata.n_vars > 0, f"no features in: {path}" + print(f" {path.name}: {adata.n_obs} x {adata.n_vars}") + + +def main() -> None: + tmp = Path(tempfile.mkdtemp(prefix="cell_profiler_e2e_")) + link_dir = tmp / "link_csvs" + link_dir.mkdir(parents=True, exist_ok=True) + output_dir = tmp / "cell-profiler" + + print(f"Working dir: {tmp}\n") + + # 1. subset + 2. config -> tmp + n_genes = subsample_link_csv(link_dir) + config = build_config(link_dir, output_dir) + config_path = tmp / "config.yml" + config_path.write_text(yaml.safe_dump(config)) + print(f"[2] Wrote config -> {config_path}") + + # 3. run extraction + cp_csv = run_extraction(link_dir, output_dir) + n_cells = verify_feature_csv(cp_csv, n_genes) + + # 4. CSV -> AnnData (CellProfiler branch: split by reporter) + verify + process_features_csv(str(cp_csv), config_path=str(config_path)) + verify_anndata(cp_csv) + + print(f"\n✓ CellProfiler e2e PASSED ({n_cells} cells). Outputs under: {tmp}") + + +if __name__ == "__main__": + main() diff --git a/tests/e2e_tests/combination_e2e.py b/tests/e2e_tests/combination_e2e.py new file mode 100644 index 0000000..5ee4a28 --- /dev/null +++ b/tests/e2e_tests/combination_e2e.py @@ -0,0 +1,168 @@ +"""End-to-end test for the embedding combination pipeline (pca_optimization). + +Self-contained script (run directly, not via pytest). It exercises the full +multi-experiment combination path for a minimal example: + + 1. discover 5 experiments that share the cell_dino reporter set + (Phase / 5xUPRE / SEC61B), pull ALL channels (reporters) for each + 2. subset every per-experiment, per-reporter cell-level h5ad to NTC + 31 + genes (<=50 cells per gene), saved to a tmp dir + 3. write a config pointing `signal_paths` at the subsetted h5ads (one pooled + list of the 5 experiments per reporter) and run the pipeline in-process + (slurm: false) via run_from_config + 4. verify the canonical combined outputs are produced + +The combiner pools cells across experiments sharing a signal, fits PCA, and +aggregates to guide/gene level (Phase 2 also NTC-normalizes + scores metrics). +On a 32-gene subset the metric values are not meaningful, so this test only +asserts the pipeline completes and produces sane-shaped outputs. All outputs go +to a fresh tmp dir (printed at the end). + +Run with: + uv run python tests/e2e_tests/combination_e2e.py +""" + +import tempfile +from pathlib import Path + +import numpy as np +import anndata as ad +import yaml + +from ops_model.post_process.combination.pca_optimization import run_from_config + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +ROOT = Path("/hpc/projects/icd.fast.ops") +FEATURE_DIR = "cell_dino_features" # feature mode: cell_dino +REPORTERS = ["Phase", "5xUPRE", "SEC61B"] # "all channels" +N_EXPERIMENTS = 5 +N_GENES = 32 # NTC + 31 perturbations +CELLS_PER_GENE = 50 # cap per (gene, experiment) for speed +CHAD_ANNOTATION = ( + "/hpc/projects/icd.fast.ops/configs/gene_clusters/" + "val_library_chad_positive_controls_v1.yml" +) + + +def anndata_path(experiment: str, reporter: str) -> Path: + return ( + ROOT + / experiment + / "3-assembly" + / FEATURE_DIR + / "anndata_objects" + / f"features_processed_{reporter}.h5ad" + ) + + +def find_experiments(n: int) -> list[str]: + """First n experiments (sorted) with cell_dino h5ads for all REPORTERS.""" + found = [] + for d in sorted(ROOT.glob(f"ops0*/3-assembly/{FEATURE_DIR}/anndata_objects")): + experiment = d.relative_to(ROOT).parts[0] + if all(anndata_path(experiment, r).exists() for r in REPORTERS): + found.append(experiment) + if len(found) == n: + break + assert len(found) == n, f"only found {len(found)} experiments with all reporters" + return found + + +def pick_genes(experiment: str) -> list[str]: + """NTC + the 31 most-common non-NTC perturbations in one reference h5ad.""" + a = ad.read_h5ad(anndata_path(experiment, REPORTERS[0]), backed="r") + counts = a.obs["perturbation"].astype(str).value_counts() + non_ntc = [g for g in counts.index if g != "NTC"][: N_GENES - 1] + genes = ["NTC"] + non_ntc + assert len(genes) == N_GENES + return genes + + +def subset_h5ad(src: Path, genes: list[str], out: Path) -> int: + """Subset src to `genes`, capping cells per gene; write to out. Returns n_obs.""" + a = ad.read_h5ad(src, backed="r") + pert = a.obs["perturbation"].astype(str).values + keep = [] + for g in genes: + idx = np.where(pert == g)[0] + keep.extend(idx[:CELLS_PER_GENE].tolist()) + keep = sorted(keep) + sub = a[keep].to_memory() + sub.write_h5ad(out) + return sub.n_obs + + +def build_config(signal_paths: dict, output_dir: Path) -> dict: + return { + "cell_dino": True, + "signal_paths": signal_paths, # {reporter: [pooled per-experiment h5ads]} + "output_dir": str(output_dir), + "run_tag": "e2e", + "slurm": False, # run Phase 1 + Phase 2 in-process + "fixed_threshold": 0.80, # single PCA cutoff (skip the consensus sweep) + "norm_method": "ntc", + "zscore_per_experiment": True, + "second_pca": False, # skip the second-pass PCA consensus + "chad_annotation": CHAD_ANNOTATION, + } + + +def verify_outputs(output_dir: Path) -> None: + gene = list(output_dir.rglob("gene_embedding_pca_optimized.h5ad")) + guide = list(output_dir.rglob("guide_pca_optimized.h5ad")) + report = list(output_dir.rglob("pca_report.csv")) + assert gene, f"no gene_embedding_pca_optimized.h5ad under {output_dir}" + assert guide, f"no guide_pca_optimized.h5ad under {output_dir}" + assert report, f"no pca_report.csv under {output_dir}" + + adata_gene = ad.read_h5ad(gene[0]) + adata_guide = ad.read_h5ad(guide[0]) + assert adata_gene.n_obs > 0 and adata_gene.n_vars > 0, "empty gene-level output" + assert adata_guide.n_obs > 0 and adata_guide.n_vars > 0, "empty guide-level output" + print(f"[4] Combined outputs OK:") + print(f" guide: {adata_guide.n_obs} x {adata_guide.n_vars} -> {guide[0]}") + print(f" gene: {adata_gene.n_obs} x {adata_gene.n_vars} -> {gene[0]}") + print(f" report: {report[0]}") + + +def main() -> None: + tmp = Path(tempfile.mkdtemp(prefix="combination_e2e_")) + inputs_dir = tmp / "inputs" + inputs_dir.mkdir(parents=True, exist_ok=True) + output_dir = tmp / "combined" + + print(f"Working dir: {tmp}\n") + + experiments = find_experiments(N_EXPERIMENTS) + genes = pick_genes(experiments[0]) + print(f"[0] Experiments: {experiments}") + print(f"[0] Genes ({len(genes)}): NTC + {len(genes) - 1} perturbations\n") + + # 1 + 2. subset every (experiment, reporter) h5ad -> tmp, build signal_paths + signal_paths: dict[str, list[str]] = {r: [] for r in REPORTERS} + for experiment in experiments: + for reporter in REPORTERS: + out = inputs_dir / f"{experiment}_{reporter}.h5ad" + n = subset_h5ad(anndata_path(experiment, reporter), genes, out) + signal_paths[reporter].append(str(out)) + print(f"[1] {experiment}/{reporter}: {n} cells -> {out.name}") + + config = build_config(signal_paths, output_dir) + config_path = tmp / "config.yml" + config_path.write_text(yaml.safe_dump(config)) + print(f"\n[2] Wrote config -> {config_path}") + + # 3. run the combination pipeline in-process + run_from_config(str(config_path)) + + # 4. verify + verify_outputs(output_dir) + + print(f"\n✓ Combination e2e PASSED. Outputs under: {tmp}") + + +if __name__ == "__main__": + main() diff --git a/tests/e2e_tests/dinov3_e2e.py b/tests/e2e_tests/dinov3_e2e.py new file mode 100644 index 0000000..7a24f04 --- /dev/null +++ b/tests/e2e_tests/dinov3_e2e.py @@ -0,0 +1,168 @@ +"""End-to-end test for the DINOv3 feature-extraction pipeline. + +Self-contained script (run directly, not via pytest). It exercises the full +production path for one minimal example: + + 1. subset a real per-well link CSV to one cell per gene KO, saved to a tmp dir + 2. point an inline config at that tmp dir (link_csv_dir) + tmp output_dir + 3. run extraction normally (extract_dinov3_features) -> feature CSV + 4. convert to AnnData (process_features_csv) -> cell/guide/gene .h5ad + 5. verify the outputs at each step + +Crops are read from the real phenotyping_v3.zarr at the subsampled bboxes, so +only the link CSV is subset; nothing else is mutated. All outputs go to a fresh +tmp dir (printed at the end for inspection). + +Requires a GPU node with the DINOv3 checkpoint on disk. + +Run with: + uv run python tests/e2e_tests/dinov3_e2e.py +""" + +import tempfile +from pathlib import Path + +import pandas as pd +import yaml + +from ops_model.models.dinov3 import extract_dinov3_features +from ops_model.features.processing_common import process_features_csv + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +EXPERIMENT = "ops0031_20250424" +WELL = "A/1/0" +WELL_PREFIX = WELL[0] + WELL[2] # "A/1/0" -> "A1" +CHANNEL = "Phase2D" +EXPECTED_DIM = 1024 # DINOv3 ViT-L/16 embedding width +MAX_GENES = 32 # cap (one cell per gene) to keep the run fast + +ASSEMBLY_DIR = Path(f"/hpc/projects/icd.fast.ops/{EXPERIMENT}/3-assembly") +SOURCE_LINK_CSV = ASSEMBLY_DIR / f"{WELL_PREFIX}_linked_pheno_iss.csv" +PHENOTYPING_ZARR = ASSEMBLY_DIR / "phenotyping_v3.zarr" + + +def build_config(link_dir: Path, output_dir: Path) -> dict: + """Inline config; mirrors configs/model_configs/dinov3/ops0031_dino.yml.""" + return { + "model_type": "dinov3", + "embedding_type": "dinov3", + "dataset_type": "basic", + "data_manager": { + "experiments": {EXPERIMENT: [WELL]}, + "batch_size": 8, + "data_split": [0, 0, 1], # everything in the test loader + "out_channels": [CHANNEL], + "initial_yx_patch_size": [256, 256], + "final_yx_patch_size": [128, 128], + "num_workers": 0, + "link_csv_dir": str(link_dir), + }, + "output_dir": str(output_dir), + "cell_type": "A549", + # Skip PCA/UMAP: they'd fail on a tiny one-cell-per-gene dataset. + "aggregation": { + "guide_level": {"compute_embeddings": False}, + "gene_level": {"compute_embeddings": False}, + }, + } + + +# --------------------------------------------------------------------------- +# Steps +# --------------------------------------------------------------------------- + + +def subsample_link_csv(link_dir: Path) -> int: + """Copy the real link CSV into link_dir, keeping one cell per gene KO. + + Returns the number of genes (== number of cells kept). + """ + assert SOURCE_LINK_CSV.exists(), f"source link CSV not found: {SOURCE_LINK_CSV}" + assert PHENOTYPING_ZARR.exists(), f"phenotyping zarr not found: {PHENOTYPING_ZARR}" + + df = pd.read_csv(SOURCE_LINK_CSV) + gene_col = "gene_name" if "gene_name" in df.columns else "Gene name" + assert gene_col in df.columns, f"no gene column in {SOURCE_LINK_CSV}" + + # Mirror the loader's basic QC so kept cells survive into a batch. + df = df.dropna(subset=["segmentation_id", gene_col]) + + one_per_gene = df.groupby(gene_col, sort=True).head(1) + genes = one_per_gene[gene_col].unique()[:MAX_GENES] + subsampled = one_per_gene[one_per_gene[gene_col].isin(genes)].reset_index(drop=True) + + n_genes = subsampled[gene_col].nunique() + assert n_genes >= 2, "need at least a couple of genes for the test" + assert len(subsampled) == n_genes, "expected exactly one cell per gene" + + out_csv = link_dir / f"{WELL_PREFIX}_linked_pheno_iss.csv" + subsampled.to_csv(out_csv, index=False) + assert out_csv.exists() + print(f"[1] Subsampled link CSV: {n_genes} cells (1/gene) -> {out_csv}") + return n_genes + + +def verify_feature_csv(feature_csv: Path, n_genes: int) -> int: + """Verify the extracted feature CSV. Returns the cell count.""" + assert feature_csv.exists(), f"feature CSV not produced: {feature_csv}" + feats = pd.read_csv(feature_csv) + n_cells = len(feats) + assert 0 < n_cells <= n_genes, f"unexpected cell count {n_cells} (genes={n_genes})" + + feature_cols = [c for c in feats.columns if str(c).isdigit()] + assert ( + len(feature_cols) == EXPECTED_DIM + ), f"expected {EXPECTED_DIM} feature dims, got {len(feature_cols)}" + for col in ("label_int", "label_str", "sgRNA", "experiment", "well"): + assert col in feats.columns, f"missing metadata column {col}" + + print(f"[3] Feature CSV OK: {n_cells} cells x {len(feature_cols)} dims -> {feature_csv}") + return n_cells + + +def verify_anndata(feature_csv: Path, n_cells: int) -> None: + """Verify the AnnData outputs from process_features_csv.""" + import anndata as ad + + anndata_dir = feature_csv.parent / "anndata_objects" + produced = list(anndata_dir.glob("features_processed_*.h5ad")) + assert produced, f"no features_processed_*.h5ad written in {anndata_dir}" + + reloaded = ad.read_h5ad(produced[0]) + assert reloaded.n_obs == n_cells, "AnnData cell count != feature CSV rows" + assert reloaded.n_vars == EXPECTED_DIM, "AnnData feature width mismatch" + print(f"[4] AnnData OK: {reloaded.n_obs} x {reloaded.n_vars} -> {produced[0]}") + + +def main() -> None: + tmp = Path(tempfile.mkdtemp(prefix="dinov3_e2e_")) + link_dir = tmp / "link_csvs" + link_dir.mkdir(parents=True, exist_ok=True) + output_dir = tmp / "features" + + print(f"Working dir: {tmp}\n") + + # 1. subset + 2. config -> tmp + n_genes = subsample_link_csv(link_dir) + config = build_config(link_dir, output_dir) + config_path = tmp / "config.yml" + config_path.write_text(yaml.safe_dump(config)) + print(f"[2] Wrote config -> {config_path}") + + # 3. run extraction + extract_dinov3_features(config=config) + feature_csv = output_dir / f"dinov3_features_{CHANNEL}.csv" + n_cells = verify_feature_csv(feature_csv, n_genes) + + # 4. CSV -> AnnData + verify + process_features_csv(str(feature_csv), config_path=str(config_path)) + verify_anndata(feature_csv, n_cells) + + print(f"\n✓ DINOv3 e2e PASSED. Outputs under: {tmp}") + + +if __name__ == "__main__": + main() diff --git a/tests/e2e_tests/subcell_e2e.py b/tests/e2e_tests/subcell_e2e.py new file mode 100644 index 0000000..801d985 --- /dev/null +++ b/tests/e2e_tests/subcell_e2e.py @@ -0,0 +1,179 @@ +"""End-to-end test for the SubCell feature-extraction pipeline. + +Self-contained script (run directly, not via pytest). It exercises the full +production path for one minimal example: + + 1. subset a real per-well link CSV to one cell per gene KO, saved to a tmp dir + 2. point an inline config at that tmp dir (link_csv_dir) + tmp output_dir + 3. run extraction normally (extract_subcell_features) -> feature CSV + 4. convert to AnnData (process_features_csv) -> cell/guide/gene .h5ad + 5. verify the outputs at each step + +SubCell is a two-channel (DNA + protein) model: out_channels is +[dna_channel, protein_channel] and the embedding is named after the protein +channel. Crops are read from the real phenotyping_v3.zarr at the subsampled +bboxes, so only the link CSV is subset. All outputs go to a fresh tmp dir +(printed at the end for inspection). + +Requires a GPU node; the SubCell weights download from S3 on first use and cache +under the model_checkpoints dir. + +Run with: + uv run python tests/e2e_tests/subcell_e2e.py +""" + +import tempfile +from pathlib import Path + +import pandas as pd +import yaml + +from ops_model.models.subcell import extract_subcell_features +from ops_model.features.processing_common import process_features_csv + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +EXPERIMENT = "ops0031_20250424" +WELL = "A/1/0" +WELL_PREFIX = WELL[0] + WELL[2] # "A/1/0" -> "A1" +DNA_CHANNEL = "nuclei_prediction" # SubCell bg model's nuclear/Blue channel +PROTEIN_CHANNEL = "GFP" # protein-of-interest; embedding is named after this +OUT_CHANNELS = [DNA_CHANNEL, PROTEIN_CHANNEL] # [DNA, protein] +EXPECTED_DIM = 1536 # SubCell gated-attention pooler (2 heads x 768) +MAX_GENES = 32 # cap (one cell per gene) to keep the run fast + +ASSEMBLY_DIR = Path(f"/hpc/projects/icd.fast.ops/{EXPERIMENT}/3-assembly") +SOURCE_LINK_CSV = ASSEMBLY_DIR / f"{WELL_PREFIX}_linked_pheno_iss.csv" +PHENOTYPING_ZARR = ASSEMBLY_DIR / "phenotyping_v3.zarr" + + +def build_config(link_dir: Path, output_dir: Path) -> dict: + """Inline config; mirrors experiments/embedding/configs/subcell/ops0031_subcell.yml. + + Note: extract_subcell_features consumes out_channels directly as + [dna_channel, protein_channel] (the orchestrator's dna_channel pairing is + inlined here). + """ + return { + "model_type": "subcell", + "embedding_type": "subcell", + "dataset_type": "basic", + "data_manager": { + "experiments": {EXPERIMENT: [WELL]}, + "batch_size": 4, + "data_split": [0, 0, 1], # everything in the test loader + "out_channels": OUT_CHANNELS, + "dna_channel": DNA_CHANNEL, + "initial_yx_patch_size": [256, 256], + "final_yx_patch_size": [128, 128], + "num_workers": 0, + "link_csv_dir": str(link_dir), + }, + "output_dir": str(output_dir), + "cell_type": "A549", + # Skip PCA/UMAP: they'd fail on a tiny one-cell-per-gene dataset. + "aggregation": { + "guide_level": {"compute_embeddings": False}, + "gene_level": {"compute_embeddings": False}, + }, + } + + +# --------------------------------------------------------------------------- +# Steps +# --------------------------------------------------------------------------- + + +def subsample_link_csv(link_dir: Path) -> int: + """Copy the real link CSV into link_dir, keeping one cell per gene KO. + + Returns the number of genes (== number of cells kept). + """ + assert SOURCE_LINK_CSV.exists(), f"source link CSV not found: {SOURCE_LINK_CSV}" + assert PHENOTYPING_ZARR.exists(), f"phenotyping zarr not found: {PHENOTYPING_ZARR}" + + df = pd.read_csv(SOURCE_LINK_CSV) + gene_col = "gene_name" if "gene_name" in df.columns else "Gene name" + assert gene_col in df.columns, f"no gene column in {SOURCE_LINK_CSV}" + + # Mirror the loader's basic QC so kept cells survive into a batch. + df = df.dropna(subset=["segmentation_id", gene_col]) + + one_per_gene = df.groupby(gene_col, sort=True).head(1) + genes = one_per_gene[gene_col].unique()[:MAX_GENES] + subsampled = one_per_gene[one_per_gene[gene_col].isin(genes)].reset_index(drop=True) + + n_genes = subsampled[gene_col].nunique() + assert n_genes >= 2, "need at least a couple of genes for the test" + assert len(subsampled) == n_genes, "expected exactly one cell per gene" + + out_csv = link_dir / f"{WELL_PREFIX}_linked_pheno_iss.csv" + subsampled.to_csv(out_csv, index=False) + assert out_csv.exists() + print(f"[1] Subsampled link CSV: {n_genes} cells (1/gene) -> {out_csv}") + return n_genes + + +def verify_feature_csv(feature_csv: Path, n_genes: int) -> int: + """Verify the extracted feature CSV. Returns the cell count.""" + assert feature_csv.exists(), f"feature CSV not produced: {feature_csv}" + feats = pd.read_csv(feature_csv) + n_cells = len(feats) + assert 0 < n_cells <= n_genes, f"unexpected cell count {n_cells} (genes={n_genes})" + + feature_cols = [c for c in feats.columns if str(c).isdigit()] + assert ( + len(feature_cols) == EXPECTED_DIM + ), f"expected {EXPECTED_DIM} feature dims, got {len(feature_cols)}" + for col in ("label_int", "label_str", "sgRNA", "experiment", "well"): + assert col in feats.columns, f"missing metadata column {col}" + + print(f"[3] Feature CSV OK: {n_cells} cells x {len(feature_cols)} dims -> {feature_csv}") + return n_cells + + +def verify_anndata(feature_csv: Path, n_cells: int) -> None: + """Verify the AnnData outputs from process_features_csv.""" + import anndata as ad + + anndata_dir = feature_csv.parent / "anndata_objects" + produced = list(anndata_dir.glob("features_processed_*.h5ad")) + assert produced, f"no features_processed_*.h5ad written in {anndata_dir}" + + reloaded = ad.read_h5ad(produced[0]) + assert reloaded.n_obs == n_cells, "AnnData cell count != feature CSV rows" + assert reloaded.n_vars == EXPECTED_DIM, "AnnData feature width mismatch" + print(f"[4] AnnData OK: {reloaded.n_obs} x {reloaded.n_vars} -> {produced[0]}") + + +def main() -> None: + tmp = Path(tempfile.mkdtemp(prefix="subcell_e2e_")) + link_dir = tmp / "link_csvs" + link_dir.mkdir(parents=True, exist_ok=True) + output_dir = tmp / "features" + + print(f"Working dir: {tmp}\n") + + # 1. subset + 2. config -> tmp + n_genes = subsample_link_csv(link_dir) + config = build_config(link_dir, output_dir) + config_path = tmp / "config.yml" + config_path.write_text(yaml.safe_dump(config)) + print(f"[2] Wrote config -> {config_path}") + + # 3. run extraction (output is named after the protein channel) + extract_subcell_features(config=config) + feature_csv = output_dir / f"subcell_features_{PROTEIN_CHANNEL}.csv" + n_cells = verify_feature_csv(feature_csv, n_genes) + + # 4. CSV -> AnnData + verify + process_features_csv(str(feature_csv), config_path=str(config_path)) + verify_anndata(feature_csv, n_cells) + + print(f"\n✓ SubCell e2e PASSED. Outputs under: {tmp}") + + +if __name__ == "__main__": + main() diff --git a/tests/e2e_tests/testing_plan b/tests/e2e_tests/testing_plan new file mode 100644 index 0000000..3f2e074 --- /dev/null +++ b/tests/e2e_tests/testing_plan @@ -0,0 +1,13 @@ +There are 5 core features of ops_model that should each have a self contained end-to-end test + - 4 model inference feature extractors + - embedding post-processing and combination + +Struture of the tests: + - the tests should be structured as python scripts, not pytests + - they should be based on the following outline: + 1. define a config at the top of the script + 2. load csvs from a select subset of experiments and subset them to a minimal example, then save that subsample to a tmp dir. + have the config point to this as the source + 2. write that config to a tmp dir + 3. run feature extraction normally + 4. save outputs to the tmp dir and verify that they ran successfully \ No newline at end of file From 3669801969b8d213d13f4d789d0f14cec19c1fe4 Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 10 Jun 2026 16:01:13 -0700 Subject: [PATCH 06/11] Fix stale / personal hardcoded config paths in scoring + overlays Two hardcoded paths that silently disabled features outside one environment: - annotated_gene_panel_July2025.csv: repoint from the now-missing /hpc/projects/intracellular_dashboard/ops/configs path to the present icd.fast.ops/configs location (data/embeddings/utils.py, funk_clusters.py, combination/analysis/embedding_overlays.py). The dead path was caught-and- skipped, silently dropping CORUM consistency scoring. - gene_supercategory_mapping.yaml: default to the in-repo copy (resolved from the repo root) instead of a personal home-dir path that was permission-denied for other users (combination/analysis/embedding_overlays.py, compare_modalities.py, models/attention/atlas/attention_atlas.py). Co-Authored-By: Claude Opus 4.8 (1M context) --- src/ops_model/data/embeddings/funk_clusters.py | 2 +- src/ops_model/data/embeddings/utils.py | 2 +- .../models/attention/atlas/attention_atlas.py | 9 +++++++-- .../combination/analysis/compare_modalities.py | 10 +++++++--- .../combination/analysis/embedding_overlays.py | 11 ++++++++--- 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/ops_model/data/embeddings/funk_clusters.py b/src/ops_model/data/embeddings/funk_clusters.py index e2fdb6c..6b84968 100644 --- a/src/ops_model/data/embeddings/funk_clusters.py +++ b/src/ops_model/data/embeddings/funk_clusters.py @@ -8,7 +8,7 @@ def get_funk_clusters(): path = Path( - "/hpc/projects/intracellular_dashboard/ops/configs/annotated_gene_panel_July2025.csv" + "/hpc/projects/icd.fast.ops/configs/annotated_gene_panel_July2025.csv" ) df = pd.read_csv(path) diff --git a/src/ops_model/data/embeddings/utils.py b/src/ops_model/data/embeddings/utils.py index 28dda45..5395752 100644 --- a/src/ops_model/data/embeddings/utils.py +++ b/src/ops_model/data/embeddings/utils.py @@ -23,7 +23,7 @@ def group_guides(): def get_gene_complexes(): - path = "/hpc/projects/intracellular_dashboard/ops/configs/annotated_gene_panel_July2025.csv" + path = "/hpc/projects/icd.fast.ops/configs/annotated_gene_panel_July2025.csv" df = pd.read_csv(path) complex_df = df[["Gene.name", "In_same_complex_with"]] gene_list = list(complex_df["Gene.name"]) diff --git a/src/ops_model/models/attention/atlas/attention_atlas.py b/src/ops_model/models/attention/atlas/attention_atlas.py index a6c2dd3..a681d44 100644 --- a/src/ops_model/models/attention/atlas/attention_atlas.py +++ b/src/ops_model/models/attention/atlas/attention_atlas.py @@ -187,8 +187,13 @@ def ta(text, color=None): from ops_utils.data.filesystem import resolve_experiment_name -DEFAULT_SUPERCATEGORY_CONFIG = Path( - "/hpc/mydata/gav.sturm/ops_mono/organelle_profiler/configs/gene_supercategory_mapping.yaml" +# In-repo default (was a hardcoded personal home-dir path). Resolved relative to +# the repo root so it works for any checkout; callers may override. +DEFAULT_SUPERCATEGORY_CONFIG = ( + Path(__file__).resolve().parents[6] + / "organelle_profiler" + / "configs" + / "gene_supercategory_mapping.yaml" ) diff --git a/src/ops_model/post_process/combination/analysis/compare_modalities.py b/src/ops_model/post_process/combination/analysis/compare_modalities.py index eb45080..571af7a 100644 --- a/src/ops_model/post_process/combination/analysis/compare_modalities.py +++ b/src/ops_model/post_process/combination/analysis/compare_modalities.py @@ -71,9 +71,13 @@ NULL_SIZE = 100_000 DEFAULT_GROUPS = ("cp", "4i", "matched_livecell_best") -DEFAULT_SUPERCATEGORY_CONFIG = Path( - "/home/gav.sturm/linked_folders/mydata/ops_mono/organelle_profiler/configs/" - "gene_supercategory_mapping.yaml" +# In-repo default (was a hardcoded personal home-dir path). Resolved relative to +# the repo root so it works for any checkout; callers may override. +DEFAULT_SUPERCATEGORY_CONFIG = ( + Path(__file__).resolve().parents[6] + / "organelle_profiler" + / "configs" + / "gene_supercategory_mapping.yaml" ) GROUP_PALETTE = { "cp": "#d97706", # amber diff --git a/src/ops_model/post_process/combination/analysis/embedding_overlays.py b/src/ops_model/post_process/combination/analysis/embedding_overlays.py index d41ab9d..f2c3957 100644 --- a/src/ops_model/post_process/combination/analysis/embedding_overlays.py +++ b/src/ops_model/post_process/combination/analysis/embedding_overlays.py @@ -45,8 +45,13 @@ "Pfam_Domains_2019", "COMPARTMENTS_Curated_2025", ) -DEFAULT_SUPERCATEGORY_CONFIG = Path( - "/home/gav.sturm/linked_folders/mydata/ops_mono/organelle_profiler/configs/gene_supercategory_mapping.yaml" +# In-repo default (was a hardcoded personal home-dir path). Resolved relative to +# the repo root so it works for any checkout; callers may override. +DEFAULT_SUPERCATEGORY_CONFIG = ( + Path(__file__).resolve().parents[6] + / "organelle_profiler" + / "configs" + / "gene_supercategory_mapping.yaml" ) @@ -171,7 +176,7 @@ def _build_corum_map() -> Dict[str, str]: import ast panel_path = Path( - "/hpc/projects/intracellular_dashboard/ops/configs/annotated_gene_panel_July2025.csv" + "/hpc/projects/icd.fast.ops/configs/annotated_gene_panel_July2025.csv" ) if not panel_path.exists(): return {} From dc79c32c0587abef8076a76ede9d8b92dfd0339b Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Thu, 11 Jun 2026 11:56:06 -0700 Subject: [PATCH 07/11] Score CORUM/CHAD/EBI consistency independently _score_consistency previously ran CORUM, CHAD and EBI inside one shared try/except, so a failure in one metric (e.g. CHAD failing to parse its annotation) silently suppressed the others and dropped EBI entirely. Each metric now runs in its own try/except and returns (None, 0.0) on failure; the panel/volcano plots are best-effort. One metric failing no longer takes the rest down. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../pca_optimization/embeddings.py | 122 +++++++++--------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/src/ops_model/post_process/combination/pca_optimization/embeddings.py b/src/ops_model/post_process/combination/pca_optimization/embeddings.py index b070d1a..24e9007 100644 --- a/src/ops_model/post_process/combination/pca_optimization/embeddings.py +++ b/src/ops_model/post_process/combination/pca_optimization/embeddings.py @@ -364,8 +364,9 @@ def _score_consistency( Returns ``(corum_map, corum_ratio, chad_map, chad_ratio, ebi_map, ebi_ratio)`` — six values now that EBI is a permanent third - consistency metric. Failure mode is six zeros so callers can keep - unpacking with one shape. + consistency metric. Each metric is scored independently: a failure in one + (e.g. an annotation that won't parse) yields ``(None, 0.0)`` for that metric + alone and never suppresses the others. NOTE: ``phenotypic_consistency_*`` is called WITHOUT ``activity_map``, so consistency is computed over all genes regardless of the ``suffix`` @@ -379,67 +380,70 @@ def _score_consistency( label = "all geneKOs" if activity_map is None: return None, 0.0, None, 0.0, None, 0.0 - try: - from ops_utils.analysis.map_scores import ( - phenotypic_consistency_corum, - phenotypic_consistency_ebi, - phenotypic_consistency_manual_annotation, - ) + from ops_utils.analysis.map_scores import ( + phenotypic_consistency_corum, + phenotypic_consistency_ebi, + phenotypic_consistency_manual_annotation, + ) - _logger.info(f"Running CORUM consistency ({label})...") - consistency_corum_map, consistency_corum_ratio = phenotypic_consistency_corum( + def _score(name, csv_stem, fn): + """Run one consistency metric independently; (None, 0.0) on failure.""" + _logger.info(f"Running {name} consistency ({label})...") + try: + cmap, ratio = fn() + cmap.to_csv(metrics_dir / f"{csv_stem}{suffix}.csv", index=False) + _logger.info(f" {name} ({label}): {ratio:.1%}") + return cmap, ratio + except Exception as exc: + _logger.error(f" {name} consistency ({label}) failed: {exc}") + return None, 0.0 + + corum_map, corum_ratio = _score( + "CORUM", + "phenotypic_consistency_corum", + lambda: phenotypic_consistency_corum( adata_gene, plot_results=False, null_size=100_000, cache_similarity=True, distance=distance, - ) - consistency_corum_map.to_csv( - metrics_dir / f"phenotypic_consistency_corum{suffix}.csv", index=False - ) - _logger.info(f" CORUM ({label}): {consistency_corum_ratio:.1%}") - - _logger.info(f"Running CHAD consistency ({label})...") - consistency_manual_map, consistency_manual_ratio = ( - phenotypic_consistency_manual_annotation( - adata_gene, - plot_results=False, - null_size=100_000, - cache_similarity=True, - distance=distance, - annotation_path=CHAD_ANNOTATION_PATH, - ) - ) - consistency_manual_map.to_csv( - metrics_dir / f"phenotypic_consistency_manual{suffix}.csv", index=False - ) - _logger.info(f" Manual CHAD ({label}): {consistency_manual_ratio:.1%}") - - _logger.info(f"Running EBI consistency ({label})...") - consistency_ebi_map, consistency_ebi_ratio = phenotypic_consistency_ebi( + ), + ) + chad_map, chad_ratio = _score( + "CHAD", + "phenotypic_consistency_manual", + lambda: phenotypic_consistency_manual_annotation( + adata_gene, + plot_results=False, + null_size=100_000, + cache_similarity=True, + distance=distance, + annotation_path=CHAD_ANNOTATION_PATH, + ), + ) + ebi_map, ebi_ratio = _score( + "EBI", + "phenotypic_consistency_ebi", + lambda: phenotypic_consistency_ebi( adata_gene, plot_results=False, null_size=100_000, cache_similarity=True, distance=distance, annotation_path=EBI_ANNOTATION_PATH, - ) - consistency_ebi_map.to_csv( - metrics_dir / f"phenotypic_consistency_ebi{suffix}.csv", index=False - ) - _logger.info(f" EBI ({label}): {consistency_ebi_ratio:.1%}") + ), + ) - # 1×3 panel: CORUM + CHAD + EBI scatter (existing style) + # Plots are best-effort and must not drop already-computed metrics. + # plot_map_scatter renders a "No data" placeholder for any None/empty map. + try: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(22, 7)) - plot_map_scatter(ax1, consistency_corum_map, - f"Consistency CORUM ({label})", - consistency_corum_ratio, show_ntc=False) - plot_map_scatter(ax2, consistency_manual_map, - f"Consistency CHAD ({label})", - consistency_manual_ratio, show_ntc=False) - plot_map_scatter(ax3, consistency_ebi_map, - f"Consistency EBI ({label})", - consistency_ebi_ratio, show_ntc=False) + plot_map_scatter(ax1, corum_map, f"Consistency CORUM ({label})", + corum_ratio, show_ntc=False) + plot_map_scatter(ax2, chad_map, f"Consistency CHAD ({label})", + chad_ratio, show_ntc=False) + plot_map_scatter(ax3, ebi_map, f"Consistency EBI ({label})", + ebi_ratio, show_ntc=False) fig.suptitle( f"Consistency Metrics ({label}) — {total_feats} features", fontsize=13, fontweight="bold", @@ -450,15 +454,18 @@ def _score_consistency( ) plt.close(fig) _logger.info(f" Saved plots/map_consistency{suffix}.png") + except Exception as exc: + _logger.warning(f" Consistency panel plot failed: {exc}") - # Standalone EBI panel using the canonical map-scatter helper — - # same style as the activity / distinctiveness mAP scatters. + # Standalone EBI panel using the canonical map-scatter helper — + # same style as the activity / distinctiveness mAP scatters. + if ebi_map is not None: try: fig, ax = plt.subplots(figsize=(8, 7)) plot_map_scatter( - ax, consistency_ebi_map, + ax, ebi_map, f"Consistency EBI ({label})", - consistency_ebi_ratio, show_ntc=False, + ebi_ratio, show_ntc=False, ) fig.tight_layout() fig.savefig(plots_dir / f"map_ebi_volcano{suffix}.png", @@ -468,11 +475,4 @@ def _score_consistency( except Exception as exc: _logger.warning(f" EBI volcano plot failed: {exc}") - return ( - consistency_corum_map, consistency_corum_ratio, - consistency_manual_map, consistency_manual_ratio, - consistency_ebi_map, consistency_ebi_ratio, - ) - except Exception as exc: - _logger.error(f" Consistency metrics ({label}) failed: {exc}") - return None, 0.0, None, 0.0, None, 0.0 + return corum_map, corum_ratio, chad_map, chad_ratio, ebi_map, ebi_ratio From d285daeeeb628418f903b2b7fea626c5cad06121 Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Thu, 11 Jun 2026 11:56:20 -0700 Subject: [PATCH 08/11] Fix stale icd.ops gene_clusters paths -> icd.fast.ops /hpc/projects/icd.ops/configs/gene_clusters/ no longer exists; the CHAD cluster YAMLs live under icd.fast.ops. Repoint the dead references (CHAD overlay hierarchy/cluster-map in pca_optimization phase2/handlers, deprecated gene/guide eval, titration decay tools, compare_map_scores) so they load instead of being skipped with warnings. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/ops_model/deprecated/eval/evaluate_gene.py | 2 +- src/ops_model/deprecated/eval/evaluate_guide.py | 2 +- .../models/attention/titration/decay/map_attention_decay.py | 2 +- .../titration/decay/plot_all_cells_correction_bars.py | 2 +- .../post_process/combination/analysis/compare_map_scores.py | 2 +- .../post_process/combination/pca_optimization/handlers.py | 2 +- .../post_process/combination/pca_optimization/phase2.py | 4 ++-- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/ops_model/deprecated/eval/evaluate_gene.py b/src/ops_model/deprecated/eval/evaluate_gene.py index 3d73ec0..a909b00 100644 --- a/src/ops_model/deprecated/eval/evaluate_gene.py +++ b/src/ops_model/deprecated/eval/evaluate_gene.py @@ -17,7 +17,7 @@ from ops_model.eval.metrics import mean_cosine_sim_within_groups MANUAL_ANNOTATION_YAML_PATH = ( - "/hpc/projects/icd.ops/configs/gene_clusters/chad_positive_controls_v4.yml" + "/hpc/projects/icd.fast.ops/configs/gene_clusters/chad_positive_controls_v4.yml" ) diff --git a/src/ops_model/deprecated/eval/evaluate_guide.py b/src/ops_model/deprecated/eval/evaluate_guide.py index 6966ebb..40603ea 100644 --- a/src/ops_model/deprecated/eval/evaluate_guide.py +++ b/src/ops_model/deprecated/eval/evaluate_guide.py @@ -16,7 +16,7 @@ from ops_model.eval.metrics import mean_cosine_sim_within_groups POS_CONTROLS_YAML_PATH = ( - "/hpc/projects/icd.ops/configs/gene_clusters/chad_positive_controls_v4.yml" + "/hpc/projects/icd.fast.ops/configs/gene_clusters/chad_positive_controls_v4.yml" ) diff --git a/src/ops_model/models/attention/titration/decay/map_attention_decay.py b/src/ops_model/models/attention/titration/decay/map_attention_decay.py index 3e101eb..44913f6 100644 --- a/src/ops_model/models/attention/titration/decay/map_attention_decay.py +++ b/src/ops_model/models/attention/titration/decay/map_attention_decay.py @@ -106,7 +106,7 @@ DEFAULT_ALL_CELLS_GUIDE_H5AD = Path(_PHASE_PCA_DIR) / "guide_pca_optimized.h5ad" # CHAD manual-annotation gene-cluster YAML — used by # phenotypic_consistency_manual_annotation as a complement to EBI. -DEFAULT_CHAD_YAML = Path("/hpc/projects/icd.ops/configs/gene_clusters/" +DEFAULT_CHAD_YAML = Path("/hpc/projects/icd.fast.ops/configs/gene_clusters/" "chad_positive_controls_v4.yml") diff --git a/src/ops_model/models/attention/titration/decay/plot_all_cells_correction_bars.py b/src/ops_model/models/attention/titration/decay/plot_all_cells_correction_bars.py index d2b3448..363bd4c 100644 --- a/src/ops_model/models/attention/titration/decay/plot_all_cells_correction_bars.py +++ b/src/ops_model/models/attention/titration/decay/plot_all_cells_correction_bars.py @@ -54,7 +54,7 @@ "/hpc/projects/icd.fast.ops/models/alex_lin_attention/v3/attention_v3/cdino" ) CHAD_YAML = Path( - "/hpc/projects/icd.ops/configs/gene_clusters/chad_positive_controls_v4.yml" + "/hpc/projects/icd.fast.ops/configs/gene_clusters/chad_positive_controls_v4.yml" ) NULL_SIZE = 10_000 diff --git a/src/ops_model/post_process/combination/analysis/compare_map_scores.py b/src/ops_model/post_process/combination/analysis/compare_map_scores.py index 1abf847..3020f7a 100644 --- a/src/ops_model/post_process/combination/analysis/compare_map_scores.py +++ b/src/ops_model/post_process/combination/analysis/compare_map_scores.py @@ -50,7 +50,7 @@ TOP_N_LABELS = 10 SLOPE_MAX_BG = 120 -_CHAD_YAML = "/hpc/projects/icd.ops/configs/gene_clusters/chad_positive_controls_v4.yml" +_CHAD_YAML = "/hpc/projects/icd.fast.ops/configs/gene_clusters/chad_positive_controls_v4.yml" _EBI_YAML = ( "/hpc/projects/icd.fast.ops/configs/gene_clusters/" "EBI_complexes_v1_old_gene_names.yaml" diff --git a/src/ops_model/post_process/combination/pca_optimization/handlers.py b/src/ops_model/post_process/combination/pca_optimization/handlers.py index 18cad89..1860da6 100644 --- a/src/ops_model/post_process/combination/pca_optimization/handlers.py +++ b/src/ops_model/post_process/combination/pca_optimization/handlers.py @@ -151,7 +151,7 @@ def _handle_chad_umap_only(args, output_dir): print("ERROR: No X_umap in gene embedding.") return - chad_path = args.chad_annotation or "/hpc/projects/icd.ops/configs/gene_clusters/chad_positive_controls_v5_hierarchy.yml" + chad_path = args.chad_annotation or "/hpc/projects/icd.fast.ops/configs/gene_clusters/chad_positive_controls_v5_hierarchy.yml" with open(chad_path) as f: chad_clusters = _yaml.safe_load(f) diff --git a/src/ops_model/post_process/combination/pca_optimization/phase2.py b/src/ops_model/post_process/combination/pca_optimization/phase2.py index 257b570..1d9cfd7 100644 --- a/src/ops_model/post_process/combination/pca_optimization/phase2.py +++ b/src/ops_model/post_process/combination/pca_optimization/phase2.py @@ -315,7 +315,7 @@ def aggregate_channels( except Exception as exc: _logger.warning(f" 1st-pass metric violin plot failed: {exc}") - _chad_path = CHAD_ANNOTATION_PATH or "/hpc/projects/icd.ops/configs/gene_clusters/chad_positive_controls_v5_hierarchy.yml" + _chad_path = CHAD_ANNOTATION_PATH or "/hpc/projects/icd.fast.ops/configs/gene_clusters/chad_positive_controls_v5_hierarchy.yml" if adata_gene_embed is not None and "X_umap" in adata_gene_embed.obsm: try: import yaml as _yaml @@ -1022,7 +1022,7 @@ def apply_second_pass_pca( _chad_path = ( CHAD_ANNOTATION_PATH - or "/hpc/projects/icd.ops/configs/gene_clusters/chad_positive_controls_v5_hierarchy.yml" + or "/hpc/projects/icd.fast.ops/configs/gene_clusters/chad_positive_controls_v5_hierarchy.yml" ) if adata_gene_embed is not None and "X_umap" in adata_gene_embed.obsm: try: From a89da8b8664d2a771eb647736f9b18324d3b8a59 Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 24 Jun 2026 09:06:16 -0700 Subject: [PATCH 09/11] Remove deprecated base_dataset, embeddings, and move_links modules Delete the old base_dataset.py, the data/embeddings/* helpers (cosine_similarity, embedding_metrics, funk_clusters, pca, umap_plots, utils), and move_links.py, along with their now-obsolete tests (test_basedataset.py, test_feature_metrics.py). Co-Authored-By: Claude Opus 4.8 (1M context) --- src/ops_model/data/base_dataset.py | 332 --------------- .../data/embeddings/cosine_similarity.py | 384 ------------------ .../data/embeddings/embedding_metrics.py | 123 ------ .../data/embeddings/funk_clusters.py | 334 --------------- src/ops_model/data/embeddings/pca.py | 56 --- src/ops_model/data/embeddings/umap_plots.py | 355 ---------------- src/ops_model/data/embeddings/utils.py | 56 --- src/ops_model/data/move_links.py | 142 ------- tests/features/test_feature_metrics.py | 46 --- tests/test_basedataset.py | 135 ------ 10 files changed, 1963 deletions(-) delete mode 100644 src/ops_model/data/base_dataset.py delete mode 100644 src/ops_model/data/embeddings/cosine_similarity.py delete mode 100644 src/ops_model/data/embeddings/embedding_metrics.py delete mode 100644 src/ops_model/data/embeddings/funk_clusters.py delete mode 100644 src/ops_model/data/embeddings/pca.py delete mode 100644 src/ops_model/data/embeddings/umap_plots.py delete mode 100644 src/ops_model/data/embeddings/utils.py delete mode 100644 src/ops_model/data/move_links.py delete mode 100644 tests/features/test_feature_metrics.py delete mode 100644 tests/test_basedataset.py diff --git a/src/ops_model/data/base_dataset.py b/src/ops_model/data/base_dataset.py deleted file mode 100644 index effeb67..0000000 --- a/src/ops_model/data/base_dataset.py +++ /dev/null @@ -1,332 +0,0 @@ -import ast -import random -from typing import Callable, List, Literal, Optional - -import zarr -import numpy as np -import pandas as pd -import torch -from iohub import open_ome_zarr - -from torch.utils.data import Dataset -from monai.transforms import ( - CenterSpatialCropd, - Compose, - SpatialPadd, - ToTensord, -) - - -class BaseDataset(Dataset): - """ - Base PyTorch Dataset for loading and preprocessing microscopy image patches with associated labels. - - This dataset handles loading image patches from OME-Zarr stores, applying spatial transformations, - and preparing data for training deep learning models on microscopy data. It supports single or - multiple channel selection, cell masking, and flexible patch sizing. - - Attributes: - stores (dict): Dictionary mapping store keys to OME-Zarr store objects. - labels_df (pd.DataFrame): DataFrame containing crop information and labels for each sample. - initial_yx_patch_size (tuple): Initial spatial size (height, width) for cropping patches. - final_yx_patch_size (tuple): Final spatial size after padding/cropping transformations. - out_channels (List[str] | Literal["random"] | Literal["all"]): Channel selection strategy. - label_int_lut (dict): Lookup table mapping gene names to integer labels. - mask_cell (bool): Whether to apply cell segmentation mask to image data. - use_original_crop_size (bool): Whether to use original crop size without padding/cropping. - transform (Compose): MONAI composition of transformations to apply to data. - """ - - def __init__( - self, - stores: dict, - labels_df: pd.DataFrame, - initial_yx_patch_size: tuple = (128, 128), - final_yx_patch_size: tuple = (128, 128), - out_channels: List[str] | Literal["random"] | Literal["all"] = "random", - label_int_lut: Optional[dict] = None, # string --> int - mask_cell: bool = True, - use_original_crop_size: bool = False, - ): - """ - Initialize the BaseDataset. - - Args: - stores (dict): Dictionary mapping store keys to opened OME-Zarr store objects. - labels_df (pd.DataFrame): DataFrame with columns including 'gene_name', 'bbox', - 'store_key', 'well', 'segmentation_id', and 'total_index'. - initial_yx_patch_size (tuple, optional): Initial (height, width) size for extracting - patches before transformations. Defaults to (128, 128). - final_yx_patch_size (tuple, optional): Final (height, width) size after padding and - center cropping. Defaults to (128, 128). - out_channels (List[str] | Literal["random"] | Literal["all"], optional): Strategy for - channel selection. "random" selects one random channel per sample, "all" uses all - available channels, or provide a list of specific channel names. Defaults to "random". - label_int_lut (Optional[dict], optional): Lookup table mapping gene names (str) to - integer labels (int). If None, automatically generated from unique gene names in - labels_df. Defaults to None. - mask_cell (bool, optional): If True, multiply image data by segmentation mask to isolate - individual cells. Defaults to True. - use_original_crop_size (bool, optional): If True, skip padding/cropping transformations - and use original bounding box size. Defaults to False. - """ - self.stores = stores - self.labels_df = labels_df - self.initial_yx_patch_size = initial_yx_patch_size - self.final_yx_patch_size = final_yx_patch_size - self.out_channels = out_channels - self.mask_cell = mask_cell - self.use_original_crop_size = use_original_crop_size - if label_int_lut is None: - gene_labels = sorted(self.labels_df["gene_name"].unique()) - self.label_int_lut = {gene: i for i, gene in enumerate(gene_labels)} - else: - self.label_int_lut = label_int_lut - - if self.use_original_crop_size: - self.transform = Compose( - [ - ToTensord(keys=["data", "mask"]), - ] - ) - else: - self.transform = Compose( - [ - SpatialPadd( - keys=["data", "mask"], - spatial_size=self.initial_yx_patch_size, - ), - CenterSpatialCropd( - keys=["data", "mask"], roi_size=(self.final_yx_patch_size) - ), - ToTensord( - keys=["data", "mask"], - ), - ] - ) - - return - - def _get_bbox(self, ci, final_shape): - """ - Extract and optionally expand bounding box to match target shape. - - Parses the bounding box from crop info and pads it equally on all sides if it's - smaller than the target shape. Padding is distributed symmetrically to keep the - original crop centered. - - Args: - ci: Row from labels_df containing crop information with a 'bbox' field. - final_shape (tuple): Target (height, width) dimensions for the bounding box. - - Returns: - tuple: Bounding box as (ymin, xmin, ymax, xmax). If use_original_crop_size is True, - returns the original bbox. Otherwise, returns padded bbox matching final_shape. - - Note: - The bbox format is (ymin, xmin, ymax, xmax) representing top-left and bottom-right - corners in (y, x) coordinates. - """ - bbox = ast.literal_eval(ci.bbox) - - if not self.use_original_crop_size: - - if len(final_shape) > 2: - final_shape = final_shape[-2:] - - ymin, xmin, ymax, xmax = bbox - target_height, target_width = final_shape - - # Calculate current bbox dimensions - current_height = ymax - ymin - current_width = xmax - xmin - - # Calculate padding needed - height_padding = max(0, target_height - current_height) - width_padding = max(0, target_width - current_width) - - # Distribute padding equally on both sides - pad_top = height_padding / 2 - pad_bottom = height_padding / 2 - pad_left = width_padding / 2 - pad_right = width_padding / 2 - - # Apply padding - new_ymin = int(ymin - pad_top) - new_ymax = int(ymax + pad_bottom) - new_xmin = int(xmin - pad_left) - new_xmax = int(xmax + pad_right) - bbox = (new_ymin, new_xmin, new_ymax, new_xmax) - - return bbox - - def _get_channels(self, ci): - """ - Determine which channels to load based on the configured strategy. - - Retrieves available channel names from the OME-Zarr metadata and selects channels - according to the out_channels strategy (random, all, or specific channels). - - Args: - ci: Row from labels_df containing 'store_key' and 'well' fields to identify - the data source. - - Returns: - tuple: A tuple containing: - - channel_names (list): List of channel name strings to load. - - channel_index (list): List of integer indices corresponding to the channels - in the OME-Zarr store. - - Note: - If out_channels is "random", selects one random channel per call. - If out_channels is "all", returns all available channels. - Otherwise, uses the specific channel names provided in out_channels. - """ - - attrs = self.stores[ci.store_key][ci.well].attrs.asdict() - all_channel_names = [a["label"] for a in attrs["ome"]["omero"]["channels"]] - - if self.out_channels == "random": - channel_names = [random.choice(all_channel_names)] - if self.out_channels == "all": - channel_names = all_channel_names - else: - channel_names = self.out_channels - channel_index = [all_channel_names.index(c) for c in channel_names] - - return channel_names, channel_index - - def __len__(self): - return len(self.labels_df) - - def add_labels_to_batch(self, ci): - """ - Extract label information from crop info for the current sample. - - Converts gene name to integer label using the lookup table and retrieves - additional metadata for tracking. - - Args: - ci: Row from labels_df containing 'gene_name' and 'total_index' fields. - - Returns: - tuple: A tuple containing: - - gene_label (int): Integer label for the gene name. - - total_index: Unique identifier for this sample. - - crop_info (dict): Dictionary representation of all crop metadata. - """ - gene_label = self.label_int_lut[ci.gene_name] - total_index = ci.total_index - - return gene_label, total_index, ci.to_dict() - - def add_mask_to_batch(self, ci, bbox): - """ - Load and extract binary segmentation mask for a specific cell. - - Retrieves the segmentation mask from the OME-Zarr store, crops it to the bounding box, - and creates a binary mask for the specific cell identified by segmentation_id. - - Args: - ci: Row from labels_df containing 'store_key', 'well', and 'segmentation_id' fields. - bbox (tuple): Bounding box as (ymin, xmin, ymax, xmax) defining the region to extract. - - Returns: - np.ndarray: Binary mask of shape (1, height, width) where True indicates pixels - belonging to the target cell (segmentation_id) and False elsewhere. - """ - mask_fov = self.stores[ci.store_key][ci.well]["labels"]["cell_seg"]["0"] - mask = np.asarray( - mask_fov[0:1, :, 0:1, slice(bbox[0], bbox[2]), slice(bbox[1], bbox[3])] - ).copy() - mask = np.squeeze(mask) - mask = np.expand_dims(mask, axis=0) - sc_mask = mask == ci.segmentation_id - - return sc_mask - - def add_data_to_batch(self, ci, bbox, channel_index): - """ - Load and extract image data for specified channels and bounding box. - - Retrieves raw microscopy image data from the OME-Zarr store, crops it to the - bounding box region, and extracts the specified channels. - - Args: - ci: Row from labels_df containing 'store_key' and 'well' fields to identify - the data source. - bbox (tuple): Bounding box as (ymin, xmin, ymax, xmax) defining the region to extract. - channel_index (list): List of integer indices specifying which channels to load. - - Returns: - np.ndarray: Image data as float32 array with shape (C, height, width) where C is - the number of channels. Single channel images are expanded to (1, height, width). - """ - fov = self.stores[ci.store_key][ci.well]["0"] - data = np.asarray( - fov[ - 0:1, - channel_index, - 0:1, - slice(bbox[0], bbox[2]), - slice(bbox[1], bbox[3]), - ] - ).copy() - data = np.squeeze(data) - if len(data.shape) == 2: - data = np.expand_dims(data, axis=0) - - return data.astype(np.float32) - - def __getitem__(self, index): - """ - Load and preprocess a single sample from the dataset. - - This is the main data loading method called by PyTorch DataLoader. It orchestrates - loading the image patch, segmentation mask, and labels, applies transformations, - and returns a dictionary containing all sample data. - - Args: - index (int): Index of the sample to retrieve from labels_df. - - Returns: - dict: Dictionary containing: - - 'data' (torch.Tensor): Image data of shape (C, H, W) or (1, C, H, W) if - final_yx_patch_size has 3 dimensions. Optionally masked by cell segmentation. - - 'mask' (torch.Tensor): Binary segmentation mask of shape (1, H, W) or - (1, 1, H, W), with same dimensionality as data. - - 'marker_label' (list): List of channel names loaded for this sample. - - 'gene_label' (int): Integer label for the gene associated with this cell. - - 'total_index' (int): Unique identifier for this sample. - - 'crop_info' (dict): Complete metadata for this crop from labels_df. - - Note: - If mask_cell is True, the returned data will be element-wise multiplied by the mask. - Transformations (padding, cropping, tensor conversion) are applied based on the - transform pipeline configured during initialization. - """ - batch = {} - ci = self.labels_df.iloc[index] # crop info - bbox = self._get_bbox(ci, self.initial_yx_patch_size) - c_names, c_index = self._get_channels(ci) - batch["marker_label"] = c_names - - gene_label, total_index, crop_info = self.add_labels_to_batch(ci) - batch["gene_label"] = gene_label - batch["total_index"] = int(total_index) - batch["crop_info"] = crop_info - - batch["data"] = self.add_data_to_batch(ci, bbox, c_index) - batch["mask"] = self.add_mask_to_batch(ci, bbox) - - if self.mask_cell: - batch["data"] = batch["data"] * batch["mask"] - - if len(self.final_yx_patch_size) == 3: - batch["data"] = np.expand_dims(batch["data"], axis=0) - batch["mask"] = np.expand_dims(batch["mask"], axis=0) - - if self.transform is not None: - batch = self.transform(batch) - - return batch diff --git a/src/ops_model/data/embeddings/cosine_similarity.py b/src/ops_model/data/embeddings/cosine_similarity.py deleted file mode 100644 index cb74884..0000000 --- a/src/ops_model/data/embeddings/cosine_similarity.py +++ /dev/null @@ -1,384 +0,0 @@ -# %% -from tqdm import tqdm -from pathlib import Path - -import torch -import torch.nn.functional as F -import matplotlib.pyplot as plt -import pandas as pd - -from ops_model.data.embeddings.utils import load_adata - - -def embedding_spread(adata, label, plot=True): - """ - Calculates the mean cosine similarity from each embedding to the centroid of the class - for all embeddings with the specified label in adata.obs['label_str']. - - Returns: - - the average cosine similarity to the centroid - - a histogram of the cosine similarities - - """ - - # Filter observations with the specified label - mask = adata.obs["label_str"] == label - embeddings = torch.tensor(adata.X[mask]).cuda() - - if embeddings.shape[0] < 2: - print(f"Not enough samples for label '{label}': {embeddings.shape[0]}") - return None, None - - # Normalize embeddings for cosine similarity computation - x = F.normalize(embeddings, dim=1) - - # Compute the centroid (mean of normalized embeddings, then re-normalize) - centroid = x.mean(dim=0, keepdim=True) - centroid = F.normalize(centroid, dim=1) - - # Compute cosine similarity from each embedding to the centroid - cosine_similarities = (x * centroid).sum(dim=1).cpu() - mean_similarity = cosine_similarities.mean().item() - std_similarity = cosine_similarities.std().item() - - if plot: - # Create histogram - fig, ax = plt.subplots(figsize=(8, 5)) - ax.hist(cosine_similarities.numpy(), bins=50, edgecolor="black", alpha=0.7) - ax.set_xlabel("Cosine Similarity to Centroid") - ax.set_ylabel("Frequency") - ax.set_title(f"Cosine Similarity to Centroid for Label: {label}") - ax.axvline( - mean_similarity, - color="red", - linestyle="--", - linewidth=2, - label=f"Mean: {mean_similarity:.4f}", - ) - ax.legend() - plt.tight_layout() - plt.close(fig) - - return mean_similarity, std_similarity, fig - - return mean_similarity, std_similarity, None - - -def embedding_spread_all_labels(adata, min_samples=2): - """ - Calculates the embedding spread (mean cosine similarity to centroid) for all labels - in the adata object. - - Args: - adata: AnnData object with embeddings in .X - min_samples: Minimum number of samples required to compute spread (default: 2) - - Returns: - results_dict: Dictionary mapping label names to (mean_similarity, std_similarity) tuples - sorted_results: List of (label, mean_similarity, std_similarity) tuples sorted by mean similarity - fig: Matplotlib figure with histogram of mean similarities across all labels - """ - import matplotlib.pyplot as plt - import numpy as np - - # Get all unique labels - unique_labels = adata.obs["label_str"].unique() - - results_dict = {} - - for label in tqdm(unique_labels): - # Check if label has enough samples - n_samples = (adata.obs["label_str"] == label).sum() - - if n_samples < min_samples: - print(f"Skipping label '{label}': only {n_samples} samples") - continue - - # Compute embedding spread for this label - mean_sim, std_sim, _ = embedding_spread(adata, label, plot=False) - - if mean_sim is not None: - results_dict[label] = (mean_sim, std_sim) - - # Sort by mean similarity (highest similarity = most compact clusters first) - sorted_results = sorted( - [(label, mean, std) for label, (mean, std) in results_dict.items()], - key=lambda x: x[1], - reverse=True, - ) - - # Create histogram of mean similarities - mean_similarities = [mean for _, mean, _ in sorted_results] - overall_mean = np.mean(mean_similarities) - - top_10_tightest = sorted_results[:10] - top_10_diffuse = sorted_results[-10:] - - fig, ax = plt.subplots(figsize=(8, 5)) - ax.hist(mean_similarities, bins=30, edgecolor="black", alpha=0.7) - ax.set_xlabel("Mean Cosine Similarity to Centroid") - ax.set_ylabel("Number of Labels") - ax.set_title("Distribution of Embedding Spread Across Labels") - ax.axvline( - overall_mean, - color="red", - linestyle="--", - linewidth=2, - label=f"Mean: {overall_mean:.4f}", - ) - ax.legend() - plt.tight_layout() - - return top_10_tightest, top_10_diffuse, sorted_results, fig - - -def cosine_similarity_to_reference(adata, reference_label): - """ - Calculates the cosine similarity from each label's embedding to a reference label's embedding. - Assumes each label has exactly one embedding in the adata object. - - Args: - adata: AnnData object with embeddings in .X (one embedding per label) - reference_label: The reference label to compare against - - Returns: - similarities_dict: Dictionary mapping label names to cosine similarities from reference - sorted_labels: List of (label, similarity) tuples sorted by similarity (most similar first) - fig: Matplotlib figure with histogram of similarities - """ - - # Get reference embedding - ref_mask = adata.obs["label_str"] == reference_label - if ref_mask.sum() == 0: - raise ValueError(f"Reference label '{reference_label}' not found in adata") - if ref_mask.sum() > 1: - raise ValueError( - f"Reference label '{reference_label}' has multiple embeddings, expected 1" - ) - - ref_embedding = torch.tensor(adata.X[ref_mask]).cuda() - ref_x = F.normalize(ref_embedding, dim=1) - - # Get all embeddings and labels - all_embeddings = torch.tensor(adata.X).cuda() - all_x = F.normalize(all_embeddings, dim=1) - - # Compute cosine similarity between reference and all embeddings - cosine_similarities = (ref_x @ all_x.T).squeeze().cpu() - - # Create dictionary mapping labels to similarities - similarities_dict = {} - for idx, label in enumerate(adata.obs["label_str"]): - similarities_dict[label] = cosine_similarities[idx].item() - - # Sort labels by similarity (highest first) - sorted_labels = sorted(similarities_dict.items(), key=lambda x: x[1], reverse=True) - - # Calculate mean similarity - mean_similarity = cosine_similarities.mean().item() - - top_10_closest = sorted_labels[1:11] - top_10_furthest = sorted_labels[-10:] - - # Create histogram - fig, ax = plt.subplots(figsize=(8, 5)) - ax.hist(cosine_similarities.numpy(), bins=50, edgecolor="black", alpha=0.7) - ax.set_xlabel("Cosine Similarity") - ax.set_ylabel("Frequency") - ax.set_title(f"Cosine Similarity Distribution to Reference: {reference_label}") - ax.axvline( - mean_similarity, - color="red", - linestyle="--", - linewidth=2, - label=f"Mean: {mean_similarity:.4f}", - ) - ax.legend() - plt.tight_layout() - - return top_10_closest, top_10_furthest, sorted_labels, fig - - -def mean_similarity( - adata, - n_samples=10_000_000, - batch_size=10000, -): - """ - Compute mean and std of pairwise cosine similarities. - - Args: - adata: AnnData object with embeddings in .X - use_sampling: If True, sample random pairs instead of computing all pairs - n_samples: Number of pairs to sample (only used if use_sampling=True) - batch_size: Batch size for processing (only used if use_sampling=False) - """ - embeddings = torch.tensor(adata.X).cuda() - x = F.normalize(embeddings, dim=1) - n = x.shape[0] - - # Sampling approach: much faster for large datasets - # Sample random pairs and compute their similarities - n_samples = min(n_samples, n * (n - 1) // 2) # Don't sample more than total pairs - - # Generate random pairs - idx_i = torch.randint(0, n, (n_samples,), device="cuda") - idx_j = torch.randint(0, n, (n_samples,), device="cuda") - - # Ensure i != j - mask = idx_i == idx_j - idx_j[mask] = (idx_j[mask] + 1) % n - - # Compute similarities for sampled pairs in batches - similarities = [] - for start in tqdm(range(0, n_samples, batch_size)): - end = min(start + batch_size, n_samples) - batch_i = x[idx_i[start:end]] - batch_j = x[idx_j[start:end]] - sim = (batch_i * batch_j).sum(dim=1) - similarities.append(sim.cpu()) - - similarities = torch.cat(similarities) - mean_similarity = similarities.mean().item() - std_similarity = similarities.std().item() - - fig, ax = plt.subplots(figsize=(8, 5)) - ax.hist(similarities, bins=100, edgecolor="black", alpha=0.7) - ax.set_xlabel("Cosine Similarity") - ax.set_ylabel("Number of Labels") - ax.set_title("Cosine Similarity Distribution Across Labels") - ax.axvline( - mean_similarity, - color="red", - linestyle="--", - linewidth=2, - label=f"Mean: {mean_similarity:.4f}", - ) - ax.legend() - plt.tight_layout() - plt.close(fig) - - return mean_similarity, std_similarity, fig - - -def mean_similarity_within_labels( - adata, - n_samples_per_label=10000, - batch_size=10000, -): - """ - Compute mean and std of pairwise cosine similarities for pairs within the same label. - Only computes similarities between embeddings that share the same label in adata.obs['label_str']. - - Args: - adata: AnnData object with embeddings in .X - n_samples_per_label: Number of pairs to sample per label - batch_size: Batch size for processing - - Returns: - mean_similarity: Mean cosine similarity across all within-label pairs - std_similarity: Standard deviation of cosine similarities - """ - import numpy as np - - embeddings = torch.tensor(adata.X).cuda() - x = F.normalize(embeddings, dim=1) - - # Get unique labels - unique_labels = adata.obs["label_str"].unique() - - all_similarities = [] - - for label in tqdm(unique_labels, desc="Processing labels"): - # Get indices for this label - label_mask = adata.obs["label_str"] == label - label_indices = np.where(label_mask)[0] - n_label = len(label_indices) - - # Skip labels with only one sample - if n_label < 2: - continue - - # Determine number of pairs to sample for this label - n_possible_pairs = n_label * (n_label - 1) // 2 - n_samples = min(n_samples_per_label, n_possible_pairs) - - # Convert label indices to tensor - label_indices_tensor = torch.tensor(label_indices, device="cuda") - - # Generate random pairs within this label - idx_i = torch.randint(0, n_label, (n_samples,), device="cuda") - idx_j = torch.randint(0, n_label, (n_samples,), device="cuda") - - # Ensure i != j - mask = idx_i == idx_j - idx_j[mask] = (idx_j[mask] + 1) % n_label - - # Map to actual indices in the full dataset - actual_idx_i = label_indices_tensor[idx_i] - actual_idx_j = label_indices_tensor[idx_j] - - # Compute similarities for sampled pairs in batches - for start in range(0, n_samples, batch_size): - end = min(start + batch_size, n_samples) - batch_i = x[actual_idx_i[start:end]] - batch_j = x[actual_idx_j[start:end]] - sim = (batch_i * batch_j).sum(dim=1) - all_similarities.append(sim.cpu()) - - # Combine all similarities - all_similarities = torch.cat(all_similarities) - mean_similarity = all_similarities.mean().item() - std_similarity = all_similarities.std().item() - - fig, ax = plt.subplots(figsize=(8, 5)) - ax.hist(all_similarities, bins=100, edgecolor="black", alpha=0.7) - ax.set_xlabel("Cosine Similarity") - ax.set_ylabel("Number of Labels") - ax.set_title("Cosine Similarity Distribution Within Labels") - ax.axvline( - mean_similarity, - color="red", - linestyle="--", - linewidth=2, - label=f"Mean: {mean_similarity:.4f}", - ) - ax.legend() - plt.tight_layout() - plt.close(fig) - - return mean_similarity, std_similarity, fig - - -if __name__ == "__main__": - adata_path = "/hpc/projects/intracellular_dashboard/ops/ops0031_20250424/3-assembly/dynaclr_features" - adata_cells, adata_guides, adata_genes = load_adata(adata_path) - - mean_similarity_all, std_similarity_all, across_labels_fig = mean_similarity( - adata_cells, n_samples=1_000_000 - ) - mean_similarity_within, std_similarity_within, within_labels_fig = ( - mean_similarity_within_labels(adata_cells, n_samples_per_label=10_000) - ) - - top_10_closest, top_10_furthest, sorted_labels, fig = ( - cosine_similarity_to_reference(adata_genes, reference_label="NTC") - ) - top_10_tightest, top_10_diffuse, spread_sorted_labels, spread_fig = ( - embedding_spread_all_labels(adata_cells, min_samples=2) - ) - - print("Mean Cosine Similarity (All Cells):", f"{mean_similarity_all:.4f}") - print("Std Dev Cosine Similarity (All Cells):", f"{std_similarity_all:.4f}") - print("\nTop 10 Closest Labels to NTC:") - for label, sim in top_10_closest: - print(f" {label}: {sim:.4f}") - print("\nTop 10 Furthest Labels from NTC:") - for label, sim in top_10_furthest: - print(f" {label}: {sim:.4f}") - print("\nTop 10 Tightest Embedding Spreads:") - for label, mean_sim, std_sim in top_10_tightest: - print(f" {label}: Mean Similarity = {mean_sim:.4f}, Std Dev = {std_sim:.4f}") - print("\nTop 10 Most Diffuse Embedding Spreads:") - for label, mean_sim, std_sim in top_10_diffuse: - print(f" {label}: Mean Similarity = {mean_sim:.4f}, Std Dev = {std_sim:.4f}") diff --git a/src/ops_model/data/embeddings/embedding_metrics.py b/src/ops_model/data/embeddings/embedding_metrics.py deleted file mode 100644 index 08c9915..0000000 --- a/src/ops_model/data/embeddings/embedding_metrics.py +++ /dev/null @@ -1,123 +0,0 @@ -from tqdm import tqdm -from pathlib import Path -import yaml - -import torch -import numpy as np -import pandas as pd -import anndata as ad -import torch.nn.functional as F - -from ops_model.data.paths import OpsPaths - - -def mean_similarity( - adata, - n_samples=10_000_000, - batch_size=10000, -): - """ - Compute mean and std of pairwise cosine similarities. - - Args: - adata: AnnData object with embeddings in .X - use_sampling: If True, sample random pairs instead of computing all pairs - n_samples: Number of pairs to sample (only used if use_sampling=True) - batch_size: Batch size for processing (only used if use_sampling=False) - """ - embeddings = torch.tensor(adata.X).cuda() - x = F.normalize(embeddings, dim=1) - n = x.shape[0] - - # Sampling approach: much faster for large datasets - # Sample random pairs and compute their similarities - n_samples = min(n_samples, n * (n - 1) // 2) # Don't sample more than total pairs - - # Generate random pairs - idx_i = torch.randint(0, n, (n_samples,), device="cuda") - idx_j = torch.randint(0, n, (n_samples,), device="cuda") - - # Ensure i != j - mask = idx_i == idx_j - idx_j[mask] = (idx_j[mask] + 1) % n - - # Compute similarities for sampled pairs in batches - similarities = [] - for start in tqdm(range(0, n_samples, batch_size)): - end = min(start + batch_size, n_samples) - batch_i = x[idx_i[start:end]] - batch_j = x[idx_j[start:end]] - sim = (batch_i * batch_j).sum(dim=1) - similarities.append(sim.cpu()) - - similarities = torch.cat(similarities) - mean_similarity = similarities.mean().item() - std_similarity = similarities.std().item() - - return mean_similarity, std_similarity - - -def alignment_and_uniformity(adata, n_uniformity_samples=1_000_000, batch_size=10000): - """ - Compute alignment and uniformity metrics for embeddings. - - Code adapted from: - title={Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere}, - author={Wang, Tongzhou and Isola, Phillip}, - booktitle={International Conference on Machine Learning}, - organization={PMLR}, - pages={9929--9939}, - year={2020} - - Args: - adata: AnnData object with embeddings and gene_int labels - n_uniformity_samples: Number of random pairs to sample for uniformity - batch_size: Batch size for processing - """ - gene_int_list = adata.obs["label_int"].unique().tolist() - x = [] # positive pair i - y = [] # positive pair j - for i in tqdm(gene_int_list): - single_gene_embs = adata.X[adata.obs["label_int"] == i] - x += [single_gene_embs[j] for j in range(len(single_gene_embs) - 1)] - y += [ - single_gene_embs[z] - for z in np.random.permutation(np.arange(len(single_gene_embs) - 1)) - ] - - x = torch.tensor(np.asarray(x)).cuda() - y = torch.tensor(np.asarray(y)).cuda() - - # Compute alignment (all positive pairs) - alignment = (x - y).norm(p=2, dim=1).pow(2).mean().item() - - # Compute uniformity using sampling to avoid OOM - n = x.shape[0] - n_samples = min(n_uniformity_samples, n * (n - 1) // 2) - - # Sample random pairs for uniformity - idx_i = torch.randint(0, n, (n_samples,), device="cuda") - idx_j = torch.randint(0, n, (n_samples,), device="cuda") - - # Ensure i != j - mask = idx_i == idx_j - idx_j[mask] = (idx_j[mask] + 1) % n - - # Compute pairwise distances for sampled pairs in batches - uniformity_vals = [] - for start in tqdm(range(0, n_samples, batch_size)): - end = min(start + batch_size, n_samples) - batch_i = x[idx_i[start:end]] - batch_j = x[idx_j[start:end]] - # Compute squared L2 distance - dist_sq = (batch_i - batch_j).norm(p=2, dim=1).pow(2) - uniformity_vals.append(dist_sq.cpu()) - - uniformity_vals = torch.cat(uniformity_vals) - uniformity = uniformity_vals.mul(-2).exp().mean().log().item() - - return alignment, uniformity - - -if __name__ == "__main__": - pass diff --git a/src/ops_model/data/embeddings/funk_clusters.py b/src/ops_model/data/embeddings/funk_clusters.py deleted file mode 100644 index 6b84968..0000000 --- a/src/ops_model/data/embeddings/funk_clusters.py +++ /dev/null @@ -1,334 +0,0 @@ -import yaml -import pandas as pd -from pathlib import Path -import anndata as ad -import matplotlib.pyplot as plt -from ops_model.data.embeddings.utils import load_adata - - -def get_funk_clusters(): - path = Path( - "/hpc/projects/icd.fast.ops/configs/annotated_gene_panel_July2025.csv" - ) - df = pd.read_csv(path) - - funk_clusters = { - "26": { - "genes": df[df["funk_cluster"] == "26"]["Gene.name"].tolist(), - "desc": "26 DNA Replication", - }, - "148": { - "genes": df[df["funk_cluster"] == "148"]["Gene.name"].tolist(), - "desc": "148 Cell Cycle and Cytokinesis", - }, - "14": { - "genes": df[df["funk_cluster"] == "14"]["Gene.name"].tolist(), - "desc": "14 Translation Initiation ", - # 'desc': '14 Translation Initiation and tRNA Ligases' - }, - "106": { - "genes": df[df["funk_cluster"] == "106"]["Gene.name"].tolist(), - "desc": "106 Proteasome 19S Regulatory Particle", - # 'desc': '106 Proteasome 19S Regulatory Particle ATPase Subunits & Ubiquitination Factors', - }, - "138": { - "genes": df[df["funk_cluster"] == "138"]["Gene.name"].tolist(), - "desc": "138 Spliceosome", - }, - "29": { - "genes": df[df["funk_cluster"] == "29"]["Gene.name"].tolist(), - "desc": "29 Adhesion intracellular transport & NSL complex", - }, - "184": { - "genes": df[df["funk_cluster"] == "184"]["Gene.name"].tolist(), - "desc": "184 Actin Cytoskeletion & nuclear transport", - }, - "201": { - "genes": df[df["funk_cluster"] == "201"]["Gene.name"].tolist(), - "desc": "201 Golgi-ER transport", - }, - "13": { - "genes": df[df["funk_cluster"] == "13"]["Gene.name"].tolist(), - "desc": "13 DNA damage", - }, - "199": { - "genes": df[df["funk_cluster"] == "199"]["Gene.name"].tolist(), - "desc": "199 RNA Polymerase II", - }, - "200": { - "genes": df[df["funk_cluster"] == "200"]["Gene.name"].tolist(), - "desc": "200 COP9 signalosome", - }, - "52": { - "genes": df[df["funk_cluster"] == "52"]["Gene.name"].tolist(), - "desc": "52 Spliceosome", - }, - "155": { - "genes": df[df["funk_cluster"] == "155"]["Gene.name"].tolist(), - "desc": "155 RNA Polymerase I", - }, - # '82': { - # 'genes': df[df['funk_cluster'] == '82']['Gene.name'].tolist(), - # 'desc': '82 Mitochondrial Ribosome', - # }, - "3": { - "genes": df[df["funk_cluster"] == "3"]["Gene.name"].tolist(), - "desc": "3 DNA Damage", - }, - "46": { - "genes": df[df["funk_cluster"] == "46"]["Gene.name"].tolist(), - "desc": "46 Cell Cycle", - }, - "66": { - "genes": df[df["funk_cluster"] == "66"]["Gene.name"].tolist(), - "desc": "66 Ribosome 40S subunit", - }, - "212": { - "genes": df[df["funk_cluster"] == "212"]["Gene.name"].tolist(), - "desc": "212 Chaperonin TCP-1", - }, - "136": { - "genes": df[df["funk_cluster"] == "136"]["Gene.name"].tolist(), - "desc": "136 40S Ribosome Biogenesis", - }, - "110": { - "genes": df[df["funk_cluster"] == "110"]["Gene.name"].tolist(), - "desc": "110 Spliceosome", - }, - "214": { - "genes": df[df["funk_cluster"] == "214"]["Gene.name"].tolist(), - "desc": "214 Augmin complex", - }, - "21": { - "genes": df[df["funk_cluster"] == "21"]["Gene.name"].tolist(), - "desc": "21 Ribosome Biogenesis", - }, - "15": { - "genes": df[df["funk_cluster"] == "15"]["Gene.name"].tolist(), - "desc": "15 60S Ribosome Biogenesis", - }, - "23": { - "genes": df[df["funk_cluster"] == "23"]["Gene.name"].tolist(), - "desc": "23 Ribosome 60S subunit", - }, - "203": { - "genes": df[df["funk_cluster"] == "203"]["Gene.name"].tolist(), - "desc": "203 Ribosome Biogenesis", - }, - "216": { - "genes": df[df["funk_cluster"] == "216"]["Gene.name"].tolist(), - "desc": "216 Ribosome Biogenesis", - }, - "179": { - "genes": df[df["funk_cluster"] == "179"]["Gene.name"].tolist(), - "desc": "179 DNA Damage", - }, - } - return funk_clusters - - -def print_cluster_info(funk_clusters): - for k, v in funk_clusters.items(): - print(f"Cluster {k} ({v['desc']}): {len(v['genes'])} genes") - - return - - -def plot_funk_clusters( - adata, funk_clusters, save_path=None, report_dir=None, filename="funk_clusters.png" -): - """ - Plot UMAP colored by Funk functional clusters. - - Creates a 7×4 grid showing all 28 functional clusters (DNA replication, - ribosome, spliceosome, etc.) colored on UMAP. - - Args: - adata: AnnData object with UMAP coordinates and gene labels - funk_clusters: Dictionary of functional cluster definitions - save_path: Legacy parameter - direct path to save file - report_dir: Path to report directory (preferred over save_path) - filename: Filename to use when saving to report_dir - """ - fig, axs = plt.subplots(nrows=7, ncols=4, figsize=(20, 28)) - axs = axs.flatten() - for i, (cluster_num, cluster_info) in enumerate(funk_clusters.items()): - genes = cluster_info["genes"] - desc = cluster_info["desc"] - plt.sca(axs[i]) - umap = adata.obsm["X_umap"] - plt.scatter(umap[:, 0], umap[:, 1], c="lightgrey", s=20, alpha=0.8, linewidth=0) - for gene in genes: - subset = adata[adata.obs["perturbation"] == gene].obsm["X_umap"] - plt.scatter( - subset[:, 0], subset[:, 1], s=40, alpha=1, linewidth=0, label=gene - ) - plt.title(desc) - plt.xticks([]) - plt.yticks([]) - plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize="10") - plt.tight_layout() - - # Determine save path - if report_dir is not None: - save_path = Path(report_dir) / "plots" / filename - - if save_path is not None: - print(f"Saving funk cluster UMAP plot to {save_path}") - plt.savefig(save_path, dpi=300) - - return - - -def check_cluster_genes(cluster_list_path: str, gene_list_path: str): - """ - Validate that all genes in cluster definitions exist in the gene panel. - - 1) Load the yaml found at cluster_list_path and the csv found at gene_list_path - 2) Check that each gene in each cluster is found in the gene list - 3) Raise an error if any genes are missing, otherwise print success message - - Args: - cluster_list_path: Path to YAML file containing cluster definitions - gene_list_path: Path to CSV file containing valid gene names - - Raises: - ValueError: If any genes in the clusters are not found in the gene list - """ - # Load cluster definitions from YAML - with open(cluster_list_path, "r") as f: - clusters = yaml.safe_load(f) - - # Load gene list from CSV - gene_df = pd.read_csv(gene_list_path) - valid_genes = set(gene_df["Gene.name"].tolist()) - - # Check each cluster for missing genes - missing_genes_by_cluster = {} - cluster_info = [] - - for cluster_id, cluster_data in clusters.items(): - cluster_name = cluster_data["name"] - cluster_genes = cluster_data["genes"] - - # Find missing genes - missing = [gene for gene in cluster_genes if gene not in valid_genes] - - if missing: - missing_genes_by_cluster[cluster_id] = { - "name": cluster_name, - "missing_genes": missing, - } - - cluster_info.append((cluster_id, cluster_name, len(cluster_genes))) - - # Report results - if missing_genes_by_cluster: - error_msg = "\nValidation FAILED: The following genes are missing from the gene panel:\n\n" - for cluster_id, info in missing_genes_by_cluster.items(): - error_msg += f"Cluster {cluster_id} ({info['name']}):\n" - for gene in info["missing_genes"]: - error_msg += f" - {gene}\n" - error_msg += "\n" - raise ValueError(error_msg) - else: - print("\n✓ Validation PASSED: All genes found in gene panel!\n") - print("Cluster summary:") - for cluster_id, cluster_name, num_genes in cluster_info: - print(f" Cluster {cluster_id} ({cluster_name}): {num_genes} genes") - print() - - return - - -def plot_clusters( - adata, cluster_list_path: str, save_path: str = None, obsm_key: str = "X_umap" -): - """ - Plot UMAP colored by gene clusters from YAML configuration. - - Takes as input a yaml file containing cluster names and lists of gene KOs, - creates plots similar to plot_funk_clusters but for each cluster in the yaml, - and saves plots to disk at save_path. - - Args: - adata: AnnData object with UMAP coordinates and gene labels - cluster_list_path: Path to YAML file containing cluster definitions - save_path: Path where the plot PNG will be saved - - Raises: - ValueError: If any gene in the clusters is not found in adata.obs['perturbation'] - """ - # Load cluster definitions from YAML - with open(cluster_list_path, "r") as f: - clusters = yaml.safe_load(f) - - # Calculate grid dimensions (5 columns max) - num_clusters = len(clusters) - ncols = min(5, num_clusters) - nrows = (num_clusters + ncols - 1) // ncols # Ceiling division - - # Create figure with subplots - fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 4 * nrows)) - - # Handle case where there's only one subplot (axs won't be an array) - if num_clusters == 1: - axs = [axs] - else: - axs = axs.flatten() - - # Get all available genes in adata - available_genes = set(adata.obs["perturbation"].unique()) - - # Plot each cluster - umap = adata.obsm[obsm_key] - for i, (cluster_id, cluster_data) in enumerate(clusters.items()): - cluster_name = cluster_data["name"] - genes = cluster_data["genes"] - - # Check if all genes exist in adata - missing_genes = [gene for gene in genes if gene not in available_genes] - if missing_genes: - pass - # raise ValueError( - # f"Cluster {cluster_id} ({cluster_name}) contains genes not found in adata: {missing_genes}" - # ) - - # Plot this cluster - plt.sca(axs[i]) - - # Background: all points in grey - plt.scatter(umap[:, 0], umap[:, 1], c="lightgrey", s=20, alpha=0.8, linewidth=0) - - # Highlight genes in this cluster - for gene in genes: - subset = adata[adata.obs["perturbation"] == gene].obsm[obsm_key] - plt.scatter( - subset[:, 0], subset[:, 1], s=40, alpha=1, linewidth=0, label=gene - ) - - # Format subplot - plt.title(f"{cluster_id}: {cluster_name}") - plt.xticks([]) - plt.yticks([]) - if cluster_name == "NTCs": - continue - else: - plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize="10") - - # Hide any unused subplots - for i in range(num_clusters, len(axs)): - axs[i].axis("off") - - plt.tight_layout() - - if save_path is not None: - print(f"Saving cluster UMAP plot to {save_path}") - plt.savefig(save_path, dpi=300, bbox_inches="tight") - plt.close() - - return - - -if __name__ == "__main__": - pass diff --git a/src/ops_model/data/embeddings/pca.py b/src/ops_model/data/embeddings/pca.py deleted file mode 100644 index caa1e4c..0000000 --- a/src/ops_model/data/embeddings/pca.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Optional -from pathlib import Path - -import scanpy as sc -import anndata as ad -import matplotlib.pyplot as plt - -from ops_model.data.paths import OpsPaths - - -def plot_pca( - adata: ad.AnnData, - save_path: Optional[str] = None, - description: Optional[str] = None, - report_dir: Optional[str] = None, - filename: str = "pca_variance_ratio.png", -): - """ - Plot PCA variance ratio. - - Args: - adata: AnnData object with PCA results - save_path: Legacy parameter - direct path to save file - description: Plot title/description - report_dir: Path to report directory (preferred over save_path) - filename: Filename to use when saving to report_dir - """ - sc.pl.pca_variance_ratio(adata, n_pcs=100, log=False, save=False, show=False) - if description: - plt.title(description) - - # Determine save path - if report_dir is not None: - save_path = Path(report_dir) / "plots" / filename - - if save_path is not None: - plt.savefig(save_path, dpi=300, bbox_inches="tight") - plt.close() - - return - - -def pca_90_percent_variance(adata: ad.AnnData) -> int: - cumulative_variance = adata.uns["pca"]["variance_ratio"].cumsum() - if cumulative_variance[-1] < 0.9: - return ( - cumulative_variance[99], - f"First 100 PCs explain {cumulative_variance[-1]:.2f} variance", - ) - else: - num_components = (cumulative_variance < 0.9).sum() + 1 - return num_components, f"First {num_components} PCs explain 90% variance" - - -if __name__ == "__main__": - pass diff --git a/src/ops_model/data/embeddings/umap_plots.py b/src/ops_model/data/embeddings/umap_plots.py deleted file mode 100644 index 4c52bfe..0000000 --- a/src/ops_model/data/embeddings/umap_plots.py +++ /dev/null @@ -1,355 +0,0 @@ -from tqdm import tqdm -from pathlib import Path -from typing import Literal, Optional -from ast import literal_eval - -import pandas as pd -import anndata as ad -import matplotlib.pyplot as plt - -from ops_model.data.embeddings.utils import group_guides, get_gene_complexes -from ops_model.data.paths import OpsPaths - - -COLORS = [ - "lightcoral", - "brown", - "darkred", - "burlywood", - "darkgoldenrod", -] - - -def plot_umap( - gene: str, - adata: ad.AnnData, - save_path: Optional[str] = None, - guides: Optional[list] = None, - data_point_type: Literal["cell", "guide", "gene"] = "cell", - report_dir: Optional[str] = None, - filename: Optional[str] = None, -): - """ - Plot UMAP highlighting a specific gene or guide. - - Args: - gene: Gene symbol to highlight - adata: AnnData object with UMAP coordinates - save_path: Legacy parameter - direct path to save file - guides: List of guide IDs (for guide-level plotting) - data_point_type: Type of data points ("cell", "guide", or "gene") - report_dir: Path to report directory (preferred over save_path) - filename: Filename to use when saving to report_dir - """ - umap = adata.obsm["X_umap"] - if data_point_type == "cell": - subset = adata[adata.obs["label_str"] == gene].obsm["X_umap"] - s = 1 - alpha = 0.5 - elif data_point_type == "gene": - subset = adata[adata.obs["label_str"] == gene].obsm["X_umap"] - s = 20 - alpha = 0.8 - else: # guide-level - subset = adata[adata.obs["sgRNA"].isin(guides)].obsm["X_umap"] - s = 20 - alpha = 0.8 - - plt.scatter(umap[:, 0], umap[:, 1], c="lightgrey", s=s, alpha=alpha, linewidth=0) - plt.scatter(subset[:, 0], subset[:, 1], c="blue", s=s, alpha=alpha, linewidth=0) - plt.title(gene) - plt.xticks([]) - plt.yticks([]) - - # Determine save path - if report_dir is not None: - if filename is None: - filename = f"umap_{data_point_type}_{gene}.png" - save_path = Path(report_dir) / "plots" / filename - - if save_path is not None: - plt.savefig(save_path, dpi=300) - plt.figure() - plt.close() - - # Further implementation for plotting by guide - return - - -def plot_umap_multiple_genes( - genes: list, - adata: ad.AnnData, - save_path: Optional[str] = None, - title: Optional[str] = None, - report_dir: Optional[str] = None, - filename: Optional[str] = None, -): - """ - Plot UMAP highlighting multiple genes with different colors. - - Args: - genes: List of gene symbols to highlight - adata: AnnData object with UMAP coordinates - save_path: Legacy parameter - direct path to save file - title: Plot title - report_dir: Path to report directory (preferred over save_path) - filename: Filename to use when saving to report_dir - """ - umap = adata.obsm["X_umap"] - plt.scatter(umap[:, 0], umap[:, 1], c="lightgrey", s=20, alpha=0.8, linewidth=0) - for gene in genes: - subset = adata[adata.obs["label_str"] == gene].obsm["X_umap"] - plt.scatter(subset[:, 0], subset[:, 1], s=20, alpha=1, linewidth=0, label=gene) - plt.title(title if title is not None else "") - plt.xticks([]) - plt.yticks([]) - plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) - plt.tight_layout() - - # Determine save path - if report_dir is not None: - if filename is None: - # Generate filename from title or genes - if title: - filename = f"umap_{title.replace(' ', '_').lower()}.png" - else: - filename = f"umap_multiple_genes.png" - save_path = Path(report_dir) / "plots" / filename - - if save_path is not None: - plt.savefig(save_path, dpi=300) - plt.figure() - plt.close() - return - - -def report_umap_plot_2( - feature_dir: str, - adata_cells: Optional[ad.AnnData] = None, - adata_guides: Optional[ad.AnnData] = None, - adata_genes: Optional[ad.AnnData] = None, - output_path: Optional[str] = None, - report_dir: Optional[str] = None, -): - """ - Generate UMAP plots for protein complex genes (RPL, NUP, TRAPPC, KRT). - - Args: - feature_dir: Path to feature directory (for loading AnnData if not provided) - adata_cells: Cell-level AnnData (optional) - adata_guides: Guide-level AnnData (optional) - adata_genes: Gene-level AnnData (optional) - output_path: Legacy parameter - directory for output - report_dir: Path to report directory (preferred over output_path) - """ - path = Path(feature_dir) - output_path = Path(output_path) if output_path is not None else None - if adata_cells is None: - adata_path_cells = path / "anndata_objects" / "features_processed.h5ad" - adata_cells = ad.read_h5ad(adata_path_cells) - if adata_guides is None: - adata_path_guides = path / "anndata_objects" / "guide_bulked.h5ad" - adata_guides = ad.read_h5ad(adata_path_guides) - if adata_genes is None: - adata_path_genes = path / "anndata_objects" / "gene_bulked.h5ad" - adata_genes = ad.read_h5ad(adata_path_genes) - gene_guide_dict = group_guides() - - plot_umap_multiple_genes( - genes=[ - "RPL18", - "RPL23", - "RPL9", - "RPL30", - "RPL35", - "RPL32", - "RPLP2", - "RPL27A", - "RPL5", - "RPL15", - "RPL41", - "RPL34", - "RPL26", - "RPL37A", - ], - adata=adata_genes, - title="RPL genes UMAP", - save_path=( - output_path / "fig_2_umap_rpl_genes.png" - if output_path is not None - else None - ), - report_dir=report_dir, - filename="umap_rpl_genes.png", - ) - plt.figure() - plot_umap_multiple_genes( - genes=["NUP54", "NUP98", "NUP214", "NUP37"], - adata=adata_genes, - title="NUP genes UMAP", - save_path=( - output_path / "fig_2_umap_nup_genes.png" - if output_path is not None - else None - ), - report_dir=report_dir, - filename="umap_nup_genes.png", - ) - plt.figure() - plot_umap_multiple_genes( - genes=["TRAPPC11", "TRAPPC4", "TRAPPC2L"], - adata=adata_genes, - title="TRAPPC genes UMAP", - save_path=( - output_path / "fig_2_umap_trappc_genes.png" - if output_path is not None - else None - ), - report_dir=report_dir, - filename="umap_trappc_genes.png", - ) - plt.figure() - plot_umap_multiple_genes( - genes=["KRT18", "KRT8"], - adata=adata_genes, - title="KRT genes UMAP", - save_path=( - output_path / "fig_2_umap_krt_genes.png" - if output_path is not None - else None - ), - report_dir=report_dir, - filename="umap_krt_genes.png", - ) - plt.figure() - - return - - -def report_umap_plot_1( - feature_dir: str, - adata_cells: Optional[ad.AnnData] = None, - adata_guides: Optional[ad.AnnData] = None, - adata_genes: Optional[ad.AnnData] = None, - output_path: Optional[str] = None, - report_dir: Optional[str] = None, -): - """ - Generate UMAP plots for NTC (non-targeting control) at cell, guide, and gene levels. - - Args: - feature_dir: Path to feature directory (for loading AnnData if not provided) - adata_cells: Cell-level AnnData (optional) - adata_guides: Guide-level AnnData (optional) - adata_genes: Gene-level AnnData (optional) - output_path: Legacy parameter - directory for output - report_dir: Path to report directory (preferred over output_path) - """ - path = Path(feature_dir) - output_path = Path(output_path) if output_path is not None else None - if adata_cells is None: - adata_path_cells = path / "anndata_objects" / "features_processed.h5ad" - adata_cells = ad.read_h5ad(adata_path_cells) - if adata_guides is None: - adata_path_guides = path / "anndata_objects" / "guide_bulked.h5ad" - adata_guides = ad.read_h5ad(adata_path_guides) - if adata_genes is None: - adata_path_genes = path / "anndata_objects" / "gene_bulked.h5ad" - adata_genes = ad.read_h5ad(adata_path_genes) - gene_guide_dict = group_guides() - - plot_umap( - gene="NTC", - adata=adata_cells, - data_point_type="cell", - save_path=( - output_path / "fig_1_umap_cell_ntc.png" if output_path is not None else None - ), - report_dir=report_dir, - filename="umap_cell_ntc.png", - ) - plt.figure() - plot_umap( - gene="Nontargeting", - adata=adata_guides, - guides=gene_guide_dict.get("Nontargeting", []), - data_point_type="guide", - save_path=( - output_path / "fig_1umap_guide_ntc.png" if output_path is not None else None - ), - report_dir=report_dir, - filename="umap_guide_ntc.png", - ) - plt.figure() - plot_umap( - gene="NTC", - adata=adata_genes, - data_point_type="gene", - save_path=( - output_path / "fig_1_umap_gene_ntc.png" if output_path is not None else None - ), - report_dir=report_dir, - filename="umap_gene_ntc.png", - ) - plt.figure() - - return - - -def report_umap_plots( - feature_dir: str, - output_path: Optional[str] = None, - report_dir: Optional[str] = None, -): - """ - Generate all standard UMAP report plots. - - Wrapper function that generates: - - NTC control plots (cell, guide, gene levels) - - Protein complex plots (RPL, NUP, TRAPPC, KRT) - - Args: - feature_dir: Path to feature directory containing anndata_objects/ - output_path: Legacy parameter - directory for output - report_dir: Path to report directory (preferred over output_path) - """ - path = Path(feature_dir) - if output_path is not None: - output_path = Path(output_path) - output_path.mkdir(parents=True, exist_ok=True) - - adata_path_cells = path / "anndata_objects" / "features_processed.h5ad" - adata_cells = ad.read_h5ad(adata_path_cells) - adata_path_guides = path / "anndata_objects" / "guide_bulked.h5ad" - adata_guides = ad.read_h5ad(adata_path_guides) - adata_path_genes = path / "anndata_objects" / "gene_bulked.h5ad" - adata_genes = ad.read_h5ad(adata_path_genes) - - report_umap_plot_1( - feature_dir=feature_dir, - adata_cells=adata_cells, - adata_guides=adata_guides, - adata_genes=adata_genes, - output_path=output_path, - report_dir=report_dir, - ) - - report_umap_plot_2( - feature_dir=feature_dir, - adata_cells=adata_cells, - adata_guides=adata_guides, - adata_genes=adata_genes, - output_path=output_path, - report_dir=report_dir, - ) - - return - - -if __name__ == "__main__": - feature_dir = "/hpc/projects/intracellular_dashboard/ops/ops0031_20250424/3-assembly/dynaclr_features" - output_path = "/hpc/projects/intracellular_dashboard/ops/ops0031_20250424/3-assembly/dynaclr_features/report_plots" - report_umap_plots( - feature_dir=feature_dir, - output_path=output_path, - ) diff --git a/src/ops_model/data/embeddings/utils.py b/src/ops_model/data/embeddings/utils.py deleted file mode 100644 index 5395752..0000000 --- a/src/ops_model/data/embeddings/utils.py +++ /dev/null @@ -1,56 +0,0 @@ -from ast import literal_eval -from pathlib import Path - -import pandas as pd -import anndata as ad - -from ops_model.data.paths import OpsPaths - - -def group_guides(): - from collections import defaultdict - - gene_lib = pd.read_csv(OpsPaths(experiment="doesnt_matter").other["gene_library"]) - guide_to_gene = defaultdict(list) - for _, row in gene_lib.iterrows(): - if pd.isna(row["gene_symbol"]): - key = "NTC" - else: - key = row["gene_symbol"] - guide_to_gene[key].append(row["sgRNA"]) - - return guide_to_gene - - -def get_gene_complexes(): - path = "/hpc/projects/icd.fast.ops/configs/annotated_gene_panel_July2025.csv" - df = pd.read_csv(path) - complex_df = df[["Gene.name", "In_same_complex_with"]] - gene_list = list(complex_df["Gene.name"]) - unique_complexes = dict() - black_set = set() - count = 0 - for idx, row in complex_df.iterrows(): - if row["Gene.name"] in black_set: - continue - genes_in_complex = literal_eval(row.In_same_complex_with) - genes_in_complex.append(row["Gene.name"]) - for g in genes_in_complex: - to_add = [g for g in genes_in_complex if g in gene_list] - unique_complexes[count] = to_add - count += 1 - for g in to_add: - black_set.add(g) - return {k: v for k, v in unique_complexes.items() if len(v) > 2} - - -def load_adata(path): - path = Path(path) - adata_path_cells = path / "anndata_objects" / "features_processed.h5ad" - adata_cells = ad.read_h5ad(adata_path_cells) - adata_path_guides = path / "anndata_objects" / "guide_bulked.h5ad" - adata_guides = ad.read_h5ad(adata_path_guides) - adata_path_genes = path / "anndata_objects" / "gene_bulked.h5ad" - adata_genes = ad.read_h5ad(adata_path_genes) - - return adata_cells, adata_guides, adata_genes diff --git a/src/ops_model/data/move_links.py b/src/ops_model/data/move_links.py deleted file mode 100644 index 5b9ed39..0000000 --- a/src/ops_model/data/move_links.py +++ /dev/null @@ -1,142 +0,0 @@ -# %% -import shutil -from pathlib import Path - -from ops_model.data.paths import OpsPaths - -# For a list of experiments, copy linked_pheno_iss.csv files to a shared folder for model training - -# Configuration -experiments = [ - "ops0015_20250213", - "ops0016_20250220", - "ops0031_20250424", - "ops0032_20250428", - "ops0033_20250429", - "ops0035_20250501", - "ops0036_20250505", - "ops0037_20250506", - "ops0038_20250514", - "ops0041_20250519", - "ops0042_20250520", - "ops0043_20250605", - "ops0044_20250602", - "ops0045_20250603", - "ops0046_20250611", - "ops0047_20250612", - "ops0048_20250616", - "ops0049_20250626", - "ops0050_20250630", - "ops0051_20250623", - "ops0052_20250702", - "ops0053_20250709", - "ops0054_20250710", - "ops0055_20250715", - "ops0056_20250721", - "ops0057_20250722", - "ops0058_20250805", - "ops0059_20250804", - "ops0062_20250729", - "ops0063_20250731", - "ops0064_20250811", - "ops0065_20250812", - "ops0066_20250820", - "ops0067_20250826", - "ops0068_20250901", - "ops0069_20250902", - "ops0070_20250908", - "ops0071_20250828", - "ops0076_20250917", - "ops0078_20250923", - "ops0079_20250916", - "ops0081_20250924", - "ops0084_20251022", - "ops0085_20251118", - "ops0086_20250922", - "ops0089_20251119", - "ops0090_20251120", - "ops0091_20251117", - "ops0092_20251027", - "ops0094_20251217", - "ops0097_20251023", - "ops0098_20251113", - "ops0100_20251218", - "ops0101_20251211", - "ops0102_20251210", - "ops0103_20251216", - "ops0104_20251215", - "ops0105_20260106", - "ops0106_20251204", - "ops0107_20251208", - "ops0108_20251209", - "ops0110_20260108", - "ops0113_20251219", - "ops0114_20260112", - "ops0115_20260121", - "ops0116_20260120", - "ops0117_20260128", - "ops0118_20260129", - "ops0119_20260203", - "ops0120_20260204", - "ops0121_20260210", - "ops0122_20260211", - "ops0124_20260218", - "ops0125_20260219", - "ops0126_20260224", - "ops0128_20260225", - "ops0129_20260303", - "ops0130_20260304", - "ops0131_20260310", - "ops0132_20260316", - "ops0134_20260317", - "ops0135_20260318", - "ops0137_20260323", - "ops0138_20260305", - "ops0139_20260325", - "ops0140_20260331", - "ops0141_20260319", - "ops0142_20260401", - "ops0143_20260407", - "ops0144_20260406", - "ops0146_20260402", - "ops0149_20260409", - # Add more experiment names here -] -wells = ["A/1/0", "A/2/0", "A/3/0"] - -# Source directory where experiment folders are located -source_base_path = Path("/hpc/projects/intracellular_dashboard/ops") - -# Destination directory for copied files -destination_path = Path("/hpc/mydata/alexander.hillsley/ops/training_data") - -# Create destination directory if it doesn't exist -destination_path.mkdir(parents=True, exist_ok=True) - -# For each experiment in list of experiments -for experiment in experiments: - print(f"Processing experiment: {experiment}") - for well in wells: - print(f" - Processing well: {well}") - path_obj = OpsPaths(experiment=experiment, well=well) - - # Create a subdir at destination_path with experiment name - experiment_dest_dir = path_obj.links["training"].parent - experiment_dest_dir.mkdir(parents=True, exist_ok=True) - - # Source path for this experiment's linked_pheno_iss.csv - # Adjust this path structure based on your actual directory layout - source_file = path_obj.links["original"] - - # Check if source file exists - if source_file.exists(): - # Copy the linked_pheno_iss.csv file to the new subdir - dest_file = path_obj.links["training"] - shutil.copy2(source_file, dest_file) - print(f"  Copied {source_file} to {dest_file}") - else: - print(f"  File not found: {source_file}") - -print("\nCopying complete!") - -# %% diff --git a/tests/features/test_feature_metrics.py b/tests/features/test_feature_metrics.py deleted file mode 100644 index e828898..0000000 --- a/tests/features/test_feature_metrics.py +++ /dev/null @@ -1,46 +0,0 @@ -import warnings - -# Filter anndata zarr deprecation warning BEFORE importing anndata -warnings.filterwarnings("ignore", message=".*zarr v2.*", category=DeprecationWarning) - -import numpy as np -import anndata as ad -import pytest - -from ops_model.data.embeddings.embeddding_metrics import ( - alignment_and_uniformity, - mean_similarity, -) - - -@pytest.fixture(scope="module") -def constant_adata(): - n_cells = 5000 - n_features = 50 - - # All features are constant (zeros) - X = np.repeat(np.arange(n_features).reshape(1, -1), n_cells, axis=0).astype(float) - - # Create observations metadata - obs = { - "label_str": ["gene_A"] * (n_cells // 2) + ["gene_B"] * (n_cells // 2), - "label_int": [0] * (n_cells // 2) + [1] * (n_cells // 2), - } - - # Create AnnData object - adata = ad.AnnData(X=X, obs=obs) - return adata - - -def test_mean_similarity(constant_adata): - mean_sim, std_sim = mean_similarity(constant_adata, n_samples=1000, batch_size=100) - assert np.isclose( - mean_sim, 1.0 - ), f"Expected mean similarity close to 1.0, got {mean_sim}" - assert np.isclose( - std_sim, 0.0 - ), f"Expected std similarity close to 0.0, got {std_sim}" - - -def test_alignment_uniformity(constant_adata): - return diff --git a/tests/test_basedataset.py b/tests/test_basedataset.py deleted file mode 100644 index 8a831dc..0000000 --- a/tests/test_basedataset.py +++ /dev/null @@ -1,135 +0,0 @@ -import pytest -import torch -import zarr -import numpy as np -from ops_model.data.base_dataset import BaseDataset -from ops_model.data import data_loader - -import warnings - -warnings.filterwarnings("ignore", category=zarr.errors.ZarrUserWarning) - - -@pytest.fixture(scope="module") -def dataset_args(): - experiment_dict = {"ops0031_20250424": ["A/1/0", "A/2/0", "A/3/0"]} - dm = data_loader.OpsDataManager( - experiments=experiment_dict, - batch_size=2, - data_split=(1, 0, 0), - out_channels=["Phase2D", "mCherry"], - initial_yx_patch_size=(256, 256), - verbose=False, - ) - - labels_df = dm.get_labels() - stores = dm.combine_stores() - - return stores, labels_df - - -def test_base_dataset_defaults(dataset_args): - stores, labels_df = dataset_args - # leave all optional args as defaults - dataset = BaseDataset( - stores=stores, - labels_df=labels_df, - out_channels=["Phase2D"], - ) - basic_item = dataset[0] - expexted_keys = { - "data": torch.Tensor, - "mask": torch.Tensor, - "gene_label": int, - "marker_label": list, - "total_index": int, - "crop_info": dict, - } - batch_keys = list(basic_item.keys()) - for k, v in expexted_keys.items(): - assert k in batch_keys - print(f"Testing key: {k}") - assert isinstance(basic_item[k], v) - - return - - -def test_base_dataset_out_channels(dataset_args): - stores, labels_df = dataset_args - # test with multiple out channels - dataset_a = BaseDataset( - stores=stores, - labels_df=labels_df, - out_channels=["Phase2D", "mCherry"], - ) - basic_item = dataset_a[0] - expected_marker_labels = ["Phase2D", "mCherry"] - assert basic_item["marker_label"] == expected_marker_labels - shape_a = basic_item["data"].shape - assert shape_a[0] == 2 # 2 out channels are expected - - dataset_b = BaseDataset( - stores=stores, - labels_df=labels_df, - out_channels=["Phase2D"], - ) - basic_item = dataset_b[0] - expected_marker_labels = ["Phase2D"] - assert basic_item["marker_label"] == expected_marker_labels - shape_b = basic_item["data"].shape - assert shape_b[0] == 1 # 1 out channel is expected - - return - - -def test_base_dataset_mask_cell(dataset_args): - stores, labels_df = dataset_args - # test with mask_cell True - dataset_no_mask = BaseDataset( - stores=stores, - labels_df=labels_df, - out_channels=["Phase2D"], - mask_cell=False, - use_original_crop_size=True, - ) - basic_item_no_mask = dataset_no_mask[0] - - a, b = np.where(basic_item_no_mask["data"][0] == 0) - assert len(a) == 0 # there should be no zeros in data when mask_cell is False - - dataset_mask = BaseDataset( - stores=stores, - labels_df=labels_df, - out_channels=["Phase2D"], - mask_cell=True, - use_original_crop_size=True, - ) - basic_item_mask = dataset_mask[0] - - assert not np.array_equal(basic_item_no_mask["data"], basic_item_mask["data"]) - - return - - -def test_base_dataset_original_shape(dataset_args): - stores, labels_df = dataset_args - # test with multiple out channels - dataset_a = BaseDataset( - stores=stores, - labels_df=labels_df, - out_channels=["Phase2D", "mCherry"], - use_original_crop_size=False, - ) - basic_item_a = dataset_a[0] - - dataset_b = BaseDataset( - stores=stores, - labels_df=labels_df, - out_channels=["Phase2D", "mCherry"], - use_original_crop_size=True, - ) - basic_item_b = dataset_b[0] - - assert basic_item_a["data"].shape != basic_item_b["data"].shape - - return From 70e1a3b68788461f09dfdecb21028c810d1e27fb Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 24 Jun 2026 09:11:21 -0700 Subject: [PATCH 10/11] remove iss_drift_fix.py from the tracked repo, it was a 1 off script and should not be aprt of the public repo --- src/ops_model/data/iss_drift_fix.py | 377 ---------------------------- 1 file changed, 377 deletions(-) delete mode 100644 src/ops_model/data/iss_drift_fix.py diff --git a/src/ops_model/data/iss_drift_fix.py b/src/ops_model/data/iss_drift_fix.py deleted file mode 100644 index 906fc57..0000000 --- a/src/ops_model/data/iss_drift_fix.py +++ /dev/null @@ -1,377 +0,0 @@ -"""Sidecar metadata patch for stale ISS labels in features_processed_*.h5ad. - -Background ----------- -``move_links.py`` snapshots ``/3-assembly/_linked_pheno_iss.csv`` into -the frozen ``/hpc/projects/icd.fast.ops/models/link_csvs//`` location once -per experiment. When the ISS calling pipeline is later re-run, only the -``3-assembly/`` copy is refreshed; the frozen snapshot stays stale. - -Every feature-extraction model (cell-dino, dinov3, subcell, cp_features, -cp_extraction) loads the frozen copy via ``OpsPaths(...).links["training"]`` -([paths.py:50]), so cells extracted after the ISS re-run carry **the old** -gene/sgRNA assignments in ``obs["sgRNA"]`` / ``obs["perturbation"]``. - -Diagnosis confirmed that ``bbox`` is bit-identical across the two CSVs for all -matched cells, so the extracted features themselves remain correct. Only the -metadata labels are wrong, which is fixable without re-running inference. - -What this script does ---------------------- -For each affected experiment, for each ``features_processed_.h5ad``, -writes a sibling parquet named -``features_processed__obs_corrected.parquet`` containing the corrected -metadata. The original h5ad is **never modified**. - -Two-step join (because the h5ad obs has no segmentation_id): - 1. ``(well, x_position, y_position)`` --> FROZEN CSV --> ``segmentation_id`` - 2. ``segmentation_id`` --> CURRENT CSV --> ``sgRNA``, ``gene_name`` - -Sidecar schema (one row per obs row, indexed by positional ``obs_idx``): - obs_idx, sgRNA, perturbation, label_str, segmentation_id, correction_status - -``correction_status`` is one of: - ``matched`` cell matched in both frozen + current ISS; labels - updated to current values. - ``orphan_in_h5ad`` cell was in frozen ISS (and h5ad) but the current - ISS run dropped it; original (stale) labels kept. - ``position_unresolved`` could not even resolve the obs row to a frozen - CSV row (rare; floating-point drift or missing - well CSV); original labels kept. - -Downstream consumers load via -``ops_model.features.anndata_utils.load_features_corrected`` which transparently -applies the sidecar. - -Usage ------ - # Build for one experiment locally: - python -m ops_model.data.iss_drift_fix --exp ops0031_20250424 - - # Build for all stale experiments via SLURM: - python -m ops_model.data.iss_drift_fix --slurm - - # Build for all stale experiments locally (slow): - python -m ops_model.data.iss_drift_fix -""" -from __future__ import annotations - -import argparse -import logging -import time -from pathlib import Path - -logger = logging.getLogger(__name__) - -BASE = Path("/hpc/projects/icd.fast.ops") -FEATURE_DIR_NAME = "cell_dino_features" -ANNDATA_SUBDIR = "anndata_objects" -SIDECAR_SUFFIX = "_obs_corrected.parquet" - -# Position-join tolerance: x_position in h5ad obs is x_pheno from the frozen -# CSV ([cell_dino.py:166]), so we expect exact equality up to CSV-roundtrip -# float precision. Rounding to 3 decimals (0.001 px) catches roundtrip noise -# without colliding cells (smallest cell-to-cell distance ~ several pixels). -POS_ROUND = 3 - - -def _well_prefix_from_obs_well(well: str) -> str | None: - """obs["well"] looks like ``A/1/0_ops0031_20250424`` -> ``A1``.""" - path = str(well).split("_", 1)[0] # ``A/1/0`` - if len(path) < 3 or path[1] != "/": - return None - return path[0] + path[2] - - -def _process_experiment(exp: str, base: str = str(BASE), - feature_dir: str = FEATURE_DIR_NAME, - overwrite: bool = False) -> dict: - """Build sidecar parquets for every ``features_processed_*.h5ad`` in one - experiment. - - Self-contained so submitit can pickle it for SLURM dispatch (all imports - inside the function). - """ - import anndata as ad - import numpy as np - import pandas as pd - - base_p = Path(base) - h5ad_dir = base_p / exp / "3-assembly" / feature_dir / ANNDATA_SUBDIR - if not h5ad_dir.exists(): - return {"exp": exp, "status": "no_anndata_objects"} - - h5ads = sorted(p for p in h5ad_dir.glob("features_processed_*.h5ad") - if SIDECAR_SUFFIX not in p.name) - - results = [] - t0 = time.time() - for h5 in h5ads: - channel = h5.stem.replace("features_processed_", "") - sidecar = h5.with_name(h5.stem + SIDECAR_SUFFIX) - if sidecar.exists() and not overwrite: - results.append({"channel": channel, "status": "skipped_exists", - "sidecar": str(sidecar)}) - continue - - try: - a = ad.read_h5ad(h5, backed="r") - obs = a.obs[["well", "x_position", "y_position", "sgRNA", - "perturbation"]].copy() - obs["_obs_idx"] = np.arange(len(obs), dtype=np.int64) - obs["_well_prefix"] = obs["well"].map(_well_prefix_from_obs_well) - - blocks = [] - for wp, obs_w in obs.groupby("_well_prefix", dropna=False): - if wp is None or not isinstance(wp, str): - logger.warning(f"{exp} {channel}: bad well prefix; " - f"{len(obs_w)} cells") - blocks.append(_unresolved_block(obs_w)) - continue - cur_p = base_p / exp / "3-assembly" / f"{wp}_linked_pheno_iss.csv" - frz_p = (base_p / "models" / "link_csvs" / exp - / f"{wp}_linked_pheno_iss.csv") - if not (cur_p.exists() and frz_p.exists()): - logger.warning(f"{exp} {wp}: missing link CSV " - f"(cur={cur_p.exists()}, frz={frz_p.exists()})") - blocks.append(_unresolved_block(obs_w)) - continue - blocks.append(_correct_well(obs_w, frz_p, cur_p)) - - patch = pd.concat(blocks, ignore_index=True) - patch = patch.sort_values("obs_idx").reset_index(drop=True) - if len(patch) != len(obs): - raise RuntimeError( - f"size mismatch: patch={len(patch)} vs h5ad={len(obs)}") - - sidecar.parent.mkdir(parents=True, exist_ok=True) - tmp = sidecar.with_suffix(sidecar.suffix + ".tmp") - patch.to_parquet(tmp, index=False) - tmp.replace(sidecar) - - counts = patch["correction_status"].value_counts().to_dict() - results.append({ - "channel": channel, - "n_total": len(patch), - "n_matched": int(counts.get("matched", 0)), - "n_orphan": int(counts.get("orphan_in_h5ad", 0)), - "n_unresolved": int(counts.get("position_unresolved", 0)), - "sidecar": str(sidecar), - "status": "ok", - }) - except Exception as e: - results.append({"channel": channel, "status": "error", - "error": str(e)}) - - return {"exp": exp, "elapsed_s": round(time.time() - t0, 1), - "channels": results} - - -def _unresolved_block(obs_w): - """Build a sidecar block where every row is position_unresolved (keep - original labels).""" - import pandas as pd - return pd.DataFrame({ - "obs_idx": obs_w["_obs_idx"].to_numpy(), - "sgRNA": obs_w["sgRNA"].astype(str).to_numpy(), - "perturbation": obs_w["perturbation"].astype(str).to_numpy(), - "label_str": obs_w["perturbation"].astype(str).to_numpy(), - "segmentation_id": pd.array([pd.NA] * len(obs_w), dtype="Int64"), - "correction_status": "position_unresolved", - }) - - -def _correct_well(obs_w, frz_p: Path, cur_p: Path): - """Two-step join for one well's obs rows. - - Step 1: position --> frozen CSV --> segmentation_id. - Step 2: segmentation_id --> current CSV --> corrected sgRNA + gene_name. - """ - import numpy as np - import pandas as pd - - cols = ["segmentation_id", "x_pheno", "y_pheno", "gene_name", "sgRNA"] - frz = pd.read_csv(frz_p, usecols=cols, low_memory=False) - cur = pd.read_csv(cur_p, usecols=cols, low_memory=False) - frz = frz.dropna(subset=["segmentation_id", "x_pheno", "y_pheno"]) - cur = cur.dropna(subset=["segmentation_id"]) - frz["segmentation_id"] = frz["segmentation_id"].astype("int64") - cur["segmentation_id"] = cur["segmentation_id"].astype("int64") - - # Step 1: position --> segmentation_id via frozen CSV. - frz["_x"] = frz["x_pheno"].astype(float).round(POS_ROUND) - frz["_y"] = frz["y_pheno"].astype(float).round(POS_ROUND) - obs_w = obs_w.assign( - _x=obs_w["x_position"].astype(float).round(POS_ROUND), - _y=obs_w["y_position"].astype(float).round(POS_ROUND), - ) - # Dedup the frozen position table (extremely rare collisions at sub-pixel - # rounding; keep first). - frz_lookup = (frz[["_x", "_y", "segmentation_id"]] - .drop_duplicates(subset=["_x", "_y"], keep="first")) - - s1 = obs_w.merge(frz_lookup, on=["_x", "_y"], how="left") - - # Step 2: segmentation_id --> corrected labels via current CSV. ISS can - # write multiple rows per segmentation_id (one per barcode read), some - # with NaN gene_name where the call failed; sort so a real gene_name - # wins over a NaN one for the same cell. - cur_sorted = cur.sort_values( - ["segmentation_id", "gene_name"], na_position="last" - ) - cur_lookup = cur_sorted.drop_duplicates( - subset="segmentation_id", keep="first" - )[["segmentation_id", "gene_name", "sgRNA"]] - # NaN gene_name --> "NTC" (matches data_loader.get_labels' fillna). NaN - # sgRNA --> "" (matches what a fresh extraction would record). - cur_lookup = cur_lookup.assign( - _cur_gene_name=cur_lookup["gene_name"].fillna("NTC").astype(str), - _cur_sgRNA=cur_lookup["sgRNA"].fillna("").astype(str), - )[["segmentation_id", "_cur_gene_name", "_cur_sgRNA"]] - - s2 = s1.merge(cur_lookup, on="segmentation_id", how="left") - - # Classify status. A row counts as "matched" whenever the cell's - # segmentation_id is present in the current ISS calls, even if the new - # call is no-barcode-found ("NTC"). True orphans are seg_ids the ISS - # re-run removed entirely. - has_seg = s2["segmentation_id"].notna().to_numpy() - has_cur = s2["_cur_gene_name"].notna().to_numpy() - status = np.where( - ~has_seg, "position_unresolved", - np.where(~has_cur, "orphan_in_h5ad", "matched"), - ) - - # For matched rows, use current labels; otherwise keep stale. - cur_gene = s2["_cur_gene_name"].fillna("").astype(str).to_numpy() - cur_sgRNA = s2["_cur_sgRNA"].fillna("").astype(str).to_numpy() - orig_pert = s2["perturbation"].astype(str).to_numpy() - orig_sgRNA = s2["sgRNA"].astype(str).to_numpy() - matched_mask = (status == "matched") - - new_pert = np.where(matched_mask, cur_gene, orig_pert) - new_sgRNA = np.where(matched_mask, cur_sgRNA, orig_sgRNA) - - return pd.DataFrame({ - "obs_idx": s2["_obs_idx"].to_numpy(), - "sgRNA": new_sgRNA, - "perturbation": new_pert, - "label_str": new_pert, - "segmentation_id": s2["segmentation_id"].astype("Int64"), - "correction_status": status, - }) - - -# ---------------------------------------------------------------------------- -# Stale-experiment discovery (content-based, no file-mtime heuristic) -# ---------------------------------------------------------------------------- - -def discover_stale_experiments(base: Path = BASE, - feature_dir: str = FEATURE_DIR_NAME) -> list[str]: - """Return experiments where the 3-assembly link CSVs differ in content - from the frozen ``models/link_csvs//`` snapshot.""" - import pandas as pd - - import hashlib - exps = sorted({p.parent.parent.parent.name - for p in base.glob(f"ops0*/3-assembly/{feature_dir}/{ANNDATA_SUBDIR}")}) - - def _gene_col_hash(p: Path) -> str: - col = pd.read_csv(p, usecols=["gene_name"], low_memory=False)["gene_name"] - return hashlib.md5(col.fillna("").astype(str).str.cat().encode()).hexdigest() - - stale = [] - for exp in exps: - for w in ("A1", "A2", "A3"): - cur_p = base / exp / "3-assembly" / f"{w}_linked_pheno_iss.csv" - frz_p = base / "models" / "link_csvs" / exp / f"{w}_linked_pheno_iss.csv" - if not (cur_p.exists() and frz_p.exists()): - continue - n_cur = sum(1 for _ in open(cur_p)) - 1 - n_frz = sum(1 for _ in open(frz_p)) - 1 - if n_cur != n_frz: - stale.append(exp); break - # Same row count -> hash the full gene_name column to catch drift - # that's spread throughout the file (200-row samples miss this). - if _gene_col_hash(cur_p) != _gene_col_hash(frz_p): - stale.append(exp); break - return stale - - -# ---------------------------------------------------------------------------- -# CLI / SLURM orchestration -# ---------------------------------------------------------------------------- - -def _run_slurm(exps: list[str], overwrite: bool) -> dict: - from ops_utils.hpc.slurm_batch_utils import submit_parallel_jobs - - jobs = [] - for exp in exps: - jobs.append({ - "name": f"iss_fix_{exp.split('_')[0]}", - "func": _process_experiment, - "kwargs": {"exp": exp, "overwrite": overwrite}, - }) - return submit_parallel_jobs( - jobs_to_submit=jobs, - experiment="iss_drift_fix", - slurm_params={ - "timeout_min": 30, - "mem": "16GB", - "cpus_per_task": 2, - "slurm_partition": "cpu", - }, - log_dir="iss_drift_fix", - manifest_prefix="iss_fix", - wait_for_completion=True, - ) - - -def main(): - logging.basicConfig(level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s") - ap = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter) - ap.add_argument("--exp", help="Single experiment to process") - ap.add_argument("--exps", nargs="+", help="Explicit list of experiments") - ap.add_argument("--slurm", action="store_true", - help="Fan out across experiments via SLURM") - ap.add_argument("--overwrite", action="store_true", - help="Re-build sidecars even if they already exist") - ap.add_argument("--feature-dir", default=FEATURE_DIR_NAME, - help=f"Feature directory under 3-assembly/ " - f"(default: {FEATURE_DIR_NAME})") - args = ap.parse_args() - - if args.exp: - r = _process_experiment(args.exp, feature_dir=args.feature_dir, - overwrite=args.overwrite) - for c in r.get("channels", []): - logger.info(f" {args.exp} {c}") - return - - if args.exps: - exps = args.exps - else: - logger.info("Discovering stale experiments (content-based)…") - exps = discover_stale_experiments(feature_dir=args.feature_dir) - logger.info(f" found {len(exps)} stale experiments") - - if args.slurm: - result = _run_slurm(exps, overwrite=args.overwrite) - failed = result.get("failed", []) - logger.info(f"SLURM done: {len(failed)} failed") - if failed: - logger.warning(f" failed jobs: {failed}") - else: - for exp in exps: - logger.info(f"--- {exp} ---") - r = _process_experiment(exp, feature_dir=args.feature_dir, - overwrite=args.overwrite) - for c in r.get("channels", []): - logger.info(f" {c}") - - -if __name__ == "__main__": - main() From 4efbe68f4989ed33b99558f64314dc31529017e2 Mon Sep 17 00:00:00 2001 From: Alexander Hillsley Date: Wed, 24 Jun 2026 09:18:54 -0700 Subject: [PATCH 11/11] Remove deprecated post_process/map mAP re-export wrapper post_process/map/ was only a backward-compat shim re-exporting the mAP functions from ops_utils.analysis (map_scores, map_umap). Nothing in ops_model imports it anymore, so move it to deprecated/ (kept locally only) and drop it from the tree. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/ops_model/post_process/map/__init__.py | 0 src/ops_model/post_process/map/map.py | 22 ---------------------- 2 files changed, 22 deletions(-) delete mode 100644 src/ops_model/post_process/map/__init__.py delete mode 100644 src/ops_model/post_process/map/map.py diff --git a/src/ops_model/post_process/map/__init__.py b/src/ops_model/post_process/map/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/ops_model/post_process/map/map.py b/src/ops_model/post_process/map/map.py deleted file mode 100644 index 9129111..0000000 --- a/src/ops_model/post_process/map/map.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Phenotypic Activity Assessment using copairs mAP (mean Average Precision). - -Backward-compatible re-export wrapper. All implementations now live in -``ops_utils.analysis`` for cross-pipeline reuse. -""" - -# Re-export all shared mAP functions from ops_utils (single source of truth) -from ops_utils.analysis.map_scores import ( # noqa: F401 - _compute_single_complex_map, - adata_to_copairs_df, - compute_auc_score, - compute_threshold_sweep_auc, - phenotypic_activity_assesment, - phenotypic_distinctivness, - phenotypic_consistency_corum, - phenotypic_consistency_manual_annotation, - map_main, -) - -# Re-export UMAP visualization from ops_utils -from ops_utils.analysis.map_umap import metric_umap # noqa: F401