From 0b529ddcb472afba84476d9280f81f0fa425c4c7 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 14 May 2026 09:54:44 -0700 Subject: [PATCH 01/13] Port scale-aware inference and LOT batch correction to modular-viscy-staging Scale-aware patch extraction (packages/viscy-data): - Add _read_pixel_size() helper to read X pixel size from OME-Zarr metadata - Add reference_pixel_size parameter to TripletDataModule: when set, computes initial_yx_patch_size from the pixel-size ratio so the same physical area is covered at inference time - Replace BatchedRescaleYXd (removed per review) with existing BatchedZoomd using scale_factor=(1.0, final_y/initial_y, final_x/initial_x) and mode="bilinear" with antialias; import is lazy to avoid a hard dep LOT batch correction (applications/dynaclr): - Add dynaclr.evaluation.lot_correction submodule with core logic (fit, apply, save, load), Pydantic configs, and Click CLI entry points - Register fit-lot-correction and apply-lot-correction in cli.py via LazyCommand - Add pot and joblib to optional-dependencies.eval in pyproject.toml Co-Authored-By: Claude Sonnet 4.6 --- applications/dynaclr/pyproject.toml | 2 + applications/dynaclr/src/dynaclr/cli.py | 16 + .../evaluation/lot_correction/__init__.py | 15 + .../lot_correction/apply_lot_correction.py | 84 ++++++ .../evaluation/lot_correction/config.py | 118 ++++++++ .../lot_correction/fit_lot_correction.py | 95 ++++++ .../lot_correction/lot_correction.py | 284 ++++++++++++++++++ packages/viscy-data/src/viscy_data/triplet.py | 64 +++- 8 files changed, 675 insertions(+), 3 deletions(-) create mode 100644 applications/dynaclr/src/dynaclr/evaluation/lot_correction/__init__.py create mode 100644 applications/dynaclr/src/dynaclr/evaluation/lot_correction/apply_lot_correction.py create mode 100644 applications/dynaclr/src/dynaclr/evaluation/lot_correction/config.py create mode 100644 applications/dynaclr/src/dynaclr/evaluation/lot_correction/fit_lot_correction.py create mode 100644 applications/dynaclr/src/dynaclr/evaluation/lot_correction/lot_correction.py diff --git a/applications/dynaclr/pyproject.toml b/applications/dynaclr/pyproject.toml index 4cce8f180..ab36b670e 100644 --- a/applications/dynaclr/pyproject.toml +++ b/applications/dynaclr/pyproject.toml @@ -53,8 +53,10 @@ dependencies = [ optional-dependencies.eval = [ "anndata", "dtaidistance", + "joblib", "natsort", "phate", + "pot", "scikit-learn", "statsmodels", "umap-learn", diff --git a/applications/dynaclr/src/dynaclr/cli.py b/applications/dynaclr/src/dynaclr/cli.py index 40d845fdc..30f28b322 100644 --- a/applications/dynaclr/src/dynaclr/cli.py +++ b/applications/dynaclr/src/dynaclr/cli.py @@ -264,6 +264,22 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="fit-lot-correction", + import_path="dynaclr.evaluation.lot_correction.fit_lot_correction.main", + short_help="Fit a LOT batch-correction pipeline on source and target embedding zarrs", + ) +) + +dynaclr.add_command( + LazyCommand( + name="apply-lot-correction", + import_path="dynaclr.evaluation.lot_correction.apply_lot_correction.main", + short_help="Apply a fitted LOT pipeline to correct batch effects in an embedding zarr", + ) +) + def main(): """Run the DynaCLR CLI. diff --git a/applications/dynaclr/src/dynaclr/evaluation/lot_correction/__init__.py b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/__init__.py new file mode 100644 index 000000000..e08fbd0f6 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/__init__.py @@ -0,0 +1,15 @@ +"""LOT (Linear Optimal Transport) batch correction for embedding zarrs.""" + +from dynaclr.evaluation.lot_correction.lot_correction import ( + apply_lot_correction, + fit_lot_correction, + load_lot_pipeline, + save_lot_pipeline, +) + +__all__ = [ + "fit_lot_correction", + "apply_lot_correction", + "save_lot_pipeline", + "load_lot_pipeline", +] diff --git a/applications/dynaclr/src/dynaclr/evaluation/lot_correction/apply_lot_correction.py b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/apply_lot_correction.py new file mode 100644 index 000000000..f42085454 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/apply_lot_correction.py @@ -0,0 +1,84 @@ +"""CLI for applying a fitted LOT pipeline to an embedding zarr. + +Usage +----- + dynaclr apply-lot-correction -c config.yaml + +Transforms all cells through StandardScaler → PCA → LOT and writes a new +zarr whose ``.X`` contains the corrected embeddings (shape n_cells × n_pca). +All ``.obs`` metadata from the input zarr is preserved. + +Example config (YAML) +--------------------- + input_zarr: /path/to/lightsheet_organelle.zarr + pipeline: /path/to/lot_pipeline.pkl + output_zarr: /path/to/corrected_organelle.zarr + overwrite: false +""" + +import logging +from pathlib import Path + +import click +from pydantic import ValidationError + +from dynaclr.evaluation.lot_correction.config import LotApplyConfig +from dynaclr.evaluation.lot_correction.lot_correction import ( + apply_lot_correction, + load_lot_pipeline, +) +from viscy_utils.cli_utils import load_config + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file.", +) +def main(config: Path): + """Apply a fitted LOT pipeline to correct batch effects in an embedding zarr.""" + click.echo("=" * 60) + click.echo("LOT BATCH CORRECTION — APPLY") + click.echo("=" * 60) + + try: + config_dict = load_config(config) + apply_config = LotApplyConfig(**config_dict) + except ValidationError as e: + click.echo(f"\nConfiguration validation failed:\n{e}", err=True) + raise click.Abort() + except Exception as e: + click.echo(f"\nFailed to load configuration: {e}", err=True) + raise click.Abort() + + click.echo(f"\nConfiguration loaded: {config}") + click.echo(f" Input zarr: {apply_config.input_zarr}") + click.echo(f" Pipeline: {apply_config.pipeline}") + click.echo(f" Output zarr: {apply_config.output_zarr}") + click.echo(f" Overwrite: {apply_config.overwrite}") + + try: + pipeline = load_lot_pipeline(apply_config.pipeline) + click.echo( + f"\nPipeline loaded — n_pca={pipeline['n_pca']}, " + f"PCA variance={pipeline.get('pca_variance_explained', float('nan')):.1f}%" + ) + apply_lot_correction( + input_zarr=apply_config.input_zarr, + pipeline=pipeline, + output_zarr=apply_config.output_zarr, + overwrite=apply_config.overwrite, + ) + click.echo(f"\nCorrected zarr written to: {apply_config.output_zarr}") + except Exception as e: + click.echo(f"\nApplication failed: {e}", err=True) + raise click.Abort() + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/lot_correction/config.py b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/config.py new file mode 100644 index 000000000..c3c899194 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/config.py @@ -0,0 +1,118 @@ +"""Pydantic configuration models for LOT batch correction.""" + +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel, Field, model_validator + + +class UninfFilter(BaseModel): + """Specification for selecting uninfected reference cells from an obs table. + + Exactly one of ``startswith`` or ``equals`` must be provided. + + Parameters + ---------- + column : str + Name of the ``.obs`` column to filter on (e.g. ``"fov_name"``). + startswith : str or list[str], optional + Keep cells whose column value starts with any of these prefixes. + equals : str, optional + Keep cells whose column value equals this string. + """ + + column: str = Field(..., min_length=1) + startswith: Optional[Union[str, list[str]]] = Field(default=None) + equals: Optional[str] = Field(default=None) + + @model_validator(mode="after") + def exactly_one_filter(self): + has_sw = self.startswith is not None + has_eq = self.equals is not None + if not has_sw and not has_eq: + raise ValueError("UninfFilter must specify either 'startswith' or 'equals'.") + if has_sw and has_eq: + raise ValueError("UninfFilter must specify only one of 'startswith' or 'equals'.") + return self + + def to_dict(self) -> dict: + """Convert to the dict format expected by _apply_filter.""" + d = {"column": self.column} + if self.startswith is not None: + d["startswith"] = self.startswith + else: + d["equals"] = self.equals + return d + + +class LotFitConfig(BaseModel): + """Configuration for fitting a LOT batch-correction pipeline. + + Parameters + ---------- + source_zarr : str + Path to the source AnnData zarr (e.g. light-sheet embeddings). + target_zarr : str + Path to the target AnnData zarr (e.g. confocal embeddings). + source_uninf_filter : UninfFilter + Filter identifying uninfected cells in the source dataset. + target_uninf_filter : UninfFilter + Filter identifying uninfected cells in the target dataset. + n_pca : int, optional + Number of PCA components for the shared PCA, by default 50. + ns_lot : int, optional + Maximum cells subsampled per dataset for LOT fitting, by default 3000. + random_seed : int, optional + Random seed, by default 42. + output_pipeline : str + Path to save the fitted pipeline (joblib pickle). + """ + + source_zarr: str = Field(..., min_length=1) + target_zarr: str = Field(..., min_length=1) + source_uninf_filter: UninfFilter + target_uninf_filter: UninfFilter + n_pca: int = Field(default=50, gt=0) + ns_lot: int = Field(default=3000, gt=0) + random_seed: int = Field(default=42) + output_pipeline: str = Field(..., min_length=1) + + @model_validator(mode="after") + def validate_paths(self): + if not Path(self.source_zarr).exists(): + raise ValueError(f"source_zarr not found: {self.source_zarr}") + if not Path(self.target_zarr).exists(): + raise ValueError(f"target_zarr not found: {self.target_zarr}") + return self + + +class LotApplyConfig(BaseModel): + """Configuration for applying a fitted LOT pipeline to a zarr. + + Parameters + ---------- + input_zarr : str + Path to the source AnnData zarr to correct. + pipeline : str + Path to the fitted pipeline file (joblib pickle). + output_zarr : str + Path to write the corrected AnnData zarr. + overwrite : bool, optional + Overwrite output if it exists, by default False. + """ + + input_zarr: str = Field(..., min_length=1) + pipeline: str = Field(..., min_length=1) + output_zarr: str = Field(..., min_length=1) + overwrite: bool = Field(default=False) + + @model_validator(mode="after") + def validate_paths(self): + if not Path(self.input_zarr).exists(): + raise ValueError(f"input_zarr not found: {self.input_zarr}") + if not Path(self.pipeline).exists(): + raise ValueError(f"pipeline file not found: {self.pipeline}") + output = Path(self.output_zarr) + if output.exists() and not self.overwrite: + raise ValueError(f"output_zarr already exists: {self.output_zarr}. Set overwrite: true to overwrite.") + return self diff --git a/applications/dynaclr/src/dynaclr/evaluation/lot_correction/fit_lot_correction.py b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/fit_lot_correction.py new file mode 100644 index 000000000..7ecff69b5 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/fit_lot_correction.py @@ -0,0 +1,95 @@ +"""CLI for fitting a LOT batch-correction pipeline on embedding zarrs. + +Usage +----- + dynaclr fit-lot-correction -c config.yaml + +The fitted pipeline (StandardScaler + PCA + LinearTransport) is saved to +the path specified by ``output_pipeline`` in the config file. + +Example config (YAML) +--------------------- + source_zarr: /path/to/lightsheet_organelle.zarr + target_zarr: /path/to/confocal_organelle.zarr + source_uninf_filter: + column: fov_name + startswith: + - "C/1/" + target_uninf_filter: + column: fov_name + startswith: + - "G3BP1/uninfected" + n_pca: 50 + ns_lot: 3000 + random_seed: 42 + output_pipeline: /path/to/lot_pipeline.pkl +""" + +import logging +from pathlib import Path + +import click +from pydantic import ValidationError + +from dynaclr.evaluation.lot_correction.config import LotFitConfig +from dynaclr.evaluation.lot_correction.lot_correction import ( + fit_lot_correction, + save_lot_pipeline, +) +from viscy_utils.cli_utils import load_config + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file.", +) +def main(config: Path): + """Fit a LOT batch-correction pipeline on source and target embedding zarrs.""" + click.echo("=" * 60) + click.echo("LOT BATCH CORRECTION — FIT") + click.echo("=" * 60) + + try: + config_dict = load_config(config) + fit_config = LotFitConfig(**config_dict) + except ValidationError as e: + click.echo(f"\nConfiguration validation failed:\n{e}", err=True) + raise click.Abort() + except Exception as e: + click.echo(f"\nFailed to load configuration: {e}", err=True) + raise click.Abort() + + click.echo(f"\nConfiguration loaded: {config}") + click.echo(f" Source zarr: {fit_config.source_zarr}") + click.echo(f" Target zarr: {fit_config.target_zarr}") + click.echo(f" n_pca: {fit_config.n_pca}") + click.echo(f" ns_lot: {fit_config.ns_lot}") + click.echo(f" Random seed: {fit_config.random_seed}") + click.echo(f" Output: {fit_config.output_pipeline}") + + try: + pipeline = fit_lot_correction( + source_zarr=fit_config.source_zarr, + target_zarr=fit_config.target_zarr, + source_uninf_filter=fit_config.source_uninf_filter.to_dict(), + target_uninf_filter=fit_config.target_uninf_filter.to_dict(), + n_pca=fit_config.n_pca, + ns_lot=fit_config.ns_lot, + random_seed=fit_config.random_seed, + ) + click.echo(f"\nPipeline fitted — PCA explained variance: {pipeline['pca_variance_explained']:.1f}%") + save_lot_pipeline(pipeline, fit_config.output_pipeline) + click.echo(f"Pipeline saved to: {fit_config.output_pipeline}") + except Exception as e: + click.echo(f"\nFitting failed: {e}", err=True) + raise click.Abort() + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/lot_correction/lot_correction.py b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/lot_correction.py new file mode 100644 index 000000000..bc56fed34 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/lot_correction/lot_correction.py @@ -0,0 +1,284 @@ +"""Core functions for LOT (Linear Optimal Transport) batch correction. + +Pipeline +-------- +1. Load source and target embedding zarrs (AnnData format). +2. Filter cells to the uninfected reference population in each dataset. +3. Fit a shared StandardScaler + PCA on the combined source + target cells. +4. Fit a LinearTransport (LOT) map in PCA space using uninfected cells only, + mapping source → target distribution. +5. Save the fitted pipeline (scaler, PCA, LOT) to disk with joblib. + +The saved pipeline can then be applied to any source zarr to produce a new +zarr whose embeddings are in the target's PCA coordinate system, corrected +for cross-platform batch effects. +""" + +import logging +from pathlib import Path +from typing import Union + +import anndata as ad +import joblib +import numpy as np +import ot +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +_logger = logging.getLogger(__name__) + + +def _to_np(X) -> np.ndarray: + """Convert sparse or dense matrix to float32 numpy array.""" + return np.array(X.toarray() if hasattr(X, "toarray") else X, dtype=np.float32) + + +def _apply_filter(obs, filter_spec: dict) -> np.ndarray: + """Return a boolean mask for rows of *obs* matching *filter_spec*. + + Parameters + ---------- + obs : pd.DataFrame + AnnData ``.obs`` table. + filter_spec : dict + Must contain ``"column"`` plus one of: + + * ``"startswith"`` – str or list[str]: keep rows where the column + value starts with any of the given prefixes. + * ``"equals"`` – str: keep rows where the column value equals the + given string. + + Returns + ------- + np.ndarray of bool + Boolean mask with the same length as *obs*. + """ + col = filter_spec["column"] + values = obs[col].astype(str) + + if "startswith" in filter_spec: + prefixes = filter_spec["startswith"] + if isinstance(prefixes, str): + prefixes = [prefixes] + mask = np.zeros(len(obs), dtype=bool) + for p in prefixes: + mask |= values.str.startswith(p).values + return mask + + if "equals" in filter_spec: + return (values == str(filter_spec["equals"])).values + + raise ValueError(f"filter_spec must contain either 'startswith' or 'equals'. Got: {list(filter_spec.keys())}") + + +def fit_lot_correction( + source_zarr: Union[str, Path], + target_zarr: Union[str, Path], + source_uninf_filter: dict, + target_uninf_filter: dict, + n_pca: int = 50, + ns_lot: int = 3000, + random_seed: int = 42, +) -> dict: + """Fit a shared PCA + LOT batch-correction pipeline. + + Parameters + ---------- + source_zarr : str or Path + Path to the source AnnData zarr (e.g. light-sheet embeddings). + target_zarr : str or Path + Path to the target AnnData zarr (e.g. confocal embeddings). + source_uninf_filter : dict + Filter spec selecting uninfected source cells used to fit LOT. + target_uninf_filter : dict + Filter spec selecting uninfected target cells used to fit LOT. + n_pca : int, optional + Number of PCA components, by default 50. + ns_lot : int, optional + Maximum number of cells subsampled per dataset for LOT fitting, + by default 3000. + random_seed : int, optional + Random seed for reproducibility, by default 42. + + Returns + ------- + dict with keys ``"scaler"``, ``"pca"``, ``"lot"``, ``"n_pca"``, + ``"ns_lot"``, ``"random_seed"``, ``"pca_variance_explained"``. + """ + rng = np.random.default_rng(random_seed) + + _logger.info("Loading source zarr: %s", source_zarr) + adata_src = ad.read_zarr(source_zarr) + adata_src.obs_names_make_unique() + + _logger.info("Loading target zarr: %s", target_zarr) + adata_tgt = ad.read_zarr(target_zarr) + adata_tgt.obs_names_make_unique() + + _logger.info("Source shape: %s Target shape: %s", adata_src.shape, adata_tgt.shape) + + X_src = _to_np(adata_src.X) + X_tgt = _to_np(adata_tgt.X) + + src_uninf_mask = _apply_filter(adata_src.obs, source_uninf_filter) + tgt_uninf_mask = _apply_filter(adata_tgt.obs, target_uninf_filter) + + _logger.info( + "Uninfected cells — source: %d / %d, target: %d / %d", + src_uninf_mask.sum(), + len(X_src), + tgt_uninf_mask.sum(), + len(X_tgt), + ) + + if src_uninf_mask.sum() < 5 or tgt_uninf_mask.sum() < 5: + raise ValueError( + "Too few uninfected cells to fit LOT " + f"(source={src_uninf_mask.sum()}, target={tgt_uninf_mask.sum()}). " + "Check your filter specifications." + ) + + _logger.info("Fitting shared StandardScaler + PCA-%d ...", n_pca) + scaler = StandardScaler() + X_combined_scaled = scaler.fit_transform(np.vstack([X_src, X_tgt])) + pca = PCA(n_components=n_pca, random_state=random_seed) + Z_all = pca.fit_transform(X_combined_scaled) + var_exp = pca.explained_variance_ratio_.sum() * 100 + _logger.info("PCA explained variance: %.1f%%", var_exp) + + n_src = len(X_src) + Z_src_uninf = Z_all[:n_src][src_uninf_mask] + Z_tgt_uninf = Z_all[n_src:][tgt_uninf_mask] + + ns_src = min(len(Z_src_uninf), ns_lot) + ns_tgt = min(len(Z_tgt_uninf), ns_lot) + idx_src = rng.choice(len(Z_src_uninf), ns_src, replace=False) + idx_tgt = rng.choice(len(Z_tgt_uninf), ns_tgt, replace=False) + + _logger.info("Fitting LOT (source subsample=%d, target subsample=%d) ...", ns_src, ns_tgt) + lot = ot.da.LinearTransport(reg=1e-3) + lot.fit(Xs=Z_src_uninf[idx_src], Xt=Z_tgt_uninf[idx_tgt]) + _logger.info("LOT fitted.") + + return { + "scaler": scaler, + "pca": pca, + "lot": lot, + "n_pca": n_pca, + "ns_lot": ns_lot, + "random_seed": random_seed, + "pca_variance_explained": float(var_exp), + } + + +def apply_lot_correction( + input_zarr: Union[str, Path], + pipeline: dict, + output_zarr: Union[str, Path], + overwrite: bool = False, +) -> None: + """Apply a fitted LOT pipeline to an embedding zarr. + + Transforms all cells through StandardScaler → PCA → LOT and writes an + AnnData zarr whose ``.X`` contains the corrected embeddings in the + target's PCA space. All ``.obs`` metadata is preserved. + + Parameters + ---------- + input_zarr : str or Path + Path to the source AnnData zarr to correct. + pipeline : dict + Fitted pipeline as returned by :func:`fit_lot_correction`. + output_zarr : str or Path + Path to write the corrected AnnData zarr. + overwrite : bool, optional + If ``False`` (default) and *output_zarr* already exists, raise. + """ + import shutil + + import pandas as pd + + output_zarr = Path(output_zarr) + if output_zarr.exists(): + if not overwrite: + raise FileExistsError(f"Output path already exists: {output_zarr}. Set overwrite=true to overwrite.") + shutil.rmtree(output_zarr) + + _logger.info("Loading input zarr: %s", input_zarr) + adata_in = ad.read_zarr(input_zarr) + adata_in.obs_names_make_unique() + + X = _to_np(adata_in.X) + _logger.info("Input shape: %s", adata_in.shape) + + scaler = pipeline["scaler"] + pca = pipeline["pca"] + lot = pipeline["lot"] + + _logger.info("Applying StandardScaler → PCA → LOT ...") + Z = pca.transform(scaler.transform(X)) + Z_corrected = lot.transform(Z) + _logger.info("Corrected embeddings shape: %s (n_pca=%d)", Z_corrected.shape, pipeline["n_pca"]) + + obs = adata_in.obs.copy() + for col in obs.columns: + dtype = obs[col].dtype + if isinstance(dtype, pd.StringDtype): + obs[col] = obs[col].astype(object) + elif isinstance(dtype, pd.CategoricalDtype) and isinstance(dtype.categories.dtype, pd.StringDtype): + obs[col] = obs[col].astype(object).astype("category") + + try: + ad.settings.allow_write_nullable_strings = True + except AttributeError: + pass + + adata_out = ad.AnnData(X=Z_corrected.astype(np.float32), obs=obs) + adata_out.uns["lot_correction"] = { + "source_zarr": str(input_zarr), + "n_pca": pipeline["n_pca"], + "pca_variance_explained": pipeline.get("pca_variance_explained"), + } + + _logger.info("Writing corrected zarr: %s", output_zarr) + adata_out.write_zarr(output_zarr) + _logger.info("Done.") + + +def save_lot_pipeline(pipeline: dict, path: Union[str, Path]) -> None: + """Save a fitted LOT pipeline to disk using joblib. + + Parameters + ---------- + pipeline : dict + Fitted pipeline as returned by :func:`fit_lot_correction`. + path : str or Path + Output path (e.g. ``lot_pipeline.pkl``). + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + joblib.dump(pipeline, path) + _logger.info("Pipeline saved to %s", path) + + +def load_lot_pipeline(path: Union[str, Path]) -> dict: + """Load a fitted LOT pipeline from disk. + + Parameters + ---------- + path : str or Path + Path to the saved pipeline file. + + Returns + ------- + dict + Pipeline with keys ``"scaler"``, ``"pca"``, ``"lot"``. + """ + pipeline = joblib.load(path) + _logger.info( + "Pipeline loaded from %s (n_pca=%d, pca_var=%.1f%%)", + path, + pipeline["n_pca"], + pipeline.get("pca_variance_explained", float("nan")), + ) + return pipeline diff --git a/packages/viscy-data/src/viscy_data/triplet.py b/packages/viscy-data/src/viscy_data/triplet.py index deed9fe57..5f6592ff7 100644 --- a/packages/viscy-data/src/viscy_data/triplet.py +++ b/packages/viscy-data/src/viscy_data/triplet.py @@ -40,6 +40,25 @@ _logger = logging.getLogger("lightning.pytorch") +def _read_pixel_size(data_path: str | Path) -> float: + """Read the X pixel size (µm/pixel) from the first FOV in an OME-Zarr dataset. + + Parameters + ---------- + data_path : str | Path + Path to the OME-Zarr plate or position. + + Returns + ------- + float + X pixel size in micrometers per pixel. + """ + with open_ome_zarr(data_path, mode="r") as store: + for _, pos in store.positions(): + return float(pos.scale[-1]) + raise ValueError(f"No positions found in {data_path}") + + def _default_tensorstore_config(cache_pool_bytes: int = 0) -> TensorStoreConfig: """Build a TensorStoreConfig with SLURM-aware concurrency.""" cpus = os.environ.get("SLURM_CPUS_PER_TASK") @@ -316,6 +335,7 @@ def __init__( pin_memory: bool = False, z_window_size: int | None = None, cache_pool_bytes: int = 0, + reference_pixel_size: float | None = None, ): """Lightning data module for triplet sampling of patches. @@ -330,9 +350,20 @@ def __init__( z_range : tuple[int, int] Range of valid z-slices initial_yx_patch_size : tuple[int, int], optional - XY size of the initially sampled image patch, by default (512, 512) + YX size of the initially sampled image patch, by default (512, 512). + Ignored when ``reference_pixel_size`` is set — the patch size is then + computed automatically from the pixel-size ratio. final_yx_patch_size : tuple[int, int], optional Output patch size, by default (224, 224) + reference_pixel_size : float | None, optional + X pixel size (µm/pixel) of the dataset used to train the model. + When provided, reads the pixel size of the inference dataset from + its OME-Zarr metadata and computes + ``initial_yx_patch_size = round(final_yx_patch_size * + reference_pixel_size / inference_pixel_size)`` so that the same + physical area is covered. The extracted patch is then rescaled to + ``final_yx_patch_size`` with bilinear interpolation. By default + ``None`` (no rescaling). split_ratio : float, optional Ratio of training samples, by default 0.8 batch_size : int, optional @@ -409,8 +440,35 @@ def __init__( self.return_negative = return_negative self.augment_validation = augment_validation self._cache_pool_bytes = cache_pool_bytes - self._augmentation_transform = Compose(self.normalizations + self.augmentations) - self._no_augmentation_transform = Compose(self.normalizations) + if reference_pixel_size is not None: + inference_pixel_size = _read_pixel_size(data_path) + scale = reference_pixel_size / inference_pixel_size + self.initial_yx_patch_size = tuple(round(s * scale) for s in final_yx_patch_size) + _logger.info( + f"Pixel size rescaling enabled: " + f"reference={reference_pixel_size:.4f} µm/px, " + f"inference={inference_pixel_size:.4f} µm/px, " + f"scale={scale:.4f}. " + f"Extracting {self.initial_yx_patch_size} px patches " + f"and resizing to {final_yx_patch_size} px." + ) + from viscy_transforms import BatchedZoomd + + scale_yx = ( + final_yx_patch_size[0] / self.initial_yx_patch_size[0], + final_yx_patch_size[1] / self.initial_yx_patch_size[1], + ) + rescale_transform = BatchedZoomd( + keys=list(self.source_channel), + scale_factor=(1.0, *scale_yx), + mode="bilinear", + antialias=True, + ) + self._augmentation_transform = Compose(self.normalizations + self.augmentations + [rescale_transform]) + self._no_augmentation_transform = Compose(self.normalizations + [rescale_transform]) + else: + self._augmentation_transform = Compose(self.normalizations + self.augmentations) + self._no_augmentation_transform = Compose(self.normalizations) def _align_tracks_tables_with_positions( self, From 5e6407a2b741462fbc7405b37dbbf44a118a6681 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 14 May 2026 09:58:11 -0700 Subject: [PATCH 02/13] Add visualization script for TripletDataModule scale-aware rescaling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Plots n cell patches side-by-side: - Left: raw patch at initial_yx_patch_size (larger physical area) - Right: same patch bilinearly downscaled to final_yx_patch_size (model input) Both columns share the same percentile contrast window so differences are spatial. Physical scale (µm × µm) is shown in each panel title. Usage: python visualize_triplet_rescaling.py --data-path /path/to/data.zarr --tracks-path /path/to/tracks --source-channel Phase3D --z-range 0 5 --final-yx-patch-size 224 224 --reference-pixel-size 0.325 --output rescaling_comparison.png Co-Authored-By: Claude Sonnet 4.6 --- .../visualize_triplet_rescaling.py | 253 ++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 applications/dynaclr/scripts/dataloader_inspection/visualize_triplet_rescaling.py diff --git a/applications/dynaclr/scripts/dataloader_inspection/visualize_triplet_rescaling.py b/applications/dynaclr/scripts/dataloader_inspection/visualize_triplet_rescaling.py new file mode 100644 index 000000000..75d4bb942 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/visualize_triplet_rescaling.py @@ -0,0 +1,253 @@ +r"""Visualize scale-aware patch rescaling in TripletDataModule. + +Loads a few cell patches from an OME-Zarr dataset using TripletDataModule +with ``reference_pixel_size`` set, then plots each patch in two columns: + + Left — raw patch at ``initial_yx_patch_size`` (larger physical area sampled + to match the reference pixel size) + Right — the same patch bilinearly downscaled to ``final_yx_patch_size`` + (what the model actually receives) + +Both columns use the same percentile-based grayscale contrast window so +that spatial content differences are visible rather than intensity shifts. +A physical-scale annotation (µm × µm) is printed below each patch. + +Usage:: + + python visualize_triplet_rescaling.py \\ + --data-path /path/to/data.zarr \\ + --tracks-path /path/to/tracks \\ + --source-channel Phase3D \\ + --z-range 0 5 \\ + --final-yx-patch-size 224 224 \\ + --reference-pixel-size 0.325 \\ + --n-samples 6 \\ + --output rescaling_comparison.png +""" + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from viscy_data.triplet import TripletDataModule, _read_pixel_size + +# ── CLI ─────────────────────────────────────────────────────────────────────── + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + p = argparse.ArgumentParser( + description="Visualize TripletDataModule scale-aware patch rescaling.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument("--data-path", required=True, help="Path to OME-Zarr plate/position.") + p.add_argument("--tracks-path", required=True, help="Path to tracks CSV directory.") + p.add_argument( + "--source-channel", + required=True, + nargs="+", + help="Channel name(s) to load (e.g. Phase3D).", + ) + p.add_argument( + "--z-range", + required=True, + type=int, + nargs=2, + metavar=("Z_START", "Z_STOP"), + help="Z-slice range [start, stop).", + ) + p.add_argument( + "--final-yx-patch-size", + type=int, + nargs=2, + default=[224, 224], + metavar=("Y", "X"), + help="Target patch size fed to the model (pixels).", + ) + p.add_argument( + "--reference-pixel-size", + type=float, + required=True, + help="X pixel size (µm/px) of the model's training dataset.", + ) + p.add_argument( + "--n-samples", + type=int, + default=6, + help="Number of cell patches to visualize.", + ) + p.add_argument( + "--z-slice", + type=int, + default=None, + help="Z index within the patch to display. Defaults to middle slice.", + ) + p.add_argument( + "--output", + type=Path, + default=Path("rescaling_comparison.png"), + help="Output PNG path.", + ) + return p.parse_args() + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +def _percentile_norm(img: np.ndarray, lo: float = 1.0, hi: float = 99.0): + """Return (vmin, vmax) for percentile-based display.""" + vmin, vmax = np.percentile(img, [lo, hi]) + if vmax <= vmin: + vmax = vmin + 1.0 + return float(vmin), float(vmax) + + +def _rescale_yx( + patch: torch.Tensor, + target_yx: tuple[int, int], +) -> torch.Tensor: + """Bilinear-rescale YX of a (C, Z, H, W) patch to target_yx.""" + c, z, h, w = patch.shape + flat = patch.reshape(c * z, 1, h, w).float() + out = torch.nn.functional.interpolate( + flat, + size=target_yx, + mode="bilinear", + align_corners=False, + antialias=True, + ) + return out.reshape(c, z, *target_yx) + + +def _mid_slice(patch: torch.Tensor, z_idx: int | None) -> np.ndarray: + """Return a 2-D (H, W) numpy array from (C, Z, H, W), channel 0, given z.""" + z_size = patch.shape[1] + z = z_idx if z_idx is not None else z_size // 2 + z = max(0, min(z, z_size - 1)) + return patch[0, z].numpy() + + +# ── Main ────────────────────────────────────────────────────────────────────── + + +def main(): + """Run the visualization CLI.""" + args = parse_args() + + final_yx = tuple(args.final_yx_patch_size) + z_range = tuple(args.z_range) + + # Build the data module with scale-aware rescaling enabled. + dm = TripletDataModule( + data_path=args.data_path, + tracks_path=args.tracks_path, + source_channel=args.source_channel, + z_range=z_range, + final_yx_patch_size=final_yx, + reference_pixel_size=args.reference_pixel_size, + batch_size=args.n_samples, + num_workers=0, + ) + dm.setup("predict") + + inference_pixel_size = _read_pixel_size(args.data_path) + initial_yx = dm.initial_yx_patch_size + scale = args.reference_pixel_size / inference_pixel_size + + print( + f"Reference pixel size : {args.reference_pixel_size:.4f} µm/px\n" + f"Inference pixel size : {inference_pixel_size:.4f} µm/px\n" + f"Scale factor : {scale:.4f}\n" + f"initial_yx_patch_size: {initial_yx}\n" + f"final_yx_patch_size : {final_yx}\n" + ) + + # Draw samples directly from the dataset (raw, before any transforms). + n = min(args.n_samples, len(dm.predict_dataset)) + raw_batch = dm.predict_dataset.__getitems__(list(range(n))) + raw_patches = raw_batch["anchor"] # (B, C, Z, initial_Y, initial_X) + + # ── Plot ────────────────────────────────────────────────────────────────── + fig, axes = plt.subplots( + n, + 2, + figsize=(8, 4 * n), + squeeze=False, + ) + + for i in range(n): + raw = raw_patches[i] # (C, Z, initial_Y, initial_X) + rescaled = _rescale_yx(raw, final_yx) # (C, Z, final_Y, final_X) + + raw_2d = _mid_slice(raw, args.z_slice) + rescaled_2d = _mid_slice(rescaled, args.z_slice) + + # Shared contrast from the raw patch so differences are spatial only. + vmin, vmax = _percentile_norm(raw_2d) + + phys_raw_y = initial_yx[0] * inference_pixel_size + phys_raw_x = initial_yx[1] * inference_pixel_size + phys_final_y = final_yx[0] * args.reference_pixel_size + phys_final_x = final_yx[1] * args.reference_pixel_size + + # Left column: raw patch + ax = axes[i, 0] + ax.imshow(raw_2d, cmap="gray", vmin=vmin, vmax=vmax, interpolation="nearest") + ax.set_title( + f"Sample {i} — raw patch\n{initial_yx[0]}×{initial_yx[1]} px ({phys_raw_y:.1f}×{phys_raw_x:.1f} µm)", + fontsize=9, + ) + ax.axis("off") + + # Right column: rescaled patch + ax = axes[i, 1] + ax.imshow(rescaled_2d, cmap="gray", vmin=vmin, vmax=vmax, interpolation="nearest") + ax.set_title( + f"Sample {i} — rescaled (model input)\n" + f"{final_yx[0]}×{final_yx[1]} px " + f"({phys_final_y:.1f}×{phys_final_x:.1f} µm)", + fontsize=9, + ) + ax.axis("off") + + axes[0, 0].annotate( + "RAW (initial_yx_patch_size)", + xy=(0.5, 1.12), + xycoords="axes fraction", + ha="center", + fontsize=11, + fontweight="bold", + color="#e74c3c", + ) + axes[0, 1].annotate( + "RESCALED (final_yx_patch_size)", + xy=(0.5, 1.12), + xycoords="axes fraction", + ha="center", + fontsize=11, + fontweight="bold", + color="#2ecc71", + ) + + fig.suptitle( + f"Scale-aware patch rescaling\n" + f"reference={args.reference_pixel_size} µm/px · " + f"inference={inference_pixel_size:.4f} µm/px · " + f"scale={scale:.3f}", + fontsize=12, + fontweight="bold", + y=1.01, + ) + + fig.tight_layout() + args.output.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(args.output, dpi=150, bbox_inches="tight") + print(f"Saved: {args.output}") + plt.close(fig) + + +if __name__ == "__main__": + main() From 378a4ac5b19087552d5e9fe83f52c49385686555 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 13:16:02 -0700 Subject: [PATCH 03/13] fix(viscy-data): handle nested timepoint_statistics in _collate_norm_meta _collate_norm_meta iterated every normalization level and called torch.stack on each stat, but timepoint_statistics is nested {timepoint: {stat: tensor}} rather than flat {stat: tensor}. Any zarr carrying timepoint_statistics (alongside fov/dataset stats) crashed batch collation in TripletDataModule with "expected Tensor as element 0 ... but got dict". Stack within each timepoint sub-dict instead. Co-Authored-By: Claude Opus 4.8 (1M context) --- packages/viscy-data/src/viscy_data/_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/packages/viscy-data/src/viscy_data/_utils.py b/packages/viscy-data/src/viscy_data/_utils.py index cad79cb7d..bf133e184 100644 --- a/packages/viscy-data/src/viscy_data/_utils.py +++ b/packages/viscy-data/src/viscy_data/_utils.py @@ -187,6 +187,13 @@ def _collate_norm_meta(norm_metas: list[NormMeta]) -> NormMeta: if level_stats is None: result[ch][level] = None continue + if level == "timepoint_statistics": + # Nested {timepoint: {stat: tensor}}; stack within each timepoint. + result[ch][level] = { + tp: {stat: torch.stack([m[ch][level][tp][stat] for m in norm_metas]) for stat in tp_stats} + for tp, tp_stats in level_stats.items() + } + continue result[ch][level] = {stat: torch.stack([m[ch][level][stat] for m in norm_metas]) for stat in level_stats} return result From 19a3301490d0b0ab20af3fe4e3aaedbb3c9e03d2 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 13:17:21 -0700 Subject: [PATCH 04/13] feat(viscy-data): add Z-reduction and focus-centered z_range to TripletDataModule Lets a 3D OME-Zarr feed a 2D model without materializing a separate MIP dataset, and centers the extracted Z window on each FOV's focus plane. - z_reduction ("mip"/"center"): collapse the extracted z_range to one slice via BatchedChannelWiseZReductiond. Label-free channels (resolved by parse_channel_name) take the center slice; others are max-projected. on_after_batch_transfer expects Z=1 when reduction is on. - z_extraction_window/z_focus_offset/focus_channel: resolve a per-FOV focus-centered window from each position's focus_slice[ch].fov_statistics.z_focus_mean (fallback z_total//2), all windows the same width. z_range stays as an explicit override; exactly one of z_range / z_extraction_window must be given. Per-FOV windows are resolved at setup() and looked up per patch in the dataset. Tests cover both reduction strategies (discriminating center vs MIP), the normalize-then-reduce order, per-FOV focus resolution, and the z_range/z_extraction_window XOR guard. Co-Authored-By: Claude Opus 4.8 (1M context) --- packages/viscy-data/src/viscy_data/triplet.py | 214 ++++++++++++++++-- packages/viscy-data/tests/test_triplet.py | 184 ++++++++++++++- 2 files changed, 375 insertions(+), 23 deletions(-) diff --git a/packages/viscy-data/src/viscy_data/triplet.py b/packages/viscy-data/src/viscy_data/triplet.py index 5f6592ff7..2d7f8a32f 100644 --- a/packages/viscy-data/src/viscy_data/triplet.py +++ b/packages/viscy-data/src/viscy_data/triplet.py @@ -59,6 +59,84 @@ def _read_pixel_size(data_path: str | Path) -> float: raise ValueError(f"No positions found in {data_path}") +def _focus_window(z_focus_mean: float | None, z_total: int, z_extraction_window: int, z_focus_offset: float) -> slice: + """Compute a fixed-width Z window centered on a focus plane. + + Parameters + ---------- + z_focus_mean : float or None + Focus plane (slice index). When ``None``, the window is centered on + ``z_total // 2``. + z_total : int + Total number of Z slices available in the FOV. + z_extraction_window : int + Window width (clamped to ``z_total``). + z_focus_offset : float + Fraction of the window placed below the focus plane (0.5 = symmetric). + + Returns + ------- + slice + ``slice(z_start, z_end)`` with ``z_end - z_start == min(z_extraction_window, z_total)``. + """ + z_center = int(round(z_focus_mean)) if z_focus_mean is not None else z_total // 2 + effective_extract = min(z_extraction_window, z_total) + z_below = int(effective_extract * z_focus_offset) + z_start = max(0, z_center - z_below) + z_end = min(z_total, z_start + effective_extract) + z_start = max(0, z_end - effective_extract) + return slice(z_start, z_end) + + +def _resolve_per_fov_z_ranges( + positions: "list[Position]", + z_extraction_window: int, + z_focus_offset: float, + focus_channel: str, +) -> dict[str, slice]: + """Resolve a per-FOV focus-centered Z window for each position. + + Each FOV's window is centered on its own + ``zattrs["focus_slice"][focus_channel]["fov_statistics"]["z_focus_mean"]`` + (the per-FOV focus plane used by the cell-index pipeline). All windows have + the same width (``z_extraction_window``), so the extracted patch Z-size is + uniform across FOVs. FOVs lacking a recorded focus fall back to the + geometric center (``z_total // 2``). + + Parameters + ---------- + positions : list[Position] + Open OME-Zarr positions. + z_extraction_window : int + Window width. + z_focus_offset : float + Fraction of the window placed below the focus plane. + focus_channel : str + Channel name to look up in each FOV's ``focus_slice`` metadata. + + Returns + ------- + dict[str, slice] + Map from stripped FOV name to its ``slice(z_start, z_end)``. + """ + z_ranges: dict[str, slice] = {} + for pos in positions: + fov_name = pos.zgroup.name.strip("/") + z_total = pos["0"].shape[2] + fov_stats = pos.zattrs.get("focus_slice", {}).get(focus_channel, {}).get("fov_statistics", {}) + z_focus_mean = fov_stats.get("z_focus_mean") + z_ranges[fov_name] = _focus_window(z_focus_mean, z_total, z_extraction_window, z_focus_offset) + _logger.info( + "FOV '%s': focus-centered z_range=%s (z_total=%d, z_focus_mean=%s, window=%d).", + fov_name, + (z_ranges[fov_name].start, z_ranges[fov_name].stop), + z_total, + z_focus_mean, + min(z_extraction_window, z_total), + ) + return z_ranges + + def _default_tensorstore_config(cache_pool_bytes: int = 0) -> TensorStoreConfig: """Build a TensorStoreConfig with SLURM-aware concurrency.""" cpus = os.environ.get("SLURM_CPUS_PER_TASK") @@ -78,7 +156,7 @@ def __init__( tracks_tables: "list[pd.DataFrame]", channel_names: list[str], initial_yx_patch_size: tuple[int, int], - z_range: slice, + z_range: "slice | dict[str, slice]", fit: bool = True, predict_cells: bool = False, include_fov_names: list[str] | None = None, @@ -98,8 +176,10 @@ def __init__( Input channel names initial_yx_patch_size : tuple[int, int] YX size of the initially sampled image patch - z_range : slice - Range of Z-slices + z_range : slice or dict[str, slice] + Range of Z-slices. A single ``slice`` applies to every FOV; a + ``dict`` maps each stripped FOV name to its own slice (per-FOV + focus-centered windows). All slices must have the same width. fit : bool, optional Fitting mode in which the full triplet will be sampled, only sample anchor if ``False``, by default True @@ -141,6 +221,12 @@ def __init__( self.return_negative = return_negative self._tensorstores: dict[str, "ts.TensorStore"] = {} + def _fov_z_range(self, fov_name: str) -> slice: + """Return the Z slice for a FOV (per-FOV dict or shared single slice).""" + if isinstance(self.z_range, dict): + return self.z_range[fov_name.strip("/")] + return self.z_range + def _get_tensorstore(self, position: Position) -> "ts.TensorStore": """Get cached tensorstore handle, opening via iohub's tensorstore impl on miss. @@ -175,8 +261,9 @@ def _filter_tracks(self, tracks_tables: "list[pd.DataFrame]") -> "pd.DataFrame": tracks["fov_name"] = pos.zgroup.name.strip("/") tracks["global_track_id"] = tracks["fov_name"].str.cat(tracks["track_id"].astype(str), sep="_") image: ImageArray = pos["0"] - if self.z_range.stop > image.slices: - raise ValueError(f"Z range {self.z_range} exceeds image with Z={image.slices}") + z_range = self._fov_z_range(pos.zgroup.name) + if z_range.stop > image.slices: + raise ValueError(f"Z range {z_range} exceeds image with Z={image.slices}") y_range = (y_exclude, image.height - y_exclude) x_range = (x_exclude, image.width - x_exclude) # FIXME: Check if future time points are available after interval @@ -254,7 +341,7 @@ def _slice_patch(self, track_row: "pd.Series") -> "tuple[ts.TensorStore, NormMet patch = image.oindex[ time, [int(i) for i in self.channel_indices], - self.z_range, + self._fov_z_range(position.zgroup.name), slice(y_center - y_half, y_center + y_half), slice(x_center - x_half, x_center + x_half), ] @@ -314,7 +401,10 @@ def __init__( data_path: str, tracks_path: str, source_channel: str | Sequence[str], - z_range: tuple[int, int], + z_range: tuple[int, int] | None = None, + z_extraction_window: int | None = None, + z_focus_offset: float = 0.5, + focus_channel: str | None = None, initial_yx_patch_size: tuple[int, int] = (512, 512), final_yx_patch_size: tuple[int, int] = (224, 224), split_ratio: float = 0.8, @@ -336,6 +426,7 @@ def __init__( z_window_size: int | None = None, cache_pool_bytes: int = 0, reference_pixel_size: float | None = None, + z_reduction: Literal["mip", "center"] | None = None, ): """Lightning data module for triplet sampling of patches. @@ -347,8 +438,21 @@ def __init__( Tracks labels dataset path source_channel : str | Sequence[str] List of input channel names - z_range : tuple[int, int] - Range of valid z-slices + z_range : tuple[int, int] or None, optional + Explicit ``(z_start, z_end)`` slice range. Mutually exclusive with + ``z_extraction_window``: provide exactly one. When ``None``, the + range is resolved from the focus plane via ``z_extraction_window``. + z_extraction_window : int or None, optional + Number of Z slices to extract, centered on the plate's focus plane + (read from ``zattrs["focus_slice"]``). Mutually exclusive with + ``z_range``. By default ``None`` (use the explicit ``z_range``). + z_focus_offset : float, optional + Fraction of ``z_extraction_window`` placed below the focus plane, + by default 0.5 (symmetric). Only used with ``z_extraction_window``. + focus_channel : str or None, optional + Channel name whose ``focus_slice`` metadata centers the window. + Defaults to the first ``source_channel``. Only used with + ``z_extraction_window``. initial_yx_patch_size : tuple[int, int], optional YX size of the initially sampled image patch, by default (512, 512). Ignored when ``reference_pixel_size`` is set — the patch size is then @@ -364,6 +468,17 @@ def __init__( physical area is covered. The extracted patch is then rescaled to ``final_yx_patch_size`` with bilinear interpolation. By default ``None`` (no rescaling). + z_reduction : {"mip", "center"} or None, optional + Collapse the extracted ``z_range`` window to a single Z-slice so a + 3D dataset can feed a 2D model without materializing a separate + MIP dataset. Label-free channels take the center slice and all other + channels are max-projected; channel type is resolved per channel + name via :func:`viscy_data.channel_utils.parse_channel_name`. The + value (``"mip"`` or ``"center"``) only sets the fallback used when + no channel can be classified as label-free. The caller controls + which Z-planes are collapsed by setting ``z_range`` (e.g. a window + centered on the focus plane). By default ``None`` (no reduction; + full ``z_range`` is kept). split_ratio : float, optional Ratio of training samples, by default 0.8 batch_size : int, optional @@ -412,11 +527,18 @@ def __init__( """ if num_workers > 1: warnings.warn("Using more than 1 thread worker will likely degrade performance.") + if (z_range is None) == (z_extraction_window is None): + raise ValueError("Provide exactly one of 'z_range' or 'z_extraction_window'.") + # Extraction window width is known without opening the zarr: it is the + # explicit z_range span or z_extraction_window. Per-FOV focus centering + # (when z_extraction_window is set) is resolved at setup() time, where + # the positions are open. + extraction_width = (z_range[1] - z_range[0]) if z_range is not None else z_extraction_window super().__init__( data_path=data_path, source_channel=source_channel, target_channel=[], - z_window_size=z_window_size or z_range[1] - z_range[0], + z_window_size=z_window_size or extraction_width, split_ratio=split_ratio, batch_size=batch_size, num_workers=num_workers, @@ -428,7 +550,12 @@ def __init__( prefetch_factor=prefetch_factor, pin_memory=pin_memory, ) - self.z_range = slice(*z_range) + self.z_range = slice(*z_range) if z_range is not None else None + self._z_extraction_window = z_extraction_window + self._z_focus_offset = z_focus_offset + self._focus_channel = focus_channel or ( + source_channel[0] if isinstance(source_channel, (list, tuple)) else source_channel + ) self.tracks_path = Path(tracks_path) self.initial_yx_patch_size = initial_yx_patch_size self._include_wells = fit_include_wells @@ -440,6 +567,11 @@ def __init__( self.return_negative = return_negative self.augment_validation = augment_validation self._cache_pool_bytes = cache_pool_bytes + self.z_reduction = z_reduction + + # Transforms appended after normalization and augmentation, in order. + extra_transforms: list[MapTransform] = [] + if reference_pixel_size is not None: inference_pixel_size = _read_pixel_size(data_path) scale = reference_pixel_size / inference_pixel_size @@ -458,17 +590,36 @@ def __init__( final_yx_patch_size[0] / self.initial_yx_patch_size[0], final_yx_patch_size[1] / self.initial_yx_patch_size[1], ) - rescale_transform = BatchedZoomd( - keys=list(self.source_channel), - scale_factor=(1.0, *scale_yx), - mode="bilinear", - antialias=True, + extra_transforms.append( + BatchedZoomd( + keys=list(self.source_channel), + scale_factor=(1.0, *scale_yx), + mode="bilinear", + antialias=True, + ) ) - self._augmentation_transform = Compose(self.normalizations + self.augmentations + [rescale_transform]) - self._no_augmentation_transform = Compose(self.normalizations + [rescale_transform]) - else: - self._augmentation_transform = Compose(self.normalizations + self.augmentations) - self._no_augmentation_transform = Compose(self.normalizations) + + if z_reduction is not None: + from viscy_data.channel_utils import parse_channel_name + from viscy_transforms import BatchedChannelWiseZReductiond + + labelfree_keys = [ch for ch in self.source_channel if parse_channel_name(ch)["channel_type"] == "labelfree"] + mip_keys = [ch for ch in self.source_channel if ch not in labelfree_keys] + _logger.info( + f"Z-reduction enabled (default_strategy={z_reduction}): collapsing z_range to 1 slice. " + f"MIP channels={mip_keys}, center-slice channels={labelfree_keys}." + ) + extra_transforms.append( + BatchedChannelWiseZReductiond( + keys=list(self.source_channel), + labelfree_keys=labelfree_keys, + default_strategy=z_reduction, + allow_missing_keys=True, + ) + ) + + self._augmentation_transform = Compose(self.normalizations + self.augmentations + extra_transforms) + self._no_augmentation_transform = Compose(self.normalizations + extra_transforms) def _align_tracks_tables_with_positions( self, @@ -508,9 +659,26 @@ def _base_dataset_settings(self) -> dict: "time_interval": self.time_interval, } + def _resolve_z_range(self, positions: "list[Position]") -> "slice | dict[str, slice]": + """Resolve the Z range for a set of positions. + + Returns the explicit ``self.z_range`` slice when one was given, otherwise + a per-FOV focus-centered dict resolved from each position's + ``focus_slice`` metadata. + """ + if self.z_range is not None: + return self.z_range + return _resolve_per_fov_z_ranges( + positions=positions, + z_extraction_window=self._z_extraction_window, + z_focus_offset=self._z_focus_offset, + focus_channel=self._focus_channel, + ) + def _setup_fit(self, dataset_settings: dict): """Set up training and validation triplet datasets.""" positions, tracks_tables = self._align_tracks_tables_with_positions() + dataset_settings = {**dataset_settings, "z_range": self._resolve_z_range(positions)} shuffled_indices = self._set_fit_global_state(len(positions)) positions = [positions[i] for i in shuffled_indices] tracks_tables = [tracks_tables[i] for i in shuffled_indices] @@ -544,6 +712,7 @@ def _setup_predict(self, dataset_settings: dict): """Set up the prediction triplet dataset.""" self._set_predict_global_state() positions, tracks_tables = self._align_tracks_tables_with_positions() + dataset_settings = {**dataset_settings, "z_range": self._resolve_z_range(positions)} self.predict_dataset = TripletDataset( positions=positions, tracks_tables=tracks_tables, @@ -621,7 +790,8 @@ def on_after_batch_transfer(self, batch, dataloader_idx: int): if isinstance(batch, Tensor): # example array return batch - expected_spatial = (self.z_window_size, *self.yx_patch_size) + expected_z = 1 if self.z_reduction is not None else self.z_window_size + expected_spatial = (expected_z, *self.yx_patch_size) for key in ["anchor", "positive", "negative"]: if key in batch: norm_meta_key = f"{key}_norm_meta" diff --git a/packages/viscy-data/tests/test_triplet.py b/packages/viscy-data/tests/test_triplet.py index afb582341..7d0be1e6c 100644 --- a/packages/viscy-data/tests/test_triplet.py +++ b/packages/viscy-data/tests/test_triplet.py @@ -1,8 +1,10 @@ import pandas as pd +import torch from iohub import open_ome_zarr -from pytest import mark +from pytest import mark, raises from viscy_data import TripletDataModule, TripletDataset +from viscy_data.channel_utils import parse_channel_name @mark.parametrize("include_wells", [None, ["A/1", "A/2", "B/1"]]) @@ -107,6 +109,186 @@ def test_datamodule_z_window_size(preprocessed_hcs_dataset, tracks_hcs_dataset, ) +def test_z_range_xor_extraction_window(preprocessed_hcs_dataset, tracks_hcs_dataset): + """Exactly one of z_range / z_extraction_window must be provided.""" + with open_ome_zarr(preprocessed_hcs_dataset) as dataset: + channel_names = dataset.channel_names + common = dict( + data_path=preprocessed_hcs_dataset, + tracks_path=tracks_hcs_dataset, + source_channel=channel_names, + num_workers=0, + ) + with raises(ValueError, match="exactly one"): + TripletDataModule(z_range=None, z_extraction_window=None, **common) + with raises(ValueError, match="exactly one"): + TripletDataModule(z_range=(4, 9), z_extraction_window=8, **common) + + +@mark.parametrize("z_focus_offset", [0.5, 0.3]) +def test_focus_centered_z_range(tmp_path_factory, preprocessed_hcs_dataset, tracks_hcs_dataset, z_focus_offset): + """z_extraction_window resolves a per-FOV focus-centered z_range from zattrs. + + Writes a different per-FOV ``z_focus_mean`` to each position's ``focus_slice`` + ``fov_statistics`` (on a private copy of the session-scoped dataset), then + checks each FOV gets its own window of ``z_extraction_window`` slices centered + on its focus plane with ``z_focus_offset`` of the window below it. + """ + import shutil + + z_extraction_window = 5 + # Copy the session-scoped dataset so writing focus_slice does not leak. + data_path = tmp_path_factory.mktemp("focus") / "data.zarr" + shutil.copytree(preprocessed_hcs_dataset, data_path) + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + fov_names = [name for name, _ in dataset.positions()] + z_total = dataset[fov_names[0]]["0"].shape[2] + focus_channel = channel_names[0] + + # Give each FOV a distinct focus plane so per-FOV resolution is exercised. + per_fov_focus = {fov: float(3 + i % (z_total - 4)) for i, fov in enumerate(fov_names)} + with open_ome_zarr(data_path, mode="r+") as dataset: + for fov, pos in dataset.positions(): + pos.zattrs["focus_slice"] = {focus_channel: {"fov_statistics": {"z_focus_mean": per_fov_focus[fov]}}} + + def expected_window(z_focus_mean): + z_center = round(z_focus_mean) + z_below = int(z_extraction_window * z_focus_offset) + z_start = max(0, z_center - z_below) + z_end = min(z_total, z_start + z_extraction_window) + return slice(max(0, z_end - z_extraction_window), z_end) + + dm = TripletDataModule( + data_path=data_path, + tracks_path=tracks_hcs_dataset, + source_channel=channel_names, + z_extraction_window=z_extraction_window, + z_focus_offset=z_focus_offset, + focus_channel=focus_channel, + initial_yx_patch_size=(32, 32), + final_yx_patch_size=(32, 32), + num_workers=0, + batch_size=4, + return_negative=True, + ) + assert dm.z_range is None # explicit z_range not given; resolved per-FOV at setup + dm.setup(stage="fit") + resolved = dm.train_dataset.z_range + assert isinstance(resolved, dict) + # Every resolved FOV window matches its own focus plane and has uniform width. + for fov, z_slice in resolved.items(): + assert z_slice == expected_window(per_fov_focus[fov.strip("/")]), f"FOV {fov} window mismatch" + assert z_slice.stop - z_slice.start == z_extraction_window + for batch in dm.train_dataloader(): + dm.on_after_batch_transfer(batch, 0) + assert batch["anchor"].shape[2] == z_extraction_window + break + break + + +@mark.parametrize("z_reduction", ["mip", "center"]) +def test_datamodule_z_reduction(preprocessed_hcs_dataset, tracks_hcs_dataset, z_reduction): + """z_reduction collapses the z_range window to a single slice per channel. + + Label-free channels (Phase, Retardance) take the center slice; all other + channels (GFP, DAPI) are max-projected, regardless of ``z_reduction``, + which only sets the fallback strategy. + """ + z_range = (4, 9) + yx_patch_size = [32, 32] + batch_size = 4 + with open_ome_zarr(preprocessed_hcs_dataset) as dataset: + channel_names = dataset.channel_names + dm = TripletDataModule( + data_path=preprocessed_hcs_dataset, + tracks_path=tracks_hcs_dataset, + source_channel=channel_names, + z_range=z_range, + initial_yx_patch_size=(32, 32), + final_yx_patch_size=(32, 32), + num_workers=0, + batch_size=batch_size, + return_negative=True, + z_reduction=z_reduction, + ) + dm.setup(stage="fit") + labelfree = {ch for ch in channel_names if parse_channel_name(ch)["channel_type"] == "labelfree"} + assert labelfree, "fixture must contain at least one label-free channel" + assert set(channel_names) - labelfree, "fixture must contain at least one non-label-free channel" + z_window_size = z_range[1] - z_range[0] + center = z_window_size // 2 + for batch in dm.train_dataloader(): + # Snapshot the raw extracted patch before transforms reduce it. + raw = batch["anchor"].clone() + dm.on_after_batch_transfer(batch, 0) + reduced = batch["anchor"] + assert reduced.shape == (batch_size, len(channel_names), 1, *yx_patch_size) + for ci, ch in enumerate(channel_names): + mip = raw[:, ci].amax(dim=1, keepdim=True) + center_slice = raw[:, ci, center : center + 1] + if ch in labelfree: + assert torch.equal(reduced[:, ci], center_slice), f"label-free channel {ch} should be center-sliced" + # Random Z-stack: center slice must differ from MIP, so a strategy + # swap (center vs mip) would be caught rather than passing silently. + assert not torch.equal(reduced[:, ci], mip), f"label-free channel {ch} was max-projected, not centered" + else: + assert torch.equal(reduced[:, ci], mip), f"non-label-free channel {ch} should be max-projected" + assert not torch.equal(reduced[:, ci], center_slice), ( + f"non-label-free channel {ch} was centered, not MIP" + ) + + +def test_z_reduction_runs_on_normalized_stack(preprocessed_hcs_dataset, tracks_hcs_dataset): + """Z-reduction must run after normalization (production order: normalize -> reduce). + + The fixture's ``dataset_statistics`` normalization is a fixed monotone-increasing + affine ``(x - 0.5) / (1/sqrt(12))``, which commutes with both center-slice and + MIP. So the datamodule output (normalize-then-reduce) must equal the same affine + applied to the reduced raw stack. A bug that reduced *before* normalizing, or + skipped normalization, would change the values and fail this check. + """ + import numpy as np + + from viscy_transforms import NormalizeSampled + + z_range = (4, 9) + batch_size = 4 + mean, std = 0.5, 1 / np.sqrt(12) # matches preprocessed_hcs_dataset fixture + with open_ome_zarr(preprocessed_hcs_dataset) as dataset: + channel_names = dataset.channel_names + normalizations = [ + NormalizeSampled(keys=list(channel_names), level="dataset_statistics", subtrahend="mean", divisor="std") + ] + dm = TripletDataModule( + data_path=preprocessed_hcs_dataset, + tracks_path=tracks_hcs_dataset, + source_channel=channel_names, + z_range=z_range, + initial_yx_patch_size=(32, 32), + final_yx_patch_size=(32, 32), + num_workers=0, + batch_size=batch_size, + return_negative=True, + normalizations=normalizations, + z_reduction="mip", + ) + dm.setup(stage="fit") + labelfree = {ch for ch in channel_names if parse_channel_name(ch)["channel_type"] == "labelfree"} + center = (z_range[1] - z_range[0]) // 2 + for batch in dm.train_dataloader(): + raw = batch["anchor"].clone() + dm.on_after_batch_transfer(batch, 0) + reduced = batch["anchor"] + for ci, ch in enumerate(channel_names): + if ch in labelfree: + reduced_raw = raw[:, ci, center : center + 1] + else: + reduced_raw = raw[:, ci].amax(dim=1, keepdim=True) + expected = (reduced_raw - mean) / (std + 1e-8) + assert torch.allclose(reduced[:, ci], expected, atol=1e-5), f"channel {ch} not reduced on normalized stack" + + def test_filter_anchors_time_interval_any(preprocessed_hcs_dataset, tracks_with_gaps_dataset): """Test that time_interval='any' returns all tracks unchanged.""" with open_ome_zarr(preprocessed_hcs_dataset) as dataset: From 365fc4279743f75808bc2c9cd5ce71fe10e0c7ac Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 13:17:29 -0700 Subject: [PATCH 05/13] docs(dynaclr): triplet inference DAG + sample 2D-from-3D predict config Documents the TripletDataModule predict path (zarr + tracking, not parquet) and adds a runnable sample config demonstrating z_reduction + reference_pixel_size to feed a 3D dataset to a 2D model. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../prediction/predict_triplet_2d_from_3d.yml | 81 ++++++++ .../dynaclr/docs/DAGs/inference_triplet.md | 185 ++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 applications/dynaclr/configs/prediction/predict_triplet_2d_from_3d.yml create mode 100644 applications/dynaclr/docs/DAGs/inference_triplet.md diff --git a/applications/dynaclr/configs/prediction/predict_triplet_2d_from_3d.yml b/applications/dynaclr/configs/prediction/predict_triplet_2d_from_3d.yml new file mode 100644 index 000000000..330962b01 --- /dev/null +++ b/applications/dynaclr/configs/prediction/predict_triplet_2d_from_3d.yml @@ -0,0 +1,81 @@ +# Sample triplet predict config: feed a 3D OME-Zarr to a 2D model. +# +# Demonstrates the TripletDataModule options that avoid materializing a separate +# 2D MIP dataset: +# - z_reduction: collapse the extracted z_range window to a single slice +# (label-free channels -> center slice, others -> max projection). +# - reference_pixel_size: rescale each patch to the model's training pixel size +# when the inference dataset was acquired at a different magnification. +# +# See docs/DAGs/inference_triplet.md for the full pipeline. +# +# TODO: point to the path to save the embeddings +# TODO: point to the path to the data +# TODO: point to the path to the tracks +# TODO: point to the path to the checkpoint + +seed_everything: 42 +trainer: + accelerator: gpu + strategy: auto + devices: 1 + num_nodes: 1 + precision: 32-true + inference_mode: true + logger: false + callbacks: + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: 10 + - class_path: viscy_utils.callbacks.embedding_writer.EmbeddingWriter + init_args: + output_path: #TODO point to the path to save the embeddings (e.g. /embeddings/embeddings.zarr) + embedding_key: features # "projections" for frozen-backbone MLP heads + overwrite: true + pca_kwargs: null # reductions left to a later step (dynaclr reduce-dimensionality) + phate_kwargs: null + umap_kwargs: null +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 1 # 2D model — pairs with data.z_reduction below + stem_kernel_size: [1, 4, 4] + stem_stride: [1, 4, 4] + embedding_dim: 768 + projection_dim: 32 + drop_path_rate: 0.0 # stochastic depth off at inference + example_input_array_shape: [1, 1, 1, 160, 160] +data: + class_path: viscy_data.triplet.TripletDataModule + init_args: + data_path: #TODO point to the path to the data (e.g. /registered_test.zarr) + tracks_path: #TODO point to the path to the tracks (e.g. /track_test.zarr) + source_channel: + - Phase3D + z_range: [15, 45] # 3D window; collapsed to 1 slice by z_reduction. + # Center it on the focus plane to control which planes collapse. + z_reduction: mip # "mip" (max projection) or "center" (center slice). + # Label-free channels always take the center slice; + # others are max-projected. This sets the fallback only. + reference_pixel_size: 0.1494 # µm/px of the model's TRAINING dataset. Remove (or set + # null) when the inference dataset is at the same resolution. + initial_yx_patch_size: [160, 160] # ignored when reference_pixel_size is set (computed from ratio) + final_yx_patch_size: [160, 160] # patch size fed to the model after rescale + z_reduction + batch_size: 32 + num_workers: 0 # REQUIRED for predict (avoids zarr-fork deadlock) + normalizations: + - class_path: viscy_transforms.NormalizeSampled + init_args: + keys: [Phase3D] + level: fov_statistics + subtrahend: mean + divisor: std + # augmentations omitted: predict must be deterministic. The datamodule still + # applies normalizations + reference_pixel_size rescale + z_reduction at predict time. +return_predictions: false +ckpt_path: #TODO point to the path to the checkpoint (e.g. /checkpoints/epoch=94-step=2375.ckpt) diff --git a/applications/dynaclr/docs/DAGs/inference_triplet.md b/applications/dynaclr/docs/DAGs/inference_triplet.md new file mode 100644 index 000000000..d031703ba --- /dev/null +++ b/applications/dynaclr/docs/DAGs/inference_triplet.md @@ -0,0 +1,185 @@ +# Inference DAG (Triplet path) + +Embedding inference for a trained DynaCLR encoder using `TripletDataModule` — +the **zarr + tracking** path (no parquet). Use this when you want to run a +trained checkpoint directly over an OME-Zarr store and its `ultrack` tracking, +rather than the parquet-first `MultiExperimentDataModule` path +(see [evaluation.md](evaluation.md) for the parquet path). + +The triplet path is the one that carries the patch-rescaling +(`reference_pixel_size`) and on-the-fly Z-reduction (`z_reduction`) options, so +a 3D zarr can feed a 2D model without materializing a separate MIP dataset. + +## Prerequisites + +- A trained checkpoint (`last.ckpt` or a selected epoch) for a + `dynaclr.engine.ContrastiveModule`. +- The inference dataset as an OME-Zarr store with `normalization` metadata in + the FOV `zattrs` (so `NormalizeSampled` has per-FOV stats), plus a tracking + zarr/CSV directory with `track_id, t, y, x` columns. +- The model's training pixel size (µm/px) if the inference dataset was acquired + at a different magnification — passed as `reference_pixel_size` to rescale + each patch to the physical area the model was trained on. + +## Step-by-step detail + +``` +dataset.zarr (preprocessed: normalization in FOV zattrs) +tracking.zarr/CSV (track_id, t, y, x per cell) +checkpoint.ckpt (trained ContrastiveModule) + │ + ├──► predict config (TripletDataModule + ContrastiveModule + EmbeddingWriter) + ▼ +viscy predict --config configs/prediction/predict_triplet.yml + │ TripletDataModule(fit=False): samples ONE anchor patch per (cell, timepoint) + │ - extracts z_range window, yx at initial_yx_patch_size + │ - reference_pixel_size → extract larger patch, BatchedZoomd to final_yx + │ - z_reduction → BatchedChannelWiseZReductiond collapses Z to 1 (2D model) + │ ContrastiveModule.predict_step → backbone features (+ projections) + │ EmbeddingWriter accumulates (features, index) and writes one combined store + ▼ +embeddings.zarr (AnnData: obsm["X_backbone"], obs = fov_name/track_id/t/...) + │ + ▼ +dynaclr split-embeddings --input embeddings.zarr --output-dir embeddings/ + │ groups rows by obs["experiment"], writes one zarr per experiment + │ removes the combined store afterwards + ▼ +embeddings/{experiment}.zarr (one per experiment, informatively named) + │ + ▼ +downstream eval (reduce-dimensionality, linear classifiers, MMD, pseudotime …) + see evaluation.md / pseudotime.md +``` + +## Pipeline DAG (process dependency) + +``` +predict config + checkpoint + zarr + tracking + │ + ▼ +viscy predict (GPU, minutes–hours by cell count) + │ + ▼ +split-embeddings (CPU, ~1 min, I/O bound) + │ + ▼ +downstream eval (CPU/GPU, per analysis) +``` + +## Key commands + +| Step | Command | Input | Output | +| ---------------- | --------------------------------------------------------------------------- | --------------------------------------- | ----------------------------------- | +| Predict | `uv run viscy predict --config configs/prediction/predict_triplet.yml` | predict config + ckpt + zarr + tracking | combined `embeddings.zarr` | +| Predict (SLURM) | `sbatch configs/prediction/predict_triplet.sh` | same | combined `embeddings.zarr` | +| Split embeddings | `dynaclr split-embeddings --input embeddings.zarr --output-dir embeddings/` | combined zarr with `obs["experiment"]` | one `{experiment}.zarr` per dataset | + +## What lives where + +| Data | Location | When written | +| --------------------------------- | ------------------------------------------------- | --------------------- | +| Pixel data (TCZYX) | dataset.zarr on VAST | data prep | +| Cell tracks (track_id, t, y, x) | tracking.zarr / CSV on VAST | data prep | +| Normalization stats (per FOV) | dataset.zarr FOV `zattrs["normalization"]` | `viscy preprocess` | +| Backbone embeddings | `embeddings.zarr` → `obsm["X_backbone"]` | `viscy predict` | +| Cell index (fov_name/track_id/t) | `embeddings.zarr` → `obs` | `viscy predict` | +| Per-experiment embeddings | `embeddings/{experiment}.zarr` | `split-embeddings` | + +## Predict config structure + +A ready-to-edit sample lives at +[`configs/prediction/predict_triplet_2d_from_3d.yml`](../../configs/prediction/predict_triplet_2d_from_3d.yml) +(the 2D-from-3D case, with `z_reduction` + `reference_pixel_size`). The skeleton +below annotates the load-bearing fields: + +```yaml +seed_everything: 42 + +trainer: + accelerator: gpu + devices: 1 + precision: 32-true + inference_mode: true + logger: false + callbacks: + - class_path: viscy_utils.callbacks.embedding_writer.EmbeddingWriter + init_args: + output_path: /path/to/embeddings/embeddings.zarr + embedding_key: features # "projections" for frozen-backbone MLP heads + overwrite: true + +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 1 # 2D model — pair with z_reduction below + # … must match the trained checkpoint's encoder args … + +data: + class_path: viscy_data.TripletDataModule + init_args: + data_path: /path/to/dataset.zarr + tracks_path: /path/to/tracking.zarr + source_channel: [Phase3D] + z_range: [0, 16] # window collapsed by z_reduction + final_yx_patch_size: [160, 160] + reference_pixel_size: 0.1494 # rescale to the model's training pixel size (optional) + z_reduction: mip # collapse z_range to 1 slice for a 2D model (optional) + batch_size: 400 + num_workers: 0 # REQUIRED for predict (see Notes) + predict_cells: false # true + include_fov_names/include_track_ids to subset + normalizations: + - class_path: viscy_transforms.NormalizeSampled + init_args: + keys: [Phase3D] + subtrahend: mean + divisor: std + augmentations: [] # MUST be empty for deterministic predict + +ckpt_path: /path/to/checkpoint/last.ckpt +return_predictions: false # writer persists to zarr; don't hold in memory +``` + +## Notes + +- **`num_workers: 0` is required for the predict path.** `HCSDataModule`/ + `TripletDataModule` predict does not use `mmap_preload`, and >0 workers risks a + zarr-fork deadlock. This matches the dynacell predict overlay. +- **`augmentations: []`** — predict must be deterministic. The datamodule still + applies `normalizations` (and the `reference_pixel_size` rescale + `z_reduction` + collapse) at predict time via `_no_augmentation_transform`; only random + augmentations are dropped. +- **2D from 3D without a MIP dataset.** Set `z_reduction: mip` (or `center`) to + collapse the extracted `z_range` window to a single slice. Label-free channels + (resolved by name via `parse_channel_name`) take the center slice; all other + channels are max-projected. Pair with `in_stack_depth: 1` on the encoder. + Center the `z_range` on the focus plane to control which planes are collapsed. +- **Pixel-size rescaling.** When the inference dataset's pixel size differs from + the model's training pixel size, set `reference_pixel_size` (µm/px) so a larger + patch covering the same physical area is extracted and bilinearly resized to + `final_yx_patch_size`. Leave unset for same-resolution datasets. +- **`embedding_key`.** Use `features` for the backbone output (most models) and + `projections` for frozen-backbone MLP-head models, which writes + `obsm["X_projections"]` instead. +- **`split-embeddings` requires `obs["experiment"]`** on the combined store. For a + single-experiment predict run the split step is optional — the combined + `embeddings.zarr` is already per-experiment. +- Downstream analyses (dimensionality reduction, linear classifiers, MMD, + pseudotime) consume the per-experiment zarrs and are documented in + [evaluation.md](evaluation.md) and [pseudotime.md](pseudotime.md). + +## Triplet vs parquet (MultiExperimentDataModule) + +| Aspect | Triplet path (this doc) | Parquet path (evaluation.md) | +| ------------------- | ------------------------------------------------ | -------------------------------------------------- | +| Data entry point | `data_path` zarr + `tracks_path` | `cell_index.parquet` (built + preprocessed) | +| Setup cost | reads tracking + zarr shape at init | reads parquet only at init | +| Focus / z window | caller sets `z_range`; `z_reduction` collapses | per-FOV `z_extraction_window` from `focus_slice` | +| 2D-from-3D | `z_reduction: mip` / `center` | `BatchedChannelWiseZReductiond` in normalizations | +| Pixel rescaling | `reference_pixel_size` | `reference_pixel_size_xy_um` | +| Best for | ad-hoc predict over a single zarr + tracking | large multi-experiment runs, reproducible recipes | From dc39073a6fff39be65c8a706588cc600e2596a1a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 13:17:53 -0700 Subject: [PATCH 06/13] docs(dynaclr): add triplet z-projection dataloader inspection script Notebook-style script (no CLI) that loads a 3D OME-Zarr + tracking, extracts a per-FOV focus-centered Z window, collapses it via z_reduction, applies a random affine so anchor/positive diverge, and visualizes a couple of batches to a PNG. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../triplet_dataloader_zprojection.py | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 applications/dynaclr/scripts/dataloader_inspection/triplet_dataloader_zprojection.py diff --git a/applications/dynaclr/scripts/dataloader_inspection/triplet_dataloader_zprojection.py b/applications/dynaclr/scripts/dataloader_inspection/triplet_dataloader_zprojection.py new file mode 100644 index 000000000..83773d617 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/triplet_dataloader_zprojection.py @@ -0,0 +1,137 @@ +"""Inspect TripletDataModule with focus-centered z_range + on-the-fly Z-reduction. + +Loads a 3D OME-Zarr + tracking, extracts a focus-centered Z window per FOV +(from ``focus_slice`` zattrs), collapses it to a single slice via ``z_reduction`` +(so a 3D dataset feeds a 2D model), and visualizes a couple of batches. + +Run cell-by-cell in an interactive window (VS Code / Jupyter) or top-to-bottom. +""" + +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from iohub import open_ome_zarr + +from viscy_data import TripletDataModule +from viscy_transforms import BatchedRandAffined, NormalizeSampled + +# --- edit these --------------------------------------------------------------- +DATASET_DIR = "/hpc/projects/organelle_phenotyping/datasets/2026_04_08_A549_G3BP1_ZIKV" +DATA_PATH = f"{DATASET_DIR}/2026_04_08_A549_G3BP1_ZIKV.zarr" +TRACKS_PATH = f"{DATASET_DIR}/tracking.zarr" +SOURCE_CHANNEL = ["Phase3D", "raw GFP EX488 EM525-45"] # label-free + fluorescence +FOCUS_CHANNEL = "Phase3D" # channel whose focus plane centers the Z window +Z_EXTRACTION_WINDOW = 15 # Z slices extracted, centered on each FOV's focus plane +Z_FOCUS_OFFSET = 0.5 # fraction of the window below the focus plane +Z_REDUCTION = "mip" # "mip" (fluorescence) / center-slice (label-free), or None +YX_PATCH = (256, 256) # extracted == final (no YX rescale; set reference_pixel_size to rescale) +BATCH_SIZE = 8 +N_BATCHES = 2 +OUTPUT_DIR = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/scripts/dataloader_inspection/output" +OUTPUT_PNG = f"{OUTPUT_DIR}/triplet_zprojection.png" +# ------------------------------------------------------------------------------ + +# %% [markdown] +# Peek at the store: channels, shape, and the per-FOV focus plane the window centers on. + +# %% +with open_ome_zarr(DATA_PATH) as plate: + print("channels:", plate.channel_names) + name, pos = next(iter(plate.positions())) + print("first FOV:", name, "| TCZYX:", pos["0"].shape) + focus = pos.zattrs.get("focus_slice", {}).get(FOCUS_CHANNEL, {}) + print("per-FOV z_focus_mean:", focus.get("fov_statistics", {}).get("z_focus_mean")) + +# %% [markdown] +# Build the datamodule. ``z_extraction_window`` (not ``z_range``) makes each FOV's +# window center on its own focus plane; ``z_reduction`` collapses Z to 1. + +# %% +dm = TripletDataModule( + data_path=DATA_PATH, + tracks_path=TRACKS_PATH, + source_channel=SOURCE_CHANNEL, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=Z_FOCUS_OFFSET, + focus_channel=FOCUS_CHANNEL, + initial_yx_patch_size=YX_PATCH, + final_yx_patch_size=YX_PATCH, + z_reduction=Z_REDUCTION, + batch_size=BATCH_SIZE, + num_workers=0, + normalizations=[ + NormalizeSampled( + keys=SOURCE_CHANNEL, + level="fov_statistics", + subtrahend="mean", + divisor="std", + ) + ], + augmentations=[ + # Random affine so anchor and positive (a clone of the anchor when + # time_interval="any") diverge — applied per-key with fresh random + # params, on the 3D stack before z_reduction. prob=1.0 = always fire. + BatchedRandAffined( + keys=SOURCE_CHANNEL, + prob=1.0, + # rotate_range is (Z, Y, X) radians: the Z entry is the in-plane (XY) + # rotation. Rotating about Y/X would tumble the stack out of plane and + # collapse to a strip after MIP, so keep those at 0. + rotate_range=(3.14159, 0.0, 0.0), # full in-plane rotation + scale_range=(0.8, 1.2), + translate_range=(0.0, 0.1, 0.1), # up to 10% YX shift, no Z shift + ) + ], +) +dm.setup(stage="fit") + +# Per-FOV focus-centered windows resolved at setup (one slice per FOV). +resolved = dm.train_dataset.z_range +print(f"resolved {len(resolved)} per-FOV z-windows; e.g.:") +for fov, z_slice in list(resolved.items())[:5]: + print(f" {fov}: {z_slice}") + +# %% [markdown] +# Pull a couple of batches and check shapes. After z_reduction the Z axis is 1. + +# %% +batches = [] +for i, batch in enumerate(dm.train_dataloader()): + dm.on_after_batch_transfer(batch, 0) # normalize + Z-reduce on the batch + print(f"batch {i}: anchor {tuple(batch['anchor'].shape)} (expect Z=1)") + batches.append(batch) + if i + 1 >= N_BATCHES: + break + +# %% [markdown] +# Visualize anchor / positive / negative for the first few cells of each batch. +# Each panel is one channel of the Z-reduced (B, C, 1, Y, X) patch. + +# %% +keys = [k for k in ("anchor", "positive", "negative") if k in batches[0]] +n_cells = min(4, BATCH_SIZE) +n_rows = N_BATCHES * n_cells +n_cols = len(keys) * len(SOURCE_CHANNEL) +fig, axes = plt.subplots(n_rows, n_cols, figsize=(2.4 * n_cols, 2.4 * n_rows), squeeze=False) + +for bi, batch in enumerate(batches): + for ci in range(n_cells): + row = bi * n_cells + ci + col = 0 + for key in keys: + patch = batch[key][ci] # (C, 1, Y, X) + for ch_idx, ch in enumerate(SOURCE_CHANNEL): + img = patch[ch_idx, 0].cpu().numpy() # (Y, X) + lo, hi = (img.min(), img.max()) if img.max() > img.min() else (0.0, 1.0) + ax = axes[row, col] + ax.imshow(img, cmap="gray", vmin=lo, vmax=hi) + ax.set_title(f"b{bi} c{ci}\n{key}/{ch.split()[0]}", fontsize=7) + ax.axis("off") + col += 1 + +fig.suptitle(f"Triplet patches — z_extraction_window={Z_EXTRACTION_WINDOW}, z_reduction={Z_REDUCTION}", fontsize=10) +fig.tight_layout() +Path(OUTPUT_PNG).parent.mkdir(parents=True, exist_ok=True) +fig.savefig(OUTPUT_PNG, dpi=120, bbox_inches="tight") +print("saved:", OUTPUT_PNG) From a14a0ac12d43d7a43eb96eafe0a6ee18e84d4af2 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 14:07:46 -0700 Subject: [PATCH 07/13] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- packages/viscy-data/src/viscy_data/triplet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/viscy-data/src/viscy_data/triplet.py b/packages/viscy-data/src/viscy_data/triplet.py index 2d7f8a32f..4d42601e2 100644 --- a/packages/viscy-data/src/viscy_data/triplet.py +++ b/packages/viscy-data/src/viscy_data/triplet.py @@ -553,9 +553,7 @@ def __init__( self.z_range = slice(*z_range) if z_range is not None else None self._z_extraction_window = z_extraction_window self._z_focus_offset = z_focus_offset - self._focus_channel = focus_channel or ( - source_channel[0] if isinstance(source_channel, (list, tuple)) else source_channel - ) + self._focus_channel = focus_channel or self.source_channel[0] self.tracks_path = Path(tracks_path) self.initial_yx_patch_size = initial_yx_patch_size self._include_wells = fit_include_wells From 52c72a1d4c3e6402517a3aad9012c796b165eeff Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 14:07:56 -0700 Subject: [PATCH 08/13] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- packages/viscy-data/tests/test_triplet.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/viscy-data/tests/test_triplet.py b/packages/viscy-data/tests/test_triplet.py index 7d0be1e6c..483dac897 100644 --- a/packages/viscy-data/tests/test_triplet.py +++ b/packages/viscy-data/tests/test_triplet.py @@ -184,8 +184,6 @@ def expected_window(z_focus_mean): dm.on_after_batch_transfer(batch, 0) assert batch["anchor"].shape[2] == z_extraction_window break - break - @mark.parametrize("z_reduction", ["mip", "center"]) def test_datamodule_z_reduction(preprocessed_hcs_dataset, tracks_hcs_dataset, z_reduction): From 6d320aeae359e14774d356a57bd079bca6dc05ba Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 14:08:54 -0700 Subject: [PATCH 09/13] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- packages/viscy-data/src/viscy_data/triplet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/viscy-data/src/viscy_data/triplet.py b/packages/viscy-data/src/viscy_data/triplet.py index 4d42601e2..11f746540 100644 --- a/packages/viscy-data/src/viscy_data/triplet.py +++ b/packages/viscy-data/src/viscy_data/triplet.py @@ -529,6 +529,10 @@ def __init__( warnings.warn("Using more than 1 thread worker will likely degrade performance.") if (z_range is None) == (z_extraction_window is None): raise ValueError("Provide exactly one of 'z_range' or 'z_extraction_window'.") + if z_extraction_window is not None and z_extraction_window <= 0: + raise ValueError("'z_extraction_window' must be a positive integer.") + if z_extraction_window is not None and not (0.0 <= z_focus_offset <= 1.0): + raise ValueError("'z_focus_offset' must be between 0.0 and 1.0 (inclusive).") # Extraction window width is known without opening the zarr: it is the # explicit z_range span or z_extraction_window. Per-FOV focus centering # (when z_extraction_window is set) is resolved at setup() time, where From 2a1b327a30a466941727d8c77b95eadfba7ff3c9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 14:09:06 -0700 Subject: [PATCH 10/13] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- applications/dynaclr/docs/DAGs/inference_triplet.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/applications/dynaclr/docs/DAGs/inference_triplet.md b/applications/dynaclr/docs/DAGs/inference_triplet.md index d031703ba..2aa6bddb1 100644 --- a/applications/dynaclr/docs/DAGs/inference_triplet.md +++ b/applications/dynaclr/docs/DAGs/inference_triplet.md @@ -177,9 +177,10 @@ return_predictions: false # writer persists to zarr; don't hold in | Aspect | Triplet path (this doc) | Parquet path (evaluation.md) | | ------------------- | ------------------------------------------------ | -------------------------------------------------- | -| Data entry point | `data_path` zarr + `tracks_path` | `cell_index.parquet` (built + preprocessed) | -| Setup cost | reads tracking + zarr shape at init | reads parquet only at init | -| Focus / z window | caller sets `z_range`; `z_reduction` collapses | per-FOV `z_extraction_window` from `focus_slice` | -| 2D-from-3D | `z_reduction: mip` / `center` | `BatchedChannelWiseZReductiond` in normalizations | +| Aspect | Triplet path (this doc) | Parquet path (evaluation.md) | +| ------------------- | ------------------------------------------------------------------ | -------------------------------------------------- | +| Data entry point | `data_path` zarr + `tracks_path` | `cell_index.parquet` (built + preprocessed) | +| Setup cost | reads tracking + zarr shape at init | reads parquet only at init | +| Focus / z window | explicit `z_range` or per-FOV `z_extraction_window` from `focus_slice`; `z_reduction` collapses | per-FOV `z_extraction_window` from `focus_slice` | | Pixel rescaling | `reference_pixel_size` | `reference_pixel_size_xy_um` | | Best for | ad-hoc predict over a single zarr + tracking | large multi-experiment runs, reproducible recipes | From dbe22f25da4f28fda81dd71b3f1247c09bf769d5 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 14:09:16 -0700 Subject: [PATCH 11/13] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- packages/viscy-data/src/viscy_data/triplet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/viscy-data/src/viscy_data/triplet.py b/packages/viscy-data/src/viscy_data/triplet.py index 11f746540..d27b11ba5 100644 --- a/packages/viscy-data/src/viscy_data/triplet.py +++ b/packages/viscy-data/src/viscy_data/triplet.py @@ -577,7 +577,7 @@ def __init__( if reference_pixel_size is not None: inference_pixel_size = _read_pixel_size(data_path) scale = reference_pixel_size / inference_pixel_size - self.initial_yx_patch_size = tuple(round(s * scale) for s in final_yx_patch_size) + self.initial_yx_patch_size = tuple(max(1, int(round(s * scale))) for s in final_yx_patch_size) _logger.info( f"Pixel size rescaling enabled: " f"reference={reference_pixel_size:.4f} µm/px, " From 4a2de142dd314f0f431ebdc999bdf5371a5702ed Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 29 Jun 2026 14:15:48 -0700 Subject: [PATCH 12/13] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- packages/viscy-data/src/viscy_data/triplet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/viscy-data/src/viscy_data/triplet.py b/packages/viscy-data/src/viscy_data/triplet.py index d27b11ba5..be04159df 100644 --- a/packages/viscy-data/src/viscy_data/triplet.py +++ b/packages/viscy-data/src/viscy_data/triplet.py @@ -126,7 +126,7 @@ def _resolve_per_fov_z_ranges( fov_stats = pos.zattrs.get("focus_slice", {}).get(focus_channel, {}).get("fov_statistics", {}) z_focus_mean = fov_stats.get("z_focus_mean") z_ranges[fov_name] = _focus_window(z_focus_mean, z_total, z_extraction_window, z_focus_offset) - _logger.info( + _logger.debug( "FOV '%s': focus-centered z_range=%s (z_total=%d, z_focus_mean=%s, window=%d).", fov_name, (z_ranges[fov_name].start, z_ranges[fov_name].stop), From 5b81c2c8bce71bb2321b28f84779eb7fdc4abaac Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 30 Jun 2026 15:17:08 -0700 Subject: [PATCH 13/13] fix(viscy-data): address triplet z-reduction/rescaling review findings - Declare viscy-transforms in viscy-data's triplet extra and hoist the BatchedZoomd/BatchedChannelWiseZReductiond imports to module top; the inline imports were masking a missing runtime dependency that would ImportError for anyone using z_reduction or reference_pixel_size. - Raise in _resolve_per_fov_z_ranges when a FOV's Z is smaller than z_extraction_window, instead of silently emitting a narrower window that breaks cross-FOV batch stacking. - Remove a dead duplicate break in test_focus_centered_z_range. - Correct the inference DAG doc: embeddings live in .X (the embedding_key array), mirrored to obsm["X_backbone"]/["X_projections"]. uv.lock intentionally omitted: the workspace glob currently entangles the untracked applications/eet package, so a clean regen of the single viscy-transforms edge is not possible until eet is committed or removed. Co-Authored-By: Claude Opus 4.8 (1M context) --- applications/dynaclr/docs/DAGs/inference_triplet.md | 5 +++-- packages/viscy-data/pyproject.toml | 2 +- packages/viscy-data/src/viscy_data/triplet.py | 13 ++++++++----- packages/viscy-data/tests/test_triplet.py | 1 + 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/applications/dynaclr/docs/DAGs/inference_triplet.md b/applications/dynaclr/docs/DAGs/inference_triplet.md index 2aa6bddb1..52aa3aab4 100644 --- a/applications/dynaclr/docs/DAGs/inference_triplet.md +++ b/applications/dynaclr/docs/DAGs/inference_triplet.md @@ -38,7 +38,8 @@ viscy predict --config configs/prediction/predict_triplet.yml │ ContrastiveModule.predict_step → backbone features (+ projections) │ EmbeddingWriter accumulates (features, index) and writes one combined store ▼ -embeddings.zarr (AnnData: obsm["X_backbone"], obs = fov_name/track_id/t/...) +embeddings.zarr (AnnData: .X = embedding_key array, mirrored to obsm["X_backbone"] + /["X_projections"]; obs = fov_name/track_id/t/...) │ ▼ dynaclr split-embeddings --input embeddings.zarr --output-dir embeddings/ @@ -82,7 +83,7 @@ downstream eval (CPU/GPU, per analysis) | Pixel data (TCZYX) | dataset.zarr on VAST | data prep | | Cell tracks (track_id, t, y, x) | tracking.zarr / CSV on VAST | data prep | | Normalization stats (per FOV) | dataset.zarr FOV `zattrs["normalization"]` | `viscy preprocess` | -| Backbone embeddings | `embeddings.zarr` → `obsm["X_backbone"]` | `viscy predict` | +| Backbone embeddings | `embeddings.zarr` → `.X` (+ `obsm["X_backbone"]`) | `viscy predict` | | Cell index (fov_name/track_id/t) | `embeddings.zarr` → `obs` | `viscy predict` | | Per-experiment embeddings | `embeddings/{experiment}.zarr` | `split-embeddings` | diff --git a/packages/viscy-data/pyproject.toml b/packages/viscy-data/pyproject.toml index 959f9543e..504e36aa5 100644 --- a/packages/viscy-data/pyproject.toml +++ b/packages/viscy-data/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ optional-dependencies.all = [ "viscy-data[livecell,mmap,triplet]" ] optional-dependencies.livecell = [ "pycocotools", "tifffile", "torchvision" ] optional-dependencies.mmap = [ "tensordict" ] -optional-dependencies.triplet = [ "tensorstore" ] +optional-dependencies.triplet = [ "tensorstore", "viscy-transforms" ] urls.Homepage = "https://github.com/mehta-lab/VisCy" urls.Issues = "https://github.com/mehta-lab/VisCy/issues" urls.Repository = "https://github.com/mehta-lab/VisCy" diff --git a/packages/viscy-data/src/viscy_data/triplet.py b/packages/viscy-data/src/viscy_data/triplet.py index be04159df..6bb81ad0f 100644 --- a/packages/viscy-data/src/viscy_data/triplet.py +++ b/packages/viscy-data/src/viscy_data/triplet.py @@ -34,8 +34,10 @@ _read_norm_meta, _transform_channel_wise, ) +from viscy_data.channel_utils import parse_channel_name from viscy_data.hcs import HCSDataModule from viscy_data.select import _filter_fovs, _filter_wells +from viscy_transforms import BatchedChannelWiseZReductiond, BatchedZoomd _logger = logging.getLogger("lightning.pytorch") @@ -123,6 +125,12 @@ def _resolve_per_fov_z_ranges( for pos in positions: fov_name = pos.zgroup.name.strip("/") z_total = pos["0"].shape[2] + if z_total < z_extraction_window: + raise ValueError( + f"FOV '{fov_name}' has Z={z_total} < z_extraction_window={z_extraction_window}; " + "its window would be narrower than the others and break cross-FOV batch stacking. " + "Lower z_extraction_window or exclude this FOV." + ) fov_stats = pos.zattrs.get("focus_slice", {}).get(focus_channel, {}).get("fov_statistics", {}) z_focus_mean = fov_stats.get("z_focus_mean") z_ranges[fov_name] = _focus_window(z_focus_mean, z_total, z_extraction_window, z_focus_offset) @@ -586,8 +594,6 @@ def __init__( f"Extracting {self.initial_yx_patch_size} px patches " f"and resizing to {final_yx_patch_size} px." ) - from viscy_transforms import BatchedZoomd - scale_yx = ( final_yx_patch_size[0] / self.initial_yx_patch_size[0], final_yx_patch_size[1] / self.initial_yx_patch_size[1], @@ -602,9 +608,6 @@ def __init__( ) if z_reduction is not None: - from viscy_data.channel_utils import parse_channel_name - from viscy_transforms import BatchedChannelWiseZReductiond - labelfree_keys = [ch for ch in self.source_channel if parse_channel_name(ch)["channel_type"] == "labelfree"] mip_keys = [ch for ch in self.source_channel if ch not in labelfree_keys] _logger.info( diff --git a/packages/viscy-data/tests/test_triplet.py b/packages/viscy-data/tests/test_triplet.py index 483dac897..6a6fa939f 100644 --- a/packages/viscy-data/tests/test_triplet.py +++ b/packages/viscy-data/tests/test_triplet.py @@ -185,6 +185,7 @@ def expected_window(z_focus_mean): assert batch["anchor"].shape[2] == z_extraction_window break + @mark.parametrize("z_reduction", ["mip", "center"]) def test_datamodule_z_reduction(preprocessed_hcs_dataset, tracks_hcs_dataset, z_reduction): """z_reduction collapses the z_range window to a single slice per channel.