diff --git a/docs/user_guide/3d_volumes.md b/docs/user_guide/3d_volumes.md index 6f79e0a..8acc472 100644 --- a/docs/user_guide/3d_volumes.md +++ b/docs/user_guide/3d_volumes.md @@ -2,7 +2,12 @@ This guide covers stack alignment for 3D volumetric data acquired through z-stack scanning in 2-photon microscopy. -PyFlowReg performs **2D frame-by-frame motion correction**. It does not perform true 3D volumetric registration. For z-stack alignment, PyFlowReg uses an adaptive reference approach where the reference frame is updated slice-by-slice as you move through the stack. +PyFlowReg performs **2D frame-by-frame motion correction**. It does not perform true +3D volumetric registration. For z-stack alignment, PyFlowReg uses an adaptive +reference approach where the reference frame is updated slice-by-slice as you move +through the stack. The separate `pyflowreg.z_align` workflow can then estimate +z-shifts against a reference stack and optionally write a z-corrected signal plus +a z-shift-only simulation video. ## Z-Stack Acquisition Strategy @@ -285,7 +290,8 @@ PyFlowReg's z-stack approach vs. true 3D volumetric registration: **PyFlowReg (2D + adaptive reference)**: - Registers each frame as a 2D image - Adapts reference slice-by-slice -- Cannot correct through-plane (z-axis) motion +- The core 2D workflow does not correct through-plane (z-axis) motion; use + `pyflowreg.z_align` for reference-stack z-shift estimation/correction - Fast and memory-efficient - Works well when z-motion is minimal and xy-motion dominates diff --git a/examples/z_align_config.toml b/examples/z_align_config.toml new file mode 100644 index 0000000..f513ed2 --- /dev/null +++ b/examples/z_align_config.toml @@ -0,0 +1,62 @@ +# Example z-align configuration matching the MATLAB z-shift scripts. +# Assumes these files are present in `root`: +# - compensated.tiff +# - file_00004_00001.tif + +# === Data Location === +root = "." # Run from folder containing MATLAB-style inputs +input_file = "compensated.tiff" # Recording to z-correct +volume_input_file = "file_00004_00001.tif" # Input for reference volume creation +reference_source_file = "compensated.tiff" # Used to compute Stage-1 reference image + +# MATLAB snippet: get_video_file_reader("compensated.tiff", 10, 20); mean(first 2000) +reference_source_frames = 2000 +reference_source_buffer_size = 10 +reference_source_bin_size = 20 + +# === Output Paths === +# Kept identical to MATLAB output names/locations. +output_root = "." +volume_output_dir = "aligned_stack" +recording_prealigned_output_dir = "prealigned_recording" +z_shift_file = "z_shift.HDF5" +corrected_output_file = "compensated_shift_corrected.tif" +simulated_output_file = "simulated_from_z.tif" + +# === Pipeline Controls === +resume = true +prealign_stack = true # 2D-align the reference stack before z estimation +prealign_recording = false # optionally 2D-align input_file before z estimation +write_corrected = true # direct z-corrected signal +write_simulated = true # z-shift-only video interpolated from the volume + +# === Stage 1: Reference Volume Build (compensate_recording) === +stage1_alpha = 5.0 +stage1_quality_setting = "quality" +stage1_buffer_size = 500 +stage1_bin_size = 1 +stage1_update_reference = true +# Set to the number of scans per z slice to process one slice per batch. +# When set, Stage 1 forces update_reference=true and uses this as volume_bin_size. +# stack_scans_per_slice = 9 +flow_backend = "flowreg" + +# === Stage 2: Patch-Based z Estimation === +input_buffer_size = 50 +input_bin_size = 1 +volume_buffer_size = 500 +volume_bin_size = 1 + +win_half = 10 +patch_size = 128 +overlap = 0.75 + +spatial_sigma = 1.5 +temporal_sigma = 1.5 +z_smooth_sigma_spatial = 5.0 +z_smooth_sigma_temporal = 1.5 +parabolic_tau_scale = 1e-3 + +output_dtype = "uint16" + +[backend_params] diff --git a/examples/z_shift_demo.py b/examples/z_shift_demo.py new file mode 100644 index 0000000..f1f6da5 --- /dev/null +++ b/examples/z_shift_demo.py @@ -0,0 +1,96 @@ +""" +Z-Shift Demo - MATLAB-style z alignment workflow + +This example assumes the same input files as the MATLAB scripts: +- compensated.tiff (time recording to z-correct) +- file_00004_00001.tif (stack/source for reference volume creation) + +Outputs (matching MATLAB names) are written to the working directory: +- aligned_stack/compensated.HDF5 +- z_shift.HDF5 +- compensated_shift_corrected.tif +- simulated_from_z.tif +""" + +from pathlib import Path + +from pyflowreg.z_align import ZAlignConfig, run_all_stages + + +def main(): + root = Path(".").resolve() + + required = [root / "compensated.tiff", root / "file_00004_00001.tif"] + missing = [p.name for p in required if not p.exists()] + if missing: + raise FileNotFoundError( + "Missing required input files in working directory: " + ", ".join(missing) + ) + + config = ZAlignConfig( + root=root, + # MATLAB-style inputs + input_file="compensated.tiff", + volume_input_file="file_00004_00001.tif", + reference_source_file="compensated.tiff", + # MATLAB script: read first 2000 frames with buffer/bin (10, 20) + reference_source_frames=2000, + reference_source_buffer_size=10, + reference_source_bin_size=20, + # Keep MATLAB output paths/names + output_root=".", + volume_output_dir="aligned_stack", + recording_prealigned_output_dir="prealigned_recording", + z_shift_file="z_shift.HDF5", + corrected_output_file="compensated_shift_corrected.tif", + simulated_output_file="simulated_from_z.tif", + # Stage toggles: + # write_corrected=True -> direct z-corrected signal + # write_simulated=True -> z-shift-only video interpolated from the volume + prealign_stack=True, + prealign_recording=False, + write_corrected=True, + write_simulated=True, + resume=True, + # Stage 1 (volume build) defaults from MATLAB snippet + stage1_alpha=5.0, + stage1_quality_setting="quality", + stage1_buffer_size=500, + stage1_bin_size=1, + stage1_update_reference=True, + # Set to scans per z slice when the stack stores repeated scans per slice. + stack_scans_per_slice=None, + # Stage 2 (patch-based z estimation) defaults from MATLAB snippet + input_buffer_size=50, + input_bin_size=1, + win_half=10, + patch_size=128, + overlap=0.75, + spatial_sigma=1.5, + temporal_sigma=1.5, + z_smooth_sigma_spatial=5.0, + z_smooth_sigma_temporal=1.5, + ) + + print("=" * 60) + print("Z-SHIFT DEMO") + print("=" * 60) + print(f"Root: {root}") + print("Input recording: compensated.tiff") + print("Volume source: file_00004_00001.tif") + + outputs = run_all_stages(config) + + print("\n" + "=" * 60) + print("DEMO COMPLETE") + print("=" * 60) + print(f"Reference volume: {outputs['volume_path']}") + print(f"Z-shift file: {outputs['z_shift_path']}") + if outputs["corrected_path"] is not None: + print(f"Corrected signal: {outputs['corrected_path']}") + if outputs["simulated_path"] is not None: + print(f"Z-shift simulation:{outputs['simulated_path']}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 2deb07b..467a60c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ [project.scripts] pyflowreg-session = "pyflowreg.session.cli:main" +pyflowreg-z-align = "pyflowreg.z_align.cli:main" [project.optional-dependencies] vis = [ diff --git a/src/pyflowreg/z_align/__init__.py b/src/pyflowreg/z_align/__init__.py new file mode 100644 index 0000000..3a68288 --- /dev/null +++ b/src/pyflowreg/z_align/__init__.py @@ -0,0 +1,28 @@ +""" +Z-alignment pipeline for depth-shift correction. + +The ``z_align`` module implements a stage-based workflow that mirrors the +MATLAB prototype used for z-shift estimation/correction: + +1. Build/load a reference volume +2. Estimate per-pixel z shifts and optionally write z-corrected output +3. Optionally simulate a z-shift-only recording from the estimated z shifts +""" + +from pyflowreg.z_align.config import ZAlignConfig +from pyflowreg.z_align.pipeline import ( + run_recording_prealignment, + run_stage1, + run_stage2, + run_stage3, + run_all_stages, +) + +__all__ = [ + "ZAlignConfig", + "run_recording_prealignment", + "run_stage1", + "run_stage2", + "run_stage3", + "run_all_stages", +] diff --git a/src/pyflowreg/z_align/cli.py b/src/pyflowreg/z_align/cli.py new file mode 100644 index 0000000..491459e --- /dev/null +++ b/src/pyflowreg/z_align/cli.py @@ -0,0 +1,136 @@ +""" +Command-line interface for z-alignment workflows. + +Provides the ``pyflowreg-z-align`` command. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Any, Dict, Optional + +from pyflowreg.z_align.config import ZAlignConfig +from pyflowreg.z_align.pipeline import ( + run_all_stages, + run_stage1, + run_stage2, + run_stage3, +) + + +def _parse_value(raw: str) -> Any: + """Parse CLI override values.""" + lower = raw.lower() + if lower == "true": + return True + if lower == "false": + return False + + for cast in (int, float): + try: + return cast(raw) + except ValueError: + pass + + # Optional JSON parsing for lists/dicts + if raw.startswith("[") or raw.startswith("{"): + try: + return json.loads(raw) + except json.JSONDecodeError: + return raw + return raw + + +def _parse_overrides(params: Optional[list[str]]) -> Dict[str, Any]: + """Parse KEY=VALUE CLI overrides.""" + overrides: Dict[str, Any] = {} + if not params: + return overrides + + for item in params: + if "=" not in item: + print(f"Warning: ignoring malformed override '{item}' (expected KEY=VALUE)") + continue + key, value = item.split("=", 1) + overrides[key] = _parse_value(value) + return overrides + + +def cmd_run(args: argparse.Namespace) -> None: + """Handle the ``run`` subcommand.""" + config_path = Path(args.config) + if not config_path.exists(): + print(f"Error: configuration file not found: {config_path}") + sys.exit(1) + + config = ZAlignConfig.from_file(config_path) + overrides = _parse_overrides(args.of_params) + + if args.stage == "1": + run_stage1(config, overrides or None) + return + + if args.stage == "2": + run_stage2(config) + return + + if args.stage == "3": + run_stage3(config) + return + + run_all_stages(config, overrides or None) + + +def main() -> None: + """CLI entry point.""" + parser = argparse.ArgumentParser( + description="PyFlowReg z-shift alignment pipeline", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Run full z-align workflow + pyflowreg-z-align run --config z_align.toml + + # Run only z-shift estimation/correction + pyflowreg-z-align run --config z_align.toml --stage 2 + + # Override stage-1 OFOptions from CLI + pyflowreg-z-align run --config z_align.toml --of-params alpha=8 quality_setting=balanced + """, + ) + + subparsers = parser.add_subparsers(dest="command", help="Command to run") + + run_parser = subparsers.add_parser("run", help="Run z-align processing") + run_parser.add_argument( + "--config", + "-c", + required=True, + help="Path to z-align config file (.toml/.yaml/.yml)", + ) + run_parser.add_argument( + "--stage", + "-s", + choices=["1", "2", "3"], + help="Run only one stage (default: run all applicable stages)", + ) + run_parser.add_argument( + "--of-params", + nargs="*", + metavar="KEY=VALUE", + help="Override stage-1 OFOptions parameters", + ) + run_parser.set_defaults(func=cmd_run) + + args = parser.parse_args() + if not hasattr(args, "func"): + parser.print_help() + sys.exit(1) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/src/pyflowreg/z_align/config.py b/src/pyflowreg/z_align/config.py new file mode 100644 index 0000000..2982c13 --- /dev/null +++ b/src/pyflowreg/z_align/config.py @@ -0,0 +1,388 @@ +""" +Configuration model for z-alignment workflows. + +The z-align pipeline mirrors the MATLAB prototypes with three stages: +1) Build or load a reference volume. +2) Estimate per-pixel z-shifts and optionally write a z-corrected signal. +3) Optionally simulate a z-shift-only recording from volume + z-shifts. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import numpy as np +from pydantic import BaseModel, Field, field_validator + + +_STAGE1_PROTECTED_OF_FIELDS = { + "input_file", + "output_path", + "output_format", + "output_file_name", + "reference_frames", +} + +_RECORDING_PREALIGN_PROTECTED_OF_FIELDS = { + "input_file", + "output_path", + "output_format", + "output_file_name", +} + + +class ZAlignConfig(BaseModel): + """Configuration for z-shift alignment and correction.""" + + # Core paths + root: Path + input_file: Path + volume_input_file: Optional[Path] = None + reference_volume: Optional[Path] = None + reference_source_file: Optional[Path] = None + + # Stage 1 reference-frame estimation from a source recording + reference_source_frames: int = 2000 + reference_source_buffer_size: int = 10 + reference_source_bin_size: int = 20 + + # Outputs + output_root: Path = Field(default=Path("z_align_outputs")) + volume_output_dir: Path = Field(default=Path("aligned_stack")) + recording_prealigned_output_dir: Path = Field(default=Path("prealigned_recording")) + z_shift_file: Path = Field(default=Path("z_shift.HDF5")) + corrected_output_file: Path = Field(default=Path("compensated_shift_corrected.tif")) + simulated_output_file: Path = Field(default=Path("simulated_from_z.tif")) + + # Control flags + resume: bool = True + prealign_stack: bool = True + prealign_recording: bool = False + write_corrected: bool = True + write_simulated: bool = True + + # Stage 1 (volume build via compensate_recording / OFOptions) + stage1_alpha: float = 5.0 + stage1_quality_setting: str = "quality" + stage1_buffer_size: int = 500 + stage1_bin_size: int = 1 + stage1_update_reference: bool = True + stack_scans_per_slice: Optional[int] = None + flow_backend: str = "flowreg" + backend_params: Dict[str, Any] = Field(default_factory=dict) + stage1_flow_options: Optional[Union[Dict[str, Any], Path]] = None + recording_prealign_flow_options: Optional[Union[Dict[str, Any], Path]] = None + + # Stage 2 (patch-based z estimation) + input_buffer_size: int = 50 + input_bin_size: int = 1 + volume_buffer_size: int = 500 + volume_bin_size: int = 1 + win_half: int = 10 + patch_size: int = 128 + overlap: float = 0.75 + spatial_sigma: float = 1.5 + temporal_sigma: float = 1.5 + z_smooth_sigma_spatial: float = 5.0 + z_smooth_sigma_temporal: float = 1.5 + parabolic_tau_scale: float = 1e-3 + output_dtype: str = "uint16" + n_jobs: int = -1 + parallelization: str = "sequential" + + @field_validator( + "root", + "input_file", + "volume_input_file", + "reference_volume", + "reference_source_file", + "output_root", + "volume_output_dir", + "recording_prealigned_output_dir", + "z_shift_file", + "corrected_output_file", + "simulated_output_file", + mode="before", + ) + @classmethod + def _to_path(cls, v): + if v is None or isinstance(v, Path): + return v + if isinstance(v, str): + return Path(v) + return v + + @field_validator( + "stage1_flow_options", + "recording_prealign_flow_options", + mode="before", + ) + @classmethod + def _normalize_flow_options(cls, v): + if v is None: + return None + if isinstance(v, dict): + return v + if isinstance(v, Path): + return v + if isinstance(v, str): + stripped = v.strip() + return Path(stripped) if stripped else None + raise TypeError("Flow options must be a mapping or path") + + @field_validator("root") + @classmethod + def _validate_root(cls, v: Path): + if not v.exists(): + raise ValueError(f"Root directory does not exist: {v}") + if not v.is_dir(): + raise ValueError(f"Root path is not a directory: {v}") + return v + + @field_validator( + "reference_source_frames", + "reference_source_buffer_size", + "reference_source_bin_size", + "stage1_buffer_size", + "stage1_bin_size", + "stack_scans_per_slice", + "input_buffer_size", + "input_bin_size", + "volume_buffer_size", + "volume_bin_size", + "win_half", + "patch_size", + ) + @classmethod + def _validate_positive_int(cls, v: Optional[int]): + if v is None: + return v + if v < 1: + raise ValueError("Value must be >= 1") + return v + + @field_validator( + "stage1_alpha", + "spatial_sigma", + "temporal_sigma", + "z_smooth_sigma_spatial", + "z_smooth_sigma_temporal", + "parabolic_tau_scale", + ) + @classmethod + def _validate_positive_float(cls, v: float): + if v <= 0: + raise ValueError("Value must be > 0") + return v + + @field_validator("overlap") + @classmethod + def _validate_overlap(cls, v: float): + if not (0.0 <= v < 1.0): + raise ValueError("overlap must satisfy 0 <= overlap < 1") + return v + + @field_validator("output_dtype") + @classmethod + def _validate_output_dtype(cls, v: str): + try: + np.dtype(v) + except TypeError as exc: + raise ValueError(f"Invalid output_dtype: {v}") from exc + return v + + @field_validator("n_jobs") + @classmethod + def _validate_n_jobs(cls, v: int): + if v == -1: + return v + if v < 1: + raise ValueError("n_jobs must be -1 or >= 1") + return v + + @field_validator("parallelization", mode="before") + @classmethod + def _validate_parallelization(cls, v): + if not isinstance(v, str): + raise TypeError("parallelization must be a string") + value = v.strip().lower() + allowed = {"sequential", "threading"} + if value not in allowed: + raise ValueError(f"parallelization must be one of {sorted(allowed)}") + return value + + def _resolve_from_root(self, path: Path) -> Path: + p = path.expanduser() + return p if p.is_absolute() else (self.root / p) + + def _resolve_from_output_root(self, path: Path) -> Path: + p = path.expanduser() + return p if p.is_absolute() else (self.resolve_output_root() / p) + + def resolve_output_root(self) -> Path: + return self._resolve_from_root(self.output_root) + + def resolve_input_file(self) -> Path: + return self._resolve_from_root(self.input_file) + + def resolve_volume_input_file(self) -> Optional[Path]: + if self.volume_input_file is None: + return None + return self._resolve_from_root(self.volume_input_file) + + def resolve_reference_source_file(self) -> Optional[Path]: + if self.reference_source_file is None: + return None + return self._resolve_from_root(self.reference_source_file) + + def resolve_volume_output_dir(self) -> Path: + return self._resolve_from_output_root(self.volume_output_dir) + + def resolve_recording_prealigned_output_dir(self) -> Path: + return self._resolve_from_output_root(self.recording_prealigned_output_dir) + + def resolve_recording_prealigned_file(self) -> Path: + return self.resolve_recording_prealigned_output_dir() / "compensated.HDF5" + + def resolve_z_shift_file(self) -> Path: + return self._resolve_from_output_root(self.z_shift_file) + + def resolve_corrected_output_file(self) -> Path: + return self._resolve_from_output_root(self.corrected_output_file) + + def resolve_simulated_output_file(self) -> Path: + return self._resolve_from_output_root(self.simulated_output_file) + + def resolve_reference_volume_path(self) -> Path: + """ + Resolve reference volume path. + + If ``reference_volume`` is provided, use it. Otherwise, return the default + compensated-volume path under ``volume_output_dir``. + """ + if self.reference_volume is not None: + return self._resolve_from_root(self.reference_volume) + + volume_dir = self.resolve_volume_output_dir() + default_candidates = [ + volume_dir / "compensated.HDF5", + volume_dir / "compensated.hdf5", + ] + for candidate in default_candidates: + if candidate.exists(): + return candidate + return default_candidates[0] + + def effective_volume_bin_size(self) -> int: + """Return the bin size used when reading the reference stack as z slices.""" + return self.stack_scans_per_slice or self.volume_bin_size + + def _resolve_flow_options_path(self, path: Path) -> Path: + options_path = path.expanduser() + if not options_path.is_absolute(): + options_path = self.root / options_path + return options_path + + def _get_flow_option_overrides( + self, + option_source: Optional[Union[Dict[str, Any], Path]], + *, + protected_fields: set[str], + label: str, + ) -> Dict[str, Any]: + """ + Return OFOptions overrides with workflow-owned fields removed. + + The config supports inline dict values or paths to saved OF_options JSON. + """ + if option_source is None: + return {} + + if isinstance(option_source, dict): + return { + key: value + for key, value in option_source.items() + if key not in protected_fields + } + + options_path = self._resolve_flow_options_path(option_source) + + if not options_path.exists(): + raise ValueError(f"{label} flow options file not found: {options_path}") + + from pyflowreg.motion_correction.OF_options import OFOptions + + options = OFOptions.load_options(options_path) + allowed_fields = set(OFOptions.model_fields.keys()) + allowed_fields.difference_update(protected_fields) + + return { + key: value + for key, value in options.model_dump().items() + if key in allowed_fields + } + + def get_stage1_overrides(self) -> Dict[str, Any]: + """Return OFOptions overrides for stage 1.""" + return self._get_flow_option_overrides( + self.stage1_flow_options, + protected_fields=_STAGE1_PROTECTED_OF_FIELDS, + label="Stage-1", + ) + + def get_recording_prealign_overrides(self) -> Dict[str, Any]: + """Return OFOptions overrides for optional recording prealignment.""" + return self._get_flow_option_overrides( + self.recording_prealign_flow_options, + protected_fields=_RECORDING_PREALIGN_PROTECTED_OF_FIELDS, + label="Recording prealignment", + ) + + @classmethod + def from_toml(cls, path: Union[str, Path]) -> "ZAlignConfig": + import sys + + p = Path(path) + if sys.version_info >= (3, 11): + import tomllib + + with open(p, "rb") as f: + data = tomllib.load(f) + else: + try: + import tomli + except ImportError as exc: + raise ImportError( + "TOML support requires 'tomli' for Python < 3.11." + ) from exc + with open(p, "rb") as f: + data = tomli.load(f) + + return cls(**data) + + @classmethod + def from_yaml(cls, path: Union[str, Path]) -> "ZAlignConfig": + try: + import yaml + except ImportError as exc: + raise ImportError( + "YAML support requires 'pyyaml'. Install with: pip install pyyaml" + ) from exc + + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + + return cls(**data) + + @classmethod + def from_file(cls, path: Union[str, Path]) -> "ZAlignConfig": + p = Path(path) + suffix = p.suffix.lower() + if suffix == ".toml": + return cls.from_toml(p) + if suffix in {".yaml", ".yml"}: + return cls.from_yaml(p) + raise ValueError( + f"Unsupported config file format: {suffix}. Use .toml, .yaml, or .yml." + ) diff --git a/src/pyflowreg/z_align/pipeline.py b/src/pyflowreg/z_align/pipeline.py new file mode 100644 index 0000000..036e42d --- /dev/null +++ b/src/pyflowreg/z_align/pipeline.py @@ -0,0 +1,952 @@ +""" +Stage-based z-alignment pipeline. + +This module ports the MATLAB patch-based z-shift workflow into the existing +PyFlowReg architecture: + +1) Build/load a compensated reference volume. +2) Estimate per-frame/per-pixel z-shifts and optionally write z-corrected data. +3) Optionally simulate a z-shift-only recording from the estimated z-shifts. +""" + +from __future__ import annotations + +import json +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from time import time +from typing import Any, Dict, Optional + +import numpy as np +from scipy.ndimage import gaussian_filter + +from pyflowreg.motion_correction.OF_options import OFOptions +from pyflowreg.motion_correction.compensate_recording import compensate_recording +from pyflowreg.util.io.factory import get_video_file_reader, get_video_file_writer +from pyflowreg.z_align.config import ZAlignConfig + + +_STAGE1_PROTECTED_OF_FIELDS = { + "input_file", + "output_path", + "output_format", + "output_file_name", + "reference_frames", +} + +_RECORDING_PREALIGN_PROTECTED_OF_FIELDS = { + "input_file", + "output_path", + "output_format", + "output_file_name", +} + + +def load_or_create_status(output_root: Path) -> Dict[str, Any]: + """Load ``status.json`` from ``output_root`` or return an empty dict.""" + status_path = output_root / "status.json" + if status_path.exists(): + with open(status_path, "r", encoding="utf-8") as f: + return json.load(f) + return {} + + +def save_status(output_root: Path, status: Dict[str, Any]) -> None: + """Atomically persist ``status.json``.""" + status_path = output_root / "status.json" + tmp_path = status_path.with_suffix(".json.tmp") + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(status, f, indent=2) + tmp_path.replace(status_path) + + +def _ensure_thwc(arr: np.ndarray) -> np.ndarray: + """Normalize frame arrays to THWC layout.""" + if arr.ndim == 4: + return arr + if arr.ndim == 3: + # Either (T,H,W) or (H,W,C) single frame. Treat as (T,H,W) here. + return arr[:, :, :, np.newaxis] + if arr.ndim == 2: + return arr[np.newaxis, :, :, np.newaxis] + raise ValueError(f"Expected 2D/3D/4D frame array, got {arr.ndim}D") + + +def _to_hwcz(volume_thwc: np.ndarray) -> np.ndarray: + """Convert THWC -> HWCZ.""" + return np.transpose(volume_thwc, (1, 2, 3, 0)) + + +def _to_hwct(batch_thwc: np.ndarray) -> np.ndarray: + """Convert THWC -> HWCT.""" + return np.transpose(batch_thwc, (1, 2, 3, 0)) + + +def _from_hwct(batch_hwct: np.ndarray) -> np.ndarray: + """Convert HWCT -> THWC.""" + return np.transpose(batch_hwct, (3, 0, 1, 2)) + + +def _parse_output_format(path: Path, fallback: str = "TIFF") -> str: + """Infer writer format from file extension.""" + ext = path.suffix.lower() + if ext in {".tif", ".tiff"}: + return "TIFF" + if ext in {".h5", ".hdf5", ".hdf"}: + return "HDF5" + if ext == ".mat": + return "MAT" + return fallback + + +def _resolve_n_workers(n_jobs: int) -> int: + """Resolve worker count from config style semantics.""" + if n_jobs == -1: + return os.cpu_count() or 4 + return max(1, int(n_jobs)) + + +def _clip_and_cast(frames: np.ndarray, dtype_name: str) -> np.ndarray: + """Clip to dtype range and cast.""" + dtype = np.dtype(dtype_name) + arr = np.maximum(frames, 0) + if np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + arr = np.clip(arr, info.min, info.max) + arr = np.rint(arr) + return arr.astype(dtype, copy=False) + + +def _compute_reference_from_source( + config: ZAlignConfig, source_path: Optional[Path] = None +) -> Optional[np.ndarray]: + """ + Build a reference image from the configured or provided reference source. + + Mirrors the MATLAB step: + ``reference = mean(reader.read_frames(1:N), 4)``. + """ + ref_source = ( + source_path + if source_path is not None + else config.resolve_reference_source_file() + ) + if ref_source is None: + return None + if not ref_source.exists(): + raise FileNotFoundError(f"Reference source file not found: {ref_source}") + + reader = get_video_file_reader( + str(ref_source), + buffer_size=config.reference_source_buffer_size, + bin_size=config.reference_source_bin_size, + ) + try: + n_frames = min(config.reference_source_frames, len(reader)) + if n_frames < 1: + raise ValueError("reference_source_file has no frames") + frames = reader[slice(0, n_frames)] + frames = _ensure_thwc(frames).astype(np.float32, copy=False) + reference = np.mean(frames, axis=0) + return reference + finally: + reader.close() + + +def _build_stage1_overrides( + config: ZAlignConfig, runtime_override: Optional[Dict[str, Any]] +) -> Dict[str, Any]: + """Merge config-level and runtime OFOptions overrides.""" + overrides: Dict[str, Any] = {} + config_override = config.get_stage1_overrides() + if config_override: + overrides.update(config_override) + if runtime_override: + overrides.update(runtime_override) + for field in _STAGE1_PROTECTED_OF_FIELDS: + overrides.pop(field, None) + return overrides + + +def _build_recording_prealign_overrides(config: ZAlignConfig) -> Dict[str, Any]: + """Return workflow-safe OFOptions overrides for recording prealignment.""" + overrides = config.get_recording_prealign_overrides() + for field in _RECORDING_PREALIGN_PROTECTED_OF_FIELDS: + overrides.pop(field, None) + return overrides + + +def _compute_xy_gradient(img_2d: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Central-difference style 2D gradients (gx, gy).""" + gy, gx = np.gradient(img_2d.astype(np.float32), edge_order=1) + return gx.astype(np.float32, copy=False), gy.astype(np.float32, copy=False) + + +def _compute_volume_gradients( + volume_hwcz: np.ndarray, spatial_sigma: float +) -> tuple[np.ndarray, np.ndarray]: + """Precompute per-slice spatial gradients for the reference volume.""" + H, W, C, Z = volume_hwcz.shape + gx_vol = np.zeros((H, W, C, Z), dtype=np.float32) + gy_vol = np.zeros((H, W, C, Z), dtype=np.float32) + + for c in range(C): + for z in range(Z): + smooth = gaussian_filter(volume_hwcz[:, :, c, z], sigma=spatial_sigma) + gx, gy = _compute_xy_gradient(smooth) + gx_vol[:, :, c, z] = gx + gy_vol[:, :, c, z] = gy + + return gx_vol, gy_vol + + +def _compute_batch_gradients( + batch_hwct: np.ndarray, + spatial_sigma: float, + temporal_sigma: float, +) -> tuple[np.ndarray, np.ndarray]: + """Compute spatiotemporal-smoothed gradients for an input batch.""" + H, W, C, T = batch_hwct.shape + gx_f = np.zeros((H, W, C, T), dtype=np.float32) + gy_f = np.zeros((H, W, C, T), dtype=np.float32) + + for c in range(C): + fc = batch_hwct[:, :, c, :] + fs3 = gaussian_filter(fc, sigma=(spatial_sigma, spatial_sigma, temporal_sigma)) + gy, gx = np.gradient( + fs3.astype(np.float32, copy=False), axis=(0, 1), edge_order=1 + ) + gx_f[:, :, c, :] = gx.astype(np.float32, copy=False) + gy_f[:, :, c, :] = gy.astype(np.float32, copy=False) + + return gx_f, gy_f + + +def _estimate_anchor_z( + gx_vol: np.ndarray, + gy_vol: np.ndarray, + gx_f: np.ndarray, + gy_f: np.ndarray, +) -> tuple[int, np.ndarray]: + """Estimate anchor z-index (0-based) from the first batch.""" + Z = gx_vol.shape[3] + e_sum = np.zeros((Z,), dtype=np.float64) + for z in range(Z): + ex = np.abs(gx_vol[:, :, :, z][:, :, :, None] - gx_f) + ey = np.abs(gy_vol[:, :, :, z][:, :, :, None] - gy_f) + e_sum[z] = np.sum(ex + ey, dtype=np.float64) + anchor_z = int(np.argmin(e_sum)) + return anchor_z, e_sum + + +def _generate_patch_starts(length: int, patch_size: int, stride: int) -> list[int]: + """Generate patch starts with guaranteed end coverage.""" + last = max(length - patch_size, 0) + starts = list(range(0, last + 1, stride)) + if not starts or starts[-1] != last: + starts.append(last) + return sorted(set(starts)) + + +def _score_patch( + patch_bounds: tuple[int, int, int, int], + gx_vol: np.ndarray, + gy_vol: np.ndarray, + gx_f: np.ndarray, + gy_f: np.ndarray, + z_candidates: np.ndarray, + z_indices: np.ndarray, + z_min: int, + z_max: int, + tau_scale: float, +) -> np.ndarray: + """ + Score one spatial patch against z candidates and return z_hat for all frames. + + Returns + ------- + ndarray, shape (T,) + Sub-voxel z estimates for this patch and each time sample. + """ + r1, r2, c1, c2 = patch_bounds + + gx_patch = gx_f[r1:r2, c1:c2, :, :] + gy_patch = gy_f[r1:r2, c1:c2, :, :] + gx_vol_patch = gx_vol[r1:r2, c1:c2, :, :] + gy_vol_patch = gy_vol[r1:r2, c1:c2, :, :] + T = gx_patch.shape[3] + + e_patch = np.zeros((T, len(z_candidates)), dtype=np.float64) + for ii, z in enumerate(z_indices): + ex = np.abs(gx_vol_patch[:, :, :, z][:, :, :, None] - gx_patch) + ey = np.abs(gy_vol_patch[:, :, :, z][:, :, :, None] - gy_patch) + e_patch[:, ii] = np.sum(ex + ey, axis=(0, 1, 2), dtype=np.float64) + + s_patch = -e_patch + k_rel = np.argmax(s_patch, axis=1) + km1 = np.maximum(k_rel - 1, 0) + kp1 = np.minimum(k_rel + 1, len(z_candidates) - 1) + t_idx = np.arange(T) + + s0 = s_patch[t_idx, k_rel] + sm = s_patch[t_idx, km1] + sp = s_patch[t_idx, kp1] + den = sm - (2.0 * s0) + sp + + tau = tau_scale * np.maximum(np.abs(s0), 1.0) + den[np.abs(den) < tau] = np.nan + + delta = 0.5 * (sm - sp) / den + delta[~np.isfinite(delta)] = 0.0 + delta = np.clip(delta, -0.5, 0.5) + + z_hat_patch = np.clip(z_candidates[k_rel] + delta, z_min, z_max) + return z_hat_patch + + +def _estimate_z_patchwise( + gx_vol: np.ndarray, + gy_vol: np.ndarray, + gx_f: np.ndarray, + gy_f: np.ndarray, + *, + anchor_z: int, + win_half: int, + patch_size: int, + overlap: float, + tau_scale: float, + z_smooth_sigma_spatial: float, + z_smooth_sigma_temporal: float, + parallelization: str = "sequential", + n_jobs: int = -1, +) -> np.ndarray: + """Patch-based z estimation with sub-voxel quadratic refinement.""" + H, W, _, T = gx_f.shape + Z = gx_vol.shape[3] + + stride = max(1, int(round(patch_size * (1.0 - overlap)))) + z_min = max(0, anchor_z - win_half) + z_max = min(Z - 1, anchor_z + win_half) + z_candidates = np.arange(z_min, z_max + 1, dtype=np.float64) + z_indices = z_candidates.astype(np.int32, copy=False) + + row_starts = _generate_patch_starts(H, patch_size, stride) + col_starts = _generate_patch_starts(W, patch_size, stride) + patches: list[tuple[int, int, int, int]] = [] + for r1 in row_starts: + r2 = min(H, r1 + patch_size) + for c1 in col_starts: + c2 = min(W, c1 + patch_size) + patches.append((r1, r2, c1, c2)) + + z_accum = np.zeros((H, W, T), dtype=np.float64) + w_accum = np.zeros((H, W, T), dtype=np.float64) + + n_patches = len(patches) + patch_scores: list[Optional[np.ndarray]] = [None] * n_patches + n_workers = min(_resolve_n_workers(n_jobs), n_patches) if n_patches else 1 + use_threading = parallelization == "threading" and n_workers > 1 + + if use_threading: + with ThreadPoolExecutor(max_workers=n_workers) as executor: + future_to_idx = { + executor.submit( + _score_patch, + patch, + gx_vol, + gy_vol, + gx_f, + gy_f, + z_candidates, + z_indices, + z_min, + z_max, + tau_scale, + ): patch_idx + for patch_idx, patch in enumerate(patches) + } + for future in as_completed(future_to_idx): + patch_idx = future_to_idx[future] + patch_scores[patch_idx] = future.result() + else: + for patch_idx, patch in enumerate(patches): + patch_scores[patch_idx] = _score_patch( + patch, + gx_vol, + gy_vol, + gx_f, + gy_f, + z_candidates, + z_indices, + z_min, + z_max, + tau_scale, + ) + + # Deterministic accumulation: always add patch contributions in row-major order. + for patch_idx, (r1, r2, c1, c2) in enumerate(patches): + z_hat_patch = patch_scores[patch_idx] + if z_hat_patch is None: + raise RuntimeError("Missing patch score during z estimation") + z_accum[r1:r2, c1:c2, :] += z_hat_patch[np.newaxis, np.newaxis, :] + w_accum[r1:r2, c1:c2, :] += 1.0 + + z_hat = z_accum / np.maximum(w_accum, np.finfo(np.float64).eps) + z_hat = gaussian_filter( + z_hat, + sigma=( + z_smooth_sigma_spatial, + z_smooth_sigma_spatial, + z_smooth_sigma_temporal, + ), + ) + return np.clip(z_hat, z_min, z_max) + + +def _apply_z_correction( + batch_hwct: np.ndarray, + z_hat_hwt: np.ndarray, + diff_hwcz: np.ndarray, +) -> np.ndarray: + """Apply direct z-correction via interpolated ``Diff(anchor)-Diff(z)``.""" + H, W, C, T = batch_hwct.shape + Z = diff_hwcz.shape[3] + corrected = np.zeros_like(batch_hwct, dtype=np.float32) + one = np.float32(1.0) + + for t in range(T): + zh = np.clip(z_hat_hwt[:, :, t], 0.0, float(Z - 1)) + z0 = np.floor(zh).astype(np.int32) + z1 = np.minimum(z0 + 1, Z - 1) + alpha = (zh - z0).astype(np.float32) + + for c in range(C): + diff_c = diff_hwcz[:, :, c, :] + d0 = np.take_along_axis(diff_c, z0[:, :, None], axis=2)[:, :, 0] + d1 = np.take_along_axis(diff_c, z1[:, :, None], axis=2)[:, :, 0] + corr = (one - alpha) * d0 + alpha * d1 + corrected[:, :, c, t] = batch_hwct[:, :, c, t] + corr + + return corrected + + +def _simulate_from_z(volume_hwcz: np.ndarray, z_hat_hwt: np.ndarray) -> np.ndarray: + """Simulate recording frames by interpolating along z in the reference volume.""" + H, W, C, Z = volume_hwcz.shape + T = z_hat_hwt.shape[2] + simulated = np.zeros((H, W, C, T), dtype=np.float32) + one = np.float32(1.0) + + for t in range(T): + zh = np.clip(z_hat_hwt[:, :, t], 0.0, float(Z - 1)) + z0 = np.floor(zh).astype(np.int32) + z1 = np.minimum(z0 + 1, Z - 1) + alpha = (zh - z0).astype(np.float32) + alpha[z0 == (Z - 1)] = 0.0 + + for c in range(C): + vol_c = volume_hwcz[:, :, c, :] + v0 = np.take_along_axis(vol_c, z0[:, :, None], axis=2)[:, :, 0] + v1 = np.take_along_axis(vol_c, z1[:, :, None], axis=2)[:, :, 0] + simulated[:, :, c, t] = (one - alpha) * v0 + alpha * v1 + + return simulated + + +def _find_compensated_volume(volume_dir: Path) -> Path: + """Return existing compensated volume path or default target path.""" + candidates = [ + volume_dir / "compensated.HDF5", + volume_dir / "compensated.hdf5", + ] + for candidate in candidates: + if candidate.exists(): + return candidate + return candidates[0] + + +def _load_volume(config: ZAlignConfig, volume_path: Path) -> np.ndarray: + """Load reference volume from file and convert to HWCZ float32.""" + reader = get_video_file_reader( + str(volume_path), + buffer_size=config.volume_buffer_size, + bin_size=config.effective_volume_bin_size(), + ) + try: + volume_thwc = _ensure_thwc(reader[:]).astype(np.float32, copy=False) + finally: + reader.close() + + if volume_thwc.shape[0] < 2: + raise ValueError("Reference volume must contain at least 2 z slices") + return _to_hwcz(volume_thwc) + + +def run_stage1( + config: ZAlignConfig, + of_options_override: Optional[Dict[str, Any]] = None, +) -> Path: + """ + Stage 1: build or load the compensated reference volume. + + Returns + ------- + Path + Path to reference volume file. + """ + start_time = time() + output_root = config.resolve_output_root() + output_root.mkdir(parents=True, exist_ok=True) + + status = load_or_create_status(output_root) + volume_output_dir = config.resolve_volume_output_dir() + volume_output_dir.mkdir(parents=True, exist_ok=True) + + if config.reference_volume is not None: + volume_path = config.resolve_reference_volume_path() + if not volume_path.exists(): + raise FileNotFoundError( + f"Configured reference_volume not found: {volume_path}" + ) + status["stage1"] = "done" + status["volume_path"] = str(volume_path) + save_status(output_root, status) + print(f"Stage 1: using existing reference volume {volume_path}") + return volume_path + + volume_input_file = config.resolve_volume_input_file() + if volume_input_file is None: + raise ValueError( + "volume_input_file is required when reference_volume is not provided" + ) + if not volume_input_file.exists(): + raise FileNotFoundError(f"volume_input_file not found: {volume_input_file}") + + if not config.prealign_stack: + status["stage1"] = "done" + status["volume_path"] = str(volume_input_file) + status["prealign_stack"] = False + save_status(output_root, status) + print(f"Stage 1: using raw reference stack {volume_input_file}") + return volume_input_file + + expected_volume = _find_compensated_volume(volume_output_dir) + if config.resume and status.get("stage1") == "done" and expected_volume.exists(): + print(f"Stage 1: reusing existing volume {expected_volume}") + return expected_volume + + reference = _compute_reference_from_source(config) + stage1_buffer_size = config.stack_scans_per_slice or config.stage1_buffer_size + stage1_update_reference = ( + True + if config.stack_scans_per_slice is not None + else config.stage1_update_reference + ) + + of_params: Dict[str, Any] = { + "input_file": str(volume_input_file), + "output_path": str(volume_output_dir), + "output_format": "HDF5", + "alpha": config.stage1_alpha, + "quality_setting": config.stage1_quality_setting, + "buffer_size": stage1_buffer_size, + "bin_size": config.stage1_bin_size, + "update_reference": stage1_update_reference, + "flow_backend": config.flow_backend, + "backend_params": config.backend_params, + } + if reference is not None: + of_params["reference_frames"] = reference + + overrides = _build_stage1_overrides(config, of_options_override) + of_params.update(overrides) + if config.stack_scans_per_slice is not None: + of_params["buffer_size"] = config.stack_scans_per_slice + of_params["update_reference"] = True + + options = OFOptions(**of_params) + print("Stage 1: running compensate_recording to build reference volume...") + compensate_recording(options) + + volume_path = _find_compensated_volume(volume_output_dir) + if not volume_path.exists(): + raise RuntimeError( + "Stage 1 did not produce compensated volume. Expected " + f"{volume_output_dir / 'compensated.HDF5'}" + ) + + status["stage1"] = "done" + status["volume_path"] = str(volume_path) + save_status(output_root, status) + + elapsed = time() - start_time + print(f"Stage 1 complete in {elapsed:.2f}s") + return volume_path + + +def run_recording_prealignment(config: ZAlignConfig) -> Optional[Path]: + """ + Optionally prealign the Stage-2 recording before z-shift estimation. + + Returns + ------- + Path or None + Path to the prealigned recording, or None when disabled. + """ + if not config.prealign_recording: + return None + + start_time = time() + output_root = config.resolve_output_root() + output_root.mkdir(parents=True, exist_ok=True) + status = load_or_create_status(output_root) + + output_dir = config.resolve_recording_prealigned_output_dir() + output_dir.mkdir(parents=True, exist_ok=True) + prealigned_path = config.resolve_recording_prealigned_file() + + if ( + config.resume + and status.get("recording_prealign") == "done" + and prealigned_path.exists() + ): + print(f"Recording prealignment: reusing {prealigned_path}") + return prealigned_path + + input_path = config.resolve_input_file() + if not input_path.exists(): + raise FileNotFoundError(f"input_file not found: {input_path}") + + reference_source = config.resolve_reference_source_file() or input_path + reference = _compute_reference_from_source(config, reference_source) + + of_params: Dict[str, Any] = { + "input_file": str(input_path), + "output_path": str(output_dir), + "output_format": "HDF5", + "alpha": config.stage1_alpha, + "quality_setting": config.stage1_quality_setting, + "buffer_size": config.input_buffer_size, + "bin_size": config.input_bin_size, + "update_reference": False, + "flow_backend": config.flow_backend, + "backend_params": config.backend_params, + } + if reference is not None: + of_params["reference_frames"] = reference + + of_params.update(_build_recording_prealign_overrides(config)) + + options = OFOptions(**of_params) + print("Recording prealignment: running compensate_recording...") + compensate_recording(options) + + if not prealigned_path.exists(): + raise RuntimeError( + "Recording prealignment did not produce output. Expected " + f"{prealigned_path}" + ) + + status["recording_prealign"] = "done" + status["prealigned_recording_path"] = str(prealigned_path) + save_status(output_root, status) + + elapsed = time() - start_time + print(f"Recording prealignment complete in {elapsed:.2f}s") + return prealigned_path + + +def run_stage2( + config: ZAlignConfig, + volume_path: Optional[Path] = None, +) -> Dict[str, Any]: + """ + Stage 2: estimate z-shifts and optionally write z-corrected output. + + Returns + ------- + dict + Keys: ``z_shift_path``, ``corrected_path``, ``anchor_z``, and + ``prealigned_recording_path``. + """ + start_time = time() + output_root = config.resolve_output_root() + output_root.mkdir(parents=True, exist_ok=True) + status = load_or_create_status(output_root) + + if volume_path is None: + volume_path = config.resolve_reference_volume_path() + if not volume_path.exists(): + raise FileNotFoundError(f"Reference volume not found: {volume_path}") + + z_shift_path = config.resolve_z_shift_file() + corrected_path = config.resolve_corrected_output_file() + z_shift_path.parent.mkdir(parents=True, exist_ok=True) + corrected_path.parent.mkdir(parents=True, exist_ok=True) + + stage2_outputs_ready = z_shift_path.exists() and ( + (not config.write_corrected) or corrected_path.exists() + ) + if config.resume and status.get("stage2") == "done" and stage2_outputs_ready: + prealigned_recording_path = ( + run_recording_prealignment(config) if config.prealign_recording else None + ) + anchor_z = status.get("anchor_z", None) + print("Stage 2: existing outputs found, skipping") + return { + "z_shift_path": z_shift_path, + "corrected_path": corrected_path if config.write_corrected else None, + "anchor_z": anchor_z, + "prealigned_recording_path": prealigned_recording_path, + } + + prealigned_recording_path = run_recording_prealignment(config) + input_path = prealigned_recording_path or config.resolve_input_file() + if not input_path.exists(): + raise FileNotFoundError(f"input_file not found: {input_path}") + + volume_hwcz = _load_volume(config, volume_path) + H, W, C, Z = volume_hwcz.shape + gx_vol, gy_vol = _compute_volume_gradients(volume_hwcz, config.spatial_sigma) + + input_reader = get_video_file_reader( + str(input_path), + buffer_size=config.input_buffer_size, + bin_size=config.input_bin_size, + ) + + z_writer = get_video_file_writer(str(z_shift_path), "HDF5") + corrected_writer = None + if config.write_corrected: + corrected_fmt = _parse_output_format(corrected_path, fallback="TIFF") + corrected_writer = get_video_file_writer(str(corrected_path), corrected_fmt) + + anchor_z: Optional[int] = None + diff_hwcz: Optional[np.ndarray] = None + + try: + n_batches = 0 + while input_reader.has_batch(): + batch_thwc = _ensure_thwc(input_reader.read_batch()) + batch_hwct = _to_hwct(batch_thwc).astype(np.float32, copy=False) + if batch_hwct.shape[:3] != (H, W, C): + input_shape = ( + batch_hwct.shape[0], + batch_hwct.shape[1], + batch_hwct.shape[2], + ) + raise ValueError( + "Input recording dimensions do not match reference volume: " + f"input {input_shape} " + f"vs volume {(H, W, C)}" + ) + + gx_f, gy_f = _compute_batch_gradients( + batch_hwct, + spatial_sigma=config.spatial_sigma, + temporal_sigma=config.temporal_sigma, + ) + + if anchor_z is None: + anchor_z, _ = _estimate_anchor_z(gx_vol, gy_vol, gx_f, gy_f) + diff_hwcz = ( + volume_hwcz[:, :, :, anchor_z][:, :, :, None] - volume_hwcz + ).astype(np.float32) + diff_hwcz[:, :, :, anchor_z] = 0.0 + + z_hat_hwt = _estimate_z_patchwise( + gx_vol, + gy_vol, + gx_f, + gy_f, + anchor_z=anchor_z, + win_half=config.win_half, + patch_size=config.patch_size, + overlap=config.overlap, + tau_scale=config.parabolic_tau_scale, + z_smooth_sigma_spatial=config.z_smooth_sigma_spatial, + z_smooth_sigma_temporal=config.z_smooth_sigma_temporal, + parallelization=config.parallelization, + n_jobs=config.n_jobs, + ) + + # Persist z-shifts in MATLAB-style 1-based slice coordinates. + z_batch_thwc = _from_hwct( + (z_hat_hwt + 1.0)[:, :, None, :].astype(np.float32) + ) + z_writer.write_frames(z_batch_thwc) + + if corrected_writer is not None and diff_hwcz is not None: + corrected_hwct = _apply_z_correction(batch_hwct, z_hat_hwt, diff_hwcz) + corrected_thwc = _from_hwct(corrected_hwct) + corrected_writer.write_frames( + _clip_and_cast(corrected_thwc, config.output_dtype) + ) + + n_batches += 1 + print(f"Stage 2: processed batch {n_batches}") + + finally: + input_reader.close() + z_writer.close() + if corrected_writer is not None: + corrected_writer.close() + + if anchor_z is None: + raise RuntimeError( + "Stage 2 processed zero batches; no z-shift estimate produced" + ) + + np.savez( + str(output_root / "stage2_metadata.npz"), + anchor_z_0based=np.array(anchor_z, dtype=np.int32), + anchor_z_1based=np.array(anchor_z + 1, dtype=np.int32), + volume_path=str(volume_path), + input_path=str(input_path), + prealigned_recording_path=( + "" if prealigned_recording_path is None else str(prealigned_recording_path) + ), + ) + + status["stage2"] = "done" + status["anchor_z"] = int(anchor_z) + status["anchor_z_1based"] = int(anchor_z + 1) + save_status(output_root, status) + + elapsed = time() - start_time + print(f"Stage 2 complete in {elapsed:.2f}s") + return { + "z_shift_path": z_shift_path, + "corrected_path": corrected_path if config.write_corrected else None, + "anchor_z": anchor_z, + "prealigned_recording_path": prealigned_recording_path, + } + + +def run_stage3( + config: ZAlignConfig, + volume_path: Optional[Path] = None, + z_shift_path: Optional[Path] = None, +) -> Optional[Path]: + """ + Stage 3: simulate a z-shift-only recording from volume + z-shift. + + Returns + ------- + Path or None + Simulated output path (or None if simulation disabled). + """ + if not config.write_simulated: + print("Stage 3: simulation disabled by config (write_simulated=false)") + return None + + start_time = time() + output_root = config.resolve_output_root() + output_root.mkdir(parents=True, exist_ok=True) + status = load_or_create_status(output_root) + + simulated_path = config.resolve_simulated_output_file() + simulated_path.parent.mkdir(parents=True, exist_ok=True) + + if config.resume and status.get("stage3") == "done" and simulated_path.exists(): + print(f"Stage 3: reusing existing simulation {simulated_path}") + return simulated_path + + if volume_path is None: + volume_path = config.resolve_reference_volume_path() + if z_shift_path is None: + z_shift_path = config.resolve_z_shift_file() + + if not volume_path.exists(): + raise FileNotFoundError(f"Reference volume not found: {volume_path}") + if not z_shift_path.exists(): + raise FileNotFoundError(f"z_shift file not found: {z_shift_path}") + + volume_hwcz = _load_volume(config, volume_path) + H, W, C, _ = volume_hwcz.shape + + z_reader = get_video_file_reader( + str(z_shift_path), + buffer_size=config.input_buffer_size, + bin_size=1, + ) + sim_fmt = _parse_output_format(simulated_path, fallback="TIFF") + sim_writer = get_video_file_writer(str(simulated_path), sim_fmt) + + try: + n_batches = 0 + while z_reader.has_batch(): + z_thwc = _ensure_thwc(z_reader.read_batch()).astype(np.float32, copy=False) + if z_thwc.shape[1] != H or z_thwc.shape[2] != W: + raise ValueError( + "z_shift dimensions do not match reference volume: " + f"z {(z_thwc.shape[1], z_thwc.shape[2])} vs volume {(H, W)}" + ) + if z_thwc.shape[3] < 1: + raise ValueError("z_shift batch must have at least one channel") + + # z_shift is stored as 1-based slice IDs for MATLAB parity. + z_hwt = np.transpose(z_thwc[:, :, :, 0], (1, 2, 0)).astype(np.float64) - 1.0 + sim_hwct = _simulate_from_z(volume_hwcz, z_hwt) + if sim_hwct.shape[2] != C: + raise RuntimeError("Internal channel mismatch in simulated output") + + sim_thwc = _from_hwct(sim_hwct) + sim_writer.write_frames(_clip_and_cast(sim_thwc, config.output_dtype)) + + n_batches += 1 + print(f"Stage 3: processed batch {n_batches}") + + finally: + z_reader.close() + sim_writer.close() + + status["stage3"] = "done" + save_status(output_root, status) + + elapsed = time() - start_time + print(f"Stage 3 complete in {elapsed:.2f}s") + return simulated_path + + +def run_all_stages( + config: ZAlignConfig, + of_options_override: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Run all z-align stages. + + Returns + ------- + dict + Collected stage outputs. + """ + print("=" * 60) + print("Z-ALIGN STAGE 1: Build/Load Reference Volume") + print("=" * 60) + volume_path = run_stage1(config, of_options_override=of_options_override) + + print("\n" + "=" * 60) + print("Z-ALIGN STAGE 2: Estimate z-shifts and Correct Signal") + print("=" * 60) + stage2_out = run_stage2(config, volume_path=volume_path) + + simulated_path = None + if config.write_simulated: + print("\n" + "=" * 60) + print("Z-ALIGN STAGE 3: Simulate z-shift-only recording") + print("=" * 60) + simulated_path = run_stage3( + config, + volume_path=volume_path, + z_shift_path=stage2_out["z_shift_path"], + ) + else: + print("\nSkipping Stage 3 (write_simulated=false)") + + return { + "volume_path": volume_path, + **stage2_out, + "simulated_path": simulated_path, + } diff --git a/tests/z_align/__init__.py b/tests/z_align/__init__.py new file mode 100644 index 0000000..69e2d30 --- /dev/null +++ b/tests/z_align/__init__.py @@ -0,0 +1 @@ +"""Tests for z-alignment processing.""" diff --git a/tests/z_align/test_cli.py b/tests/z_align/test_cli.py new file mode 100644 index 0000000..1283fa6 --- /dev/null +++ b/tests/z_align/test_cli.py @@ -0,0 +1,145 @@ +""" +Tests for z-align CLI. +""" + +from __future__ import annotations + +import argparse + +import pytest + +from pyflowreg.z_align.config import ZAlignConfig +from pyflowreg.z_align import cli + + +class TestCLIParsing: + """Test helper parsing functions.""" + + def test_parse_value_scalars(self): + assert cli._parse_value("true") is True + assert cli._parse_value("false") is False + assert cli._parse_value("12") == 12 + assert cli._parse_value("1.25") == 1.25 + assert cli._parse_value("quality") == "quality" + + def test_parse_value_json(self): + assert cli._parse_value("[1, 2, 3]") == [1, 2, 3] + assert cli._parse_value('{"a": 1}') == {"a": 1} + + def test_parse_overrides(self): + out = cli._parse_overrides( + ["alpha=5", "write_simulated=false", "quality_setting=balanced"] + ) + assert out["alpha"] == 5 + assert out["write_simulated"] is False + assert out["quality_setting"] == "balanced" + + +class TestCLIRouting: + """Test subcommand routing behavior.""" + + @pytest.fixture + def config_file(self, tmp_path): + cfg_file = tmp_path / "z_align.toml" + cfg_file.write_text( + "\n".join( + [ + f'root = "{tmp_path.as_posix()}"', + 'input_file = "compensated.tiff"', + ] + ), + encoding="utf-8", + ) + return cfg_file + + def test_cmd_run_stage1(self, config_file, monkeypatch): + cfg = ZAlignConfig(root=config_file.parent, input_file="compensated.tiff") + called = {"stage1": 0} + + monkeypatch.setattr(cli.ZAlignConfig, "from_file", lambda _p: cfg) + monkeypatch.setattr( + cli, + "run_stage1", + lambda config, overrides=None: called.__setitem__( + "stage1", called["stage1"] + 1 + ), + ) + monkeypatch.setattr( + cli, "run_stage2", lambda *args, **kwargs: pytest.fail("unexpected stage2") + ) + monkeypatch.setattr( + cli, "run_stage3", lambda *args, **kwargs: pytest.fail("unexpected stage3") + ) + monkeypatch.setattr( + cli, + "run_all_stages", + lambda *args, **kwargs: pytest.fail("unexpected run_all_stages"), + ) + + args = argparse.Namespace( + config=str(config_file), + stage="1", + of_params=["alpha=8", "quality_setting=balanced"], + ) + cli.cmd_run(args) + assert called["stage1"] == 1 + + def test_cmd_run_stage2(self, config_file, monkeypatch): + cfg = ZAlignConfig(root=config_file.parent, input_file="compensated.tiff") + called = {"stage2": 0} + + monkeypatch.setattr(cli.ZAlignConfig, "from_file", lambda _p: cfg) + monkeypatch.setattr( + cli, + "run_stage2", + lambda config: called.__setitem__("stage2", called["stage2"] + 1), + ) + monkeypatch.setattr( + cli, "run_stage1", lambda *args, **kwargs: pytest.fail("unexpected stage1") + ) + + args = argparse.Namespace(config=str(config_file), stage="2", of_params=None) + cli.cmd_run(args) + assert called["stage2"] == 1 + + def test_cmd_run_all_stages(self, config_file, monkeypatch): + cfg = ZAlignConfig(root=config_file.parent, input_file="compensated.tiff") + called = {"all": 0} + + monkeypatch.setattr(cli.ZAlignConfig, "from_file", lambda _p: cfg) + monkeypatch.setattr( + cli, + "run_all_stages", + lambda config, overrides=None: called.__setitem__("all", called["all"] + 1), + ) + monkeypatch.setattr( + cli, "run_stage1", lambda *args, **kwargs: pytest.fail("unexpected stage1") + ) + monkeypatch.setattr( + cli, "run_stage2", lambda *args, **kwargs: pytest.fail("unexpected stage2") + ) + monkeypatch.setattr( + cli, "run_stage3", lambda *args, **kwargs: pytest.fail("unexpected stage3") + ) + + args = argparse.Namespace( + config=str(config_file), + stage=None, + of_params=["alpha=5"], + ) + cli.cmd_run(args) + assert called["all"] == 1 + + def test_cmd_run_missing_config_exits(self, tmp_path): + args = argparse.Namespace( + config=str(tmp_path / "missing.toml"), + stage=None, + of_params=None, + ) + with pytest.raises(SystemExit, match="1"): + cli.cmd_run(args) + + def test_main_without_subcommand_exits(self, monkeypatch): + monkeypatch.setattr("sys.argv", ["pyflowreg-z-align"]) + with pytest.raises(SystemExit, match="1"): + cli.main() diff --git a/tests/z_align/test_config.py b/tests/z_align/test_config.py new file mode 100644 index 0000000..29a1227 --- /dev/null +++ b/tests/z_align/test_config.py @@ -0,0 +1,246 @@ +""" +Tests for z-align configuration. +""" + +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from pyflowreg.motion_correction.OF_options import OFOptions +from pyflowreg.z_align.config import ZAlignConfig + + +class TestZAlignConfigBasics: + """Test basic config creation and validation.""" + + def test_minimal_valid_config(self, tmp_path): + cfg = ZAlignConfig(root=tmp_path, input_file="compensated.tiff") + assert cfg.root == tmp_path + assert cfg.input_file == Path("compensated.tiff") + assert cfg.resume is True + assert cfg.prealign_stack is True + assert cfg.prealign_recording is False + assert cfg.stack_scans_per_slice is None + assert cfg.write_corrected is True + assert cfg.write_simulated is True + + def test_root_must_exist(self, tmp_path): + with pytest.raises(ValidationError, match="Root directory does not exist"): + ZAlignConfig( + root=tmp_path / "missing_root", + input_file="compensated.tiff", + ) + + def test_overlap_out_of_range_raises(self, tmp_path): + with pytest.raises(ValidationError, match="overlap must satisfy"): + ZAlignConfig(root=tmp_path, input_file="compensated.tiff", overlap=1.0) + + def test_invalid_output_dtype_raises(self, tmp_path): + with pytest.raises(ValidationError, match="Invalid output_dtype"): + ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + output_dtype="not_a_dtype", + ) + + def test_stack_scans_per_slice_must_be_positive(self, tmp_path): + with pytest.raises(ValidationError, match="Value must be >= 1"): + ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + stack_scans_per_slice=0, + ) + + +class TestZAlignConfigPathResolution: + """Test path resolution behavior.""" + + def test_resolve_relative_paths(self, tmp_path): + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + output_root="z_out", + volume_output_dir="aligned_stack", + recording_prealigned_output_dir="prealigned_recording", + z_shift_file="z_shift.HDF5", + corrected_output_file="compensated_shift_corrected.tif", + simulated_output_file="simulated_from_z.tif", + ) + + assert cfg.resolve_output_root() == tmp_path / "z_out" + assert cfg.resolve_input_file() == tmp_path / "compensated.tiff" + assert cfg.resolve_volume_output_dir() == tmp_path / "z_out" / "aligned_stack" + assert cfg.resolve_recording_prealigned_output_dir() == ( + tmp_path / "z_out" / "prealigned_recording" + ) + assert cfg.resolve_recording_prealigned_file() == ( + tmp_path / "z_out" / "prealigned_recording" / "compensated.HDF5" + ) + assert cfg.resolve_z_shift_file() == tmp_path / "z_out" / "z_shift.HDF5" + assert cfg.resolve_corrected_output_file() == ( + tmp_path / "z_out" / "compensated_shift_corrected.tif" + ) + assert cfg.resolve_simulated_output_file() == ( + tmp_path / "z_out" / "simulated_from_z.tif" + ) + + def test_resolve_reference_volume_prefers_existing_default(self, tmp_path): + volume_dir = tmp_path / "z_out" / "aligned_stack" + volume_dir.mkdir(parents=True) + existing = volume_dir / "compensated.hdf5" + existing.touch() + + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + output_root="z_out", + volume_output_dir="aligned_stack", + ) + assert cfg.resolve_reference_volume_path() == existing + + def test_resolve_reference_volume_explicit(self, tmp_path): + explicit = tmp_path / "my_ref.h5" + explicit.touch() + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + reference_volume="my_ref.h5", + ) + assert cfg.resolve_reference_volume_path() == explicit + + +class TestZAlignFlowOptions: + """Test stage-1 OFOptions override loading.""" + + def test_stage1_flow_options_dict_returns_copy(self, tmp_path): + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + stage1_flow_options={"alpha": (7.0, 7.0), "buffer_size": 123}, + ) + overrides = cfg.get_stage1_overrides() + assert overrides["buffer_size"] == 123 + overrides["buffer_size"] = 999 + # Ensure config stored mapping is unaffected + assert cfg.stage1_flow_options["buffer_size"] == 123 + + def test_stage1_flow_options_protect_workflow_owned_fields(self, tmp_path): + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + stage1_flow_options={ + "input_file": "other.tif", + "output_path": "other_out", + "output_format": "TIFF", + "output_file_name": "other.tif", + "reference_frames": [1, 2, 3], + "buffer_size": 123, + }, + ) + overrides = cfg.get_stage1_overrides() + assert overrides == {"buffer_size": 123} + + def test_stage1_flow_options_from_json_file(self, tmp_path): + options_path = tmp_path / "of_options.json" + opts = OFOptions( + input_file="input.tif", + output_path=tmp_path / "out_dir", + quality_setting="balanced", + alpha=3.0, + buffer_size=777, + ) + opts.save_options(options_path) + + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + stage1_flow_options=options_path, + ) + overrides = cfg.get_stage1_overrides() + + assert overrides["quality_setting"] == "balanced" + assert overrides["buffer_size"] == 777 + assert "input_file" not in overrides + assert "output_path" not in overrides + assert "output_format" not in overrides + assert "output_file_name" not in overrides + assert "reference_frames" not in overrides + + def test_stage1_flow_options_file_missing_raises(self, tmp_path): + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + stage1_flow_options="missing_options.json", + ) + with pytest.raises(ValueError, match="not found"): + cfg.get_stage1_overrides() + + def test_recording_prealign_flow_options_dict_returns_copy(self, tmp_path): + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + recording_prealign_flow_options={ + "input_file": "other.tif", + "output_path": "other_out", + "output_format": "TIFF", + "output_file_name": "other.tif", + "reference_frames": [1, 2, 3], + "alpha": 8.0, + }, + ) + + overrides = cfg.get_recording_prealign_overrides() + assert overrides == {"reference_frames": [1, 2, 3], "alpha": 8.0} + overrides["alpha"] = 2.0 + assert cfg.recording_prealign_flow_options["alpha"] == 8.0 + + def test_effective_volume_bin_size_prefers_stack_scans_per_slice(self, tmp_path): + cfg = ZAlignConfig( + root=tmp_path, + input_file="compensated.tiff", + volume_bin_size=3, + stack_scans_per_slice=9, + ) + assert cfg.effective_volume_bin_size() == 9 + + +class TestZAlignConfigFileLoading: + """Test config file loading helpers.""" + + def test_load_from_toml(self, tmp_path): + root_posix = tmp_path.as_posix() + cfg_file = tmp_path / "z_align.toml" + cfg_file.write_text( + "\n".join( + [ + f'root = "{root_posix}"', + 'input_file = "compensated.tiff"', + 'output_root = "z_out"', + 'recording_prealigned_output_dir = "prealigned"', + "prealign_recording = true", + "stack_scans_per_slice = 9", + "write_corrected = false", + "write_simulated = true", + "patch_size = 64", + ] + ), + encoding="utf-8", + ) + + cfg = ZAlignConfig.from_toml(cfg_file) + assert cfg.root == tmp_path + assert cfg.input_file == Path("compensated.tiff") + assert cfg.output_root == Path("z_out") + assert cfg.recording_prealigned_output_dir == Path("prealigned") + assert cfg.prealign_recording is True + assert cfg.stack_scans_per_slice == 9 + assert cfg.write_corrected is False + assert cfg.write_simulated is True + assert cfg.patch_size == 64 + + def test_from_file_unsupported_suffix_raises(self, tmp_path): + cfg_file = tmp_path / "z_align.json" + cfg_file.write_text("{}", encoding="utf-8") + with pytest.raises(ValueError, match="Unsupported config file format"): + ZAlignConfig.from_file(cfg_file) diff --git a/tests/z_align/test_pipeline.py b/tests/z_align/test_pipeline.py new file mode 100644 index 0000000..fd48dee --- /dev/null +++ b/tests/z_align/test_pipeline.py @@ -0,0 +1,669 @@ +""" +Tests for z-align pipeline stages. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from pyflowreg.z_align.config import ZAlignConfig +from pyflowreg.z_align import pipeline + + +class DummyBatchReader: + """Simple batch reader used for stage tests.""" + + def __init__(self, batches): + self._batches = [np.asarray(b) for b in batches] + self._idx = 0 + self.closed = False + + def has_batch(self): + return self._idx < len(self._batches) + + def read_batch(self): + out = self._batches[self._idx] + self._idx += 1 + return out + + def close(self): + self.closed = True + + +class RecordingWriter: + """Writer that records frame writes and touches destination on close.""" + + def __init__(self, path: str): + self.path = Path(path) + self.writes = [] + self.closed = False + + def write_frames(self, frames): + self.writes.append(np.asarray(frames)) + + def close(self): + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.touch(exist_ok=True) + self.closed = True + + +class IndexedReader: + """Reader supporting whole-file indexing for reference-source and volume tests.""" + + def __init__(self, frames): + self.frames = np.asarray(frames) + self.closed = False + + def __len__(self): + return self.frames.shape[0] + + def __getitem__(self, key): + return self.frames[key] + + def close(self): + self.closed = True + + +def _make_config(tmp_path: Path, **kwargs) -> ZAlignConfig: + params = { + "root": tmp_path, + "input_file": "compensated.tiff", + "output_root": "z_out", + "volume_output_dir": "aligned_stack", + "z_shift_file": "z_shift.HDF5", + "corrected_output_file": "compensated_shift_corrected.tif", + "simulated_output_file": "simulated_from_z.tif", + } + params.update(kwargs) + return ZAlignConfig(**params) + + +class TestStatusHelpers: + """Test status.json helper functions.""" + + def test_save_and_load_status_roundtrip(self, tmp_path): + out = tmp_path / "z_out" + out.mkdir() + status = {"stage1": "done", "anchor_z": 3} + pipeline.save_status(out, status) + loaded = pipeline.load_or_create_status(out) + assert loaded == status + + def test_load_status_missing_returns_empty_dict(self, tmp_path): + out = tmp_path / "no_status" + out.mkdir() + assert pipeline.load_or_create_status(out) == {} + + +class TestPipelineUtilities: + """Test small pipeline helpers.""" + + def test_clip_and_cast_rounds_integer_outputs(self): + frames = np.array([0.4, 0.6, 1.6, 300.0], dtype=np.float32) + out = pipeline._clip_and_cast(frames, "uint8") + np.testing.assert_array_equal(out, np.array([0, 1, 2, 255], dtype=np.uint8)) + + def test_load_volume_uses_stack_scans_per_slice_as_bin_size( + self, tmp_path, monkeypatch + ): + volume_file = tmp_path / "volume.h5" + volume_file.touch() + cfg = _make_config( + tmp_path, + reference_volume="volume.h5", + volume_bin_size=2, + stack_scans_per_slice=9, + ) + volume_thwc = np.zeros((3, 4, 5, 1), dtype=np.float32) + captured = {} + + def fake_get_video_file_reader(path, *args, **kwargs): + captured["path"] = Path(path) + captured["buffer_size"] = kwargs["buffer_size"] + captured["bin_size"] = kwargs["bin_size"] + return IndexedReader(volume_thwc) + + monkeypatch.setattr( + pipeline, "get_video_file_reader", fake_get_video_file_reader + ) + + out = pipeline._load_volume(cfg, volume_file) + + assert captured["path"] == volume_file + assert captured["buffer_size"] == cfg.volume_buffer_size + assert captured["bin_size"] == 9 + assert out.shape == (4, 5, 1, 3) + + +class TestStage1: + """Test stage 1 behavior.""" + + def test_run_stage1_uses_existing_reference_volume(self, tmp_path): + ref_path = tmp_path / "reference_volume.h5" + ref_path.touch() + cfg = _make_config(tmp_path, reference_volume="reference_volume.h5") + + result = pipeline.run_stage1(cfg) + assert result == ref_path + + status = pipeline.load_or_create_status(cfg.resolve_output_root()) + assert status["stage1"] == "done" + assert Path(status["volume_path"]) == ref_path + + def test_run_stage1_runs_compensate_recording(self, tmp_path, monkeypatch): + source = tmp_path / "file_00004_00001.tif" + source.touch() + cfg = _make_config(tmp_path, volume_input_file="file_00004_00001.tif") + + captured = {} + + def fake_compensate_recording(options): + captured["alpha"] = options.alpha + captured["quality_setting"] = options.quality_setting.value + out = Path(options.output_path) + out.mkdir(parents=True, exist_ok=True) + (out / "compensated.HDF5").touch() + return None + + monkeypatch.setattr( + pipeline, + "compensate_recording", + fake_compensate_recording, + ) + + result = pipeline.run_stage1(cfg) + + assert result.exists() + assert result.name.lower().endswith("hdf5") + assert captured["alpha"] == (5.0, 5.0) + assert captured["quality_setting"] == "quality" + + def test_run_stage1_uses_stack_scans_per_slice_for_batch_and_update_reference( + self, tmp_path, monkeypatch + ): + source = tmp_path / "file_00004_00001.tif" + source.touch() + cfg = _make_config( + tmp_path, + volume_input_file="file_00004_00001.tif", + stage1_buffer_size=500, + stage1_update_reference=False, + stack_scans_per_slice=9, + ) + captured = {} + + def fake_compensate_recording(options): + captured["buffer_size"] = options.buffer_size + captured["update_reference"] = options.update_reference + out = Path(options.output_path) + out.mkdir(parents=True, exist_ok=True) + (out / "compensated.HDF5").touch() + return None + + monkeypatch.setattr(pipeline, "compensate_recording", fake_compensate_recording) + + result = pipeline.run_stage1( + cfg, + of_options_override={ + "buffer_size": 123, + "update_reference": False, + "output_format": "TIFF", + }, + ) + + assert result == cfg.resolve_volume_output_dir() / "compensated.HDF5" + assert captured["buffer_size"] == 9 + assert captured["update_reference"] is True + + def test_run_stage1_can_skip_stack_prealignment(self, tmp_path, monkeypatch): + source = tmp_path / "file_00004_00001.tif" + source.touch() + cfg = _make_config( + tmp_path, + volume_input_file="file_00004_00001.tif", + prealign_stack=False, + ) + + monkeypatch.setattr( + pipeline, + "compensate_recording", + lambda *_args, **_kwargs: pytest.fail("unexpected compensate_recording"), + ) + + result = pipeline.run_stage1(cfg) + + assert result == source + status = pipeline.load_or_create_status(cfg.resolve_output_root()) + assert status["stage1"] == "done" + assert status["prealign_stack"] is False + assert Path(status["volume_path"]) == source + + +class TestRecordingPrealignment: + """Test optional 2D prealignment of the recording used for z estimation.""" + + def test_run_recording_prealignment_runs_compensate_recording( + self, tmp_path, monkeypatch + ): + input_file = tmp_path / "compensated.tiff" + input_file.touch() + cfg = _make_config( + tmp_path, + prealign_recording=True, + input_buffer_size=7, + input_bin_size=2, + recording_prealign_flow_options={ + "buffer_size": 11, + "quality_setting": "balanced", + "input_file": "wrong.tif", + "output_path": "wrong_out", + "output_format": "TIFF", + "output_file_name": "wrong.tif", + }, + ) + reference = np.ones((4, 5, 1), dtype=np.float32) + captured = {} + + def fake_compute_reference(config, source_path=None): + captured["reference_source"] = source_path + return reference + + def fake_compensate_recording(options): + captured["input_file"] = Path(options.input_file) + captured["output_path"] = Path(options.output_path) + captured["output_format"] = options.output_format.value + captured["output_file_name"] = options.output_file_name + captured["quality_setting"] = options.quality_setting.value + captured["buffer_size"] = options.buffer_size + captured["bin_size"] = options.bin_size + captured["update_reference"] = options.update_reference + captured["reference_frames"] = options.reference_frames + out = Path(options.output_path) + out.mkdir(parents=True, exist_ok=True) + (out / "compensated.HDF5").touch() + return None + + monkeypatch.setattr( + pipeline, "_compute_reference_from_source", fake_compute_reference + ) + monkeypatch.setattr(pipeline, "compensate_recording", fake_compensate_recording) + + result = pipeline.run_recording_prealignment(cfg) + + assert result == cfg.resolve_recording_prealigned_file() + assert captured["reference_source"] == input_file + assert captured["input_file"] == input_file + assert captured["output_path"] == cfg.resolve_recording_prealigned_output_dir() + assert captured["output_format"] == "HDF5" + assert captured["output_file_name"] is None + assert captured["quality_setting"] == "balanced" + assert captured["buffer_size"] == 11 + assert captured["bin_size"] == 2 + assert captured["update_reference"] is False + np.testing.assert_array_equal(captured["reference_frames"], reference) + + status = pipeline.load_or_create_status(cfg.resolve_output_root()) + assert status["recording_prealign"] == "done" + assert Path(status["prealigned_recording_path"]) == result + + +class TestStage2: + """Test stage 2 behavior.""" + + def test_run_stage2_resume_skip(self, tmp_path, monkeypatch): + cfg = _make_config(tmp_path, write_corrected=False, reference_volume="vol.h5") + volume_path = tmp_path / "vol.h5" + volume_path.touch() + + z_shift_path = cfg.resolve_z_shift_file() + z_shift_path.parent.mkdir(parents=True, exist_ok=True) + z_shift_path.touch() + + pipeline.save_status( + cfg.resolve_output_root(), + {"stage2": "done", "anchor_z": 2, "anchor_z_1based": 3}, + ) + + # If resume path is taken, _load_volume should never be called. + monkeypatch.setattr( + pipeline, + "_load_volume", + lambda *_args, **_kwargs: pytest.fail("Should not load volume on resume"), + ) + + out = pipeline.run_stage2(cfg) + assert out["z_shift_path"] == z_shift_path + assert out["corrected_path"] is None + assert out["anchor_z"] == 2 + assert out["prealigned_recording_path"] is None + + def test_run_stage2_writes_1based_zshift_and_corrected(self, tmp_path, monkeypatch): + input_file = tmp_path / "compensated.tiff" + input_file.touch() + volume_file = tmp_path / "volume.h5" + volume_file.touch() + + cfg = _make_config( + tmp_path, + reference_volume="volume.h5", + write_corrected=True, + patch_size=2, + overlap=0.5, + win_half=1, + ) + + H, W, C, Z = 4, 4, 1, 3 + T = 2 + batch_thwc = np.full((T, H, W, C), 100.0, dtype=np.float32) + input_reader = DummyBatchReader([batch_thwc]) + + writers = {} + + def fake_get_video_file_reader(path, *args, **kwargs): + if Path(path) == cfg.resolve_input_file(): + return input_reader + raise AssertionError(f"Unexpected reader path: {path}") + + def fake_get_video_file_writer(path, output_format, **kwargs): + writer = RecordingWriter(path) + writers[str(Path(path))] = writer + return writer + + monkeypatch.setattr( + pipeline, "get_video_file_reader", fake_get_video_file_reader + ) + monkeypatch.setattr( + pipeline, "get_video_file_writer", fake_get_video_file_writer + ) + monkeypatch.setattr( + pipeline, + "_load_volume", + lambda _cfg, _path: np.arange(H * W * C * Z, dtype=np.float32).reshape( + H, W, C, Z + ), + ) + monkeypatch.setattr( + pipeline, + "_compute_volume_gradients", + lambda volume, sigma: ( + np.zeros_like(volume, dtype=np.float32), + np.zeros_like(volume, dtype=np.float32), + ), + ) + monkeypatch.setattr( + pipeline, + "_compute_batch_gradients", + lambda batch, spatial_sigma, temporal_sigma: ( + np.zeros_like(batch, dtype=np.float32), + np.zeros_like(batch, dtype=np.float32), + ), + ) + monkeypatch.setattr( + pipeline, + "_estimate_anchor_z", + lambda gx_vol, gy_vol, gx_f, gy_f: (1, np.zeros((Z,), dtype=np.float64)), + ) + monkeypatch.setattr( + pipeline, + "_estimate_z_patchwise", + lambda *args, **kwargs: np.zeros((H, W, T), dtype=np.float64), + ) + monkeypatch.setattr( + pipeline, + "_apply_z_correction", + lambda batch_hwct, z_hat_hwt, diff_hwcz: batch_hwct + 5.0, + ) + + out = pipeline.run_stage2(cfg, volume_path=volume_file) + + z_writer = writers[str(cfg.resolve_z_shift_file())] + corrected_writer = writers[str(cfg.resolve_corrected_output_file())] + + assert out["anchor_z"] == 1 + assert out["prealigned_recording_path"] is None + assert len(z_writer.writes) == 1 + assert len(corrected_writer.writes) == 1 + + z_written = z_writer.writes[0] + corrected_written = corrected_writer.writes[0] + + # z_shift is stored as 1-based IDs for MATLAB compatibility. + assert np.all(z_written == 1.0) + assert corrected_written.dtype == np.uint16 + assert np.all(corrected_written == 105) + + status = pipeline.load_or_create_status(cfg.resolve_output_root()) + assert status["stage2"] == "done" + assert status["anchor_z"] == 1 + assert status["anchor_z_1based"] == 2 + + def test_run_stage2_reads_prealigned_recording_when_enabled( + self, tmp_path, monkeypatch + ): + input_file = tmp_path / "compensated.tiff" + input_file.touch() + prealigned_file = ( + tmp_path / "z_out" / "prealigned_recording" / "compensated.HDF5" + ) + prealigned_file.parent.mkdir(parents=True) + prealigned_file.touch() + volume_file = tmp_path / "volume.h5" + volume_file.touch() + + cfg = _make_config( + tmp_path, + reference_volume="volume.h5", + prealign_recording=True, + write_corrected=False, + patch_size=2, + overlap=0.5, + win_half=1, + ) + + H, W, C, Z = 4, 4, 1, 3 + T = 2 + batch_thwc = np.full((T, H, W, C), 100.0, dtype=np.float32) + input_reader = DummyBatchReader([batch_thwc]) + writers = {} + seen = {} + + def fake_get_video_file_reader(path, *args, **kwargs): + seen["input_reader_path"] = Path(path) + if Path(path) == prealigned_file: + return input_reader + raise AssertionError(f"Unexpected reader path: {path}") + + def fake_get_video_file_writer(path, output_format, **kwargs): + writer = RecordingWriter(path) + writers[str(Path(path))] = writer + return writer + + monkeypatch.setattr( + pipeline, + "run_recording_prealignment", + lambda _config: prealigned_file, + ) + monkeypatch.setattr( + pipeline, "get_video_file_reader", fake_get_video_file_reader + ) + monkeypatch.setattr( + pipeline, "get_video_file_writer", fake_get_video_file_writer + ) + monkeypatch.setattr( + pipeline, + "_load_volume", + lambda _cfg, _path: np.arange(H * W * C * Z, dtype=np.float32).reshape( + H, W, C, Z + ), + ) + monkeypatch.setattr( + pipeline, + "_compute_volume_gradients", + lambda volume, sigma: ( + np.zeros_like(volume, dtype=np.float32), + np.zeros_like(volume, dtype=np.float32), + ), + ) + monkeypatch.setattr( + pipeline, + "_compute_batch_gradients", + lambda batch, spatial_sigma, temporal_sigma: ( + np.zeros_like(batch, dtype=np.float32), + np.zeros_like(batch, dtype=np.float32), + ), + ) + monkeypatch.setattr( + pipeline, + "_estimate_anchor_z", + lambda gx_vol, gy_vol, gx_f, gy_f: (1, np.zeros((Z,), dtype=np.float64)), + ) + monkeypatch.setattr( + pipeline, + "_estimate_z_patchwise", + lambda *args, **kwargs: np.zeros((H, W, T), dtype=np.float64), + ) + + out = pipeline.run_stage2(cfg, volume_path=volume_file) + + assert seen["input_reader_path"] == prealigned_file + assert out["prealigned_recording_path"] == prealigned_file + assert out["corrected_path"] is None + assert len(writers[str(cfg.resolve_z_shift_file())].writes) == 1 + + +class TestStage3: + """Test stage 3 behavior.""" + + def test_run_stage3_subtracts_one_before_simulation(self, tmp_path, monkeypatch): + volume_file = tmp_path / "vol.h5" + volume_file.touch() + cfg = _make_config(tmp_path, reference_volume="vol.h5") + + z_shift_path = cfg.resolve_z_shift_file() + z_shift_path.parent.mkdir(parents=True, exist_ok=True) + z_shift_path.touch() + + H, W, C, Z = 3, 3, 1, 4 + z_batch_thwc = np.full((2, H, W, 1), 2.0, dtype=np.float32) # 1-based z=2 + z_reader = DummyBatchReader([z_batch_thwc]) + writers = {} + seen = {} + + def fake_get_video_file_reader(path, *args, **kwargs): + if Path(path) == z_shift_path: + return z_reader + raise AssertionError(f"Unexpected stage3 reader path: {path}") + + def fake_get_video_file_writer(path, output_format, **kwargs): + writer = RecordingWriter(path) + writers[str(Path(path))] = writer + return writer + + def fake_simulate_from_z(volume_hwcz, z_hat_hwt): + seen["z_hat_hwt"] = z_hat_hwt.copy() + T = z_hat_hwt.shape[2] + return np.full((H, W, C, T), 7.0, dtype=np.float32) + + monkeypatch.setattr( + pipeline, "get_video_file_reader", fake_get_video_file_reader + ) + monkeypatch.setattr( + pipeline, "get_video_file_writer", fake_get_video_file_writer + ) + monkeypatch.setattr( + pipeline, + "_load_volume", + lambda _cfg, _path: np.zeros((H, W, C, Z), dtype=np.float32), + ) + monkeypatch.setattr(pipeline, "_simulate_from_z", fake_simulate_from_z) + + out_path = pipeline.run_stage3( + cfg, volume_path=volume_file, z_shift_path=z_shift_path + ) + + assert out_path == cfg.resolve_simulated_output_file() + assert np.all(seen["z_hat_hwt"] == 1.0) # 2 (1-based) -> 1 (0-based) + + writer = writers[str(cfg.resolve_simulated_output_file())] + assert len(writer.writes) == 1 + assert writer.writes[0].dtype == np.uint16 + assert np.all(writer.writes[0] == 7) + + status = pipeline.load_or_create_status(cfg.resolve_output_root()) + assert status["stage3"] == "done" + + def test_run_stage3_returns_none_when_disabled(self, tmp_path): + cfg = _make_config(tmp_path, write_simulated=False) + assert pipeline.run_stage3(cfg) is None + + +class TestRunAllStages: + """Test all-stage orchestrator behavior.""" + + def test_run_all_stages_skips_stage3_when_disabled(self, tmp_path, monkeypatch): + cfg = _make_config(tmp_path, write_simulated=False) + called = {"stage1": 0, "stage2": 0, "stage3": 0} + + monkeypatch.setattr( + pipeline, + "run_stage1", + lambda config, of_options_override=None: called.__setitem__( + "stage1", called["stage1"] + 1 + ) + or (tmp_path / "vol.h5"), + ) + monkeypatch.setattr( + pipeline, + "run_stage2", + lambda config, volume_path=None: called.__setitem__( + "stage2", called["stage2"] + 1 + ) + or { + "z_shift_path": tmp_path / "z_shift.HDF5", + "corrected_path": tmp_path / "corr.tif", + "anchor_z": 0, + }, + ) + monkeypatch.setattr( + pipeline, + "run_stage3", + lambda *args, **kwargs: called.__setitem__("stage3", called["stage3"] + 1), + ) + + out = pipeline.run_all_stages(cfg) + + assert called["stage1"] == 1 + assert called["stage2"] == 1 + assert called["stage3"] == 0 + assert out["simulated_path"] is None + + def test_run_all_stages_returns_prealigned_recording_path( + self, tmp_path, monkeypatch + ): + cfg = _make_config(tmp_path, write_simulated=False) + prealigned_file = ( + tmp_path / "z_out" / "prealigned_recording" / "compensated.HDF5" + ) + + monkeypatch.setattr( + pipeline, + "run_stage1", + lambda config, of_options_override=None: tmp_path / "vol.h5", + ) + monkeypatch.setattr( + pipeline, + "run_stage2", + lambda config, volume_path=None: { + "z_shift_path": tmp_path / "z_shift.HDF5", + "corrected_path": tmp_path / "corr.tif", + "anchor_z": 0, + "prealigned_recording_path": prealigned_file, + }, + ) + + out = pipeline.run_all_stages(cfg) + + assert out["prealigned_recording_path"] == prealigned_file diff --git a/tests/z_align/test_pipeline_parallel.py b/tests/z_align/test_pipeline_parallel.py new file mode 100644 index 0000000..d899b7a --- /dev/null +++ b/tests/z_align/test_pipeline_parallel.py @@ -0,0 +1,49 @@ +""" +Parallelization equivalence tests for z-align patch scoring. +""" + +import numpy as np + +from pyflowreg.z_align.pipeline import _estimate_z_patchwise + + +def test_estimate_z_patchwise_threading_matches_sequential(): + """Threaded patch scoring should match sequential output up to FP noise.""" + rng = np.random.default_rng(1234) + H, W, C, Z, T = 32, 32, 1, 9, 7 + + gx_vol = rng.standard_normal((H, W, C, Z), dtype=np.float32) + gy_vol = rng.standard_normal((H, W, C, Z), dtype=np.float32) + gx_f = rng.standard_normal((H, W, C, T), dtype=np.float32) + gy_f = rng.standard_normal((H, W, C, T), dtype=np.float32) + + kwargs = { + "anchor_z": 4, + "win_half": 3, + "patch_size": 8, + "overlap": 0.5, + "tau_scale": 1e-3, + "z_smooth_sigma_spatial": 1.5, + "z_smooth_sigma_temporal": 1.0, + } + + z_seq = _estimate_z_patchwise( + gx_vol, + gy_vol, + gx_f, + gy_f, + parallelization="sequential", + n_jobs=1, + **kwargs, + ) + z_thr = _estimate_z_patchwise( + gx_vol, + gy_vol, + gx_f, + gy_f, + parallelization="threading", + n_jobs=4, + **kwargs, + ) + + assert np.allclose(z_seq, z_thr, atol=1e-5, rtol=1e-5)