diff --git a/CLAUDE.md b/CLAUDE.md index e63604c1..086502c1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -47,7 +47,7 @@ Single entry point: `protspace = protspace.cli.app:app` | `protspace project` | HDF5 → dimensionality reduction | | `protspace annotate` | Fetch protein annotations | | `protspace bundle` | Combine projections + annotations → .parquetbundle | -| `protspace stats` | Compute projection quality statistics (cluster-validity + faithfulness) | +| `protspace stats` | Compute projection quality statistics (annotation-based cluster-validity + faithfulness) | | `protspace serve` | Launch Dash web frontend | | `protspace style` | Add annotation colors/styles | @@ -67,17 +67,20 @@ protspace prepare -i -m -o [options] # Parameter sweep: protspace prepare -i emb.h5 -m "umap2:n_neighbors=15" -m "umap2:n_neighbors=50" -m pca2 -o output # Inline params: protspace prepare -i emb.h5 -m "pca2,umap2:n_neighbors=50;min_dist=0.3" -o output # Quality stats (opt-in): protspace prepare -i emb.h5 -m pca2,umap2 --stats -o output +# Quality stats scoped to specific annotations: protspace prepare -i emb.h5 -m pca2 --stats --stats-annotation major_group,ec_number -o output ``` ### protspace stats Usage -Compute per-projection quality statistics for an existing project directory (also available inline via `prepare --stats`). Cluster-validity → `statistics.parquet` (bundle 5th part) + per-protein `cluster_elbow_*` / `cluster_silhouette_*` membership columns (each value a `cluster N` label with the per-point silhouette attached as `|score`) + auto legend styles; faithfulness (local kNN + global metrics, tagged `scope`) → each projection's `info_json.quality`. `--cluster-selection elbow|silhouette|both` picks the K-selection method(s). +Compute per-projection quality statistics for an existing project directory (also available inline via `prepare --stats`). Validity is **annotation-based**: silhouette/DBI/CH are scored on a user-selected annotation's own category labels (not auto-clustering), computed once for the source embedding and again for each projection — `statistics.parquet` (bundle 5th part) gains an `annotation` column and `space_kind ∈ {embedding, projection}`. `--stats-annotation auto|name1,name2` (default `auto`) picks which annotation column(s) to score (all "suitable" low-cardinality categoricals, or an explicit list); requires `-a/--annotations`. Auto-clustering (KMeans elbow/silhouette) is retained for the per-protein `cluster_elbow_*` / `cluster_silhouette_*` membership columns (each value a `cluster N` label with the per-point silhouette attached as `|score`) + auto legend styles, but is no longer self-scored — instead its **ARI**/**NMI** agreement against each scored annotation is recorded (`stat_family=cluster_agreement`). Faithfulness (local kNN + global metrics, tagged `scope`) → each projection's `info_json.quality`. `--cluster-selection elbow|silhouette|both` picks the K-selection method(s). ```bash -# Standalone (embeddings needed for faithfulness) +# Standalone (embeddings needed for faithfulness + the once-per-embedding annotation-validity pass) protspace stats -i emb.h5 -p project_dir -o statistics.parquet -# Enrich annotations in place + emit cluster legend styles for `bundle --settings` +# Enrich annotations in place, score annotation-based validity, + emit cluster legend styles for `bundle --settings` protspace stats -i emb.h5 -p project_dir -o statistics.parquet -a annotations.parquet --settings-out styles.json +# Score only specific annotations instead of every suitable categorical (default: auto) +protspace stats -i emb.h5 -p project_dir -o statistics.parquet -a annotations.parquet --stats-annotation major_group,ec_number # Elbow + silhouette-optimal clusterings side by side protspace stats -i emb.h5 -p project_dir -o statistics.parquet -a annotations.parquet --cluster-selection both # Fold a stats parquet + settings into a bundle @@ -149,12 +152,14 @@ src/protspace/ ├── stats/ # Projection quality statistics (opt-in, --stats) │ ├── __init__.py # Lazy STATISTICS registry + compute_statistics entry │ ├── base.py # StatContext / StatRow / AnnotationColumn / StatsReport -│ ├── driver.py # Per-projection contexts, embedding id-join, run stats +│ ├── driver.py # Per-projection contexts + once-per-embedding pass, embedding id-join, run stats │ ├── carriage.py # Route rows to bundle parts (metadata / annotations / legend) +│ ├── annotation_select.py # Pick "suitable" annotations (auto/list) + build id→category labels │ ├── cluster/kmeans_elbow.py # KMeans + distance-to-chord elbow (subsampled at scale) │ └── metrics/ -│ ├── validity.py # silhouette / Davies-Bouldin / Calinski-Harabasz -│ └── faithfulness.py # kNN-overlap / trustworthiness / continuity +│ ├── validity.py # Auto-cluster (KMeans) + ARI/NMI agreement vs annotations +│ ├── annotation_validity.py # silhouette / Davies-Bouldin / Calinski-Harabasz per annotation +│ └── faithfulness.py # kNN-overlap / trustworthiness / continuity ├── utils/ │ ├── __init__.py # Lazy exports: REDUCERS dict, reducer constants │ ├── constants.py # DimensionReductionConfig, method name constants @@ -216,7 +221,7 @@ HDF5 file (float16 embeddings) 2. `projections_metadata` — projection names, dimensions, parameters (faithfulness rides in `info_json.quality` when `--stats`) 3. `projections_data` — reduced coordinates per protein per projection 4. `settings` (optional) — annotation styles, pinned values, display config -5. `statistics` (optional) — tidy per-projection cluster-validity table (`protspace stats` / `prepare --stats`) +5. `statistics` (optional) — tidy table of annotation-based validity (silhouette/DBI/CH per annotation, `space_kind ∈ {embedding, projection}`, `annotation` column) + auto-cluster ARI/NMI agreement (`stat_family=cluster_agreement`) (`protspace stats` / `prepare --stats`) Positional layout is `core(3) + settings? + statistics?`. When statistics are present but settings are absent, the settings slot is written as **zero bytes** so statistics stay at position five (readers branch on emptiness, not part count). Both bundled and separate-file (`--no-bundled`) output persist `settings.parquet` and `statistics.parquet` when present. @@ -240,10 +245,12 @@ uv run pytest tests/ --cov=src/protspace # With coverage | `test_settings_converter.py` | 31 | Settings table ↔ visualization state conversion | | `test_uniprot_annotation_retriever.py` | 24 | UniProt API mocking, inactive entry resolution | | `test_pipeline_utils.py` | 70 | ReductionPipeline, EmbeddingSet, method parsing, multi-input merging, inline param overrides | -| `test_stats.py` | 43 | Projection statistics: elbow, cluster-validity, faithfulness (dual continuity + global metrics), cluster-selection (elbow/silhouette/both), subsample determinism/order-invariance, silhouette consistency | -| `test_stats_cli.py` | 12 | `protspace stats` CLI + `prepare` stats wiring, `--settings-out` guard, `--cluster-selection` validation | +| `test_stats.py` | 48 | Projection statistics: elbow, annotation-based validity (silhouette/DBI/CH per annotation), auto-cluster ARI/NMI agreement, faithfulness (dual continuity + global metrics), cluster-selection (elbow/silhouette/both), subsample determinism/order-invariance, silhouette consistency | +| `test_stats_cli.py` | 16 | `protspace stats` CLI + `prepare` stats wiring, `--stats-annotation` (auto/list) wiring, `--settings-out` guard, `--cluster-selection` validation | | `test_stats_carriage.py` | 10 | Routing rows to bundle parts (metadata quality, annotation columns, cluster legend) | | `test_stats_bundle.py` | 7 | Optional 5th (statistics) bundle part round-trip | +| `test_annotation_select.py` | 6 | Annotation selection: suitability filter (cardinality/numeric/id-like exclusion), `auto` vs explicit-list label building (explicit names bypass the heuristic), missing-value dropping | +| `test_annotation_validity.py` | 6 | `AnnotationValidityStatistic`: silhouette/DBI/CH scored per annotation on `ctx.coords`, embedding vs. projection `space_kind`, missing-value exclusion, single-category no-op, id-canonical subsample determinism | | `test_biocentral_embedder.py` | 23 | Biocentral API client, embedding flow | | `test_fasta.py` | 17 | FASTA parsing, edge cases, CSV annotation loading | | `test_biocentral_retriever.py` | 14 | Biocentral prediction retriever (TMbed parsing, per-sequence) | diff --git a/README.md b/README.md index badf4409..3e62ab0f 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ ProtSpace is a visualization tool for exploring **protein embeddings** or **simi - **Multiple projections**: PCA, UMAP, t-SNE, MDS, PaCMAP, LocalMAP - **Automatic annotations**: UniProt, InterPro, and Taxonomy -- **Quality metrics** _(opt-in)_: per-projection cluster-validity + faithfulness (local & global) via `--stats` +- **Quality metrics** _(opt-in)_: annotation-based cluster-validity + faithfulness (local & global) via `--stats` - **Structure viewer**: Integrated protein structure visualization - **Export**: PNG, PDF, SVG, HTML @@ -65,7 +65,7 @@ protspace stats -i embeddings/prot_t5.h5 -p projections/ -o statistics.parquet protspace bundle -p projections/ -a annotations.parquet -s statistics.parquet -o output.parquetbundle ``` -Or compute quality metrics inline during `prepare` with `--stats` (opt-in): cluster-validity + faithfulness per projection. See the [CLI Reference](docs/cli.md#projection-statistics---stats). +Or compute quality metrics inline during `prepare` with `--stats` (opt-in): annotation-based cluster-validity + faithfulness per projection. See the [CLI Reference](docs/cli.md#projection-statistics---stats). ## 📊 Example Output diff --git a/docs/cli.md b/docs/cli.md index e35a08a1..39c9c551 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -130,8 +130,9 @@ This produces three projections: `ProtT5 — PCA 2`, `ProtT5 — UMAP 2 (n=15)`, | ---- | ----------- | ------- | | `-o, --output` | Output directory. | `.` | | `--bundled / --no-bundled` | Bundle into single `.parquetbundle`. | bundled | -| `--stats / --no-stats` | Compute projection quality statistics (cluster-validity + faithfulness). See [Projection Statistics](#projection-statistics---stats). | off | +| `--stats / --no-stats` | Compute projection quality statistics (annotation-based cluster-validity + faithfulness). See [Projection Statistics](#projection-statistics---stats). | off | | `--cluster-selection` | With `--stats`, how to choose the cluster count K: `elbow`, `silhouette`, or `both`. | `elbow` | +| `--stats-annotation` | With `--stats`, which annotation column(s) to score for cluster-validity: `auto` (all suitable low-cardinality categoricals) or a comma-separated list. | `auto` | | `--keep-tmp` | Cache intermediates for resumability. | on | | `--no-log` | Skip writing `run.log`. | off | | `--dump-cache` | Print cached annotations and exit. | off | @@ -179,17 +180,21 @@ protspace bundle -p projections/ -a annotations.parquet \ ## `protspace stats` -Compute per-projection quality statistics for an existing project directory and write them as a `statistics.parquet` (the optional 5th `.parquetbundle` part). No annotations are required. See [Projection Statistics](#projection-statistics---stats) for what is computed. +Compute per-projection quality statistics for an existing project directory and write them as a `statistics.parquet` (the optional 5th `.parquetbundle` part). Faithfulness and the auto-cluster membership columns need no annotations; annotation-based validity (and its ARI/NMI agreement with the auto-clusters) needs `-a/--annotations`. See [Projection Statistics](#projection-statistics---stats) for what is computed. ```bash # Statistics for a project (embeddings needed for faithfulness) protspace stats -i embeddings/prot_t5.h5 -p projections/ -o statistics.parquet # Also enrich an annotations parquet in place with per-protein cluster-membership -# columns, and write the auto cluster-legend styles for `bundle` +# columns, score annotation-based validity, and write the auto cluster-legend styles protspace stats -i embeddings/prot_t5.h5 -p projections/ -o statistics.parquet \ -a annotations.parquet --settings-out cluster_styles.json +# Score only specific annotations instead of every suitable categorical (default: auto) +protspace stats -i embeddings/prot_t5.h5 -p projections/ -o statistics.parquet \ + -a annotations.parquet --stats-annotation major_group,ec_number + # Emit both the elbow and the silhouette-optimal clustering protspace stats -i embeddings/prot_t5.h5 -p projections/ -o statistics.parquet \ -a annotations.parquet --cluster-selection both @@ -197,28 +202,31 @@ protspace stats -i embeddings/prot_t5.h5 -p projections/ -o statistics.parquet \ | Flag | Description | Default | | ---- | ----------- | ------- | -| `-i, --input` | HDF5 embedding file(s) (for faithfulness). Repeat for multi-embedding; `-i file.h5:name` to override the name. | — | +| `-i, --input` | HDF5 embedding file(s) (for faithfulness + the once-per-embedding annotation-validity pass). Repeat for multi-embedding; `-i file.h5:name` to override the name. | — | | `-p, --projections` | Project directory with `projections_metadata.parquet` + `projections_data.parquet`. | — | | `-o, --output` | Output `statistics.parquet` path. | — | -| `-a, --annotations` | Annotations parquet to enrich in place with per-protein `cluster_*` membership columns (per-point silhouette attached as `value|score`). | — | +| `-a, --annotations` | Annotations parquet to enrich in place with per-protein `cluster_*` membership columns (per-point silhouette attached as `value|score`), and to score for annotation-based validity + ARI/NMI agreement. | — | | `--cluster-selection` | Cluster count K selection: `elbow`, `silhouette`, or `both`. | `elbow` | +| `--stats-annotation` | Which annotation column(s) to score for cluster-validity: `auto` (all suitable low-cardinality categoricals) or a comma-separated list. Requires `-a`. | `auto` | | `--settings-out` | Write auto cluster-legend styles here (JSON) for `bundle --settings`. Requires `-a`. | — | | `--metric` | High-dim distance metric for faithfulness when the projection metadata omits one (e.g. PCA/MDS). | `euclidean` | | `--seed` | Random seed. | `42` | ## Projection Statistics (`--stats`) -`prepare --stats` (opt-in) and the standalone `protspace stats` command compute two families of per-projection quality metrics and bake them into the output: +`prepare --stats` (opt-in) and the standalone `protspace stats` command compute three families of per-projection quality metrics and bake them into the output: -- **Cluster validity** — KMeans labels the projection, scored by **silhouette**, **Davies–Bouldin**, and **Calinski–Harabasz**, written to the tidy `statistics.parquet` (the bundle's 5th part). The cluster count K is chosen by the inertia **elbow** and/or by **max silhouette** — `--cluster-selection elbow|silhouette|both`. Each selection also becomes a per-protein membership column — `cluster_elbow_` and/or `cluster_silhouette_` — with the point's **silhouette attached to its value** as `cluster N|` (the same `value|score` convention as UniProt evidence codes / InterPro bit scores; suppressed by `--no-scores`). Membership columns get an auto Kelly-palette legend (the bundle's 4th settings part); in `statistics.parquet` the two selections are distinguished by `label_kind` (`kmeans_elbow` / `kmeans_silhouette`). +- **Annotation-based validity** — silhouette, Davies–Bouldin, and Calinski–Harabasz scored using an annotation's own category labels (not auto-clustering) — how well proteins already grouped by an annotation (e.g. `major_group`, `ec_number`) separate in a given space. Computed once for the source embedding (a separability "ceiling") and again for each projection, written to the tidy `statistics.parquet` (the bundle's 5th part) with `space_kind ∈ {embedding, projection}` and an `annotation` column naming which one was scored. `--stats-annotation auto|name1,name2` (default `auto`) picks which annotation column(s) to score — `auto` scores every "suitable" low-cardinality categorical (≥2 and ≤min(50, max(2, n/2)) distinct non-empty values, not numeric, and not a generated `cluster_*` column); requires `-a/--annotations`. +- **Auto-cluster agreement** — KMeans labels the projection; the cluster count K is chosen by the inertia **elbow** and/or by **max silhouette** — `--cluster-selection elbow|silhouette|both`. This auto-clustering is no longer scored against itself (that was circular); instead, when annotations are supplied, each labelling's **ARI** (adjusted Rand index) and **NMI** (normalized mutual information) agreement with every scored annotation is recorded (`stat_family=cluster_agreement`). Each selection also becomes a per-protein membership column — `cluster_elbow_` and/or `cluster_silhouette_` — with the point's **silhouette attached to its value** as `cluster N|` (the same `value|score` convention as UniProt evidence codes / InterPro bit scores; suppressed by `--no-scores`). Membership columns get an auto Kelly-palette legend (the bundle's 4th settings part); in `statistics.parquet` the two selections are distinguished by `label_kind` (`kmeans_elbow` / `kmeans_silhouette`). - **Faithfulness** — how well the projection preserves the source embedding's structure; each row is tagged `scope`: - **local** (kNN-neighbourhood): **kNN-overlap**, **trustworthiness**, **continuity**. - **global** (whole-layout): **random_triplet** (relative-ordering accuracy over random triplets, ∈[0,1]) and **spearman_distance** (rank correlation of all pairwise distances, ∈[−1,1]). - These per-projection scalars ride in each projection's `info_json.quality`. + These per-projection scalars ride in each projection's `info_json.quality` — they never land in `statistics.parquet`. Notes: -- Off by default — the compute (a KMeans sweep + faithfulness) and the extra bundle columns/styles are opt-in. +- Off by default — the compute (annotation-validity + a KMeans sweep + faithfulness) and the extra bundle columns/styles are opt-in. +- Annotation-based validity and cluster agreement need `-a/--annotations`; faithfulness and the membership columns do not. - Uses the projection's own high-dim metric (e.g. `cosine`) for faithfulness; falls back to `--metric` / `euclidean` when the reducer doesn't record one. - Best-effort: a failure for one statistic or projection is logged and skipped, never failing the run. At large scale the heavier metrics are subsampled (silhouette/faithfulness) or fit on a bounded subsample (KMeans elbow) with a deterministic seed. diff --git a/docs/superpowers/plans/2026-07-02-annotation-cluster-validity.md b/docs/superpowers/plans/2026-07-02-annotation-cluster-validity.md new file mode 100644 index 00000000..ede08673 --- /dev/null +++ b/docs/superpowers/plans/2026-07-02-annotation-cluster-validity.md @@ -0,0 +1,1171 @@ +# Annotation-based Cluster-Validity Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Compute cluster-validity (silhouette / Davies-Bouldin / Calinski-Harabasz) for user-selected **annotations** on both the source embedding and each projection, plus ARI/NMI agreement between the auto-clusters and each annotation — replacing the circular auto-KMeans self-validity. + +**Architecture:** A new `AnnotationValidityStatistic` scores each annotation's category labels on whatever space its `StatContext` carries (`ctx.coords`); the driver runs it on both a new once-per-embedding pass (`space_kind="embedding"`) and the existing per-projection pass. Agreement (ARI/NMI) is emitted from the existing `ClusterValidityStatistic`, reusing the KMeans labels it already computes (no second sweep). `statistics.parquet` gains an `annotation` dimension column. + +**Tech Stack:** Python ≥3.10, numpy, scikit-learn (function-local imports), pyarrow, pandas, Typer, pytest. + +## Global Constraints + +- Run all Python via `uv run` (e.g. `uv run pytest`, `uv run ruff check`). Never bare `python`. +- Lint clean: `uv run ruff check src/ tests/` (py310 target, 88-char lines). +- Terminology is **annotation**, never "feature", in all new code, help text, columns, and docs. +- scikit-learn imports stay **function-local** (keep CLI startup fast). +- Statistics are **best-effort**: a failure for one statistic/annotation/projection is logged and skipped, never raises. Wrap per-metric bodies in `try/except Exception` with `# noqa: BLE001`. +- Reuse the existing bounded-cost guards: subsample to `sample_threshold` (default `DEFAULT_SAMPLE_THRESHOLD = 5000`) with a deterministic seed; skip silhouette outside `2 ≤ k ≤ n-1`; skip DBI/CH when any cluster is a singleton. +- `--stats-annotation` only ever evaluated inside a stats run; default `auto`. +- Commit messages: use `feat(stats):` for the user-visible flag/behavior; `test(stats):`/`refactor(stats):`/`docs(stats):` for the rest. End every commit body with `Co-Authored-By: Claude Fable 5 `. + +--- + +### Task 1: Data model — `StatContext.annotations` + `StatRow.annotation` column + +**Files:** +- Modify: `src/protspace/stats/base.py` +- Test: `tests/test_stats.py` + +**Interfaces:** +- Produces: `StatContext(..., annotations: dict[str, dict[str, str]] | None = None)`; `StatRow(..., annotation: str = "")`; `STATS_SCHEMA` gains a 9th column `("annotation", pa.string())` positioned after `space_name`. + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_stats.py`: + +```python +def test_statrow_carries_annotation_column(): + from protspace.stats.base import STATS_SCHEMA, StatRow, StatsReport + + assert "annotation" in STATS_SCHEMA.names + row = StatRow( + space_kind="embedding", + space_name="prot_t5", + stat_family="annotation_validity", + label_kind="annotation", + metric="silhouette", + metric_kind="validity", + value=0.42, + annotation="major_group", + ) + rec = row.to_record() + assert rec["annotation"] == "major_group" + report = StatsReport() + report.add([row]) + tbl = report.to_arrow() + assert tbl.column("annotation").to_pylist() == ["major_group"] + + +def test_statcontext_defaults_annotations_none(): + from protspace.stats.base import StatContext + import numpy as np + + ctx = StatContext("projection", "P", coords=np.zeros((3, 2)), ids=["a", "b", "c"]) + assert ctx.annotations is None +``` + +- [ ] **Step 2: Run to verify it fails** + +Run: `uv run pytest tests/test_stats.py::test_statrow_carries_annotation_column tests/test_stats.py::test_statcontext_defaults_annotations_none -v` +Expected: FAIL (`annotation` not in schema / unexpected kwarg). + +- [ ] **Step 3: Implement** + +In `src/protspace/stats/base.py`: + +Update `STATS_SCHEMA` (add `annotation` after `space_name`) and its comment: + +```python +# The tidy schema. Rows are the bundle-boundary contract. Dimensions of the data +# (space, annotation, label kind, metric) are columns; per-row provenance +# (seeds, sample sizes, inertia lists) goes in ``extra_json``. +STATS_SCHEMA = pa.schema( + [ + ("space_kind", pa.string()), + ("space_name", pa.string()), + ("annotation", pa.string()), + ("stat_family", pa.string()), + ("label_kind", pa.string()), + ("metric", pa.string()), + ("metric_kind", pa.string()), + ("value", pa.float64()), + ("extra_json", pa.string()), + ] +) +``` + +Add `annotations` to `StatContext` (after `params` is fine; keep it keyword-friendly): + +```python + high_dim_metric: str = "euclidean" + params: dict = field(default_factory=dict) + # annotation name -> {protein id -> category label}. Present only when the + # caller requested annotation-based validity; id-keyed so lookup is + # order-independent for any space (embedding or projection). + annotations: dict[str, dict[str, str]] | None = None +``` + +Add `annotation` field to `StatRow` (after `space_name`) and emit it in `to_record`: + +```python + space_kind: str + space_name: str + annotation: str # "" for non-annotation rows; the annotation name otherwise + stat_family: str + label_kind: str + metric: str + metric_kind: str + value: float + extra: dict = field(default_factory=dict) + destination: str = "statistics_part" + + def to_record(self) -> dict: + return { + "space_kind": self.space_kind, + "space_name": self.space_name, + "annotation": self.annotation, + "stat_family": self.stat_family, + "label_kind": self.label_kind, + "metric": self.metric, + "metric_kind": self.metric_kind, + "value": float(self.value), + "extra_json": json.dumps(self.extra, sort_keys=True, default=_json_default), + } +``` + +> NOTE: `annotation` is a required positional field between `space_name` and `stat_family`. Every existing `StatRow(...)` call site (validity.py, faithfulness.py) is updated in later tasks; the test in this task uses the keyword form. + +- [ ] **Step 4: Run to verify it passes** + +Run: `uv run pytest tests/test_stats.py::test_statrow_carries_annotation_column tests/test_stats.py::test_statcontext_defaults_annotations_none -v` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/protspace/stats/base.py tests/test_stats.py +git commit -m "$(printf 'feat(stats): add annotation dimension to StatRow + StatContext\n\nCo-Authored-By: Claude Fable 5 ')" +``` + +--- + +### Task 2: Annotation selection + suitability filter + +**Files:** +- Create: `src/protspace/stats/annotation_select.py` +- Test: `tests/test_annotation_select.py` + +**Interfaces:** +- Produces: + - `suitable_annotations(frame, id_col: str = "identifier", max_card: int = 50) -> list[str]` + - `build_annotation_labels(frame, selection, id_col: str = "identifier") -> dict[str, dict[str, str]]` where `selection` is the string `"auto"` or a list of annotation names; returns `{name: {id: category}}`, dropping empty / `""` / NaN values. `frame` is a pandas DataFrame. + +- [ ] **Step 1: Write the failing test** + +Create `tests/test_annotation_select.py`: + +```python +import pandas as pd +import pytest + +from protspace.stats.annotation_select import ( + build_annotation_labels, + suitable_annotations, +) + + +def _frame(): + return pd.DataFrame( + { + "identifier": [f"p{i}" for i in range(6)], + "major_group": ["a", "a", "b", "b", "c", "c"], # suitable (3 cats) + "all_unique": [f"u{i}" for i in range(6)], # unsuitable (all unique) + "constant": ["x"] * 6, # unsuitable (1 cat) + "count": [1, 2, 3, 4, 5, 6], # unsuitable (numeric) + "cluster_elbow_P": ["cluster 0"] * 3 + ["cluster 1"] * 3, # excluded + } + ) + + +def test_suitable_annotations_filters(): + names = suitable_annotations(_frame()) + assert names == ["major_group"] + + +def test_build_labels_auto(): + labels = build_annotation_labels(_frame(), "auto") + assert set(labels) == {"major_group"} + assert labels["major_group"]["p0"] == "a" + assert len(labels["major_group"]) == 6 + + +def test_build_labels_explicit_names_and_missing_dropped(): + frame = _frame() + frame.loc[0, "major_group"] = "" # sentinel missing + frame.loc[1, "major_group"] = "" # empty + labels = build_annotation_labels(frame, ["major_group"]) + assert "p0" not in labels["major_group"] # dropped + assert "p1" not in labels["major_group"] # empty dropped + assert labels["major_group"]["p2"] == "b" + + +def test_build_labels_unknown_name_skipped(): + labels = build_annotation_labels(_frame(), ["does_not_exist"]) + assert labels == {} +``` + +- [ ] **Step 2: Run to verify it fails** + +Run: `uv run pytest tests/test_annotation_select.py -v` +Expected: FAIL (`ModuleNotFoundError: protspace.stats.annotation_select`). + +- [ ] **Step 3: Implement** + +Create `src/protspace/stats/annotation_select.py`: + +```python +"""Select which annotation columns to score, and materialise their labels. + +An annotation is "suitable" for cluster-validity when it is a low-cardinality +categorical column: at least 2 distinct non-empty values, at most +``min(max_card, n/2)`` (so it is not effectively an id), and not numeric. The +generated ``cluster_*`` membership columns and the id column are excluded. +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +_MISSING = {"", "", "nan", "None"} + + +def _clean(series) -> list[str]: + """Non-missing string values of a column.""" + out = [] + for v in series.tolist(): + if v is None: + continue + s = str(v) + if s in _MISSING: + continue + out.append(s) + return out + + +def _is_numeric(series) -> bool: + vals = _clean(series) + if not vals: + return False + try: + for s in vals: + float(s) + return True + except ValueError: + return False + + +def suitable_annotations( + frame, id_col: str = "identifier", max_card: int = 50 +) -> list[str]: + n = len(frame) + cap = min(max_card, max(2, n // 2)) + names: list[str] = [] + for col in frame.columns: + if col == id_col or col.startswith("cluster_"): + continue + vals = _clean(frame[col]) + distinct = len(set(vals)) + if distinct < 2 or distinct > cap: + continue + if distinct == len(vals): # all-unique → id-like + continue + if _is_numeric(frame[col]): + continue + names.append(col) + return names + + +def build_annotation_labels( + frame, selection, id_col: str = "identifier" +) -> dict[str, dict[str, str]]: + """``{annotation name -> {protein id -> category}}`` for the selection. + + ``selection`` is the string ``"auto"`` (all suitable) or a list of column + names. Missing / sentinel values are dropped, so a protein absent from a + column's mapping simply has no category for it. + """ + if id_col not in getattr(frame, "columns", []): + return {} + if isinstance(selection, str) and selection.lower() == "auto": + names = suitable_annotations(frame, id_col=id_col) + else: + wanted = list(selection) + available = suitable_annotations(frame, id_col=id_col) + names = [] + for name in wanted: + if name in available: + names.append(name) + else: + logger.warning( + "--stats-annotation '%s' is missing or unsuitable; skipping", name + ) + labels: dict[str, dict[str, str]] = {} + ids = [str(i) for i in frame[id_col].tolist()] + for name in names: + col = frame[name].tolist() + mapping: dict[str, str] = {} + for pid, v in zip(ids, col, strict=False): + if v is None: + continue + s = str(v) + if s in _MISSING: + continue + mapping[pid] = s + if mapping: + labels[name] = mapping + return labels +``` + +- [ ] **Step 4: Run to verify it passes** + +Run: `uv run pytest tests/test_annotation_select.py -v` +Expected: PASS (4 tests). + +- [ ] **Step 5: Commit** + +```bash +git add src/protspace/stats/annotation_select.py tests/test_annotation_select.py +git commit -m "$(printf 'feat(stats): annotation selection + suitability filter\n\nCo-Authored-By: Claude Fable 5 ')" +``` + +--- + +### Task 3: `AnnotationValidityStatistic` + +**Files:** +- Create: `src/protspace/stats/metrics/annotation_validity.py` +- Test: `tests/test_annotation_validity.py` + +**Interfaces:** +- Consumes: `StatContext` with `coords`, `ids`, `annotations`, `params`, `rng_seed`. +- Produces: class `AnnotationValidityStatistic` with `family = "annotation_validity"`, `requires_embedding = False`, `embedding_space = True`; `compute(ctx) -> list[StatRow]` emitting `stat_family="annotation_validity"`, `label_kind="annotation"`, `metric ∈ {silhouette, davies_bouldin, calinski_harabasz}`, `metric_kind="validity"`, `annotation=`. + +- [ ] **Step 1: Write the failing test** + +Create `tests/test_annotation_validity.py`: + +```python +import numpy as np + +from protspace.stats.base import StatContext, StatRow +from protspace.stats.metrics.annotation_validity import AnnotationValidityStatistic + + +def _blobs(n=200, centers=4, dim=2, seed=1): + from sklearn.datasets import make_blobs + + X, y = make_blobs(n_samples=n, centers=centers, n_features=dim, random_state=seed) + return X, y + + +def test_scores_each_annotation_on_ctx_coords(): + X, y = _blobs(n=200, centers=4, dim=2, seed=3) + ids = [f"p{i}" for i in range(200)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(ids, y)}} + outs = AnnotationValidityStatistic().compute( + StatContext("projection", "PCA_2", coords=X, ids=ids, annotations=ann) + ) + by_metric = {r.metric: r for r in outs if isinstance(r, StatRow)} + assert {"silhouette", "davies_bouldin", "calinski_harabasz"} <= set(by_metric) + s = by_metric["silhouette"] + assert s.stat_family == "annotation_validity" + assert s.annotation == "grp" and s.label_kind == "annotation" + assert 0.4 < s.value <= 1.0 # well-separated blobs → high silhouette + + +def test_space_kind_is_taken_from_context(): + X, y = _blobs(n=120, centers=3, dim=8, seed=4) + ids = [f"p{i}" for i in range(120)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(ids, y)}} + outs = AnnotationValidityStatistic().compute( + StatContext("embedding", "prot_t5", coords=X, ids=ids, annotations=ann) + ) + assert all(r.space_kind == "embedding" for r in outs) + assert all(r.space_name == "prot_t5" for r in outs) + + +def test_missing_annotation_values_excluded(): + X, y = _blobs(n=100, centers=2, dim=2, seed=5) + ids = [f"p{i}" for i in range(100)] + # Only half the proteins have a category → the rest are dropped from scoring. + ann = {"grp": {pid: f"g{int(c)}" for pid, c in list(zip(ids, y))[:50]}} + outs = AnnotationValidityStatistic().compute( + StatContext("projection", "P", coords=X, ids=ids, annotations=ann) + ) + sil = next(r for r in outs if r.metric == "silhouette") + assert sil.extra["n_labels"] == 50 + + +def test_single_category_annotation_emits_nothing(): + X, _ = _blobs(n=80, centers=1, dim=2, seed=6) + ids = [f"p{i}" for i in range(80)] + ann = {"grp": {pid: "only" for pid in ids}} # 1 category + outs = AnnotationValidityStatistic().compute( + StatContext("projection", "P", coords=X, ids=ids, annotations=ann) + ) + assert outs == [] + + +def test_no_annotations_returns_empty(): + X, _ = _blobs(n=50, centers=2, dim=2, seed=7) + outs = AnnotationValidityStatistic().compute( + StatContext("projection", "P", coords=X, ids=[f"p{i}" for i in range(50)]) + ) + assert outs == [] +``` + +- [ ] **Step 2: Run to verify it fails** + +Run: `uv run pytest tests/test_annotation_validity.py -v` +Expected: FAIL (module not found). + +- [ ] **Step 3: Implement** + +Create `src/protspace/stats/metrics/annotation_validity.py`: + +```python +"""Annotation-based cluster-validity: how well an annotation's categories +separate in a given space (embedding or projection). + +silhouette / Davies-Bouldin / Calinski-Harabasz are computed with the +annotation's category labels (not auto-KMeans labels), on ``ctx.coords`` — +the driver hands us the embedding for the once-per-embedding pass and the 2D +projection for the per-projection pass. scikit-learn imports are function-local. +""" + +from __future__ import annotations + +import numpy as np + +from protspace.stats.base import StatContext, StatRow + +DEFAULT_SAMPLE_THRESHOLD = 5000 + + +def _subsample(n: int, threshold: int, rng_seed: int): + """Deterministic sorted index subsample, or None when n <= threshold.""" + if n <= threshold: + return None + rng = np.random.default_rng(rng_seed) + return np.sort(rng.permutation(n)[:threshold]) + + +class AnnotationValidityStatistic: + """silhouette / DBI / CH of each annotation's categories on ``ctx.coords``.""" + + family = "annotation_validity" + requires_embedding = False + embedding_space = True # also run by the driver's once-per-embedding pass + + def compute(self, ctx: StatContext) -> list[StatRow]: + if not ctx.annotations: + return [] + from sklearn.metrics import ( + calinski_harabasz_score, + davies_bouldin_score, + silhouette_score, + ) + + X = np.asarray(ctx.coords, dtype=float) + threshold = int(ctx.params.get("sample_threshold", DEFAULT_SAMPLE_THRESHOLD)) + id_to_row = {pid: i for i, pid in enumerate(ctx.ids)} + rows: list[StatRow] = [] + + for name, mapping in ctx.annotations.items(): + # Rows of ctx.coords that have a category for this annotation. + row_idx: list[int] = [] + cats: list[str] = [] + for pid, cat in mapping.items(): + i = id_to_row.get(pid) + if i is not None: + row_idx.append(i) + cats.append(cat) + if len(row_idx) < 3: + continue + uniq = sorted(set(cats)) + if len(uniq) < 2: # need >= 2 categories + continue + cat_to_int = {c: j for j, c in enumerate(uniq)} + Xa = X[np.asarray(row_idx)] + labels = np.asarray([cat_to_int[c] for c in cats]) + + # Bound cost: shared deterministic subsample across all three metrics. + sub = _subsample(Xa.shape[0], threshold, ctx.rng_seed) + if sub is not None: + Xa, labels = Xa[sub], labels[sub] + n = Xa.shape[0] + _, counts = np.unique(labels, return_counts=True) + achieved = len(counts) + if achieved < 2: # a category vanished under subsampling + continue + has_singleton = bool((counts < 2).any()) + base = dict( + space_kind=ctx.space_kind, + space_name=ctx.space_name, + annotation=name, + stat_family=self.family, + label_kind="annotation", + ) + extra = { + "seed": ctx.rng_seed, + "n_labels": int(n), + "n_categories": int(achieved), + "sampled": sub is not None, + } + + if 2 <= achieved <= n - 1: + try: + rows.append( + StatRow( + metric="silhouette", + metric_kind="validity", + value=float(silhouette_score(Xa, labels)), + extra=extra, + **base, + ) + ) + except Exception: # noqa: BLE001 - best-effort + pass + if not has_singleton: + for metric_name, fn in ( + ("davies_bouldin", davies_bouldin_score), + ("calinski_harabasz", calinski_harabasz_score), + ): + try: + rows.append( + StatRow( + metric=metric_name, + metric_kind="validity", + value=float(fn(Xa, labels)), + extra=extra, + **base, + ) + ) + except Exception: # noqa: BLE001 - best-effort + pass + return rows +``` + +- [ ] **Step 4: Run to verify it passes** + +Run: `uv run pytest tests/test_annotation_validity.py -v` +Expected: PASS (5 tests). + +- [ ] **Step 5: Commit** + +```bash +git add src/protspace/stats/metrics/annotation_validity.py tests/test_annotation_validity.py +git commit -m "$(printf 'feat(stats): AnnotationValidityStatistic (silhouette/DBI/CH per annotation)\n\nCo-Authored-By: Claude Fable 5 ')" +``` + +--- + +### Task 4: Rework `ClusterValidityStatistic` — drop self-validity, add ARI/NMI agreement + +**Files:** +- Modify: `src/protspace/stats/metrics/validity.py` +- Test: `tests/test_stats.py` + +**Interfaces:** +- Consumes: `StatContext` with `coords`, `ids`, `annotations`, `params`. +- Produces: `ClusterValidityStatistic.compute` now emits, per auto-cluster labelling: an `n_clusters` **meta** `StatRow`, the membership `AnnotationColumn` (unchanged), and — for each annotation — `adjusted_rand` / `normalized_mutual_info` `StatRow`s with `stat_family="cluster_agreement"`, `metric_kind="agreement"`, `annotation=`, `label_kind=`. It NO LONGER emits `silhouette`/`davies_bouldin`/`calinski_harabasz` rows for the auto-clusters. + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_stats.py`: + +```python +def test_cluster_validity_emits_agreement_not_self_validity(): + from protspace.stats.base import AnnotationColumn, StatContext, StatRow + from protspace.stats.metrics.validity import ClusterValidityStatistic + + X, y = _blobs(n=200, centers=4, dim=2, seed=61) + ids = [f"p{i}" for i in range(200)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(ids, y)}} + outs = ClusterValidityStatistic().compute( + StatContext("projection", "PCA_2", coords=X, ids=ids, annotations=ann) + ) + rows = [o for o in outs if isinstance(o, StatRow)] + metrics = {r.metric for r in rows} + # No self-validity rows anymore: + assert not ({"silhouette", "davies_bouldin", "calinski_harabasz"} & metrics) + # n_clusters meta kept: + assert "n_clusters" in metrics + # ARI/NMI agreement vs the annotation, tagged correctly: + agree = [r for r in rows if r.stat_family == "cluster_agreement"] + assert {r.metric for r in agree} == {"adjusted_rand", "normalized_mutual_info"} + assert all(r.annotation == "grp" and r.metric_kind == "agreement" for r in agree) + assert all(r.label_kind == "kmeans_elbow" for r in agree) + # Auto-clusters recover well-separated blobs → high agreement. + ari = next(r for r in agree if r.metric == "adjusted_rand") + assert ari.value > 0.5 + # Membership column still emitted. + assert any(isinstance(o, AnnotationColumn) for o in outs) + + +def test_cluster_validity_no_annotations_still_emits_membership(): + from protspace.stats.base import AnnotationColumn, StatContext, StatRow + from protspace.stats.metrics.validity import ClusterValidityStatistic + + X, _ = _blobs(n=150, centers=3, dim=2, seed=62) + ids = [f"p{i}" for i in range(150)] + outs = ClusterValidityStatistic().compute( + StatContext("projection", "P", coords=X, ids=ids) + ) + assert any(isinstance(o, AnnotationColumn) for o in outs) + assert not [r for r in outs if isinstance(r, StatRow) and r.stat_family == "cluster_agreement"] +``` + +Also update the existing `test_aggregate_silhouette_equals_per_point_mean` and any test asserting a `silhouette`/`davies_bouldin`/`calinski_harabasz` `StatRow` from `ClusterValidityStatistic`: those aggregate rows are removed. Change them to assert on the membership column's attached per-point silhouette instead (the per-point confidence is retained), or delete the now-invalid assertion. Grep first: `uv run grep -rn "davies_bouldin\|calinski_harabasz\|metric == \"silhouette\"" tests/test_stats.py`. + +- [ ] **Step 2: Run to verify it fails** + +Run: `uv run pytest tests/test_stats.py::test_cluster_validity_emits_agreement_not_self_validity -v` +Expected: FAIL (agreement rows absent; self-validity rows still present). + +- [ ] **Step 3: Implement** + +In `src/protspace/stats/metrics/validity.py`, inside `_emit_labeling` (the per-labelling method): + +1. **Delete** the silhouette aggregate block (the `if silhouette_ok:` block that appends the `metric="silhouette"` `StatRow`) and the Davies-Bouldin / Calinski-Harabasz loop (`if not has_singleton:` block that appends those two rows). Keep the `per_point_samples` computation (it still feeds the membership column's attached `|silhouette`) and the `n_clusters` meta row. + +2. **Add** every existing `StatRow(...)` construction the new required positional `annotation=""` — the `n_clusters` meta row becomes: + +```python + rows: list = [ + StatRow( + space_kind=ctx.space_kind, + space_name=ctx.space_name, + annotation="", + stat_family=self.family, + label_kind=label_kind, + metric="n_clusters", + metric_kind="meta", + value=float(achieved), + extra=meta_extra, + ) + ] +``` + +(Replace the old `**base` spread with explicit kwargs, or add `annotation=""` to `base`.) + +3. **Add** the agreement block after the membership column is appended, still inside `_emit_labeling`: + +```python + # ARI/NMI: does this auto-clustering recover each annotation? Reuses the + # KMeans labels already computed (no second sweep). Compared over the + # id-intersection of clustered points and annotated points. + if ctx.annotations: + from sklearn.metrics import ( + adjusted_rand_score, + normalized_mutual_info_score, + ) + + label_by_id = dict(zip(ctx.ids, labels, strict=False)) + for name, mapping in ctx.annotations.items(): + paired_clu: list[int] = [] + paired_ann: list[str] = [] + for pid, cat in mapping.items(): + lbl = label_by_id.get(pid) + if lbl is not None: + paired_clu.append(int(lbl)) + paired_ann.append(cat) + if len(set(paired_ann)) < 2 or len(paired_ann) < 3: + continue + for metric_name, fn in ( + ("adjusted_rand", adjusted_rand_score), + ("normalized_mutual_info", normalized_mutual_info_score), + ): + try: + rows.append( + StatRow( + space_kind=ctx.space_kind, + space_name=ctx.space_name, + annotation=name, + stat_family="cluster_agreement", + label_kind=label_kind, + metric=metric_name, + metric_kind="agreement", + value=float(fn(paired_ann, paired_clu)), + extra={"seed": rng_seed, "n_labels": len(paired_ann)}, + ) + ) + except Exception: # noqa: BLE001 - best-effort + pass +``` + +4. Update the module docstring: it now emits auto-clustering (membership + `n_clusters`) and annotation **agreement**, not self-validity. + +- [ ] **Step 4: Run to verify it passes** + +Run: `uv run pytest tests/test_stats.py -k "cluster_validity or agreement or membership" -v` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/protspace/stats/metrics/validity.py tests/test_stats.py +git commit -m "$(printf 'refactor(stats): drop auto-cluster self-validity, add ARI/NMI agreement\n\nCo-Authored-By: Claude Fable 5 ')" +``` + +--- + +### Task 5: Driver — thread annotations + once-per-embedding pass + registry + +**Files:** +- Modify: `src/protspace/stats/driver.py` +- Modify: `src/protspace/stats/__init__.py` +- Modify: `src/protspace/stats/metrics/faithfulness.py` (add `annotation=""` to its `StatRow(...)` calls) +- Test: `tests/test_stats.py` + +**Interfaces:** +- Consumes: `AnnotationValidityStatistic` (Task 3), reworked `ClusterValidityStatistic` (Task 4). +- Produces: `compute_statistics(..., annotations: dict[str, dict[str, str]] | None = None)`. Projection contexts carry `annotations`; a new per-embedding loop builds `StatContext(space_kind="embedding", space_name=, coords=, ids=, annotations=...)` and runs only statistics with `embedding_space = True`. + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_stats.py`: + +```python +def test_driver_emits_embedding_and_projection_annotation_validity(): + from protspace.stats import compute_statistics + from sklearn.decomposition import PCA + + X, y = _blobs(n=180, centers=4, dim=8, seed=71) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + headers = [f"p{i}" for i in range(180)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(headers, y)}} + + class _Emb: + name = "e" + data = X + headers = headers + precomputed = False + + report = compute_statistics( + [_Emb()], + [{"name": "e — PCA 2", "data": coords, "ids": headers, "source": "e"}], + annotations=ann, + ) + av = [r for r in report.rows if r.stat_family == "annotation_validity"] + kinds = {(r.space_kind, r.annotation) for r in av} + assert ("embedding", "grp") in kinds # once-per-embedding pass + assert ("projection", "grp") in kinds # per-projection pass + # embedding is computed exactly once per (embedding, annotation, metric) + emb_sil = [r for r in av if r.space_kind == "embedding" and r.metric == "silhouette"] + assert len(emb_sil) == 1 +``` + +- [ ] **Step 2: Run to verify it fails** + +Run: `uv run pytest tests/test_stats.py::test_driver_emits_embedding_and_projection_annotation_validity -v` +Expected: FAIL (`compute_statistics` has no `annotations` param / no embedding rows). + +- [ ] **Step 3: Implement** + +In `src/protspace/stats/__init__.py`, register the new statistic: + +```python + from protspace.stats.metrics.annotation_validity import ( + AnnotationValidityStatistic, + ) + from protspace.stats.metrics.faithfulness import FaithfulnessStatistic + from protspace.stats.metrics.validity import ClusterValidityStatistic + + _STATISTICS = [ + ClusterValidityStatistic(), + AnnotationValidityStatistic(), + FaithfulnessStatistic(), + ] +``` + +In `src/protspace/stats/driver.py`: + +Add the `annotations` parameter and thread it into the projection `StatContext` (add `annotations=annotations` to the `StatContext(...)` call), then append the embedding pass just before `return report`: + +```python +def compute_statistics( + embedding_sets: list, + reductions: list[dict], + *, + rng_seed: int = 42, + params: dict | None = None, + statistics: list | None = None, + default_metric: str = "euclidean", + annotations: dict | None = None, +) -> StatsReport: +``` + +Projection `StatContext(...)` gains: + +```python + params=params, + annotations=annotations, + ) +``` + +After the projection loop, before `return report`: + +```python + # Once-per-embedding pass: annotation-validity on the source embedding itself + # (the true-separability "ceiling"), computed once per embedding rather than + # repeated for every projection that shares it. Only statistics that opt in + # via ``embedding_space`` run here. + if annotations: + emb_stats = [s for s in stats if getattr(s, "embedding_space", False)] + for es in embedding_sets: + if getattr(es, "precomputed", False): + continue + try: + ectx = StatContext( + space_kind="embedding", + space_name=es.name, + coords=np.asarray(es.data, dtype=float), + ids=list(es.headers), + rng_seed=rng_seed, + params=params or {}, + annotations=annotations, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("embedding-stats setup failed for '%s': %s", es.name, exc) + continue + for stat in emb_stats: + try: + report.add(stat.compute(ectx)) + except Exception as exc: # noqa: BLE001 - statistics are secondary + logger.warning( + "statistic %s failed for embedding '%s': %s", + getattr(stat, "family", stat), + es.name, + exc, + ) + + return report +``` + +In `src/protspace/stats/metrics/faithfulness.py`, add `annotation=""` to each `StatRow(...)` construction (the skip row and the per-metric rows) so they satisfy the new required field. + +- [ ] **Step 4: Run to verify it passes** + +Run: `uv run pytest tests/test_stats.py::test_driver_emits_embedding_and_projection_annotation_validity -v` +Then the whole stats suite: `uv run pytest tests/test_stats.py -q` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/protspace/stats/driver.py src/protspace/stats/__init__.py src/protspace/stats/metrics/faithfulness.py tests/test_stats.py +git commit -m "$(printf 'feat(stats): driver runs annotation-validity on embedding + projections\n\nCo-Authored-By: Claude Fable 5 ')" +``` + +--- + +### Task 6: `stats` CLI — `--stats-annotation` + +**Files:** +- Modify: `src/protspace/cli/stats.py` +- Test: `tests/test_stats_cli.py` + +**Interfaces:** +- Consumes: `build_annotation_labels` (Task 2), `compute_statistics(..., annotations=...)` (Task 5). +- Produces: `stats --stats-annotation "auto"|"a,b"` reads the `-a` parquet into a frame, builds labels, passes them to `compute_statistics`; `statistics.parquet` then contains `annotation_validity` + `cluster_agreement` rows. + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_stats_cli.py` (reuse the existing `_project_dir` helper that builds an h5 + projections; it must also write an annotations parquet with a categorical column — extend it or build inline): + +```python +def test_stats_command_computes_annotation_validity(tmp_path): + import pyarrow as pa + import pyarrow.parquet as pq + from typer.testing import CliRunner + + from protspace.cli.app import app + + h5_path, proj, ids = _project_dir(tmp_path) # returns (h5, proj_dir, id list) + ann_path = tmp_path / "annotations.parquet" + # A separable categorical annotation over the same ids. + groups = ["a" if i % 2 else "b" for i in range(len(ids))] + pq.write_table( + pa.table({"identifier": ids, "major_group": groups}), str(ann_path) + ) + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, + ["stats", "-i", f"{h5_path}:E", "-p", str(proj), "-o", str(out), + "-a", str(ann_path), "--stats-annotation", "auto"], + ) + assert result.exit_code == 0, result.output + st = pq.read_table(str(out)).to_pandas() + assert "annotation" in st.columns + av = st[st.stat_family == "annotation_validity"] + assert set(av["annotation"]) == {"major_group"} + assert {"embedding", "projection"} <= set(av["space_kind"]) + + +def test_stats_rejects_no_annotation_source_for_stats_annotation(tmp_path): + from typer.testing import CliRunner + from protspace.cli.app import app + + h5_path, proj, _ = _project_dir(tmp_path) + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, + ["stats", "-i", f"{h5_path}:E", "-p", str(proj), "-o", str(out), + "--stats-annotation", "major_group"], # no -a + ) + # --stats-annotation without -a has nothing to score → clear error. + assert result.exit_code != 0 +``` + +> If `_project_dir` doesn't return the id list, update it to `return h5_path, proj, ids` and fix its existing callers in the same commit. + +- [ ] **Step 2: Run to verify it fails** + +Run: `uv run pytest tests/test_stats_cli.py::test_stats_command_computes_annotation_validity -v` +Expected: FAIL (no `--stats-annotation` option). + +- [ ] **Step 3: Implement** + +In `src/protspace/cli/stats.py`: + +Add the option to the `stats` signature (after `cluster_selection`): + +```python + stats_annotation: Annotated[ + str, + typer.Option( + "--stats-annotation", + help="Which annotation column(s) to score for cluster-validity: " + "'auto' (all suitable categoricals) or a comma-separated list. " + "Requires -a/--annotations.", + ), + ] = "auto", +``` + +Add validation next to the existing `--settings-out requires -a` guard: + +```python + if stats_annotation and annotations is None and stats_annotation != "auto": + raise typer.BadParameter("--stats-annotation requires -a/--annotations.") +``` + +Build labels and pass them in. Just before the `compute_statistics(...)` call: + +```python + import pyarrow.parquet as pq # already imported at function top + + annotation_labels = None + if annotations is not None: + ann_frame = pq.read_table(str(annotations)).to_pandas() + id_col = "identifier" if "identifier" in ann_frame.columns else ann_frame.columns[0] + selection = ( + "auto" + if stats_annotation.strip().lower() == "auto" + else [s.strip() for s in stats_annotation.split(",") if s.strip()] + ) + annotation_labels = build_annotation_labels(ann_frame, selection, id_col=id_col) + + report = compute_statistics( + embedding_sets, + reductions, + rng_seed=seed, + params=params, + default_metric=metric, + annotations=annotation_labels, + ) +``` + +Add the import at the top of the function's import block: + +```python + from protspace.stats.annotation_select import build_annotation_labels +``` + +- [ ] **Step 4: Run to verify it passes** + +Run: `uv run pytest tests/test_stats_cli.py -k "annotation" -v` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/protspace/cli/stats.py tests/test_stats_cli.py +git commit -m "$(printf 'feat(stats): stats --stats-annotation scores selected annotations\n\nCo-Authored-By: Claude Fable 5 ')" +``` + +--- + +### Task 7: `prepare` CLI + pipeline — `--stats-annotation` + +**Files:** +- Modify: `src/protspace/cli/prepare.py` +- Modify: `src/protspace/data/processors/pipeline.py` +- Test: `tests/test_stats_cli.py` + +**Interfaces:** +- Consumes: `PipelineConfig`, `build_annotation_labels`, `compute_statistics(..., annotations=...)`. +- Produces: `PipelineConfig.stats_annotation: str = "auto"`; `prepare --stats --stats-annotation ...` flows the selection into `_compute_statistics`, which builds labels from the `metadata` frame and passes them to `compute_statistics`. + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_stats_cli.py` an assertion in (or alongside) the existing `test_prepare_pipeline_compute_statistics` that, with `--stats` and an annotation CSV containing a categorical column, the resulting bundle's `statistics.parquet` has `stat_family == "annotation_validity"` rows. (Follow that test's existing construction of inputs; add `--stats-annotation auto` to the invocation and read the 5th bundle part via `read_statistics_from_bundle`.) + +```python +def test_prepare_stats_annotation_validity_in_bundle(tmp_path): + # ... build FASTA/h5 + a CSV annotation with a categorical 'grp' column, + # mirroring test_prepare_pipeline_compute_statistics ... + # invoke: prepare -i emb.h5 -a grp.csv -m pca2,umap2 --stats --stats-annotation auto -o out + # then: + from protspace.data.io.bundle import read_statistics_from_bundle + import pyarrow.parquet as pq, io + raw = read_statistics_from_bundle(bundle_path) + st = pq.read_table(io.BytesIO(raw)).to_pandas() + assert (st.stat_family == "annotation_validity").any() + assert "annotation" in st.columns +``` + +- [ ] **Step 2: Run to verify it fails** + +Run: `uv run pytest tests/test_stats_cli.py::test_prepare_stats_annotation_validity_in_bundle -v` +Expected: FAIL. + +- [ ] **Step 3: Implement** + +In `src/protspace/data/processors/pipeline.py`, add to `PipelineConfig`: + +```python + cluster_selection: str = "elbow" # elbow | silhouette | both (for --stats) + stats_annotation: str = "auto" # which annotation(s) to score (--stats) +``` + +In `_compute_statistics`, build labels from `metadata` and pass them in: + +```python + from protspace.stats.annotation_select import build_annotation_labels + + annotation_labels = None + if metadata is not None: + selection = ( + "auto" + if str(self.config.stats_annotation).strip().lower() == "auto" + else [ + s.strip() + for s in self.config.stats_annotation.split(",") + if s.strip() + ] + ) + annotation_labels = build_annotation_labels( + metadata, selection, id_col="identifier" + ) + + report = compute_statistics( + embedding_sets, + all_reductions, + rng_seed=self.config.reducer_params.random_state, + params={ + "cluster_selection": self.config.cluster_selection, + "include_scores": not self.config.no_scores, + }, + default_metric=self.config.reducer_params.metric, + annotations=annotation_labels, + ) +``` + +In `src/protspace/cli/prepare.py`, add the option type near `Opt_ClusterSelection`: + +```python +Opt_StatsAnnotation = Annotated[ + str, + typer.Option( + "--stats-annotation", + help="With --stats, which annotation column(s) to score: 'auto' (all " + "suitable categoricals) or a comma-separated list.", + rich_help_panel="Output", + ), +] +``` + +Add the parameter to `prepare(...)` after `cluster_selection`: + +```python + cluster_selection: Opt_ClusterSelection = ClusterSelection.elbow, + stats_annotation: Opt_StatsAnnotation = "auto", +``` + +Pass it into `PipelineConfig(...)`: + +```python + cluster_selection=cluster_selection.value, + stats_annotation=stats_annotation, +``` + +- [ ] **Step 4: Run to verify it passes** + +Run: `uv run pytest tests/test_stats_cli.py -k "annotation or compute_statistics" -v` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/protspace/cli/prepare.py src/protspace/data/processors/pipeline.py tests/test_stats_cli.py +git commit -m "$(printf 'feat(stats): prepare --stats-annotation flows selection into the pipeline\n\nCo-Authored-By: Claude Fable 5 ')" +``` + +--- + +### Task 8: Full-suite green, lint, and docs + +**Files:** +- Modify: `CLAUDE.md`, `../protspace/CLAUDE.md` (the package one at `src`-level), `docs/cli.md`, `README.md`, `notebooks/ProtSpace_Preparation.ipynb` +- Test: entire suite + +- [ ] **Step 1: Run the whole fast suite + lint** + +Run: +```bash +uv run pytest tests/ -m "not slow" -q +uv run ruff check src/ tests/ +``` +Expected: all pass. Fix any residual `StatRow(...)` call sites missing the new `annotation` field (grep: `uv run grep -rn "StatRow(" src/protspace | wc -l` and confirm each passes `annotation`). Fix any test that asserted the old auto-cluster self-validity rows. + +- [ ] **Step 2: Update docs (annotation terminology + new flag + schema)** + +Edit the `protspace stats` / `prepare` sections and Output Format in `protspace/CLAUDE.md`, `docs/cli.md`, `README.md`, and the notebook stats cell: +- Document `--stats-annotation` (`auto` | comma-list) on both `prepare` and `stats`. +- Describe cluster-validity as **annotation-based** (silhouette/DBI/CH per annotation on embedding + projection) + ARI/NMI agreement; note the auto-cluster membership columns are retained but no longer self-scored. +- Note `statistics.parquet` gains the `annotation` column and `space_kind ∈ {projection, embedding}`. +- Update the `test_stats*` counts in the test-file table. + +- [ ] **Step 3: Commit** + +```bash +git add -A +git commit -m "$(printf 'docs(stats): document annotation-based cluster-validity + --stats-annotation\n\nCo-Authored-By: Claude Fable 5 ')" +``` + +--- + +## Self-Review + +- **Spec coverage:** score space (embedding + projection) → Tasks 3+5; annotation selection + suitability → Task 2; `--stats-annotation` (auto/list, default auto, gated on --stats) → Tasks 6-7; ARI/NMI vs auto-clusters → Task 4; drop self-validity, keep membership + n_clusters → Task 4; schema `annotation` column + `space_kind=embedding` → Task 1; input dependencies (missing embedding/annotation skips only its part) → Tasks 3 (`requires_embedding=False`, empty-annotations guard) + 5 (embedding pass gated on `annotations`); docs/frontend note → Task 8 (frontend #296 + sample regeneration handled after merge, as the spec parks them). Gap/BIC out of scope → #64. Covered. +- **Placeholder scan:** Task 7's test references inputs "mirroring test_prepare_pipeline_compute_statistics" — the implementer copies that test's setup; acceptable since the exact fixture already exists in the file. All algorithmic code is complete. +- **Type consistency:** `build_annotation_labels(frame, selection, id_col)` and `compute_statistics(..., annotations=...)` and `StatContext(annotations=...)` and `StatRow(annotation=...)` are used identically across Tasks 1-7. `embedding_space` attribute set in Task 3, read in Task 5. `stat_family` values `annotation_validity` / `cluster_agreement` consistent between Tasks 3, 4, 6. + +## Post-implementation (NOT part of this plan — resume the parked work) +1. Regenerate the 3FTx sample bundle with the new stats. +2. Update protspace_web#296 spec + post the concise body (still parked). +3. Update/regenerate and re-decide the `feat/projection-statistics` → main merge. diff --git a/docs/superpowers/specs/2026-07-02-annotation-cluster-validity-design.md b/docs/superpowers/specs/2026-07-02-annotation-cluster-validity-design.md new file mode 100644 index 00000000..9258fea6 --- /dev/null +++ b/docs/superpowers/specs/2026-07-02-annotation-cluster-validity-design.md @@ -0,0 +1,106 @@ +# Annotation-based cluster-validity — design + +**Date:** 2026-07-02 +**Branch:** `feat/annotation-cluster-validity` (stacked on `feat/projection-statistics`) +**Refs:** #31 (parent feature request), #63 (extras, merged), #64 (deferred gap/BIC k-selection), protspace_web#296 (frontend spec) + +## Motivation + +The shipped cluster-validity metric diverges from what #31 asked for. Today `ClusterValidityStatistic` runs KMeans on each projection's 2D coordinates and computes `silhouette / davies_bouldin / calinski_harabasz` **on those auto-KMeans labels**. That answers "does this projection form clean KMeans blobs?" — and it is partly circular, since KMeans optimises the very compactness silhouette/CH reward. + +#31 instead asked to *"compute standard clustering quality scores for **any selected feature/annotation**"* and noted *"metrics should be computed on the **original high-dimensional embeddings**, not on UMAP/t-SNE projections which distort distances."* i.e. the scores should measure how well a **biological annotation** (e.g. `major_group`) separates — the answer users actually want. + +This rework makes cluster-validity **annotation-based**, while keeping the automated group-detection columns intact. + +## Terminology + +Use **annotation** throughout (never "feature"). An *annotation* is a per-protein categorical column such as `major_group`. (The only "feature" strings in the repo are legacy Dash help copy — out of scope.) + +## What changes / stays + +| Piece | Before | After | +|---|---|---| +| Validity scores (silhouette/DBI/CH) | on auto-KMeans labels of the 2D projection | on each **annotation's** category labels, on **both** the embedding and each projection | +| Auto-clustering: `cluster_elbow_*` / `cluster_silhouette_*` membership columns (+ per-point silhouette confidence + auto legend) | present | **kept unchanged** (group detection; frontend #296 already consumes them) | +| Agreement (ARI/NMI) | none | **new** — auto-clusters vs each annotation | +| Auto-cluster **self**-silhouette/DBI/CH aggregate rows | in `statistics.parquet` | **removed** (the circular metric). `n_clusters` meta row kept (documents detected K + inertia/knee for an elbow chart) | +| Gap statistic / BIC-AIC k-selection | — | **out of scope** → #64 | + +## What gets computed + +For each selected annotation `a` (per-protein category labels), dropping proteins whose value for `a` is missing/``: + +1. **Embedding-space validity** — `silhouette / davies_bouldin / calinski_harabasz` of `a` on the source embedding. One set per `(embedding × a)`. The true-separability "ceiling" (#31's Key note). Computed **once per embedding** (not repeated per projection). +2. **Projection-space validity** — the same three on each projection's 2D coords. One set per `(projection × a)`. How well each layout displays that separation. +3. **Agreement** — `adjusted_rand` + `normalized_mutual_info` between each auto-cluster labelling (`kmeans_elbow`, `kmeans_silhouette`) and `a`, per projection. "Did automated KMeans recover `a`?" (label-only; coordinate-independent.) + +All silhouette computations reuse the existing subsample (`sample_threshold`, default 5000) and hard-ceiling guards, so cost stays bounded at 570k scale. DBI/CH are `O(n·k)`. ARI/NMI are `O(n)`. + +**Input dependencies (best-effort — a missing input skips only what needs it):** +- Annotation-validity + agreement need **annotation data**: the metadata frame in `prepare`, `-a/--annotations` in the standalone `stats` command. Absent → skipped. +- **Embedding-space** validity needs the **embedding**: always present in `prepare`; `-i/--input` in `stats` (already required there for faithfulness). Absent → only projection-space validity is emitted. +- Projection-space validity + agreement need only projection coords + annotations. + +## Annotation selection (CLI) + +New option on **both** `prepare` and the standalone `stats` command: + +- `--stats-annotation major_group,sub_group` — score exactly those columns. +- `--stats-annotation auto` — score every **suitable** categorical annotation. +- **Default when `--stats` is active and the flag is omitted: `auto`.** Only ever evaluated inside a stats run; bounded by the silhouette guards. Users narrow to explicit names to cut compute. + +(Chosen over `--score-annotation`; `-a/--annotations` remains the annotation *source*, `--stats-annotation` selects *which columns to score*. Revisit at review if the overlap bothers.) + +**Suitable** annotation = categorical AND `2 ≤ n_distinct ≤ min(50, n/2)` AND not all-unique (excludes `identifier`) AND not numeric-valued (excludes `seq_start`, `number_cysteines`) AND not a generated `cluster_*` column. Unsuitable names passed explicitly are skipped with a logged warning (best-effort, never fail the run). + +## `statistics.parquet` schema change + +Additive — **one new column `annotation`**, and `space_kind` gains the value `embedding` (was always `projection`): + +| stat_family | space_kind | space_name | annotation | label_kind | metric | metric_kind | +|---|---|---|---|---|---|---| +| `annotation_validity` | `embedding` | `prot_t5` | `major_group` | `annotation` | `silhouette`/`davies_bouldin`/`calinski_harabasz` | `validity` | +| `annotation_validity` | `projection` | `ProtT5 — UMAP 2` | `major_group` | `annotation` | (same three) | `validity` | +| `cluster_agreement` | `projection` | `ProtT5 — UMAP 2` | `major_group` | `kmeans_elbow`/`kmeans_silhouette` | `adjusted_rand`/`normalized_mutual_info` | `agreement` | +| `cluster_validity` | `projection` | `ProtT5 — UMAP 2` | *(empty)* | `kmeans_elbow`/`kmeans_silhouette` | `n_clusters` | `meta` | + +`annotation` is empty for non-annotation rows. `extra_json` keeps per-metric provenance (`sampled`, `sample_size`, `seed`, `n_labels`, and for agreement the two label kinds compared). Readers branch on the `annotation` column / `space_kind`, not on column count. + +## Architecture + +- **`StatContext`** (`stats/base.py`): add `annotations: dict[str, dict[str, str]] | None` (annotation name → {protein id → category}, id-aligned). Cluster-validity path now also uses `embedding`/`embedding_coords` (today `requires_embedding=False`). +- **Driver** (`stats/driver.py`): accept `annotations` (a frame/dict) and thread it into each `StatContext`. Add a **once-per-embedding pass** that emits embedding-space annotation-validity (so it isn't recomputed for every projection sharing an embedding). Per-projection pass emits projection-space validity + agreement. +- **Statistics classes** (`stats/metrics/`): + - `ClusterValidityStatistic` — keep only the **auto-clustering + membership columns + `n_clusters` meta** (drop the self-validity aggregate rows). + - **new `AnnotationValidityStatistic`** — silhouette/DBI/CH of each annotation on the context's space (embedding or projection). + - **new `ClusterAgreementStatistic`** — ARI/NMI of each auto-cluster labelling vs each annotation (projection contexts only). + - All keep the best-effort/guard conventions (function-local sklearn imports, per-row try/except, singleton guards for DBI/CH). +- **Carriage** (`stats/carriage.py`): annotation-validity + agreement rows route to `statistics.parquet` (5th part) like existing validity rows. Membership columns + legend unchanged. +- **CLI** (`cli/prepare.py`, `cli/stats.py`): add `--stats-annotation`; validate (`auto` or known column names); pass selection + the annotation frame into the pipeline/driver. + +## Frontend (#296) + docs impact + +The schema gains `annotation` + `space_kind == embedding` + `cluster_agreement` rows, and drops the auto-cluster self-silhouette. After this lands: +- Regenerate the 3FTx sample bundle. +- Update the #296 spec (currently the concise draft is **not yet posted** — it stays parked). +- Update `CLAUDE.md` / `docs/cli.md` / README / notebook stats sections. + +**Merge of `feat/projection-statistics` → main and the #296 cleanup remain paused until this ships.** + +## Testing + +New/updated `tests/test_stats*.py` cases: +- Annotation validity computed on embedding vs projection (values differ; embedding is the ceiling on a synthetic separable annotation). +- `auto` suitability filter: skips numeric / all-unique / high-cardinality / `cluster_*`; keeps a valid categorical. +- ARI/NMI high when auto-clusters match a planted annotation, low when random. +- Missing-value handling: proteins with `` for an annotation are excluded from that annotation's scoring only. +- `--stats-annotation` CLI: `auto`, explicit names, unknown-name warning, on both `prepare` and `stats`. +- Removed: assertions on the old auto-cluster self-silhouette aggregate row. + +## Out of scope +- Gap statistic + BIC/AIC k-selection → #64. +- HDBSCAN/GMM auto-clustering models → future (KMeans only). +- Frontend rendering of the new rows → protspace_web#296. + +## Migration / compatibility +Bundle stays 5-part (`core(3) + settings? + statistics?`). The `statistics.parquet` gains a column and a `space_kind` value; both are additive and readers branch on content. No change to parts 1–4 beyond the already-shipped membership columns. diff --git a/notebooks/ProtSpace_Preparation.ipynb b/notebooks/ProtSpace_Preparation.ipynb index af715e54..0e779c05 100644 --- a/notebooks/ProtSpace_Preparation.ipynb +++ b/notebooks/ProtSpace_Preparation.ipynb @@ -310,24 +310,7 @@ "cell_type": "markdown", "id": "b7a15c0stats", "metadata": {}, - "source": [ - "## 📊 Quality statistics (optional)\n", - "\n", - "Gauge how well each projection preserves your data. The CLI bakes two metric families into the bundle:\n", - "\n", - "- **cluster-validity** — silhouette, Davies–Bouldin, Calinski–Harabasz on a KMeans clustering; choose the cluster count K by `elbow`, `silhouette`, or `both` (`--cluster-selection`).\n", - "- **faithfulness** — *local* neighbourhood preservation (kNN-overlap, trustworthiness, continuity) and *global* layout preservation (random_triplet, spearman_distance).\n", - "\n", - "```bash\n", - "# inline during prepare (opt-in)\n", - "protspace prepare -i embeddings.h5 -m pca2,umap2 --stats -o output/\n", - "\n", - "# or compute for an already-generated project directory\n", - "protspace stats -i embeddings.h5 -p output/tmp -o statistics.parquet\n", - "```\n", - "\n", - "This also adds an auto-colored per-protein `cluster_elbow_` membership column — with each point's silhouette confidence attached to its value — that you can explore directly in the viewer. See [the CLI docs](https://github.com/tsenoner/protspace/blob/main/docs/cli.md#projection-statistics---stats)." - ] + "source": "## 📊 Quality statistics (optional)\n\nGauge how well each projection preserves your data, and how well your annotations separate. The CLI bakes three metric families into the bundle:\n\n- **annotation-based validity** — silhouette, Davies–Bouldin, Calinski–Harabasz scored on your own annotation's categories (not auto-clustering), computed once for the embedding and again for each projection. `--stats-annotation` picks which annotation(s) to score: `auto` (every suitable low-cardinality categorical, the default) or a comma-separated list.\n- **auto-cluster agreement** — a KMeans clustering (choose K by `elbow`, `silhouette`, or `both` via `--cluster-selection`) compared against your annotations via ARI/NMI.\n- **faithfulness** — *local* neighbourhood preservation (kNN-overlap, trustworthiness, continuity) and *global* layout preservation (random_triplet, spearman_distance).\n\n```bash\n# inline during prepare (opt-in)\nprotspace prepare -i embeddings.h5 -m pca2,umap2 --stats -o output/\n\n# score only specific annotations instead of every suitable one\nprotspace prepare -i embeddings.h5 -m pca2,umap2 --stats --stats-annotation major_group -o output/\n\n# or compute for an already-generated project directory\nprotspace stats -i embeddings.h5 -p output/tmp -o statistics.parquet\n```\n\nThis also adds an auto-colored per-protein `cluster_elbow_` membership column — with each point's silhouette confidence attached to its value — that you can explore directly in the viewer. See [the CLI docs](https://github.com/tsenoner/protspace/blob/main/docs/cli.md#projection-statistics---stats)." }, { "cell_type": "code", diff --git a/src/protspace/cli/prepare.py b/src/protspace/cli/prepare.py index 78c9c35d..da925ecf 100644 --- a/src/protspace/cli/prepare.py +++ b/src/protspace/cli/prepare.py @@ -136,6 +136,15 @@ rich_help_panel="Output", ), ] +Opt_StatsAnnotation = Annotated[ + str, + typer.Option( + "--stats-annotation", + help="With --stats, which annotation column(s) to score: 'auto' (all " + "suitable categoricals) or a comma-separated list.", + rich_help_panel="Output", + ), +] REFETCH_STAGES = frozenset( { "query", @@ -313,6 +322,7 @@ def prepare( scores: Opt_Scores = True, stats: Opt_Stats = False, cluster_selection: Opt_ClusterSelection = ClusterSelection.elbow, + stats_annotation: Opt_StatsAnnotation = "auto", refetch: Opt_Refetch = None, # Output output: Opt_Output = Path("."), @@ -530,6 +540,7 @@ def prepare( no_scores=not scores, stats=stats, cluster_selection=cluster_selection.value, + stats_annotation=stats_annotation, refetch_stages=refetch_stages, annotations=annotation_list, intermediate_dir=cache_dir, diff --git a/src/protspace/cli/stats.py b/src/protspace/cli/stats.py index fce1e3fd..f633a5f1 100644 --- a/src/protspace/cli/stats.py +++ b/src/protspace/cli/stats.py @@ -2,8 +2,10 @@ Loads the embedding H5(s) (for faithfulness) and the projection coordinates from a project directory, computes the tidy statistics table, and writes it as a -parquet file — the optional fifth ``.parquetbundle`` part. No annotations are -needed. Best-effort: per-statistic failures are isolated by the driver. +parquet file — the optional fifth ``.parquetbundle`` part. Faithfulness and the +cluster-membership columns need no annotations; annotation-based validity and its +ARI/NMI agreement need ``-a/--annotations``. Best-effort: per-statistic failures +are isolated by the driver. """ import json @@ -19,6 +21,11 @@ logger = logging.getLogger(__name__) +def _resolve_id_col(frame) -> str: + """The identifier column: ``identifier`` if present, else the first column.""" + return "identifier" if "identifier" in frame.columns else frame.columns[0] + + def _atomic_write_table(table, path: Path) -> None: """Overwrite ``path`` with ``table`` atomically. @@ -146,14 +153,15 @@ def _merge_quality_into_metadata(meta_path: Path, quality_by_name: dict) -> None _atomic_write_table(table, meta_path) -def _merge_annotations_with_columns(ann_path: Path, report) -> int: +def _merge_annotations_with_columns(ann_path: Path, report, frame=None) -> int: """Merge the report's per-protein ``AnnotationColumn``s into ``ann_path``. Rewrites the annotations parquet in place with the computed ``cluster_*`` membership columns joined by identifier (each value a ``cluster N`` label with the per-point silhouette attached as ``|score``). Added columns are stringified (absent → empty) so they match the prepare path's all-string annotations and the - frontend's content-based type inference. Returns the number of columns added. + frontend's content-based type inference. ``frame`` reuses an already-loaded + DataFrame instead of re-reading ``ann_path``. Returns the number of columns added. """ import pyarrow as pa import pyarrow.parquet as pq @@ -162,9 +170,8 @@ def _merge_annotations_with_columns(ann_path: Path, report) -> int: if not report.annotation_columns or not ann_path.exists(): return 0 - df = pq.read_table(str(ann_path)).to_pandas() - id_col = "identifier" if "identifier" in df.columns else df.columns[0] - added = merge_annotation_columns(report, df, id_col=id_col) + df = frame if frame is not None else pq.read_table(str(ann_path)).to_pandas() + added = merge_annotation_columns(report, df, id_col=_resolve_id_col(df)) for name in added: df[name] = df[name].fillna("").astype(str) _atomic_write_table(pa.Table.from_pandas(df, preserve_index=False), ann_path) @@ -228,6 +235,15 @@ def stats( "(max-silhouette K), or 'both' (emit both clusterings).", ), ] = ClusterSelection.elbow, + stats_annotation: Annotated[ + str, + typer.Option( + "--stats-annotation", + help="Which annotation column(s) to score for cluster-validity: " + "'auto' (all suitable categoricals) or a comma-separated list. " + "Requires -a/--annotations.", + ), + ] = "auto", verbose: Annotated[ int, typer.Option("-v", "--verbose", count=True, help="Increase verbosity.") ] = 0, @@ -239,6 +255,12 @@ def stats( # columns, so --settings-out without -a would silently write nothing. if settings_out is not None and annotations is None: raise typer.BadParameter("--settings-out requires -a/--annotations.") + if ( + stats_annotation + and annotations is None + and stats_annotation.strip().lower() != "auto" + ): + raise typer.BadParameter("--stats-annotation requires -a/--annotations.") import pyarrow.parquet as pq @@ -246,6 +268,7 @@ def stats( from protspace.data.loaders import load_h5 from protspace.data.loaders.embedding_set import merge_same_name_sets from protspace.stats import compute_statistics + from protspace.stats.annotation_select import build_annotation_labels from protspace.stats.carriage import ( build_cluster_legend_settings, route_faithfulness_to_metadata, @@ -267,12 +290,22 @@ def stats( params = {"cluster_selection": cluster_selection.value} if annotations is None: params["cluster_annotations"] = False + + annotation_labels = None + ann_frame = None + if annotations is not None: + ann_frame = pq.read_table(str(annotations)).to_pandas() + annotation_labels = build_annotation_labels( + ann_frame, stats_annotation, id_col=_resolve_id_col(ann_frame) + ) + report = compute_statistics( embedding_sets, reductions, rng_seed=seed, params=params, default_metric=metric, + annotations=annotation_labels, ) # Route per-projection faithfulness into projections_metadata.info_json.quality @@ -289,7 +322,9 @@ def stats( n_cols = 0 if annotations is not None: - n_cols = _merge_annotations_with_columns(annotations, report) + # Reuse the frame already read for label-building — nothing has rewritten + # the annotations parquet since (only projections_metadata was touched). + n_cols = _merge_annotations_with_columns(annotations, report, frame=ann_frame) if settings_out is not None: cluster_settings = build_cluster_legend_settings(report) settings_out.parent.mkdir(parents=True, exist_ok=True) diff --git a/src/protspace/data/processors/pipeline.py b/src/protspace/data/processors/pipeline.py index c0d7ffa4..4a15521a 100644 --- a/src/protspace/data/processors/pipeline.py +++ b/src/protspace/data/processors/pipeline.py @@ -75,6 +75,7 @@ class PipelineConfig: no_scores: bool = False stats: bool = False cluster_selection: str = "elbow" # elbow | silhouette | both (for --stats) + stats_annotation: str = "auto" # which annotation(s) to score (for --stats) refetch_stages: frozenset[str] = field(default_factory=frozenset) annotations: list[str] | None = None intermediate_dir: Path | None = None @@ -728,6 +729,15 @@ def _compute_statistics( for red in all_reductions: red.setdefault("ids", all_headers) + + from protspace.stats.annotation_select import build_annotation_labels + + annotation_labels = None + if metadata is not None: + annotation_labels = build_annotation_labels( + metadata, self.config.stats_annotation, id_col="identifier" + ) + report = compute_statistics( embedding_sets, all_reductions, @@ -742,6 +752,7 @@ def _compute_statistics( # 'metric' from their params, so fall back to the run's metric # rather than silently assuming euclidean. default_metric=self.config.reducer_params.metric, + annotations=annotation_labels, ) route_faithfulness_to_metadata(report, all_reductions) if metadata is not None and report.annotation_columns: diff --git a/src/protspace/stats/__init__.py b/src/protspace/stats/__init__.py index 0c1e550a..13ba9e42 100644 --- a/src/protspace/stats/__init__.py +++ b/src/protspace/stats/__init__.py @@ -14,10 +14,17 @@ def get_statistics() -> list: """Return the registered Statistic instances (lazy-imported).""" global _STATISTICS if _STATISTICS is None: + from protspace.stats.metrics.annotation_validity import ( + AnnotationValidityStatistic, + ) from protspace.stats.metrics.faithfulness import FaithfulnessStatistic from protspace.stats.metrics.validity import ClusterValidityStatistic - _STATISTICS = [ClusterValidityStatistic(), FaithfulnessStatistic()] + _STATISTICS = [ + ClusterValidityStatistic(), + AnnotationValidityStatistic(), + FaithfulnessStatistic(), + ] return _STATISTICS diff --git a/src/protspace/stats/_sampling.py b/src/protspace/stats/_sampling.py new file mode 100644 index 00000000..1ebbfb14 --- /dev/null +++ b/src/protspace/stats/_sampling.py @@ -0,0 +1,33 @@ +"""Deterministic, id-canonical subsampling shared by the stats metrics. + +Bounding cost at 570k+ scale means every heavy metric subsamples. To keep the +scores *comparable across spaces* (a projection vs its source embedding) and +*reproducible* regardless of a parquet/h5 row ordering, the draw is seeded from +the id-set (not the raw seed) and taken over rows in canonical id order — two +inputs sharing an id-set then select the same proteins. +""" + +from __future__ import annotations + +import hashlib + +import numpy as np + + +def id_seed(rng_seed: int, ids: list[str]) -> int: + """Seed derived from ``(rng_seed, sorted ids)``. + + Paired with a canonical-id-order selection, two inputs with the same id-set + draw the same subset regardless of row order. + """ + digest = hashlib.sha256("|".join(sorted(map(str, ids))).encode()).hexdigest()[:8] + return (rng_seed * 2654435761 + int(digest, 16)) % (2**32) + + +def sorted_subsample(n: int, threshold: int, rng) -> np.ndarray | None: + """Sorted positional index subsample of size ``threshold``, or ``None`` when + ``n <= threshold``. Positional, so the caller must pass rows in canonical id + order for the draw to be id-canonical.""" + if n <= threshold: + return None + return np.sort(rng.permutation(n)[:threshold]) diff --git a/src/protspace/stats/annotation_select.py b/src/protspace/stats/annotation_select.py new file mode 100644 index 00000000..011c4a32 --- /dev/null +++ b/src/protspace/stats/annotation_select.py @@ -0,0 +1,170 @@ +"""Select which annotation columns to score, and materialise their labels. + +An annotation is "suitable" for cluster-validity when it is a low-cardinality +categorical column: at least 2 distinct non-empty values, at most +``min(max_card, n/2)`` (so it is not effectively an id), and not numeric. The +generated ``cluster_*`` membership columns and the id column are excluded. +""" + +from __future__ import annotations + +import logging + +from protspace.stats.base import CLUSTER_COLUMN_PREFIX + +logger = logging.getLogger(__name__) + +# Values treated as "missing" (dropped, never a category). Kept in sync with the +# codebase's canonical sentinels: ``core.constants.standardize_missing`` maps +# ``"", nan, none, null, NA, NaN`` → the display sentinel ``""``; pandas +# nullable dtypes stringify to ``""`` / ``"NaT"``. Missing any of these would +# score a phantom missing-value cluster and inflate a column's cardinality. +_MISSING = { + "", + "", + "", + "", + "nan", + "NaN", + "NaT", + "none", + "None", + "null", + "NA", +} + + +def _is_missing(value) -> bool: + """Whether a raw cell value is missing (None or a sentinel string).""" + return value is None or str(value) in _MISSING + + +def _clean(series) -> list[str]: + """Non-missing string values of a column.""" + return [str(v) for v in series.tolist() if not _is_missing(v)] + + +def _is_numeric(vals) -> bool: + """Whether every (already-cleaned) value parses as a float.""" + if not vals: + return False + try: + for s in vals: + float(s) + return True + except ValueError: + return False + + +def _is_suitable_column(series, cap: int) -> bool: + """A low-cardinality categorical: 2..cap distinct non-missing values, not + all-unique (id-like), and not numeric. + + Bails out as soon as the distinct count exceeds ``cap`` so a high-cardinality + free-text column doesn't grow a full 570k-value set before rejection. + """ + seen: set[str] = set() + total = 0 + for v in series.tolist(): + if _is_missing(v): + continue + total += 1 + seen.add(str(v)) + if len(seen) > cap: # too many categories → not a low-card categorical + return False + distinct = len(seen) + if distinct < 2 or distinct == total: # too few, or all-unique (id-like) + return False + return not _is_numeric(seen) + + +def suitable_annotations( + frame, id_col: str = "identifier", max_card: int = 50 +) -> list[str]: + n = len(frame) + cap = min(max_card, max(2, n // 2)) + return [ + col + for col in frame.columns + if col != id_col + and not col.startswith(CLUSTER_COLUMN_PREFIX) + and _is_suitable_column(frame[col], cap) + ] + + +def build_annotation_labels( + frame, selection, id_col: str = "identifier" +) -> dict[str, dict[str, str]]: + """``{annotation name -> {protein id -> category}}`` for the selection. + + ``selection`` is ``"auto"`` (all suitable), a comma-separated string of + column names (the raw ``--stats-annotation`` flag), or a list of names. + Missing / sentinel values are dropped, so a protein absent from a column's + mapping simply has no category for it. + """ + cols = list(getattr(frame, "columns", [])) + if id_col not in cols: + return {} + # ``wanted is None`` means "auto" (all suitable columns); otherwise it is the + # explicit list of requested names. Splitting the raw flag string here keeps + # every caller from re-implementing the "auto vs comma-list" parse. + if isinstance(selection, str): + stripped = selection.strip() + wanted = ( + None + if stripped.lower() == "auto" + else [s.strip() for s in stripped.split(",") if s.strip()] + ) + else: + wanted = [str(s).strip() for s in selection if str(s).strip()] + + if wanted is None: + names = suitable_annotations(frame, id_col=id_col) + else: + # Explicit names: honour the request. The suitability heuristic (cardinality + # cap, numeric/id-like exclusion) is a *discovery* filter for ``auto``, not + # an authorisation gate — a user who names ``ec_number`` wants it scored even + # though it is high-cardinality. Only require what the metric needs: the + # column exists and carries >= 2 distinct non-missing categories. + names = [] + for name in wanted: + if name == id_col or name not in cols: + logger.warning( + "--stats-annotation '%s' is not a column; skipping", name + ) + elif len({*_clean(frame[name])}) < 2: + logger.warning( + "--stats-annotation '%s' has fewer than 2 categories; skipping", + name, + ) + else: + names.append(name) + labels: dict[str, dict[str, str]] = {} + ids = [str(i) for i in frame[id_col].tolist()] + for name in names: + mapping = { + pid: str(v) + for pid, v in zip(ids, frame[name].tolist(), strict=False) + if not _is_missing(v) + } + if mapping: + labels[name] = mapping + return labels + + +def pair_by_id(mapping, lookup): + """Align an annotation ``{id: category}`` mapping to an id-keyed ``lookup``. + + Returns parallel lists ``(values, categories)`` over the ids present in both, + where ``values[i] == lookup[id]``. Ids missing from ``lookup`` (a point absent + from this space) are dropped — used by cluster-agreement to pair each auto + cluster label with its annotation category over the id-intersection. + """ + values: list = [] + categories: list = [] + for pid, cat in mapping.items(): + v = lookup.get(pid) + if v is not None: + values.append(v) + categories.append(cat) + return values, categories diff --git a/src/protspace/stats/base.py b/src/protspace/stats/base.py index bc8a6c1d..1fc8dd69 100644 --- a/src/protspace/stats/base.py +++ b/src/protspace/stats/base.py @@ -2,7 +2,7 @@ A ``Statistic`` describes a projection (and optionally its source embedding). It declares the inputs it needs and returns one or more ``StatRow`` records. The -tidy long-format table produced by ``StatsReport.to_arrow`` (eight columns) is +tidy long-format table produced by ``StatsReport.to_arrow`` (nine columns) is the bundle-boundary contract consumed downstream. Heavy imports (scikit-learn) live inside the metric/cluster modules, function- @@ -18,12 +18,22 @@ import numpy as np import pyarrow as pa -# The frozen eight-column schema. New scalar statistics add rows, never columns; -# any per-source attribute (e.g. an annotation column name) goes in ``extra_json``. +# Prefix of the generated per-protein cluster-membership columns +# (``cluster_elbow_`` / ``cluster_silhouette_``). Shared so +# ``annotation_select`` can exclude them from annotation scoring by the same +# contract that ``ClusterValidityStatistic`` names them by — if the two drift, +# the auto-clusters get scored as annotations again (the circular self-validity +# this design removed). +CLUSTER_COLUMN_PREFIX = "cluster_" + +# The tidy schema. Rows are the bundle-boundary contract. Dimensions of the data +# (space, annotation, label kind, metric) are columns; per-row provenance +# (seeds, sample sizes, inertia lists) goes in ``extra_json``. STATS_SCHEMA = pa.schema( [ ("space_kind", pa.string()), ("space_name", pa.string()), + ("annotation", pa.string()), ("stat_family", pa.string()), ("label_kind", pa.string()), ("metric", pa.string()), @@ -67,6 +77,10 @@ class StatContext: embedding_name: str | None = None high_dim_metric: str = "euclidean" params: dict = field(default_factory=dict) + # annotation name -> {protein id -> category label}. Present only when the + # caller requested annotation-based validity; id-keyed so lookup is + # order-independent for any space (embedding or projection). + annotations: dict[str, dict[str, str]] | None = None @dataclass @@ -74,14 +88,15 @@ class StatRow: """One statistic value. ``destination`` routes the row to a bundle part at carriage time: - ``statistics_part`` (the tidy 8-column table — the default, so every existing - construction is unchanged), ``projection_metadata`` (folded into a projection's - ``info_json``), or ``annotation`` (a per-protein column). It is carriage - metadata, not a tidy-table column, so ``to_record`` never emits it. + ``statistics_part`` (the tidy 9-column table — the default), ``projection_metadata`` + (folded into a projection's ``info_json``), or ``annotation`` (a per-protein + column). It is carriage metadata, not a tidy-table column, so ``to_record`` + never emits it. """ space_kind: str space_name: str + annotation: str # "" for non-annotation rows; the annotation name otherwise stat_family: str label_kind: str metric: str @@ -94,6 +109,7 @@ def to_record(self) -> dict: return { "space_kind": self.space_kind, "space_name": self.space_name, + "annotation": self.annotation, "stat_family": self.stat_family, "label_kind": self.label_kind, "metric": self.metric, @@ -159,10 +175,13 @@ class Statistic(Protocol): """A unit of computation over a projection space. ``requires_embedding`` lets the driver skip statistics when no source - embedding is available for a projection. + embedding is available for a projection. ``embedding_space`` opts a statistic + into the driver's once-per-embedding pass (scoring the source embedding, not + just each projection); defaults to ``False`` for projection-only statistics. """ family: str requires_embedding: bool + embedding_space: bool = False def compute(self, ctx: StatContext) -> list[StatRow]: ... diff --git a/src/protspace/stats/driver.py b/src/protspace/stats/driver.py index 252b1a7f..396f1128 100644 --- a/src/protspace/stats/driver.py +++ b/src/protspace/stats/driver.py @@ -18,6 +18,26 @@ logger = logging.getLogger(__name__) +def _run_stats( + report: StatsReport, ctx: StatContext, stats: list, *, kind: str +) -> None: + """Run each statistic on ``ctx``, isolating per-statistic failures. + + ``kind`` (``projection`` | ``embedding``) only tags the warning message. + """ + for stat in stats: + try: + report.add(stat.compute(ctx)) + except Exception as exc: # noqa: BLE001 - statistics are secondary + logger.warning( + "statistic %s failed for %s '%s': %s", + getattr(stat, "family", stat), + kind, + ctx.space_name, + exc, + ) + + def _select_embedding(reduction: dict, embedding_sets: list, emb_by_name: dict): """Pick the embedding set that produced this projection. @@ -82,6 +102,7 @@ def compute_statistics( params: dict | None = None, statistics: list | None = None, default_metric: str = "euclidean", + annotations: dict | None = None, ) -> StatsReport: """Compute statistics for each projection. @@ -95,6 +116,10 @@ def compute_statistics( ``max_fit_sample``, ``n_triplets_per_point``; ``cluster_selection`` (``elbow`` | ``silhouette`` | ``both``); ``cluster_annotations`` and ``include_scores`` (per-protein membership column + attached silhouette). + annotations: annotation name -> {protein id -> category label}. When + supplied, threaded into every projection's ``StatContext`` and also + drives a once-per-embedding pass (see below) so annotation-validity + statistics can score the source embedding as a separability ceiling. Returns: A ``StatsReport`` (may be partial/empty; never raises on a statistic error). @@ -147,6 +172,7 @@ def compute_statistics( embedding_name=embedding_name, high_dim_metric=high_dim_metric, params=params, + annotations=annotations, ) except Exception as exc: # noqa: BLE001 - one bad reduction must not sink the report logger.warning( @@ -154,17 +180,40 @@ def compute_statistics( ) continue - for stat in stats: - if getattr(stat, "requires_embedding", False) and ctx.embedding is None: + runnable = [ + s + for s in stats + if not getattr(s, "requires_embedding", False) or ctx.embedding is not None + ] + _run_stats(report, ctx, runnable, kind="projection") + + # Once-per-embedding pass: annotation-validity on the source embedding itself + # (the true-separability "ceiling"), computed once per embedding rather than + # repeated for every projection that shares it. Only statistics that opt in + # via ``embedding_space`` run here — skip the whole pass when none do so we + # don't build the (large) embedding context for nothing. + emb_stats = [s for s in stats if getattr(s, "embedding_space", False)] + if annotations and emb_stats: + for es in embedding_sets: + if getattr(es, "precomputed", False): continue try: - report.add(stat.compute(ctx)) - except Exception as exc: # noqa: BLE001 - statistics are secondary + ectx = StatContext( + space_kind="embedding", + space_name=es.name, + # Keep the embedding at its native dtype (float32); the scored + # statistic upcasts only its bounded subsample, not all 570k rows. + coords=np.asarray(es.data), + ids=list(es.headers), + rng_seed=rng_seed, + params=params, + annotations=annotations, + ) + except Exception as exc: # noqa: BLE001 logger.warning( - "statistic %s failed for projection '%s': %s", - getattr(stat, "family", stat), - ctx.space_name, - exc, + "embedding-stats setup failed for '%s': %s", es.name, exc ) + continue + _run_stats(report, ectx, emb_stats, kind="embedding") return report diff --git a/src/protspace/stats/metrics/annotation_validity.py b/src/protspace/stats/metrics/annotation_validity.py new file mode 100644 index 00000000..e6f677ce --- /dev/null +++ b/src/protspace/stats/metrics/annotation_validity.py @@ -0,0 +1,108 @@ +"""Annotation-based cluster-validity: how well an annotation's categories +separate in a given space (embedding or projection). + +silhouette / Davies-Bouldin / Calinski-Harabasz are computed with the +annotation's category labels (not auto-KMeans labels), on ``ctx.coords`` — +the driver hands us the embedding for the once-per-embedding pass and the 2D +projection for the per-projection pass. scikit-learn imports are function-local. +""" + +from __future__ import annotations + +import numpy as np + +from protspace.stats._sampling import id_seed, sorted_subsample +from protspace.stats.base import StatContext, StatRow + +DEFAULT_SAMPLE_THRESHOLD = 5000 + + +class AnnotationValidityStatistic: + """silhouette / DBI / CH of each annotation's categories on ``ctx.coords``.""" + + family = "annotation_validity" + requires_embedding = False + embedding_space = True # also run by the driver's once-per-embedding pass + + def compute(self, ctx: StatContext) -> list[StatRow]: + if not ctx.annotations: + return [] + from sklearn.metrics import ( + calinski_harabasz_score, + davies_bouldin_score, + silhouette_score, + ) + + coords = np.asarray(ctx.coords) + threshold = int(ctx.params.get("sample_threshold", DEFAULT_SAMPLE_THRESHOLD)) + id_to_row = {pid: i for i, pid in enumerate(ctx.ids)} + rows: list[StatRow] = [] + + for name, mapping in ctx.annotations.items(): + # Annotated points present in this space, in canonical id order — so the + # subsample below is reproducible and picks the *same* proteins across + # spaces (embedding vs projection) whenever the annotated id-set matches; + # otherwise the "separability ceiling" would compare two different draws. + present = sorted( + (pid, id_to_row[pid], cat) + for pid, cat in mapping.items() + if pid in id_to_row + ) + if len(present) < 3 or len({c for _, _, c in present}) < 2: + continue # need >= 3 points, >= 2 categories + + # Bound cost: subsample (id-seeded) BEFORE gathering + upcasting, so at + # 570k scale we materialise ~threshold float64 rows, not all of them + # (label integers are arbitrary, so renumbering post-subsample is + # metric-invariant). Shared across all three metrics. + rng = np.random.default_rng(id_seed(ctx.rng_seed, [p[0] for p in present])) + sub = sorted_subsample(len(present), threshold, rng) + if sub is not None: + present = [present[i] for i in sub] + row_idx = [r for _, r, _ in present] + cats = [c for _, _, c in present] + cat_to_int = {c: j for j, c in enumerate(sorted(set(cats)))} + Xa = np.asarray(coords[row_idx], dtype=float) + labels = np.asarray([cat_to_int[c] for c in cats]) + n = Xa.shape[0] + _, counts = np.unique(labels, return_counts=True) + achieved = len(counts) + if achieved < 2: # a category vanished under subsampling + continue + base = { + "space_kind": ctx.space_kind, + "space_name": ctx.space_name, + "annotation": name, + "stat_family": self.family, + "label_kind": "annotation", + } + extra = { + "seed": ctx.rng_seed, + "n_labels": int(n), + "n_categories": int(achieved), + "sampled": sub is not None, + } + + # silhouette needs 2 <= k <= n-1; DBI/CH are unstable with singletons. + candidates: list = [] + if 2 <= achieved <= n - 1: + candidates.append(("silhouette", silhouette_score)) + if not bool((counts < 2).any()): + candidates += [ + ("davies_bouldin", davies_bouldin_score), + ("calinski_harabasz", calinski_harabasz_score), + ] + for metric_name, fn in candidates: + try: + rows.append( + StatRow( + metric=metric_name, + metric_kind="validity", + value=float(fn(Xa, labels)), + extra=extra, + **base, + ) + ) + except Exception: # noqa: BLE001 - best-effort + pass + return rows diff --git a/src/protspace/stats/metrics/faithfulness.py b/src/protspace/stats/metrics/faithfulness.py index 953979d9..950ab615 100644 --- a/src/protspace/stats/metrics/faithfulness.py +++ b/src/protspace/stats/metrics/faithfulness.py @@ -32,10 +32,9 @@ from __future__ import annotations -import hashlib - import numpy as np +from protspace.stats._sampling import id_seed, sorted_subsample from protspace.stats.base import StatContext, StatRow DEFAULT_K = 15 @@ -44,15 +43,6 @@ DEFAULT_N_TRIPLETS_PER_POINT = 5 -def _subsample_seed(rng_seed: int, ids: list[str]) -> int: - """A seed derived from (rng_seed, sorted ids). Paired with a canonical-id-order - selection (see ``compute``), two inputs with the same id-set draw the same - id subset regardless of row order — keeping cross-projection scores comparable - and reproducible without relying on a shared row ordering.""" - digest = hashlib.sha256("|".join(sorted(ids)).encode()).hexdigest()[:8] - return (rng_seed * 2654435761 + int(digest, 16)) % (2**32) - - def _knn_overlap(embedding, coords, k: int, metric: str) -> float: from sklearn.neighbors import NearestNeighbors @@ -171,6 +161,7 @@ class FaithfulnessStatistic: family = "faithfulness" requires_embedding = True + embedding_space = False # projection-only (compares projection vs embedding) def compute(self, ctx: StatContext) -> list[StatRow]: from sklearn.manifold import trustworthiness @@ -193,6 +184,7 @@ def compute(self, ctx: StatContext) -> list[StatRow]: base = { "space_kind": ctx.space_kind, "space_name": ctx.space_name, + "annotation": "", "stat_family": self.family, "label_kind": "none", "metric_kind": "faithfulness", @@ -235,11 +227,11 @@ def compute(self, ctx: StatContext) -> list[StatRow]: hi_metric = ctx.high_dim_metric or "euclidean" sampled = False - if n > sample_threshold: - rng = np.random.default_rng(_subsample_seed(ctx.rng_seed, ids)) - # Rows are already in canonical id order, so a positional draw is itself - # id-canonical and thus row-order invariant. - idx = np.sort(rng.permutation(n)[:sample_threshold]) + # Rows are already in canonical id order, so a positional draw is itself + # id-canonical and thus row-order invariant. + rng = np.random.default_rng(id_seed(ctx.rng_seed, ids)) + idx = sorted_subsample(n, sample_threshold, rng) + if idx is not None: emb = emb[idx] coords = coords[idx] n = len(idx) diff --git a/src/protspace/stats/metrics/validity.py b/src/protspace/stats/metrics/validity.py index e2a12ac1..096d743a 100644 --- a/src/protspace/stats/metrics/validity.py +++ b/src/protspace/stats/metrics/validity.py @@ -1,11 +1,20 @@ -"""Cluster-validity statistics on projection coordinates. - -KMeans labels the projection; silhouette, Davies-Bouldin and Calinski-Harabasz -score that labelling. The K can be chosen by the inertia **elbow** and/or by -**max silhouette** (``ctx.params["cluster_selection"]`` = ``elbow`` | ``silhouette`` -| ``both``); each selection is emitted with its own ``label_kind`` -(``kmeans_elbow`` / ``kmeans_silhouette``). The chosen K is emitted as a -``metric_kind="meta"`` row so consumers can exclude it from validity aggregates. +"""Auto-clustering (KMeans) on projection coordinates + agreement with annotations. + +KMeans labels the projection. The K can be chosen by the inertia **elbow** +and/or by **max silhouette** (``ctx.params["cluster_selection"]`` = ``elbow`` | +``silhouette`` | ``both``); each selection is emitted with its own +``label_kind`` (``kmeans_elbow`` / ``kmeans_silhouette``). The chosen K is +emitted as a ``metric_kind="meta"`` row (``n_clusters``). + +This auto-clustering is no longer self-scored (no silhouette / Davies-Bouldin / +Calinski-Harabasz on the KMeans labels themselves — that was circular: KMeans +optimises inertia, then silhouette grades the KMeans result against itself). +Instead, when ``ctx.annotations`` are supplied, each auto-clustering is compared +against every annotation's category labels via **ARI** (``adjusted_rand``) and +**NMI** (``normalized_mutual_info``) — ``stat_family="cluster_agreement"``, +``metric_kind="agreement"`` — reusing the KMeans labels already computed (no +second sweep). Annotation-based *validity* (silhouette/DBI/CH scored on the +annotation's own categories) lives in ``AnnotationValidityStatistic``. Each labelling also becomes a per-protein ``cluster_*`` membership column whose per-point silhouette rides along as an attached ``value|score`` confidence — the @@ -22,7 +31,13 @@ import numpy as np -from protspace.stats.base import AnnotationColumn, StatContext, StatRow +from protspace.stats.annotation_select import pair_by_id +from protspace.stats.base import ( + CLUSTER_COLUMN_PREFIX, + AnnotationColumn, + StatContext, + StatRow, +) from protspace.stats.cluster.kmeans_elbow import kmeans_elbow DEFAULT_SAMPLE_THRESHOLD = 5000 @@ -34,20 +49,6 @@ DEFAULT_MAX_FIT_SAMPLE = 50_000 -def _silhouette(X, labels, *, rng_seed: int, sample_threshold: int): - from sklearn.metrics import silhouette_score - - n = len(labels) - if n > sample_threshold: - val = float( - silhouette_score( - X, labels, sample_size=sample_threshold, random_state=rng_seed - ) - ) - return val, {"sampled": True, "sample_size": int(sample_threshold)} - return float(silhouette_score(X, labels)), {"sampled": False, "sample_size": int(n)} - - class _Labeling(NamedTuple): """One K-selection's clustering: how it was chosen + its column and labels.""" @@ -59,10 +60,11 @@ class _Labeling(NamedTuple): class ClusterValidityStatistic: - """Elbow / silhouette K + silhouette / Davies-Bouldin / Calinski-Harabasz.""" + """Elbow / silhouette auto-clustering + ARI/NMI agreement vs annotations.""" family = "cluster_validity" requires_embedding = False + embedding_space = False # projection-only (auto-clustering + agreement) def compute(self, ctx: StatContext) -> list: X = np.asarray(ctx.coords, dtype=float) @@ -95,7 +97,7 @@ def compute(self, ctx: StatContext) -> list: labelings.append( _Labeling( "kmeans_elbow", - f"cluster_elbow_{ctx.space_name}", + f"{CLUSTER_COLUMN_PREFIX}elbow_{ctx.space_name}", "elbow", res.k, res.labels, @@ -105,7 +107,7 @@ def compute(self, ctx: StatContext) -> list: labelings.append( _Labeling( "kmeans_silhouette", - f"cluster_silhouette_{ctx.space_name}", + f"{CLUSTER_COLUMN_PREFIX}silhouette_{ctx.space_name}", "silhouette", int(res.silhouette_k), res.silhouette_labels, @@ -119,11 +121,8 @@ def compute(self, ctx: StatContext) -> list: def _emit_labeling(self, ctx, X, n, res, labeling: _Labeling) -> list: """Rows + membership column for one labelling (elbow or silhouette-K).""" - from sklearn.metrics import calinski_harabasz_score, davies_bouldin_score - rng_seed = ctx.rng_seed params = ctx.params - sample_threshold = int(params.get("sample_threshold", DEFAULT_SAMPLE_THRESHOLD)) label_kind = labeling.label_kind col_name = labeling.col_name selection_name = labeling.selection_name @@ -131,23 +130,13 @@ def _emit_labeling(self, ctx, X, n, res, labeling: _Labeling) -> list: k = int(labeling.requested_k) # Report the ACHIEVED number of distinct clusters (KMeans can collapse on - # coincident points), keeping the requested K in extra. The cluster sizes - # also feed the Davies-Bouldin / Calinski-Harabasz singleton guard. - unique_labels, label_counts = np.unique(labels, return_counts=True) + # coincident points), keeping the requested K in extra. + unique_labels = np.unique(labels) achieved = int(len(unique_labels)) - has_singleton = bool((label_counts < 2).any()) - - base = { - "space_kind": ctx.space_kind, - "space_name": ctx.space_name, - "stat_family": self.family, - "label_kind": label_kind, - } # Decide up front whether the exact per-point silhouette will be computed. - # When it is, its mean is the exact aggregate silhouette (== unsampled - # silhouette_score) AND its per-point values ride along on the membership - # column, so nothing is computed twice. + # Its per-point values ride along on the membership column below as a + # `|score` confidence; no aggregate self-validity row is emitted for it. silhouette_ok = 2 <= k <= n - 1 hard_ceiling = int( params.get("silhouette_hard_ceiling", DEFAULT_SILHOUETTE_HARD_CEILING) @@ -178,57 +167,18 @@ def _emit_labeling(self, ctx, X, n, res, labeling: _Labeling) -> list: rows: list = [ StatRow( + space_kind=ctx.space_kind, + space_name=ctx.space_name, + annotation="", + stat_family=self.family, + label_kind=label_kind, metric="n_clusters", metric_kind="meta", value=float(achieved), extra=meta_extra, - **base, ) ] - # silhouette needs 2 <= k <= n - 1 - if silhouette_ok: - try: - if per_point_samples is not None: - # Exact aggregate over all n, consistent with the per-point - # values attached to the membership column below. - sil = float(per_point_samples.mean()) - sx = {"sampled": False, "sample_size": int(n)} - else: - sil, sx = _silhouette( - X, labels, rng_seed=rng_seed, sample_threshold=sample_threshold - ) - rows.append( - StatRow( - metric="silhouette", - metric_kind="validity", - value=sil, - extra={**sx, "seed": rng_seed}, - **base, - ) - ) - except Exception: # noqa: BLE001 - validity is best-effort - pass - - # Davies-Bouldin / Calinski-Harabasz are unstable with singleton clusters. - if not has_singleton: - for metric_name, fn in ( - ("davies_bouldin", davies_bouldin_score), - ("calinski_harabasz", calinski_harabasz_score), - ): - try: - rows.append( - StatRow( - metric=metric_name, - metric_kind="validity", - value=float(fn(X, labels)), - extra={"seed": rng_seed}, - **base, - ) - ) - except Exception: # noqa: BLE001 - pass - # Per-protein membership: a categorical `cluster N` label, with the per-point # silhouette attached as a `value|score` confidence (like ECO / InterPro bit # scores) so a single column carries both membership and its confidence. @@ -265,4 +215,42 @@ def _membership(pid, lbl): }, ) ) + + # ARI/NMI: does this auto-clustering recover each annotation? Reuses the + # KMeans labels already computed (no second sweep). Compared over the + # id-intersection of clustered points and annotated points. + if ctx.annotations: + from sklearn.metrics import ( + adjusted_rand_score, + normalized_mutual_info_score, + ) + + label_by_id = dict(zip(ctx.ids, labels, strict=False)) + for name, mapping in ctx.annotations.items(): + # ``paired_clu`` holds numpy ints straight from the KMeans labels; + # sklearn's ARI/NMI accept them as-is (no per-element cast needed). + paired_clu, paired_ann = pair_by_id(mapping, label_by_id) + if len(set(paired_ann)) < 2 or len(paired_ann) < 3: + continue + for metric_name, fn in ( + ("adjusted_rand", adjusted_rand_score), + ("normalized_mutual_info", normalized_mutual_info_score), + ): + try: + rows.append( + StatRow( + space_kind=ctx.space_kind, + space_name=ctx.space_name, + annotation=name, + stat_family="cluster_agreement", + label_kind=label_kind, + metric=metric_name, + metric_kind="agreement", + value=float(fn(paired_ann, paired_clu)), + extra={"seed": rng_seed, "n_labels": len(paired_ann)}, + ) + ) + except Exception: # noqa: BLE001 - best-effort + pass + return rows diff --git a/tests/test_annotation_select.py b/tests/test_annotation_select.py new file mode 100644 index 00000000..96042ac8 --- /dev/null +++ b/tests/test_annotation_select.py @@ -0,0 +1,66 @@ +import pandas as pd +import pytest + +from protspace.stats.annotation_select import ( + build_annotation_labels, + suitable_annotations, +) + + +def _frame(): + return pd.DataFrame( + { + "identifier": [f"p{i}" for i in range(6)], + "major_group": ["a", "a", "b", "b", "c", "c"], # suitable (3 cats) + "all_unique": [f"u{i}" for i in range(6)], # unsuitable (all unique) + "constant": ["x"] * 6, # unsuitable (1 cat) + "count": [1, 2, 3, 4, 5, 6], # unsuitable (numeric) + "cluster_elbow_P": ["cluster 0"] * 3 + ["cluster 1"] * 3, # excluded + } + ) + + +def test_suitable_annotations_filters(): + names = suitable_annotations(_frame()) + assert names == ["major_group"] + + +def test_build_labels_auto(): + labels = build_annotation_labels(_frame(), "auto") + assert set(labels) == {"major_group"} + assert labels["major_group"]["p0"] == "a" + assert len(labels["major_group"]) == 6 + + +def test_build_labels_explicit_names_and_missing_dropped(): + frame = _frame() + frame.loc[0, "major_group"] = "" # sentinel missing + frame.loc[1, "major_group"] = "" # empty + labels = build_annotation_labels(frame, ["major_group"]) + assert "p0" not in labels["major_group"] # dropped + assert "p1" not in labels["major_group"] # empty dropped + assert labels["major_group"]["p2"] == "b" + + +def test_build_labels_unknown_name_skipped(): + labels = build_annotation_labels(_frame(), ["does_not_exist"]) + assert labels == {} + + +def test_explicit_name_bypasses_auto_suitability_heuristic(): + # A high-cardinality categorical the auto heuristic rejects (all-unique) must + # still be honoured when the user names it explicitly — the suitability filter + # is for discovery, not authorisation. (Regression: the documented + # `--stats-annotation ec_number` example silently scored nothing.) + frame = _frame() + assert "all_unique" not in suitable_annotations(frame) # auto rejects it + labels = build_annotation_labels(frame, "all_unique") # raw-string explicit + assert set(labels) == {"all_unique"} + assert len(labels["all_unique"]) == 6 + + +def test_explicit_numeric_coded_categorical_is_honoured(): + # An integer-coded categorical (auto excludes as "numeric") is scored when named. + labels = build_annotation_labels(_frame(), ["count"]) + assert set(labels) == {"count"} + assert labels["count"]["p0"] == "1" diff --git a/tests/test_annotation_validity.py b/tests/test_annotation_validity.py new file mode 100644 index 00000000..09cc95ba --- /dev/null +++ b/tests/test_annotation_validity.py @@ -0,0 +1,99 @@ +import numpy as np + +from protspace.stats.base import StatContext, StatRow +from protspace.stats.metrics.annotation_validity import AnnotationValidityStatistic + + +def _blobs(n=200, centers=4, dim=2, seed=1): + from sklearn.datasets import make_blobs + + X, y = make_blobs(n_samples=n, centers=centers, n_features=dim, random_state=seed) + return X, y + + +def test_scores_each_annotation_on_ctx_coords(): + X, y = _blobs(n=200, centers=4, dim=2, seed=3) + ids = [f"p{i}" for i in range(200)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(ids, y, strict=True)}} + outs = AnnotationValidityStatistic().compute( + StatContext("projection", "PCA_2", coords=X, ids=ids, annotations=ann) + ) + by_metric = {r.metric: r for r in outs if isinstance(r, StatRow)} + assert {"silhouette", "davies_bouldin", "calinski_harabasz"} <= set(by_metric) + s = by_metric["silhouette"] + assert s.stat_family == "annotation_validity" + assert s.annotation == "grp" and s.label_kind == "annotation" + assert 0.4 < s.value <= 1.0 # well-separated blobs → high silhouette + + +def test_space_kind_is_taken_from_context(): + X, y = _blobs(n=120, centers=3, dim=8, seed=4) + ids = [f"p{i}" for i in range(120)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(ids, y, strict=True)}} + outs = AnnotationValidityStatistic().compute( + StatContext("embedding", "prot_t5", coords=X, ids=ids, annotations=ann) + ) + assert all(r.space_kind == "embedding" for r in outs) + assert all(r.space_name == "prot_t5" for r in outs) + + +def test_missing_annotation_values_excluded(): + X, y = _blobs(n=100, centers=2, dim=2, seed=5) + ids = [f"p{i}" for i in range(100)] + # Only half the proteins have a category → the rest are dropped from scoring. + ann = {"grp": {pid: f"g{int(c)}" for pid, c in list(zip(ids, y, strict=True))[:50]}} + outs = AnnotationValidityStatistic().compute( + StatContext("projection", "P", coords=X, ids=ids, annotations=ann) + ) + sil = next(r for r in outs if r.metric == "silhouette") + assert sil.extra["n_labels"] == 50 + + +def test_single_category_annotation_emits_nothing(): + X, _ = _blobs(n=80, centers=1, dim=2, seed=6) + ids = [f"p{i}" for i in range(80)] + ann = {"grp": dict.fromkeys(ids, "only")} # 1 category + outs = AnnotationValidityStatistic().compute( + StatContext("projection", "P", coords=X, ids=ids, annotations=ann) + ) + assert outs == [] + + +def test_no_annotations_returns_empty(): + X, _ = _blobs(n=50, centers=2, dim=2, seed=7) + outs = AnnotationValidityStatistic().compute( + StatContext("projection", "P", coords=X, ids=[f"p{i}" for i in range(50)]) + ) + assert outs == [] + + +def test_subsample_path_flags_sampled_and_is_deterministic(): + """When n exceeds ``sample_threshold`` the shared subsample kicks in: the + emitted silhouette row must report ``sampled=True`` and ``n_labels`` equal to + the threshold (not the full n), and repeating the identical call must + reproduce the exact same value (deterministic rng_seed-based subsample).""" + threshold = 30 + X, y = _blobs(n=200, centers=4, dim=2, seed=9) + ids = [f"p{i}" for i in range(200)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(ids, y, strict=True)}} + + def _run(): + outs = AnnotationValidityStatistic().compute( + StatContext( + "projection", + "P", + coords=X, + ids=ids, + annotations=ann, + params={"sample_threshold": threshold}, + ) + ) + return next(r for r in outs if r.metric == "silhouette") + + sil = _run() + assert sil.extra["sampled"] is True + assert sil.extra["n_labels"] == threshold + + sil_again = _run() + assert sil_again.value == sil.value + assert sil_again.extra == sil.extra diff --git a/tests/test_stats.py b/tests/test_stats.py index 284578af..8aa7b993 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -40,19 +40,20 @@ def _blobs(n=300, centers=4, dim=2, seed=0): # --------------------------------------------------------------------------- # -def test_registry_returns_two_statistics(): +def test_registry_returns_three_statistics(): stats = get_statistics() families = {s.family for s in stats} - assert families == {"cluster_validity", "faithfulness"} + assert families == {"cluster_validity", "annotation_validity", "faithfulness"} -def test_to_arrow_has_eight_column_schema(): +def test_to_arrow_has_nine_column_schema(): report = StatsReport() report.add( [ StatRow( space_kind="projection", space_name="UMAP_2", + annotation="", stat_family="cluster_validity", label_kind="kmeans_elbow", metric="silhouette", @@ -63,9 +64,10 @@ def test_to_arrow_has_eight_column_schema(): ] ) table = report.to_arrow() - assert table.schema.names == [ + names = [ "space_kind", "space_name", + "annotation", "stat_family", "label_kind", "metric", @@ -73,6 +75,8 @@ def test_to_arrow_has_eight_column_schema(): "value", "extra_json", ] + assert table.schema.names == names + assert len(names) == 9 assert table.num_rows == 1 assert table.column("value")[0].as_py() == pytest.approx(0.42) @@ -83,6 +87,37 @@ def test_empty_report_keeps_schema(): assert table.schema == STATS_SCHEMA +def test_statrow_carries_annotation_column(): + from protspace.stats.base import STATS_SCHEMA, StatRow, StatsReport + + assert "annotation" in STATS_SCHEMA.names + row = StatRow( + space_kind="embedding", + space_name="prot_t5", + stat_family="annotation_validity", + label_kind="annotation", + metric="silhouette", + metric_kind="validity", + value=0.42, + annotation="major_group", + ) + rec = row.to_record() + assert rec["annotation"] == "major_group" + report = StatsReport() + report.add([row]) + tbl = report.to_arrow() + assert tbl.column("annotation").to_pylist() == ["major_group"] + + +def test_statcontext_defaults_annotations_none(): + import numpy as np + + from protspace.stats.base import StatContext + + ctx = StatContext("projection", "P", coords=np.zeros((3, 2)), ids=["a", "b", "c"]) + assert ctx.annotations is None + + # --------------------------------------------------------------------------- # # 2. cluster validity / elbow # --------------------------------------------------------------------------- # @@ -96,31 +131,36 @@ def test_elbow_recovers_known_cluster_count(): assert res.knee_confidence == "high" +def _mean_membership_silhouette(outs): + """Mean of the per-point silhouette attached to the (sole) membership + column's `cluster N|score` values — the aggregate `silhouette` StatRow was + removed (self-validity on auto-clusters is circular); this per-point + confidence is the retained signal.""" + from protspace.stats.base import AnnotationColumn + + col = next(o for o in outs if isinstance(o, AnnotationColumn)) + per_point = [float(v.split("|", 1)[1]) for v in col.values.values()] + return float(np.mean(per_point)) + + def test_cluster_validity_separated_vs_overlapping(): sep, _ = _blobs(n=300, centers=4, dim=2, seed=2) ctx = StatContext( "projection", "PCA_2", coords=sep, ids=[str(i) for i in range(len(sep))] ) - sep_sil = { - r.metric: r.value - for r in ClusterValidityStatistic().compute(ctx) - if isinstance(r, StatRow) - }["silhouette"] + sep_sil = _mean_membership_silhouette(ClusterValidityStatistic().compute(ctx)) assert sep_sil > 0.6 # Heavily overlapping clusters: KMeans still imposes a split, but the - # silhouette is markedly lower than for well-separated clusters. + # per-point silhouette (attached to the membership column) is markedly + # lower than for well-separated clusters. overlap, _ = make_blobs( n_samples=300, centers=4, n_features=2, random_state=2, cluster_std=4.0 ) ctx2 = StatContext( "projection", "PCA_2", coords=overlap, ids=[str(i) for i in range(300)] ) - ov_sil = { - r.metric: r.value - for r in ClusterValidityStatistic().compute(ctx2) - if isinstance(r, StatRow) - }["silhouette"] + ov_sil = _mean_membership_silhouette(ClusterValidityStatistic().compute(ctx2)) assert ov_sil < 0.45 assert sep_sil > ov_sil + 0.2 @@ -135,11 +175,59 @@ def test_cluster_validity_emits_meta_and_validity_kinds(): ] by_metric = {r.metric: r for r in rows} assert by_metric["n_clusters"].metric_kind == "meta" - assert by_metric["silhouette"].metric_kind == "validity" - assert {"davies_bouldin", "calinski_harabasz"} <= set(by_metric) + # Self-validity (silhouette/DBI/CH) on the auto-clusters is no longer + # emitted (circular: KMeans optimises inertia, then silhouette grades the + # KMeans result against itself); without annotations, n_clusters is the + # only row. + assert set(by_metric) == {"n_clusters"} assert all(r.label_kind == "kmeans_elbow" for r in rows) +def test_cluster_validity_emits_agreement_not_self_validity(): + from protspace.stats.base import AnnotationColumn, StatContext, StatRow + from protspace.stats.metrics.validity import ClusterValidityStatistic + + X, y = _blobs(n=200, centers=4, dim=2, seed=61) + ids = [f"p{i}" for i in range(200)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(ids, y, strict=True)}} + outs = ClusterValidityStatistic().compute( + StatContext("projection", "PCA_2", coords=X, ids=ids, annotations=ann) + ) + rows = [o for o in outs if isinstance(o, StatRow)] + metrics = {r.metric for r in rows} + # No self-validity rows anymore: + assert not ({"silhouette", "davies_bouldin", "calinski_harabasz"} & metrics) + # n_clusters meta kept: + assert "n_clusters" in metrics + # ARI/NMI agreement vs the annotation, tagged correctly: + agree = [r for r in rows if r.stat_family == "cluster_agreement"] + assert {r.metric for r in agree} == {"adjusted_rand", "normalized_mutual_info"} + assert all(r.annotation == "grp" and r.metric_kind == "agreement" for r in agree) + assert all(r.label_kind == "kmeans_elbow" for r in agree) + # Auto-clusters recover well-separated blobs → high agreement. + ari = next(r for r in agree if r.metric == "adjusted_rand") + assert ari.value > 0.5 + # Membership column still emitted. + assert any(isinstance(o, AnnotationColumn) for o in outs) + + +def test_cluster_validity_no_annotations_still_emits_membership(): + from protspace.stats.base import AnnotationColumn, StatContext, StatRow + from protspace.stats.metrics.validity import ClusterValidityStatistic + + X, _ = _blobs(n=150, centers=3, dim=2, seed=62) + ids = [f"p{i}" for i in range(150)] + outs = ClusterValidityStatistic().compute( + StatContext("projection", "P", coords=X, ids=ids) + ) + assert any(isinstance(o, AnnotationColumn) for o in outs) + assert not [ + r + for r in outs + if isinstance(r, StatRow) and r.stat_family == "cluster_agreement" + ] + + def test_cluster_validity_too_few_points(): ctx = StatContext("projection", "PCA_2", coords=np.zeros((2, 2)), ids=["a", "b"]) assert ClusterValidityStatistic().compute(ctx) == [] @@ -278,12 +366,10 @@ def test_driver_full_matrix_shape(): ] report = compute_statistics([emb], reductions, rng_seed=42) metrics = {r.metric for r in report.rows} - assert { - "silhouette", - "davies_bouldin", - "calinski_harabasz", - "n_clusters", - } <= metrics + # Self-validity (silhouette/DBI/CH) on the auto-clusters is gone; only the + # n_clusters meta row remains without annotations. + assert "n_clusters" in metrics + assert not ({"silhouette", "davies_bouldin", "calinski_harabasz"} & metrics) assert {"knn_overlap", "trustworthiness", "continuity"} <= metrics assert all(r.space_name == "ProtT5 — PCA 2" for r in report.rows) @@ -385,7 +471,7 @@ def test_cluster_validity_uses_full_projection_not_embedding_subset(): r.extra["sample_size"] == 60 for r in faith ) # faithfulness on the subset # cluster_validity still runs (on the full 100-point projection) - assert any(r.metric == "silhouette" for r in report.rows) + assert any(r.metric == "n_clusters" for r in report.rows) def test_faithfulness_honors_default_metric_when_info_lacks_metric(): @@ -414,7 +500,7 @@ def test_precomputed_embedding_skips_faithfulness(): rng_seed=42, ) assert not any(r.stat_family == "faithfulness" for r in report.rows) - assert any(r.metric == "silhouette" for r in report.rows) + assert any(r.metric == "n_clusters" for r in report.rows) def test_source_disambiguates_same_id_embeddings(): @@ -443,6 +529,41 @@ def test_source_disambiguates_same_id_embeddings(): assert embs["B — PCA 2"] == "B" +def test_driver_emits_embedding_and_projection_annotation_validity(): + from sklearn.decomposition import PCA + + from protspace.stats import compute_statistics + + X, y = _blobs(n=180, centers=4, dim=8, seed=71) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + # Named `hdrs` (not `headers`) so the class body below can read it: a class + # attribute assignment `headers = headers` would shadow the enclosing local + # in the class namespace before the RHS is resolved, raising NameError. + hdrs = [f"p{i}" for i in range(180)] + ann = {"grp": {pid: f"g{int(c)}" for pid, c in zip(hdrs, y, strict=False)}} + + class _Emb: + name = "e" + data = X + headers = hdrs + precomputed = False + + report = compute_statistics( + [_Emb()], + [{"name": "e — PCA 2", "data": coords, "ids": hdrs, "source": "e"}], + annotations=ann, + ) + av = [r for r in report.rows if r.stat_family == "annotation_validity"] + kinds = {(r.space_kind, r.annotation) for r in av} + assert ("embedding", "grp") in kinds # once-per-embedding pass + assert ("projection", "grp") in kinds # per-projection pass + # embedding is computed exactly once per (embedding, annotation, metric) + emb_sil = [ + r for r in av if r.space_kind == "embedding" and r.metric == "silhouette" + ] + assert len(emb_sil) == 1 + + def test_driver_isolates_failures(): class _Boom: family = "boom" @@ -460,7 +581,7 @@ def compute(self, ctx): statistics=[_Boom(), ClusterValidityStatistic()], ) # Boom is swallowed; cluster validity still produced rows. - assert any(r.metric == "silhouette" for r in report.rows) + assert any(r.metric == "n_clusters" for r in report.rows) # --------------------------------------------------------------------------- # @@ -468,11 +589,12 @@ def compute(self, ctx): # --------------------------------------------------------------------------- # -def _statrow(metric, value, *, destination=None, **extra): +def _statrow(metric, value, *, destination=None, annotation="", **extra): kw = {} if destination is None else {"destination": destination} return StatRow( space_kind="projection", space_name="PCA_2", + annotation=annotation, stat_family="cluster_validity", label_kind="kmeans_elbow", metric=metric, @@ -489,7 +611,7 @@ def test_statrow_defaults_to_statistics_part_destination(): def test_destination_is_not_a_tidy_table_column(): - # destination is carriage metadata, never a column in the 8-column schema. + # destination is carriage metadata, never a column in the tidy schema. rec = _statrow("silhouette", 0.5).to_record() assert "destination" not in rec assert set(rec) == set(STATS_SCHEMA.names) @@ -613,8 +735,8 @@ def test_cluster_annotations_can_be_disabled(): ) ) assert not any(isinstance(o, AnnotationColumn) for o in outs) - # aggregate validity rows still produced - assert any(getattr(o, "metric", None) == "silhouette" for o in outs) + # the n_clusters meta row is still produced (self-validity rows are gone) + assert any(getattr(o, "metric", None) == "n_clusters" for o in outs) # --------------------------------------------------------------------------- # @@ -691,25 +813,37 @@ def test_elbow_result_has_no_silhouette_optimal_k(): assert "silhouette_optimal_k" not in meta.extra -def test_aggregate_silhouette_equals_per_point_mean(): - """The aggregate silhouette is exactly the mean of the per-point silhouettes - attached to the membership column values (consistent, not a sampled estimate).""" +def test_membership_silhouette_matches_sklearn_per_point(): + """The per-point silhouette attached to the membership column (`cluster N|score`) + is exactly sklearn's `silhouette_samples` for the auto-cluster labels — an exact + per-point value, not a sampled estimate. (Previously cross-checked against the + now-removed aggregate `silhouette` StatRow; the aggregate is gone because + self-scoring the auto-clusters this way is circular, so this cross-checks the + per-point column directly against sklearn instead.)""" + from sklearn.metrics import silhouette_samples + from protspace.stats.base import AnnotationColumn + from protspace.stats.cluster.kmeans_elbow import kmeans_elbow X, _ = _blobs(n=300, centers=4, dim=2, seed=45) ids = [f"p{i}" for i in range(300)] outs = ClusterValidityStatistic().compute( StatContext("projection", "P", coords=X, ids=ids) ) - agg = next(o for o in outs if isinstance(o, StatRow) and o.metric == "silhouette") col = next( o for o in outs if isinstance(o, AnnotationColumn) and o.name == "cluster_elbow_P" ) - per_point = [float(v.split("|", 1)[1]) for v in col.values.values()] - assert agg.extra["sampled"] is False - assert agg.value == pytest.approx(float(np.mean(per_point)), abs=1e-4) + per_point = {pid: float(v.split("|", 1)[1]) for pid, v in col.values.items()} + + # Recompute independently via the same deterministic elbow selection. The + # membership column formats the attached score to 4 decimal places, so + # compare at that precision rather than bit-exact. + res = kmeans_elbow(X, rng_seed=42) + expected = silhouette_samples(X, res.labels) + for i, pid in enumerate(ids): + assert per_point[pid] == pytest.approx(float(expected[i]), abs=1e-4) @pytest.mark.parametrize("sample_threshold", [60, 1000]) diff --git a/tests/test_stats_carriage.py b/tests/test_stats_carriage.py index ad40a640..eda807cd 100644 --- a/tests/test_stats_carriage.py +++ b/tests/test_stats_carriage.py @@ -31,6 +31,7 @@ def _faith_row(space_name, metric_name, value, **extra): return StatRow( space_kind="projection", space_name=space_name, + annotation="", # faithfulness rows are not annotation-scoped stat_family="faithfulness", label_kind="none", metric=metric_name, diff --git a/tests/test_stats_cli.py b/tests/test_stats_cli.py index 4db44fd7..03fa2666 100644 --- a/tests/test_stats_cli.py +++ b/tests/test_stats_cli.py @@ -4,6 +4,7 @@ import h5py import numpy as np +import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import pytest @@ -69,8 +70,11 @@ def test_discrete_path_produces_full_matrix(tmp_path): reductions = _load_reductions(proj) report = compute_statistics([emb], reductions, rng_seed=42) metrics = {r.metric for r in report.rows} - # cluster-validity (coords only) + faithfulness (embedding matched by id-join) - assert {"silhouette", "n_clusters"} <= metrics + # Without annotations, cluster_validity emits only the K-selection meta row + # (n_clusters) — silhouette/DBI/CH now live in annotation_validity and require + # `annotations=`, which this call doesn't supply. Faithfulness (embedding + # matched by id-join) is unconditional. + assert {"n_clusters"} <= metrics assert {"knn_overlap", "trustworthiness", "continuity"} <= metrics # The fifth part (to_arrow) now carries aggregate validity only — faithfulness # routes to projection metadata, not this table (route-projection-statistics). @@ -83,7 +87,13 @@ def test_discrete_path_produces_full_matrix(tmp_path): def test_stats_command_writes_aggregate_only_part(tmp_path): """`protspace stats -o statistics.parquet` writes validity/meta rows only — faithfulness now rides in projection metadata, not this fifth part - (route-projection-statistics Phase 1A; the prep stats+bundle path stays valid).""" + (route-projection-statistics Phase 1A; the prep stats+bundle path stays valid). + + Without -a/--annotations, no annotation labels are built, so cluster_validity + contributes only its K-selection meta row (n_clusters) — annotation-based + silhouette/DBI/CH (family=annotation_validity) and cluster_agreement (ARI/NMI) + both require annotations and are absent here (see + test_stats_command_computes_annotation_validity for the annotated path).""" from typer.testing import CliRunner from protspace.cli.app import app @@ -100,12 +110,7 @@ def test_stats_command_writes_aggregate_only_part(tmp_path): families = set(table.column("stat_family").to_pylist()) assert families == {"cluster_validity"} metrics = set(table.column("metric").to_pylist()) - assert { - "silhouette", - "davies_bouldin", - "calinski_harabasz", - "n_clusters", - } <= metrics + assert metrics == {"n_clusters"} assert not ({"knn_overlap", "trustworthiness", "continuity"} & metrics) @@ -457,6 +462,62 @@ def __init__(self, name, data, headers): assert {"knn_overlap", "trustworthiness", "continuity"} <= set(quality) +def test_prepare_stats_annotation_validity_in_bundle(tmp_path): + """`prepare --stats --stats-annotation auto` (with -a CSV) threads annotation + labels through `PipelineConfig.stats_annotation` into `_compute_statistics`, + so the resulting bundle's fifth part gains `annotation_validity` rows scored + against the CSV's categorical column.""" + import io + + from typer.testing import CliRunner + + from protspace.cli.app import app + from protspace.data.io.bundle import read_statistics_from_bundle + + X, _ = make_blobs(n_samples=120, centers=3, n_features=5, random_state=1) + headers = [f"p{i}" for i in range(120)] + + h5_path = tmp_path / "emb.h5" + with h5py.File(h5_path, "w") as f: + for i, h in enumerate(headers): + f.create_dataset(h, data=X[i].astype(np.float32)) + + ann_path = tmp_path / "grp.csv" + groups = ["a" if i % 2 == 0 else "b" for i in range(len(headers))] + pd.DataFrame({"identifier": headers, "grp": groups}).to_csv(ann_path, index=False) + + output_dir = tmp_path / "out" + result = CliRunner().invoke( + app, + [ + "prepare", + "-i", + f"{h5_path}:E", + "-a", + str(ann_path), + "-m", + "pca2", + "--stats", + "--stats-annotation", + "auto", + "--no-log", + "-o", + str(output_dir), + ], + ) + assert result.exit_code == 0, result.output + + bundle_path = output_dir / "data.parquetbundle" + assert bundle_path.exists() + raw = read_statistics_from_bundle(bundle_path) + assert raw is not None + st = pq.read_table(io.BytesIO(raw)).to_pandas() + assert (st.stat_family == "annotation_validity").any() + assert "annotation" in st.columns + av = st[st.stat_family == "annotation_validity"] + assert set(av["annotation"]) == {"grp"} + + def test_stats_rejects_bad_cluster_selection(tmp_path): """`--cluster-selection` is validated (fail-fast) rather than silently ignored.""" from typer.testing import CliRunner @@ -480,3 +541,97 @@ def test_stats_rejects_bad_cluster_selection(tmp_path): ], ) assert result.exit_code != 0 + + +def test_stats_command_computes_annotation_validity(tmp_path): + """`stats --stats-annotation auto` (with -a) builds annotation labels and + threads them into `compute_statistics`, so the fifth part gains + annotation_validity rows for both the source embedding and the projection.""" + from typer.testing import CliRunner + + from protspace.cli.app import app + + h5_path, proj, ids = _project_dir(tmp_path) # returns (h5, proj_dir, id list) + ann_path = tmp_path / "annotations.parquet" + # A separable categorical annotation over the same ids. + groups = ["a" if i % 2 else "b" for i in range(len(ids))] + pq.write_table(pa.table({"identifier": ids, "major_group": groups}), str(ann_path)) + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, + [ + "stats", + "-i", + f"{h5_path}:E", + "-p", + str(proj), + "-o", + str(out), + "-a", + str(ann_path), + "--stats-annotation", + "auto", + ], + ) + assert result.exit_code == 0, result.output + st = pq.read_table(str(out)).to_pandas() + assert "annotation" in st.columns + av = st[st.stat_family == "annotation_validity"] + assert set(av["annotation"]) == {"major_group"} + assert {"embedding", "projection"} <= set(av["space_kind"]) + + +def test_stats_rejects_no_annotation_source_for_stats_annotation(tmp_path): + """`--stats-annotation ` (non-`auto`) with no -a/--annotations has + nothing to score, so it must fail fast rather than silently doing nothing.""" + from typer.testing import CliRunner + + from protspace.cli.app import app + + h5_path, proj, _ = _project_dir(tmp_path) + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, + [ + "stats", + "-i", + f"{h5_path}:E", + "-p", + str(proj), + "-o", + str(out), + "--stats-annotation", + "major_group", # no -a + ], + ) + assert result.exit_code != 0 + + +def test_stats_annotation_auto_case_and_whitespace_without_annotations_ok(tmp_path): + """`--stats-annotation` values that normalise to ``auto`` (e.g. mixed case with + surrounding whitespace) must not be wrongly rejected by the `-a/--annotations` + guard when no `-a` is given — the guard has to use the same + ``.strip().lower() == "auto"`` normalisation the parser uses below it, not a + strict ``!= "auto"`` comparison.""" + from typer.testing import CliRunner + + from protspace.cli.app import app + + h5_path, proj, _ = _project_dir(tmp_path) + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, + [ + "stats", + "-i", + f"{h5_path}:E", + "-p", + str(proj), + "-o", + str(out), + "--stats-annotation", + " Auto ", # no -a; must normalise to "auto", not be rejected + ], + ) + assert result.exit_code == 0, result.output + assert out.exists()