Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,27 @@ def _detect_condition(eval_dir: Path) -> str:
raise ValueError(f"cannot infer condition from eval_dir name {name!r}: expected trailing _{{mock,denv,zikv}}")


def _load_embeddings(eval_dir: Path, source: str, feature: str) -> tuple[np.ndarray, np.ndarray]:
def _load_embeddings(
eval_dir: Path,
source: str,
feature: str,
cache: dict[Path, tuple[np.ndarray, np.ndarray]] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Return ``(embeddings, fov_ids)`` from one ``*_single_cell_embeddings.npz``.

``np.load`` raises ``FileNotFoundError`` when the NPZ is missing.
``np.load`` raises ``FileNotFoundError`` when the NPZ is missing. When
*cache* is given, the result is memoized on the resolved NPZ path so the
shared reference (``mock``) side is read from disk once per group instead
of once per pair (see :func:`run` / :func:`run_for_group`).
"""
npz_path = eval_dir / "embeddings" / f"{source}_{feature}_single_cell_embeddings.npz"
if cache is not None and npz_path in cache:
return cache[npz_path]
with np.load(npz_path) as data:
return np.asarray(data["embeddings"]), np.asarray(data["fov"])
result = (np.asarray(data["embeddings"]), np.asarray(data["fov"]))
if cache is not None:
cache[npz_path] = result
return result


def _probe_pair(
Expand All @@ -82,6 +95,7 @@ def _probe_pair(
source: str,
n_splits: int,
rng_seed: int,
cache: dict[Path, tuple[np.ndarray, np.ndarray]] | None = None,
) -> dict:
"""Run one ``fov_stratified_auroc`` call for the given (pair, feature, source).

Expand All @@ -106,8 +120,8 @@ def _probe_pair(
row["skipped_reason"] = "missing eval dir for one side of pair"
return row
try:
x0, fov0 = _load_embeddings(eval_dirs_by_condition[c0], source, feature)
x1, fov1 = _load_embeddings(eval_dirs_by_condition[c1], source, feature)
x0, fov0 = _load_embeddings(eval_dirs_by_condition[c0], source, feature, cache)
x1, fov1 = _load_embeddings(eval_dirs_by_condition[c1], source, feature, cache)
except FileNotFoundError as e:
row["skipped_reason"] = f"missing embeddings file: {e}"
return row
Expand Down Expand Up @@ -182,7 +196,9 @@ def run_for_group(
eval_dirs : list[Path]
Per-condition eval dirs of one (model, pool, organelle) group. The
condition is inferred from each dir's trailing ``_{mock,denv,zikv}``;
dirs without a recognized token are ignored.
dirs without a recognized token are ignored. Two dirs mapping to the
same condition raise ``ValueError`` (an ambiguous group) rather than
silently picking one.
n_splits, rng_seed : int
Forwarded to :func:`fov_stratified_auroc`.
"""
Expand All @@ -192,16 +208,21 @@ def run_for_group(
cond = _detect_condition(d)
except ValueError:
continue
if cond in by_condition:
raise ValueError(f"duplicate condition {cond!r}: {by_condition[cond]} and {d}")
by_condition[cond] = d
if "mock" not in by_condition:
return []

# Shared across pairs so the mock reference embeddings are read once, not
# re-read for every infected condition. Local to this call -> released on return.
cache: dict[Path, tuple[np.ndarray, np.ndarray]] = {}
written: list[Path] = []
for ref, cond in _DEFAULT_PAIRS: # ref == "mock" for every default pair
if cond not in by_condition:
continue
rows = [
_probe_pair(by_condition, (ref, cond), feature, source, n_splits, rng_seed)
_probe_pair(by_condition, (ref, cond), feature, source, n_splits, rng_seed, cache)
for feature in _FEATURE_TYPES
for source in _SOURCES
]
Expand Down Expand Up @@ -239,11 +260,14 @@ def run(
raise ValueError(f"duplicate condition {cond!r}: {eval_dirs_by_condition[cond]} and {d}")
eval_dirs_by_condition[cond] = d

# Shared across pairs so a condition's embeddings (e.g. the mock reference
# reused by every pair) are read once. Local to this call -> released on return.
cache: dict[Path, tuple[np.ndarray, np.ndarray]] = {}
rows = []
for feature in _FEATURE_TYPES:
for pair in pairs:
for source in _SOURCES:
rows.append(_probe_pair(eval_dirs_by_condition, pair, feature, source, n_splits, rng_seed))
rows.append(_probe_pair(eval_dirs_by_condition, pair, feature, source, n_splits, rng_seed, cache))

_write_rows(out_path, rows)
return out_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

import numpy as np
from cubic.metrics import average_precision

DEFAULT_IOU_THRESHOLDS = (0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95)
"""IoU thresholds for the AP sweep (Cellpose / StarDist standard 0.50..0.95)."""
Expand Down Expand Up @@ -70,6 +69,12 @@ def instance_average_precision(
ap_vals = [0.0] * len(thresholds)
tp, fp, fn = 0.0, float(n_pred), float(n_gt)
else:
# Imported lazily (not at module top) so importing this module for
# DEFAULT_IOU_THRESHOLDS / _relabel_sequential — e.g. pipeline_cache pulling
# the threshold constant — does not require the GPU-only cubic stack. The
# actual AP computation still hard-requires cubic and fails loudly here.
from cubic.metrics import average_precision

ap, tp_arr, fp_arr, fn_arr = average_precision(gt, pred, thresholds)
ap_vals = [float(a) for a in np.atleast_1d(ap)]
idx = thresholds.index(_PRIMARY_THRESHOLD) if _PRIMARY_THRESHOLD in thresholds else 0
Expand Down