diff --git a/.gitignore b/.gitignore index a5cb8ca..ec27911 100644 --- a/.gitignore +++ b/.gitignore @@ -297,6 +297,7 @@ dask-worker-space/ # temporary files *.*~ +build/** # scratch space examples/data/** @@ -308,4 +309,6 @@ expt_ctrl/temp*.cfg *.hdf5 # documentation -docs/_* \ No newline at end of file +docs/_* + +.codex/** \ No newline at end of file diff --git a/examples/build_dpc_mie_stack.py b/examples/build_dpc_mie_stack.py new file mode 100644 index 0000000..a0f4ed7 --- /dev/null +++ b/examples/build_dpc_mie_stack.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +Generate the default DPC Mie simulation stack used by the example scripts. +""" + +from __future__ import annotations + +from pathlib import Path +from dataclasses import dataclass + +import numpy as np + +try: + import cupy as cp # type: ignore +except Exception: + cp = None + +from simulate_dpc_microsphere import SphereSpec, simulate_dpc_images_sphere_to_zarr + + +@dataclass(frozen=True) +class SimParams: + wavelength_um: float = 0.515 + na_obj: float = 0.8 + led_grid_shape: tuple[int, int] = (64, 64) + pitch_mm: float = 2.5 + camera_pixel_um: float = 2.4 + magnification: float = 20.0 + camera_oversample: int = 1 + esize_camera: tuple[int, int] = (256, 256) + z_plane_um: float = 0.0 + sphere: SphereSpec = SphereSpec(radius_um=1., n_sphere=1.59, n_medium=1.55) + inner_na: float = 0.0 + include_center_led: bool = False + pattern_order: tuple[str, str, str, str] = ("left", "right", "up", "down") + normalize_by_led_count: bool = True + led_subsample: int = 1 + mie_kwargs: dict | None = None + camera_gains: float = 2.0 + camera_offsets: float = 100.0 + camera_readout_noise_sds: float = 3.0 + camera_photon_shot_noise: bool = True + camera_saturation: float | None = None + camera_image_is_integer: bool = True + exposure_time_ms: float = 10.0 + illumination_photons_per_s_per_um2: float = 1e7 + focal_stack_planes: int = 31 + focal_stack_step_um: float = 0.325 + + @property + def z_planes_um(self) -> np.ndarray: + return ( + np.arange(self.focal_stack_planes) - self.focal_stack_planes // 2 + ) * self.focal_stack_step_um + + +def ensure_dpc_mie_stack(zarr_path: Path, *, use_gpu: bool) -> None: + """ + Ensure the default DPC Mie simulation stack exists at the provided path. + """ + if zarr_path.exists(): + return + params = SimParams() + zarr_path.parent.mkdir(parents=True, exist_ok=True) + simulate_dpc_images_sphere_to_zarr( + zarr_path, + wavelength_um=params.wavelength_um, + na_obj=params.na_obj, + led_grid_shape=params.led_grid_shape, + pitch_mm=params.pitch_mm, + camera_pixel_um=params.camera_pixel_um, + magnification=params.magnification, + camera_oversample=params.camera_oversample, + esize_camera=params.esize_camera, + z_plane_um=params.z_plane_um, + sphere=params.sphere, + inner_na=params.inner_na, + include_center_led=params.include_center_led, + pattern_order=params.pattern_order, + normalize_by_led_count=params.normalize_by_led_count, + led_subsample=params.led_subsample, + use_gpu=use_gpu, + mie_kwargs=params.mie_kwargs, + camera_gains=params.camera_gains, + camera_offsets=params.camera_offsets, + camera_readout_noise_sds=params.camera_readout_noise_sds, + camera_photon_shot_noise=params.camera_photon_shot_noise, + camera_saturation=params.camera_saturation, + camera_image_is_integer=params.camera_image_is_integer, + exposure_time_ms=params.exposure_time_ms, + illumination_photons_per_s_per_um2=params.illumination_photons_per_s_per_um2, + focal_stack_planes=params.focal_stack_planes, + focal_stack_step_um=params.focal_stack_step_um, + overwrite=True, + reuse_cache=True, + ) + + +def main() -> None: + default_path = Path("build/dpc_mie_stack.zarr") + use_gpu = False + if cp is not None: + try: + use_gpu = bool(cp.is_available()) + except Exception: + use_gpu = False + ensure_dpc_mie_stack(default_path, use_gpu=use_gpu) + print(f"Simulation written to {default_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/compare_wotf_predictions.py b/examples/compare_wotf_predictions.py new file mode 100644 index 0000000..0292447 --- /dev/null +++ b/examples/compare_wotf_predictions.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python3 +""" +Compare WOTF-predicted images for FISTA vs Tikhonov reconstructions. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np +import zarr # type: ignore + +try: + import cupy as cp # type: ignore +except Exception: + cp = None +try: + import matplotlib.pyplot as plt # type: ignore +except Exception: + plt = None + +from mcsim.analysis.dpc_meta import DPCMeta +from mcsim.analysis.fft import ft3, ift3 +from mcsim.analysis.field_prop import get_n +from mcsim.analysis.optimize import to_cpu +from mcsim.analysis.wotf_fista import ( + WOTFParams, + WOTFFISTAOptimizer, + build_led_na_grid, + build_led_pattern_membership, + build_wotf_transfer, + camera_to_photons, +) +from build_dpc_mie_stack import ensure_dpc_mie_stack + + +def _load_simulated_stack(zarr_path: Path) -> tuple[np.ndarray, dict, np.ndarray]: + g = zarr.open_group(str(zarr_path), mode="r") + if "dpc" not in g: + raise ValueError("Missing 'dpc' dataset in simulation output.") + dpc = np.asarray(g["dpc"], dtype=np.float32) + + meta_group = g.get("meta") + if meta_group is None or "focal_offsets_um" not in meta_group: + raise ValueError("Missing 'meta/focal_offsets_um' in simulation output.") + z_planes_um = np.asarray(meta_group["focal_offsets_um"], dtype=float) + + if dpc.ndim == 3: + if dpc.shape[0] != 4: + raise ValueError(f"Expected dpc shape (4, ny, nx), got {dpc.shape}") + dpc = dpc[None, ...] + elif dpc.ndim == 4: + if dpc.shape[1] != 4: + raise ValueError(f"Expected dpc shape (nz, 4, ny, nx), got {dpc.shape}") + else: + raise ValueError(f"Unexpected dpc ndim: {dpc.ndim}") + + I_cam = np.moveaxis(dpc, 1, 0) + return I_cam, dict(g.attrs), z_planes_um + + +def _build_meta(I_cam: np.ndarray, attrs: dict, z_planes_um: np.ndarray) -> DPCMeta: + nz = int(I_cam.shape[1]) + ny = int(I_cam.shape[2]) + nx = int(I_cam.shape[3]) + if z_planes_um.size != nz: + raise ValueError("z_planes_um length must match stack depth.") + + wavelength_um = float(attrs["wavelength_um"]) + na_obj = float(attrs["na_obj"]) + camera_pixel_um = float(attrs["camera_pixel_um"]) + magnification = float(attrs["magnification"]) + n_medium = float(attrs.get("n_medium", 1.0)) + led_grid_shape = tuple(int(v) for v in attrs.get("led_grid_shape", (64, 64))) + pattern_order = tuple(attrs.get("pattern_order", ("left", "right", "up", "down"))) + inner_na = float(attrs.get("inner_na", 0.0)) + include_center_led = bool(attrs.get("include_center_led", False)) + + dz = 0.0 + if z_planes_um.size > 1: + dzs = np.diff(z_planes_um) + if not np.allclose(dzs, dzs[0]): + raise ValueError("z_planes_um must be evenly spaced.") + dz = float(dzs[0]) + + dxy_um = camera_pixel_um / magnification + + return DPCMeta( + wavelength_um=wavelength_um, + n_background=n_medium, + NA_obj=na_obj, + magnification=magnification, + camera_pixel_pitch_um=camera_pixel_um, + volume_shape_zyx=(nz, ny, nx), + voxel_size_um_zyx=(dz, dxy_um, dxy_um), + z_planes_um=z_planes_um, + led_grid_shape=led_grid_shape, + inner_na=inner_na, + include_center_led=include_center_led, + pattern_order=pattern_order, # type: ignore[arg-type] + ) + + +def _report_z_sampling(meta: DPCMeta) -> tuple[float, float]: + dzs = np.diff(np.asarray(meta.z_planes_um, dtype=float)) + dz_mean = float(np.mean(dzs)) if dzs.size else 0.0 + dz_std = float(np.std(dzs)) if dzs.size else 0.0 + dz_meta = float(meta.voxel_size_um_zyx[0]) + print( + "Z sampling:", + f"dz_meta={dz_meta:.6g} um", + f"dz_mean={dz_mean:.6g} um", + f"dz_std={dz_std:.3g} um", + f"nz={meta.volume_shape_zyx[0]}", + ) + if dzs.size and not np.isclose(dz_mean, dz_meta, rtol=1.0e-6, atol=1.0e-9): + print("Warning: dz in meta does not match z_planes spacing.") + return dz_mean, dz_meta + + +def _axial_transfer_summary(H: np.ndarray) -> np.ndarray: + return np.mean(np.abs(H), axis=(2, 3)) + + +def _save_wotf_slice_plots( + out_dir: Path, + H_real: np.ndarray, + H_imag: np.ndarray, + *, + suffix: str, +) -> None: + if plt is None: + print("matplotlib not available; skipping WOTF plot images.") + return + + n_patterns, nz, ny, nx = H_real.shape + z_mid = nz // 2 + y_mid = ny // 2 + + def _grid_plot(data: np.ndarray, title: str, out_name: str) -> None: + log_data = np.log10(data + 1.0e-12) + fig, axes = plt.subplots(2, 2, figsize=(8, 8), constrained_layout=True) + axes = axes.ravel() + for pid in range(min(4, n_patterns)): + axes[pid].imshow(log_data[pid], origin="lower", aspect="auto") + axes[pid].set_title(f"pattern {pid}") + for ax in axes[n_patterns:]: + ax.axis("off") + fig.suptitle(f"{title} (log10)") + fig.savefig(out_dir / out_name, dpi=150) + plt.close(fig) + + h_real_kxy = np.abs(H_real[:, z_mid]) + h_imag_kxy = np.abs(H_imag[:, z_mid]) + h_real_kzx = np.abs(H_real[:, :, y_mid, :]) + h_imag_kzx = np.abs(H_imag[:, :, y_mid, :]) + + _grid_plot(h_real_kxy, "H_real |kz mid| (kx-ky)", f"wotf_real_kxy_{suffix}.png") + _grid_plot(h_imag_kxy, "H_imag |kz mid| (kx-ky)", f"wotf_imag_kxy_{suffix}.png") + _grid_plot(h_real_kzx, "H_real |ky mid| (kz-kx)", f"wotf_real_kzx_{suffix}.png") + _grid_plot(h_imag_kzx, "H_imag |ky mid| (kz-kx)", f"wotf_imag_kzx_{suffix}.png") + + +def _tikhonov_wotf( + contrast: np.ndarray, + H_real: np.ndarray, + *, + reg: float, +) -> np.ndarray: + contrast_ft = ft3(contrast, axes=(1, 2, 3), shift=True) + num = np.sum(np.conj(H_real) * contrast_ft, axis=0) + den = np.sum(np.abs(H_real) ** 2, axis=0) + float(reg) + v_ft_real = num / den + v_real = ift3(v_ft_real, axes=(0, 1, 2), shift=True).real + return v_real.astype(np.float32, copy=False) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--sim-zarr", type=Path, default=Path("build/dpc_mie_stack.zarr")) + parser.add_argument("--out", type=Path, default=Path("build/wotf_pred_compare.zarr")) + parser.add_argument("--use-gpu", action="store_true") + parser.add_argument("--step", type=float, default=5.0e-5) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--tv-weight", type=float, default=0.0) + parser.add_argument("--tv-max-num-iter", type=int, default=50) + parser.add_argument("--tv-eps", type=float, default=2.0e-4) + parser.add_argument("--tv-aniso-z", type=float, default=1.0) + parser.add_argument("--z-taper", type=int, default=8) + parser.add_argument("--xy-taper", type=int, default=None) + parser.add_argument("--pupil-taper-na", type=float, default=0.0) + parser.add_argument("--eps", type=float, default=1.0e-8) + parser.add_argument("--line-search", action="store_true") + parser.add_argument("--line-search-factor", type=float, default=0.5) + parser.add_argument("--restart-line-search", action="store_true") + parser.add_argument("--tikhonov-reg", type=float, default=5.0e-4) + parser.add_argument("--led-subsample", type=int, default=1) + parser.add_argument("--pad-z", type=int, default=64) + parser.add_argument("--pad-yx", type=int, default=64) + parser.add_argument("--adjoint-check", action="store_true") + parser.add_argument("--adjoint-crop-z", type=int, default=0) + parser.add_argument("--adjoint-crop-yx", type=int, default=0) + args = parser.parse_args() + + use_gpu = bool(args.use_gpu and cp is not None) + print(f"Using GPU: {use_gpu}") + + ensure_dpc_mie_stack(args.sim_zarr, use_gpu=use_gpu) + I_cam, attrs, z_planes_um = _load_simulated_stack(args.sim_zarr) + meta = _build_meta(I_cam, attrs, z_planes_um) + _report_z_sampling(meta) + + camera_gain = attrs.get("camera_gains", 1.0) + camera_offset = attrs.get("camera_offsets", 0.0) + I_cam_xp = cp.asarray(I_cam) if use_gpu else np.asarray(I_cam) + I_meas_phot = camera_to_photons( + I_cam_xp, + camera_offset_adu=camera_offset, + camera_gain_photons_per_adu=camera_gain, + ) + + I_meas = np.asarray(to_cpu(I_meas_phot), dtype=np.float32) + I0_pred = np.mean(I_meas, axis=(2, 3), keepdims=True).astype(np.float32, copy=False) + + led_na_xy = build_led_na_grid( + meta.led_grid_shape[0], + meta.led_grid_shape[1], + na_obj=meta.NA_obj, + na_in=meta.inner_na, + include_center=meta.include_center_led, + led_subsample=int(args.led_subsample), + ) + membership = build_led_pattern_membership( + led_na_xy, + order=meta.pattern_order, + include_center=meta.include_center_led, + ) + + params = WOTFParams( + wavelength_um=meta.wavelength_um, + na_obj=meta.NA_obj, + na_in=meta.inner_na, + n0=meta.n_background, + dxy_um=meta.dxy_um, + dz_um=float(meta.voxel_size_um_zyx[0]), + nz=meta.volume_shape_zyx[0] + 2 * int(args.pad_z), + ny=meta.volume_shape_zyx[1] + 2 * int(args.pad_yx), + nx=meta.volume_shape_zyx[2] + 2 * int(args.pad_yx), + pupil_taper_na=float(args.pupil_taper_na), + ) + H_real, H_imag = build_wotf_transfer(led_na_xy, membership, params) + expected_shape = ( + meta.volume_shape_zyx[0] + 2 * int(args.pad_z), + meta.volume_shape_zyx[1] + 2 * int(args.pad_yx), + meta.volume_shape_zyx[2] + 2 * int(args.pad_yx), + ) + if H_real.shape[1:] != expected_shape: + print(f"Warning: H_real shape {H_real.shape} != expected {expected_shape}") + + xp = cp if use_gpu else np + n0_vol = xp.full(meta.volume_shape_zyx, float(meta.n_background), dtype=xp.float32) + + contrast = (I_meas / I0_pred - 1.0) / -1.0 + pad_z = int(args.pad_z) + pad_yx = int(args.pad_yx) + if pad_z or pad_yx: + contrast = np.pad( + contrast, + ((0, 0), (pad_z, pad_z), (pad_yx, pad_yx), (pad_yx, pad_yx)), + mode="constant", + ) + v_real = _tikhonov_wotf(contrast, H_real, reg=float(args.tikhonov_reg)) + if pad_z or pad_yx: + nz, ny, nx = meta.volume_shape_zyx + v_real = v_real[pad_z : pad_z + nz, pad_yx : pad_yx + ny, pad_yx : pad_yx + nx] + n_tikh = get_n(v_real, meta.n_background, meta.wavelength_um).real.astype(np.float32, copy=False) + + optimizer = WOTFFISTAOptimizer( + xp.asarray(I_meas), + xp.asarray(I0_pred), + xp.asarray(H_real), + xp.asarray(H_imag), + n0=meta.n_background, + wavelength_um=meta.wavelength_um, + eps=float(args.eps), + tv_weight=float(args.tv_weight), + tv_max_num_iter=int(args.tv_max_num_iter), + tv_eps=float(args.tv_eps), + tv_voxel_size_zyx=(params.dz_um, params.dxy_um, params.dxy_um), + tv_weight_scale_zyx=(float(args.tv_aniso_z), 1.0, 1.0), + pad_zyx=(int(args.pad_z), int(args.pad_yx), int(args.pad_yx)), + z_taper=int(args.z_taper), + xy_taper=args.xy_taper, + use_real_constraint=True, + ) + + if args.adjoint_check: + nz, ny, nx = meta.volume_shape_zyx + crop_z = int(args.adjoint_crop_z) + crop_yx = int(args.adjoint_crop_yx) + if crop_z > 0: + crop_z = min(crop_z, nz) + if crop_yx > 0: + crop_yx = min(crop_yx, ny, nx) + z0 = (nz - crop_z) // 2 if crop_z > 0 else 0 + y0 = (ny - crop_yx) // 2 if crop_yx > 0 else 0 + x0 = (nx - crop_yx) // 2 if crop_yx > 0 else 0 + z1 = z0 + (crop_z if crop_z > 0 else nz) + y1 = y0 + (crop_yx if crop_yx > 0 else ny) + x1 = x0 + (crop_yx if crop_yx > 0 else nx) + v = xp.zeros((nz, ny, nx), dtype=xp.float32) + g = xp.zeros((H_real.shape[0], nz, ny, nx), dtype=xp.float32) + rng = np.random.default_rng(0) + v_cpu = rng.standard_normal((z1 - z0, y1 - y0, x1 - x0)).astype(np.float32) + g_cpu = rng.standard_normal((H_real.shape[0], z1 - z0, y1 - y0, x1 - x0)).astype(np.float32) + v[z0:z1, y0:y1, x0:x1] = xp.asarray(v_cpu) + g[:, z0:z1, y0:y1, x0:x1] = xp.asarray(g_cpu) + pad_z = int(args.pad_z) + pad_yx = int(args.pad_yx) + if pad_z or pad_yx: + v_pad = xp.pad(v, ((pad_z, pad_z), (pad_yx, pad_yx), (pad_yx, pad_yx)), mode="constant") + g_pad = xp.pad(g, ((0, 0), (pad_z, pad_z), (pad_yx, pad_yx), (pad_yx, pad_yx)), mode="constant") + else: + v_pad = v + g_pad = g + v_ft = ft3(v_pad, axes=(0, 1, 2), shift=True) + pred_ft = xp.asarray(H_real) * v_ft.real + xp.asarray(H_imag) * v_ft.imag + contrast = ift3(pred_ft, axes=(1, 2, 3), shift=True).real + if pad_z or pad_yx: + contrast = contrast[:, pad_z : pad_z + nz, pad_yx : pad_yx + ny, pad_yx : pad_yx + nx] + lhs = float(to_cpu(xp.sum(contrast * g))) + g_ft = ift3(g_pad, axes=(1, 2, 3), shift=True, adjoint=True) + grad_vft_real = xp.sum(xp.conj(xp.asarray(H_real)) * g_ft, axis=0) + grad_vft_imag = xp.sum(xp.conj(xp.asarray(H_imag)) * g_ft, axis=0) + grad_v = ft3( + xp.real(grad_vft_real) + 1j * xp.real(grad_vft_imag), + axes=(0, 1, 2), + shift=True, + adjoint=True, + ) + if pad_z or pad_yx: + grad_v = grad_v[pad_z : pad_z + nz, pad_yx : pad_yx + ny, pad_yx : pad_yx + nx] + rhs = float(to_cpu(xp.sum(v * grad_v.real))) + denom = max(1.0e-12, abs(lhs), abs(rhs)) + rel_err = abs(lhs - rhs) / denom + n_init = (cp.asarray if use_gpu else np.asarray)( + np.full(meta.volume_shape_zyx, meta.n_background, dtype=np.float32) + ) + diag_optimizer = WOTFFISTAOptimizer( + xp.asarray(I_meas), + xp.asarray(I0_pred), + xp.asarray(H_real), + xp.asarray(H_imag), + n0=meta.n_background, + wavelength_um=meta.wavelength_um, + eps=float(args.eps), + tv_weight=0.0, + tv_max_num_iter=0, + tv_eps=float(args.tv_eps), + tv_voxel_size_zyx=(params.dz_um, params.dxy_um, params.dxy_um), + tv_weight_scale_zyx=(1.0, 1.0, 1.0), + pad_zyx=(int(args.pad_z), int(args.pad_yx), int(args.pad_yx)), + z_taper=int(args.z_taper), + xy_taper=args.xy_taper, + use_real_constraint=True, + ) + cost_tikh = float(to_cpu(diag_optimizer.cost(n_init))[0]) + grad_tikh = diag_optimizer.gradient(n_init)[0] + grad_norm = float(to_cpu(xp.sqrt(xp.sum(grad_tikh**2)))) + step_size = float(args.step) + n_after_raw = n_init - step_size * grad_tikh + n_after = diag_optimizer.prox(n_after_raw, step_size) + cost_tikh_step = float(to_cpu(diag_optimizer.cost(n_after))[0]) + cost_tikh_step_raw = float(to_cpu(diag_optimizer.cost(n_after_raw))[0]) + dir_deriv = float(to_cpu(xp.sum(grad_tikh * grad_tikh))) + n_min = float(to_cpu(xp.min(n_init))) + n0 = float(meta.n_background) + n_below = float(to_cpu(xp.mean(n_init < n0))) + eps_candidates = [1.0e-5, 1.0e-6, 1.0e-7, 1.0e-8] + step_costs = [] + for eps_step in eps_candidates: + n_pos = n_init - eps_step * grad_tikh + n_neg = n_init + eps_step * grad_tikh + cost_pos = float(to_cpu(diag_optimizer.cost(n_pos))[0]) + cost_neg = float(to_cpu(diag_optimizer.cost(n_neg))[0]) + step_costs.append((eps_step, cost_pos, cost_neg)) + line_search_iter_limit = None if args.line_search else 0 + result = optimizer.run( + n_init, + step=float(args.step), + max_iterations=int(args.iters), + use_fista=True, + n_batch=1, + compute_batch_grad_parallel=True, + verbose=True, + compute_cost=True, + line_search_iter_limit=line_search_iter_limit, + line_search_factor=float(args.line_search_factor), + restart_line_search=bool(args.restart_line_search), + ) + n_fista = np.asarray(to_cpu(result["x"]), dtype=np.float32) + + I_pred_fista = np.asarray( + to_cpu( + optimizer._forward_intensity( # type: ignore[attr-defined] + cp.asarray(n_fista) if use_gpu else n_fista + ) + ), + dtype=np.float32, + ) + I_pred_tikh = np.asarray( + to_cpu( + optimizer._forward_intensity( # type: ignore[attr-defined] + cp.asarray(n_tikh) if use_gpu else n_tikh + ) + ), + dtype=np.float32, + ) + rmse_fista = np.sqrt(np.mean((I_pred_fista - I_meas) ** 2, axis=(1, 2, 3))) + rmse_tikh = np.sqrt(np.mean((I_pred_tikh - I_meas) ** 2, axis=(1, 2, 3))) + print("RMSE per pattern (FISTA):", rmse_fista) + print("RMSE per pattern (Tikhonov):", rmse_tikh) + + out_path = args.out + out_path.parent.mkdir(parents=True, exist_ok=True) + _save_wotf_slice_plots(out_path.parent, H_real, H_imag, suffix=f"{out_path.stem}_fista") + _save_wotf_slice_plots(out_path.parent, H_real, H_imag, suffix=f"{out_path.stem}_tikh") + g = zarr.open_group(str(out_path), mode="w", zarr_format=3) + g.attrs.update(meta.as_dict()) + g.attrs["use_gpu"] = bool(use_gpu) + g.attrs["iters"] = int(args.iters) + g.attrs["step"] = float(args.step) + g.attrs["tv_weight"] = float(args.tv_weight) + g.attrs["z_taper"] = int(args.z_taper) + g.attrs["xy_taper"] = None if args.xy_taper is None else int(args.xy_taper) + g.attrs["pupil_taper_na"] = float(args.pupil_taper_na) + g.attrs["tikhonov_reg"] = float(args.tikhonov_reg) + g.attrs["init_from_tikhonov"] = True + g.attrs["diag_cost_tikh"] = float(cost_tikh) + g.attrs["diag_cost_tikh_step"] = float(cost_tikh_step) + g.attrs["diag_step_size"] = float(step_size) + g.attrs["diag_grad_norm"] = float(grad_norm) + g.attrs["diag_dir_deriv"] = float(dir_deriv) + g.attrs["diag_step_costs"] = [(float(eps), float(cpos), float(cneg)) for eps, cpos, cneg in step_costs] + g.create_array("I_meas", shape=I_meas.shape, dtype="float32")[...] = I_meas + g.create_array("I0_pred", shape=I0_pred.shape, dtype="float32")[...] = I0_pred + g.create_array("n_fista", shape=n_fista.shape, dtype="float32")[...] = n_fista + g.create_array("n_tikh", shape=n_tikh.shape, dtype="float32")[...] = n_tikh + g.create_array("rmse_fista", shape=rmse_fista.shape, dtype="float32")[...] = rmse_fista.astype( + np.float32, copy=False + ) + g.create_array("rmse_tikh", shape=rmse_tikh.shape, dtype="float32")[...] = rmse_tikh.astype( + np.float32, copy=False + ) + g.create_array("I_pred_fista", shape=I_pred_fista.shape, dtype="float32")[...] = I_pred_fista + g.create_array("I_pred_tikh", shape=I_pred_tikh.shape, dtype="float32")[...] = I_pred_tikh + h_real_abs = _axial_transfer_summary(H_real) + h_imag_abs = _axial_transfer_summary(H_imag) + g.create_array("H_real_abs_mean", shape=h_real_abs.shape, dtype="float32")[...] = h_real_abs.astype( + np.float32, copy=False + ) + g.create_array("H_imag_abs_mean", shape=h_imag_abs.shape, dtype="float32")[...] = h_imag_abs.astype( + np.float32, copy=False + ) + z_mid = int(meta.volume_shape_zyx[0]) // 2 + y_mid = int(meta.volume_shape_zyx[1]) // 2 + h_real_kxy = np.abs(H_real[:, z_mid]) + h_imag_kxy = np.abs(H_imag[:, z_mid]) + h_real_kzx = np.abs(H_real[:, :, y_mid, :]) + h_imag_kzx = np.abs(H_imag[:, :, y_mid, :]) + g.create_array("H_real_kxy_abs", shape=h_real_kxy.shape, dtype="float32")[...] = h_real_kxy.astype( + np.float32, copy=False + ) + g.create_array("H_imag_kxy_abs", shape=h_imag_kxy.shape, dtype="float32")[...] = h_imag_kxy.astype( + np.float32, copy=False + ) + g.create_array("H_real_kzx_abs", shape=h_real_kzx.shape, dtype="float32")[...] = h_real_kzx.astype( + np.float32, copy=False + ) + g.create_array("H_imag_kzx_abs", shape=h_imag_kzx.shape, dtype="float32")[...] = h_imag_kzx.astype( + np.float32, copy=False + ) + kz = np.fft.fftshift(np.fft.fftfreq(int(meta.volume_shape_zyx[0]), float(meta.voxel_size_um_zyx[0]))) + g.create_array("kz_cycles_per_um", shape=kz.shape, dtype="float32")[...] = kz.astype(np.float32, copy=False) + g.create_array("led_na_xy", shape=led_na_xy.shape, dtype="float32")[...] = led_na_xy.astype(np.float32, copy=False) + g.create_array("pattern_membership", shape=membership.shape, dtype="uint8")[...] = membership.astype( + np.uint8, copy=False + ) + print(f"Wrote outputs to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/run_wotf_fista_mie.py b/examples/run_wotf_fista_mie.py new file mode 100644 index 0000000..39d4882 --- /dev/null +++ b/examples/run_wotf_fista_mie.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +""" +Run WOTF FISTA reconstruction on a Mie-simulated DPC stack. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np +import zarr # type: ignore + +try: + import cupy as cp # type: ignore +except Exception: + cp = None + +from mcsim.analysis.dpc_meta import DPCMeta +from mcsim.analysis.fft import ft3, ift3 +from mcsim.analysis.field_prop import get_n +from mcsim.analysis.optimize import to_cpu +from mcsim.analysis.wotf_fista import ( + WOTFParams, + WOTFFISTAOptimizer, + build_led_na_grid, + build_led_pattern_membership, + build_wotf_transfer, + camera_to_photons, +) +from build_dpc_mie_stack import ensure_dpc_mie_stack + + +def _load_simulated_stack(zarr_path: Path) -> tuple[np.ndarray, dict, np.ndarray]: + g = zarr.open_group(str(zarr_path), mode="r") + if "dpc" not in g: + raise ValueError("Missing 'dpc' dataset in simulation output.") + dpc = np.asarray(g["dpc"], dtype=np.float32) + + meta_group = g.get("meta") + if meta_group is None or "focal_offsets_um" not in meta_group: + raise ValueError("Missing 'meta/focal_offsets_um' in simulation output.") + z_planes_um = np.asarray(meta_group["focal_offsets_um"], dtype=float) + + if dpc.ndim == 3: + if dpc.shape[0] != 4: + raise ValueError(f"Expected dpc shape (4, ny, nx), got {dpc.shape}") + dpc = dpc[None, ...] + elif dpc.ndim == 4: + if dpc.shape[1] != 4: + raise ValueError(f"Expected dpc shape (nz, 4, ny, nx), got {dpc.shape}") + else: + raise ValueError(f"Unexpected dpc ndim: {dpc.ndim}") + + I_cam = np.moveaxis(dpc, 1, 0) + return I_cam, dict(g.attrs), z_planes_um + + +def _build_meta(I_cam: np.ndarray, attrs: dict, z_planes_um: np.ndarray) -> DPCMeta: + nz = int(I_cam.shape[1]) + ny = int(I_cam.shape[2]) + nx = int(I_cam.shape[3]) + if z_planes_um.size != nz: + raise ValueError("z_planes_um length must match stack depth.") + + wavelength_um = float(attrs["wavelength_um"]) + na_obj = float(attrs["na_obj"]) + camera_pixel_um = float(attrs["camera_pixel_um"]) + magnification = float(attrs["magnification"]) + n_medium = float(attrs.get("n_medium", 1.0)) + led_grid_shape = tuple(int(v) for v in attrs.get("led_grid_shape", (64, 64))) + pattern_order = tuple(attrs.get("pattern_order", ("left", "right", "up", "down"))) + inner_na = float(attrs.get("inner_na", 0.0)) + include_center_led = bool(attrs.get("include_center_led", False)) + + dz = 0.0 + if z_planes_um.size > 1: + dzs = np.diff(z_planes_um) + if not np.allclose(dzs, dzs[0]): + raise ValueError("z_planes_um must be evenly spaced.") + dz = float(dzs[0]) + + dxy_um = camera_pixel_um / magnification + + return DPCMeta( + wavelength_um=wavelength_um, + n_background=n_medium, + NA_obj=na_obj, + magnification=magnification, + camera_pixel_pitch_um=camera_pixel_um, + volume_shape_zyx=(nz, ny, nx), + voxel_size_um_zyx=(dz, dxy_um, dxy_um), + z_planes_um=z_planes_um, + led_grid_shape=led_grid_shape, + inner_na=inner_na, + include_center_led=include_center_led, + pattern_order=pattern_order, # type: ignore[arg-type] + ) + + +def _tikhonov_wotf( + contrast: np.ndarray, + H_real: np.ndarray, + *, + reg: float, +) -> np.ndarray: + contrast_ft = ft3(contrast, axes=(1, 2, 3), shift=True) + num = np.sum(np.conj(H_real) * contrast_ft, axis=0) + den = np.sum(np.abs(H_real) ** 2, axis=0) + float(reg) + v_ft_real = num / den + v_real = ift3(v_ft_real, axes=(0, 1, 2), shift=True).real + return v_real.astype(np.float32, copy=False) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--sim-zarr", type=Path, default=Path("build/dpc_mie_stack.zarr")) + parser.add_argument("--out", type=Path, default=Path("build/wotf_fista_mie.zarr")) + parser.add_argument("--use-gpu", action="store_true") + parser.add_argument("--step", type=float, default=5.0e-5) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--tv-weight", type=float, default=0.0) + parser.add_argument("--tv-max-num-iter", type=int, default=50) + parser.add_argument("--tv-eps", type=float, default=2.0e-4) + parser.add_argument("--tv-aniso-z", type=float, default=1.0) + parser.add_argument("--z-taper", type=int, default=8) + parser.add_argument("--xy-taper", type=int, default=None) + parser.add_argument("--pupil-taper-na", type=float, default=0.0) + parser.add_argument("--eps", type=float, default=1.0e-8) + parser.add_argument("--line-search", action="store_true") + parser.add_argument("--line-search-factor", type=float, default=0.5) + parser.add_argument("--restart-line-search", action="store_true") + parser.add_argument("--tikhonov-reg", type=float, default=5.0e-4) + parser.add_argument("--led-subsample", type=int, default=1) + parser.add_argument("--pad-z", type=int, default=64) + parser.add_argument("--pad-yx", type=int, default=64) + args = parser.parse_args() + + use_gpu = bool(args.use_gpu and cp is not None) + print(f"Using GPU: {use_gpu}") + + ensure_dpc_mie_stack(args.sim_zarr, use_gpu=use_gpu) + I_cam, attrs, z_planes_um = _load_simulated_stack(args.sim_zarr) + meta = _build_meta(I_cam, attrs, z_planes_um) + + camera_gain = attrs.get("camera_gains", 1.0) + camera_offset = attrs.get("camera_offsets", 0.0) + I_cam_xp = cp.asarray(I_cam) if use_gpu else np.asarray(I_cam) + I_meas_phot = camera_to_photons( + I_cam_xp, + camera_offset_adu=camera_offset, + camera_gain_photons_per_adu=camera_gain, + ) + + I_meas = np.asarray(to_cpu(I_meas_phot), dtype=np.float32) + I0_pred = np.mean(I_meas, axis=(2, 3), keepdims=True).astype(np.float32, copy=False) + + led_na_xy = build_led_na_grid( + meta.led_grid_shape[0], + meta.led_grid_shape[1], + na_obj=meta.NA_obj, + na_in=meta.inner_na, + include_center=meta.include_center_led, + led_subsample=int(args.led_subsample), + ) + membership = build_led_pattern_membership( + led_na_xy, + order=meta.pattern_order, + include_center=meta.include_center_led, + ) + + params = WOTFParams( + wavelength_um=meta.wavelength_um, + na_obj=meta.NA_obj, + na_in=meta.inner_na, + n0=meta.n_background, + dxy_um=meta.dxy_um, + dz_um=float(meta.voxel_size_um_zyx[0]), + nz=meta.volume_shape_zyx[0] + 2 * int(args.pad_z), + ny=meta.volume_shape_zyx[1] + 2 * int(args.pad_yx), + nx=meta.volume_shape_zyx[2] + 2 * int(args.pad_yx), + pupil_taper_na=float(args.pupil_taper_na), + ) + H_real, H_imag = build_wotf_transfer(led_na_xy, membership, params) + + xp = cp if use_gpu else np + n0_vol = xp.full(meta.volume_shape_zyx, float(meta.n_background), dtype=xp.float32) + + contrast = (I_meas / I0_pred - 1.0) / -1.0 + pad_z = int(args.pad_z) + pad_yx = int(args.pad_yx) + if pad_z or pad_yx: + contrast = np.pad( + contrast, + ((0, 0), (pad_z, pad_z), (pad_yx, pad_yx), (pad_yx, pad_yx)), + mode="constant", + ) + v_real = _tikhonov_wotf(contrast, H_real, reg=float(args.tikhonov_reg)) + if pad_z or pad_yx: + nz, ny, nx = meta.volume_shape_zyx + v_real = v_real[pad_z : pad_z + nz, pad_yx : pad_yx + ny, pad_yx : pad_yx + nx] + n_tikh = get_n(v_real, meta.n_background, meta.wavelength_um).real.astype(np.float32, copy=False) + + optimizer = WOTFFISTAOptimizer( + xp.asarray(I_meas), + xp.asarray(I0_pred), + xp.asarray(H_real), + xp.asarray(H_imag), + n0=meta.n_background, + wavelength_um=meta.wavelength_um, + eps=float(args.eps), + tv_weight=float(args.tv_weight), + tv_max_num_iter=int(args.tv_max_num_iter), + tv_eps=float(args.tv_eps), + tv_voxel_size_zyx=(params.dz_um, params.dxy_um, params.dxy_um), + tv_weight_scale_zyx=(float(args.tv_aniso_z), 1.0, 1.0), + pad_zyx=(int(args.pad_z), int(args.pad_yx), int(args.pad_yx)), + z_taper=int(args.z_taper), + xy_taper=args.xy_taper, + use_real_constraint=True, + ) + + n_init = (cp.asarray if use_gpu else np.asarray)( + np.full(meta.volume_shape_zyx, meta.n_background, dtype=np.float32) + ) + line_search_iter_limit = None if args.line_search else 0 + result = optimizer.run( + n_init, + step=float(args.step), + max_iterations=int(args.iters), + use_fista=True, + n_batch=1, + compute_batch_grad_parallel=True, + verbose=True, + compute_cost=True, + line_search_iter_limit=line_search_iter_limit, + line_search_factor=float(args.line_search_factor), + restart_line_search=bool(args.restart_line_search), + ) + n_fista = np.asarray(to_cpu(result["x"]), dtype=np.float32) + + I_pred_fista = np.asarray( + to_cpu( + optimizer._forward_intensity( # type: ignore[attr-defined] + cp.asarray(n_fista) if use_gpu else n_fista + ) + ), + dtype=np.float32, + ) + + out_path = args.out + out_path.parent.mkdir(parents=True, exist_ok=True) + g = zarr.open_group(str(out_path), mode="w", zarr_format=3) + g.attrs.update(meta.as_dict()) + g.attrs["use_gpu"] = bool(use_gpu) + g.attrs["iters"] = int(args.iters) + g.attrs["step"] = float(args.step) + g.attrs["tv_weight"] = float(args.tv_weight) + g.attrs["z_taper"] = int(args.z_taper) + g.attrs["xy_taper"] = None if args.xy_taper is None else int(args.xy_taper) + g.attrs["pupil_taper_na"] = float(args.pupil_taper_na) + g.attrs["tikhonov_reg"] = float(args.tikhonov_reg) + g.attrs["init_from_tikhonov"] = True + g.create_array("I_meas", shape=I_meas.shape, dtype="float32")[...] = I_meas + g.create_array("I0_pred", shape=I0_pred.shape, dtype="float32")[...] = I0_pred + g.create_array("n_fista", shape=n_fista.shape, dtype="float32")[...] = n_fista + g.create_array("n_tikh", shape=n_tikh.shape, dtype="float32")[...] = n_tikh + g.create_array("I_pred_fista", shape=I_pred_fista.shape, dtype="float32")[...] = I_pred_fista + g.create_array("led_na_xy", shape=led_na_xy.shape, dtype="float32")[...] = led_na_xy.astype(np.float32, copy=False) + g.create_array("pattern_membership", shape=membership.shape, dtype="uint8")[...] = membership.astype( + np.uint8, copy=False + ) + print(f"Wrote outputs to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/simulate_dpc_microsphere.py b/examples/simulate_dpc_microsphere.py new file mode 100644 index 0000000..fe5e928 --- /dev/null +++ b/examples/simulate_dpc_microsphere.py @@ -0,0 +1,2101 @@ +""" +Synthetic DPC image generation using Mie-theory scattered fields from a sphere, +with a simple synthetic imaging model (objective pupil + camera sampling). + +Overview +-------- +For each LED on a centered board, we: +1) map its board position to an incidence direction (NA_x, NA_y), +2) compute the complex vector electric field around/through a sphere using Mie theory + (`mie_fields.mie_efield`), +3) apply a coherent imaging operator representing the objective/tube-lens relay + (modeled as a circular pupil in spatial-frequency space with cutoff NA/lambda), +4) compute irradiance I(x,y) = sum_{c in {x,y,z}} |E_c(x,y)|^2, +5) optionally bin/average to camera pixels (camera integration / sampling), +6) incoherently sum irradiances for LEDs in left/right/up/down half-plane patterns. + +Key implementation requirement +------------------------------ +A 64 x 64 board has 4096 LEDs. This module simulates each active LED exactly once +(after circular/annular mask and optional subsampling), and accumulates that LED's +irradiance into any applicable DPC pattern sums (left/right/up/down). This avoids +redundant Mie computations across patterns. + +Optical model included (requested) +---------------------------------- +- Objective: circular pupil with coherent cutoff f_c = NA / lambda (cycles/µm). +- Tube lens: assumed ideal/infinity-corrected; its role is magnification. +- Camera: sampling and optional pixel integration via binning from an oversampled grid. + +This is a first-order synthetic imaging model suitable for generating test data and +debugging inverse-model implementations. + +Zarr output +----------- +Helpers are provided to write a Zarr dataset in the exact layout expected by the +unit tests in `test_dpc_idt_reconstruction.py`: +- root array: /dpc, shape (4, ny, nx), dtype float32 +- root attrs: wavelength_um, na_obj, camera_pixel_um, magnification, n_medium, + led_grid_shape, pattern_order, plus optional nz and z_span_um + +Dependencies +------------ +- mie_fields.py (user-provided): provides `mie_efield(...)`. + +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Sequence + +import numpy as np +import hashlib +import json +import os +from tqdm import trange + +try: + import cupy as cp # Optional; mie_fields supports GPU mode +except ImportError: # pragma: no cover + cp = None + +try: + from localize_psf.camera import simulated_img +except ImportError: # pragma: no cover + simulated_img = None + +from mcsim.analysis.mie_fields import mie_efield +from mcsim.analysis.fft import ft2, ift2 +try: + from mcsim.analysis.field_prop import propagate_homogeneous +except Exception as _e: # pragma: no cover + propagate_homogeneous = None + +from mcsim.analysis.dpc_fista_solver import LEDBoard, compute_led_geometry + + +if cp: + array = np.ndarray | cp.ndarray +else: + array = np.ndarray + + +PatternName = Literal["left", "right", "up", "down"] + + +@dataclass(frozen=True, slots=True) +class SphereSpec: + """ + Optical / geometric description of the sphere for Mie simulation. + + Parameters + ---------- + radius_um : float + Sphere radius (µm). Forwarded to `mie_fields.mie_efield` and used to pick + a safe evaluation plane outside the sphere. + n_sphere : complex or float + Complex refractive index of the sphere (real part for phase, imaginary for + absorption). Passed directly to `mie_fields.mie_efield`. + n_medium : float + Surrounding medium refractive index used by `mie_efield` and + `na_to_incidence_angles`. + """ + radius_um: float + n_sphere: complex | float + n_medium: float = 1.0 + + +@dataclass(frozen=True, slots=True) +class SimulationSpace: + """ + Grid used for field simulation (object space). + + Parameters + ---------- + dxy_um : float + Sampling pitch (µm) on the simulation grid passed to `mie_fields.mie_efield` + and `field_prop.propagate_homogeneous`. + esize : tuple[int, int] + Grid shape (ny, nx) for the simulated field prior to camera binning. + z_plane_um : float or None + Axial distance from the sphere center for field evaluation before refocus. + If None, a plane just outside the sphere is chosen for stability. + """ + dxy_um: float + esize: tuple[int, int] + z_plane_um: float | None = None + + +@dataclass(frozen=True, slots=True) +class CameraSpace: + """ + Description of the camera grid and noise model (image space). + + Parameters + ---------- + pixel_um : float + Camera pixel pitch (µm) in the camera plane; combined with `magnification` + to compute the `bin_size` for `localize_psf.camera.simulated_img`. + magnification : float + Object-to-camera magnification; larger values increase the bin factor between + simulation and camera grids. + shape : tuple[int, int] + Final camera image shape (ny, nx) returned by `simulated_img`. + psf : array or None + Optional PSF forwarded to `simulated_img` for blurring. + apodization : array or int or float + Apodization factor forwarded to `simulated_img` during PSF blurring (1 disables). + gains : array or float + Multiplicative conversion from photons to ADU (ADU/e) applied in `simulated_img`. + offsets : array or float + Additive camera offset (ADU) applied after gain and readout noise. + readout_noise_sds : array or float + Standard deviation (ADU) of Gaussian readout noise added in `simulated_img`. + photon_shot_noise : bool + If True, enable Poisson shot noise in `simulated_img`. + saturation : int or None + Clip simulated images above this value in `simulated_img`. + image_is_integer : bool + If True, round the final simulated image to integers in `simulated_img`. + """ + pixel_um: float + magnification: float + shape: tuple[int, int] + psf: array | None = None # type: ignore + apodization: array | int | float = 1 # type: ignore + gains: array | float = 1.0 # type: ignore + offsets: array | float = 0.0 # type: ignore + readout_noise_sds: array | float = 0.0 # type: ignore + photon_shot_noise: bool = False + saturation: int | None = None + image_is_integer: bool = False + + @property + def object_pixel_um(self) -> float: + return float(self.pixel_um) / float(self.magnification) + + +def _get_xp(use_gpu: bool): + """ + Select NumPy or CuPy module based on GPU usage. + + Parameters + ---------- + use_gpu : bool + If True and CuPy is available, return `cupy`; otherwise return `numpy`. + + Returns + ------- + module + Backend array module. + """ + if use_gpu and (cp is not None): + return cp + return np + + +def _to_xp(x: array, *, use_gpu: bool) -> array: # type: ignore + """ + Explicitly move arrays to the requested backend. + + Parameters + ---------- + x : array + Input array (NumPy, CuPy, or dask). + use_gpu : bool + If True, ensure the output is a CuPy array; otherwise ensure NumPy. + + Returns + ------- + array + Array on the requested backend, avoiding implicit host/device transfers. + """ + if use_gpu: + if cp is None: + raise ImportError("use_gpu=True requested but CuPy is not available.") + return cp.asarray(x) + if cp is not None and isinstance(x, cp.ndarray): + return cp.asnumpy(x) + return x + + +def _nan_to_zero(x: array, *, use_gpu: bool) -> array: # type: ignore + """ + Replace NaN/Inf with 0 on the requested backend. + + Parameters + ---------- + x : array + Input array possibly containing NaN/Inf. + use_gpu : bool + If True, operate with CuPy; otherwise NumPy. + + Returns + ------- + array + Cleaned array on the requested backend with NaN/Inf replaced by 0. + """ + x = _to_xp(x, use_gpu=use_gpu) + xp = cp if (use_gpu and cp is not None) else np + # Both NumPy and CuPy implement nan_to_num (including for complex dtypes). + return xp.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + +def _json_dumps_stable(obj) -> str: + """ + Stable JSON dump for hashing parameters. + + Parameters + ---------- + obj : Any + Object to serialize. + + Returns + ------- + str + JSON string with deterministic key ordering. + """ + return json.dumps(obj, sort_keys=True, separators=(",", ":"), default=str) + + +def _cache_key(params: dict) -> str: + """ + Create a short stable cache key for a simulation run. + + Parameters + ---------- + params : dict + Parameter dictionary to hash. + + Returns + ------- + str + SHA1-based short key (first 16 hex chars). + """ + import hashlib + + s = _json_dumps_stable(params).encode("utf-8") + return hashlib.sha1(s).hexdigest()[:16] + + +def _cache_run_dir(cache_dir: str | Path, params: dict) -> Path: + """ + Create (or reuse) the cache directory for a parameter set. + + Parameters + ---------- + cache_dir : str or Path + Base directory for simulation caches. + params : dict + Parameter dictionary hashed to form a unique run key. + + Returns + ------- + Path + Path to the run-specific cache directory containing metadata and Zarr stores. + """ + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + key = _cache_key(params) + run_dir = cache_dir / key + run_dir.mkdir(parents=True, exist_ok=True) + meta_path = run_dir / "meta.json" + if not meta_path.exists(): + meta_path.write_text(_json_dumps_stable(params) + "\n") + return run_dir + + +def _open_led_cache_zarr( + run_dir: Path, + *, + n_led: int, + n_planes: int, + ny: int, + nx: int, + chunks: tuple[int, int, int, int], +): + """ + Open (or create) a per-LED irradiance cache as a Zarr store. + + Parameters + ---------- + run_dir : Path + Directory in which to create the cache. + n_led : int + Number of LED entries to cache. + n_planes : int + Number of axial planes per LED (for focal stacks). + ny, nx : int + Camera image dimensions. + chunks : tuple[int, int, int, int] + Chunk shape for the cached array (n_led, n_planes, ny, nx). + + Returns + ------- + group, I_arr, done + Zarr group, cached camera irradiance array, and completion mask. + + Notes + ----- + We cache the binned camera irradiance (float32) to reduce size and avoid recomputation. + """ + try: + import zarr # type: ignore + except Exception as e: # pragma: no cover + raise ImportError("zarr is required for caching") from e + + store_path = run_dir / "led_cache.zarr" + g = zarr.open_group(str(store_path), mode="a") + + g.attrs["n_led"] = int(n_led) + g.attrs["n_planes"] = int(n_planes) + g.attrs["ny"] = int(ny) + g.attrs["nx"] = int(nx) + + if "I_cam" in g: + I_arr = g["I_cam"] + if tuple(I_arr.shape) != (int(n_led), int(n_planes), int(ny), int(nx)): + raise ValueError( + f"Existing LED cache shape {tuple(I_arr.shape)} != {(int(n_led), int(n_planes), int(ny), int(nx))}" + ) + else: + # Use chunks kwarg for compatibility with current Zarr API + I_arr = g.create_array( + "I_cam", + shape=(int(n_led), int(n_planes), int(ny), int(nx)), + dtype="float32", + chunks=chunks, + ) # type: ignore[call-arg] + + if "done" in g: + done = g["done"] + if tuple(done.shape) != (int(n_led),): + raise ValueError(f"Existing LED cache done shape {tuple(done.shape)} != {(int(n_led),)}") + else: + try: + done = g.create_array("done", shape=(int(n_led),), dtype="u1", overwrite=False) + except (TypeError, AttributeError): + done = g.create_array("done", shape=(int(n_led),), dtype="u1") + done[...] = 0 + + return g, I_arr, done + + +def _zarr_write_array( + g, + name: str, + data: np.ndarray, + *, + dtype: str = "float32", + chunks=None, + compressor=None, + overwrite: bool = True, +): + """ + Write an array using Zarr v3 `Group.create_array`. + + Parameters + ---------- + g : zarr.Group + Target group. + name : str + Dataset name. + data : np.ndarray + Array to write (converted to dtype). + dtype : str + Target dtype for the stored array. + chunks : tuple[int, ...] or None + Chunk shape passed as `chunk_shape` to `create_array`. If None, let Zarr decide. + overwrite : bool + Included for API compatibility; Zarr v3 `create_array` overwrites only if allowed by mode. + + Returns + ------- + zarr.Array + Written array. + """ + shape = tuple(int(s) for s in data.shape) + create_kwargs = {"shape": shape, "dtype": dtype} + if compressor is not None: + if isinstance(compressor, dict): + create_kwargs.update(compressor) + else: + create_kwargs["compressor"] = compressor + + if chunks is not None: + arr = g.create_array(name, chunks=chunks, **create_kwargs) # type: ignore[call-arg] + else: + arr = g.create_array(name, **create_kwargs) + arr[...] = data + return arr + +def sample_pixel_size_um(camera_pixel_um: float, magnification: float, *, oversample: int = 1) -> float: + """ + Convert camera pixel size to object-plane pixel size, optionally oversampled. + + Parameters + ---------- + camera_pixel_um : float + Camera pixel pitch (micrometers). This matches `CameraSpace.pixel_um`. + magnification : float + System magnification (object→camera). Higher magnification reduces the + effective object-plane sampling. + oversample : int + Oversampling factor. If >1, the object-plane sampling is divided by this + factor (finer simulation grid), and images can be binned back to camera + sampling using `bin2_average` or `simulated_img(bin_size=oversample)`. + + Returns + ------- + dxy_um : float + Object-plane pixel size (micrometers). + """ + if magnification <= 0: + raise ValueError("magnification must be > 0") + if oversample < 1: + raise ValueError("oversample must be >= 1") + return float(camera_pixel_um) / float(magnification) / float(oversample) + + + +def _default_z_plane_um(*, sphere_radius_um: float, dxy_um: float, wavelength_um: float) -> float: + """ + Choose a stable observation plane for Mie evaluation. + + The Mie scattered-field expansion is unstable near r=0; this picks an exterior + plane to evaluate fields. + + Parameters + ---------- + sphere_radius_um : float + Sphere radius (µm). + dxy_um : float + Object-plane sampling (µm). + wavelength_um : float + Vacuum wavelength (µm). + + Returns + ------- + float + Axial distance (µm) from sphere center to evaluation plane. + """ + safety = float(max(float(dxy_um), 0.25 * float(wavelength_um))) + return float(sphere_radius_um) + safety + + +def _safe_z_plane_um( + z_plane_um: float, + *, + sphere_radius_um: float, + dxy_um: float, + wavelength_um: float, +) -> float: + """Ensure the Mie evaluation plane lies outside the sphere. + + If |z_plane_um| is less than radius + safety_margin, we override it to + sign(z_plane_um) * (radius + safety_margin). If z_plane_um==0, we choose +. + + This avoids evaluating the scattered-field expansion at/near r=0 and prevents + overflow/NaN in irradiance formation. + """ + safety = float(max(float(dxy_um), 0.01 * float(wavelength_um))) + z_min = float(sphere_radius_um) + safety + z = float(z_plane_um) + if abs(z) < z_min: + sgn = 1.0 if z == 0.0 else float(np.sign(z)) + return sgn * z_min + return z + + + +def make_led_na_positions( + ny_led: int, + nx_led: int, + *, + na_obj: float, + inner_na: float = 0.0, + include_center: bool = False, +) -> np.ndarray: + """ + Generate LED positions in NA space for a centered rectangular board, masked to an inscribed circle. + + Assumptions + ---------- + - Board is centered on the objective. + - The outermost LED inside the circular mask maps to `na_obj`. + - LEDs lie on a regular grid; positions are scaled to NA units. + + Parameters + ---------- + ny_led, nx_led : int + LED grid shape (64, 64 for the requested board). + na_obj : float + Objective NA; sets max LED radius in NA space. Passed later to + `DPCMieSimulator` to form the pupil cutoff. + inner_na : float + Inner NA radius (>=0) to form an annulus. Default 0 (filled disk). + include_center : bool + If False, excludes the center LED (NA=0). + + Returns + ------- + na_xy : np.ndarray + array of shape (N, 2) with columns (NA_x, NA_y). + """ + if ny_led < 1 or nx_led < 1: + raise ValueError("ny_led and nx_led must be >= 1") + if inner_na < 0 or inner_na >= na_obj: + raise ValueError("inner_na must satisfy 0 <= inner_na < na_obj") + + cy = (ny_led - 1) / 2.0 + cx = (nx_led - 1) / 2.0 + + yy, xx = np.meshgrid( + np.arange(ny_led, dtype=np.float32), + np.arange(nx_led, dtype=np.float32), + indexing="ij", + ) + x = xx - cx + y = yy - cy + r = np.sqrt(x * x + y * y) + + # Use inscribed circle to avoid corners. + r_circle = float(min(cx, cy)) if (cx > 0 and cy > 0) else float(np.max(r)) + mask = r <= r_circle + + if not include_center: + mask = mask & (r > 0) + + r_mask_max = float(np.max(r[mask])) if np.any(mask) else 1.0 + scale = float(na_obj) / r_mask_max + + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + + ann = na_r >= float(inner_na) + na_xy = np.stack([na_x[ann], na_y[ann]], axis=1) + return na_xy + + +def split_dpc_patterns( + na_xy: np.ndarray, + *, + order: Sequence[PatternName] = ("left", "right", "up", "down"), +) -> dict[PatternName, np.ndarray]: + """ + Split LED NA positions into canonical DPC half-circle patterns. + + Pattern membership is based on the sign of NA components: + - left: NA_x < 0 + - right: NA_x > 0 + - up: NA_y > 0 + - down: NA_y < 0 + + Parameters + ---------- + na_xy : np.ndarray + LED positions in NA units, shape (N, 2). + order : Sequence[PatternName] + Must be a permutation of ("left","right","up","down"). Used for validation. + + Returns + ------- + patterns : dict[PatternName, np.ndarray] + Mapping pattern name -> LED NA positions in that pattern, each of shape (N_p, 2). + These positions are later iterated in `DPCMieSimulator.simulate_patterns` + when calling `mie_efield` for each LED. + """ + allowed: tuple[PatternName, ...] = ("left", "right", "up", "down") + if len(order) != 4 or set(order) != set(allowed): + raise ValueError(f"order must be a permutation of {allowed}; got {tuple(order)}") + + na_x = na_xy[:, 0] + na_y = na_xy[:, 1] + + masks = { + "left": na_x < 0, + "right": na_x > 0, + "up": na_y > 0, + "down": na_y < 0, + } + + out: dict[PatternName, np.ndarray] = {} + for name in allowed: + pts = na_xy[masks[name]] + if pts.shape[0] == 0: + raise ValueError(f"pattern '{name}' has zero LEDs (check board geometry).") + out[name] = pts + + return out + + +def na_to_incidence_angles( + na_x: float, + na_y: float, + *, + n_medium: float = 1.0, +) -> tuple[float, float]: + """ + Convert (NA_x, NA_y) into incidence direction angles for `mie_efield`. + + We interpret NA components as: + NA_x = n_medium * sin(theta) * cos(phi) + NA_y = n_medium * sin(theta) * sin(phi) + + therefore: + sin(theta) = sqrt(NA_x^2 + NA_y^2) / n_medium + phi = atan2(NA_y, NA_x) + + Parameters + ---------- + na_x, na_y : float + NA components (dimensionless). + n_medium : float + Medium refractive index. + + Returns + ------- + beam_theta, beam_phi : float + Angles (radians) to pass as `beam_theta` and `beam_phi`. + `beam_psi` can be set to 0 for x-polarized input in mie_fields. + These are forwarded unmodified to `mie_fields.mie_efield`. + """ + na_r = float(np.sqrt(na_x * na_x + na_y * na_y)) + s = na_r / float(n_medium) + s = float(np.clip(s, 0.0, 1.0)) + theta = float(np.arcsin(s)) + phi = float(np.arctan2(na_y, na_x)) + return theta, phi + + +def _fftshift_freq_grid(n: int, d: float, xp) -> array: # type: ignore + """ + Frequency samples (cycles/µm) aligned to fftshifted spectra. + """ + f = xp.fft.fftfreq(n, d=d) + return xp.fft.fftshift(f) + + +def make_pupil( + ny: int, + nx: int, + *, + dxy_um: float, + wavelength_um: float, + na_obj: float, + xp, +) -> array: # type: ignore + """ + Create a circular coherent pupil mask for the objective. + + Parameters + ---------- + ny, nx : int + Output mask dimensions. + dxy_um : float + Sampling pitch (µm) of the simulation grid. + wavelength_um : float + Vacuum wavelength (µm). + na_obj : float + Objective NA determining cutoff f_c = NA / wavelength. + xp : module + Backend array module (NumPy or CuPy). + + Returns + ------- + array + Pupil mask of shape (ny, nx) with ones inside the cutoff and zeros outside, + used by `apply_objective_imaging`. + """ + fx = _fftshift_freq_grid(nx, dxy_um, xp) # (nx,) + fy = _fftshift_freq_grid(ny, dxy_um, xp) # (ny,) + fxx, fyy = xp.meshgrid(fx, fy, indexing="xy") + fr = xp.sqrt(fxx * fxx + fyy * fyy) + + fc = float(na_obj) / float(wavelength_um) + return (fr <= fc).astype(xp.float32) + + +# -----------------------------------------------------------------------------# +# Coherent imaging helpers (pupil + optional apodization) with caching +# -----------------------------------------------------------------------------# + +_APOD_CACHE: dict[tuple, array] = {} # type: ignore + + +def _tukey_1d(n: int, alpha: float, xp) -> array: # type: ignore + """ + Create a 1D Tukey window on the requested backend. + + Parameters + ---------- + n : int + Number of samples (must be > 0). + alpha : float + Taper parameter in [0,1]; 0 yields a rectangular window, 1 yields Hann. + xp : module + Backend array module (NumPy or CuPy). + + Returns + ------- + array + Window of shape (n,) on the requested backend. + """ + if n <= 0: + raise ValueError("n must be > 0") + if alpha <= 0: + return xp.ones((n,), dtype=xp.float32) + if alpha >= 1: + # Hann + x = xp.arange(n, dtype=xp.float32) + return (0.5 - 0.5 * xp.cos(2 * xp.pi * x / (n - 1))).astype(xp.float32) + + x = xp.linspace(0.0, 1.0, n, dtype=xp.float32) + w = xp.ones((n,), dtype=xp.float32) + edge = alpha / 2.0 + + m1 = x < edge + m2 = x > (1.0 - edge) + # rising cosine + w[m1] = 0.5 * (1.0 + xp.cos(xp.pi * ((2.0 * x[m1] / alpha) - 1.0))) + # falling cosine + w[m2] = 0.5 * (1.0 + xp.cos(xp.pi * ((2.0 * x[m2] / alpha) - (2.0 / alpha) + 1.0))) + return w.astype(xp.float32) + + +def _get_cached_apodization( + ny: int, + nx: int, + *, + alpha: float, + use_gpu: bool, +): + """ + Retrieve or build a cached Tukey apodization window on the requested backend. + + Parameters + ---------- + ny, nx : int + Window dimensions. + alpha : float + Tukey alpha parameter passed to `_tukey_1d`. + use_gpu : bool + If True, cache CuPy arrays; otherwise NumPy. + + Returns + ------- + array + 2D apodization window cached for reuse. + """ + xp = cp if (use_gpu and cp is not None) else np + key = (int(ny), int(nx), float(alpha), bool(use_gpu)) + apo = _APOD_CACHE.get(key) + if apo is None: + wy = _tukey_1d(int(ny), float(alpha), xp) + wx = _tukey_1d(int(nx), float(alpha), xp) + apo = (wy[:, None] * wx[None, :]).astype(xp.float32, copy=False) + _APOD_CACHE[key] = apo + return apo + + +def apply_objective_imaging( + e_vec: array, # type: ignore + *, + dxy_um: float, + wavelength_um: float, + na_obj: float, + use_gpu: bool = False, +) -> array: # type: ignore + """ + Apply a simple coherent imaging operator to a vector field. + + This models an ideal infinity-corrected objective + tube lens as a circular pupil + (coherent transfer) in spatial-frequency space. + `dxy_um`, `wavelength_um`, and `na_obj` jointly define the pupil generated by + `make_pupil`, which is applied component-wise to `e_vec` before returning to + the spatial domain. The resulting field is later converted to intensity and + handed to `localize_psf.camera.simulated_img` for camera binning/noise. + + GPU/CPU behavior + ---------------- + This function explicitly converts inputs to the requested backend to avoid + implicit host<->device transfers. + + Parameters + ---------- + e_vec : array + Vector field, shape (3, ny, nx). + dxy_um : float + Sampling (µm) for e_vec grid. + wavelength_um : float + Vacuum wavelength (µm). + na_obj : float + Objective NA. + use_gpu : bool + If True, uses CuPy and runs FFTs on the GPU (requires CuPy). + + Returns + ------- + e_img : array + Filtered vector field, shape (3, ny, nx), on the requested backend. + """ + if e_vec.shape[0] != 3: + raise ValueError(f"e_vec must have shape (3, ny, nx), got {e_vec.shape}") + + # Ensure e_vec lives on the requested backend. + e_vec = _to_xp(e_vec, use_gpu=use_gpu) + xp = cp if (use_gpu and cp is not None) else np + + ny, nx = int(e_vec.shape[-2]), int(e_vec.shape[-1]) + P = make_pupil( + ny, + nx, + dxy_um=float(dxy_um), + wavelength_um=float(wavelength_um), + na_obj=float(na_obj), + xp=xp, + ) + + # Apply per-component in Fourier domain (centered FT). + out = xp.empty_like(e_vec) + for c in range(3): + Ec = _nan_to_zero(e_vec[c], use_gpu=use_gpu) + Ec_ft = ft2(Ec, axes=(-2, -1), shift=True, adjoint=False) + Ec_ft = Ec_ft * P + out[c] = ift2(Ec_ft, axes=(-2, -1), shift=True, adjoint=False) + return out + + +def bin2_average(im: array, factor: int, *, use_gpu: bool = False) -> array: # type: ignore + """ + Bin/average a 2D image by an integer factor. + + This approximates camera pixel integration when `factor` corresponds to an + oversampling ratio (e.g., when simulation sampling is finer than camera pixels). + It mirrors the binning performed inside `localize_psf.camera.simulated_img` + when `bin_size` matches `factor`. + + Backend behavior + ---------------- + This function explicitly converts `im` to the requested backend to avoid + implicit host<->device transfers. + + Parameters + ---------- + im : array + 2D image, shape (ny, nx). + factor : int + Binning factor (>=1). + use_gpu : bool + If True, operate on the GPU (requires CuPy). + + Returns + ------- + binned : array + Binned image, shape (ny//factor, nx//factor), on the requested backend. + """ + if factor < 1: + raise ValueError("factor must be >= 1") + if factor == 1: + return _to_xp(im, use_gpu=use_gpu) + + im = _to_xp(im, use_gpu=use_gpu) + xp = cp if (use_gpu and cp is not None) else np + + ny, nx = int(im.shape[-2]), int(im.shape[-1]) + ny2 = (ny // factor) * factor + nx2 = (nx // factor) * factor + + im2 = im[:ny2, :nx2] + im2 = im2.reshape(ny2 // factor, factor, nx2 // factor, factor) + return xp.mean(im2, axis=(1, 3)) + + +class DPCMieSimulator: + """ + End-to-end simulator that keeps a clear boundary between the simulation grid + (Mie field generation + propagation) and the camera grid (pixel binning + noise). + + Parameters + ---------- + wavelength_um : float + Vacuum wavelength (µm) for field generation and pupil cutoff. + na_obj : float + Objective NA used in `apply_objective_imaging`. + sphere : SphereSpec + Sphere properties forwarded to `mie_fields.mie_efield`. + simulation : SimulationSpace + Simulation grid settings (sampling, size, z-plane). + camera : CameraSpace + Camera grid/noise settings forwarded to `localize_psf.camera.simulated_img`. + led_grid_shape : tuple[int, int], optional + LED board dimensions. + pitch_mm : float, optional + LED pitch (mm) on the board. + inner_na : float, optional + Inner NA for annular illumination mask. + include_center_led : bool, optional + Whether to include NA=0 LED. + pattern_order : Sequence[str], optional + Order of DPC patterns in outputs. + normalize_by_led_count : bool, optional + Normalize patterns by contributing LED count. + led_subsample : int, optional + Subsample factor for LEDs simulated. + use_gpu : bool, optional + If True, use CuPy-backed computations where available. + mie_kwargs : dict or None, optional + Extra kwargs passed to `mie_fields.mie_efield`. + cache_dir : str or Path or None, optional + Directory for Zarr per-LED camera irradiance cache. + reuse_cache : bool, optional + If True, reuse cached per-LED images when present. + """ + + def __init__( + self, + *, + wavelength_um: float, + na_obj: float, + sphere: SphereSpec, + simulation: SimulationSpace, + camera: CameraSpace, + led_grid_shape: tuple[int, int] = (64, 64), + pitch_mm: float = 2.5, + inner_na: float = 0.0, + include_center_led: bool = False, + pattern_order: Sequence[PatternName] = ("left", "right", "up", "down"), + normalize_by_led_count: bool = True, + led_subsample: int = 1, + use_gpu: bool = False, + mie_kwargs: dict | None = None, + cache_dir: str | Path | None = None, + reuse_cache: bool = True, + exposure_time_ms: float = 1.0, + illumination_photons_per_s_per_um2: float = 1.0, + focal_stack_planes: int | None = None, + focal_stack_step_um: float = 0.5, + ): + if simulated_img is None: + raise ImportError("localize_psf.camera.simulated_img is required but not importable.") + if propagate_homogeneous is None: + raise ImportError("mcsim.analysis.field_prop.propagate_homogeneous is required but not importable.") + + allowed: tuple[PatternName, ...] = ("left", "right", "up", "down") + if len(pattern_order) != 4 or set(pattern_order) != set(allowed): + raise ValueError(f"pattern_order must be a permutation of {allowed}; got {tuple(pattern_order)}") + if led_subsample < 1: + raise ValueError("led_subsample must be >= 1") + + self.wavelength_um = float(wavelength_um) + self.na_obj = float(na_obj) + self.sphere = sphere + self.simulation = simulation + self.camera = camera + self.led_grid_shape = (int(led_grid_shape[0]), int(led_grid_shape[1])) + self.pitch_mm = float(pitch_mm) + self.inner_na = float(inner_na) + self.include_center_led = bool(include_center_led) + self.pattern_order = pattern_order + self.normalize_by_led_count = bool(normalize_by_led_count) + self.led_subsample = int(led_subsample) + self.use_gpu = bool(use_gpu) + self.mie_kwargs = mie_kwargs or {} + self.cache_dir = cache_dir + self.reuse_cache = bool(reuse_cache) + self.exposure_time_s = float(exposure_time_ms) / 1000.0 + self.illumination_photons_per_s_per_um2 = float(illumination_photons_per_s_per_um2) + if focal_stack_planes is None or focal_stack_planes <= 1: + self.focal_offsets_um = np.array([0.0], dtype=float) + else: + n = int(focal_stack_planes) + offsets = (np.arange(n) - (n - 1) / 2.0) * float(focal_stack_step_um) + self.focal_offsets_um = offsets.astype(float) + + self.xp = _get_xp(use_gpu) + self.camera_bin_factor = self._camera_bin_factor() + + def _camera_bin_factor(self) -> int: + """ + Integer ratio between camera sampling (object plane) and simulation sampling. + + Returns + ------- + int + Bin factor (`camera.object_pixel_um / simulation.dxy_um`) used as + `bin_size` for `localize_psf.camera.simulated_img`. + """ + ratio = float(self.camera.object_pixel_um) / float(self.simulation.dxy_um) + ratio_round = int(round(ratio)) + if ratio_round < 1 or not np.isclose(ratio, ratio_round, rtol=1e-3, atol=1e-6): + raise ValueError( + f"Camera pixel size / simulation sampling must be a positive integer; got ratio={ratio:.4f}" + ) + return ratio_round + + def _field_plane_z(self) -> float: + """ + Choose a stable evaluation plane for the Mie field. + + This is the `dz` plane passed to `mie_fields.mie_efield`. If the user provides + `simulation.z_plane_um`, it is validated with `_safe_z_plane_um`; otherwise a + default just outside the sphere surface is selected. + """ + if self.simulation.z_plane_um is not None: + return float(_safe_z_plane_um( + self.simulation.z_plane_um, + sphere_radius_um=float(self.sphere.radius_um), + dxy_um=float(self.simulation.dxy_um), + wavelength_um=self.wavelength_um, + )) + return _safe_z_plane_um( + _default_z_plane_um( + sphere_radius_um=float(self.sphere.radius_um), + dxy_um=float(self.simulation.dxy_um), + wavelength_um=self.wavelength_um, + ), + sphere_radius_um=float(self.sphere.radius_um), + dxy_um=float(self.simulation.dxy_um), + wavelength_um=self.wavelength_um, + ) + + def _simulate_led_ground_truth(self, na_x: float, na_y: float) -> array: # type: ignore + """ + Simulate pupil-filtered field on the simulation grid for a single LED (before camera). + + Parameters + ---------- + na_x, na_y : float + Illumination NA components converted to angles for `mie_efield`. + + Returns + ------- + array + Complex field on the simulation grid after refocus and pupil filtering. + + Notes + ----- + Uses `mie_fields.mie_efield` with `simulation.dxy_um`, `simulation.esize`, + and `_field_plane_z()`, refocuses via `propagate_homogeneous`, and applies + `apply_objective_imaging`. + """ + xp = self.xp + theta, phi = na_to_incidence_angles( + float(na_x), + float(na_y), + n_medium=float(self.sphere.n_medium), + ) + + z_plane = self._field_plane_z() + ny_sim, nx_sim = self.simulation.esize + + e_scatt_vec, _, e_inc_vec, _ = mie_efield( + float(self.wavelength_um), + float(self.sphere.n_medium), + float(self.sphere.radius_um), + complex(self.sphere.n_sphere), + float(self.simulation.dxy_um), + (int(ny_sim), int(nx_sim)), + dz=float(z_plane), + beam_theta=float(theta), + beam_phi=float(phi), + use_gpu=bool(self.use_gpu), + **self.mie_kwargs, + ) + + # Scalar field (x-component) and gentle apodization to suppress wrap-around. + e0 = (e_scatt_vec[0] + e_inc_vec[0]).astype(xp.complex64, copy=False) + apod = _get_cached_apodization(int(ny_sim), int(nx_sim), alpha=0.1, use_gpu=bool(self.use_gpu)) + e0 = (e0 * apod.astype(e0.real.dtype, copy=False)).astype(xp.complex64, copy=False) + + e_in = e0[None, None, :, :] # shape (1, 1, ny, nx) for propagate_homogeneous + e_prop = propagate_homogeneous( + e_in, + [-float(z_plane)], + float(self.sphere.n_medium), + (float(self.simulation.dxy_um), float(self.simulation.dxy_um)), + float(self.wavelength_um), + ) + + if e_prop.ndim >= 5: + e_foc = e_prop[..., 0, :, :][0, 0] + elif e_prop.ndim == 4: + e_foc = e_prop[0, 0] + else: + e_foc = e_prop + + # Reuse the existing objective pupil helper. + e_vec = xp.zeros((3,) + e_foc.shape, dtype=e_foc.dtype) + e_vec[0] = e_foc + e_img_vec = apply_objective_imaging( + e_vec, + dxy_um=float(self.simulation.dxy_um), + wavelength_um=float(self.wavelength_um), + na_obj=float(self.na_obj), + use_gpu=bool(self.use_gpu), + ) + return e_img_vec[0] + + def _render_camera(self, irradiance: array) -> array: # type: ignore + """ + Project simulation irradiance onto the camera grid (binning + optional noise). + + Parameters + ---------- + irradiance : array + Simulation-grid irradiance (pre-camera). + + Returns + ------- + array + Camera-grid image after binning/noise from `localize_psf.camera.simulated_img`. + + Notes + ----- + `bin_size` is derived from `camera.object_pixel_um / simulation.dxy_um`; all + camera noise parameters (gains, offsets, readout_noise_sds, photon_shot_noise, + saturation, image_is_integer) are forwarded from `CameraSpace`. PSF input is + not supported; the coherent transfer is modeled via the pupil internally. + """ + xp = self.xp + gt = xp.asarray(irradiance, dtype=xp.float32) + gt = gt * self.exposure_time_s + + gains = xp.asarray(self.camera.gains, dtype=gt.dtype) + offsets = xp.asarray(self.camera.offsets, dtype=gt.dtype) + readout = xp.asarray(self.camera.readout_noise_sds, dtype=gt.dtype) + + psf = None + if self.camera.psf is not None: + psf = xp.asarray(self.camera.psf, dtype=gt.dtype) + + apo = self.camera.apodization + if isinstance(apo, (np.ndarray,)) or (cp is not None and isinstance(apo, cp.ndarray)): + apo = xp.asarray(apo, dtype=gt.dtype) + + img, _ = simulated_img( + ground_truth=gt, + gains=gains, + offsets=offsets, + readout_noise_sds=readout, + psf=psf, + photon_shot_noise=bool(self.camera.photon_shot_noise), + bin_size=int(self.camera_bin_factor), + apodization=apo, + saturation=self.camera.saturation, + image_is_integer=bool(self.camera.image_is_integer), + ) + return xp.asarray(img, dtype=xp.float32) + + def _simulate_led(self, na_x: float, na_y: float) -> array: # type: ignore + """ + Full pipeline for one LED: field generation → optional defocus → camera rendering. + + Returns a focal stack with shape (n_planes, ny_cam, nx_cam). + """ + xp = self.xp + field0 = self._simulate_led_ground_truth(na_x, na_y) + + flux_scale = float(self.illumination_photons_per_s_per_um2) * float(self.simulation.dxy_um) ** 2 + stack = [] + for dz in self.focal_offsets_um: + if dz == 0: + f_def = field0 + else: + e_in = field0[None, None, :, :] + e_prop = propagate_homogeneous( + e_in, + [float(dz)], + float(self.sphere.n_medium), + (float(self.simulation.dxy_um), float(self.simulation.dxy_um)), + float(self.wavelength_um), + ) + if e_prop.ndim >= 5: + f_def = e_prop[..., 0, :, :][0, 0] + elif e_prop.ndim == 4: + f_def = e_prop[0, 0] + else: + f_def = e_prop + + I_sim = (xp.abs(f_def) ** 2).astype(xp.float32, copy=False) * flux_scale + stack.append(self._render_camera(I_sim)) + + return xp.stack(stack, axis=0) + + def simulate_patterns(self) -> tuple[array, dict[str, array]]: # type: ignore + """ + Simulate all LEDs and accumulate into DPC patterns. + + Returns + ------- + tuple[array, dict[str, array]] + DPC stack ordered by `pattern_order` and metadata including NA positions, + grid sizes, bin factor, and LED counts. + + Notes + ----- + - LED NA grid and pattern membership come from board geometry (n_side, pitch_mm) + mapped to NA via inscribed circle matched to na_obj, then split by + `split_dpc_patterns` using `inner_na`, `include_center_led`, and `pattern_order`. + - Each LED image is produced by `_simulate_led` (internally + `mie_fields.mie_efield` → `propagate_homogeneous` → `apply_objective_imaging` + → `localize_psf.camera.simulated_img`). + - Optional per-LED cache stored in Zarr (v3 only). + """ + allowed: tuple[PatternName, ...] = ("left", "right", "up", "down") + ny_led, nx_led = self.led_grid_shape + board = LEDBoard( + n_side=ny_led, + pitch_mm=float(self.pitch_mm), + na_obj=float(self.na_obj), + wavelength_um=float(self.wavelength_um), + n_medium=float(self.sphere.n_medium), + ) + geom = compute_led_geometry(board) + na_xy_all = geom.na_components + if not self.include_center_led: + r = np.linalg.norm(na_xy_all, axis=1) + na_xy_all = na_xy_all[r > 0] + if float(self.inner_na) > 0: + r = np.linalg.norm(na_xy_all, axis=1) + na_xy_all = na_xy_all[r >= float(self.inner_na)] + patterns_full = split_dpc_patterns(na_xy_all, order=self.pattern_order) + + na_xy = na_xy_all[:: self.led_subsample] + na_x = na_xy[:, 0] + na_y = na_xy[:, 1] + + membership = { + "left": na_x < 0, + "right": na_x > 0, + "up": na_y > 0, + "down": na_y < 0, + } + + if self.include_center_led: + center_mask = (na_x == 0) & (na_y == 0) + if np.any(center_mask): + for k in membership: + membership[k] = membership[k] | center_mask + + xp = self.xp + ny_cam, nx_cam = self.camera.shape + n_planes = int(len(self.focal_offsets_um)) + acc = {k: xp.zeros((n_planes, ny_cam, nx_cam), dtype=xp.float32) for k in allowed} + led_counts = {k: int(np.sum(np.asarray(membership[k], dtype=bool))) for k in allowed} + + run_dir = None + I_cache = None + done_cache = None + + if self.cache_dir is not None: + cache_params = { + "wavelength_um": float(self.wavelength_um), + "na_obj": float(self.na_obj), + "led_grid_shape": [int(self.led_grid_shape[0]), int(self.led_grid_shape[1])], + "pitch_mm": float(self.pitch_mm), + "camera_pixel_um": float(self.camera.pixel_um), + "magnification": float(self.camera.magnification), + "camera_shape": [int(ny_cam), int(nx_cam)], + "simulation_dxy_um": float(self.simulation.dxy_um), + "simulation_esize": [int(self.simulation.esize[0]), int(self.simulation.esize[1])], + "field_generation_z_um": float(self._field_plane_z()), + "sphere": { + "radius_um": float(self.sphere.radius_um), + "n_sphere": str(self.sphere.n_sphere), + "n_medium": float(self.sphere.n_medium), + }, + "inner_na": float(self.inner_na), + "include_center_led": bool(self.include_center_led), + "pattern_order": [str(p) for p in self.pattern_order], + "normalize_by_led_count": bool(self.normalize_by_led_count), + "led_subsample": int(self.led_subsample), + "mie_kwargs": self.mie_kwargs, + "camera_bin_factor": int(self.camera_bin_factor), + "exposure_time_s": float(self.exposure_time_s), + "illumination_photons_per_s_per_um2": float(self.illumination_photons_per_s_per_um2), + "focal_offsets_um": self.focal_offsets_um.tolist(), + } + run_dir = _cache_run_dir(self.cache_dir, cache_params) + chunks = (1, 1, min(256, int(ny_cam)), min(256, int(nx_cam))) + _, I_cache, done_cache = _open_led_cache_zarr( + Path(run_dir), + n_led=int(na_xy.shape[0]), + n_planes=int(n_planes), + ny=int(ny_cam), + nx=int(nx_cam), + chunks=chunks, + ) + + for j in trange(int(na_xy.shape[0]), desc="Simulating LEDs"): + nax = float(na_xy[j, 0]) + nay = float(na_xy[j, 1]) + + I_cam = None + if I_cache is not None and done_cache is not None and self.reuse_cache and int(done_cache[j]) == 1: + I_cam_np = np.asarray(I_cache[j, ...], dtype=np.float32) + I_cam = cp.asarray(I_cam_np) if (self.use_gpu and cp is not None) else I_cam_np + I_cam = I_cam.astype(xp.float32, copy=False) + + if I_cam is None: + I_cam = self._simulate_led(nax, nay).astype(xp.float32, copy=False) + + if I_cache is not None and done_cache is not None: + I_cam_save = cp.asnumpy(I_cam) if (cp is not None and hasattr(I_cam, "__cuda_array_interface__")) else np.asarray(I_cam) + I_cache[j, ...] = np.asarray(I_cam_save, dtype=np.float32) + done_cache[j] = 1 + + if not (xp.isfinite(I_cam).all()) or xp.isinf(I_cam).any(): + print(f"isfinite check: {xp.isfinite(I_cam).all()}") + print(f"isinf check: {xp.isinf(I_cam).any()}") + + if membership["left"][j]: + acc["left"] += I_cam + if membership["right"][j]: + acc["right"] += I_cam + if membership["up"][j]: + acc["up"] += I_cam + if membership["down"][j]: + acc["down"] += I_cam + + if self.normalize_by_led_count: + for k in allowed: + if led_counts[k] <= 0: + raise RuntimeError(f"No LEDs accumulated for pattern '{k}' (after subsampling).") + acc[k] = acc[k] / float(led_counts[k]) + + # Sanity check: ensure no pattern is identically zero after accumulation + for k in allowed: + if float(xp.sum(acc[k])) == 0.0: + raise RuntimeError(f"Accumulated pattern '{k}' is zero; check illumination masks and cache settings.") + + dpc = xp.stack([acc[p].copy() for p in self.pattern_order], axis=0) + dpc = xp.moveaxis(dpc, 0, 1) # (n_planes, 4, ny, nx) + dpc_out: array # type: ignore + if dpc.shape[0] == 1: + dpc_out = dpc[0] + else: + dpc_out = dpc + + meta: dict[str, array] = { # type: ignore + "na_xy": na_xy_all, + "pitch_mm": float(self.pitch_mm), + "simulation_dxy_um": np.asarray(self.simulation.dxy_um, dtype=np.float32), + "simulation_esize": (int(self.simulation.esize[0]), int(self.simulation.esize[1])), + "camera_shape": (int(ny_cam), int(nx_cam)), + "camera_bin_factor": int(self.camera_bin_factor), + "led_counts": {k: int(led_counts[k]) for k in allowed}, + "z_plane_um": float(self._field_plane_z()), + "exposure_time_s": float(self.exposure_time_s), + "illumination_photons_per_s_per_um2": float(self.illumination_photons_per_s_per_um2), + "pattern_sums": {k: float(xp.sum(acc[k])) for k in allowed}, + "focal_offsets_um": self.focal_offsets_um.tolist(), + } + for k, v in patterns_full.items(): + meta[k] = v + + return dpc_out, meta + + + +def simulate_led_irradiance( + *, + wavelength_um: float, + sphere: SphereSpec, + dxy_um: float, + esize: tuple[int, int], + z_plane_um: float, + na_x: float, + na_y: float, + na_obj: float, + camera_bin: int = 1, + use_gpu: bool = False, + mie_kwargs: dict | None = None, +) -> array: # type: ignore + """ + Simulate the (binned) camera-plane irradiance for a single LED illumination. + + Parameters + ---------- + wavelength_um : float + Vacuum wavelength (µm) for `mie_efield` and the objective pupil. + sphere : SphereSpec + Sphere properties forwarded to `mie_efield`. + dxy_um : float + Simulation sampling (µm) passed to `mie_efield` and propagation. + esize : tuple[int, int] + Simulation grid size (ny, nx). + z_plane_um : float + Field evaluation plane for `mie_efield`; if inside the sphere, a safe plane is chosen. + na_x, na_y : float + Illumination NA components converted to angles for `mie_efield`. + na_obj : float + Objective NA for the coherent pupil. + camera_bin : int, optional + Bin size for `simulated_img`, derived from simulation vs camera sampling. + use_gpu : bool, optional + If True, use CuPy-backed operations when available. + mie_kwargs : dict or None, optional + Extra keyword arguments passed directly to `mie_fields.mie_efield`. + + Returns + ------- + array + Binned camera-plane irradiance for the single LED. + + Notes + ----- + Internally builds a `DPCMieSimulator` with matching grids and calls its single-LED + pipeline: `mie_efield` → `propagate_homogeneous` → `apply_objective_imaging` → + `localize_psf.camera.simulated_img`. + """ + if sphere is None: + sphere = SphereSpec(radius_um=5.0, n_sphere=1.59 + 0.0j, n_medium=1.55) + + sim_space = SimulationSpace( + dxy_um=float(dxy_um), + esize=(int(esize[0]), int(esize[1])), + z_plane_um=float(z_plane_um), + ) + cam_shape = (int(esize[0] // camera_bin), int(esize[1] // camera_bin)) + cam_space = CameraSpace( + pixel_um=float(dxy_um) * float(camera_bin), + magnification=1.0, + shape=cam_shape, + photon_shot_noise=False, + readout_noise_sds=0.0, + gains=1.0, + offsets=0.0, + image_is_integer=False, + ) + + simulator = DPCMieSimulator( + wavelength_um=float(wavelength_um), + na_obj=float(na_obj), + sphere=sphere, + simulation=sim_space, + camera=cam_space, + led_grid_shape=(1, 1), + normalize_by_led_count=False, + led_subsample=1, + use_gpu=use_gpu, + mie_kwargs=mie_kwargs, + cache_dir=None, + reuse_cache=True, + ) + return simulator._simulate_led(float(na_x), float(na_y)) + + +def simulate_dpc_images_sphere( + *, + wavelength_um: float = 0.515, + na_obj: float = 0.8, + led_grid_shape: tuple[int, int] = (64, 64), + pitch_mm: float = 2.5, + camera_pixel_um: float = 2.4, + magnification: float = 20.0, + camera_oversample: int = 1, + esize_camera: tuple[int, int] = (256, 256), + sim_dxy_um: float | None = None, + sim_esize: tuple[int, int] | None = None, + z_plane_um: float = 0.0, + sphere: SphereSpec | None = None, + inner_na: float = 0.0, + include_center_led: bool = False, + pattern_order: Sequence[PatternName] = ("left", "right", "up", "down"), + normalize_by_led_count: bool = True, + led_subsample: int = 1, + use_gpu: bool = False, + mie_kwargs: dict | None = None, + cache_dir: str | Path | None = None, + reuse_cache: bool = True, + camera_gains: array | float = 1.0, # type: ignore + camera_offsets: array | float = 0.0, # type: ignore + camera_readout_noise_sds: array | float = 0.0, # type: ignore + camera_photon_shot_noise: bool = False, + camera_saturation: int | None = None, + camera_image_is_integer: bool = False, + exposure_time_ms: float = 1.0, + illumination_photons_per_s_per_um2: float = 1.0, + focal_stack_planes: int | None = None, + focal_stack_step_um: float = 0.5, +) -> tuple[array, dict[str, array]]: # type: ignore + """ + Generate four synthetic DPC images by incoherently summing per-LED irradiance. + + Parameters (and where they go) + ------------------------------ + wavelength_um : float + Vacuum wavelength (µm). Passed to `mie_fields.mie_efield`, the objective pupil + in `apply_objective_imaging`, and `propagate_homogeneous`. + na_obj : float + Objective NA. Sets the coherent cutoff in `apply_objective_imaging` via + `make_pupil`. + led_grid_shape : (int, int) + LED board shape (ny_led, nx_led). + pitch_mm : float + LED pitch on the board (mm), used to map board radius to illumination NA. + camera_pixel_um : float + Camera pixel pitch (µm) in the camera plane; combined with `magnification` to + derive object-plane sampling and the bin size for `localize_psf.camera.simulated_img`. + magnification : float + Object-to-camera magnification. Higher values increase the bin factor between + simulation grid and camera grid. + camera_oversample : int + Optional oversampling factor applied to the camera grid; influences the default + `sim_dxy_um` and `sim_esize` when not provided. + esize_camera : (int, int) + Final camera image shape (ny, nx) produced by `simulated_img`. + sim_dxy_um : float | None + Simulation grid sampling (µm). If None, derived from `camera_pixel_um / + magnification / camera_oversample` and forwarded to `mie_efield` and + `propagate_homogeneous`. + sim_esize : (int, int) | None + Simulation grid size (ny, nx). If None, defaults to `esize_camera * + camera_oversample`. + z_plane_um : float + Optional user-specified field evaluation plane for `mie_efield` (validated to + stay outside the sphere). If 0, a safe exterior plane is chosen automatically. + sphere : SphereSpec | None + Sphere parameters (radius, n_sphere, n_medium) passed to `mie_efield`. Defaults + to a polystyrene-like sphere in medium if None. + inner_na : float + Inner NA to form an annular LED mask in `make_led_na_positions`. + include_center_led : bool + Whether to include the origin LED (NA=0) in all patterns. + camera_gains : array or float, optional + Gains (ADU/e) for `simulated_img`. + camera_offsets : array or float, optional + Offsets (ADU) for `simulated_img`. + camera_readout_noise_sds : array or float, optional + Readout noise SD (ADU) for `simulated_img`. + camera_photon_shot_noise : bool, optional + Enable Poisson shot noise in `simulated_img`. + camera_saturation : int or None, optional + Saturation level passed to `simulated_img`. + camera_image_is_integer : bool, optional + If True, round simulated images to integers in `simulated_img`. + exposure_time_ms : float, optional + Exposure time (milliseconds). Multiplies simulated irradiance to convert to + expected photon counts before camera noise is applied. + illumination_photons_per_s_per_um2 : float, optional + Incident photon flux density (photons/s/µm^2) used to scale the Mie-derived + irradiance (relative units) into absolute photons/s per pixel area. + focal_stack_planes : int or None, optional + Number of focal planes to simulate. If None or <=1, a single plane is produced. + focal_stack_step_um : float, optional + Axial step (µm) between focal planes, centered on nominal focus. + pattern_order : Sequence[str] + Order of DPC patterns ("left","right","up","down") returned in the output stack. + normalize_by_led_count : bool + If True, divides each pattern sum by the number of LEDs contributing to it. + led_subsample : int + Subsample factor on the LED list before simulation (simulates every Nth LED). + use_gpu : bool + If True, uses CuPy-backed versions of `mie_efield`, FFTs, and `simulated_img` + where available. + mie_kwargs : dict | None + Extra keyword arguments forwarded directly to `mie_fields.mie_efield`. + cache_dir : str | Path | None + Optional directory for Zarr-based per-LED camera irradiance caching + (v3-compatible; no .npy files). + reuse_cache : bool + If True, reuses cached per-LED camera images when present. + + Returns + ------- + dpc : array + Stack of four DPC images ordered by `pattern_order` (shape (4, ny, nx)). + meta : dict + Metadata including NA positions, grid sizes, bin factor, and LED counts for + downstream reconstruction/testing. + """ + if led_subsample < 1: + raise ValueError("led_subsample must be >= 1") + if camera_oversample < 1: + raise ValueError("camera_oversample must be >= 1") + + if sphere is None: + sphere = SphereSpec(radius_um=5.0, n_sphere=1.59 + 0.0j, n_medium=1.55) + + sim_dxy_val = sim_dxy_um if sim_dxy_um is not None else sample_pixel_size_um( + camera_pixel_um, magnification, oversample=int(camera_oversample) + ) + sim_esize_val = sim_esize if sim_esize is not None else ( + int(esize_camera[0]) * int(camera_oversample), + int(esize_camera[1]) * int(camera_oversample), + ) + + sim_space = SimulationSpace( + dxy_um=float(sim_dxy_val), + esize=(int(sim_esize_val[0]), int(sim_esize_val[1])), + z_plane_um=float(z_plane_um), + ) + cam_space = CameraSpace( + pixel_um=float(camera_pixel_um), + magnification=float(magnification), + shape=(int(esize_camera[0]), int(esize_camera[1])), + photon_shot_noise=bool(camera_photon_shot_noise), + readout_noise_sds=camera_readout_noise_sds, + gains=camera_gains, + offsets=camera_offsets, + image_is_integer=bool(camera_image_is_integer), + psf=None, + apodization=1, + saturation=camera_saturation, + ) + + simulator = DPCMieSimulator( + wavelength_um=float(wavelength_um), + na_obj=float(na_obj), + sphere=sphere, + simulation=sim_space, + camera=cam_space, + led_grid_shape=led_grid_shape, + pitch_mm=float(pitch_mm), + inner_na=float(inner_na), + include_center_led=bool(include_center_led), + pattern_order=pattern_order, + normalize_by_led_count=bool(normalize_by_led_count), + led_subsample=int(led_subsample), + use_gpu=bool(use_gpu), + mie_kwargs=mie_kwargs, + cache_dir=cache_dir, + reuse_cache=bool(reuse_cache), + exposure_time_ms=float(exposure_time_ms), + illumination_photons_per_s_per_um2=float(illumination_photons_per_s_per_um2), + focal_stack_planes=focal_stack_planes, + focal_stack_step_um=float(focal_stack_step_um), + ) + + dpc, meta = simulator.simulate_patterns() + + # Compatibility metadata with previous helper + meta.update({ + "dxy_um": np.asarray(sim_dxy_val, dtype=np.float32), + "camera_oversample": int(simulator.camera_bin_factor), + "esize_camera": (int(esize_camera[0]), int(esize_camera[1])), + "esize_mie": (int(sim_space.esize[0]), int(sim_space.esize[1])), + }) + + # Crop 10% border from each side to keep central region + def _crop_center(arr: array) -> array: # type: ignore + if arr.ndim == 3: + _, ny, nx = arr.shape + elif arr.ndim == 4: + _, _, ny, nx = arr.shape + else: + return arr + ypad = max(0, int(round(0.1 * ny))) + xpad = max(0, int(round(0.1 * nx))) + y1, y2 = ypad, ny - ypad + x1, x2 = xpad, nx - xpad + if y2 <= y1 or x2 <= x1: + return arr + if arr.ndim == 3: + return arr[:, y1:y2, x1:x2] + else: + return arr[:, :, y1:y2, x1:x2] + + dpc = _crop_center(dpc) + if dpc.ndim == 3: + _, ny_crop, nx_crop = dpc.shape + else: + _, _, ny_crop, nx_crop = dpc.shape + meta["camera_shape"] = (int(ny_crop), int(nx_crop)) + meta["esize_camera"] = (int(ny_crop), int(nx_crop)) + + return dpc, meta + + +def write_dpc_zarr( + zarr_path: str | Path, + dpc: array, # type: ignore + meta: dict[str, array], # type: ignore + *, + wavelength_um: float, + na_obj: float, + camera_pixel_um: float, + magnification: float, + n_medium: float, + led_grid_shape: tuple[int, int], + pattern_order: Sequence[PatternName], + sphere: SphereSpec, + inner_na: float, + include_center_led: bool, + camera_gains: array | float = 1.0, # type: ignore + camera_offsets: array | float = 0.0, # type: ignore + camera_readout_noise_sds: array | float = 0.0, # type: ignore + camera_photon_shot_noise: bool = False, + camera_saturation: int | None = None, + camera_image_is_integer: bool = False, + exposure_time_ms: float = 1.0, + illumination_photons_per_s_per_um2: float = 1.0, + nz: int | None = None, + z_span_um: float | None = None, + overwrite: bool = True, +) -> None: + """ + Write a DPC synthetic dataset to Zarr in the layout expected by tests. + + Parameters + ---------- + zarr_path : str or Path + Output Zarr group path. + dpc : array + DPC stack (4, ny, nx) to write. + meta : dict[str, array] + Metadata returned from `simulate_dpc_images_sphere` (includes NA positions). + wavelength_um : float + Vacuum wavelength (µm) for root attrs. + na_obj : float + Objective NA for root attrs. + camera_pixel_um : float + Camera pixel size (µm) for root attrs. + magnification : float + System magnification for root attrs. + n_medium : float + Medium index for root attrs. + led_grid_shape : tuple[int, int] + LED grid shape stored in root attrs. + pattern_order : Sequence[str] + Pattern order stored in root attrs. + sphere : SphereSpec + Sphere parameters; radius/index stored as extra attrs. + inner_na : float + Inner NA stored as extra attr. + camera_gains : array or float, optional + Camera gain(s) stored in attrs. + camera_offsets : array or float, optional + Camera offset(s) stored in attrs. + camera_readout_noise_sds : array or float, optional + Camera readout noise sds stored in attrs. + camera_photon_shot_noise : bool, optional + Camera shot noise flag stored in attrs. + camera_saturation : int or None, optional + Camera saturation stored in attrs. + camera_image_is_integer : bool, optional + Camera integer flag stored in attrs. + exposure_time_ms : float, optional + Exposure time stored in attrs. + illumination_photons_per_s_per_um2 : float, optional + Illumination intensity stored in attrs. + include_center_led : bool + Whether center LED included; stored as extra attr. + nz : int or None, optional + Optional z-stacks count stored in attrs. + z_span_um : float or None, optional + Optional z-span stored in attrs. + overwrite : bool, optional + If True, overwrite existing Zarr group. + + Returns + ------- + None + """ + try: + import zarr # type: ignore + except Exception as e: # pragma: no cover + raise ImportError("zarr is required to write datasets") from e + + zarr_path = str(zarr_path) + mode = "w" if overwrite else "w-" + g = zarr.open_group(zarr_path, mode=mode) + + # --- Root attributes (exact keys expected by tests) --- + g.attrs["wavelength_um"] = float(wavelength_um) + g.attrs["na_obj"] = float(na_obj) + g.attrs["camera_pixel_um"] = float(camera_pixel_um) + g.attrs["magnification"] = float(magnification) + g.attrs["n_medium"] = float(n_medium) + g.attrs["led_grid_shape"] = [int(led_grid_shape[0]), int(led_grid_shape[1])] + g.attrs["pattern_order"] = [str(p) for p in pattern_order] + g.attrs["pitch_mm"] = float(meta.get("pitch_mm", np.nan)) + # Effective Mie evaluation plane used for simulation (µm). + if "z_plane_um" in meta: + g.attrs["field_generation_z_um"] = float(np.asarray(meta["z_plane_um"])) + + # Ensure CPU ndarray for Zarr + if cp is not None and isinstance(dpc, cp.ndarray): + dpc_np = cp.asnumpy(dpc).astype(np.float32, copy=False) + else: + dpc_np = np.asarray(dpc, dtype=np.float32) + + if nz is not None: + g.attrs["nz"] = int(nz) + elif dpc_np.ndim == 4: + g.attrs["nz"] = int(dpc_np.shape[0]) + if z_span_um is not None: + g.attrs["z_span_um"] = float(z_span_um) + + # Optional extras (not required by tests) + g.attrs["sphere_radius_um"] = float(sphere.radius_um) + g.attrs["sphere_n_sphere"] = complex(sphere.n_sphere).__repr__() + g.attrs["sphere_n_medium"] = float(sphere.n_medium) + g.attrs["inner_na"] = float(inner_na) + g.attrs["include_center_led"] = bool(include_center_led) + g.attrs["dxy_um"] = float(np.asarray(meta.get("dxy_um", camera_pixel_um / magnification))) + g.attrs["camera_gains"] = ( + float(camera_gains) if np.isscalar(camera_gains) else np.asarray(camera_gains).tolist() + ) + g.attrs["camera_offsets"] = ( + float(camera_offsets) if np.isscalar(camera_offsets) else np.asarray(camera_offsets).tolist() + ) + g.attrs["camera_readout_noise_sds"] = ( + float(camera_readout_noise_sds) + if np.isscalar(camera_readout_noise_sds) + else np.asarray(camera_readout_noise_sds).tolist() + ) + g.attrs["camera_photon_shot_noise"] = bool(camera_photon_shot_noise) + g.attrs["camera_saturation"] = camera_saturation if camera_saturation is None else int(camera_saturation) + g.attrs["camera_image_is_integer"] = bool(camera_image_is_integer) + g.attrs["exposure_time_ms"] = float(exposure_time_ms) + g.attrs["illumination_photons_per_s_per_um2"] = float(illumination_photons_per_s_per_um2) + + if dpc_np.ndim == 3: + if dpc_np.shape[0] != 4: + raise ValueError(f"dpc must have shape (4, ny, nx), got {dpc_np.shape}") + chunks = (1, min(256, dpc_np.shape[1]), min(256, dpc_np.shape[2])) + elif dpc_np.ndim == 4: + if dpc_np.shape[1] != 4: + raise ValueError(f"dpc must have shape (nz, 4, ny, nx), got {dpc_np.shape}") + chunks = (1, 1, min(256, dpc_np.shape[2]), min(256, dpc_np.shape[3])) + else: + raise ValueError(f"dpc must have ndim 3 or 4, got {dpc_np.ndim}") + _zarr_write_array(g, "dpc", dpc_np, dtype="float32", chunks=chunks, overwrite=True) + + mg = g.require_group("meta") + if "na_xy" in meta: + na_xy_np = np.asarray(meta["na_xy"], dtype=np.float32) + _zarr_write_array(mg, "na_xy", na_xy_np, dtype="float32", overwrite=True) + + for pname in ("left", "right", "up", "down"): + if pname in meta: + pts_np = np.asarray(meta[pname], dtype=np.float32) + _zarr_write_array(mg, pname, pts_np, dtype="float32", overwrite=True) + if "focal_offsets_um" in meta: + offsets_np = np.asarray(meta["focal_offsets_um"], dtype=np.float32) + _zarr_write_array(mg, "focal_offsets_um", offsets_np, dtype="float32", overwrite=True) + + +def simulate_dpc_images_sphere_to_zarr( + zarr_path: str | Path, + *, + wavelength_um: float = 0.515, + na_obj: float = 0.8, + led_grid_shape: tuple[int, int] = (64, 64), + pitch_mm: float = 2.5, + camera_pixel_um: float = 2.4, + magnification: float = 20.0, + camera_oversample: int = 1, + esize_camera: tuple[int, int] = (256, 256), + sim_dxy_um: float | None = None, + sim_esize: tuple[int, int] | None = None, + z_plane_um: float = 0.0, + sphere: SphereSpec | None = None, + inner_na: float = 0.0, + include_center_led: bool = False, + pattern_order: Sequence[PatternName] = ("left", "right", "up", "down"), + normalize_by_led_count: bool = True, + led_subsample: int = 1, + use_gpu: bool = False, + mie_kwargs: dict | None = None, + cache_dir: str | Path | None = None, + reuse_cache: bool = True, + nz: int | None = None, + z_span_um: float | None = None, + overwrite: bool = True, + camera_gains: array | float = 1.0, # type: ignore + camera_offsets: array | float = 0.0, # type: ignore + camera_readout_noise_sds: array | float = 0.0, # type: ignore + camera_photon_shot_noise: bool = False, + camera_saturation: int | None = None, + camera_image_is_integer: bool = False, + exposure_time_ms: float = 1.0, + illumination_photons_per_s_per_um2: float = 1.0, + focal_stack_planes: int | None = None, + focal_stack_step_um: float = 0.5, +) -> None: + """ + Generate DPC images and write them to Zarr. + + Parameters + ---------- + zarr_path : str or Path + Output Zarr group path. + wavelength_um : float, optional + Vacuum wavelength (µm) passed to simulation and stored in attrs. + na_obj : float, optional + Objective NA for pupil cutoff and attrs. + led_grid_shape : tuple[int, int], optional + LED board shape for pattern generation. + pitch_mm : float, optional + LED pitch on the board (mm), used to map board radius to illumination NA. + camera_pixel_um : float, optional + Camera pixel size (µm) for attrs and camera sampling. + magnification : float, optional + Object-to-camera magnification used to derive camera binning. + camera_oversample : int, optional + Oversampling factor; affects default simulation sampling/size. + esize_camera : tuple[int, int], optional + Final camera image shape (ny, nx). + sim_dxy_um : float or None, optional + Simulation sampling (µm); defaults to camera_pixel_um / magnification / camera_oversample. + sim_esize : tuple[int, int] or None, optional + Simulation grid size; defaults to esize_camera * camera_oversample. + z_plane_um : float, optional + Field evaluation plane for `mie_efield`; validated to be outside the sphere. + sphere : SphereSpec or None, optional + Sphere parameters; defaults to a preset if None. + inner_na : float, optional + Inner NA for annular LED mask. + include_center_led : bool, optional + Whether to include the center LED. + pattern_order : Sequence[str], optional + Order of DPC patterns in output. + normalize_by_led_count : bool, optional + If True, normalize each pattern by its LED count. + led_subsample : int, optional + Subsample factor for LEDs. + use_gpu : bool, optional + If True, use CuPy-backed ops when available. + mie_kwargs : dict or None, optional + Extra kwargs passed to `mie_fields.mie_efield`. + cache_dir : str or Path or None, optional + Directory for Zarr per-LED camera irradiance cache. + reuse_cache : bool, optional + If True, reuse cached per-LED images. + camera_gains : array or float, optional + Gains (ADU/e) for `simulated_img`. + camera_offsets : array or float, optional + Offsets (ADU) for `simulated_img`. + camera_readout_noise_sds : array or float, optional + Readout noise SD (ADU) for `simulated_img`. + camera_photon_shot_noise : bool, optional + Enable Poisson shot noise in `simulated_img`. + camera_saturation : int or None, optional + Saturation level for `simulated_img`. + camera_image_is_integer : bool, optional + If True, round simulated images to integers in `simulated_img`. + exposure_time_ms : float, optional + Exposure time (milliseconds). Multiplies simulated irradiance before camera noise. + illumination_photons_per_s_per_um2 : float, optional + Incident photon flux density (photons/s/µm^2) used to scale irradiance to photons. + focal_stack_planes : int or None, optional + Number of focal planes to simulate. If None or <=1, a single plane is produced. + focal_stack_step_um : float, optional + Axial step (µm) between focal planes, centered on nominal focus. + nz : int or None, optional + Optional z-stack count stored in attrs. + z_span_um : float or None, optional + Optional z-span stored in attrs. + overwrite : bool, optional + If True, overwrite existing Zarr group. + + Returns + ------- + None + """ + if cache_dir is None: + # Default cache directory adjacent to the output Zarr store. + zp = Path(zarr_path) + cache_dir = zp.with_name(zp.name + ".mie_cache") + + dpc, meta = simulate_dpc_images_sphere( + wavelength_um=wavelength_um, + na_obj=na_obj, + led_grid_shape=led_grid_shape, + pitch_mm=pitch_mm, + camera_pixel_um=camera_pixel_um, + magnification=magnification, + camera_oversample=camera_oversample, + esize_camera=esize_camera, + sim_dxy_um=sim_dxy_um, + sim_esize=sim_esize, + z_plane_um=z_plane_um, + sphere=sphere, + inner_na=inner_na, + include_center_led=include_center_led, + pattern_order=pattern_order, + normalize_by_led_count=normalize_by_led_count, + led_subsample=led_subsample, + use_gpu=use_gpu, + mie_kwargs=mie_kwargs, + cache_dir=cache_dir, + reuse_cache=reuse_cache, + camera_gains=camera_gains, + camera_offsets=camera_offsets, + camera_readout_noise_sds=camera_readout_noise_sds, + camera_photon_shot_noise=camera_photon_shot_noise, + camera_saturation=camera_saturation, + camera_image_is_integer=camera_image_is_integer, + exposure_time_ms=exposure_time_ms, + illumination_photons_per_s_per_um2=illumination_photons_per_s_per_um2, + focal_stack_planes=focal_stack_planes, + focal_stack_step_um=focal_stack_step_um, + ) + + sphere_used = sphere if sphere is not None else SphereSpec(radius_um=5.0, n_sphere=1.59, n_medium=1.515) + + write_dpc_zarr( + zarr_path, + dpc, + meta, + wavelength_um=wavelength_um, + na_obj=na_obj, + camera_pixel_um=camera_pixel_um, + magnification=magnification, + n_medium=float(sphere_used.n_medium), + led_grid_shape=led_grid_shape, + pattern_order=pattern_order, + sphere=sphere_used, + inner_na=inner_na, + include_center_led=include_center_led, + camera_gains=camera_gains, + camera_offsets=camera_offsets, + camera_readout_noise_sds=camera_readout_noise_sds, + camera_photon_shot_noise=camera_photon_shot_noise, + camera_saturation=camera_saturation, + camera_image_is_integer=camera_image_is_integer, + exposure_time_ms=exposure_time_ms, + illumination_photons_per_s_per_um2=illumination_photons_per_s_per_um2, + nz=nz, + z_span_um=z_span_um, + overwrite=overwrite, + ) + +if __name__ == "__main__": # pragma: no cover + # Example usage: generate a DPC dataset and write to Zarr. + output_path = Path("/media/dps/data/synthetic_dpc_sphere.zarr") + + simulate_dpc_images_sphere_to_zarr( + output_path, + wavelength_um=0.515, + na_obj=0.8, + led_grid_shape=(64, 64), + camera_pixel_um=2.4, + magnification=20.0, + camera_oversample=2, + esize_camera=(512, 512), + z_plane_um=0.0, + sphere=SphereSpec(radius_um=2.5, n_sphere=1.59, n_medium=1.515), + inner_na=0.0, + include_center_led=True, + pattern_order=("left", "right", "up", "down"), + normalize_by_led_count=True, + led_subsample=1, + use_gpu=True, + mie_kwargs=None, + camera_gains=1.0, + camera_offsets=100.0, + camera_readout_noise_sds=3.31, + camera_photon_shot_noise=True, + camera_saturation=None, + camera_image_is_integer=True, + exposure_time_ms=100.0, + illumination_photons_per_s_per_um2=1e6, + focal_stack_planes=41, + focal_stack_step_um=.325, + overwrite=True, + ) + print(f"Synthetic DPC dataset written to {output_path}") diff --git a/mcsim/analysis/dpc_fista_solver.py b/mcsim/analysis/dpc_fista_solver.py new file mode 100644 index 0000000..7a28d57 --- /dev/null +++ b/mcsim/analysis/dpc_fista_solver.py @@ -0,0 +1,99 @@ +""" +Minimal LED board geometry helpers for DPC simulations. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass(frozen=True) +class LEDBoard: + """ + LED board definition for NA-space mapping. + + :param n_side: Number of LEDs per side (square board). + :param pitch_mm: LED pitch in mm. + :param na_obj: Objective NA. + :param wavelength_um: Wavelength in um. + :param n_medium: Medium refractive index. + """ + + n_side: int + pitch_mm: float + na_obj: float + wavelength_um: float + n_medium: float + + +@dataclass(frozen=True) +class LEDGeometry: + """ + Geometry container for LED NA positions. + + :param na_components: NA positions, shape (N, 2). + """ + + na_components: np.ndarray + + +def _make_led_na_positions( + ny_led: int, + nx_led: int, + *, + na_obj: float, + inner_na: float = 0.0, + include_center: bool = True, +) -> np.ndarray: + if ny_led < 1 or nx_led < 1: + raise ValueError("ny_led and nx_led must be >= 1") + if inner_na < 0 or inner_na >= na_obj: + raise ValueError("inner_na must satisfy 0 <= inner_na < na_obj") + + cy = (ny_led - 1) / 2.0 + cx = (nx_led - 1) / 2.0 + + yy, xx = np.meshgrid( + np.arange(ny_led, dtype=np.float32), + np.arange(nx_led, dtype=np.float32), + indexing="ij", + ) + x = xx - cx + y = yy - cy + r = np.sqrt(x * x + y * y) + + r_circle = float(min(cx, cy)) if (cx > 0 and cy > 0) else float(np.max(r)) + mask = r <= r_circle + + if not include_center: + mask = mask & (r > 0) + + r_mask_max = float(np.max(r[mask])) if np.any(mask) else 1.0 + scale = float(na_obj) / r_mask_max + + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + + ann = na_r >= float(inner_na) + na_xy = np.stack([na_x[ann], na_y[ann]], axis=1) + return na_xy + + +def compute_led_geometry(board: LEDBoard) -> LEDGeometry: + """ + Compute LED NA positions for a square board. + + :param board: LED board definition. + :return: LEDGeometry with NA components. + """ + na_xy = _make_led_na_positions( + int(board.n_side), + int(board.n_side), + na_obj=float(board.na_obj), + inner_na=0.0, + include_center=True, + ) + return LEDGeometry(na_components=na_xy) diff --git a/mcsim/analysis/dpc_meta.py b/mcsim/analysis/dpc_meta.py new file mode 100644 index 0000000..c0983b1 --- /dev/null +++ b/mcsim/analysis/dpc_meta.py @@ -0,0 +1,89 @@ +""" +Metadata containers for DPC inverse solvers. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +import numpy as np + + +@dataclass(frozen=True) +class DPCMeta: + """ + Metadata required for DPC reconstruction. + + :param wavelength_um: Vacuum wavelength (um). + :param n_background: Background refractive index. + :param NA_obj: Objective NA (also used for detection NA). + :param magnification: Objective magnification. + :param camera_pixel_pitch_um: Camera pixel pitch (um). + :param volume_shape_zyx: Reconstruction volume shape (nz, ny, nx). + :param voxel_size_um_zyx: Voxel size (dz, dy, dx) in um. + :param z_planes_um: Object-space z planes in um, length nz. + :param led_grid_shape: LED board grid shape (ny, nx). + :param inner_na: Inner NA for annular LED mask. + :param include_center_led: Include NA=0 LED if True. + :param pattern_order: Pattern order matching measured data. + """ + + wavelength_um: float + n_background: float + NA_obj: float + magnification: float + camera_pixel_pitch_um: float + volume_shape_zyx: tuple[int, int, int] + voxel_size_um_zyx: tuple[float, float, float] + z_planes_um: Sequence[float] + led_grid_shape: tuple[int, int] = (64, 64) + inner_na: float = 0.0 + include_center_led: bool = False + pattern_order: tuple[str, str, str, str] = ("left", "right", "up", "down") + + def __post_init__(self) -> None: + z_planes = np.asarray(self.z_planes_um, dtype=float) + object.__setattr__(self, "z_planes_um", z_planes) + if z_planes.ndim != 1: + raise ValueError("z_planes_um must be 1D") + if len(self.volume_shape_zyx) != 3: + raise ValueError("volume_shape_zyx must have length 3") + if len(self.voxel_size_um_zyx) != 3: + raise ValueError("voxel_size_um_zyx must have length 3") + if z_planes.size != int(self.volume_shape_zyx[0]): + raise ValueError("z_planes_um length must match volume_shape_zyx[0]") + allowed = ("left", "right", "up", "down") + if len(self.pattern_order) != 4 or set(self.pattern_order) != set(allowed): + raise ValueError(f"pattern_order must be a permutation of {allowed}") + + @property + def dxy_um(self) -> float: + """ + Object-space pixel size derived from the camera pitch and magnification. + + :return: Pixel size in um. + """ + return float(self.camera_pixel_pitch_um) / float(self.magnification) + + def as_dict(self) -> dict[str, object]: + """ + Export metadata as a serializable dictionary. + + :return: Metadata dictionary. + """ + z_planes = np.asarray(self.z_planes_um, dtype=float).tolist() + return { + "wavelength_um": float(self.wavelength_um), + "n_background": float(self.n_background), + "NA_obj": float(self.NA_obj), + "magnification": float(self.magnification), + "camera_pixel_pitch_um": float(self.camera_pixel_pitch_um), + "volume_shape_zyx": tuple(int(v) for v in self.volume_shape_zyx), + "voxel_size_um_zyx": tuple(float(v) for v in self.voxel_size_um_zyx), + "z_planes_um": z_planes, + "led_grid_shape": tuple(int(v) for v in self.led_grid_shape), + "inner_na": float(self.inner_na), + "include_center_led": bool(self.include_center_led), + "pattern_order": tuple(self.pattern_order), + } diff --git a/mcsim/analysis/optimize.py b/mcsim/analysis/optimize.py index 4d2a54f..187bbcb 100644 --- a/mcsim/analysis/optimize.py +++ b/mcsim/analysis/optimize.py @@ -285,6 +285,7 @@ def run(self, gtol: float = 0.0, print_newline: bool = False, label: str = "", + iteration_callback=None, **kwargs) -> dict: """ @@ -536,6 +537,9 @@ def run(self, if stop: break + if iteration_callback is not None: + iteration_callback(ii, x) + # compute final cost if compute_cost: if compute_all_costs: diff --git a/mcsim/analysis/tv_prox_fast.py b/mcsim/analysis/tv_prox_fast.py new file mode 100644 index 0000000..1918b0d --- /dev/null +++ b/mcsim/analysis/tv_prox_fast.py @@ -0,0 +1,402 @@ +""" +Fast GPU-only 3D TV proximal operator (Chambolle dual projection) implemented +with CuPy + CUDA RawKernel. + +Import into optimize.py, e.g.: + + from tv_prox_fast import tv_prox_fast + +Design constraints: +- 3D only +- GPU/CuPy only (raises if not CuPy array) +- No positivity constraint +- Fixed-iteration behavior consistent with the uploaded prox_fgp (default num_iter=10) +""" + +from functools import lru_cache +from typing import Any, Optional, Tuple + +import cupy as cp + +__all__ = ["tv_prox_fast"] + + +@lru_cache(maxsize=None) +def _tv_fast_kernels( + dtype_str: str, +) -> Tuple[cp.RawKernel, cp.RawKernel, cp.RawKernel, cp.RawKernel]: + """ + Compile and cache RawKernels used by tv_prox_fast(). + + :param dtype_str: Data type string ("float32" or "float64"). + :return: Tuple of (discontig_sub_kernel, tv_norm_kernel). + """ + if dtype_str == "float32": + ctype = "float" + sqrt_fn = "sqrtf" + elif dtype_str == "float64": + ctype = "double" + sqrt_fn = "sqrt" + else: + raise ValueError(f"Unsupported dtype for tv_prox_fast: {dtype_str}") + + code = rf""" + extern "C" __global__ + void discontig_sub( + const {ctype}* __restrict__ arr, + {ctype}* __restrict__ out, + unsigned long long step, + unsigned long long length, + int transpose, + unsigned long long n_total + ) {{ + unsigned long long i = (unsigned long long)blockDim.x * (unsigned long long)blockIdx.x + + (unsigned long long)threadIdx.x; + if (i >= n_total) return; + + unsigned long long pos = i % length; + if (!transpose) {{ + if (pos + step < length) {{ + out[i] = arr[i + step] - arr[i]; + }} else {{ + out[i] = ({ctype})0; + }} + }} else {{ + out[i] -= arr[i]; + if (pos + step < length) {{ + out[i + step] += arr[i]; + }} + }} + }} + + extern "C" __global__ + void tv_norm3( + {ctype}* __restrict__ out, // length n + const {ctype}* __restrict__ tv, // length 3*n, packed [x|y|z] + unsigned long long n + ) {{ + unsigned long long i = (unsigned long long)blockDim.x * (unsigned long long)blockIdx.x + + (unsigned long long)threadIdx.x; + if (i >= n) return; + + {ctype} a = tv[i]; + {ctype} b = tv[i + n]; + {ctype} c = tv[i + 2*n]; + + out[i] = {sqrt_fn}(a*a + b*b + c*c); + }} + """ + + code_grad = rf""" + extern "C" __global__ + void tv_grad_update( + const {ctype}* __restrict__ out, + {ctype}* __restrict__ p, + {ctype}* __restrict__ norms, + {ctype} step, + {ctype} weight, + {ctype} w_z, + {ctype} w_y, + {ctype} w_x, + unsigned long long nz, + unsigned long long ny, + unsigned long long nx, + unsigned long long n_total, + int write_norms + ) {{ + unsigned long long i = (unsigned long long)blockDim.x * (unsigned long long)blockIdx.x + + (unsigned long long)threadIdx.x; + if (i >= n_total) return; + + unsigned long long plane = ny * nx; + unsigned long long z = i / plane; + unsigned long long rem = i - z * plane; + unsigned long long y = rem / nx; + unsigned long long x = rem - y * nx; + + {ctype} g0 = ({ctype})0; + {ctype} g1 = ({ctype})0; + {ctype} g2 = ({ctype})0; + if (z + 1 < nz) {{ + g0 = (out[i + plane] - out[i]) * w_z; + }} + if (y + 1 < ny) {{ + g1 = (out[i + nx] - out[i]) * w_y; + }} + if (x + 1 < nx) {{ + g2 = (out[i + 1] - out[i]) * w_x; + }} + + {ctype} norm = {sqrt_fn}(g0 * g0 + g1 * g1 + g2 * g2); + {ctype} denom = ({ctype})1 + (step / weight) * norm; + + {ctype} p0 = p[i]; + {ctype} p1 = p[i + n_total]; + {ctype} p2 = p[i + 2 * n_total]; + + p[i] = (p0 - step * g0) / denom; + p[i + n_total] = (p1 - step * g1) / denom; + p[i + 2 * n_total] = (p2 - step * g2) / denom; + if (write_norms) {{ + norms[i] = norm; + }} + }} + """ + + code_div = rf""" + extern "C" __global__ + void tv_divergence( + const {ctype}* __restrict__ p, + const {ctype}* __restrict__ x, + {ctype}* __restrict__ out, + {ctype} w_z, + {ctype} w_y, + {ctype} w_x, + unsigned long long nz, + unsigned long long ny, + unsigned long long nx, + unsigned long long n_total + ) {{ + unsigned long long i = (unsigned long long)blockDim.x * (unsigned long long)blockIdx.x + + (unsigned long long)threadIdx.x; + if (i >= n_total) return; + + unsigned long long plane = ny * nx; + unsigned long long z = i / plane; + unsigned long long rem = i - z * plane; + unsigned long long y = rem / nx; + unsigned long long xind = rem - y * nx; + + {ctype} p0 = p[i]; + {ctype} p1 = p[i + n_total]; + {ctype} p2 = p[i + 2 * n_total]; + + {ctype} val = -(w_z * p0 + w_y * p1 + w_x * p2); + if (z > 0) {{ + val += w_z * p[i - plane]; + }} + if (y > 0) {{ + val += w_y * p[i + n_total - nx]; + }} + if (xind > 0) {{ + val += w_x * p[i + 2 * n_total - 1]; + }} + out[i] = x[i] + val; + }} + """ + + return ( + cp.RawKernel(code, "discontig_sub"), + cp.RawKernel(code, "tv_norm3"), + cp.RawKernel(code_div, "tv_divergence"), + cp.RawKernel(code_grad, "tv_grad_update"), + ) + + +def _launch_1d(kernel: cp.RawKernel, n: int, args: Tuple[Any, ...]) -> None: + """ + Launch a 1D CUDA kernel with a fixed thread block size. + + :param kernel: RawKernel to launch. + :param n: Total number of elements. + :param args: Kernel argument tuple. + :return: None. + """ + threads = 256 + blocks = (n + threads - 1) // threads + kernel((blocks,), (threads,), args) + + +def _discontig_sub_cupy( + arr: cp.ndarray, + out: cp.ndarray, + axis: int, + transpose: bool = False, +) -> cp.ndarray: + """ + Flattened-addressing difference/adjoint operator, equivalent to fista.py discontig_sub. + + :param arr: Input array. + :param out: Output array (overwritten for forward, accumulated for adjoint). + :param axis: Axis along which to compute the flattened-order difference. + :param transpose: False for forward, True for adjoint. + :return: Output array. + """ + if arr.dtype != out.dtype: + raise ValueError("arr and out must have the same dtype") + + if not arr.flags.c_contiguous: + arr = cp.ascontiguousarray(arr) + if not out.flags.c_contiguous: + raise ValueError("out must be C-contiguous") + + axis = axis % arr.ndim + shape = arr.shape + + step = 1 + for s in shape[axis + 1 :]: + step *= int(s) + + length = step * int(shape[axis]) + n_total = int(arr.size) + + discontig, _, _, _ = _tv_fast_kernels(str(arr.dtype)) + _launch_1d( + discontig, + n_total, + ( + arr, + out, + cp.uint64(step), + cp.uint64(length), + cp.int32(1 if transpose else 0), + cp.uint64(n_total), + ), + ) + return out + + +def tv_prox_fast( + x: cp.ndarray, + tau: float, + num_iter: int = 10, + eps: float = 0.0, + voxel_size_zyx: Optional[Tuple[float, float, float]] = None, + weight_scale_zyx: Optional[Tuple[float, float, float]] = None, + out: Optional[cp.ndarray] = None, +) -> cp.ndarray: + """ + Fast GPU-only TV proximal operator for 3D arrays using a Chambolle dual-projection loop. + + This implements a proximal map: + + .. math:: + \\text{prox}_{\\tau \\, TV}(x) = \\arg\\min_y \\; 0.5\\|y - x\\|_2^2 + \\tau \\, TV(y) + + matching skimage/cucim's Chambolle formulation (no positivity projection). + + :param x: CuPy ndarray, 3D. + :param tau: TV proximal weight (must be >= 0). tau==0 returns identity. + :param num_iter: Fixed number of FGP iterations. + :param eps: Relative stopping tolerance, matching cucim TV when > 0. + :param voxel_size_zyx: Physical voxel size (dz, dy, dx). Defaults to (1, 1, 1). + :param weight_scale_zyx: Multiplicative scaling for (dz, dy, dx) weights. + :param out: Optional CuPy array to write the result into (must match shape/dtype). + :return: Proximal output (same shape/dtype as x, with float32/float64 enforced). + """ + if not isinstance(x, cp.ndarray): + raise TypeError("tv_prox_fast is GPU-only: x must be a CuPy ndarray.") + if x.ndim != 3: + raise ValueError(f"tv_prox_fast is 3D-only; got x.shape={x.shape}") + if tau < 0: + raise ValueError("tau must be >= 0") + + if tau == 0.0: + if out is None: + return x + out[...] = x + return out + + if not x.flags.c_contiguous: + x = cp.ascontiguousarray(x) + + if x.dtype not in (cp.float32, cp.float64): + x = x.astype(cp.float32, copy=False) + + dtype = x.dtype + if voxel_size_zyx is None: + voxel_size_zyx = (1.0, 1.0, 1.0) + if len(voxel_size_zyx) != 3: + raise ValueError("voxel_size_zyx must be a 3-tuple (dz, dy, dx).") + dz, dy, dx = (float(v) for v in voxel_size_zyx) + if dz <= 0 or dy <= 0 or dx <= 0: + raise ValueError("voxel_size_zyx entries must be > 0.") + if weight_scale_zyx is None: + weight_scale_zyx = (1.0, 1.0, 1.0) + if len(weight_scale_zyx) != 3: + raise ValueError("weight_scale_zyx must be a 3-tuple (sz, sy, sx).") + sz, sy, sx = (float(v) for v in weight_scale_zyx) + if sz <= 0 or sy <= 0 or sx <= 0: + raise ValueError("weight_scale_zyx entries must be > 0.") + w_z = 1.0 / dz + w_y = 1.0 / dy + w_x = 1.0 / dx + w_z *= sz + w_y *= sy + w_x *= sx + step = 1.0 / (2.0 * (w_z * w_z + w_y * w_y + w_x * w_x)) + + p = cp.zeros((3,) + x.shape, dtype=dtype) + norms = cp.empty_like(x) if eps > 0 else None + proj = cp.empty_like(x) if out is None else out + + _, _, div_kernel, grad_kernel = _tv_fast_kernels(str(dtype)) + n_vox = int(x.size) + nz, ny, nx = (int(s) for s in x.shape) + step_val = dtype.type(step) + weight_val = dtype.type(tau) + wz_val = dtype.type(w_z) + wy_val = dtype.type(w_y) + wx_val = dtype.type(w_x) + write_norms = cp.int32(1 if eps > 0 else 0) + eps_val = float(eps) + e_init = None + e_prev = None + + for ii in range(int(num_iter)): + if ii > 0: + _launch_1d( + div_kernel, + n_vox, + ( + p, + x, + proj, + wz_val, + wy_val, + wx_val, + cp.uint64(nz), + cp.uint64(ny), + cp.uint64(nx), + cp.uint64(n_vox), + ), + ) + else: + proj[...] = x + + _launch_1d( + grad_kernel, + n_vox, + ( + proj, + p, + norms if norms is not None else proj, + step_val, + weight_val, + wz_val, + wy_val, + wx_val, + cp.uint64(nz), + cp.uint64(ny), + cp.uint64(nx), + cp.uint64(n_vox), + write_norms, + ), + ) + + if eps_val > 0.0: + d = proj - x + e = cp.sum(d * d) + e += weight_val * cp.sum(norms) + e /= float(n_vox) + e_host = float(e) + if ii == 0: + e_init = e_host + e_prev = e_host + else: + if e_init is not None and e_prev is not None: + if abs(e_prev - e_host) < eps_val * e_init: + break + e_prev = e_host + + return proj diff --git a/mcsim/analysis/wotf_fista.py b/mcsim/analysis/wotf_fista.py new file mode 100644 index 0000000..0be485f --- /dev/null +++ b/mcsim/analysis/wotf_fista.py @@ -0,0 +1,591 @@ +""" +WOTF-based FISTA optimizer implemented without legacy inverse dependencies. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Sequence, Union, Any + +import numpy as np + +try: + import cupy as cp +except Exception: + cp = None + +from mcsim.analysis.fft import ft2, ift2, ft3, ift3, ftn +from mcsim.analysis.field_prop import get_v +from mcsim.analysis.optimize import Optimizer, to_cpu +from mcsim.analysis.tv_prox_fast import tv_prox_fast + + +if cp: + array = Union[np.ndarray, cp.ndarray] +else: + array = np.ndarray + + +@dataclass(frozen=True) +class WOTFParams: + wavelength_um: float + na_obj: float + na_in: float + n0: float + dxy_um: float + dz_um: float + nz: int + ny: int + nx: int + pupil_taper_na: float = 0.0 + + +def camera_to_photons( + I_cam: array, + *, + camera_offset_adu: Union[float, array], + camera_gain_photons_per_adu: Union[float, array], +) -> array: + """ + Convert camera units to photons using a linear model. + """ + xp = cp if cp and isinstance(I_cam, cp.ndarray) else np + offset = xp.asarray(camera_offset_adu, dtype=xp.float32) + gain = xp.asarray(camera_gain_photons_per_adu, dtype=xp.float32) + I_phot = (I_cam - offset) * gain + return xp.maximum(I_phot, xp.float32(0.0)) + + +def _make_led_na_positions( + ny_led: int, + nx_led: int, + *, + na_obj: float, + na_in: float, + include_center: bool, +) -> np.ndarray: + yy, xx = np.meshgrid(np.arange(ny_led), np.arange(nx_led), indexing="ij") + yy = yy.astype(np.float32) + xx = xx.astype(np.float32) + cy = (ny_led - 1) / 2.0 + cx = (nx_led - 1) / 2.0 + x = xx - cx + y = yy - cy + r = np.sqrt(x * x + y * y) + r_circle = float(min(cx, cy)) if (cx > 0 and cy > 0) else float(np.max(r)) + mask = r <= r_circle + if not include_center: + mask = mask & (r > 0) + if not np.any(mask): + raise ValueError("LED grid mask is empty; check LED geometry.") + r_mask_max = float(np.max(r[mask])) + scale = float(na_obj) / r_mask_max + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + if float(na_in) > 0: + keep = na_r >= float(na_in) + else: + keep = np.ones_like(na_r, dtype=bool) + if not include_center: + keep &= na_r > 0 + na_x = na_x[keep] + na_y = na_y[keep] + na_xy = np.stack([na_x, na_y], axis=1) + return na_xy.astype(np.float32, copy=False) + + +def build_led_na_grid( + ny_led: int, + nx_led: int, + *, + na_obj: float, + na_in: float, + include_center: bool, + led_subsample: int = 1, +) -> np.ndarray: + """ + Build an (N, 2) list of LED NA positions on the LED grid. + + led_subsample is the subsample factor for the LED grid. + """ + if led_subsample < 1: + raise ValueError("led_subsample must be >= 1.") + if led_subsample == 1: + return _make_led_na_positions( + ny_led, + nx_led, + na_obj=na_obj, + na_in=na_in, + include_center=include_center, + ) + cy = (ny_led - 1) / 2.0 + cx = (nx_led - 1) / 2.0 + step_y = int(max(1, np.floor(np.sqrt(float(led_subsample))))) + step_x = int(max(1, np.ceil(float(led_subsample) / float(step_y)))) + + yy_idx, xx_idx = np.meshgrid( + np.arange(ny_led, dtype=int), + np.arange(nx_led, dtype=int), + indexing="ij", + ) + yy = yy_idx.astype(np.float32) + xx = xx_idx.astype(np.float32) + x = xx - cx + y = yy - cy + r = np.sqrt(x * x + y * y) + r_circle = float(min(cx, cy)) if (cx > 0 and cy > 0) else float(np.max(r)) + mask_circle = r <= r_circle + if not include_center: + mask_circle = mask_circle & (r > 0) + if not np.any(mask_circle): + raise ValueError("LED grid mask is empty; check LED geometry.") + r_mask_max = float(np.max(r[mask_circle])) + scale = float(na_obj) / r_mask_max + + best_key = None + best_offsets = None + for off_y in range(step_y): + mask_y = (yy_idx - off_y) % step_y == 0 + for off_x in range(step_x): + mask_sub = mask_y & ((xx_idx - off_x) % step_x == 0) + mask = mask_circle & mask_sub + if not np.any(mask): + continue + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + if float(na_in) > 0: + keep = na_r >= float(na_in) + else: + keep = np.ones_like(na_r, dtype=bool) + if not include_center: + keep &= na_r > 0 + if not np.any(keep): + continue + na_x = na_x[keep] + na_y = na_y[keep] + left = int(np.sum(na_x < 0)) + right = int(np.sum(na_x > 0)) + up = int(np.sum(na_y > 0)) + down = int(np.sum(na_y < 0)) + if include_center: + center = int(np.sum((na_x == 0) & (na_y == 0))) + left += center + right += center + up += center + down += center + min_count = min(left, right, up, down) + total = int(na_x.size) + key = (min_count, total) + if best_key is None or key > best_key: + best_key = key + best_offsets = (off_y, off_x) + + if best_offsets is None: + raise ValueError("Subsampling removed all LEDs from the board.") + + off_y, off_x = best_offsets + mask_sub = ((yy_idx - off_y) % step_y == 0) & ((xx_idx - off_x) % step_x == 0) + mask = mask_circle & mask_sub + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + if float(na_in) > 0: + keep = na_r >= float(na_in) + else: + keep = np.ones_like(na_r, dtype=bool) + if not include_center: + keep &= na_r > 0 + na_x = na_x[keep] + na_y = na_y[keep] + na_xy = np.stack([na_x, na_y], axis=1) + if na_xy.size == 0: + raise ValueError("Subsampling removed all LEDs after inner NA filter.") + return na_xy.astype(np.float32, copy=False) + + +def build_led_pattern_membership( + led_na_xy: np.ndarray, + *, + order: Sequence[str], + include_center: bool, +) -> np.ndarray: + """ + Assign each LED to one of four patterns: left/right/up/down. + """ + if len(order) != 4: + raise ValueError("order must have 4 entries.") + na = np.asarray(led_na_xy, dtype=float) + membership = np.zeros((na.shape[0], 4), dtype=bool) + for ii, name in enumerate(order): + if name == "left": + membership[:, ii] = na[:, 0] < 0 + elif name == "right": + membership[:, ii] = na[:, 0] > 0 + elif name == "up": + membership[:, ii] = na[:, 1] > 0 + elif name == "down": + membership[:, ii] = na[:, 1] < 0 + else: + raise ValueError(f"Unknown pattern name: {name}") + if include_center: + center = np.all(na == 0, axis=1) + if np.any(center): + membership[center] = True + return membership + + +def _tukey_window(n: int, taper: int, xp: Any) -> Optional[array]: + if taper <= 0 or n <= 0: + return None + taper = min(int(taper), n // 2) + if taper <= 0: + return None + alpha = min(1.0, 2.0 * float(taper) / float(n)) + if xp is np: + try: + from scipy.signal.windows import tukey # type: ignore + + return tukey(n, alpha=alpha).astype(np.float32, copy=False) + except Exception: + pass + if cp and xp is cp: + try: + from cupyx.scipy.signal.windows import tukey # type: ignore + + return tukey(n, alpha=alpha).astype(cp.float32, copy=False) + except Exception: + pass + idx = xp.arange(n, dtype=xp.float32) + w = xp.ones(n, dtype=xp.float32) + edge = float(taper) + left = idx < edge + if xp.any(left): + w[left] = 0.5 * (1.0 - xp.cos(np.pi * idx[left] / edge)) + right = idx >= (n - edge) + if xp.any(right): + tail = idx[right] - (n - edge) + w[right] = 0.5 * (1.0 - xp.cos(np.pi * (edge - tail) / edge)) + return w + + + +def _source_flip_unshifted(source: np.ndarray) -> np.ndarray: + source_flip = np.fft.fftshift(source) + source_flip = source_flip[::-1, ::-1] + if source_flip.shape[0] % 2 == 0: + source_flip = np.roll(source_flip, 1, axis=0) + if source_flip.shape[1] % 2 == 0: + source_flip = np.roll(source_flip, 1, axis=1) + return np.fft.ifftshift(source_flip) + + +def _build_illumination_patterns_unshifted( + led_na_xy: np.ndarray, + led_pattern_membership: np.ndarray, + *, + ny: int, + nx: int, + dxy_um: float, + wavelength_um: float, + na_obj: float, +) -> np.ndarray: + na_xy = np.asarray(led_na_xy, dtype=float) + membership = np.asarray(led_pattern_membership, dtype=bool) + fx = np.fft.fftfreq(int(nx), float(dxy_um)) + fy = np.fft.fftfreq(int(ny), float(dxy_um)) + fmax = float(na_obj) / float(wavelength_um) + fx_led = na_xy[:, 0] / float(wavelength_um) + fy_led = na_xy[:, 1] / float(wavelength_um) + fx_grid, fy_grid = np.meshgrid(fx, fy, indexing="xy") + mask = (fx_grid * fx_grid + fy_grid * fy_grid) <= (fmax * fmax + 1.0e-12) + + patterns = np.zeros((membership.shape[1], int(ny), int(nx)), dtype=float) + for led_idx in range(na_xy.shape[0]): + dist2 = (fx_grid - fx_led[led_idx]) ** 2 + (fy_grid - fy_led[led_idx]) ** 2 + dist2 = np.where(mask, dist2, np.inf) + flat = int(np.argmin(dist2)) + iy, ix = np.unravel_index(flat, dist2.shape) + for pid in range(membership.shape[1]): + if membership[led_idx, pid]: + patterns[pid, iy, ix] += 1.0 + return patterns + + +def build_wotf_transfer( + led_na_xy: np.ndarray, + led_pattern_membership: np.ndarray, + params: WOTFParams, + *, + real_dtype: type[np.floating] = np.float32, + complex_dtype: type[np.complexfloating] = np.complex64, +) -> tuple[np.ndarray, np.ndarray]: + """ + Build WOTF transfer functions for each illumination pattern. + """ + fx = np.fft.fftfreq(int(params.nx), float(params.dxy_um)).astype(real_dtype, copy=False) + fy = np.fft.fftfreq(int(params.ny), float(params.dxy_um)).astype(real_dtype, copy=False) + fx_grid, fy_grid = np.meshgrid(fx, fy, indexing="xy") + na_obj = float(params.na_obj) + na_in = float(params.na_in) + wavelength_um = float(params.wavelength_um) + na_r = np.sqrt(fx_grid * fx_grid + fy_grid * fy_grid) * wavelength_um + pupil = (na_r <= na_obj).astype(real_dtype) + if na_in != 0.0: + pupil[na_r < na_in] = 0.0 + taper_na = max(float(params.pupil_taper_na), 0.0) + if taper_na > 0.0: + t_outer = np.clip((na_obj - na_r) / taper_na, 0.0, 1.0) + outer_weight = 0.5 - 0.5 * np.cos(np.pi * t_outer) + if na_in > 0.0: + t_inner = np.clip((na_r - na_in) / taper_na, 0.0, 1.0) + inner_weight = 0.5 - 0.5 * np.cos(np.pi * t_inner) + else: + inner_weight = 1.0 + pupil *= outer_weight * inner_weight + + term_defocus = (1.0 / wavelength_um) ** 2 - fx_grid * fx_grid - fy_grid * fy_grid + term_oblique = (float(params.n0) / wavelength_um) ** 2 - fx_grid * fx_grid - fy_grid * fy_grid + phase_defocus = pupil * (2.0 * np.pi) * np.sqrt(np.maximum(term_defocus, 0.0)) + oblique_factor = pupil / (4.0 * np.pi * np.sqrt(np.maximum(term_oblique, 1.0e-12))) + + z_lin = np.fft.ifftshift((np.arange(int(params.nz)) - int(params.nz) // 2) * float(params.dz_um)).astype( + real_dtype, copy=False + ) + prop_kernel = np.exp(1.0j * z_lin[None, None, :] * phase_defocus[:, :, None]).astype( + complex_dtype, copy=False + ) + window_z = np.fft.ifftshift(np.hamming(int(params.nz))).astype(real_dtype, copy=False) + + patterns = _build_illumination_patterns_unshifted( + led_na_xy, + led_pattern_membership, + ny=int(params.ny), + nx=int(params.nx), + dxy_um=float(params.dxy_um), + wavelength_um=float(params.wavelength_um), + na_obj=float(params.na_obj), + ).astype(real_dtype, copy=False) + + dfx = 1.0 / (float(params.nx) * float(params.dxy_um)) + dfy = 1.0 / (float(params.ny) * float(params.dxy_um)) + + n_patterns = int(patterns.shape[0]) + H_real = np.zeros((n_patterns, int(params.nz), int(params.ny), int(params.nx)), dtype=complex_dtype) + H_imag = np.zeros_like(H_real) + + for pid in range(n_patterns): + source = patterns[pid] + source_flip = _source_flip_unshifted(source) + fsp = ft2(source_flip[:, :, None] * pupil[:, :, None] * prop_kernel, axes=(0, 1), shift=False) + fpg = ft2(pupil[:, :, None] * prop_kernel * oblique_factor[:, :, None], axes=(0, 1), shift=False).conj() + fsp_cfpg = fsp * fpg + h_real = 2.0 * ift2(1.0j * fsp_cfpg.imag * dfx * dfy, axes=(0, 1), shift=False) + h_imag = 2.0 * ift2(fsp_cfpg.real * dfx * dfy, axes=(0, 1), shift=False) + h_real *= window_z[None, None, :] + h_imag *= window_z[None, None, :] + h_real = ftn(h_real, axes=(2,), shift=False).astype(complex_dtype, copy=False) * float(params.dz_um) + h_imag = ftn(h_imag, axes=(2,), shift=False).astype(complex_dtype, copy=False) * float(params.dz_um) + h_real = np.transpose(h_real, (2, 0, 1)) + h_imag = np.transpose(h_imag, (2, 0, 1)) + total_source = np.sum(source_flip * pupil * pupil.conj()) * dfx * dfy + if total_source == 0: + raise ValueError("Total source power is zero for a pattern.") + H_real[pid] = h_real * (1.0j / total_source) + H_imag[pid] = h_imag * (1.0 / total_source) + + H_real = np.fft.fftshift(H_real, axes=(1, 2, 3)) + H_imag = np.fft.fftshift(H_imag, axes=(1, 2, 3)) + return H_real, H_imag + + +class WOTFFISTAOptimizer(Optimizer): + """ + FISTA optimizer for 3D WOTF pattern intensities (no DPC differencing). + """ + + def __init__( + self, + I_meas: array, + I0_pred: array, + H_real: array, + H_imag: array, + *, + n0: float, + wavelength_um: float, + wotf_sign: float = -1.0, + eps: float = 1.0e-8, + tv_weight: float = 0.0, + tv_max_num_iter: int = 50, + tv_eps: float = 2.0e-4, + tv_voxel_size_zyx: Optional[Sequence[float]] = None, + tv_weight_scale_zyx: Optional[Sequence[float]] = None, + pad_zyx: Optional[Sequence[int]] = None, + z_taper: int = 0, + xy_taper: Optional[int] = None, + use_real_constraint: bool = True, + ) -> None: + super().__init__(n_samples=1, prox_parameters=None) + self.n0 = float(n0) + self.wavelength_um = float(wavelength_um) + self.wotf_sign = float(wotf_sign) + self.eps = float(eps) + self.tv_weight = float(tv_weight) + self.tv_max_num_iter = int(tv_max_num_iter) + self.tv_eps = float(tv_eps) + self.use_real_constraint = bool(use_real_constraint) + if tv_voxel_size_zyx is None: + self.tv_voxel_size_zyx = None + else: + if len(tv_voxel_size_zyx) != 3: + raise ValueError("tv_voxel_size_zyx must be a 3-tuple (dz, dy, dx).") + self.tv_voxel_size_zyx = tuple(float(v) for v in tv_voxel_size_zyx) + if tv_weight_scale_zyx is None: + self.tv_weight_scale_zyx = None + else: + if len(tv_weight_scale_zyx) != 3: + raise ValueError("tv_weight_scale_zyx must be a 3-tuple (sz, sy, sx).") + self.tv_weight_scale_zyx = tuple(float(v) for v in tv_weight_scale_zyx) + self._I_meas_cpu = np.asarray(to_cpu(I_meas)) + self._I0_pred_cpu = np.asarray(to_cpu(I0_pred)) + self._H_real_cpu = np.asarray(to_cpu(H_real)) + self._H_imag_cpu = np.asarray(to_cpu(H_imag)) + self._shape_zyx = self._I_meas_cpu.shape[1:] + self._z_taper = int(z_taper) + self._xy_taper = int(z_taper if xy_taper is None else xy_taper) + if pad_zyx is None: + self.pad_zyx = (0, 0, 0) + else: + if len(pad_zyx) != 3: + raise ValueError("pad_zyx must be a 3-tuple (pz, py, px).") + self.pad_zyx = tuple(int(v) for v in pad_zyx) + if any(v < 0 for v in self.pad_zyx): + raise ValueError("pad_zyx entries must be >= 0.") + if any(self.pad_zyx): + nz, ny, nx = self._shape_zyx + pz, py, px = self.pad_zyx + expected = (nz + 2 * pz, ny + 2 * py, nx + 2 * px) + if self._H_real_cpu.shape[1:] != expected: + raise ValueError( + "H_real/H_imag must match padded shape " + f"{expected} when pad_zyx={self.pad_zyx}." + ) + self._xp_cache = None + + def _pad_volume(self, xp: Any, vol: array) -> array: + if not any(self.pad_zyx): + return vol + pz, py, px = self.pad_zyx + return xp.pad(vol, ((pz, pz), (py, py), (px, px)), mode="constant") + + def _crop_volume(self, vol: array) -> array: + if not any(self.pad_zyx): + return vol + pz, py, px = self.pad_zyx + nz, ny, nx = self._shape_zyx + return vol[pz : pz + nz, py : py + ny, px : px + nx] + + def _pad_contrast(self, xp: Any, contrast: array) -> array: + if not any(self.pad_zyx): + return contrast + pz, py, px = self.pad_zyx + return xp.pad(contrast, ((0, 0), (pz, pz), (py, py), (px, px)), mode="constant") + + def _crop_contrast(self, contrast: array) -> array: + if not any(self.pad_zyx): + return contrast + pz, py, px = self.pad_zyx + nz, ny, nx = self._shape_zyx + return contrast[:, pz : pz + nz, py : py + ny, px : px + nx] + + def _ensure_backend(self, x: array) -> tuple: + xp = cp if cp and isinstance(x, cp.ndarray) else np + if self._xp_cache is None or self._xp_cache[0] is not xp: + z_weight = None + xy_weight = None + if self._z_taper > 0: + z_weight = _tukey_window(int(self._shape_zyx[0]), self._z_taper, xp) + xy_weight = _tukey_window(int(self._shape_zyx[1]), self._xy_taper, xp) + self._xp_cache = ( + xp, + xp.asarray(self._I_meas_cpu), + xp.asarray(self._I0_pred_cpu), + xp.asarray(self._H_real_cpu), + xp.asarray(self._H_imag_cpu), + z_weight, + xy_weight, + ) + return self._xp_cache + + def _forward_contrast(self, n_now: array) -> array: + xp, I_meas, I0_pred, H_real, H_imag, _, _ = self._ensure_backend(n_now) + v_now = get_v(n_now, self.n0, self.wavelength_um) + v_now = self._pad_volume(xp, v_now) + v_ft = ft3(v_now, axes=(0, 1, 2), shift=True) + pred_ft = H_real * v_ft.real + H_imag * v_ft.imag + contrast = ift3(pred_ft, axes=(1, 2, 3), shift=True).real + contrast = self._crop_contrast(contrast) + return contrast + + def _forward_intensity(self, n_now: array) -> array: + xp, I_meas, I0_pred, H_real, H_imag, _, _ = self._ensure_backend(n_now) + contrast = self._forward_contrast(n_now) + I_pred = I0_pred * (1.0 + self.wotf_sign * contrast) + return xp.maximum(I_pred, xp.float32(0.0)) + + def cost(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + xp, I_meas, I0_pred, H_real, H_imag, z_weight, xy_weight = self._ensure_backend(x) + contrast_pred = self.wotf_sign * self._forward_contrast(x) + contrast_meas = I_meas / xp.maximum(I0_pred, xp.float32(self.eps)) - xp.float32(1.0) + residual = contrast_pred - contrast_meas + if z_weight is not None: + residual = residual * z_weight[None, :, None, None] + if xy_weight is not None: + residual = residual * xy_weight[None, None, :, None] + residual = residual * xy_weight[None, None, None, :] + cost = 0.5 * xp.sum(residual ** 2) + return xp.asarray(cost)[None] + + def gradient(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + xp, I_meas, I0_pred, H_real, H_imag, z_weight, xy_weight = self._ensure_backend(x) + contrast_pred = self.wotf_sign * self._forward_contrast(x) + contrast_meas = I_meas / xp.maximum(I0_pred, xp.float32(self.eps)) - xp.float32(1.0) + g_contrast = contrast_pred - contrast_meas + if z_weight is not None: + g_contrast = g_contrast * z_weight[None, :, None, None] + if xy_weight is not None: + g_contrast = g_contrast * xy_weight[None, None, :, None] + g_contrast = g_contrast * xy_weight[None, None, None, :] + g_contrast = self._pad_contrast(xp, g_contrast) + g_contrast_ft = ift3(g_contrast, axes=(1, 2, 3), shift=True, adjoint=True) + grad_vft_real = self.wotf_sign * xp.sum(xp.conj(H_real) * g_contrast_ft, axis=0) + grad_vft_imag = self.wotf_sign * xp.sum(xp.conj(H_imag) * g_contrast_ft, axis=0) + grad_v = ft3( + xp.real(grad_vft_real) + 1j * xp.real(grad_vft_imag), + axes=(0, 1, 2), + shift=True, + adjoint=True, + ) + grad_v = self._crop_volume(grad_v) + k2 = (2.0 * np.pi / float(self.wavelength_um)) ** 2 + grad = -2.0 * k2 * x * grad_v.real + return grad[None, ...] + + def prox(self, x: array, step: float) -> array: + xp = cp if cp and isinstance(x, cp.ndarray) else np + n_out = x + if self.use_real_constraint: + n_out = xp.maximum(n_out, float(self.n0)) + if self.tv_weight != 0.0: + if not (cp and isinstance(n_out, cp.ndarray)): + raise ValueError("tv_prox_fast requires GPU/CuPy.") + n_out = tv_prox_fast( + cp.asarray(n_out, dtype=cp.float32), + float(self.tv_weight), + num_iter=int(self.tv_max_num_iter), + eps=float(self.tv_eps), + voxel_size_zyx=self.tv_voxel_size_zyx, + weight_scale_zyx=self.tv_weight_scale_zyx, + ) + return n_out diff --git a/pyproject.toml b/pyproject.toml index 21844d6..370cc4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ gpu = ['cupy-cuda11x', # assuming 11.2 <= CUDA version < 12. Otherwise, manuall 'cucim ; platform_system != "win32"' ] optional = ['napari', - 'miepython', + 'miepython==2.5.5', 'psfmodels'] dev = ['sphinx'] all = ['mcsim[expt_ctrl, gpu, optional, dev]'] diff --git a/tests/mie_unittest.py b/tests/mie_unittest.py index 6b75198..65b651e 100644 --- a/tests/mie_unittest.py +++ b/tests/mie_unittest.py @@ -52,3 +52,6 @@ def test_yn_gpu(self): np.testing.assert_allclose(yn_scipy[mask], yn_cp[mask], atol=1e-8) np.testing.assert_allclose(dyn_scipy[mask], dyn_cp[mask], atol=1e-8) + +if __name__ == "__main__": + unittest.main()