From 2524f503e6abeab0cd5ac09d60cff2e8505a9527 Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Wed, 3 Sep 2025 07:43:42 -0700 Subject: [PATCH 01/10] initial commit --- examples/reconstruct_dpc.py | 750 ++++++++++++++++++++++++++++++++++++ 1 file changed, 750 insertions(+) create mode 100644 examples/reconstruct_dpc.py diff --git a/examples/reconstruct_dpc.py b/examples/reconstruct_dpc.py new file mode 100644 index 0000000..999f928 --- /dev/null +++ b/examples/reconstruct_dpc.py @@ -0,0 +1,750 @@ +from typing import Optional, Sequence, Union +import numpy as np +from mcsim.optimize import Optimizer +from mcsim.optimize.prox import tv_prox, soft_threshold, median_prox +try: + import zarr # must be >= 3 +except Exception as _e: + zarr = None + _zarr_err = _e + +try: + import cupy as cp +except ImportError: + cp = None + +# ----------------------- +# Alias for CPU/GPU arrays +# ----------------------- +if cp: + array = Union[np.ndarray, cp.ndarray] +else: + array = np.ndarray + +def _to_cpu(a): + """Return NumPy array from NumPy/CuPy input without copying if possible.""" + if cp is not None and isinstance(a, cp.ndarray): + return cp.asnumpy(a) + return np.asarray(a) + +class DPC3D(Optimizer): + """ + 3D Differential Phase-Contrast inverse problem with 4 half-circle source patterns. + + Forward model (Chen–Tian–Waller, BOE 2016, Eq. (6)): + Ĩ_ℓ(kz, ky, kx) = H_Re,ℓ(kz, ky, kx)·V̂_Re(kz, ky, kx) + + H_Im,ℓ(kz, ky, kx)·V̂_Im(kz, ky, kx), + + with H_Re (Eq. (7)) and H_Im (Eq. (8)) the 3D phase/absorption WOTFs determined + by the source S (left/right/top/bottom half-circles) and detection pupil P. + Arrays are stored and operated in [pattern, z, y, x] order in *reciprocal space*. + + Parameters + ---------- + shape_zyx : tuple[int, int, int] + Spatial volume shape (Z, Y, X). + voxel_size_zyx : tuple[float, float, float] + Voxel size (dz, dy, dx) in meters. + wavelength : float + Illumination wavelength in meters. + n_medium : float + Refractive index of immersion medium. + na_obj : float + Objective NA (detection pupil). + na_src : float + Illumination NA (condenser/source). + I_meas_4p : array | None + Optional measured intensity stacks with shape (4, Z, Y, X) in *space*. + If provided, they are normalized and FFT’d to produce I_hat_meas. + If None, you must later set self.I_hat_meas. + prox_parameters : dict | None + Prox controls: {'tv_re', 'tv_im', 'soft_im', 'positivity_re', 'median'}. + src_grid_N : int + Discretization of the source pupil per axis (odd recommended). + defocus_waves : float + Optional *defocus* aberration in waves (adds quadratic phase in P). + 0.0 means ideal pupil amplitude-only. + use_ortho_fft : bool + Orthonormal FFTs (recommended -> clean Lipschitz scaling). + use_gpu : bool + Force GPU if CuPy is available and True; else CPU. + """ + + PATTERNS = ("left", "right", "top", "bottom") + + def __init__(self, + shape_zyx: tuple[int, int, int], + voxel_size_zyx: tuple[float, float, float], + wavelength: float, + n_medium: float, + na_obj: float, + na_src: float, + I_meas_4p: Optional[array] = None, + prox_parameters: Optional[dict] = None, + src_grid_N: int = 65, + defocus_waves: float = 0.0, + use_ortho_fft: bool = True, + use_gpu: bool = False): + + Z, Y, X = map(int, shape_zyx) + dz, dy, dx = map(float, voxel_size_zyx) + + # ---------- backend ---------- + self.use_gpu = bool(use_gpu and (cp is not None)) + self.xp = cp if self.use_gpu else np + + # ---------- geometry ---------- + self.Z, self.Y, self.X = Z, Y, X + self.dz, self.dy, self.dx = dz, dy, dx + self.wavelength = float(wavelength) + self.n = float(n_medium) + self.na_obj = float(na_obj) + self.na_src = float(na_src) + self.k0_cyc = self.n / self.wavelength # cycles per meter + + # ---------- FFT norms ---------- + self.use_ortho_fft = bool(use_ortho_fft) + self._fft_norm = 'ortho' if self.use_ortho_fft else None + + # ---------- frequencies ---------- + xp = self.xp + fz = xp.fft.fftfreq(Z, d=dz) # cycles/m (unshifted) + fy = xp.fft.fftfreq(Y, d=dy) + fx = xp.fft.fftfreq(X, d=dx) + self.fz = fz + self.fy = fy + self.fx = fx + # Shifted z for splatting convenience + self.fz_shift = xp.fft.fftshift(fz) + self.dfz = float(self.fz_shift[1] - self.fz_shift[0]) + self.fz_min = float(self.fz_shift[0]) + self.fz_max = float(self.fz_shift[-1]) + + # ---------- pupil (amplitude 1 inside NA, optional defocus phase) ---------- + self.P = self._make_pupil(defocus_waves=defocus_waves) # (Y, X), complex + + # ---------- WOTFs for 4 patterns ---------- + self.H_re, self.H_im = self._build_wotf_4patterns(src_grid_N=src_grid_N) # (4,Z,Y,X), real + + # ---------- measured spectra ---------- + self.I_hat_meas = None + if I_meas_4p is not None: + self.set_measurements(I_meas_4p) + + # ---------- base Optimizer ---------- + super().__init__(n_samples=4, prox_parameters=prox_parameters or {}) + self._L_est = None # Lipschitz cache + + # =========================================================== + # Eq. (10) helpers: synthesize BF and DPC stacks from raw 4 patterns + # =========================================================== + @staticmethod + def synthesize_bf_dpc( + I_lrbt: array, + eps: float = 1e-12, + pair_norm: bool = False, + ) -> tuple[array, array, array]: + """ + Build brightfield (BF) and two DPC stacks from raw LEFT/RIGHT/TOP/BOTTOM. + + Parameters + ---------- + I_lrbt : array + Raw intensity stacks in *space*, shape (4, Z, Y, X) ordered as: + [left, right, top, bottom]. + eps : float + Numerical floor to avoid divide-by-zero. + pair_norm : bool + If False (default), use the paper's BF normalization: + I_BF = mean(L, R, T, B) + DPC_x = (R - L) / mean(I_BF) + DPC_y = (T - B) / mean(I_BF) + If True, use pairwise normalization (sometimes used in 2D DPC): + DPC_x = (R - L) / (R + L + eps) + DPC_y = (T - B) / (T + B + eps) + I_BF = (L + R + T + B)/4 (returned for completeness) + + Returns + ------- + I_BF : array + Brightfield stack, (Z, Y, X). + DPC_x : array + Left–Right differential (X-direction), (Z, Y, X). + DPC_y : array + Top–Bottom differential (Y-direction), (Z, Y, X). + + Notes + ----- + This follows the synthesis described around Eq. (10) of Chen, Tian & Waller (2016). + They normalize by a DC/brightfield term; when a separate background BF is not + available, they approximate it with the *average* brightfield intensity + measured from the data (WOA regime), which is what the default branch implements. :contentReference[oaicite:0]{index=0} + """ + L, R, T, B = I_lrbt[0], I_lrbt[1], I_lrbt[2], I_lrbt[3] + I_BF = (L + R + T + B) / 4.0 + + if pair_norm: + DPC_x = (R - L) / (R + L + eps) + DPC_y = (T - B) / (T + B + eps) + else: + # DC/brightfield normalization (Eq. 10 discussion) + I_DC = I_BF.mean() # scalar , cf. paper text + DPC_x = (R - L) / (I_DC + eps) + DPC_y = (T - B) / (I_DC + eps) + + return I_BF, DPC_x, DPC_y + + def set_measurements_from_raw_patterns( + self, + I_lrbt_space: array, + normalize: bool = True, + pair_norm: bool = False, + ) -> None: + """ + Normalize background (optional) and FFT the *raw* 4-pattern stacks. + + Parameters + ---------- + I_lrbt_space : array + Raw LEFT/RIGHT/TOP/BOTTOM volumes in *space*, (4, Z, Y, X). + normalize : bool + If True, do the same background/DC normalization used in set_measurements(). + pair_norm : bool + If you plan to *also* compute and use BF/DPC stacks elsewhere, this lets + you reproduce pairwise DPC for inspection. This flag does not affect the + raw-pattern spectra used by the forward model here. + """ + xp = self.xp + assert I_lrbt_space.shape == (4, self.Z, self.Y, self.X), "Use [pattern,z,y,x] ordering" + + I = xp.array(I_lrbt_space, copy=True) + + if normalize: + # simple per-pattern background removal + bg = xp.min(I, axis=(1, 2, 3), keepdims=True) + I = I - bg + # global DC scale as in Eq. (10) paragraph (use BF mean as DC) + I_BF, _, _ = self.synthesize_bf_dpc(I, pair_norm=pair_norm) + dc = float(I_BF.mean()) + I = I / max(dc, 1e-12) + + # store Fourier-domain measurement spectra for the *4* raw patterns + self.I_hat_meas = xp.fft.fftn(I, axes=(-3, -2, -1), norm=self._fft_norm) + + # (optional) convenience if you want BF/DPC spectra too (not used by current forward model) + def bf_dpc_kspace(self, I_lrbt_space: array, pair_norm: bool = False) -> tuple[array, array, array]: + """ + FFT of the synthesized BF/DPC stacks (space -> k-space). + Returns (I_BF_hat, DPCx_hat, DPCy_hat), all (Z, Y, X). + """ + xp = self.xp + I_BF, DPCx, DPCy = self.synthesize_bf_dpc(I_lrbt_space, pair_norm=pair_norm) + F = lambda v: xp.fft.fftn(v, axes=(-3, -2, -1), norm=self._fft_norm) + return F(I_BF), F(DPCx), F(DPCy) + + + # =========================================================== + # Public API + # =========================================================== + def set_measurements(self, I_meas_4p: array, dc: Optional[float] = None) -> None: + """ + Normalize and FFT the 4 raw pattern stacks (left/right/top/bottom). + + Normalization (Eq. (10) spirit): subtract background (min over z/y/x), + divide by a DC scalar (provided or mean over brightfield proxy). + """ + xp = self.xp + assert I_meas_4p.shape == (4, self.Z, self.Y, self.X) + I = xp.array(I_meas_4p, copy=True) + + # simple background remove (per-pattern) then single DC scale + bg = xp.min(I, axis=(1, 2, 3), keepdims=True) + I = I - bg + + if dc is None: + # approximate DC as average “brightfield” (mean over all patterns) + dc = float(xp.mean(I)) + I = I / max(dc, 1e-12) + + # 3D FFT per pattern to reciprocal space + self.I_hat_meas = xp.fft.fftn(I, axes=(-3, -2, -1), norm=self._fft_norm) + + # =========================================================== + # Optimizer interface: forward, adjoint, cost, grad, prox, step + # =========================================================== + def fwd_model(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + xp = self.xp + if inds is None: + inds = range(4) + V_hat = xp.fft.fftn(x, axes=(-3, -2, -1), norm=self._fft_norm) # (Z,Y,X), complex + Vhat_re = V_hat.real + Vhat_im = V_hat.imag + Hre = self.H_re[xp.asarray(inds)] + Him = self.H_im[xp.asarray(inds)] + return Hre * Vhat_re + Him * Vhat_im # (len(inds),Z,Y,X), real + + def fwd_model_adjoint(self, y_hat: array, inds: Optional[Sequence[int]] = None) -> array: + xp = self.xp + if inds is None: + inds = range(4) + Hre = self.H_re[xp.asarray(inds)] + Him = self.H_im[xp.asarray(inds)] + # per-pattern k-space gradient parts (sum over patterns) + ghat_re = xp.sum(Hre * y_hat, axis=0) + ghat_im = xp.sum(Him * y_hat, axis=0) + g_hat = ghat_re + 1j * ghat_im + return xp.fft.ifftn(g_hat, axes=(-3, -2, -1), norm=self._fft_norm) + + def cost(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + """ + Per-pattern data term; Optimizer takes the MEAN across patterns, + so the scalar descent criterion is automatically satisfied. + """ + xp = self.xp + if inds is None: + inds = range(4) + assert self.I_hat_meas is not None, "Call set_measurements(...) first." + pred = self.fwd_model(x, inds=inds) + resid = pred - self.I_hat_meas[xp.asarray(inds)] + vol = self.Z * self.Y * self.X + return 0.5 * xp.sum(xp.abs(resid) ** 2, axis=(-3, -2, -1)) / vol # (len(inds),) + + def gradient(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + xp = self.xp + if inds is None: + inds = range(4) + assert self.I_hat_meas is not None, "Call set_measurements(...) first." + pred = self.fwd_model(x, inds=inds) + resid = pred - self.I_hat_meas[xp.asarray(inds)] + Hre = self.H_re[xp.asarray(inds)] + Him = self.H_im[xp.asarray(inds)] + ghat_re = Hre * resid + ghat_im = Him * resid + g_hat = ghat_re + 1j * ghat_im + return xp.fft.ifftn(ghat_re + 1j * ghat_im, axes=(-3, -2, -1), norm=self._fft_norm) + + def prox(self, x: array, step: float) -> array: + # reuse your provided helpers (tv_prox, soft_threshold, median_prox) + xp = self.xp + v_re = x.real + v_im = x.imag + + tv_re = self.prox_parameters.get('tv_re', 0.0) or 0.0 + tv_im = self.prox_parameters.get('tv_im', 0.0) or 0.0 + if tv_re > 0: + v_re = tv_prox(v_re, tau=step * tv_re) + if tv_im > 0: + v_im = tv_prox(v_im, tau=step * tv_im) + + soft_im = self.prox_parameters.get('soft_im', 0.0) or 0.0 + if soft_im > 0: + v_im = soft_threshold(step * soft_im, v_im) + + med = self.prox_parameters.get('median', None) + if med is not None: + v_re = median_prox(v_re, size=med) + v_im = median_prox(v_im, size=med) + + if self.prox_parameters.get('positivity_re', False): + v_re = xp.maximum(v_re, 0) + + return v_re + 1j * v_im + + def guess_step(self, x: Optional[array] = None) -> float: + if self._L_est is None: + xp = self.xp + power = xp.abs(self.H_re) ** 2 + xp.abs(self.H_im) ** 2 # (4,Z,Y,X) + self._L_est = float(xp.max(xp.sum(power, axis=0))) # max_k sum over patterns + self._L_est = max(self._L_est, 1e-8) + return 0.9 / self._L_est + + # =========================================================== + # Internals: Pupil, Source, WOTF construction + # =========================================================== + def _make_pupil(self, defocus_waves: float = 0.0) -> array: + """ + Ideal circular pupil (amplitude 1 inside NA). Optional defocus phase. + Returns complex array of shape (Y, X). + """ + xp = self.xp + fy, fx = xp.meshgrid(self.fy, self.fx, indexing="ij") + rho = xp.sqrt(fx**2 + fy**2) / self.k0_cyc + P_amp = (rho <= self.na_obj).astype(self.xp.float32) + if defocus_waves == 0.0: + return P_amp.astype(self.xp.complex64) + # Defocus Zernike ~ 2*r^2 - 1 on unit disk (simple quadratic phase) + r = xp.clip(rho / self.na_obj, 0.0, 1.0) + phase = (2.0 * r**2 - 1.0) * (2.0 * np.pi * defocus_waves) + return P_amp * xp.exp(1j * phase) + + def _build_wotf_4patterns(self, src_grid_N: int = 65) -> tuple[array, array]: + """ + Numerically assemble H_Re/H_Im on the reciprocal grid for 4 DPC patterns. + + Implementation notes + -------------------- + * Non-paraxial kinematics: fz = kz(q+u) - kz(u), kz(w) = sqrt(f0^2 - |w|^2), + in *cycles/m* (no 2π). We splat both +fz and -fz contributions. + * Absorption TF is even in z (add both signs); Phase TF is odd (subtract). + * Amplitude-only pupil assumed (|P| in {0,1}). Aberration phase can be + added in P if desired; mixing terms then require complex products. + """ + xp = self.xp + Z, Y, X = self.Z, self.Y, self.X + + # allocate shifted-z buffers for easy splatting, then unshift + Hre_s = xp.zeros((4, Z, Y, X), dtype=self.xp.float32) + Him_s = xp.zeros_like(Hre_s) + + f0 = self.k0_cyc + fy, fx = xp.meshgrid(self.fy, self.fx, indexing="ij") # (Y,X) + + # precompute |P(q+u)| mask efficiently by evaluating radius threshold + P_amp = xp.abs(self.P) # (Y,X) in {0,1} + + # discretize source pupil on a square grid, mask to circle of radius na_src + uu = xp.linspace(-self.na_src * f0, self.na_src * f0, src_grid_N) + uy_grid, ux_grid = xp.meshgrid(uu, uu, indexing="ij") + src_r = xp.sqrt(ux_grid**2 + uy_grid**2) + in_src = src_r <= (self.na_src * f0) + + # four half-circles: left (ux<0), right (ux>0), top (uy>0), bottom (uy<0) + half_masks = [ + in_src & (ux_grid < 0), # left + in_src & (ux_grid > 0), # right + in_src & (uy_grid > 0), # top + in_src & (uy_grid < 0), # bottom + ] + + # weights: uniform over active source samples per pattern + for pidx, src_mask in enumerate(half_masks): + ux_list = ux_grid[src_mask] + uy_list = uy_grid[src_mask] + if ux_list.size == 0: + continue + w = xp.ones_like(ux_list, dtype=self.xp.float32) + w /= float(ux_list.size) + + # iterate source samples (vectorized over q=(fy,fx)) + for uxi, uyi, wi in zip(ux_list, uy_list, w): + # unscattered ray must pass through pupil |P(u)|>0 + Pup_u = (xp.sqrt((uxi)**2 + (uyi)**2) <= (self.na_obj * f0)) + if not bool(Pup_u): + continue + + # lateral shift q+u + fx_shift = fx + uxi + fy_shift = fy + uyi + rad_shift = xp.sqrt(fx_shift**2 + fy_shift**2) + + # scattered ray must pass through pupil + pass_mask = (rad_shift <= (self.na_obj * f0)).astype(self.xp.float32) + if not xp.any(pass_mask): + continue + + # axial frequency difference (non-paraxial) + kz_u = xp.sqrt(xp.maximum(0.0, f0**2 - (uxi**2 + uyi**2))) + kz_qu = xp.sqrt(xp.maximum(0.0, f0**2 - rad_shift**2)) + dfz = kz_qu - kz_u # (Y,X) cycles/m + + # contributions only where valid (pass_mask) + A = wi * pass_mask # amplitude weight; |P(q+u)|*|P(u)|=1 here + + # splat +dfz and -dfz onto shifted fz grid + self._splat_plane(Hre_s, Him_s, pidx, A, +dfz) + self._splat_plane(Hre_s, Him_s, pidx, A, -dfz) + + # unshift z -> native FFT ordering + Hre = xp.fft.ifftshift(Hre_s, axes=(-3)) + Him = xp.fft.ifftshift(Him_s, axes=(-3)) + return Hre, Him + + def _splat_plane(self, + Hre_s: array, + Him_s: array, + pidx: int, + A_yx: array, + dfz_yx: array) -> None: + """ + Deposit weights at z-planes nearest to dfz(y,x) on *shifted* fz axis. + + H_Im: add +A + H_Re: add +A for +dfz, and -A for -dfz (caller passes sign via dfz) + """ + xp = self.xp + Z, Y, X = self.Z, self.Y, self.X + + # map dfz -> fractional index on shifted axis + # clamp to Nyquist + dfz_clamped = xp.clip(dfz_yx, self.fz_min, self.fz_max) + zf = (dfz_clamped - self.fz_min) / self.dfz # fractional in [0, Z-1] + z0 = xp.floor(zf).astype(self.xp.int32) + z1 = xp.clip(z0 + 1, 0, Z - 1) + alpha = (zf - z0).astype(self.xp.float32) + + # odd/even combination for phase/absorption + # H_Im receives +A; H_Re receives sign(dfz)*A (odd in z) + sign = xp.where(dfz_yx >= 0, 1.0, -1.0).astype(self.xp.float32) + + # Flatten indexing for scatter-add + yy, xx = xp.meshgrid(xp.arange(Y), xp.arange(X), indexing="ij") + w0 = (1.0 - alpha) * A_yx + w1 = alpha * A_yx + + # absorption (even) + self._add_at(Him_s, (pidx, z0, yy, xx), w0) + self._add_at(Him_s, (pidx, z1, yy, xx), w1) + + # phase (odd) + self._add_at(Hre_s, (pidx, z0, yy, xx), sign * w0) + self._add_at(Hre_s, (pidx, z1, yy, xx), sign * w1) + + def _add_at(self, arr4: array, idxs: tuple, weights: array) -> None: + """backend-safe scatter add into 4D [pattern,z,y,x] (shifted-z buffers).""" + xp = self.xp + p, z, y, x = idxs + if self.use_gpu: + # CuPy supports add.at + cp.add.at(arr4, (p, z, y, x), weights) + else: + np.add.at(arr4, (p, z, y, x), np.asarray(weights)) + + # ------------------------- + # Zarr v3 save / load + # ------------------------- + def save_wotf_zarr( + self, + store_path: str, + *, + overwrite: bool = False, + chunks: Optional[tuple[int, int, int, int]] = None, + include_pupil: bool = False, + compressor: Optional[dict] = None, + ) -> None: + """ + Save WOTFs to a Zarr v3 store on disk. + + Parameters + ---------- + store_path : str + Directory path for the Zarr store (e.g., "./wotf_cache.zarr"). + overwrite : bool + If True, remove any existing store_path first. + chunks : tuple[int, int, int, int] | None + Chunk shape for arrays in [pattern,z,y,x]. Default is (1, min(16,Z), min(128,Y), min(128,X)). + include_pupil : bool + If True, also save Pupil (as P_real/P_imag with shape (Y, X)). + compressor : dict | None + Optional Zarr v3 compressor config, e.g. {"id": "zstd", "level": 3}. + """ + if zarr is None: + raise RuntimeError(f"zarr>=3 required but not available: {_zarr_err}") + + Z, Y, X = self.Z, self.Y, self.X + H_re_cpu = _to_cpu(self.H_re).astype(np.float32, copy=False) + H_im_cpu = _to_cpu(self.H_im).astype(np.float32, copy=False) + + if chunks is None: + chunks = (1, min(16, Z), min(128, Y), min(128, X)) + + # Prepare store directory + if os.path.exists(store_path): + if overwrite: + # safest removal: remove directory tree + import shutil + shutil.rmtree(store_path) + else: + raise FileExistsError(f"{store_path} exists. Use overwrite=True to replace.") + + root = zarr.open_group(store_path, mode="w", zarr_version=3) + + # Root metadata + meta = dict( + version=1, + order="pattern,z,y,x", + shape_zyx=(Z, Y, X), + voxel_size_zyx=(self.dz, self.dy, self.dx), + wavelength=self.wavelength, + n_medium=self.n, + na_obj=self.na_obj, + na_src=self.na_src, + use_ortho_fft=bool(self.use_ortho_fft), + dtype="float32", + notes="H_re/H_im are real-valued WOTFs; pupil saved as P_real/P_imag if requested.", + ) + root.attrs["dpc3d_meta_json"] = json.dumps(meta) + + # Create arrays + aHre = root.create_array( + "H_re", + shape=(4, Z, Y, X), + chunks=chunks, + dtype="f4", + compressor=compressor, + ) + aHim = root.create_array( + "H_im", + shape=(4, Z, Y, X), + chunks=chunks, + dtype="f4", + compressor=compressor, + ) + + aHre[:] = H_re_cpu + aHim[:] = H_im_cpu + + if include_pupil: + P = _to_cpu(self.P) + root.create_array("P_real", shape=(Y, X), chunks=(min(256, Y), min(256, X)), + dtype="f4", compressor=compressor)[:] = P.real.astype(np.float32, copy=False) + root.create_array("P_imag", shape=(Y, X), chunks=(min(256, Y), min(256, X)), + dtype="f4", compressor=compressor)[:] = P.imag.astype(np.float32, copy=False) + + def load_wotf_from_zarr( + self, + store_path: str, + *, + strict_geometry: bool = True, + map_to_gpu: Optional[bool] = None, + ) -> None: + """ + Load WOTFs from a Zarr v3 store and install into this instance. + + Parameters + ---------- + store_path : str + Directory path of the Zarr store (e.g., "./wotf_cache.zarr"). + strict_geometry : bool + If True, validate Z,Y,X and optics (λ, n, NA) against this instance. + map_to_gpu : bool | None + If True (and CuPy available), map arrays to GPU. If None, follow current instance backend. + """ + if zarr is None: + raise RuntimeError(f"zarr>=3 required but not available: {_zarr_err}") + + root = zarr.open_group(store_path, mode="r", zarr_version=3) + + # Load metadata + meta_json = root.attrs.get("dpc3d_meta_json", "{}") + try: + meta = json.loads(meta_json) + except Exception: + meta = {} + + # Read arrays (NumPy on load) + H_re_np = np.asarray(root["H_re"][:], dtype=np.float32) + H_im_np = np.asarray(root["H_im"][:], dtype=np.float32) + + if strict_geometry: + mZ, mY, mX = tuple(meta.get("shape_zyx", ())) + if (mZ, mY, mX) != (self.Z, self.Y, self.X): + raise ValueError(f"Cached WOTF shape ZYX {mZ,mY,mX} != current {(self.Z,self.Y,self.X)}") + # optics check (tolerant) + def _close(a, b, tol=1e-9) -> bool: + return abs(float(a) - float(b)) <= tol * max(1.0, abs(float(a)), abs(float(b))) + if not ( + _close(meta.get("wavelength", self.wavelength), self.wavelength) and + _close(meta.get("n_medium", self.n), self.n) and + _close(meta.get("na_obj", self.na_obj), self.na_obj) and + _close(meta.get("na_src", self.na_src), self.na_src) + ): + raise ValueError("Cached optics (λ, n, NA) differ from current instance.") + + # Map to GPU if requested + use_gpu = self.use_gpu if map_to_gpu is None else (bool(map_to_gpu) and (cp is not None)) + if use_gpu: + self.H_re = cp.asarray(H_re_np) + self.H_im = cp.asarray(H_im_np) + # Optional pupil + if "P_real" in root and "P_imag" in root: + Pre = cp.asarray(np.asarray(root["P_real"][:], dtype=np.float32)) + Pim = cp.asarray(np.asarray(root["P_imag"][:], dtype=np.float32)) + self.P = (Pre + 1j * Pim).astype(cp.complex64, copy=False) + else: + self.H_re = H_re_np + self.H_im = H_im_np + if "P_real" in root and "P_imag" in root: + Pre = np.asarray(root["P_real"][:], dtype=np.float32) + Pim = np.asarray(root["P_imag"][:], dtype=np.float32) + self.P = (Pre + 1j * Pim).astype(np.complex64, copy=False) + + self._L_est = None # recompute step bound on next guess_step() + + @classmethod + def from_cached_wotf_zarr( + cls, + store_path: str, + *, + shape_zyx: tuple[int, int, int], + voxel_size_zyx: tuple[float, float, float], + wavelength: float, + n_medium: float, + na_obj: float, + na_src: float, + prox_parameters: Optional[dict] = None, + use_ortho_fft: bool = True, + use_gpu: bool = False, + ): + """ + Construct a DPC3D instance and populate WOTFs from a Zarr v3 cache. + + Notes + ----- + Pupil is rebuilt analytically; if P_real/P_imag exist in the store, they + overwrite the analytic P for exact reproducibility. + """ + obj = cls( + shape_zyx=shape_zyx, + voxel_size_zyx=voxel_size_zyx, + wavelength=wavelength, + n_medium=n_medium, + na_obj=na_obj, + na_src=na_src, + I_meas_4p=None, + prox_parameters=prox_parameters, + src_grid_N=3, # placeholder; WOTF will be loaded + defocus_waves=0.0, + use_ortho_fft=use_ortho_fft, + use_gpu=use_gpu, + ) + obj.load_wotf_from_zarr(store_path, strict_geometry=True, map_to_gpu=use_gpu) + return obj + +if __name__ == "__main__": + xp = cp if cp is not None else np # choose GPU if available + + # Geometry (example) + shape = (100, 256, 256) # Z,Y,X voxels + voxel = (1e-6, 0.2e-6, 0.2e-6) # dz,dy,dx in meters + λ = 520e-9 # wavelength (m) + n0 = 1.33 # medium RI + NA_obj = 0.65 + NA_src = 0.65 + + # Build problem + dpc = DPC3D(shape, voxel, λ, n0, NA_obj, NA_src, + I_meas_4p=your_raw_4pattern_stacks, # (4,Z,Y,X) in space + prox_parameters=dict(tv_re=1e-3, positivity_re=True), + src_grid_N=65, + use_ortho_fft=True, + use_gpu=(cp is not None)) + + # Initial guess and step + x0 = xp.zeros(shape, dtype=xp.complex64) + step = dpc.guess_step() + + # Run APGD/FISTA + res = dpc.run(x_start=x0, + step=step, + max_iterations=200, + use_fista=True, + n_batch=None, # use all 4 patterns each iter + compute_batch_grad_parallel=True, + compute_cost=True, + compute_all_costs=True, + line_search_iter_limit=25, + line_search_factor=0.5, + xtol=1e-4, + label="[DPC3D] ") + + V_est = res["x"] # complex scattering potential (Z,Y,X) \ No newline at end of file From 591e4af5ecd778ec19165adc2c6818b64b90771f Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Wed, 31 Dec 2025 19:07:55 -0700 Subject: [PATCH 02/10] add mie theory based simulation for raw DPC data --- .gitignore | 1 + examples/reconstruct_dpc.py | 750 ---------- examples/simulate_dpc_microsphere.py | 2016 ++++++++++++++++++++++++++ pyproject.toml | 2 +- tests/mie_unittest.py | 3 + 5 files changed, 2021 insertions(+), 751 deletions(-) delete mode 100644 examples/reconstruct_dpc.py create mode 100644 examples/simulate_dpc_microsphere.py diff --git a/.gitignore b/.gitignore index a5cb8ca..39a7793 100644 --- a/.gitignore +++ b/.gitignore @@ -297,6 +297,7 @@ dask-worker-space/ # temporary files *.*~ +build/** # scratch space examples/data/** diff --git a/examples/reconstruct_dpc.py b/examples/reconstruct_dpc.py deleted file mode 100644 index 999f928..0000000 --- a/examples/reconstruct_dpc.py +++ /dev/null @@ -1,750 +0,0 @@ -from typing import Optional, Sequence, Union -import numpy as np -from mcsim.optimize import Optimizer -from mcsim.optimize.prox import tv_prox, soft_threshold, median_prox -try: - import zarr # must be >= 3 -except Exception as _e: - zarr = None - _zarr_err = _e - -try: - import cupy as cp -except ImportError: - cp = None - -# ----------------------- -# Alias for CPU/GPU arrays -# ----------------------- -if cp: - array = Union[np.ndarray, cp.ndarray] -else: - array = np.ndarray - -def _to_cpu(a): - """Return NumPy array from NumPy/CuPy input without copying if possible.""" - if cp is not None and isinstance(a, cp.ndarray): - return cp.asnumpy(a) - return np.asarray(a) - -class DPC3D(Optimizer): - """ - 3D Differential Phase-Contrast inverse problem with 4 half-circle source patterns. - - Forward model (Chen–Tian–Waller, BOE 2016, Eq. (6)): - Ĩ_ℓ(kz, ky, kx) = H_Re,ℓ(kz, ky, kx)·V̂_Re(kz, ky, kx) - + H_Im,ℓ(kz, ky, kx)·V̂_Im(kz, ky, kx), - - with H_Re (Eq. (7)) and H_Im (Eq. (8)) the 3D phase/absorption WOTFs determined - by the source S (left/right/top/bottom half-circles) and detection pupil P. - Arrays are stored and operated in [pattern, z, y, x] order in *reciprocal space*. - - Parameters - ---------- - shape_zyx : tuple[int, int, int] - Spatial volume shape (Z, Y, X). - voxel_size_zyx : tuple[float, float, float] - Voxel size (dz, dy, dx) in meters. - wavelength : float - Illumination wavelength in meters. - n_medium : float - Refractive index of immersion medium. - na_obj : float - Objective NA (detection pupil). - na_src : float - Illumination NA (condenser/source). - I_meas_4p : array | None - Optional measured intensity stacks with shape (4, Z, Y, X) in *space*. - If provided, they are normalized and FFT’d to produce I_hat_meas. - If None, you must later set self.I_hat_meas. - prox_parameters : dict | None - Prox controls: {'tv_re', 'tv_im', 'soft_im', 'positivity_re', 'median'}. - src_grid_N : int - Discretization of the source pupil per axis (odd recommended). - defocus_waves : float - Optional *defocus* aberration in waves (adds quadratic phase in P). - 0.0 means ideal pupil amplitude-only. - use_ortho_fft : bool - Orthonormal FFTs (recommended -> clean Lipschitz scaling). - use_gpu : bool - Force GPU if CuPy is available and True; else CPU. - """ - - PATTERNS = ("left", "right", "top", "bottom") - - def __init__(self, - shape_zyx: tuple[int, int, int], - voxel_size_zyx: tuple[float, float, float], - wavelength: float, - n_medium: float, - na_obj: float, - na_src: float, - I_meas_4p: Optional[array] = None, - prox_parameters: Optional[dict] = None, - src_grid_N: int = 65, - defocus_waves: float = 0.0, - use_ortho_fft: bool = True, - use_gpu: bool = False): - - Z, Y, X = map(int, shape_zyx) - dz, dy, dx = map(float, voxel_size_zyx) - - # ---------- backend ---------- - self.use_gpu = bool(use_gpu and (cp is not None)) - self.xp = cp if self.use_gpu else np - - # ---------- geometry ---------- - self.Z, self.Y, self.X = Z, Y, X - self.dz, self.dy, self.dx = dz, dy, dx - self.wavelength = float(wavelength) - self.n = float(n_medium) - self.na_obj = float(na_obj) - self.na_src = float(na_src) - self.k0_cyc = self.n / self.wavelength # cycles per meter - - # ---------- FFT norms ---------- - self.use_ortho_fft = bool(use_ortho_fft) - self._fft_norm = 'ortho' if self.use_ortho_fft else None - - # ---------- frequencies ---------- - xp = self.xp - fz = xp.fft.fftfreq(Z, d=dz) # cycles/m (unshifted) - fy = xp.fft.fftfreq(Y, d=dy) - fx = xp.fft.fftfreq(X, d=dx) - self.fz = fz - self.fy = fy - self.fx = fx - # Shifted z for splatting convenience - self.fz_shift = xp.fft.fftshift(fz) - self.dfz = float(self.fz_shift[1] - self.fz_shift[0]) - self.fz_min = float(self.fz_shift[0]) - self.fz_max = float(self.fz_shift[-1]) - - # ---------- pupil (amplitude 1 inside NA, optional defocus phase) ---------- - self.P = self._make_pupil(defocus_waves=defocus_waves) # (Y, X), complex - - # ---------- WOTFs for 4 patterns ---------- - self.H_re, self.H_im = self._build_wotf_4patterns(src_grid_N=src_grid_N) # (4,Z,Y,X), real - - # ---------- measured spectra ---------- - self.I_hat_meas = None - if I_meas_4p is not None: - self.set_measurements(I_meas_4p) - - # ---------- base Optimizer ---------- - super().__init__(n_samples=4, prox_parameters=prox_parameters or {}) - self._L_est = None # Lipschitz cache - - # =========================================================== - # Eq. (10) helpers: synthesize BF and DPC stacks from raw 4 patterns - # =========================================================== - @staticmethod - def synthesize_bf_dpc( - I_lrbt: array, - eps: float = 1e-12, - pair_norm: bool = False, - ) -> tuple[array, array, array]: - """ - Build brightfield (BF) and two DPC stacks from raw LEFT/RIGHT/TOP/BOTTOM. - - Parameters - ---------- - I_lrbt : array - Raw intensity stacks in *space*, shape (4, Z, Y, X) ordered as: - [left, right, top, bottom]. - eps : float - Numerical floor to avoid divide-by-zero. - pair_norm : bool - If False (default), use the paper's BF normalization: - I_BF = mean(L, R, T, B) - DPC_x = (R - L) / mean(I_BF) - DPC_y = (T - B) / mean(I_BF) - If True, use pairwise normalization (sometimes used in 2D DPC): - DPC_x = (R - L) / (R + L + eps) - DPC_y = (T - B) / (T + B + eps) - I_BF = (L + R + T + B)/4 (returned for completeness) - - Returns - ------- - I_BF : array - Brightfield stack, (Z, Y, X). - DPC_x : array - Left–Right differential (X-direction), (Z, Y, X). - DPC_y : array - Top–Bottom differential (Y-direction), (Z, Y, X). - - Notes - ----- - This follows the synthesis described around Eq. (10) of Chen, Tian & Waller (2016). - They normalize by a DC/brightfield term; when a separate background BF is not - available, they approximate it with the *average* brightfield intensity - measured from the data (WOA regime), which is what the default branch implements. :contentReference[oaicite:0]{index=0} - """ - L, R, T, B = I_lrbt[0], I_lrbt[1], I_lrbt[2], I_lrbt[3] - I_BF = (L + R + T + B) / 4.0 - - if pair_norm: - DPC_x = (R - L) / (R + L + eps) - DPC_y = (T - B) / (T + B + eps) - else: - # DC/brightfield normalization (Eq. 10 discussion) - I_DC = I_BF.mean() # scalar , cf. paper text - DPC_x = (R - L) / (I_DC + eps) - DPC_y = (T - B) / (I_DC + eps) - - return I_BF, DPC_x, DPC_y - - def set_measurements_from_raw_patterns( - self, - I_lrbt_space: array, - normalize: bool = True, - pair_norm: bool = False, - ) -> None: - """ - Normalize background (optional) and FFT the *raw* 4-pattern stacks. - - Parameters - ---------- - I_lrbt_space : array - Raw LEFT/RIGHT/TOP/BOTTOM volumes in *space*, (4, Z, Y, X). - normalize : bool - If True, do the same background/DC normalization used in set_measurements(). - pair_norm : bool - If you plan to *also* compute and use BF/DPC stacks elsewhere, this lets - you reproduce pairwise DPC for inspection. This flag does not affect the - raw-pattern spectra used by the forward model here. - """ - xp = self.xp - assert I_lrbt_space.shape == (4, self.Z, self.Y, self.X), "Use [pattern,z,y,x] ordering" - - I = xp.array(I_lrbt_space, copy=True) - - if normalize: - # simple per-pattern background removal - bg = xp.min(I, axis=(1, 2, 3), keepdims=True) - I = I - bg - # global DC scale as in Eq. (10) paragraph (use BF mean as DC) - I_BF, _, _ = self.synthesize_bf_dpc(I, pair_norm=pair_norm) - dc = float(I_BF.mean()) - I = I / max(dc, 1e-12) - - # store Fourier-domain measurement spectra for the *4* raw patterns - self.I_hat_meas = xp.fft.fftn(I, axes=(-3, -2, -1), norm=self._fft_norm) - - # (optional) convenience if you want BF/DPC spectra too (not used by current forward model) - def bf_dpc_kspace(self, I_lrbt_space: array, pair_norm: bool = False) -> tuple[array, array, array]: - """ - FFT of the synthesized BF/DPC stacks (space -> k-space). - Returns (I_BF_hat, DPCx_hat, DPCy_hat), all (Z, Y, X). - """ - xp = self.xp - I_BF, DPCx, DPCy = self.synthesize_bf_dpc(I_lrbt_space, pair_norm=pair_norm) - F = lambda v: xp.fft.fftn(v, axes=(-3, -2, -1), norm=self._fft_norm) - return F(I_BF), F(DPCx), F(DPCy) - - - # =========================================================== - # Public API - # =========================================================== - def set_measurements(self, I_meas_4p: array, dc: Optional[float] = None) -> None: - """ - Normalize and FFT the 4 raw pattern stacks (left/right/top/bottom). - - Normalization (Eq. (10) spirit): subtract background (min over z/y/x), - divide by a DC scalar (provided or mean over brightfield proxy). - """ - xp = self.xp - assert I_meas_4p.shape == (4, self.Z, self.Y, self.X) - I = xp.array(I_meas_4p, copy=True) - - # simple background remove (per-pattern) then single DC scale - bg = xp.min(I, axis=(1, 2, 3), keepdims=True) - I = I - bg - - if dc is None: - # approximate DC as average “brightfield” (mean over all patterns) - dc = float(xp.mean(I)) - I = I / max(dc, 1e-12) - - # 3D FFT per pattern to reciprocal space - self.I_hat_meas = xp.fft.fftn(I, axes=(-3, -2, -1), norm=self._fft_norm) - - # =========================================================== - # Optimizer interface: forward, adjoint, cost, grad, prox, step - # =========================================================== - def fwd_model(self, x: array, inds: Optional[Sequence[int]] = None) -> array: - xp = self.xp - if inds is None: - inds = range(4) - V_hat = xp.fft.fftn(x, axes=(-3, -2, -1), norm=self._fft_norm) # (Z,Y,X), complex - Vhat_re = V_hat.real - Vhat_im = V_hat.imag - Hre = self.H_re[xp.asarray(inds)] - Him = self.H_im[xp.asarray(inds)] - return Hre * Vhat_re + Him * Vhat_im # (len(inds),Z,Y,X), real - - def fwd_model_adjoint(self, y_hat: array, inds: Optional[Sequence[int]] = None) -> array: - xp = self.xp - if inds is None: - inds = range(4) - Hre = self.H_re[xp.asarray(inds)] - Him = self.H_im[xp.asarray(inds)] - # per-pattern k-space gradient parts (sum over patterns) - ghat_re = xp.sum(Hre * y_hat, axis=0) - ghat_im = xp.sum(Him * y_hat, axis=0) - g_hat = ghat_re + 1j * ghat_im - return xp.fft.ifftn(g_hat, axes=(-3, -2, -1), norm=self._fft_norm) - - def cost(self, x: array, inds: Optional[Sequence[int]] = None) -> array: - """ - Per-pattern data term; Optimizer takes the MEAN across patterns, - so the scalar descent criterion is automatically satisfied. - """ - xp = self.xp - if inds is None: - inds = range(4) - assert self.I_hat_meas is not None, "Call set_measurements(...) first." - pred = self.fwd_model(x, inds=inds) - resid = pred - self.I_hat_meas[xp.asarray(inds)] - vol = self.Z * self.Y * self.X - return 0.5 * xp.sum(xp.abs(resid) ** 2, axis=(-3, -2, -1)) / vol # (len(inds),) - - def gradient(self, x: array, inds: Optional[Sequence[int]] = None) -> array: - xp = self.xp - if inds is None: - inds = range(4) - assert self.I_hat_meas is not None, "Call set_measurements(...) first." - pred = self.fwd_model(x, inds=inds) - resid = pred - self.I_hat_meas[xp.asarray(inds)] - Hre = self.H_re[xp.asarray(inds)] - Him = self.H_im[xp.asarray(inds)] - ghat_re = Hre * resid - ghat_im = Him * resid - g_hat = ghat_re + 1j * ghat_im - return xp.fft.ifftn(ghat_re + 1j * ghat_im, axes=(-3, -2, -1), norm=self._fft_norm) - - def prox(self, x: array, step: float) -> array: - # reuse your provided helpers (tv_prox, soft_threshold, median_prox) - xp = self.xp - v_re = x.real - v_im = x.imag - - tv_re = self.prox_parameters.get('tv_re', 0.0) or 0.0 - tv_im = self.prox_parameters.get('tv_im', 0.0) or 0.0 - if tv_re > 0: - v_re = tv_prox(v_re, tau=step * tv_re) - if tv_im > 0: - v_im = tv_prox(v_im, tau=step * tv_im) - - soft_im = self.prox_parameters.get('soft_im', 0.0) or 0.0 - if soft_im > 0: - v_im = soft_threshold(step * soft_im, v_im) - - med = self.prox_parameters.get('median', None) - if med is not None: - v_re = median_prox(v_re, size=med) - v_im = median_prox(v_im, size=med) - - if self.prox_parameters.get('positivity_re', False): - v_re = xp.maximum(v_re, 0) - - return v_re + 1j * v_im - - def guess_step(self, x: Optional[array] = None) -> float: - if self._L_est is None: - xp = self.xp - power = xp.abs(self.H_re) ** 2 + xp.abs(self.H_im) ** 2 # (4,Z,Y,X) - self._L_est = float(xp.max(xp.sum(power, axis=0))) # max_k sum over patterns - self._L_est = max(self._L_est, 1e-8) - return 0.9 / self._L_est - - # =========================================================== - # Internals: Pupil, Source, WOTF construction - # =========================================================== - def _make_pupil(self, defocus_waves: float = 0.0) -> array: - """ - Ideal circular pupil (amplitude 1 inside NA). Optional defocus phase. - Returns complex array of shape (Y, X). - """ - xp = self.xp - fy, fx = xp.meshgrid(self.fy, self.fx, indexing="ij") - rho = xp.sqrt(fx**2 + fy**2) / self.k0_cyc - P_amp = (rho <= self.na_obj).astype(self.xp.float32) - if defocus_waves == 0.0: - return P_amp.astype(self.xp.complex64) - # Defocus Zernike ~ 2*r^2 - 1 on unit disk (simple quadratic phase) - r = xp.clip(rho / self.na_obj, 0.0, 1.0) - phase = (2.0 * r**2 - 1.0) * (2.0 * np.pi * defocus_waves) - return P_amp * xp.exp(1j * phase) - - def _build_wotf_4patterns(self, src_grid_N: int = 65) -> tuple[array, array]: - """ - Numerically assemble H_Re/H_Im on the reciprocal grid for 4 DPC patterns. - - Implementation notes - -------------------- - * Non-paraxial kinematics: fz = kz(q+u) - kz(u), kz(w) = sqrt(f0^2 - |w|^2), - in *cycles/m* (no 2π). We splat both +fz and -fz contributions. - * Absorption TF is even in z (add both signs); Phase TF is odd (subtract). - * Amplitude-only pupil assumed (|P| in {0,1}). Aberration phase can be - added in P if desired; mixing terms then require complex products. - """ - xp = self.xp - Z, Y, X = self.Z, self.Y, self.X - - # allocate shifted-z buffers for easy splatting, then unshift - Hre_s = xp.zeros((4, Z, Y, X), dtype=self.xp.float32) - Him_s = xp.zeros_like(Hre_s) - - f0 = self.k0_cyc - fy, fx = xp.meshgrid(self.fy, self.fx, indexing="ij") # (Y,X) - - # precompute |P(q+u)| mask efficiently by evaluating radius threshold - P_amp = xp.abs(self.P) # (Y,X) in {0,1} - - # discretize source pupil on a square grid, mask to circle of radius na_src - uu = xp.linspace(-self.na_src * f0, self.na_src * f0, src_grid_N) - uy_grid, ux_grid = xp.meshgrid(uu, uu, indexing="ij") - src_r = xp.sqrt(ux_grid**2 + uy_grid**2) - in_src = src_r <= (self.na_src * f0) - - # four half-circles: left (ux<0), right (ux>0), top (uy>0), bottom (uy<0) - half_masks = [ - in_src & (ux_grid < 0), # left - in_src & (ux_grid > 0), # right - in_src & (uy_grid > 0), # top - in_src & (uy_grid < 0), # bottom - ] - - # weights: uniform over active source samples per pattern - for pidx, src_mask in enumerate(half_masks): - ux_list = ux_grid[src_mask] - uy_list = uy_grid[src_mask] - if ux_list.size == 0: - continue - w = xp.ones_like(ux_list, dtype=self.xp.float32) - w /= float(ux_list.size) - - # iterate source samples (vectorized over q=(fy,fx)) - for uxi, uyi, wi in zip(ux_list, uy_list, w): - # unscattered ray must pass through pupil |P(u)|>0 - Pup_u = (xp.sqrt((uxi)**2 + (uyi)**2) <= (self.na_obj * f0)) - if not bool(Pup_u): - continue - - # lateral shift q+u - fx_shift = fx + uxi - fy_shift = fy + uyi - rad_shift = xp.sqrt(fx_shift**2 + fy_shift**2) - - # scattered ray must pass through pupil - pass_mask = (rad_shift <= (self.na_obj * f0)).astype(self.xp.float32) - if not xp.any(pass_mask): - continue - - # axial frequency difference (non-paraxial) - kz_u = xp.sqrt(xp.maximum(0.0, f0**2 - (uxi**2 + uyi**2))) - kz_qu = xp.sqrt(xp.maximum(0.0, f0**2 - rad_shift**2)) - dfz = kz_qu - kz_u # (Y,X) cycles/m - - # contributions only where valid (pass_mask) - A = wi * pass_mask # amplitude weight; |P(q+u)|*|P(u)|=1 here - - # splat +dfz and -dfz onto shifted fz grid - self._splat_plane(Hre_s, Him_s, pidx, A, +dfz) - self._splat_plane(Hre_s, Him_s, pidx, A, -dfz) - - # unshift z -> native FFT ordering - Hre = xp.fft.ifftshift(Hre_s, axes=(-3)) - Him = xp.fft.ifftshift(Him_s, axes=(-3)) - return Hre, Him - - def _splat_plane(self, - Hre_s: array, - Him_s: array, - pidx: int, - A_yx: array, - dfz_yx: array) -> None: - """ - Deposit weights at z-planes nearest to dfz(y,x) on *shifted* fz axis. - - H_Im: add +A - H_Re: add +A for +dfz, and -A for -dfz (caller passes sign via dfz) - """ - xp = self.xp - Z, Y, X = self.Z, self.Y, self.X - - # map dfz -> fractional index on shifted axis - # clamp to Nyquist - dfz_clamped = xp.clip(dfz_yx, self.fz_min, self.fz_max) - zf = (dfz_clamped - self.fz_min) / self.dfz # fractional in [0, Z-1] - z0 = xp.floor(zf).astype(self.xp.int32) - z1 = xp.clip(z0 + 1, 0, Z - 1) - alpha = (zf - z0).astype(self.xp.float32) - - # odd/even combination for phase/absorption - # H_Im receives +A; H_Re receives sign(dfz)*A (odd in z) - sign = xp.where(dfz_yx >= 0, 1.0, -1.0).astype(self.xp.float32) - - # Flatten indexing for scatter-add - yy, xx = xp.meshgrid(xp.arange(Y), xp.arange(X), indexing="ij") - w0 = (1.0 - alpha) * A_yx - w1 = alpha * A_yx - - # absorption (even) - self._add_at(Him_s, (pidx, z0, yy, xx), w0) - self._add_at(Him_s, (pidx, z1, yy, xx), w1) - - # phase (odd) - self._add_at(Hre_s, (pidx, z0, yy, xx), sign * w0) - self._add_at(Hre_s, (pidx, z1, yy, xx), sign * w1) - - def _add_at(self, arr4: array, idxs: tuple, weights: array) -> None: - """backend-safe scatter add into 4D [pattern,z,y,x] (shifted-z buffers).""" - xp = self.xp - p, z, y, x = idxs - if self.use_gpu: - # CuPy supports add.at - cp.add.at(arr4, (p, z, y, x), weights) - else: - np.add.at(arr4, (p, z, y, x), np.asarray(weights)) - - # ------------------------- - # Zarr v3 save / load - # ------------------------- - def save_wotf_zarr( - self, - store_path: str, - *, - overwrite: bool = False, - chunks: Optional[tuple[int, int, int, int]] = None, - include_pupil: bool = False, - compressor: Optional[dict] = None, - ) -> None: - """ - Save WOTFs to a Zarr v3 store on disk. - - Parameters - ---------- - store_path : str - Directory path for the Zarr store (e.g., "./wotf_cache.zarr"). - overwrite : bool - If True, remove any existing store_path first. - chunks : tuple[int, int, int, int] | None - Chunk shape for arrays in [pattern,z,y,x]. Default is (1, min(16,Z), min(128,Y), min(128,X)). - include_pupil : bool - If True, also save Pupil (as P_real/P_imag with shape (Y, X)). - compressor : dict | None - Optional Zarr v3 compressor config, e.g. {"id": "zstd", "level": 3}. - """ - if zarr is None: - raise RuntimeError(f"zarr>=3 required but not available: {_zarr_err}") - - Z, Y, X = self.Z, self.Y, self.X - H_re_cpu = _to_cpu(self.H_re).astype(np.float32, copy=False) - H_im_cpu = _to_cpu(self.H_im).astype(np.float32, copy=False) - - if chunks is None: - chunks = (1, min(16, Z), min(128, Y), min(128, X)) - - # Prepare store directory - if os.path.exists(store_path): - if overwrite: - # safest removal: remove directory tree - import shutil - shutil.rmtree(store_path) - else: - raise FileExistsError(f"{store_path} exists. Use overwrite=True to replace.") - - root = zarr.open_group(store_path, mode="w", zarr_version=3) - - # Root metadata - meta = dict( - version=1, - order="pattern,z,y,x", - shape_zyx=(Z, Y, X), - voxel_size_zyx=(self.dz, self.dy, self.dx), - wavelength=self.wavelength, - n_medium=self.n, - na_obj=self.na_obj, - na_src=self.na_src, - use_ortho_fft=bool(self.use_ortho_fft), - dtype="float32", - notes="H_re/H_im are real-valued WOTFs; pupil saved as P_real/P_imag if requested.", - ) - root.attrs["dpc3d_meta_json"] = json.dumps(meta) - - # Create arrays - aHre = root.create_array( - "H_re", - shape=(4, Z, Y, X), - chunks=chunks, - dtype="f4", - compressor=compressor, - ) - aHim = root.create_array( - "H_im", - shape=(4, Z, Y, X), - chunks=chunks, - dtype="f4", - compressor=compressor, - ) - - aHre[:] = H_re_cpu - aHim[:] = H_im_cpu - - if include_pupil: - P = _to_cpu(self.P) - root.create_array("P_real", shape=(Y, X), chunks=(min(256, Y), min(256, X)), - dtype="f4", compressor=compressor)[:] = P.real.astype(np.float32, copy=False) - root.create_array("P_imag", shape=(Y, X), chunks=(min(256, Y), min(256, X)), - dtype="f4", compressor=compressor)[:] = P.imag.astype(np.float32, copy=False) - - def load_wotf_from_zarr( - self, - store_path: str, - *, - strict_geometry: bool = True, - map_to_gpu: Optional[bool] = None, - ) -> None: - """ - Load WOTFs from a Zarr v3 store and install into this instance. - - Parameters - ---------- - store_path : str - Directory path of the Zarr store (e.g., "./wotf_cache.zarr"). - strict_geometry : bool - If True, validate Z,Y,X and optics (λ, n, NA) against this instance. - map_to_gpu : bool | None - If True (and CuPy available), map arrays to GPU. If None, follow current instance backend. - """ - if zarr is None: - raise RuntimeError(f"zarr>=3 required but not available: {_zarr_err}") - - root = zarr.open_group(store_path, mode="r", zarr_version=3) - - # Load metadata - meta_json = root.attrs.get("dpc3d_meta_json", "{}") - try: - meta = json.loads(meta_json) - except Exception: - meta = {} - - # Read arrays (NumPy on load) - H_re_np = np.asarray(root["H_re"][:], dtype=np.float32) - H_im_np = np.asarray(root["H_im"][:], dtype=np.float32) - - if strict_geometry: - mZ, mY, mX = tuple(meta.get("shape_zyx", ())) - if (mZ, mY, mX) != (self.Z, self.Y, self.X): - raise ValueError(f"Cached WOTF shape ZYX {mZ,mY,mX} != current {(self.Z,self.Y,self.X)}") - # optics check (tolerant) - def _close(a, b, tol=1e-9) -> bool: - return abs(float(a) - float(b)) <= tol * max(1.0, abs(float(a)), abs(float(b))) - if not ( - _close(meta.get("wavelength", self.wavelength), self.wavelength) and - _close(meta.get("n_medium", self.n), self.n) and - _close(meta.get("na_obj", self.na_obj), self.na_obj) and - _close(meta.get("na_src", self.na_src), self.na_src) - ): - raise ValueError("Cached optics (λ, n, NA) differ from current instance.") - - # Map to GPU if requested - use_gpu = self.use_gpu if map_to_gpu is None else (bool(map_to_gpu) and (cp is not None)) - if use_gpu: - self.H_re = cp.asarray(H_re_np) - self.H_im = cp.asarray(H_im_np) - # Optional pupil - if "P_real" in root and "P_imag" in root: - Pre = cp.asarray(np.asarray(root["P_real"][:], dtype=np.float32)) - Pim = cp.asarray(np.asarray(root["P_imag"][:], dtype=np.float32)) - self.P = (Pre + 1j * Pim).astype(cp.complex64, copy=False) - else: - self.H_re = H_re_np - self.H_im = H_im_np - if "P_real" in root and "P_imag" in root: - Pre = np.asarray(root["P_real"][:], dtype=np.float32) - Pim = np.asarray(root["P_imag"][:], dtype=np.float32) - self.P = (Pre + 1j * Pim).astype(np.complex64, copy=False) - - self._L_est = None # recompute step bound on next guess_step() - - @classmethod - def from_cached_wotf_zarr( - cls, - store_path: str, - *, - shape_zyx: tuple[int, int, int], - voxel_size_zyx: tuple[float, float, float], - wavelength: float, - n_medium: float, - na_obj: float, - na_src: float, - prox_parameters: Optional[dict] = None, - use_ortho_fft: bool = True, - use_gpu: bool = False, - ): - """ - Construct a DPC3D instance and populate WOTFs from a Zarr v3 cache. - - Notes - ----- - Pupil is rebuilt analytically; if P_real/P_imag exist in the store, they - overwrite the analytic P for exact reproducibility. - """ - obj = cls( - shape_zyx=shape_zyx, - voxel_size_zyx=voxel_size_zyx, - wavelength=wavelength, - n_medium=n_medium, - na_obj=na_obj, - na_src=na_src, - I_meas_4p=None, - prox_parameters=prox_parameters, - src_grid_N=3, # placeholder; WOTF will be loaded - defocus_waves=0.0, - use_ortho_fft=use_ortho_fft, - use_gpu=use_gpu, - ) - obj.load_wotf_from_zarr(store_path, strict_geometry=True, map_to_gpu=use_gpu) - return obj - -if __name__ == "__main__": - xp = cp if cp is not None else np # choose GPU if available - - # Geometry (example) - shape = (100, 256, 256) # Z,Y,X voxels - voxel = (1e-6, 0.2e-6, 0.2e-6) # dz,dy,dx in meters - λ = 520e-9 # wavelength (m) - n0 = 1.33 # medium RI - NA_obj = 0.65 - NA_src = 0.65 - - # Build problem - dpc = DPC3D(shape, voxel, λ, n0, NA_obj, NA_src, - I_meas_4p=your_raw_4pattern_stacks, # (4,Z,Y,X) in space - prox_parameters=dict(tv_re=1e-3, positivity_re=True), - src_grid_N=65, - use_ortho_fft=True, - use_gpu=(cp is not None)) - - # Initial guess and step - x0 = xp.zeros(shape, dtype=xp.complex64) - step = dpc.guess_step() - - # Run APGD/FISTA - res = dpc.run(x_start=x0, - step=step, - max_iterations=200, - use_fista=True, - n_batch=None, # use all 4 patterns each iter - compute_batch_grad_parallel=True, - compute_cost=True, - compute_all_costs=True, - line_search_iter_limit=25, - line_search_factor=0.5, - xtol=1e-4, - label="[DPC3D] ") - - V_est = res["x"] # complex scattering potential (Z,Y,X) \ No newline at end of file diff --git a/examples/simulate_dpc_microsphere.py b/examples/simulate_dpc_microsphere.py new file mode 100644 index 0000000..5f89f16 --- /dev/null +++ b/examples/simulate_dpc_microsphere.py @@ -0,0 +1,2016 @@ +""" +Synthetic DPC image generation using Mie-theory scattered fields from a sphere, +with a simple synthetic imaging model (objective pupil + camera sampling). + +Overview +-------- +For each LED on a centered board, we: +1) map its board position to an incidence direction (NA_x, NA_y), +2) compute the complex vector electric field around/through a sphere using Mie theory + (`mie_fields.mie_efield`), +3) apply a coherent imaging operator representing the objective/tube-lens relay + (modeled as a circular pupil in spatial-frequency space with cutoff NA/lambda), +4) compute irradiance I(x,y) = sum_{c in {x,y,z}} |E_c(x,y)|^2, +5) optionally bin/average to camera pixels (camera integration / sampling), +6) incoherently sum irradiances for LEDs in left/right/up/down half-plane patterns. + +Key implementation requirement +------------------------------ +A 64 x 64 board has 4096 LEDs. This module simulates each active LED exactly once +(after circular/annular mask and optional subsampling), and accumulates that LED's +irradiance into any applicable DPC pattern sums (left/right/up/down). This avoids +redundant Mie computations across patterns. + +Optical model included (requested) +---------------------------------- +- Objective: circular pupil with coherent cutoff f_c = NA / lambda (cycles/µm). +- Tube lens: assumed ideal/infinity-corrected; its role is magnification. +- Camera: sampling and optional pixel integration via binning from an oversampled grid. + +This is a first-order synthetic imaging model suitable for generating test data and +debugging inverse-model implementations. + +Zarr output +----------- +Helpers are provided to write a Zarr dataset in the exact layout expected by the +unit tests in `test_dpc_idt_reconstruction.py`: +- root array: /dpc, shape (4, ny, nx), dtype float32 +- root attrs: wavelength_um, na_obj, camera_pixel_um, magnification, n_medium, + led_grid_shape, pattern_order, plus optional nz and z_span_um + +Dependencies +------------ +- mie_fields.py (user-provided): provides `mie_efield(...)`. + +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Sequence + +import numpy as np +import hashlib +import json +import os +from tqdm import trange + +try: + import cupy as cp # Optional; mie_fields supports GPU mode +except ImportError: # pragma: no cover + cp = None + +try: + from localize_psf.camera import simulated_img +except ImportError: # pragma: no cover + simulated_img = None + +from mcsim.analysis.mie_fields import mie_efield +from mcsim.analysis.fft import ft2, ift2 +try: + from mcsim.analysis.field_prop import propagate_homogeneous +except Exception as _e: # pragma: no cover + propagate_homogeneous = None + + +if cp: + array = np.ndarray | cp.ndarray +else: + array = np.ndarray + + +PatternName = Literal["left", "right", "up", "down"] + + +@dataclass(frozen=True, slots=True) +class SphereSpec: + """ + Optical / geometric description of the sphere for Mie simulation. + + Parameters + ---------- + radius_um : float + Sphere radius (µm). Forwarded to `mie_fields.mie_efield` and used to pick + a safe evaluation plane outside the sphere. + n_sphere : complex or float + Complex refractive index of the sphere (real part for phase, imaginary for + absorption). Passed directly to `mie_fields.mie_efield`. + n_medium : float + Surrounding medium refractive index used by `mie_efield` and + `na_to_incidence_angles`. + """ + radius_um: float + n_sphere: complex | float + n_medium: float = 1.0 + + +@dataclass(frozen=True, slots=True) +class SimulationSpace: + """ + Grid used for field simulation (object space). + + Parameters + ---------- + dxy_um : float + Sampling pitch (µm) on the simulation grid passed to `mie_fields.mie_efield` + and `field_prop.propagate_homogeneous`. + esize : tuple[int, int] + Grid shape (ny, nx) for the simulated field prior to camera binning. + z_plane_um : float or None + Axial distance from the sphere center for field evaluation before refocus. + If None, a plane just outside the sphere is chosen for stability. + """ + dxy_um: float + esize: tuple[int, int] + z_plane_um: float | None = None + + +@dataclass(frozen=True, slots=True) +class CameraSpace: + """ + Description of the camera grid and noise model (image space). + + Parameters + ---------- + pixel_um : float + Camera pixel pitch (µm) in the camera plane; combined with `magnification` + to compute the `bin_size` for `localize_psf.camera.simulated_img`. + magnification : float + Object-to-camera magnification; larger values increase the bin factor between + simulation and camera grids. + shape : tuple[int, int] + Final camera image shape (ny, nx) returned by `simulated_img`. + psf : array or None + Optional PSF forwarded to `simulated_img` for blurring. + apodization : array or int or float + Apodization factor forwarded to `simulated_img` during PSF blurring (1 disables). + gains : array or float + Multiplicative conversion from photons to ADU (ADU/e) applied in `simulated_img`. + offsets : array or float + Additive camera offset (ADU) applied after gain and readout noise. + readout_noise_sds : array or float + Standard deviation (ADU) of Gaussian readout noise added in `simulated_img`. + photon_shot_noise : bool + If True, enable Poisson shot noise in `simulated_img`. + saturation : int or None + Clip simulated images above this value in `simulated_img`. + image_is_integer : bool + If True, round the final simulated image to integers in `simulated_img`. + """ + pixel_um: float + magnification: float + shape: tuple[int, int] + psf: array | None = None # type: ignore + apodization: array | int | float = 1 # type: ignore + gains: array | float = 1.0 # type: ignore + offsets: array | float = 0.0 # type: ignore + readout_noise_sds: array | float = 0.0 # type: ignore + photon_shot_noise: bool = False + saturation: int | None = None + image_is_integer: bool = False + + @property + def object_pixel_um(self) -> float: + return float(self.pixel_um) / float(self.magnification) + + +def _get_xp(use_gpu: bool): + """ + Select NumPy or CuPy module based on GPU usage. + + Parameters + ---------- + use_gpu : bool + If True and CuPy is available, return `cupy`; otherwise return `numpy`. + + Returns + ------- + module + Backend array module. + """ + if use_gpu and (cp is not None): + return cp + return np + + +def _to_xp(x: array, *, use_gpu: bool) -> array: # type: ignore + """ + Explicitly move arrays to the requested backend. + + Parameters + ---------- + x : array + Input array (NumPy, CuPy, or dask). + use_gpu : bool + If True, ensure the output is a CuPy array; otherwise ensure NumPy. + + Returns + ------- + array + Array on the requested backend, avoiding implicit host/device transfers. + """ + if use_gpu: + if cp is None: + raise ImportError("use_gpu=True requested but CuPy is not available.") + return cp.asarray(x) + if cp is not None and isinstance(x, cp.ndarray): + return cp.asnumpy(x) + return x + + +def _nan_to_zero(x: array, *, use_gpu: bool) -> array: # type: ignore + """ + Replace NaN/Inf with 0 on the requested backend. + + Parameters + ---------- + x : array + Input array possibly containing NaN/Inf. + use_gpu : bool + If True, operate with CuPy; otherwise NumPy. + + Returns + ------- + array + Cleaned array on the requested backend with NaN/Inf replaced by 0. + """ + x = _to_xp(x, use_gpu=use_gpu) + xp = cp if (use_gpu and cp is not None) else np + # Both NumPy and CuPy implement nan_to_num (including for complex dtypes). + return xp.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + +def _json_dumps_stable(obj) -> str: + """ + Stable JSON dump for hashing parameters. + + Parameters + ---------- + obj : Any + Object to serialize. + + Returns + ------- + str + JSON string with deterministic key ordering. + """ + return json.dumps(obj, sort_keys=True, separators=(",", ":"), default=str) + + +def _cache_key(params: dict) -> str: + """ + Create a short stable cache key for a simulation run. + + Parameters + ---------- + params : dict + Parameter dictionary to hash. + + Returns + ------- + str + SHA1-based short key (first 16 hex chars). + """ + import hashlib + + s = _json_dumps_stable(params).encode("utf-8") + return hashlib.sha1(s).hexdigest()[:16] + + +def _cache_run_dir(cache_dir: str | Path, params: dict) -> Path: + """ + Create (or reuse) the cache directory for a parameter set. + + Parameters + ---------- + cache_dir : str or Path + Base directory for simulation caches. + params : dict + Parameter dictionary hashed to form a unique run key. + + Returns + ------- + Path + Path to the run-specific cache directory containing metadata and Zarr stores. + """ + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + key = _cache_key(params) + run_dir = cache_dir / key + run_dir.mkdir(parents=True, exist_ok=True) + meta_path = run_dir / "meta.json" + if not meta_path.exists(): + meta_path.write_text(_json_dumps_stable(params) + "\n") + return run_dir + + +def _open_led_cache_zarr( + run_dir: Path, + *, + n_led: int, + n_planes: int, + ny: int, + nx: int, + chunks: tuple[int, int, int, int], +): + """ + Open (or create) a per-LED irradiance cache as a Zarr store. + + Parameters + ---------- + run_dir : Path + Directory in which to create the cache. + n_led : int + Number of LED entries to cache. + n_planes : int + Number of axial planes per LED (for focal stacks). + ny, nx : int + Camera image dimensions. + chunks : tuple[int, int, int, int] + Chunk shape for the cached array (n_led, n_planes, ny, nx). + + Returns + ------- + group, I_arr, done + Zarr group, cached camera irradiance array, and completion mask. + + Notes + ----- + We cache the binned camera irradiance (float32) to reduce size and avoid recomputation. + """ + try: + import zarr # type: ignore + except Exception as e: # pragma: no cover + raise ImportError("zarr is required for caching") from e + + store_path = run_dir / "led_cache.zarr" + g = zarr.open_group(str(store_path), mode="a") + + g.attrs["n_led"] = int(n_led) + g.attrs["n_planes"] = int(n_planes) + g.attrs["ny"] = int(ny) + g.attrs["nx"] = int(nx) + + if "I_cam" in g: + I_arr = g["I_cam"] + if tuple(I_arr.shape) != (int(n_led), int(n_planes), int(ny), int(nx)): + raise ValueError( + f"Existing LED cache shape {tuple(I_arr.shape)} != {(int(n_led), int(n_planes), int(ny), int(nx))}" + ) + else: + create_kwargs = { + "shape": (int(n_led), int(n_planes), int(ny), int(nx)), + "dtype": "float32", + } + # Use chunks kwarg for compatibility with current Zarr API + I_arr = g.create_array("I_cam", chunks=chunks, **create_kwargs) # type: ignore[call-arg] + + if "done" in g: + done = g["done"] + if tuple(done.shape) != (int(n_led),): + raise ValueError(f"Existing LED cache done shape {tuple(done.shape)} != {(int(n_led),)}") + else: + try: + done = g.create_array("done", shape=(int(n_led),), dtype="u1", overwrite=False) + except (TypeError, AttributeError): + done = g.create_array("done", shape=(int(n_led),), dtype="u1") + done[...] = 0 + + return g, I_arr, done + + +def _zarr_write_array(g, name: str, data: np.ndarray, *, dtype: str = "float32", chunks=None, overwrite: bool = True): + """ + Write an array using Zarr v3 `Group.create_array`. + + Parameters + ---------- + g : zarr.Group + Target group. + name : str + Dataset name. + data : np.ndarray + Array to write (converted to dtype). + dtype : str + Target dtype for the stored array. + chunks : tuple[int, ...] or None + Chunk shape passed as `chunk_shape` to `create_array`. If None, let Zarr decide. + overwrite : bool + Included for API compatibility; Zarr v3 `create_array` overwrites only if allowed by mode. + + Returns + ------- + zarr.Array + Written array. + """ + shape = tuple(int(s) for s in data.shape) + create_kwargs = { + "shape": shape, + "dtype": dtype, + } + + if chunks is not None: + arr = g.create_array(name, chunks=chunks, **create_kwargs) # type: ignore[call-arg] + else: + arr = g.create_array(name, **create_kwargs) + arr[...] = data + return arr + +def sample_pixel_size_um(camera_pixel_um: float, magnification: float, *, oversample: int = 1) -> float: + """ + Convert camera pixel size to object-plane pixel size, optionally oversampled. + + Parameters + ---------- + camera_pixel_um : float + Camera pixel pitch (micrometers). This matches `CameraSpace.pixel_um`. + magnification : float + System magnification (object→camera). Higher magnification reduces the + effective object-plane sampling. + oversample : int + Oversampling factor. If >1, the object-plane sampling is divided by this + factor (finer simulation grid), and images can be binned back to camera + sampling using `bin2_average` or `simulated_img(bin_size=oversample)`. + + Returns + ------- + dxy_um : float + Object-plane pixel size (micrometers). + """ + if magnification <= 0: + raise ValueError("magnification must be > 0") + if oversample < 1: + raise ValueError("oversample must be >= 1") + return float(camera_pixel_um) / float(magnification) / float(oversample) + + + +def _default_z_plane_um(*, sphere_radius_um: float, dxy_um: float, wavelength_um: float) -> float: + """ + Choose a stable observation plane for Mie evaluation. + + The Mie scattered-field expansion is unstable near r=0; this picks an exterior + plane to evaluate fields. + + Parameters + ---------- + sphere_radius_um : float + Sphere radius (µm). + dxy_um : float + Object-plane sampling (µm). + wavelength_um : float + Vacuum wavelength (µm). + + Returns + ------- + float + Axial distance (µm) from sphere center to evaluation plane. + """ + safety = float(max(float(dxy_um), 0.25 * float(wavelength_um))) + return float(sphere_radius_um) + safety + + +def _safe_z_plane_um( + z_plane_um: float, + *, + sphere_radius_um: float, + dxy_um: float, + wavelength_um: float, +) -> float: + """Ensure the Mie evaluation plane lies outside the sphere. + + If |z_plane_um| is less than radius + safety_margin, we override it to + sign(z_plane_um) * (radius + safety_margin). If z_plane_um==0, we choose +. + + This avoids evaluating the scattered-field expansion at/near r=0 and prevents + overflow/NaN in irradiance formation. + """ + safety = float(max(float(dxy_um), 0.01 * float(wavelength_um))) + z_min = float(sphere_radius_um) + safety + z = float(z_plane_um) + if abs(z) < z_min: + sgn = 1.0 if z == 0.0 else float(np.sign(z)) + return sgn * z_min + return z + + + +def make_led_na_positions( + ny_led: int, + nx_led: int, + *, + na_obj: float, + inner_na: float = 0.0, + include_center: bool = False, +) -> np.ndarray: + """ + Generate LED positions in NA space for a centered rectangular board, masked to an inscribed circle. + + Assumptions + ---------- + - Board is centered on the objective. + - The outermost LED inside the circular mask maps to `na_obj`. + - LEDs lie on a regular grid; positions are scaled to NA units. + + Parameters + ---------- + ny_led, nx_led : int + LED grid shape (64, 64 for the requested board). + na_obj : float + Objective NA; sets max LED radius in NA space. Passed later to + `DPCMieSimulator` to form the pupil cutoff. + inner_na : float + Inner NA radius (>=0) to form an annulus. Default 0 (filled disk). + include_center : bool + If False, excludes the center LED (NA=0). + + Returns + ------- + na_xy : np.ndarray + array of shape (N, 2) with columns (NA_x, NA_y). + """ + if ny_led < 1 or nx_led < 1: + raise ValueError("ny_led and nx_led must be >= 1") + if inner_na < 0 or inner_na >= na_obj: + raise ValueError("inner_na must satisfy 0 <= inner_na < na_obj") + + cy = (ny_led - 1) / 2.0 + cx = (nx_led - 1) / 2.0 + + yy, xx = np.meshgrid( + np.arange(ny_led, dtype=np.float32), + np.arange(nx_led, dtype=np.float32), + indexing="ij", + ) + x = xx - cx + y = yy - cy + r = np.sqrt(x * x + y * y) + + # Use inscribed circle to avoid corners. + r_circle = float(min(cx, cy)) if (cx > 0 and cy > 0) else float(np.max(r)) + mask = r <= r_circle + + if not include_center: + mask = mask & (r > 0) + + r_mask_max = float(np.max(r[mask])) if np.any(mask) else 1.0 + scale = float(na_obj) / r_mask_max + + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + + ann = na_r >= float(inner_na) + na_xy = np.stack([na_x[ann], na_y[ann]], axis=1) + return na_xy + + +def split_dpc_patterns( + na_xy: np.ndarray, + *, + order: Sequence[PatternName] = ("left", "right", "up", "down"), +) -> dict[PatternName, np.ndarray]: + """ + Split LED NA positions into canonical DPC half-circle patterns. + + Pattern membership is based on the sign of NA components: + - left: NA_x < 0 + - right: NA_x > 0 + - up: NA_y > 0 + - down: NA_y < 0 + + Parameters + ---------- + na_xy : np.ndarray + LED positions in NA units, shape (N, 2). + order : Sequence[PatternName] + Must be a permutation of ("left","right","up","down"). Used for validation. + + Returns + ------- + patterns : dict[PatternName, np.ndarray] + Mapping pattern name -> LED NA positions in that pattern, each of shape (N_p, 2). + These positions are later iterated in `DPCMieSimulator.simulate_patterns` + when calling `mie_efield` for each LED. + """ + allowed: tuple[PatternName, ...] = ("left", "right", "up", "down") + if len(order) != 4 or set(order) != set(allowed): + raise ValueError(f"order must be a permutation of {allowed}; got {tuple(order)}") + + na_x = na_xy[:, 0] + na_y = na_xy[:, 1] + + masks = { + "left": na_x < 0, + "right": na_x > 0, + "up": na_y > 0, + "down": na_y < 0, + } + + out: dict[PatternName, np.ndarray] = {} + for name in allowed: + pts = na_xy[masks[name]] + if pts.shape[0] == 0: + raise ValueError(f"pattern '{name}' has zero LEDs (check board geometry).") + out[name] = pts + + return out + + +def na_to_incidence_angles( + na_x: float, + na_y: float, + *, + n_medium: float = 1.0, +) -> tuple[float, float]: + """ + Convert (NA_x, NA_y) into incidence direction angles for `mie_efield`. + + We interpret NA components as: + NA_x = n_medium * sin(theta) * cos(phi) + NA_y = n_medium * sin(theta) * sin(phi) + + therefore: + sin(theta) = sqrt(NA_x^2 + NA_y^2) / n_medium + phi = atan2(NA_y, NA_x) + + Parameters + ---------- + na_x, na_y : float + NA components (dimensionless). + n_medium : float + Medium refractive index. + + Returns + ------- + beam_theta, beam_phi : float + Angles (radians) to pass as `beam_theta` and `beam_phi`. + `beam_psi` can be set to 0 for x-polarized input in mie_fields. + These are forwarded unmodified to `mie_fields.mie_efield`. + """ + na_r = float(np.sqrt(na_x * na_x + na_y * na_y)) + s = na_r / float(n_medium) + s = float(np.clip(s, 0.0, 1.0)) + theta = float(np.arcsin(s)) + phi = float(np.arctan2(na_y, na_x)) + return theta, phi + + +def _fftshift_freq_grid(n: int, d: float, xp) -> array: # type: ignore + """ + Frequency samples (cycles/µm) aligned to fftshifted spectra. + """ + f = xp.fft.fftfreq(n, d=d) + return xp.fft.fftshift(f) + + +def make_pupil( + ny: int, + nx: int, + *, + dxy_um: float, + wavelength_um: float, + na_obj: float, + xp, +) -> array: # type: ignore + """ + Create a circular coherent pupil mask for the objective. + + Parameters + ---------- + ny, nx : int + Output mask dimensions. + dxy_um : float + Sampling pitch (µm) of the simulation grid. + wavelength_um : float + Vacuum wavelength (µm). + na_obj : float + Objective NA determining cutoff f_c = NA / wavelength. + xp : module + Backend array module (NumPy or CuPy). + + Returns + ------- + array + Pupil mask of shape (ny, nx) with ones inside the cutoff and zeros outside, + used by `apply_objective_imaging`. + """ + fx = _fftshift_freq_grid(nx, dxy_um, xp) # (nx,) + fy = _fftshift_freq_grid(ny, dxy_um, xp) # (ny,) + fxx, fyy = xp.meshgrid(fx, fy, indexing="xy") + fr = xp.sqrt(fxx * fxx + fyy * fyy) + + fc = float(na_obj) / float(wavelength_um) + return (fr <= fc).astype(xp.float32) + + +# -----------------------------------------------------------------------------# +# Coherent imaging helpers (pupil + optional apodization) with caching +# -----------------------------------------------------------------------------# + +_APOD_CACHE: dict[tuple, array] = {} # type: ignore + + +def _tukey_1d(n: int, alpha: float, xp) -> array: # type: ignore + """ + Create a 1D Tukey window on the requested backend. + + Parameters + ---------- + n : int + Number of samples (must be > 0). + alpha : float + Taper parameter in [0,1]; 0 yields a rectangular window, 1 yields Hann. + xp : module + Backend array module (NumPy or CuPy). + + Returns + ------- + array + Window of shape (n,) on the requested backend. + """ + if n <= 0: + raise ValueError("n must be > 0") + if alpha <= 0: + return xp.ones((n,), dtype=xp.float32) + if alpha >= 1: + # Hann + x = xp.arange(n, dtype=xp.float32) + return (0.5 - 0.5 * xp.cos(2 * xp.pi * x / (n - 1))).astype(xp.float32) + + x = xp.linspace(0.0, 1.0, n, dtype=xp.float32) + w = xp.ones((n,), dtype=xp.float32) + edge = alpha / 2.0 + + m1 = x < edge + m2 = x > (1.0 - edge) + # rising cosine + w[m1] = 0.5 * (1.0 + xp.cos(xp.pi * ((2.0 * x[m1] / alpha) - 1.0))) + # falling cosine + w[m2] = 0.5 * (1.0 + xp.cos(xp.pi * ((2.0 * x[m2] / alpha) - (2.0 / alpha) + 1.0))) + return w.astype(xp.float32) + + +def _get_cached_apodization( + ny: int, + nx: int, + *, + alpha: float, + use_gpu: bool, +): + """ + Retrieve or build a cached Tukey apodization window on the requested backend. + + Parameters + ---------- + ny, nx : int + Window dimensions. + alpha : float + Tukey alpha parameter passed to `_tukey_1d`. + use_gpu : bool + If True, cache CuPy arrays; otherwise NumPy. + + Returns + ------- + array + 2D apodization window cached for reuse. + """ + xp = cp if (use_gpu and cp is not None) else np + key = (int(ny), int(nx), float(alpha), bool(use_gpu)) + apo = _APOD_CACHE.get(key) + if apo is None: + wy = _tukey_1d(int(ny), float(alpha), xp) + wx = _tukey_1d(int(nx), float(alpha), xp) + apo = (wy[:, None] * wx[None, :]).astype(xp.float32, copy=False) + _APOD_CACHE[key] = apo + return apo + + +def apply_objective_imaging( + e_vec: array, # type: ignore + *, + dxy_um: float, + wavelength_um: float, + na_obj: float, + use_gpu: bool = False, +) -> array: # type: ignore + """ + Apply a simple coherent imaging operator to a vector field. + + This models an ideal infinity-corrected objective + tube lens as a circular pupil + (coherent transfer) in spatial-frequency space. + `dxy_um`, `wavelength_um`, and `na_obj` jointly define the pupil generated by + `make_pupil`, which is applied component-wise to `e_vec` before returning to + the spatial domain. The resulting field is later converted to intensity and + handed to `localize_psf.camera.simulated_img` for camera binning/noise. + + GPU/CPU behavior + ---------------- + This function explicitly converts inputs to the requested backend to avoid + implicit host<->device transfers. + + Parameters + ---------- + e_vec : array + Vector field, shape (3, ny, nx). + dxy_um : float + Sampling (µm) for e_vec grid. + wavelength_um : float + Vacuum wavelength (µm). + na_obj : float + Objective NA. + use_gpu : bool + If True, uses CuPy and runs FFTs on the GPU (requires CuPy). + + Returns + ------- + e_img : array + Filtered vector field, shape (3, ny, nx), on the requested backend. + """ + if e_vec.shape[0] != 3: + raise ValueError(f"e_vec must have shape (3, ny, nx), got {e_vec.shape}") + + # Ensure e_vec lives on the requested backend. + e_vec = _to_xp(e_vec, use_gpu=use_gpu) + xp = cp if (use_gpu and cp is not None) else np + + ny, nx = int(e_vec.shape[-2]), int(e_vec.shape[-1]) + P = make_pupil( + ny, + nx, + dxy_um=float(dxy_um), + wavelength_um=float(wavelength_um), + na_obj=float(na_obj), + xp=xp, + ) + + # Apply per-component in Fourier domain (centered FT). + out = xp.empty_like(e_vec) + for c in range(3): + Ec = _nan_to_zero(e_vec[c], use_gpu=use_gpu) + Ec_ft = ft2(Ec, axes=(-2, -1), shift=True, adjoint=False) + Ec_ft = Ec_ft * P + out[c] = ift2(Ec_ft, axes=(-2, -1), shift=True, adjoint=False) + return out + + +def bin2_average(im: array, factor: int, *, use_gpu: bool = False) -> array: # type: ignore + """ + Bin/average a 2D image by an integer factor. + + This approximates camera pixel integration when `factor` corresponds to an + oversampling ratio (e.g., when simulation sampling is finer than camera pixels). + It mirrors the binning performed inside `localize_psf.camera.simulated_img` + when `bin_size` matches `factor`. + + Backend behavior + ---------------- + This function explicitly converts `im` to the requested backend to avoid + implicit host<->device transfers. + + Parameters + ---------- + im : array + 2D image, shape (ny, nx). + factor : int + Binning factor (>=1). + use_gpu : bool + If True, operate on the GPU (requires CuPy). + + Returns + ------- + binned : array + Binned image, shape (ny//factor, nx//factor), on the requested backend. + """ + if factor < 1: + raise ValueError("factor must be >= 1") + if factor == 1: + return _to_xp(im, use_gpu=use_gpu) + + im = _to_xp(im, use_gpu=use_gpu) + xp = cp if (use_gpu and cp is not None) else np + + ny, nx = int(im.shape[-2]), int(im.shape[-1]) + ny2 = (ny // factor) * factor + nx2 = (nx // factor) * factor + + im2 = im[:ny2, :nx2] + im2 = im2.reshape(ny2 // factor, factor, nx2 // factor, factor) + return xp.mean(im2, axis=(1, 3)) + + +class DPCMieSimulator: + """ + End-to-end simulator that keeps a clear boundary between the simulation grid + (Mie field generation + propagation) and the camera grid (pixel binning + noise). + + Parameters + ---------- + wavelength_um : float + Vacuum wavelength (µm) for field generation and pupil cutoff. + na_obj : float + Objective NA used in `apply_objective_imaging`. + sphere : SphereSpec + Sphere properties forwarded to `mie_fields.mie_efield`. + simulation : SimulationSpace + Simulation grid settings (sampling, size, z-plane). + camera : CameraSpace + Camera grid/noise settings forwarded to `localize_psf.camera.simulated_img`. + led_grid_shape : tuple[int, int], optional + LED board dimensions. + inner_na : float, optional + Inner NA for annular illumination mask. + include_center_led : bool, optional + Whether to include NA=0 LED. + pattern_order : Sequence[str], optional + Order of DPC patterns in outputs. + normalize_by_led_count : bool, optional + Normalize patterns by contributing LED count. + led_subsample : int, optional + Subsample factor for LEDs simulated. + use_gpu : bool, optional + If True, use CuPy-backed computations where available. + mie_kwargs : dict or None, optional + Extra kwargs passed to `mie_fields.mie_efield`. + cache_dir : str or Path or None, optional + Directory for Zarr per-LED camera irradiance cache. + reuse_cache : bool, optional + If True, reuse cached per-LED images when present. + """ + + def __init__( + self, + *, + wavelength_um: float, + na_obj: float, + sphere: SphereSpec, + simulation: SimulationSpace, + camera: CameraSpace, + led_grid_shape: tuple[int, int] = (64, 64), + inner_na: float = 0.0, + include_center_led: bool = False, + pattern_order: Sequence[PatternName] = ("left", "right", "up", "down"), + normalize_by_led_count: bool = True, + led_subsample: int = 1, + use_gpu: bool = False, + mie_kwargs: dict | None = None, + cache_dir: str | Path | None = None, + reuse_cache: bool = True, + exposure_time_ms: float = 1.0, + illumination_photons_per_s_per_um2: float = 1.0, + focal_stack_planes: int | None = None, + focal_stack_step_um: float = 0.5, + ): + if simulated_img is None: + raise ImportError("localize_psf.camera.simulated_img is required but not importable.") + if propagate_homogeneous is None: + raise ImportError("mcsim.analysis.field_prop.propagate_homogeneous is required but not importable.") + + allowed: tuple[PatternName, ...] = ("left", "right", "up", "down") + if len(pattern_order) != 4 or set(pattern_order) != set(allowed): + raise ValueError(f"pattern_order must be a permutation of {allowed}; got {tuple(pattern_order)}") + if led_subsample < 1: + raise ValueError("led_subsample must be >= 1") + + self.wavelength_um = float(wavelength_um) + self.na_obj = float(na_obj) + self.sphere = sphere + self.simulation = simulation + self.camera = camera + self.led_grid_shape = (int(led_grid_shape[0]), int(led_grid_shape[1])) + self.inner_na = float(inner_na) + self.include_center_led = bool(include_center_led) + self.pattern_order = pattern_order + self.normalize_by_led_count = bool(normalize_by_led_count) + self.led_subsample = int(led_subsample) + self.use_gpu = bool(use_gpu) + self.mie_kwargs = mie_kwargs or {} + self.cache_dir = cache_dir + self.reuse_cache = bool(reuse_cache) + self.exposure_time_s = float(exposure_time_ms) / 1000.0 + self.illumination_photons_per_s_per_um2 = float(illumination_photons_per_s_per_um2) + if focal_stack_planes is None or focal_stack_planes <= 1: + self.focal_offsets_um = np.array([0.0], dtype=float) + else: + n = int(focal_stack_planes) + offsets = (np.arange(n) - (n - 1) / 2.0) * float(focal_stack_step_um) + self.focal_offsets_um = offsets.astype(float) + + self.xp = _get_xp(use_gpu) + self.camera_bin_factor = self._camera_bin_factor() + + def _camera_bin_factor(self) -> int: + """ + Integer ratio between camera sampling (object plane) and simulation sampling. + + Returns + ------- + int + Bin factor (`camera.object_pixel_um / simulation.dxy_um`) used as + `bin_size` for `localize_psf.camera.simulated_img`. + """ + ratio = float(self.camera.object_pixel_um) / float(self.simulation.dxy_um) + ratio_round = int(round(ratio)) + if ratio_round < 1 or not np.isclose(ratio, ratio_round, rtol=1e-3, atol=1e-6): + raise ValueError( + f"Camera pixel size / simulation sampling must be a positive integer; got ratio={ratio:.4f}" + ) + return ratio_round + + def _field_plane_z(self) -> float: + """ + Choose a stable evaluation plane for the Mie field. + + This is the `dz` plane passed to `mie_fields.mie_efield`. If the user provides + `simulation.z_plane_um`, it is validated with `_safe_z_plane_um`; otherwise a + default just outside the sphere surface is selected. + """ + if self.simulation.z_plane_um is not None: + return float(_safe_z_plane_um( + self.simulation.z_plane_um, + sphere_radius_um=float(self.sphere.radius_um), + dxy_um=float(self.simulation.dxy_um), + wavelength_um=self.wavelength_um, + )) + return _safe_z_plane_um( + _default_z_plane_um( + sphere_radius_um=float(self.sphere.radius_um), + dxy_um=float(self.simulation.dxy_um), + wavelength_um=self.wavelength_um, + ), + sphere_radius_um=float(self.sphere.radius_um), + dxy_um=float(self.simulation.dxy_um), + wavelength_um=self.wavelength_um, + ) + + def _simulate_led_ground_truth(self, na_x: float, na_y: float) -> array: # type: ignore + """ + Simulate pupil-filtered field on the simulation grid for a single LED (before camera). + + Parameters + ---------- + na_x, na_y : float + Illumination NA components converted to angles for `mie_efield`. + + Returns + ------- + array + Complex field on the simulation grid after refocus and pupil filtering. + + Notes + ----- + Uses `mie_fields.mie_efield` with `simulation.dxy_um`, `simulation.esize`, + and `_field_plane_z()`, refocuses via `propagate_homogeneous`, and applies + `apply_objective_imaging`. + """ + xp = self.xp + theta, phi = na_to_incidence_angles( + float(na_x), + float(na_y), + n_medium=float(self.sphere.n_medium), + ) + + z_plane = self._field_plane_z() + ny_sim, nx_sim = self.simulation.esize + + e_scatt_vec, _, e_inc_vec, _ = mie_efield( + float(self.wavelength_um), + float(self.sphere.n_medium), + float(self.sphere.radius_um), + complex(self.sphere.n_sphere), + float(self.simulation.dxy_um), + (int(ny_sim), int(nx_sim)), + dz=float(z_plane), + beam_theta=float(theta), + beam_phi=float(phi), + use_gpu=bool(self.use_gpu), + **self.mie_kwargs, + ) + + # Scalar field (x-component) and gentle apodization to suppress wrap-around. + e0 = (e_scatt_vec[0] + e_inc_vec[0]).astype(xp.complex64, copy=False) + apod = _get_cached_apodization(int(ny_sim), int(nx_sim), alpha=0.1, use_gpu=bool(self.use_gpu)) + e0 = (e0 * apod.astype(e0.real.dtype, copy=False)).astype(xp.complex64, copy=False) + + e_in = e0[None, None, :, :] # shape (1, 1, ny, nx) for propagate_homogeneous + e_prop = propagate_homogeneous( + e_in, + [-float(z_plane)], + float(self.sphere.n_medium), + (float(self.simulation.dxy_um), float(self.simulation.dxy_um)), + float(self.wavelength_um), + ) + + if e_prop.ndim >= 5: + e_foc = e_prop[..., 0, :, :][0, 0] + elif e_prop.ndim == 4: + e_foc = e_prop[0, 0] + else: + e_foc = e_prop + + # Reuse the existing objective pupil helper. + e_vec = xp.zeros((3,) + e_foc.shape, dtype=e_foc.dtype) + e_vec[0] = e_foc + e_img_vec = apply_objective_imaging( + e_vec, + dxy_um=float(self.simulation.dxy_um), + wavelength_um=float(self.wavelength_um), + na_obj=float(self.na_obj), + use_gpu=bool(self.use_gpu), + ) + return e_img_vec[0] + + def _render_camera(self, irradiance: array) -> array: # type: ignore + """ + Project simulation irradiance onto the camera grid (binning + optional noise). + + Parameters + ---------- + irradiance : array + Simulation-grid irradiance (pre-camera). + + Returns + ------- + array + Camera-grid image after binning/noise from `localize_psf.camera.simulated_img`. + + Notes + ----- + `bin_size` is derived from `camera.object_pixel_um / simulation.dxy_um`; all + camera noise parameters (gains, offsets, readout_noise_sds, photon_shot_noise, + saturation, image_is_integer) are forwarded from `CameraSpace`. PSF input is + not supported; the coherent transfer is modeled via the pupil internally. + """ + xp = self.xp + gt = xp.asarray(irradiance, dtype=xp.float32) + gt = gt * self.exposure_time_s + + gains = xp.asarray(self.camera.gains, dtype=gt.dtype) + offsets = xp.asarray(self.camera.offsets, dtype=gt.dtype) + readout = xp.asarray(self.camera.readout_noise_sds, dtype=gt.dtype) + + psf = None + if self.camera.psf is not None: + psf = xp.asarray(self.camera.psf, dtype=gt.dtype) + + apo = self.camera.apodization + if isinstance(apo, (np.ndarray,)) or (cp is not None and isinstance(apo, cp.ndarray)): + apo = xp.asarray(apo, dtype=gt.dtype) + + img, _ = simulated_img( + ground_truth=gt, + gains=gains, + offsets=offsets, + readout_noise_sds=readout, + psf=psf, + photon_shot_noise=bool(self.camera.photon_shot_noise), + bin_size=int(self.camera_bin_factor), + apodization=apo, + saturation=self.camera.saturation, + image_is_integer=bool(self.camera.image_is_integer), + ) + return xp.asarray(img, dtype=xp.float32) + + def _simulate_led(self, na_x: float, na_y: float) -> array: # type: ignore + """ + Full pipeline for one LED: field generation → optional defocus → camera rendering. + + Returns a focal stack with shape (n_planes, ny_cam, nx_cam). + """ + xp = self.xp + field0 = self._simulate_led_ground_truth(na_x, na_y) + + flux_scale = float(self.illumination_photons_per_s_per_um2) * float(self.simulation.dxy_um) ** 2 + stack = [] + for dz in self.focal_offsets_um: + if dz == 0: + f_def = field0 + else: + e_in = field0[None, None, :, :] + e_prop = propagate_homogeneous( + e_in, + [float(dz)], + float(self.sphere.n_medium), + (float(self.simulation.dxy_um), float(self.simulation.dxy_um)), + float(self.wavelength_um), + ) + if e_prop.ndim >= 5: + f_def = e_prop[..., 0, :, :][0, 0] + elif e_prop.ndim == 4: + f_def = e_prop[0, 0] + else: + f_def = e_prop + + I_sim = (xp.abs(f_def) ** 2).astype(xp.float32, copy=False) * flux_scale + stack.append(self._render_camera(I_sim)) + + return xp.stack(stack, axis=0) + + def simulate_patterns(self) -> tuple[array, dict[str, array]]: # type: ignore + """ + Simulate all LEDs and accumulate into DPC patterns. + + Returns + ------- + tuple[array, dict[str, array]] + DPC stack ordered by `pattern_order` and metadata including NA positions, + grid sizes, bin factor, and LED counts. + + Notes + ----- + - LED NA grid and pattern membership come from `make_led_na_positions` and + `split_dpc_patterns` using `led_grid_shape`, `inner_na`, `include_center_led`, + and `pattern_order`. + - Each LED image is produced by `_simulate_led` (internally + `mie_fields.mie_efield` → `propagate_homogeneous` → `apply_objective_imaging` + → `localize_psf.camera.simulated_img`). + - Optional per-LED cache stored in Zarr (v3 only). + """ + allowed: tuple[PatternName, ...] = ("left", "right", "up", "down") + ny_led, nx_led = self.led_grid_shape + na_xy_all = make_led_na_positions( + ny_led, + nx_led, + na_obj=float(self.na_obj), + inner_na=float(self.inner_na), + include_center=bool(self.include_center_led), + ) + patterns_full = split_dpc_patterns(na_xy_all, order=self.pattern_order) + + na_xy = na_xy_all[:: self.led_subsample] + na_x = na_xy[:, 0] + na_y = na_xy[:, 1] + + membership = { + "left": na_x < 0, + "right": na_x > 0, + "up": na_y > 0, + "down": na_y < 0, + } + + if self.include_center_led: + center_mask = (na_x == 0) & (na_y == 0) + if np.any(center_mask): + for k in membership: + membership[k] = membership[k] | center_mask + + xp = self.xp + ny_cam, nx_cam = self.camera.shape + n_planes = int(len(self.focal_offsets_um)) + acc = {k: xp.zeros((n_planes, ny_cam, nx_cam), dtype=xp.float32) for k in allowed} + led_counts = {k: int(np.sum(np.asarray(membership[k], dtype=bool))) for k in allowed} + + run_dir = None + I_cache = None + done_cache = None + + if self.cache_dir is not None: + cache_params = { + "wavelength_um": float(self.wavelength_um), + "na_obj": float(self.na_obj), + "led_grid_shape": [int(self.led_grid_shape[0]), int(self.led_grid_shape[1])], + "camera_pixel_um": float(self.camera.pixel_um), + "magnification": float(self.camera.magnification), + "camera_shape": [int(ny_cam), int(nx_cam)], + "simulation_dxy_um": float(self.simulation.dxy_um), + "simulation_esize": [int(self.simulation.esize[0]), int(self.simulation.esize[1])], + "field_generation_z_um": float(self._field_plane_z()), + "sphere": { + "radius_um": float(self.sphere.radius_um), + "n_sphere": str(self.sphere.n_sphere), + "n_medium": float(self.sphere.n_medium), + }, + "inner_na": float(self.inner_na), + "include_center_led": bool(self.include_center_led), + "pattern_order": [str(p) for p in self.pattern_order], + "normalize_by_led_count": bool(self.normalize_by_led_count), + "led_subsample": int(self.led_subsample), + "mie_kwargs": self.mie_kwargs, + "camera_bin_factor": int(self.camera_bin_factor), + "exposure_time_s": float(self.exposure_time_s), + "illumination_photons_per_s_per_um2": float(self.illumination_photons_per_s_per_um2), + "focal_offsets_um": self.focal_offsets_um.tolist(), + } + run_dir = _cache_run_dir(self.cache_dir, cache_params) + chunks = (1, 1, min(256, int(ny_cam)), min(256, int(nx_cam))) + _, I_cache, done_cache = _open_led_cache_zarr( + Path(run_dir), + n_led=int(na_xy.shape[0]), + n_planes=int(n_planes), + ny=int(ny_cam), + nx=int(nx_cam), + chunks=chunks, + ) + + for j in trange(int(na_xy.shape[0]), desc="Simulating LEDs"): + nax = float(na_xy[j, 0]) + nay = float(na_xy[j, 1]) + + I_cam = None + if I_cache is not None and done_cache is not None and self.reuse_cache and int(done_cache[j]) == 1: + I_cam_np = np.asarray(I_cache[j, ...], dtype=np.float32) + I_cam = cp.asarray(I_cam_np) if (self.use_gpu and cp is not None) else I_cam_np + I_cam = I_cam.astype(xp.float32, copy=False) + + if I_cam is None: + I_cam = self._simulate_led(nax, nay).astype(xp.float32, copy=False) + + if I_cache is not None and done_cache is not None: + I_cam_save = cp.asnumpy(I_cam) if (cp is not None and hasattr(I_cam, "__cuda_array_interface__")) else np.asarray(I_cam) + I_cache[j, ...] = np.asarray(I_cam_save, dtype=np.float32) + done_cache[j] = 1 + + if not (xp.isfinite(I_cam).all()) or xp.isinf(I_cam).any(): + print(f"isfinite check: {xp.isfinite(I_cam).all()}") + print(f"isinf check: {xp.isinf(I_cam).any()}") + + if membership["left"][j]: + acc["left"] += I_cam + if membership["right"][j]: + acc["right"] += I_cam + if membership["up"][j]: + acc["up"] += I_cam + if membership["down"][j]: + acc["down"] += I_cam + + if self.normalize_by_led_count: + for k in allowed: + if led_counts[k] <= 0: + raise RuntimeError(f"No LEDs accumulated for pattern '{k}' (after subsampling).") + acc[k] = acc[k] / float(led_counts[k]) + + # Sanity check: ensure no pattern is identically zero after accumulation + for k in allowed: + if float(xp.sum(acc[k])) == 0.0: + raise RuntimeError(f"Accumulated pattern '{k}' is zero; check illumination masks and cache settings.") + + dpc = xp.stack([acc[p].copy() for p in self.pattern_order], axis=0) + dpc = xp.moveaxis(dpc, 0, 1) # (n_planes, 4, ny, nx) + dpc_out: array + if dpc.shape[0] == 1: + dpc_out = dpc[0] + else: + dpc_out = dpc + + meta: dict[str, array] = { # type: ignore + "na_xy": na_xy_all, + "simulation_dxy_um": np.asarray(self.simulation.dxy_um, dtype=np.float32), + "simulation_esize": (int(self.simulation.esize[0]), int(self.simulation.esize[1])), + "camera_shape": (int(ny_cam), int(nx_cam)), + "camera_bin_factor": int(self.camera_bin_factor), + "led_counts": {k: int(led_counts[k]) for k in allowed}, + "z_plane_um": float(self._field_plane_z()), + "exposure_time_s": float(self.exposure_time_s), + "illumination_photons_per_s_per_um2": float(self.illumination_photons_per_s_per_um2), + "pattern_sums": {k: float(xp.sum(acc[k])) for k in allowed}, + "focal_offsets_um": self.focal_offsets_um.tolist(), + } + for k, v in patterns_full.items(): + meta[k] = v + + return dpc_out, meta + + + +def simulate_led_irradiance( + *, + wavelength_um: float, + sphere: SphereSpec, + dxy_um: float, + esize: tuple[int, int], + z_plane_um: float, + na_x: float, + na_y: float, + na_obj: float, + camera_bin: int = 1, + use_gpu: bool = False, + mie_kwargs: dict | None = None, +) -> array: # type: ignore + """ + Simulate the (binned) camera-plane irradiance for a single LED illumination. + + Parameters + ---------- + wavelength_um : float + Vacuum wavelength (µm) for `mie_efield` and the objective pupil. + sphere : SphereSpec + Sphere properties forwarded to `mie_efield`. + dxy_um : float + Simulation sampling (µm) passed to `mie_efield` and propagation. + esize : tuple[int, int] + Simulation grid size (ny, nx). + z_plane_um : float + Field evaluation plane for `mie_efield`; if inside the sphere, a safe plane is chosen. + na_x, na_y : float + Illumination NA components converted to angles for `mie_efield`. + na_obj : float + Objective NA for the coherent pupil. + camera_bin : int, optional + Bin size for `simulated_img`, derived from simulation vs camera sampling. + use_gpu : bool, optional + If True, use CuPy-backed operations when available. + mie_kwargs : dict or None, optional + Extra keyword arguments passed directly to `mie_fields.mie_efield`. + + Returns + ------- + array + Binned camera-plane irradiance for the single LED. + + Notes + ----- + Internally builds a `DPCMieSimulator` with matching grids and calls its single-LED + pipeline: `mie_efield` → `propagate_homogeneous` → `apply_objective_imaging` → + `localize_psf.camera.simulated_img`. + """ + if sphere is None: + sphere = SphereSpec(radius_um=5.0, n_sphere=1.59 + 0.0j, n_medium=1.55) + + sim_space = SimulationSpace( + dxy_um=float(dxy_um), + esize=(int(esize[0]), int(esize[1])), + z_plane_um=float(z_plane_um), + ) + cam_shape = (int(esize[0] // camera_bin), int(esize[1] // camera_bin)) + cam_space = CameraSpace( + pixel_um=float(dxy_um) * float(camera_bin), + magnification=1.0, + shape=cam_shape, + photon_shot_noise=False, + readout_noise_sds=0.0, + gains=1.0, + offsets=0.0, + image_is_integer=False, + ) + + simulator = DPCMieSimulator( + wavelength_um=float(wavelength_um), + na_obj=float(na_obj), + sphere=sphere, + simulation=sim_space, + camera=cam_space, + led_grid_shape=(1, 1), + normalize_by_led_count=False, + led_subsample=1, + use_gpu=use_gpu, + mie_kwargs=mie_kwargs, + cache_dir=None, + reuse_cache=True, + ) + return simulator._simulate_led(float(na_x), float(na_y)) + + +def simulate_dpc_images_sphere( + *, + wavelength_um: float = 0.515, + na_obj: float = 0.8, + led_grid_shape: tuple[int, int] = (64, 64), + camera_pixel_um: float = 2.4, + magnification: float = 20.0, + camera_oversample: int = 1, + esize_camera: tuple[int, int] = (256, 256), + sim_dxy_um: float | None = None, + sim_esize: tuple[int, int] | None = None, + z_plane_um: float = 0.0, + sphere: SphereSpec | None = None, + inner_na: float = 0.0, + include_center_led: bool = False, + pattern_order: Sequence[PatternName] = ("left", "right", "up", "down"), + normalize_by_led_count: bool = True, + led_subsample: int = 1, + use_gpu: bool = False, + mie_kwargs: dict | None = None, + cache_dir: str | Path | None = None, + reuse_cache: bool = True, + camera_gains: array | float = 1.0, # type: ignore + camera_offsets: array | float = 0.0, # type: ignore + camera_readout_noise_sds: array | float = 0.0, # type: ignore + camera_photon_shot_noise: bool = False, + camera_saturation: int | None = None, + camera_image_is_integer: bool = False, + exposure_time_ms: float = 1.0, + illumination_photons_per_s_per_um2: float = 1.0, + focal_stack_planes: int | None = None, + focal_stack_step_um: float = 0.5, +) -> tuple[array, dict[str, array]]: # type: ignore + """ + Generate four synthetic DPC images by incoherently summing per-LED irradiance. + + Parameters (and where they go) + ------------------------------ + wavelength_um : float + Vacuum wavelength (µm). Passed to `mie_fields.mie_efield`, the objective pupil + in `apply_objective_imaging`, and `propagate_homogeneous`. + na_obj : float + Objective NA. Sets the coherent cutoff in `apply_objective_imaging` via + `make_pupil`. + led_grid_shape : (int, int) + LED board shape (ny_led, nx_led) for `make_led_na_positions`. + camera_pixel_um : float + Camera pixel pitch (µm) in the camera plane; combined with `magnification` to + derive object-plane sampling and the bin size for `localize_psf.camera.simulated_img`. + magnification : float + Object-to-camera magnification. Higher values increase the bin factor between + simulation grid and camera grid. + camera_oversample : int + Optional oversampling factor applied to the camera grid; influences the default + `sim_dxy_um` and `sim_esize` when not provided. + esize_camera : (int, int) + Final camera image shape (ny, nx) produced by `simulated_img`. + sim_dxy_um : float | None + Simulation grid sampling (µm). If None, derived from `camera_pixel_um / + magnification / camera_oversample` and forwarded to `mie_efield` and + `propagate_homogeneous`. + sim_esize : (int, int) | None + Simulation grid size (ny, nx). If None, defaults to `esize_camera * + camera_oversample`. + z_plane_um : float + Optional user-specified field evaluation plane for `mie_efield` (validated to + stay outside the sphere). If 0, a safe exterior plane is chosen automatically. + sphere : SphereSpec | None + Sphere parameters (radius, n_sphere, n_medium) passed to `mie_efield`. Defaults + to a polystyrene-like sphere in medium if None. + inner_na : float + Inner NA to form an annular LED mask in `make_led_na_positions`. + include_center_led : bool + Whether to include the origin LED (NA=0) in all patterns. + camera_gains : array or float, optional + Gains (ADU/e) for `simulated_img`. + camera_offsets : array or float, optional + Offsets (ADU) for `simulated_img`. + camera_readout_noise_sds : array or float, optional + Readout noise SD (ADU) for `simulated_img`. + camera_photon_shot_noise : bool, optional + Enable Poisson shot noise in `simulated_img`. + camera_saturation : int or None, optional + Saturation level passed to `simulated_img`. + camera_image_is_integer : bool, optional + If True, round simulated images to integers in `simulated_img`. + exposure_time_ms : float, optional + Exposure time (milliseconds). Multiplies simulated irradiance to convert to + expected photon counts before camera noise is applied. + illumination_photons_per_s_per_um2 : float, optional + Incident photon flux density (photons/s/µm^2) used to scale the Mie-derived + irradiance (relative units) into absolute photons/s per pixel area. + focal_stack_planes : int or None, optional + Number of focal planes to simulate. If None or <=1, a single plane is produced. + focal_stack_step_um : float, optional + Axial step (µm) between focal planes, centered on nominal focus. + pattern_order : Sequence[str] + Order of DPC patterns ("left","right","up","down") returned in the output stack. + normalize_by_led_count : bool + If True, divides each pattern sum by the number of LEDs contributing to it. + led_subsample : int + Subsample factor on the LED list before simulation (simulates every Nth LED). + use_gpu : bool + If True, uses CuPy-backed versions of `mie_efield`, FFTs, and `simulated_img` + where available. + mie_kwargs : dict | None + Extra keyword arguments forwarded directly to `mie_fields.mie_efield`. + cache_dir : str | Path | None + Optional directory for Zarr-based per-LED camera irradiance caching + (v3-compatible; no .npy files). + reuse_cache : bool + If True, reuses cached per-LED camera images when present. + + Returns + ------- + dpc : array + Stack of four DPC images ordered by `pattern_order` (shape (4, ny, nx)). + meta : dict + Metadata including NA positions, grid sizes, bin factor, and LED counts for + downstream reconstruction/testing. + """ + if led_subsample < 1: + raise ValueError("led_subsample must be >= 1") + if camera_oversample < 1: + raise ValueError("camera_oversample must be >= 1") + + if sphere is None: + sphere = SphereSpec(radius_um=5.0, n_sphere=1.59 + 0.0j, n_medium=1.55) + + sim_dxy_val = sim_dxy_um if sim_dxy_um is not None else sample_pixel_size_um( + camera_pixel_um, magnification, oversample=int(camera_oversample) + ) + sim_esize_val = sim_esize if sim_esize is not None else ( + int(esize_camera[0]) * int(camera_oversample), + int(esize_camera[1]) * int(camera_oversample), + ) + + sim_space = SimulationSpace( + dxy_um=float(sim_dxy_val), + esize=(int(sim_esize_val[0]), int(sim_esize_val[1])), + z_plane_um=float(z_plane_um), + ) + cam_space = CameraSpace( + pixel_um=float(camera_pixel_um), + magnification=float(magnification), + shape=(int(esize_camera[0]), int(esize_camera[1])), + photon_shot_noise=bool(camera_photon_shot_noise), + readout_noise_sds=camera_readout_noise_sds, + gains=camera_gains, + offsets=camera_offsets, + image_is_integer=bool(camera_image_is_integer), + psf=None, + apodization=1, + saturation=camera_saturation, + ) + + simulator = DPCMieSimulator( + wavelength_um=float(wavelength_um), + na_obj=float(na_obj), + sphere=sphere, + simulation=sim_space, + camera=cam_space, + led_grid_shape=led_grid_shape, + inner_na=float(inner_na), + include_center_led=bool(include_center_led), + pattern_order=pattern_order, + normalize_by_led_count=bool(normalize_by_led_count), + led_subsample=int(led_subsample), + use_gpu=bool(use_gpu), + mie_kwargs=mie_kwargs, + cache_dir=cache_dir, + reuse_cache=bool(reuse_cache), + exposure_time_ms=float(exposure_time_ms), + illumination_photons_per_s_per_um2=float(illumination_photons_per_s_per_um2), + focal_stack_planes=focal_stack_planes, + focal_stack_step_um=float(focal_stack_step_um), + ) + + dpc, meta = simulator.simulate_patterns() + + # Compatibility metadata with previous helper + meta.update({ + "dxy_um": np.asarray(sim_dxy_val, dtype=np.float32), + "camera_oversample": int(simulator.camera_bin_factor), + "esize_camera": (int(esize_camera[0]), int(esize_camera[1])), + "esize_mie": (int(sim_space.esize[0]), int(sim_space.esize[1])), + }) + + # Crop 10% border from each side to keep central region + def _crop_center(arr: array) -> array: + if arr.ndim == 3: + _, ny, nx = arr.shape + elif arr.ndim == 4: + _, _, ny, nx = arr.shape + else: + return arr + ypad = max(0, int(round(0.1 * ny))) + xpad = max(0, int(round(0.1 * nx))) + y1, y2 = ypad, ny - ypad + x1, x2 = xpad, nx - xpad + if y2 <= y1 or x2 <= x1: + return arr + if arr.ndim == 3: + return arr[:, y1:y2, x1:x2] + else: + return arr[:, :, y1:y2, x1:x2] + + dpc = _crop_center(dpc) + if dpc.ndim == 3: + _, ny_crop, nx_crop = dpc.shape + else: + _, _, ny_crop, nx_crop = dpc.shape + meta["camera_shape"] = (int(ny_crop), int(nx_crop)) + meta["esize_camera"] = (int(ny_crop), int(nx_crop)) + + return dpc, meta + + +def write_dpc_zarr( + zarr_path: str | Path, + dpc: array, # type: ignore + meta: dict[str, array], # type: ignore + *, + wavelength_um: float, + na_obj: float, + camera_pixel_um: float, + magnification: float, + n_medium: float, + led_grid_shape: tuple[int, int], + pattern_order: Sequence[PatternName], + sphere: SphereSpec, + inner_na: float, + include_center_led: bool, + nz: int | None = None, + z_span_um: float | None = None, + overwrite: bool = True, +) -> None: + """ + Write a DPC synthetic dataset to Zarr in the layout expected by tests. + + Parameters + ---------- + zarr_path : str or Path + Output Zarr group path. + dpc : array + DPC stack (4, ny, nx) to write. + meta : dict[str, array] + Metadata returned from `simulate_dpc_images_sphere` (includes NA positions). + wavelength_um : float + Vacuum wavelength (µm) for root attrs. + na_obj : float + Objective NA for root attrs. + camera_pixel_um : float + Camera pixel size (µm) for root attrs. + magnification : float + System magnification for root attrs. + n_medium : float + Medium index for root attrs. + led_grid_shape : tuple[int, int] + LED grid shape stored in root attrs. + pattern_order : Sequence[str] + Pattern order stored in root attrs. + sphere : SphereSpec + Sphere parameters; radius/index stored as extra attrs. + inner_na : float + Inner NA stored as extra attr. + include_center_led : bool + Whether center LED included; stored as extra attr. + nz : int or None, optional + Optional z-stacks count stored in attrs. + z_span_um : float or None, optional + Optional z-span stored in attrs. + overwrite : bool, optional + If True, overwrite existing Zarr group. + + Returns + ------- + None + """ + try: + import zarr # type: ignore + except Exception as e: # pragma: no cover + raise ImportError("zarr is required to write datasets") from e + + zarr_path = str(zarr_path) + mode = "w" if overwrite else "w-" + g = zarr.open_group(zarr_path, mode=mode) + + # --- Root attributes (exact keys expected by tests) --- + g.attrs["wavelength_um"] = float(wavelength_um) + g.attrs["na_obj"] = float(na_obj) + g.attrs["camera_pixel_um"] = float(camera_pixel_um) + g.attrs["magnification"] = float(magnification) + g.attrs["n_medium"] = float(n_medium) + g.attrs["led_grid_shape"] = [int(led_grid_shape[0]), int(led_grid_shape[1])] + g.attrs["pattern_order"] = [str(p) for p in pattern_order] + # Effective Mie evaluation plane used for simulation (µm). + if "z_plane_um" in meta: + g.attrs["field_generation_z_um"] = float(np.asarray(meta["z_plane_um"])) + + # Ensure CPU ndarray for Zarr + if cp is not None and isinstance(dpc, cp.ndarray): + dpc_np = cp.asnumpy(dpc).astype(np.float32, copy=False) + else: + dpc_np = np.asarray(dpc, dtype=np.float32) + + if nz is not None: + g.attrs["nz"] = int(nz) + elif dpc_np.ndim == 4: + g.attrs["nz"] = int(dpc_np.shape[0]) + if z_span_um is not None: + g.attrs["z_span_um"] = float(z_span_um) + + # Optional extras (not required by tests) + g.attrs["sphere_radius_um"] = float(sphere.radius_um) + g.attrs["sphere_n_sphere"] = complex(sphere.n_sphere).__repr__() + g.attrs["sphere_n_medium"] = float(sphere.n_medium) + g.attrs["inner_na"] = float(inner_na) + g.attrs["include_center_led"] = bool(include_center_led) + g.attrs["dxy_um"] = float(np.asarray(meta.get("dxy_um", camera_pixel_um / magnification))) + + if dpc_np.ndim == 3: + if dpc_np.shape[0] != 4: + raise ValueError(f"dpc must have shape (4, ny, nx), got {dpc_np.shape}") + chunks = (1, min(256, dpc_np.shape[1]), min(256, dpc_np.shape[2])) + elif dpc_np.ndim == 4: + if dpc_np.shape[1] != 4: + raise ValueError(f"dpc must have shape (nz, 4, ny, nx), got {dpc_np.shape}") + chunks = (1, 1, min(256, dpc_np.shape[2]), min(256, dpc_np.shape[3])) + else: + raise ValueError(f"dpc must have ndim 3 or 4, got {dpc_np.ndim}") + _zarr_write_array(g, "dpc", dpc_np, dtype="float32", chunks=chunks, overwrite=True) + + mg = g.require_group("meta") + if "na_xy" in meta: + na_xy_np = np.asarray(meta["na_xy"], dtype=np.float32) + _zarr_write_array(mg, "na_xy", na_xy_np, dtype="float32", overwrite=True) + + for pname in ("left", "right", "up", "down"): + if pname in meta: + pts_np = np.asarray(meta[pname], dtype=np.float32) + _zarr_write_array(mg, pname, pts_np, dtype="float32", overwrite=True) + if "focal_offsets_um" in meta: + offsets_np = np.asarray(meta["focal_offsets_um"], dtype=np.float32) + _zarr_write_array(mg, "focal_offsets_um", offsets_np, dtype="float32", overwrite=True) + + +def simulate_dpc_images_sphere_to_zarr( + zarr_path: str | Path, + *, + wavelength_um: float = 0.515, + na_obj: float = 0.8, + led_grid_shape: tuple[int, int] = (64, 64), + camera_pixel_um: float = 2.4, + magnification: float = 20.0, + camera_oversample: int = 1, + esize_camera: tuple[int, int] = (256, 256), + sim_dxy_um: float | None = None, + sim_esize: tuple[int, int] | None = None, + z_plane_um: float = 0.0, + sphere: SphereSpec | None = None, + inner_na: float = 0.0, + include_center_led: bool = False, + pattern_order: Sequence[PatternName] = ("left", "right", "up", "down"), + normalize_by_led_count: bool = True, + led_subsample: int = 1, + use_gpu: bool = False, + mie_kwargs: dict | None = None, + cache_dir: str | Path | None = None, + reuse_cache: bool = True, + nz: int | None = None, + z_span_um: float | None = None, + overwrite: bool = True, + camera_gains: array | float = 1.0, # type: ignore + camera_offsets: array | float = 0.0, # type: ignore + camera_readout_noise_sds: array | float = 0.0, # type: ignore + camera_photon_shot_noise: bool = False, + camera_saturation: int | None = None, + camera_image_is_integer: bool = False, + exposure_time_ms: float = 1.0, + illumination_photons_per_s_per_um2: float = 1.0, + focal_stack_planes: int | None = None, + focal_stack_step_um: float = 0.5, +) -> None: + """ + Generate DPC images and write them to Zarr. + + Parameters + ---------- + zarr_path : str or Path + Output Zarr group path. + wavelength_um : float, optional + Vacuum wavelength (µm) passed to simulation and stored in attrs. + na_obj : float, optional + Objective NA for pupil cutoff and attrs. + led_grid_shape : tuple[int, int], optional + LED board shape for pattern generation. + camera_pixel_um : float, optional + Camera pixel size (µm) for attrs and camera sampling. + magnification : float, optional + Object-to-camera magnification used to derive camera binning. + camera_oversample : int, optional + Oversampling factor; affects default simulation sampling/size. + esize_camera : tuple[int, int], optional + Final camera image shape (ny, nx). + sim_dxy_um : float or None, optional + Simulation sampling (µm); defaults to camera_pixel_um / magnification / camera_oversample. + sim_esize : tuple[int, int] or None, optional + Simulation grid size; defaults to esize_camera * camera_oversample. + z_plane_um : float, optional + Field evaluation plane for `mie_efield`; validated to be outside the sphere. + sphere : SphereSpec or None, optional + Sphere parameters; defaults to a preset if None. + inner_na : float, optional + Inner NA for annular LED mask. + include_center_led : bool, optional + Whether to include the center LED. + pattern_order : Sequence[str], optional + Order of DPC patterns in output. + normalize_by_led_count : bool, optional + If True, normalize each pattern by its LED count. + led_subsample : int, optional + Subsample factor for LEDs. + use_gpu : bool, optional + If True, use CuPy-backed ops when available. + mie_kwargs : dict or None, optional + Extra kwargs passed to `mie_fields.mie_efield`. + cache_dir : str or Path or None, optional + Directory for Zarr per-LED camera irradiance cache. + reuse_cache : bool, optional + If True, reuse cached per-LED images. + camera_gains : array or float, optional + Gains (ADU/e) for `simulated_img`. + camera_offsets : array or float, optional + Offsets (ADU) for `simulated_img`. + camera_readout_noise_sds : array or float, optional + Readout noise SD (ADU) for `simulated_img`. + camera_photon_shot_noise : bool, optional + Enable Poisson shot noise in `simulated_img`. + camera_saturation : int or None, optional + Saturation level for `simulated_img`. + camera_image_is_integer : bool, optional + If True, round simulated images to integers in `simulated_img`. + exposure_time_ms : float, optional + Exposure time (milliseconds). Multiplies simulated irradiance before camera noise. + illumination_photons_per_s_per_um2 : float, optional + Incident photon flux density (photons/s/µm^2) used to scale irradiance to photons. + focal_stack_planes : int or None, optional + Number of focal planes to simulate. If None or <=1, a single plane is produced. + focal_stack_step_um : float, optional + Axial step (µm) between focal planes, centered on nominal focus. + nz : int or None, optional + Optional z-stack count stored in attrs. + z_span_um : float or None, optional + Optional z-span stored in attrs. + overwrite : bool, optional + If True, overwrite existing Zarr group. + + Returns + ------- + None + """ + if cache_dir is None: + # Default cache directory adjacent to the output Zarr store. + zp = Path(zarr_path) + cache_dir = zp.with_name(zp.name + ".mie_cache") + + dpc, meta = simulate_dpc_images_sphere( + wavelength_um=wavelength_um, + na_obj=na_obj, + led_grid_shape=led_grid_shape, + camera_pixel_um=camera_pixel_um, + magnification=magnification, + camera_oversample=camera_oversample, + esize_camera=esize_camera, + sim_dxy_um=sim_dxy_um, + sim_esize=sim_esize, + z_plane_um=z_plane_um, + sphere=sphere, + inner_na=inner_na, + include_center_led=include_center_led, + pattern_order=pattern_order, + normalize_by_led_count=normalize_by_led_count, + led_subsample=led_subsample, + use_gpu=use_gpu, + mie_kwargs=mie_kwargs, + cache_dir=cache_dir, + reuse_cache=reuse_cache, + camera_gains=camera_gains, + camera_offsets=camera_offsets, + camera_readout_noise_sds=camera_readout_noise_sds, + camera_photon_shot_noise=camera_photon_shot_noise, + camera_saturation=camera_saturation, + camera_image_is_integer=camera_image_is_integer, + exposure_time_ms=exposure_time_ms, + illumination_photons_per_s_per_um2=illumination_photons_per_s_per_um2, + focal_stack_planes=focal_stack_planes, + focal_stack_step_um=focal_stack_step_um, + ) + + sphere_used = sphere if sphere is not None else SphereSpec(radius_um=5.0, n_sphere=1.59, n_medium=1.515) + + write_dpc_zarr( + zarr_path, + dpc, + meta, + wavelength_um=wavelength_um, + na_obj=na_obj, + camera_pixel_um=camera_pixel_um, + magnification=magnification, + n_medium=float(sphere_used.n_medium), + led_grid_shape=led_grid_shape, + pattern_order=pattern_order, + sphere=sphere_used, + inner_na=inner_na, + include_center_led=include_center_led, + nz=nz, + z_span_um=z_span_um, + overwrite=overwrite, + ) + +if __name__ == "__main__": # pragma: no cover + # Example usage: generate a DPC dataset and write to Zarr. + output_path = Path("/media/dps/data/synthetic_dpc_sphere.zarr") + + simulate_dpc_images_sphere_to_zarr( + output_path, + wavelength_um=0.515, + na_obj=0.8, + led_grid_shape=(64, 64), + camera_pixel_um=2.4, + magnification=20.0, + camera_oversample=2, + esize_camera=(512, 512), + z_plane_um=0.0, + sphere=SphereSpec(radius_um=2.5, n_sphere=1.59, n_medium=1.515), + inner_na=0.0, + include_center_led=True, + pattern_order=("left", "right", "up", "down"), + normalize_by_led_count=True, + led_subsample=1, + use_gpu=True, + mie_kwargs=None, + camera_gains=1.0, + camera_offsets=100.0, + camera_readout_noise_sds=3.31, + camera_photon_shot_noise=True, + camera_saturation=None, + camera_image_is_integer=True, + exposure_time_ms=100.0, + illumination_photons_per_s_per_um2=1e6, + focal_stack_planes=41, + focal_stack_step_um=.325, + overwrite=True, + ) + print(f"Synthetic DPC dataset written to {output_path}") diff --git a/pyproject.toml b/pyproject.toml index 21844d6..370cc4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ gpu = ['cupy-cuda11x', # assuming 11.2 <= CUDA version < 12. Otherwise, manuall 'cucim ; platform_system != "win32"' ] optional = ['napari', - 'miepython', + 'miepython==2.5.5', 'psfmodels'] dev = ['sphinx'] all = ['mcsim[expt_ctrl, gpu, optional, dev]'] diff --git a/tests/mie_unittest.py b/tests/mie_unittest.py index 6b75198..65b651e 100644 --- a/tests/mie_unittest.py +++ b/tests/mie_unittest.py @@ -52,3 +52,6 @@ def test_yn_gpu(self): np.testing.assert_allclose(yn_scipy[mask], yn_cp[mask], atol=1e-8) np.testing.assert_allclose(dyn_scipy[mask], dyn_cp[mask], atol=1e-8) + +if __name__ == "__main__": + unittest.main() From f7cb394780299c6e18dbb11eea11cf756ed9e580 Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Thu, 1 Jan 2026 20:07:28 -0700 Subject: [PATCH 03/10] initial work on linear prop. based DPC inverse solver --- mcsim/analysis/dpc_inverse.py | 322 ++++++++++++++++++++++++++++++++++ tests/test_dpc_inverse.py | 139 +++++++++++++++ 2 files changed, 461 insertions(+) create mode 100644 mcsim/analysis/dpc_inverse.py create mode 100644 tests/test_dpc_inverse.py diff --git a/mcsim/analysis/dpc_inverse.py b/mcsim/analysis/dpc_inverse.py new file mode 100644 index 0000000..d840d4e --- /dev/null +++ b/mcsim/analysis/dpc_inverse.py @@ -0,0 +1,322 @@ +""" +Inverse-model solver for differential phase contrast (DPC) stacks using a linear +Rytov forward model and FISTA with a plug-and-play median proximal operator. + +The solver maps a 3D refractive-index volume to DPC measurements generated by +half-plane LED patterns, optionally across a focal stack. It reuses the +GPU-aware FFT and linear scattering tools already in the codebase. +""" + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import scipy.sparse as sp +from scipy.sparse.linalg import svds + +try: + import cupy as cp + import cupyx.scipy.sparse as sp_gpu +except ImportError: # pragma: no cover + cp = None + sp_gpu = None + +from mcsim.analysis.fft import ft2, ift2 +from mcsim.analysis.field_prop import ( + fwd_model_linear, + get_angular_spectrum_kernel, + get_v, +) +from mcsim.analysis.optimize import Optimizer, median_prox + + +if cp: + array = np.ndarray | cp.ndarray +else: + array = np.ndarray + + +def _get_xp(use_gpu: bool): + return cp if (use_gpu and cp is not None) else np + + +def _split_dpc_patterns(na_xy: np.ndarray, order: Sequence[str]) -> list[np.ndarray]: + """ + Split LED NA coordinates into canonical DPC half-planes. + """ + allowed = ("left", "right", "up", "down") + if len(order) != 4 or set(order) != set(allowed): + raise ValueError(f"order must be a permutation of {allowed}, got {order}") + + na_x = na_xy[:, 0] + na_y = na_xy[:, 1] + masks = { + "left": na_x < 0, + "right": na_x > 0, + "up": na_y > 0, + "down": na_y < 0, + } + + indices: list[np.ndarray] = [] + for name in order: + idx = np.nonzero(masks[name])[0] + if idx.size == 0: + raise ValueError(f"pattern '{name}' has zero LEDs") + indices.append(idx) + return indices + + +def _defocus_kernels( + ny: int, + nx: int, + dxy: float, + zs: Sequence[float], + wavelength: float, + n_medium: float, + xp, +) -> array: + """ + Precompute angular spectrum defocus kernels for each focal offset. + """ + fx = xp.expand_dims(xp.fft.fftfreq(nx, dxy), axis=0) + fy = xp.expand_dims(xp.fft.fftfreq(ny, dxy), axis=1) + kernels = [get_angular_spectrum_kernel(fx, fy, z, wavelength, n_medium) for z in zs] + return xp.stack(kernels, axis=0) + + +@dataclass +class DPCGeometry: + """ + Lightweight container for optical and sampling parameters. + """ + + wavelength_um: float + n_medium: float + na_obj: float + camera_pixel_um: float + magnification: float + focal_offsets_um: Sequence[float] + pattern_order: Sequence[str] + normalize_by_led_count: bool = True + + @property + def dxy_um(self) -> float: + return float(self.camera_pixel_um) / float(self.magnification) + + +class DPCRytovInverse(Optimizer): + """ + FISTA-compatible optimizer for DPC stacks under a linear Rytov model. + """ + + def __init__( + self, + measured: array, + led_na: np.ndarray, + *, + geom: DPCGeometry, + n_shape: Sequence[int, int, int], + drs_n: Sequence[float, float, float], + median_filter_size: Sequence[int, int, int] = (3, 3, 3), + use_gpu: bool = False, + ) -> None: + if measured.ndim == 3: + measured = measured[None, ...] + + n_planes, n_patterns, ny, nx = measured.shape + if n_patterns != 4: + raise ValueError(f"expected 4 DPC patterns, got {n_patterns}") + + xp = _get_xp(use_gpu) + self.use_gpu = bool(use_gpu) + self.xp = xp + self.geom = geom + self.n_shape = tuple(int(s) for s in n_shape) + self.drs_n = tuple(float(d) for d in drs_n) + self.ny = int(ny) + self.nx = int(nx) + self.n_planes = int(n_planes) + self.n_patterns = int(n_patterns) + self.pattern_order = tuple(geom.pattern_order) + self.normalize_by_led_count = bool(geom.normalize_by_led_count) + self.median_filter_size = tuple(int(s) for s in median_filter_size) + + self.data = xp.asarray(measured) + self.led_na = np.asarray(led_na, dtype=float) + self.pattern_led_indices = _split_dpc_patterns(self.led_na, self.pattern_order) + self.pattern_led_indices_gpu = ( + [xp.asarray(idx) for idx in self.pattern_led_indices] if xp is cp else self.pattern_led_indices + ) + + # Sampling/grid + self.dxy = float(geom.dxy_um) + self.drs_e = (self.dxy, self.dxy) + + # Forward operator (Rytov linear scattering) for all LEDs at once + beam_fx = xp.asarray(self.led_na[:, 0] / geom.wavelength_um) + beam_fy = xp.asarray(self.led_na[:, 1] / geom.wavelength_um) + beam_fz = xp.sqrt((geom.n_medium / geom.wavelength_um) ** 2 - beam_fx**2 - beam_fy**2) + self.model = fwd_model_linear( + beam_fx, + beam_fy, + beam_fz, + geom.n_medium, + geom.na_obj, + geom.wavelength_um, + (self.ny, self.nx), + self.drs_e, + self.n_shape, + self.drs_n, + mode="rytov", + interpolate=False, + use_gpu=self.use_gpu, + ) + + # Defocus transfer functions per plane + if geom.focal_offsets_um and len(geom.focal_offsets_um) not in (1, n_planes): + raise ValueError( + f"focal_offsets_um length ({len(geom.focal_offsets_um)}) must match n_planes ({n_planes}) or be 1" + ) + self.focal_offsets = [float(z) for z in geom.focal_offsets_um] if geom.focal_offsets_um else [0.0] + if len(self.focal_offsets) == 1 and n_planes > 1: + self.focal_offsets = [self.focal_offsets[0]] * n_planes + if len(self.focal_offsets) != n_planes: + raise ValueError(f"n_planes={n_planes} but {len(self.focal_offsets)} focal offsets provided") + self.defocus_kernels = _defocus_kernels( + self.ny, + self.nx, + self.dxy, + self.focal_offsets, + geom.wavelength_um, + geom.n_medium, + xp, + ) + self.e0 = xp.ones((self.n_planes, self.ny, self.nx), dtype=complex) + + # Map sample indices -> (plane, pattern) + self.sample_plane = np.repeat(np.arange(self.n_planes), self.n_patterns) + self.sample_pattern = np.tile(np.arange(self.n_patterns), self.n_planes) + + super().__init__(self.n_planes * self.n_patterns, prox_parameters={"median_filter_size": self.median_filter_size}) + + # --------------------------- + # helpers + # --------------------------- + def _select_indices(self, inds: Optional[Sequence[int]]) -> np.ndarray: + if inds is None: + return np.arange(self.n_samples) + return np.asarray(inds, dtype=int) + + def _predict_fields(self, n_volume: array) -> tuple[array, array]: + """ + Compute per-pattern irradiance predictions and per-LED defocused fields. + """ + xp = self.xp + v = get_v(n_volume, self.geom.n_medium, self.geom.wavelength_um) + v_vec = xp.asarray(v).ravel() + + es_focus = self.model.dot(v_vec).reshape((self.led_na.shape[0], self.ny, self.nx)) + es_ft = ft2(es_focus, shift=False) + es_ft = es_ft[:, None, :, :] * self.defocus_kernels[None, :, :, :] + es_z = ift2(es_ft, shift=False) # (n_led, n_planes, ny, nx) + + # move plane axis first for easier broadcasting + es_z = xp.moveaxis(es_z, 1, 0) # (n_planes, n_led, ny, nx) + intensity = xp.abs(self.e0[:, None, :, :] + es_z) ** 2 + + patterns = [] + idx_list = self.pattern_led_indices_gpu + for idx in idx_list: + pat = intensity[:, idx, :, :].sum(axis=1) + if self.normalize_by_led_count: + pat = pat / idx.shape[0] + patterns.append(pat) + + pred = xp.stack(patterns, axis=1) # (n_planes, 4, ny, nx) + return pred, es_z + + # --------------------------- + # Optimizer interface + # --------------------------- + def fwd_model(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + pred, _ = self._predict_fields(x) + pred_flat = pred.reshape((self.n_samples, self.ny, self.nx)) + sel = self._select_indices(inds) + return pred_flat[sel] + + def cost(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + xp = self.xp + pred, _ = self._predict_fields(x) + residual = pred - self.data + residual_flat = residual.reshape((self.n_samples, self.ny, self.nx)) + sel = self._select_indices(inds) + r = residual_flat[sel] + return 0.5 * xp.sum(xp.abs(r) ** 2, axis=(-1, -2)) + + def gradient(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + xp = self.xp + pred, es_z = self._predict_fields(x) + n_led = self.led_na.shape[0] + + # mask residuals to selected samples + mask = np.zeros(self.n_samples, dtype=bool) + sel = self._select_indices(inds) + mask[sel] = True + mask_stack = xp.asarray(mask.reshape((self.n_planes, self.n_patterns, 1, 1))) + + residual = (pred - self.data) * mask_stack + + # accumulate per-LED gradient at focus plane + grad_es0 = xp.zeros((n_led, self.ny, self.nx), dtype=complex) + idx_list = self.pattern_led_indices_gpu + for p_idx, led_idx in enumerate(idx_list): + res_p = residual[:, p_idx, :, :] # (n_planes, ny, nx) + scale = 1.0 / led_idx.shape[0] if self.normalize_by_led_count else 1.0 + for l in led_idx.tolist(): + # gradient of irradiance wrt field at each plane + g_z = res_p * (self.e0 + es_z[:, l, :, :]) * (2.0 * scale) + # backpropagate defocus for this LED across planes + g_ft = ft2(g_z, adjoint=True, shift=False) # (n_planes, ny, nx) + back_ft = xp.sum(g_ft * xp.conj(self.defocus_kernels), axis=0) + grad_es0[l] += ift2(back_ft, shift=False) + + # adjoint of linear model + grad_v_vec = self.model.getH().dot(grad_es0.ravel()) + grad_v = grad_v_vec.reshape(self.n_shape) + + # chain rule dv/dn = -2*(2*pi/lambda)^2 * n + n_volume = xp.asarray(x) + factor = -2 * (2 * np.pi / self.geom.wavelength_um) ** 2 * n_volume + grad_n = grad_v * factor + # broadcast gradient to requested batch; each selected sample shares the same volume gradient + return xp.broadcast_to(xp.expand_dims(grad_n, axis=0), (len(sel),) + grad_n.shape) + + def prox(self, x: array, step: float) -> array: + xp = self.xp + x_real = median_prox(x.real, self.median_filter_size) + return xp.asarray(x_real, dtype=x.dtype) + + def guess_step(self, x: Optional[array] = None) -> float: + """ + Estimate a Lipschitz-consistent step size using the same spectral norm + logic as `LinearScatt.guess_step`, reusing the sparse linear Rytov model. + + We ignore the mild nonlinearity from intensity formation; this gives a + conservative step similar to other linear-scattering solvers in mcsim. + """ + # dominant singular value of the linear model + try: + m_for_svd = self.model.get() if sp_gpu and isinstance(self.model, sp_gpu.csr_matrix) else self.model + u, s, vh = svds(m_for_svd, k=1, which="LM") + sigma = float(s[0]) + except Exception: + # fallback if svds not available + sigma = 1.0 + + # exactly match LinearScatt: L ~ sigma^2 / (n_samples * ny * nx) + lipschitz_estimate = sigma**2 / (self.n_samples * self.ny * self.nx) + if not np.isfinite(lipschitz_estimate) or lipschitz_estimate <= 0: + return 1e-3 + + return float(1.0 / lipschitz_estimate) diff --git a/tests/test_dpc_inverse.py b/tests/test_dpc_inverse.py new file mode 100644 index 0000000..e4f97c2 --- /dev/null +++ b/tests/test_dpc_inverse.py @@ -0,0 +1,139 @@ +import numpy as np +import pytest + +from mcsim.analysis.dpc_inverse import DPCGeometry, DPCRytovInverse + + +def _make_simple_solver(ny=16, nx=16, n_planes=2): + # Four LEDs, one per half-plane + led_na = np.array( + [ + [-0.1, 0.0], # left + [0.1, 0.0], # right + [0.0, 0.1], # up + [0.0, -0.1], # down + ], + dtype=float, + ) + + geom = DPCGeometry( + wavelength_um=0.5, + n_medium=1.0, + na_obj=0.8, + camera_pixel_um=2.0, + magnification=10.0, + focal_offsets_um=[0.0] * n_planes, + pattern_order=("left", "right", "up", "down"), + normalize_by_led_count=True, + ) + + n_shape = (2, ny, nx) + drs_n = (0.5, geom.dxy_um, geom.dxy_um) + + # simple ground truth RI + n0 = np.full(n_shape, geom.n_medium, dtype=np.float32) + n0[0, ny // 2, nx // 2] += 1e-3 # small perturbation + + # simulate data from the forward model + solver = DPCRytovInverse( + np.zeros((n_planes, 4, ny, nx), dtype=np.float32), + led_na, + geom=geom, + n_shape=n_shape, + drs_n=drs_n, + use_gpu=False, + ) + dpc_pred, _ = solver._predict_fields(n0) + + # re-instantiate with the simulated data + solver = DPCRytovInverse( + dpc_pred, + led_na, + geom=geom, + n_shape=n_shape, + drs_n=drs_n, + use_gpu=False, + ) + return solver, n0 + + +def test_forward_shape_and_values(): + solver, n0 = _make_simple_solver() + pred, _ = solver._predict_fields(n0) + assert pred.shape == solver.data.shape == (solver.n_planes, 4, solver.ny, solver.nx) + # Forward evaluated at ground truth should match data + np.testing.assert_allclose(pred, solver.data, rtol=1e-5, atol=1e-6) + + +def test_gradient_matches_numeric(): + solver, n0 = _make_simple_solver() + g, gn = solver.test_gradient(n0, jind=0, dx=1e-6) + np.testing.assert_allclose(g, gn, rtol=1e-3, atol=1e-5) + + +def _make_multiled_solver(ny=16, nx=16, n_planes=2): + # Multiple LEDs per half-plane + led_na = np.array( + [ + [-0.15, 0.0], + [-0.05, 0.05], + [0.15, 0.0], + [0.05, -0.05], + [0.0, 0.15], + [-0.05, 0.1], + [0.0, -0.15], + [0.05, -0.1], + ], + dtype=float, + ) + + geom = DPCGeometry( + wavelength_um=0.5, + n_medium=1.0, + na_obj=0.8, + camera_pixel_um=2.0, + magnification=10.0, + focal_offsets_um=[0.0] * n_planes, + pattern_order=("left", "right", "up", "down"), + normalize_by_led_count=True, + ) + + n_shape = (2, ny, nx) + drs_n = (0.5, geom.dxy_um, geom.dxy_um) + + n0 = np.full(n_shape, geom.n_medium, dtype=np.float32) + n0[1, ny // 2, nx // 2] += 5e-4 + + solver = DPCRytovInverse( + np.zeros((n_planes, 4, ny, nx), dtype=np.float32), + led_na, + geom=geom, + n_shape=n_shape, + drs_n=drs_n, + use_gpu=False, + ) + dpc_pred, _ = solver._predict_fields(n0) + + solver = DPCRytovInverse( + dpc_pred, + led_na, + geom=geom, + n_shape=n_shape, + drs_n=drs_n, + use_gpu=False, + ) + return solver, n0 + + +def test_forward_multiled_shape_and_values(): + solver, n0 = _make_multiled_solver() + pred, _ = solver._predict_fields(n0) + assert pred.shape == solver.data.shape == (solver.n_planes, 4, solver.ny, solver.nx) + np.testing.assert_allclose(pred, solver.data, rtol=1e-5, atol=1e-6) + + +@pytest.mark.slow +def test_gradient_multiled_matches_numeric(): + solver, n0 = _make_multiled_solver() + g, gn = solver.test_gradient(n0, jind=0, dx=1e-6) + np.testing.assert_allclose(g, gn, rtol=1e-3, atol=1e-5) From ae9f6cd21c1305158d130c0a5558a298c4400100 Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Thu, 1 Jan 2026 20:10:26 -0700 Subject: [PATCH 04/10] enable GPU in test --- tests/test_dpc_inverse.py | 64 +++++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/tests/test_dpc_inverse.py b/tests/test_dpc_inverse.py index e4f97c2..6fe3106 100644 --- a/tests/test_dpc_inverse.py +++ b/tests/test_dpc_inverse.py @@ -1,10 +1,15 @@ import numpy as np import pytest +try: + import cupy as cp # type: ignore +except ImportError: + cp = None + from mcsim.analysis.dpc_inverse import DPCGeometry, DPCRytovInverse -def _make_simple_solver(ny=16, nx=16, n_planes=2): +def _make_simple_solver(ny=16, nx=16, n_planes=2, use_gpu: bool = False): # Four LEDs, one per half-plane led_na = np.array( [ @@ -33,6 +38,8 @@ def _make_simple_solver(ny=16, nx=16, n_planes=2): # simple ground truth RI n0 = np.full(n_shape, geom.n_medium, dtype=np.float32) n0[0, ny // 2, nx // 2] += 1e-3 # small perturbation + if use_gpu and cp is not None: + n0 = cp.asarray(n0) # simulate data from the forward model solver = DPCRytovInverse( @@ -41,7 +48,7 @@ def _make_simple_solver(ny=16, nx=16, n_planes=2): geom=geom, n_shape=n_shape, drs_n=drs_n, - use_gpu=False, + use_gpu=use_gpu, ) dpc_pred, _ = solver._predict_fields(n0) @@ -52,26 +59,35 @@ def _make_simple_solver(ny=16, nx=16, n_planes=2): geom=geom, n_shape=n_shape, drs_n=drs_n, - use_gpu=False, + use_gpu=use_gpu, ) return solver, n0 -def test_forward_shape_and_values(): - solver, n0 = _make_simple_solver() +@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) +def test_forward_shape_and_values(use_gpu): + solver, n0 = _make_simple_solver(use_gpu=use_gpu) pred, _ = solver._predict_fields(n0) assert pred.shape == solver.data.shape == (solver.n_planes, 4, solver.ny, solver.nx) # Forward evaluated at ground truth should match data - np.testing.assert_allclose(pred, solver.data, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose( + cp.asnumpy(pred) if cp and use_gpu else pred, + cp.asnumpy(solver.data) if cp and use_gpu else solver.data, + rtol=1e-5, + atol=1e-6, + ) -def test_gradient_matches_numeric(): - solver, n0 = _make_simple_solver() +@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) +def test_gradient_matches_numeric(use_gpu): + solver, n0 = _make_simple_solver(use_gpu=use_gpu) g, gn = solver.test_gradient(n0, jind=0, dx=1e-6) - np.testing.assert_allclose(g, gn, rtol=1e-3, atol=1e-5) + g_np = cp.asnumpy(g) if cp and use_gpu else g + gn_np = cp.asnumpy(gn) if cp and use_gpu else gn + np.testing.assert_allclose(g_np, gn_np, rtol=1e-3, atol=1e-5) -def _make_multiled_solver(ny=16, nx=16, n_planes=2): +def _make_multiled_solver(ny=16, nx=16, n_planes=2, use_gpu: bool = False): # Multiple LEDs per half-plane led_na = np.array( [ @@ -103,6 +119,8 @@ def _make_multiled_solver(ny=16, nx=16, n_planes=2): n0 = np.full(n_shape, geom.n_medium, dtype=np.float32) n0[1, ny // 2, nx // 2] += 5e-4 + if use_gpu and cp is not None: + n0 = cp.asarray(n0) solver = DPCRytovInverse( np.zeros((n_planes, 4, ny, nx), dtype=np.float32), @@ -110,7 +128,7 @@ def _make_multiled_solver(ny=16, nx=16, n_planes=2): geom=geom, n_shape=n_shape, drs_n=drs_n, - use_gpu=False, + use_gpu=use_gpu, ) dpc_pred, _ = solver._predict_fields(n0) @@ -120,20 +138,28 @@ def _make_multiled_solver(ny=16, nx=16, n_planes=2): geom=geom, n_shape=n_shape, drs_n=drs_n, - use_gpu=False, + use_gpu=use_gpu, ) return solver, n0 -def test_forward_multiled_shape_and_values(): - solver, n0 = _make_multiled_solver() +@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) +def test_forward_multiled_shape_and_values(use_gpu): + solver, n0 = _make_multiled_solver(use_gpu=use_gpu) pred, _ = solver._predict_fields(n0) assert pred.shape == solver.data.shape == (solver.n_planes, 4, solver.ny, solver.nx) - np.testing.assert_allclose(pred, solver.data, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose( + cp.asnumpy(pred) if cp and use_gpu else pred, + cp.asnumpy(solver.data) if cp and use_gpu else solver.data, + rtol=1e-5, + atol=1e-6, + ) -@pytest.mark.slow -def test_gradient_multiled_matches_numeric(): - solver, n0 = _make_multiled_solver() +@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) +def test_gradient_multiled_matches_numeric(use_gpu): + solver, n0 = _make_multiled_solver(use_gpu=use_gpu) g, gn = solver.test_gradient(n0, jind=0, dx=1e-6) - np.testing.assert_allclose(g, gn, rtol=1e-3, atol=1e-5) + g_np = cp.asnumpy(g) if cp and use_gpu else g + gn_np = cp.asnumpy(gn) if cp and use_gpu else gn + np.testing.assert_allclose(g_np, gn_np, rtol=1e-3, atol=1e-5) From 177cdfee0511f165f08bb34db9d5e74eacbdf536 Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Thu, 1 Jan 2026 20:14:10 -0700 Subject: [PATCH 05/10] more test coverage --- tests/test_dpc_inverse.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_dpc_inverse.py b/tests/test_dpc_inverse.py index 6fe3106..9919050 100644 --- a/tests/test_dpc_inverse.py +++ b/tests/test_dpc_inverse.py @@ -163,3 +163,29 @@ def test_gradient_multiled_matches_numeric(use_gpu): g_np = cp.asnumpy(g) if cp and use_gpu else g gn_np = cp.asnumpy(gn) if cp and use_gpu else gn np.testing.assert_allclose(g_np, gn_np, rtol=1e-3, atol=1e-5) + + +@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) +def test_reconstruction_improves_mse(use_gpu): + solver, n_true = _make_simple_solver(ny=12, nx=12, n_planes=1, use_gpu=use_gpu) + xp = cp if (use_gpu and cp is not None) else np + + n_init = xp.full_like(n_true, solver.geom.n_medium) + + step = solver.guess_step() + res = solver.run( + n_init, + step=step, + max_iterations=6, + use_fista=True, + compute_cost=False, + verbose=False, + compute_all_costs=False, + line_search_iter_limit=None, + label="recon-test ", + ) + n_rec = res["x"] + + mse_init = xp.mean(xp.abs(n_init - n_true) ** 2) + mse_final = xp.mean(xp.abs(n_rec - n_true) ** 2) + assert float(mse_final) < float(mse_init) From 8f6ef1e24e396dbece54fb26075c1bf388124bb8 Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Thu, 1 Jan 2026 20:15:59 -0700 Subject: [PATCH 06/10] increase test strigency --- tests/test_dpc_inverse.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_dpc_inverse.py b/tests/test_dpc_inverse.py index 9919050..fe1a273 100644 --- a/tests/test_dpc_inverse.py +++ b/tests/test_dpc_inverse.py @@ -176,7 +176,7 @@ def test_reconstruction_improves_mse(use_gpu): res = solver.run( n_init, step=step, - max_iterations=6, + max_iterations=50, use_fista=True, compute_cost=False, verbose=False, @@ -188,4 +188,5 @@ def test_reconstruction_improves_mse(use_gpu): mse_init = xp.mean(xp.abs(n_init - n_true) ** 2) mse_final = xp.mean(xp.abs(n_rec - n_true) ** 2) - assert float(mse_final) < float(mse_init) + # Expect a substantial reduction + assert float(mse_final) < 0.1 * float(mse_init) From c6577007f58a63e28a14e30625bf3b6ad0565062 Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Thu, 1 Jan 2026 20:17:55 -0700 Subject: [PATCH 07/10] test wip --- tests/test_dpc_inverse.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_dpc_inverse.py b/tests/test_dpc_inverse.py index fe1a273..cf1c263 100644 --- a/tests/test_dpc_inverse.py +++ b/tests/test_dpc_inverse.py @@ -9,7 +9,7 @@ from mcsim.analysis.dpc_inverse import DPCGeometry, DPCRytovInverse -def _make_simple_solver(ny=16, nx=16, n_planes=2, use_gpu: bool = False): +def _make_simple_solver(ny=16, nx=16, n_planes=2, use_gpu: bool = False, delta: float = 1e-3): # Four LEDs, one per half-plane led_na = np.array( [ @@ -37,7 +37,7 @@ def _make_simple_solver(ny=16, nx=16, n_planes=2, use_gpu: bool = False): # simple ground truth RI n0 = np.full(n_shape, geom.n_medium, dtype=np.float32) - n0[0, ny // 2, nx // 2] += 1e-3 # small perturbation + n0[0, ny // 2, nx // 2] += float(delta) # small perturbation if use_gpu and cp is not None: n0 = cp.asarray(n0) @@ -167,7 +167,7 @@ def test_gradient_multiled_matches_numeric(use_gpu): @pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) def test_reconstruction_improves_mse(use_gpu): - solver, n_true = _make_simple_solver(ny=12, nx=12, n_planes=1, use_gpu=use_gpu) + solver, n_true = _make_simple_solver(ny=12, nx=12, n_planes=1, use_gpu=use_gpu, delta=5e-3) xp = cp if (use_gpu and cp is not None) else np n_init = xp.full_like(n_true, solver.geom.n_medium) @@ -176,7 +176,7 @@ def test_reconstruction_improves_mse(use_gpu): res = solver.run( n_init, step=step, - max_iterations=50, + max_iterations=80, use_fista=True, compute_cost=False, verbose=False, @@ -188,5 +188,5 @@ def test_reconstruction_improves_mse(use_gpu): mse_init = xp.mean(xp.abs(n_init - n_true) ** 2) mse_final = xp.mean(xp.abs(n_rec - n_true) ** 2) - # Expect a substantial reduction - assert float(mse_final) < 0.1 * float(mse_init) + # Expect a clear reduction + assert float(mse_final) < 0.5 * float(mse_init) From 2cc2955a597212b111aa76e14911f24a671b01e7 Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Fri, 13 Feb 2026 21:26:58 -0700 Subject: [PATCH 08/10] Tikhonov and FISTA DPC models --- examples/build_dpc_mie_stack.py | 112 +++++ examples/run_wotf_fista_mie.py | 277 +++++++++++++ examples/simulate_dpc_microsphere.py | 127 +++++- mcsim/analysis/dpc_fista_solver.py | 99 +++++ mcsim/analysis/dpc_inverse.py | 322 --------------- mcsim/analysis/dpc_meta.py | 89 ++++ mcsim/analysis/optimize.py | 4 + mcsim/analysis/tv_prox_fast.py | 402 ++++++++++++++++++ mcsim/analysis/wotf_fista.py | 591 +++++++++++++++++++++++++++ tests/test_dpc_inverse.py | 192 --------- 10 files changed, 1680 insertions(+), 535 deletions(-) create mode 100644 examples/build_dpc_mie_stack.py create mode 100644 examples/run_wotf_fista_mie.py create mode 100644 mcsim/analysis/dpc_fista_solver.py delete mode 100644 mcsim/analysis/dpc_inverse.py create mode 100644 mcsim/analysis/dpc_meta.py create mode 100644 mcsim/analysis/tv_prox_fast.py create mode 100644 mcsim/analysis/wotf_fista.py delete mode 100644 tests/test_dpc_inverse.py diff --git a/examples/build_dpc_mie_stack.py b/examples/build_dpc_mie_stack.py new file mode 100644 index 0000000..a0f4ed7 --- /dev/null +++ b/examples/build_dpc_mie_stack.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +Generate the default DPC Mie simulation stack used by the example scripts. +""" + +from __future__ import annotations + +from pathlib import Path +from dataclasses import dataclass + +import numpy as np + +try: + import cupy as cp # type: ignore +except Exception: + cp = None + +from simulate_dpc_microsphere import SphereSpec, simulate_dpc_images_sphere_to_zarr + + +@dataclass(frozen=True) +class SimParams: + wavelength_um: float = 0.515 + na_obj: float = 0.8 + led_grid_shape: tuple[int, int] = (64, 64) + pitch_mm: float = 2.5 + camera_pixel_um: float = 2.4 + magnification: float = 20.0 + camera_oversample: int = 1 + esize_camera: tuple[int, int] = (256, 256) + z_plane_um: float = 0.0 + sphere: SphereSpec = SphereSpec(radius_um=1., n_sphere=1.59, n_medium=1.55) + inner_na: float = 0.0 + include_center_led: bool = False + pattern_order: tuple[str, str, str, str] = ("left", "right", "up", "down") + normalize_by_led_count: bool = True + led_subsample: int = 1 + mie_kwargs: dict | None = None + camera_gains: float = 2.0 + camera_offsets: float = 100.0 + camera_readout_noise_sds: float = 3.0 + camera_photon_shot_noise: bool = True + camera_saturation: float | None = None + camera_image_is_integer: bool = True + exposure_time_ms: float = 10.0 + illumination_photons_per_s_per_um2: float = 1e7 + focal_stack_planes: int = 31 + focal_stack_step_um: float = 0.325 + + @property + def z_planes_um(self) -> np.ndarray: + return ( + np.arange(self.focal_stack_planes) - self.focal_stack_planes // 2 + ) * self.focal_stack_step_um + + +def ensure_dpc_mie_stack(zarr_path: Path, *, use_gpu: bool) -> None: + """ + Ensure the default DPC Mie simulation stack exists at the provided path. + """ + if zarr_path.exists(): + return + params = SimParams() + zarr_path.parent.mkdir(parents=True, exist_ok=True) + simulate_dpc_images_sphere_to_zarr( + zarr_path, + wavelength_um=params.wavelength_um, + na_obj=params.na_obj, + led_grid_shape=params.led_grid_shape, + pitch_mm=params.pitch_mm, + camera_pixel_um=params.camera_pixel_um, + magnification=params.magnification, + camera_oversample=params.camera_oversample, + esize_camera=params.esize_camera, + z_plane_um=params.z_plane_um, + sphere=params.sphere, + inner_na=params.inner_na, + include_center_led=params.include_center_led, + pattern_order=params.pattern_order, + normalize_by_led_count=params.normalize_by_led_count, + led_subsample=params.led_subsample, + use_gpu=use_gpu, + mie_kwargs=params.mie_kwargs, + camera_gains=params.camera_gains, + camera_offsets=params.camera_offsets, + camera_readout_noise_sds=params.camera_readout_noise_sds, + camera_photon_shot_noise=params.camera_photon_shot_noise, + camera_saturation=params.camera_saturation, + camera_image_is_integer=params.camera_image_is_integer, + exposure_time_ms=params.exposure_time_ms, + illumination_photons_per_s_per_um2=params.illumination_photons_per_s_per_um2, + focal_stack_planes=params.focal_stack_planes, + focal_stack_step_um=params.focal_stack_step_um, + overwrite=True, + reuse_cache=True, + ) + + +def main() -> None: + default_path = Path("build/dpc_mie_stack.zarr") + use_gpu = False + if cp is not None: + try: + use_gpu = bool(cp.is_available()) + except Exception: + use_gpu = False + ensure_dpc_mie_stack(default_path, use_gpu=use_gpu) + print(f"Simulation written to {default_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/run_wotf_fista_mie.py b/examples/run_wotf_fista_mie.py new file mode 100644 index 0000000..39d4882 --- /dev/null +++ b/examples/run_wotf_fista_mie.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +""" +Run WOTF FISTA reconstruction on a Mie-simulated DPC stack. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np +import zarr # type: ignore + +try: + import cupy as cp # type: ignore +except Exception: + cp = None + +from mcsim.analysis.dpc_meta import DPCMeta +from mcsim.analysis.fft import ft3, ift3 +from mcsim.analysis.field_prop import get_n +from mcsim.analysis.optimize import to_cpu +from mcsim.analysis.wotf_fista import ( + WOTFParams, + WOTFFISTAOptimizer, + build_led_na_grid, + build_led_pattern_membership, + build_wotf_transfer, + camera_to_photons, +) +from build_dpc_mie_stack import ensure_dpc_mie_stack + + +def _load_simulated_stack(zarr_path: Path) -> tuple[np.ndarray, dict, np.ndarray]: + g = zarr.open_group(str(zarr_path), mode="r") + if "dpc" not in g: + raise ValueError("Missing 'dpc' dataset in simulation output.") + dpc = np.asarray(g["dpc"], dtype=np.float32) + + meta_group = g.get("meta") + if meta_group is None or "focal_offsets_um" not in meta_group: + raise ValueError("Missing 'meta/focal_offsets_um' in simulation output.") + z_planes_um = np.asarray(meta_group["focal_offsets_um"], dtype=float) + + if dpc.ndim == 3: + if dpc.shape[0] != 4: + raise ValueError(f"Expected dpc shape (4, ny, nx), got {dpc.shape}") + dpc = dpc[None, ...] + elif dpc.ndim == 4: + if dpc.shape[1] != 4: + raise ValueError(f"Expected dpc shape (nz, 4, ny, nx), got {dpc.shape}") + else: + raise ValueError(f"Unexpected dpc ndim: {dpc.ndim}") + + I_cam = np.moveaxis(dpc, 1, 0) + return I_cam, dict(g.attrs), z_planes_um + + +def _build_meta(I_cam: np.ndarray, attrs: dict, z_planes_um: np.ndarray) -> DPCMeta: + nz = int(I_cam.shape[1]) + ny = int(I_cam.shape[2]) + nx = int(I_cam.shape[3]) + if z_planes_um.size != nz: + raise ValueError("z_planes_um length must match stack depth.") + + wavelength_um = float(attrs["wavelength_um"]) + na_obj = float(attrs["na_obj"]) + camera_pixel_um = float(attrs["camera_pixel_um"]) + magnification = float(attrs["magnification"]) + n_medium = float(attrs.get("n_medium", 1.0)) + led_grid_shape = tuple(int(v) for v in attrs.get("led_grid_shape", (64, 64))) + pattern_order = tuple(attrs.get("pattern_order", ("left", "right", "up", "down"))) + inner_na = float(attrs.get("inner_na", 0.0)) + include_center_led = bool(attrs.get("include_center_led", False)) + + dz = 0.0 + if z_planes_um.size > 1: + dzs = np.diff(z_planes_um) + if not np.allclose(dzs, dzs[0]): + raise ValueError("z_planes_um must be evenly spaced.") + dz = float(dzs[0]) + + dxy_um = camera_pixel_um / magnification + + return DPCMeta( + wavelength_um=wavelength_um, + n_background=n_medium, + NA_obj=na_obj, + magnification=magnification, + camera_pixel_pitch_um=camera_pixel_um, + volume_shape_zyx=(nz, ny, nx), + voxel_size_um_zyx=(dz, dxy_um, dxy_um), + z_planes_um=z_planes_um, + led_grid_shape=led_grid_shape, + inner_na=inner_na, + include_center_led=include_center_led, + pattern_order=pattern_order, # type: ignore[arg-type] + ) + + +def _tikhonov_wotf( + contrast: np.ndarray, + H_real: np.ndarray, + *, + reg: float, +) -> np.ndarray: + contrast_ft = ft3(contrast, axes=(1, 2, 3), shift=True) + num = np.sum(np.conj(H_real) * contrast_ft, axis=0) + den = np.sum(np.abs(H_real) ** 2, axis=0) + float(reg) + v_ft_real = num / den + v_real = ift3(v_ft_real, axes=(0, 1, 2), shift=True).real + return v_real.astype(np.float32, copy=False) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--sim-zarr", type=Path, default=Path("build/dpc_mie_stack.zarr")) + parser.add_argument("--out", type=Path, default=Path("build/wotf_fista_mie.zarr")) + parser.add_argument("--use-gpu", action="store_true") + parser.add_argument("--step", type=float, default=5.0e-5) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--tv-weight", type=float, default=0.0) + parser.add_argument("--tv-max-num-iter", type=int, default=50) + parser.add_argument("--tv-eps", type=float, default=2.0e-4) + parser.add_argument("--tv-aniso-z", type=float, default=1.0) + parser.add_argument("--z-taper", type=int, default=8) + parser.add_argument("--xy-taper", type=int, default=None) + parser.add_argument("--pupil-taper-na", type=float, default=0.0) + parser.add_argument("--eps", type=float, default=1.0e-8) + parser.add_argument("--line-search", action="store_true") + parser.add_argument("--line-search-factor", type=float, default=0.5) + parser.add_argument("--restart-line-search", action="store_true") + parser.add_argument("--tikhonov-reg", type=float, default=5.0e-4) + parser.add_argument("--led-subsample", type=int, default=1) + parser.add_argument("--pad-z", type=int, default=64) + parser.add_argument("--pad-yx", type=int, default=64) + args = parser.parse_args() + + use_gpu = bool(args.use_gpu and cp is not None) + print(f"Using GPU: {use_gpu}") + + ensure_dpc_mie_stack(args.sim_zarr, use_gpu=use_gpu) + I_cam, attrs, z_planes_um = _load_simulated_stack(args.sim_zarr) + meta = _build_meta(I_cam, attrs, z_planes_um) + + camera_gain = attrs.get("camera_gains", 1.0) + camera_offset = attrs.get("camera_offsets", 0.0) + I_cam_xp = cp.asarray(I_cam) if use_gpu else np.asarray(I_cam) + I_meas_phot = camera_to_photons( + I_cam_xp, + camera_offset_adu=camera_offset, + camera_gain_photons_per_adu=camera_gain, + ) + + I_meas = np.asarray(to_cpu(I_meas_phot), dtype=np.float32) + I0_pred = np.mean(I_meas, axis=(2, 3), keepdims=True).astype(np.float32, copy=False) + + led_na_xy = build_led_na_grid( + meta.led_grid_shape[0], + meta.led_grid_shape[1], + na_obj=meta.NA_obj, + na_in=meta.inner_na, + include_center=meta.include_center_led, + led_subsample=int(args.led_subsample), + ) + membership = build_led_pattern_membership( + led_na_xy, + order=meta.pattern_order, + include_center=meta.include_center_led, + ) + + params = WOTFParams( + wavelength_um=meta.wavelength_um, + na_obj=meta.NA_obj, + na_in=meta.inner_na, + n0=meta.n_background, + dxy_um=meta.dxy_um, + dz_um=float(meta.voxel_size_um_zyx[0]), + nz=meta.volume_shape_zyx[0] + 2 * int(args.pad_z), + ny=meta.volume_shape_zyx[1] + 2 * int(args.pad_yx), + nx=meta.volume_shape_zyx[2] + 2 * int(args.pad_yx), + pupil_taper_na=float(args.pupil_taper_na), + ) + H_real, H_imag = build_wotf_transfer(led_na_xy, membership, params) + + xp = cp if use_gpu else np + n0_vol = xp.full(meta.volume_shape_zyx, float(meta.n_background), dtype=xp.float32) + + contrast = (I_meas / I0_pred - 1.0) / -1.0 + pad_z = int(args.pad_z) + pad_yx = int(args.pad_yx) + if pad_z or pad_yx: + contrast = np.pad( + contrast, + ((0, 0), (pad_z, pad_z), (pad_yx, pad_yx), (pad_yx, pad_yx)), + mode="constant", + ) + v_real = _tikhonov_wotf(contrast, H_real, reg=float(args.tikhonov_reg)) + if pad_z or pad_yx: + nz, ny, nx = meta.volume_shape_zyx + v_real = v_real[pad_z : pad_z + nz, pad_yx : pad_yx + ny, pad_yx : pad_yx + nx] + n_tikh = get_n(v_real, meta.n_background, meta.wavelength_um).real.astype(np.float32, copy=False) + + optimizer = WOTFFISTAOptimizer( + xp.asarray(I_meas), + xp.asarray(I0_pred), + xp.asarray(H_real), + xp.asarray(H_imag), + n0=meta.n_background, + wavelength_um=meta.wavelength_um, + eps=float(args.eps), + tv_weight=float(args.tv_weight), + tv_max_num_iter=int(args.tv_max_num_iter), + tv_eps=float(args.tv_eps), + tv_voxel_size_zyx=(params.dz_um, params.dxy_um, params.dxy_um), + tv_weight_scale_zyx=(float(args.tv_aniso_z), 1.0, 1.0), + pad_zyx=(int(args.pad_z), int(args.pad_yx), int(args.pad_yx)), + z_taper=int(args.z_taper), + xy_taper=args.xy_taper, + use_real_constraint=True, + ) + + n_init = (cp.asarray if use_gpu else np.asarray)( + np.full(meta.volume_shape_zyx, meta.n_background, dtype=np.float32) + ) + line_search_iter_limit = None if args.line_search else 0 + result = optimizer.run( + n_init, + step=float(args.step), + max_iterations=int(args.iters), + use_fista=True, + n_batch=1, + compute_batch_grad_parallel=True, + verbose=True, + compute_cost=True, + line_search_iter_limit=line_search_iter_limit, + line_search_factor=float(args.line_search_factor), + restart_line_search=bool(args.restart_line_search), + ) + n_fista = np.asarray(to_cpu(result["x"]), dtype=np.float32) + + I_pred_fista = np.asarray( + to_cpu( + optimizer._forward_intensity( # type: ignore[attr-defined] + cp.asarray(n_fista) if use_gpu else n_fista + ) + ), + dtype=np.float32, + ) + + out_path = args.out + out_path.parent.mkdir(parents=True, exist_ok=True) + g = zarr.open_group(str(out_path), mode="w", zarr_format=3) + g.attrs.update(meta.as_dict()) + g.attrs["use_gpu"] = bool(use_gpu) + g.attrs["iters"] = int(args.iters) + g.attrs["step"] = float(args.step) + g.attrs["tv_weight"] = float(args.tv_weight) + g.attrs["z_taper"] = int(args.z_taper) + g.attrs["xy_taper"] = None if args.xy_taper is None else int(args.xy_taper) + g.attrs["pupil_taper_na"] = float(args.pupil_taper_na) + g.attrs["tikhonov_reg"] = float(args.tikhonov_reg) + g.attrs["init_from_tikhonov"] = True + g.create_array("I_meas", shape=I_meas.shape, dtype="float32")[...] = I_meas + g.create_array("I0_pred", shape=I0_pred.shape, dtype="float32")[...] = I0_pred + g.create_array("n_fista", shape=n_fista.shape, dtype="float32")[...] = n_fista + g.create_array("n_tikh", shape=n_tikh.shape, dtype="float32")[...] = n_tikh + g.create_array("I_pred_fista", shape=I_pred_fista.shape, dtype="float32")[...] = I_pred_fista + g.create_array("led_na_xy", shape=led_na_xy.shape, dtype="float32")[...] = led_na_xy.astype(np.float32, copy=False) + g.create_array("pattern_membership", shape=membership.shape, dtype="uint8")[...] = membership.astype( + np.uint8, copy=False + ) + print(f"Wrote outputs to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/simulate_dpc_microsphere.py b/examples/simulate_dpc_microsphere.py index 5f89f16..fe5e928 100644 --- a/examples/simulate_dpc_microsphere.py +++ b/examples/simulate_dpc_microsphere.py @@ -71,6 +71,8 @@ except Exception as _e: # pragma: no cover propagate_homogeneous = None +from mcsim.analysis.dpc_fista_solver import LEDBoard, compute_led_geometry + if cp: array = np.ndarray | cp.ndarray @@ -356,12 +358,13 @@ def _open_led_cache_zarr( f"Existing LED cache shape {tuple(I_arr.shape)} != {(int(n_led), int(n_planes), int(ny), int(nx))}" ) else: - create_kwargs = { - "shape": (int(n_led), int(n_planes), int(ny), int(nx)), - "dtype": "float32", - } # Use chunks kwarg for compatibility with current Zarr API - I_arr = g.create_array("I_cam", chunks=chunks, **create_kwargs) # type: ignore[call-arg] + I_arr = g.create_array( + "I_cam", + shape=(int(n_led), int(n_planes), int(ny), int(nx)), + dtype="float32", + chunks=chunks, + ) # type: ignore[call-arg] if "done" in g: done = g["done"] @@ -377,7 +380,16 @@ def _open_led_cache_zarr( return g, I_arr, done -def _zarr_write_array(g, name: str, data: np.ndarray, *, dtype: str = "float32", chunks=None, overwrite: bool = True): +def _zarr_write_array( + g, + name: str, + data: np.ndarray, + *, + dtype: str = "float32", + chunks=None, + compressor=None, + overwrite: bool = True, +): """ Write an array using Zarr v3 `Group.create_array`. @@ -402,10 +414,12 @@ def _zarr_write_array(g, name: str, data: np.ndarray, *, dtype: str = "float32", Written array. """ shape = tuple(int(s) for s in data.shape) - create_kwargs = { - "shape": shape, - "dtype": dtype, - } + create_kwargs = {"shape": shape, "dtype": dtype} + if compressor is not None: + if isinstance(compressor, dict): + create_kwargs.update(compressor) + else: + create_kwargs["compressor"] = compressor if chunks is not None: arr = g.create_array(name, chunks=chunks, **create_kwargs) # type: ignore[call-arg] @@ -916,6 +930,8 @@ class DPCMieSimulator: Camera grid/noise settings forwarded to `localize_psf.camera.simulated_img`. led_grid_shape : tuple[int, int], optional LED board dimensions. + pitch_mm : float, optional + LED pitch (mm) on the board. inner_na : float, optional Inner NA for annular illumination mask. include_center_led : bool, optional @@ -945,6 +961,7 @@ def __init__( simulation: SimulationSpace, camera: CameraSpace, led_grid_shape: tuple[int, int] = (64, 64), + pitch_mm: float = 2.5, inner_na: float = 0.0, include_center_led: bool = False, pattern_order: Sequence[PatternName] = ("left", "right", "up", "down"), @@ -976,6 +993,7 @@ def __init__( self.simulation = simulation self.camera = camera self.led_grid_shape = (int(led_grid_shape[0]), int(led_grid_shape[1])) + self.pitch_mm = float(pitch_mm) self.inner_na = float(inner_na) self.include_center_led = bool(include_center_led) self.pattern_order = pattern_order @@ -1216,9 +1234,9 @@ def simulate_patterns(self) -> tuple[array, dict[str, array]]: # type: ignore Notes ----- - - LED NA grid and pattern membership come from `make_led_na_positions` and - `split_dpc_patterns` using `led_grid_shape`, `inner_na`, `include_center_led`, - and `pattern_order`. + - LED NA grid and pattern membership come from board geometry (n_side, pitch_mm) + mapped to NA via inscribed circle matched to na_obj, then split by + `split_dpc_patterns` using `inner_na`, `include_center_led`, and `pattern_order`. - Each LED image is produced by `_simulate_led` (internally `mie_fields.mie_efield` → `propagate_homogeneous` → `apply_objective_imaging` → `localize_psf.camera.simulated_img`). @@ -1226,13 +1244,21 @@ def simulate_patterns(self) -> tuple[array, dict[str, array]]: # type: ignore """ allowed: tuple[PatternName, ...] = ("left", "right", "up", "down") ny_led, nx_led = self.led_grid_shape - na_xy_all = make_led_na_positions( - ny_led, - nx_led, + board = LEDBoard( + n_side=ny_led, + pitch_mm=float(self.pitch_mm), na_obj=float(self.na_obj), - inner_na=float(self.inner_na), - include_center=bool(self.include_center_led), + wavelength_um=float(self.wavelength_um), + n_medium=float(self.sphere.n_medium), ) + geom = compute_led_geometry(board) + na_xy_all = geom.na_components + if not self.include_center_led: + r = np.linalg.norm(na_xy_all, axis=1) + na_xy_all = na_xy_all[r > 0] + if float(self.inner_na) > 0: + r = np.linalg.norm(na_xy_all, axis=1) + na_xy_all = na_xy_all[r >= float(self.inner_na)] patterns_full = split_dpc_patterns(na_xy_all, order=self.pattern_order) na_xy = na_xy_all[:: self.led_subsample] @@ -1267,6 +1293,7 @@ def simulate_patterns(self) -> tuple[array, dict[str, array]]: # type: ignore "wavelength_um": float(self.wavelength_um), "na_obj": float(self.na_obj), "led_grid_shape": [int(self.led_grid_shape[0]), int(self.led_grid_shape[1])], + "pitch_mm": float(self.pitch_mm), "camera_pixel_um": float(self.camera.pixel_um), "magnification": float(self.camera.magnification), "camera_shape": [int(ny_cam), int(nx_cam)], @@ -1344,7 +1371,7 @@ def simulate_patterns(self) -> tuple[array, dict[str, array]]: # type: ignore dpc = xp.stack([acc[p].copy() for p in self.pattern_order], axis=0) dpc = xp.moveaxis(dpc, 0, 1) # (n_planes, 4, ny, nx) - dpc_out: array + dpc_out: array # type: ignore if dpc.shape[0] == 1: dpc_out = dpc[0] else: @@ -1352,6 +1379,7 @@ def simulate_patterns(self) -> tuple[array, dict[str, array]]: # type: ignore meta: dict[str, array] = { # type: ignore "na_xy": na_xy_all, + "pitch_mm": float(self.pitch_mm), "simulation_dxy_um": np.asarray(self.simulation.dxy_um, dtype=np.float32), "simulation_esize": (int(self.simulation.esize[0]), int(self.simulation.esize[1])), "camera_shape": (int(ny_cam), int(nx_cam)), @@ -1463,6 +1491,7 @@ def simulate_dpc_images_sphere( wavelength_um: float = 0.515, na_obj: float = 0.8, led_grid_shape: tuple[int, int] = (64, 64), + pitch_mm: float = 2.5, camera_pixel_um: float = 2.4, magnification: float = 20.0, camera_oversample: int = 1, @@ -1503,7 +1532,9 @@ def simulate_dpc_images_sphere( Objective NA. Sets the coherent cutoff in `apply_objective_imaging` via `make_pupil`. led_grid_shape : (int, int) - LED board shape (ny_led, nx_led) for `make_led_na_positions`. + LED board shape (ny_led, nx_led). + pitch_mm : float + LED pitch on the board (mm), used to map board radius to illumination NA. camera_pixel_um : float Camera pixel pitch (µm) in the camera plane; combined with `magnification` to derive object-plane sampling and the bin size for `localize_psf.camera.simulated_img`. @@ -1621,6 +1652,7 @@ def simulate_dpc_images_sphere( simulation=sim_space, camera=cam_space, led_grid_shape=led_grid_shape, + pitch_mm=float(pitch_mm), inner_na=float(inner_na), include_center_led=bool(include_center_led), pattern_order=pattern_order, @@ -1647,7 +1679,7 @@ def simulate_dpc_images_sphere( }) # Crop 10% border from each side to keep central region - def _crop_center(arr: array) -> array: + def _crop_center(arr: array) -> array: # type: ignore if arr.ndim == 3: _, ny, nx = arr.shape elif arr.ndim == 4: @@ -1691,6 +1723,14 @@ def write_dpc_zarr( sphere: SphereSpec, inner_na: float, include_center_led: bool, + camera_gains: array | float = 1.0, # type: ignore + camera_offsets: array | float = 0.0, # type: ignore + camera_readout_noise_sds: array | float = 0.0, # type: ignore + camera_photon_shot_noise: bool = False, + camera_saturation: int | None = None, + camera_image_is_integer: bool = False, + exposure_time_ms: float = 1.0, + illumination_photons_per_s_per_um2: float = 1.0, nz: int | None = None, z_span_um: float | None = None, overwrite: bool = True, @@ -1724,6 +1764,22 @@ def write_dpc_zarr( Sphere parameters; radius/index stored as extra attrs. inner_na : float Inner NA stored as extra attr. + camera_gains : array or float, optional + Camera gain(s) stored in attrs. + camera_offsets : array or float, optional + Camera offset(s) stored in attrs. + camera_readout_noise_sds : array or float, optional + Camera readout noise sds stored in attrs. + camera_photon_shot_noise : bool, optional + Camera shot noise flag stored in attrs. + camera_saturation : int or None, optional + Camera saturation stored in attrs. + camera_image_is_integer : bool, optional + Camera integer flag stored in attrs. + exposure_time_ms : float, optional + Exposure time stored in attrs. + illumination_photons_per_s_per_um2 : float, optional + Illumination intensity stored in attrs. include_center_led : bool Whether center LED included; stored as extra attr. nz : int or None, optional @@ -1754,6 +1810,7 @@ def write_dpc_zarr( g.attrs["n_medium"] = float(n_medium) g.attrs["led_grid_shape"] = [int(led_grid_shape[0]), int(led_grid_shape[1])] g.attrs["pattern_order"] = [str(p) for p in pattern_order] + g.attrs["pitch_mm"] = float(meta.get("pitch_mm", np.nan)) # Effective Mie evaluation plane used for simulation (µm). if "z_plane_um" in meta: g.attrs["field_generation_z_um"] = float(np.asarray(meta["z_plane_um"])) @@ -1778,6 +1835,22 @@ def write_dpc_zarr( g.attrs["inner_na"] = float(inner_na) g.attrs["include_center_led"] = bool(include_center_led) g.attrs["dxy_um"] = float(np.asarray(meta.get("dxy_um", camera_pixel_um / magnification))) + g.attrs["camera_gains"] = ( + float(camera_gains) if np.isscalar(camera_gains) else np.asarray(camera_gains).tolist() + ) + g.attrs["camera_offsets"] = ( + float(camera_offsets) if np.isscalar(camera_offsets) else np.asarray(camera_offsets).tolist() + ) + g.attrs["camera_readout_noise_sds"] = ( + float(camera_readout_noise_sds) + if np.isscalar(camera_readout_noise_sds) + else np.asarray(camera_readout_noise_sds).tolist() + ) + g.attrs["camera_photon_shot_noise"] = bool(camera_photon_shot_noise) + g.attrs["camera_saturation"] = camera_saturation if camera_saturation is None else int(camera_saturation) + g.attrs["camera_image_is_integer"] = bool(camera_image_is_integer) + g.attrs["exposure_time_ms"] = float(exposure_time_ms) + g.attrs["illumination_photons_per_s_per_um2"] = float(illumination_photons_per_s_per_um2) if dpc_np.ndim == 3: if dpc_np.shape[0] != 4: @@ -1811,6 +1884,7 @@ def simulate_dpc_images_sphere_to_zarr( wavelength_um: float = 0.515, na_obj: float = 0.8, led_grid_shape: tuple[int, int] = (64, 64), + pitch_mm: float = 2.5, camera_pixel_um: float = 2.4, magnification: float = 20.0, camera_oversample: int = 1, @@ -1855,6 +1929,8 @@ def simulate_dpc_images_sphere_to_zarr( Objective NA for pupil cutoff and attrs. led_grid_shape : tuple[int, int], optional LED board shape for pattern generation. + pitch_mm : float, optional + LED pitch on the board (mm), used to map board radius to illumination NA. camera_pixel_um : float, optional Camera pixel size (µm) for attrs and camera sampling. magnification : float, optional @@ -1929,6 +2005,7 @@ def simulate_dpc_images_sphere_to_zarr( wavelength_um=wavelength_um, na_obj=na_obj, led_grid_shape=led_grid_shape, + pitch_mm=pitch_mm, camera_pixel_um=camera_pixel_um, magnification=magnification, camera_oversample=camera_oversample, @@ -1974,6 +2051,14 @@ def simulate_dpc_images_sphere_to_zarr( sphere=sphere_used, inner_na=inner_na, include_center_led=include_center_led, + camera_gains=camera_gains, + camera_offsets=camera_offsets, + camera_readout_noise_sds=camera_readout_noise_sds, + camera_photon_shot_noise=camera_photon_shot_noise, + camera_saturation=camera_saturation, + camera_image_is_integer=camera_image_is_integer, + exposure_time_ms=exposure_time_ms, + illumination_photons_per_s_per_um2=illumination_photons_per_s_per_um2, nz=nz, z_span_um=z_span_um, overwrite=overwrite, diff --git a/mcsim/analysis/dpc_fista_solver.py b/mcsim/analysis/dpc_fista_solver.py new file mode 100644 index 0000000..7a28d57 --- /dev/null +++ b/mcsim/analysis/dpc_fista_solver.py @@ -0,0 +1,99 @@ +""" +Minimal LED board geometry helpers for DPC simulations. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass(frozen=True) +class LEDBoard: + """ + LED board definition for NA-space mapping. + + :param n_side: Number of LEDs per side (square board). + :param pitch_mm: LED pitch in mm. + :param na_obj: Objective NA. + :param wavelength_um: Wavelength in um. + :param n_medium: Medium refractive index. + """ + + n_side: int + pitch_mm: float + na_obj: float + wavelength_um: float + n_medium: float + + +@dataclass(frozen=True) +class LEDGeometry: + """ + Geometry container for LED NA positions. + + :param na_components: NA positions, shape (N, 2). + """ + + na_components: np.ndarray + + +def _make_led_na_positions( + ny_led: int, + nx_led: int, + *, + na_obj: float, + inner_na: float = 0.0, + include_center: bool = True, +) -> np.ndarray: + if ny_led < 1 or nx_led < 1: + raise ValueError("ny_led and nx_led must be >= 1") + if inner_na < 0 or inner_na >= na_obj: + raise ValueError("inner_na must satisfy 0 <= inner_na < na_obj") + + cy = (ny_led - 1) / 2.0 + cx = (nx_led - 1) / 2.0 + + yy, xx = np.meshgrid( + np.arange(ny_led, dtype=np.float32), + np.arange(nx_led, dtype=np.float32), + indexing="ij", + ) + x = xx - cx + y = yy - cy + r = np.sqrt(x * x + y * y) + + r_circle = float(min(cx, cy)) if (cx > 0 and cy > 0) else float(np.max(r)) + mask = r <= r_circle + + if not include_center: + mask = mask & (r > 0) + + r_mask_max = float(np.max(r[mask])) if np.any(mask) else 1.0 + scale = float(na_obj) / r_mask_max + + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + + ann = na_r >= float(inner_na) + na_xy = np.stack([na_x[ann], na_y[ann]], axis=1) + return na_xy + + +def compute_led_geometry(board: LEDBoard) -> LEDGeometry: + """ + Compute LED NA positions for a square board. + + :param board: LED board definition. + :return: LEDGeometry with NA components. + """ + na_xy = _make_led_na_positions( + int(board.n_side), + int(board.n_side), + na_obj=float(board.na_obj), + inner_na=0.0, + include_center=True, + ) + return LEDGeometry(na_components=na_xy) diff --git a/mcsim/analysis/dpc_inverse.py b/mcsim/analysis/dpc_inverse.py deleted file mode 100644 index d840d4e..0000000 --- a/mcsim/analysis/dpc_inverse.py +++ /dev/null @@ -1,322 +0,0 @@ -""" -Inverse-model solver for differential phase contrast (DPC) stacks using a linear -Rytov forward model and FISTA with a plug-and-play median proximal operator. - -The solver maps a 3D refractive-index volume to DPC measurements generated by -half-plane LED patterns, optionally across a focal stack. It reuses the -GPU-aware FFT and linear scattering tools already in the codebase. -""" - -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Optional - -import numpy as np -import scipy.sparse as sp -from scipy.sparse.linalg import svds - -try: - import cupy as cp - import cupyx.scipy.sparse as sp_gpu -except ImportError: # pragma: no cover - cp = None - sp_gpu = None - -from mcsim.analysis.fft import ft2, ift2 -from mcsim.analysis.field_prop import ( - fwd_model_linear, - get_angular_spectrum_kernel, - get_v, -) -from mcsim.analysis.optimize import Optimizer, median_prox - - -if cp: - array = np.ndarray | cp.ndarray -else: - array = np.ndarray - - -def _get_xp(use_gpu: bool): - return cp if (use_gpu and cp is not None) else np - - -def _split_dpc_patterns(na_xy: np.ndarray, order: Sequence[str]) -> list[np.ndarray]: - """ - Split LED NA coordinates into canonical DPC half-planes. - """ - allowed = ("left", "right", "up", "down") - if len(order) != 4 or set(order) != set(allowed): - raise ValueError(f"order must be a permutation of {allowed}, got {order}") - - na_x = na_xy[:, 0] - na_y = na_xy[:, 1] - masks = { - "left": na_x < 0, - "right": na_x > 0, - "up": na_y > 0, - "down": na_y < 0, - } - - indices: list[np.ndarray] = [] - for name in order: - idx = np.nonzero(masks[name])[0] - if idx.size == 0: - raise ValueError(f"pattern '{name}' has zero LEDs") - indices.append(idx) - return indices - - -def _defocus_kernels( - ny: int, - nx: int, - dxy: float, - zs: Sequence[float], - wavelength: float, - n_medium: float, - xp, -) -> array: - """ - Precompute angular spectrum defocus kernels for each focal offset. - """ - fx = xp.expand_dims(xp.fft.fftfreq(nx, dxy), axis=0) - fy = xp.expand_dims(xp.fft.fftfreq(ny, dxy), axis=1) - kernels = [get_angular_spectrum_kernel(fx, fy, z, wavelength, n_medium) for z in zs] - return xp.stack(kernels, axis=0) - - -@dataclass -class DPCGeometry: - """ - Lightweight container for optical and sampling parameters. - """ - - wavelength_um: float - n_medium: float - na_obj: float - camera_pixel_um: float - magnification: float - focal_offsets_um: Sequence[float] - pattern_order: Sequence[str] - normalize_by_led_count: bool = True - - @property - def dxy_um(self) -> float: - return float(self.camera_pixel_um) / float(self.magnification) - - -class DPCRytovInverse(Optimizer): - """ - FISTA-compatible optimizer for DPC stacks under a linear Rytov model. - """ - - def __init__( - self, - measured: array, - led_na: np.ndarray, - *, - geom: DPCGeometry, - n_shape: Sequence[int, int, int], - drs_n: Sequence[float, float, float], - median_filter_size: Sequence[int, int, int] = (3, 3, 3), - use_gpu: bool = False, - ) -> None: - if measured.ndim == 3: - measured = measured[None, ...] - - n_planes, n_patterns, ny, nx = measured.shape - if n_patterns != 4: - raise ValueError(f"expected 4 DPC patterns, got {n_patterns}") - - xp = _get_xp(use_gpu) - self.use_gpu = bool(use_gpu) - self.xp = xp - self.geom = geom - self.n_shape = tuple(int(s) for s in n_shape) - self.drs_n = tuple(float(d) for d in drs_n) - self.ny = int(ny) - self.nx = int(nx) - self.n_planes = int(n_planes) - self.n_patterns = int(n_patterns) - self.pattern_order = tuple(geom.pattern_order) - self.normalize_by_led_count = bool(geom.normalize_by_led_count) - self.median_filter_size = tuple(int(s) for s in median_filter_size) - - self.data = xp.asarray(measured) - self.led_na = np.asarray(led_na, dtype=float) - self.pattern_led_indices = _split_dpc_patterns(self.led_na, self.pattern_order) - self.pattern_led_indices_gpu = ( - [xp.asarray(idx) for idx in self.pattern_led_indices] if xp is cp else self.pattern_led_indices - ) - - # Sampling/grid - self.dxy = float(geom.dxy_um) - self.drs_e = (self.dxy, self.dxy) - - # Forward operator (Rytov linear scattering) for all LEDs at once - beam_fx = xp.asarray(self.led_na[:, 0] / geom.wavelength_um) - beam_fy = xp.asarray(self.led_na[:, 1] / geom.wavelength_um) - beam_fz = xp.sqrt((geom.n_medium / geom.wavelength_um) ** 2 - beam_fx**2 - beam_fy**2) - self.model = fwd_model_linear( - beam_fx, - beam_fy, - beam_fz, - geom.n_medium, - geom.na_obj, - geom.wavelength_um, - (self.ny, self.nx), - self.drs_e, - self.n_shape, - self.drs_n, - mode="rytov", - interpolate=False, - use_gpu=self.use_gpu, - ) - - # Defocus transfer functions per plane - if geom.focal_offsets_um and len(geom.focal_offsets_um) not in (1, n_planes): - raise ValueError( - f"focal_offsets_um length ({len(geom.focal_offsets_um)}) must match n_planes ({n_planes}) or be 1" - ) - self.focal_offsets = [float(z) for z in geom.focal_offsets_um] if geom.focal_offsets_um else [0.0] - if len(self.focal_offsets) == 1 and n_planes > 1: - self.focal_offsets = [self.focal_offsets[0]] * n_planes - if len(self.focal_offsets) != n_planes: - raise ValueError(f"n_planes={n_planes} but {len(self.focal_offsets)} focal offsets provided") - self.defocus_kernels = _defocus_kernels( - self.ny, - self.nx, - self.dxy, - self.focal_offsets, - geom.wavelength_um, - geom.n_medium, - xp, - ) - self.e0 = xp.ones((self.n_planes, self.ny, self.nx), dtype=complex) - - # Map sample indices -> (plane, pattern) - self.sample_plane = np.repeat(np.arange(self.n_planes), self.n_patterns) - self.sample_pattern = np.tile(np.arange(self.n_patterns), self.n_planes) - - super().__init__(self.n_planes * self.n_patterns, prox_parameters={"median_filter_size": self.median_filter_size}) - - # --------------------------- - # helpers - # --------------------------- - def _select_indices(self, inds: Optional[Sequence[int]]) -> np.ndarray: - if inds is None: - return np.arange(self.n_samples) - return np.asarray(inds, dtype=int) - - def _predict_fields(self, n_volume: array) -> tuple[array, array]: - """ - Compute per-pattern irradiance predictions and per-LED defocused fields. - """ - xp = self.xp - v = get_v(n_volume, self.geom.n_medium, self.geom.wavelength_um) - v_vec = xp.asarray(v).ravel() - - es_focus = self.model.dot(v_vec).reshape((self.led_na.shape[0], self.ny, self.nx)) - es_ft = ft2(es_focus, shift=False) - es_ft = es_ft[:, None, :, :] * self.defocus_kernels[None, :, :, :] - es_z = ift2(es_ft, shift=False) # (n_led, n_planes, ny, nx) - - # move plane axis first for easier broadcasting - es_z = xp.moveaxis(es_z, 1, 0) # (n_planes, n_led, ny, nx) - intensity = xp.abs(self.e0[:, None, :, :] + es_z) ** 2 - - patterns = [] - idx_list = self.pattern_led_indices_gpu - for idx in idx_list: - pat = intensity[:, idx, :, :].sum(axis=1) - if self.normalize_by_led_count: - pat = pat / idx.shape[0] - patterns.append(pat) - - pred = xp.stack(patterns, axis=1) # (n_planes, 4, ny, nx) - return pred, es_z - - # --------------------------- - # Optimizer interface - # --------------------------- - def fwd_model(self, x: array, inds: Optional[Sequence[int]] = None) -> array: - pred, _ = self._predict_fields(x) - pred_flat = pred.reshape((self.n_samples, self.ny, self.nx)) - sel = self._select_indices(inds) - return pred_flat[sel] - - def cost(self, x: array, inds: Optional[Sequence[int]] = None) -> array: - xp = self.xp - pred, _ = self._predict_fields(x) - residual = pred - self.data - residual_flat = residual.reshape((self.n_samples, self.ny, self.nx)) - sel = self._select_indices(inds) - r = residual_flat[sel] - return 0.5 * xp.sum(xp.abs(r) ** 2, axis=(-1, -2)) - - def gradient(self, x: array, inds: Optional[Sequence[int]] = None) -> array: - xp = self.xp - pred, es_z = self._predict_fields(x) - n_led = self.led_na.shape[0] - - # mask residuals to selected samples - mask = np.zeros(self.n_samples, dtype=bool) - sel = self._select_indices(inds) - mask[sel] = True - mask_stack = xp.asarray(mask.reshape((self.n_planes, self.n_patterns, 1, 1))) - - residual = (pred - self.data) * mask_stack - - # accumulate per-LED gradient at focus plane - grad_es0 = xp.zeros((n_led, self.ny, self.nx), dtype=complex) - idx_list = self.pattern_led_indices_gpu - for p_idx, led_idx in enumerate(idx_list): - res_p = residual[:, p_idx, :, :] # (n_planes, ny, nx) - scale = 1.0 / led_idx.shape[0] if self.normalize_by_led_count else 1.0 - for l in led_idx.tolist(): - # gradient of irradiance wrt field at each plane - g_z = res_p * (self.e0 + es_z[:, l, :, :]) * (2.0 * scale) - # backpropagate defocus for this LED across planes - g_ft = ft2(g_z, adjoint=True, shift=False) # (n_planes, ny, nx) - back_ft = xp.sum(g_ft * xp.conj(self.defocus_kernels), axis=0) - grad_es0[l] += ift2(back_ft, shift=False) - - # adjoint of linear model - grad_v_vec = self.model.getH().dot(grad_es0.ravel()) - grad_v = grad_v_vec.reshape(self.n_shape) - - # chain rule dv/dn = -2*(2*pi/lambda)^2 * n - n_volume = xp.asarray(x) - factor = -2 * (2 * np.pi / self.geom.wavelength_um) ** 2 * n_volume - grad_n = grad_v * factor - # broadcast gradient to requested batch; each selected sample shares the same volume gradient - return xp.broadcast_to(xp.expand_dims(grad_n, axis=0), (len(sel),) + grad_n.shape) - - def prox(self, x: array, step: float) -> array: - xp = self.xp - x_real = median_prox(x.real, self.median_filter_size) - return xp.asarray(x_real, dtype=x.dtype) - - def guess_step(self, x: Optional[array] = None) -> float: - """ - Estimate a Lipschitz-consistent step size using the same spectral norm - logic as `LinearScatt.guess_step`, reusing the sparse linear Rytov model. - - We ignore the mild nonlinearity from intensity formation; this gives a - conservative step similar to other linear-scattering solvers in mcsim. - """ - # dominant singular value of the linear model - try: - m_for_svd = self.model.get() if sp_gpu and isinstance(self.model, sp_gpu.csr_matrix) else self.model - u, s, vh = svds(m_for_svd, k=1, which="LM") - sigma = float(s[0]) - except Exception: - # fallback if svds not available - sigma = 1.0 - - # exactly match LinearScatt: L ~ sigma^2 / (n_samples * ny * nx) - lipschitz_estimate = sigma**2 / (self.n_samples * self.ny * self.nx) - if not np.isfinite(lipschitz_estimate) or lipschitz_estimate <= 0: - return 1e-3 - - return float(1.0 / lipschitz_estimate) diff --git a/mcsim/analysis/dpc_meta.py b/mcsim/analysis/dpc_meta.py new file mode 100644 index 0000000..c0983b1 --- /dev/null +++ b/mcsim/analysis/dpc_meta.py @@ -0,0 +1,89 @@ +""" +Metadata containers for DPC inverse solvers. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +import numpy as np + + +@dataclass(frozen=True) +class DPCMeta: + """ + Metadata required for DPC reconstruction. + + :param wavelength_um: Vacuum wavelength (um). + :param n_background: Background refractive index. + :param NA_obj: Objective NA (also used for detection NA). + :param magnification: Objective magnification. + :param camera_pixel_pitch_um: Camera pixel pitch (um). + :param volume_shape_zyx: Reconstruction volume shape (nz, ny, nx). + :param voxel_size_um_zyx: Voxel size (dz, dy, dx) in um. + :param z_planes_um: Object-space z planes in um, length nz. + :param led_grid_shape: LED board grid shape (ny, nx). + :param inner_na: Inner NA for annular LED mask. + :param include_center_led: Include NA=0 LED if True. + :param pattern_order: Pattern order matching measured data. + """ + + wavelength_um: float + n_background: float + NA_obj: float + magnification: float + camera_pixel_pitch_um: float + volume_shape_zyx: tuple[int, int, int] + voxel_size_um_zyx: tuple[float, float, float] + z_planes_um: Sequence[float] + led_grid_shape: tuple[int, int] = (64, 64) + inner_na: float = 0.0 + include_center_led: bool = False + pattern_order: tuple[str, str, str, str] = ("left", "right", "up", "down") + + def __post_init__(self) -> None: + z_planes = np.asarray(self.z_planes_um, dtype=float) + object.__setattr__(self, "z_planes_um", z_planes) + if z_planes.ndim != 1: + raise ValueError("z_planes_um must be 1D") + if len(self.volume_shape_zyx) != 3: + raise ValueError("volume_shape_zyx must have length 3") + if len(self.voxel_size_um_zyx) != 3: + raise ValueError("voxel_size_um_zyx must have length 3") + if z_planes.size != int(self.volume_shape_zyx[0]): + raise ValueError("z_planes_um length must match volume_shape_zyx[0]") + allowed = ("left", "right", "up", "down") + if len(self.pattern_order) != 4 or set(self.pattern_order) != set(allowed): + raise ValueError(f"pattern_order must be a permutation of {allowed}") + + @property + def dxy_um(self) -> float: + """ + Object-space pixel size derived from the camera pitch and magnification. + + :return: Pixel size in um. + """ + return float(self.camera_pixel_pitch_um) / float(self.magnification) + + def as_dict(self) -> dict[str, object]: + """ + Export metadata as a serializable dictionary. + + :return: Metadata dictionary. + """ + z_planes = np.asarray(self.z_planes_um, dtype=float).tolist() + return { + "wavelength_um": float(self.wavelength_um), + "n_background": float(self.n_background), + "NA_obj": float(self.NA_obj), + "magnification": float(self.magnification), + "camera_pixel_pitch_um": float(self.camera_pixel_pitch_um), + "volume_shape_zyx": tuple(int(v) for v in self.volume_shape_zyx), + "voxel_size_um_zyx": tuple(float(v) for v in self.voxel_size_um_zyx), + "z_planes_um": z_planes, + "led_grid_shape": tuple(int(v) for v in self.led_grid_shape), + "inner_na": float(self.inner_na), + "include_center_led": bool(self.include_center_led), + "pattern_order": tuple(self.pattern_order), + } diff --git a/mcsim/analysis/optimize.py b/mcsim/analysis/optimize.py index 4d2a54f..187bbcb 100644 --- a/mcsim/analysis/optimize.py +++ b/mcsim/analysis/optimize.py @@ -285,6 +285,7 @@ def run(self, gtol: float = 0.0, print_newline: bool = False, label: str = "", + iteration_callback=None, **kwargs) -> dict: """ @@ -536,6 +537,9 @@ def run(self, if stop: break + if iteration_callback is not None: + iteration_callback(ii, x) + # compute final cost if compute_cost: if compute_all_costs: diff --git a/mcsim/analysis/tv_prox_fast.py b/mcsim/analysis/tv_prox_fast.py new file mode 100644 index 0000000..1918b0d --- /dev/null +++ b/mcsim/analysis/tv_prox_fast.py @@ -0,0 +1,402 @@ +""" +Fast GPU-only 3D TV proximal operator (Chambolle dual projection) implemented +with CuPy + CUDA RawKernel. + +Import into optimize.py, e.g.: + + from tv_prox_fast import tv_prox_fast + +Design constraints: +- 3D only +- GPU/CuPy only (raises if not CuPy array) +- No positivity constraint +- Fixed-iteration behavior consistent with the uploaded prox_fgp (default num_iter=10) +""" + +from functools import lru_cache +from typing import Any, Optional, Tuple + +import cupy as cp + +__all__ = ["tv_prox_fast"] + + +@lru_cache(maxsize=None) +def _tv_fast_kernels( + dtype_str: str, +) -> Tuple[cp.RawKernel, cp.RawKernel, cp.RawKernel, cp.RawKernel]: + """ + Compile and cache RawKernels used by tv_prox_fast(). + + :param dtype_str: Data type string ("float32" or "float64"). + :return: Tuple of (discontig_sub_kernel, tv_norm_kernel). + """ + if dtype_str == "float32": + ctype = "float" + sqrt_fn = "sqrtf" + elif dtype_str == "float64": + ctype = "double" + sqrt_fn = "sqrt" + else: + raise ValueError(f"Unsupported dtype for tv_prox_fast: {dtype_str}") + + code = rf""" + extern "C" __global__ + void discontig_sub( + const {ctype}* __restrict__ arr, + {ctype}* __restrict__ out, + unsigned long long step, + unsigned long long length, + int transpose, + unsigned long long n_total + ) {{ + unsigned long long i = (unsigned long long)blockDim.x * (unsigned long long)blockIdx.x + + (unsigned long long)threadIdx.x; + if (i >= n_total) return; + + unsigned long long pos = i % length; + if (!transpose) {{ + if (pos + step < length) {{ + out[i] = arr[i + step] - arr[i]; + }} else {{ + out[i] = ({ctype})0; + }} + }} else {{ + out[i] -= arr[i]; + if (pos + step < length) {{ + out[i + step] += arr[i]; + }} + }} + }} + + extern "C" __global__ + void tv_norm3( + {ctype}* __restrict__ out, // length n + const {ctype}* __restrict__ tv, // length 3*n, packed [x|y|z] + unsigned long long n + ) {{ + unsigned long long i = (unsigned long long)blockDim.x * (unsigned long long)blockIdx.x + + (unsigned long long)threadIdx.x; + if (i >= n) return; + + {ctype} a = tv[i]; + {ctype} b = tv[i + n]; + {ctype} c = tv[i + 2*n]; + + out[i] = {sqrt_fn}(a*a + b*b + c*c); + }} + """ + + code_grad = rf""" + extern "C" __global__ + void tv_grad_update( + const {ctype}* __restrict__ out, + {ctype}* __restrict__ p, + {ctype}* __restrict__ norms, + {ctype} step, + {ctype} weight, + {ctype} w_z, + {ctype} w_y, + {ctype} w_x, + unsigned long long nz, + unsigned long long ny, + unsigned long long nx, + unsigned long long n_total, + int write_norms + ) {{ + unsigned long long i = (unsigned long long)blockDim.x * (unsigned long long)blockIdx.x + + (unsigned long long)threadIdx.x; + if (i >= n_total) return; + + unsigned long long plane = ny * nx; + unsigned long long z = i / plane; + unsigned long long rem = i - z * plane; + unsigned long long y = rem / nx; + unsigned long long x = rem - y * nx; + + {ctype} g0 = ({ctype})0; + {ctype} g1 = ({ctype})0; + {ctype} g2 = ({ctype})0; + if (z + 1 < nz) {{ + g0 = (out[i + plane] - out[i]) * w_z; + }} + if (y + 1 < ny) {{ + g1 = (out[i + nx] - out[i]) * w_y; + }} + if (x + 1 < nx) {{ + g2 = (out[i + 1] - out[i]) * w_x; + }} + + {ctype} norm = {sqrt_fn}(g0 * g0 + g1 * g1 + g2 * g2); + {ctype} denom = ({ctype})1 + (step / weight) * norm; + + {ctype} p0 = p[i]; + {ctype} p1 = p[i + n_total]; + {ctype} p2 = p[i + 2 * n_total]; + + p[i] = (p0 - step * g0) / denom; + p[i + n_total] = (p1 - step * g1) / denom; + p[i + 2 * n_total] = (p2 - step * g2) / denom; + if (write_norms) {{ + norms[i] = norm; + }} + }} + """ + + code_div = rf""" + extern "C" __global__ + void tv_divergence( + const {ctype}* __restrict__ p, + const {ctype}* __restrict__ x, + {ctype}* __restrict__ out, + {ctype} w_z, + {ctype} w_y, + {ctype} w_x, + unsigned long long nz, + unsigned long long ny, + unsigned long long nx, + unsigned long long n_total + ) {{ + unsigned long long i = (unsigned long long)blockDim.x * (unsigned long long)blockIdx.x + + (unsigned long long)threadIdx.x; + if (i >= n_total) return; + + unsigned long long plane = ny * nx; + unsigned long long z = i / plane; + unsigned long long rem = i - z * plane; + unsigned long long y = rem / nx; + unsigned long long xind = rem - y * nx; + + {ctype} p0 = p[i]; + {ctype} p1 = p[i + n_total]; + {ctype} p2 = p[i + 2 * n_total]; + + {ctype} val = -(w_z * p0 + w_y * p1 + w_x * p2); + if (z > 0) {{ + val += w_z * p[i - plane]; + }} + if (y > 0) {{ + val += w_y * p[i + n_total - nx]; + }} + if (xind > 0) {{ + val += w_x * p[i + 2 * n_total - 1]; + }} + out[i] = x[i] + val; + }} + """ + + return ( + cp.RawKernel(code, "discontig_sub"), + cp.RawKernel(code, "tv_norm3"), + cp.RawKernel(code_div, "tv_divergence"), + cp.RawKernel(code_grad, "tv_grad_update"), + ) + + +def _launch_1d(kernel: cp.RawKernel, n: int, args: Tuple[Any, ...]) -> None: + """ + Launch a 1D CUDA kernel with a fixed thread block size. + + :param kernel: RawKernel to launch. + :param n: Total number of elements. + :param args: Kernel argument tuple. + :return: None. + """ + threads = 256 + blocks = (n + threads - 1) // threads + kernel((blocks,), (threads,), args) + + +def _discontig_sub_cupy( + arr: cp.ndarray, + out: cp.ndarray, + axis: int, + transpose: bool = False, +) -> cp.ndarray: + """ + Flattened-addressing difference/adjoint operator, equivalent to fista.py discontig_sub. + + :param arr: Input array. + :param out: Output array (overwritten for forward, accumulated for adjoint). + :param axis: Axis along which to compute the flattened-order difference. + :param transpose: False for forward, True for adjoint. + :return: Output array. + """ + if arr.dtype != out.dtype: + raise ValueError("arr and out must have the same dtype") + + if not arr.flags.c_contiguous: + arr = cp.ascontiguousarray(arr) + if not out.flags.c_contiguous: + raise ValueError("out must be C-contiguous") + + axis = axis % arr.ndim + shape = arr.shape + + step = 1 + for s in shape[axis + 1 :]: + step *= int(s) + + length = step * int(shape[axis]) + n_total = int(arr.size) + + discontig, _, _, _ = _tv_fast_kernels(str(arr.dtype)) + _launch_1d( + discontig, + n_total, + ( + arr, + out, + cp.uint64(step), + cp.uint64(length), + cp.int32(1 if transpose else 0), + cp.uint64(n_total), + ), + ) + return out + + +def tv_prox_fast( + x: cp.ndarray, + tau: float, + num_iter: int = 10, + eps: float = 0.0, + voxel_size_zyx: Optional[Tuple[float, float, float]] = None, + weight_scale_zyx: Optional[Tuple[float, float, float]] = None, + out: Optional[cp.ndarray] = None, +) -> cp.ndarray: + """ + Fast GPU-only TV proximal operator for 3D arrays using a Chambolle dual-projection loop. + + This implements a proximal map: + + .. math:: + \\text{prox}_{\\tau \\, TV}(x) = \\arg\\min_y \\; 0.5\\|y - x\\|_2^2 + \\tau \\, TV(y) + + matching skimage/cucim's Chambolle formulation (no positivity projection). + + :param x: CuPy ndarray, 3D. + :param tau: TV proximal weight (must be >= 0). tau==0 returns identity. + :param num_iter: Fixed number of FGP iterations. + :param eps: Relative stopping tolerance, matching cucim TV when > 0. + :param voxel_size_zyx: Physical voxel size (dz, dy, dx). Defaults to (1, 1, 1). + :param weight_scale_zyx: Multiplicative scaling for (dz, dy, dx) weights. + :param out: Optional CuPy array to write the result into (must match shape/dtype). + :return: Proximal output (same shape/dtype as x, with float32/float64 enforced). + """ + if not isinstance(x, cp.ndarray): + raise TypeError("tv_prox_fast is GPU-only: x must be a CuPy ndarray.") + if x.ndim != 3: + raise ValueError(f"tv_prox_fast is 3D-only; got x.shape={x.shape}") + if tau < 0: + raise ValueError("tau must be >= 0") + + if tau == 0.0: + if out is None: + return x + out[...] = x + return out + + if not x.flags.c_contiguous: + x = cp.ascontiguousarray(x) + + if x.dtype not in (cp.float32, cp.float64): + x = x.astype(cp.float32, copy=False) + + dtype = x.dtype + if voxel_size_zyx is None: + voxel_size_zyx = (1.0, 1.0, 1.0) + if len(voxel_size_zyx) != 3: + raise ValueError("voxel_size_zyx must be a 3-tuple (dz, dy, dx).") + dz, dy, dx = (float(v) for v in voxel_size_zyx) + if dz <= 0 or dy <= 0 or dx <= 0: + raise ValueError("voxel_size_zyx entries must be > 0.") + if weight_scale_zyx is None: + weight_scale_zyx = (1.0, 1.0, 1.0) + if len(weight_scale_zyx) != 3: + raise ValueError("weight_scale_zyx must be a 3-tuple (sz, sy, sx).") + sz, sy, sx = (float(v) for v in weight_scale_zyx) + if sz <= 0 or sy <= 0 or sx <= 0: + raise ValueError("weight_scale_zyx entries must be > 0.") + w_z = 1.0 / dz + w_y = 1.0 / dy + w_x = 1.0 / dx + w_z *= sz + w_y *= sy + w_x *= sx + step = 1.0 / (2.0 * (w_z * w_z + w_y * w_y + w_x * w_x)) + + p = cp.zeros((3,) + x.shape, dtype=dtype) + norms = cp.empty_like(x) if eps > 0 else None + proj = cp.empty_like(x) if out is None else out + + _, _, div_kernel, grad_kernel = _tv_fast_kernels(str(dtype)) + n_vox = int(x.size) + nz, ny, nx = (int(s) for s in x.shape) + step_val = dtype.type(step) + weight_val = dtype.type(tau) + wz_val = dtype.type(w_z) + wy_val = dtype.type(w_y) + wx_val = dtype.type(w_x) + write_norms = cp.int32(1 if eps > 0 else 0) + eps_val = float(eps) + e_init = None + e_prev = None + + for ii in range(int(num_iter)): + if ii > 0: + _launch_1d( + div_kernel, + n_vox, + ( + p, + x, + proj, + wz_val, + wy_val, + wx_val, + cp.uint64(nz), + cp.uint64(ny), + cp.uint64(nx), + cp.uint64(n_vox), + ), + ) + else: + proj[...] = x + + _launch_1d( + grad_kernel, + n_vox, + ( + proj, + p, + norms if norms is not None else proj, + step_val, + weight_val, + wz_val, + wy_val, + wx_val, + cp.uint64(nz), + cp.uint64(ny), + cp.uint64(nx), + cp.uint64(n_vox), + write_norms, + ), + ) + + if eps_val > 0.0: + d = proj - x + e = cp.sum(d * d) + e += weight_val * cp.sum(norms) + e /= float(n_vox) + e_host = float(e) + if ii == 0: + e_init = e_host + e_prev = e_host + else: + if e_init is not None and e_prev is not None: + if abs(e_prev - e_host) < eps_val * e_init: + break + e_prev = e_host + + return proj diff --git a/mcsim/analysis/wotf_fista.py b/mcsim/analysis/wotf_fista.py new file mode 100644 index 0000000..0be485f --- /dev/null +++ b/mcsim/analysis/wotf_fista.py @@ -0,0 +1,591 @@ +""" +WOTF-based FISTA optimizer implemented without legacy inverse dependencies. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Sequence, Union, Any + +import numpy as np + +try: + import cupy as cp +except Exception: + cp = None + +from mcsim.analysis.fft import ft2, ift2, ft3, ift3, ftn +from mcsim.analysis.field_prop import get_v +from mcsim.analysis.optimize import Optimizer, to_cpu +from mcsim.analysis.tv_prox_fast import tv_prox_fast + + +if cp: + array = Union[np.ndarray, cp.ndarray] +else: + array = np.ndarray + + +@dataclass(frozen=True) +class WOTFParams: + wavelength_um: float + na_obj: float + na_in: float + n0: float + dxy_um: float + dz_um: float + nz: int + ny: int + nx: int + pupil_taper_na: float = 0.0 + + +def camera_to_photons( + I_cam: array, + *, + camera_offset_adu: Union[float, array], + camera_gain_photons_per_adu: Union[float, array], +) -> array: + """ + Convert camera units to photons using a linear model. + """ + xp = cp if cp and isinstance(I_cam, cp.ndarray) else np + offset = xp.asarray(camera_offset_adu, dtype=xp.float32) + gain = xp.asarray(camera_gain_photons_per_adu, dtype=xp.float32) + I_phot = (I_cam - offset) * gain + return xp.maximum(I_phot, xp.float32(0.0)) + + +def _make_led_na_positions( + ny_led: int, + nx_led: int, + *, + na_obj: float, + na_in: float, + include_center: bool, +) -> np.ndarray: + yy, xx = np.meshgrid(np.arange(ny_led), np.arange(nx_led), indexing="ij") + yy = yy.astype(np.float32) + xx = xx.astype(np.float32) + cy = (ny_led - 1) / 2.0 + cx = (nx_led - 1) / 2.0 + x = xx - cx + y = yy - cy + r = np.sqrt(x * x + y * y) + r_circle = float(min(cx, cy)) if (cx > 0 and cy > 0) else float(np.max(r)) + mask = r <= r_circle + if not include_center: + mask = mask & (r > 0) + if not np.any(mask): + raise ValueError("LED grid mask is empty; check LED geometry.") + r_mask_max = float(np.max(r[mask])) + scale = float(na_obj) / r_mask_max + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + if float(na_in) > 0: + keep = na_r >= float(na_in) + else: + keep = np.ones_like(na_r, dtype=bool) + if not include_center: + keep &= na_r > 0 + na_x = na_x[keep] + na_y = na_y[keep] + na_xy = np.stack([na_x, na_y], axis=1) + return na_xy.astype(np.float32, copy=False) + + +def build_led_na_grid( + ny_led: int, + nx_led: int, + *, + na_obj: float, + na_in: float, + include_center: bool, + led_subsample: int = 1, +) -> np.ndarray: + """ + Build an (N, 2) list of LED NA positions on the LED grid. + + led_subsample is the subsample factor for the LED grid. + """ + if led_subsample < 1: + raise ValueError("led_subsample must be >= 1.") + if led_subsample == 1: + return _make_led_na_positions( + ny_led, + nx_led, + na_obj=na_obj, + na_in=na_in, + include_center=include_center, + ) + cy = (ny_led - 1) / 2.0 + cx = (nx_led - 1) / 2.0 + step_y = int(max(1, np.floor(np.sqrt(float(led_subsample))))) + step_x = int(max(1, np.ceil(float(led_subsample) / float(step_y)))) + + yy_idx, xx_idx = np.meshgrid( + np.arange(ny_led, dtype=int), + np.arange(nx_led, dtype=int), + indexing="ij", + ) + yy = yy_idx.astype(np.float32) + xx = xx_idx.astype(np.float32) + x = xx - cx + y = yy - cy + r = np.sqrt(x * x + y * y) + r_circle = float(min(cx, cy)) if (cx > 0 and cy > 0) else float(np.max(r)) + mask_circle = r <= r_circle + if not include_center: + mask_circle = mask_circle & (r > 0) + if not np.any(mask_circle): + raise ValueError("LED grid mask is empty; check LED geometry.") + r_mask_max = float(np.max(r[mask_circle])) + scale = float(na_obj) / r_mask_max + + best_key = None + best_offsets = None + for off_y in range(step_y): + mask_y = (yy_idx - off_y) % step_y == 0 + for off_x in range(step_x): + mask_sub = mask_y & ((xx_idx - off_x) % step_x == 0) + mask = mask_circle & mask_sub + if not np.any(mask): + continue + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + if float(na_in) > 0: + keep = na_r >= float(na_in) + else: + keep = np.ones_like(na_r, dtype=bool) + if not include_center: + keep &= na_r > 0 + if not np.any(keep): + continue + na_x = na_x[keep] + na_y = na_y[keep] + left = int(np.sum(na_x < 0)) + right = int(np.sum(na_x > 0)) + up = int(np.sum(na_y > 0)) + down = int(np.sum(na_y < 0)) + if include_center: + center = int(np.sum((na_x == 0) & (na_y == 0))) + left += center + right += center + up += center + down += center + min_count = min(left, right, up, down) + total = int(na_x.size) + key = (min_count, total) + if best_key is None or key > best_key: + best_key = key + best_offsets = (off_y, off_x) + + if best_offsets is None: + raise ValueError("Subsampling removed all LEDs from the board.") + + off_y, off_x = best_offsets + mask_sub = ((yy_idx - off_y) % step_y == 0) & ((xx_idx - off_x) % step_x == 0) + mask = mask_circle & mask_sub + na_x = x[mask] * scale + na_y = y[mask] * scale + na_r = np.sqrt(na_x * na_x + na_y * na_y) + if float(na_in) > 0: + keep = na_r >= float(na_in) + else: + keep = np.ones_like(na_r, dtype=bool) + if not include_center: + keep &= na_r > 0 + na_x = na_x[keep] + na_y = na_y[keep] + na_xy = np.stack([na_x, na_y], axis=1) + if na_xy.size == 0: + raise ValueError("Subsampling removed all LEDs after inner NA filter.") + return na_xy.astype(np.float32, copy=False) + + +def build_led_pattern_membership( + led_na_xy: np.ndarray, + *, + order: Sequence[str], + include_center: bool, +) -> np.ndarray: + """ + Assign each LED to one of four patterns: left/right/up/down. + """ + if len(order) != 4: + raise ValueError("order must have 4 entries.") + na = np.asarray(led_na_xy, dtype=float) + membership = np.zeros((na.shape[0], 4), dtype=bool) + for ii, name in enumerate(order): + if name == "left": + membership[:, ii] = na[:, 0] < 0 + elif name == "right": + membership[:, ii] = na[:, 0] > 0 + elif name == "up": + membership[:, ii] = na[:, 1] > 0 + elif name == "down": + membership[:, ii] = na[:, 1] < 0 + else: + raise ValueError(f"Unknown pattern name: {name}") + if include_center: + center = np.all(na == 0, axis=1) + if np.any(center): + membership[center] = True + return membership + + +def _tukey_window(n: int, taper: int, xp: Any) -> Optional[array]: + if taper <= 0 or n <= 0: + return None + taper = min(int(taper), n // 2) + if taper <= 0: + return None + alpha = min(1.0, 2.0 * float(taper) / float(n)) + if xp is np: + try: + from scipy.signal.windows import tukey # type: ignore + + return tukey(n, alpha=alpha).astype(np.float32, copy=False) + except Exception: + pass + if cp and xp is cp: + try: + from cupyx.scipy.signal.windows import tukey # type: ignore + + return tukey(n, alpha=alpha).astype(cp.float32, copy=False) + except Exception: + pass + idx = xp.arange(n, dtype=xp.float32) + w = xp.ones(n, dtype=xp.float32) + edge = float(taper) + left = idx < edge + if xp.any(left): + w[left] = 0.5 * (1.0 - xp.cos(np.pi * idx[left] / edge)) + right = idx >= (n - edge) + if xp.any(right): + tail = idx[right] - (n - edge) + w[right] = 0.5 * (1.0 - xp.cos(np.pi * (edge - tail) / edge)) + return w + + + +def _source_flip_unshifted(source: np.ndarray) -> np.ndarray: + source_flip = np.fft.fftshift(source) + source_flip = source_flip[::-1, ::-1] + if source_flip.shape[0] % 2 == 0: + source_flip = np.roll(source_flip, 1, axis=0) + if source_flip.shape[1] % 2 == 0: + source_flip = np.roll(source_flip, 1, axis=1) + return np.fft.ifftshift(source_flip) + + +def _build_illumination_patterns_unshifted( + led_na_xy: np.ndarray, + led_pattern_membership: np.ndarray, + *, + ny: int, + nx: int, + dxy_um: float, + wavelength_um: float, + na_obj: float, +) -> np.ndarray: + na_xy = np.asarray(led_na_xy, dtype=float) + membership = np.asarray(led_pattern_membership, dtype=bool) + fx = np.fft.fftfreq(int(nx), float(dxy_um)) + fy = np.fft.fftfreq(int(ny), float(dxy_um)) + fmax = float(na_obj) / float(wavelength_um) + fx_led = na_xy[:, 0] / float(wavelength_um) + fy_led = na_xy[:, 1] / float(wavelength_um) + fx_grid, fy_grid = np.meshgrid(fx, fy, indexing="xy") + mask = (fx_grid * fx_grid + fy_grid * fy_grid) <= (fmax * fmax + 1.0e-12) + + patterns = np.zeros((membership.shape[1], int(ny), int(nx)), dtype=float) + for led_idx in range(na_xy.shape[0]): + dist2 = (fx_grid - fx_led[led_idx]) ** 2 + (fy_grid - fy_led[led_idx]) ** 2 + dist2 = np.where(mask, dist2, np.inf) + flat = int(np.argmin(dist2)) + iy, ix = np.unravel_index(flat, dist2.shape) + for pid in range(membership.shape[1]): + if membership[led_idx, pid]: + patterns[pid, iy, ix] += 1.0 + return patterns + + +def build_wotf_transfer( + led_na_xy: np.ndarray, + led_pattern_membership: np.ndarray, + params: WOTFParams, + *, + real_dtype: type[np.floating] = np.float32, + complex_dtype: type[np.complexfloating] = np.complex64, +) -> tuple[np.ndarray, np.ndarray]: + """ + Build WOTF transfer functions for each illumination pattern. + """ + fx = np.fft.fftfreq(int(params.nx), float(params.dxy_um)).astype(real_dtype, copy=False) + fy = np.fft.fftfreq(int(params.ny), float(params.dxy_um)).astype(real_dtype, copy=False) + fx_grid, fy_grid = np.meshgrid(fx, fy, indexing="xy") + na_obj = float(params.na_obj) + na_in = float(params.na_in) + wavelength_um = float(params.wavelength_um) + na_r = np.sqrt(fx_grid * fx_grid + fy_grid * fy_grid) * wavelength_um + pupil = (na_r <= na_obj).astype(real_dtype) + if na_in != 0.0: + pupil[na_r < na_in] = 0.0 + taper_na = max(float(params.pupil_taper_na), 0.0) + if taper_na > 0.0: + t_outer = np.clip((na_obj - na_r) / taper_na, 0.0, 1.0) + outer_weight = 0.5 - 0.5 * np.cos(np.pi * t_outer) + if na_in > 0.0: + t_inner = np.clip((na_r - na_in) / taper_na, 0.0, 1.0) + inner_weight = 0.5 - 0.5 * np.cos(np.pi * t_inner) + else: + inner_weight = 1.0 + pupil *= outer_weight * inner_weight + + term_defocus = (1.0 / wavelength_um) ** 2 - fx_grid * fx_grid - fy_grid * fy_grid + term_oblique = (float(params.n0) / wavelength_um) ** 2 - fx_grid * fx_grid - fy_grid * fy_grid + phase_defocus = pupil * (2.0 * np.pi) * np.sqrt(np.maximum(term_defocus, 0.0)) + oblique_factor = pupil / (4.0 * np.pi * np.sqrt(np.maximum(term_oblique, 1.0e-12))) + + z_lin = np.fft.ifftshift((np.arange(int(params.nz)) - int(params.nz) // 2) * float(params.dz_um)).astype( + real_dtype, copy=False + ) + prop_kernel = np.exp(1.0j * z_lin[None, None, :] * phase_defocus[:, :, None]).astype( + complex_dtype, copy=False + ) + window_z = np.fft.ifftshift(np.hamming(int(params.nz))).astype(real_dtype, copy=False) + + patterns = _build_illumination_patterns_unshifted( + led_na_xy, + led_pattern_membership, + ny=int(params.ny), + nx=int(params.nx), + dxy_um=float(params.dxy_um), + wavelength_um=float(params.wavelength_um), + na_obj=float(params.na_obj), + ).astype(real_dtype, copy=False) + + dfx = 1.0 / (float(params.nx) * float(params.dxy_um)) + dfy = 1.0 / (float(params.ny) * float(params.dxy_um)) + + n_patterns = int(patterns.shape[0]) + H_real = np.zeros((n_patterns, int(params.nz), int(params.ny), int(params.nx)), dtype=complex_dtype) + H_imag = np.zeros_like(H_real) + + for pid in range(n_patterns): + source = patterns[pid] + source_flip = _source_flip_unshifted(source) + fsp = ft2(source_flip[:, :, None] * pupil[:, :, None] * prop_kernel, axes=(0, 1), shift=False) + fpg = ft2(pupil[:, :, None] * prop_kernel * oblique_factor[:, :, None], axes=(0, 1), shift=False).conj() + fsp_cfpg = fsp * fpg + h_real = 2.0 * ift2(1.0j * fsp_cfpg.imag * dfx * dfy, axes=(0, 1), shift=False) + h_imag = 2.0 * ift2(fsp_cfpg.real * dfx * dfy, axes=(0, 1), shift=False) + h_real *= window_z[None, None, :] + h_imag *= window_z[None, None, :] + h_real = ftn(h_real, axes=(2,), shift=False).astype(complex_dtype, copy=False) * float(params.dz_um) + h_imag = ftn(h_imag, axes=(2,), shift=False).astype(complex_dtype, copy=False) * float(params.dz_um) + h_real = np.transpose(h_real, (2, 0, 1)) + h_imag = np.transpose(h_imag, (2, 0, 1)) + total_source = np.sum(source_flip * pupil * pupil.conj()) * dfx * dfy + if total_source == 0: + raise ValueError("Total source power is zero for a pattern.") + H_real[pid] = h_real * (1.0j / total_source) + H_imag[pid] = h_imag * (1.0 / total_source) + + H_real = np.fft.fftshift(H_real, axes=(1, 2, 3)) + H_imag = np.fft.fftshift(H_imag, axes=(1, 2, 3)) + return H_real, H_imag + + +class WOTFFISTAOptimizer(Optimizer): + """ + FISTA optimizer for 3D WOTF pattern intensities (no DPC differencing). + """ + + def __init__( + self, + I_meas: array, + I0_pred: array, + H_real: array, + H_imag: array, + *, + n0: float, + wavelength_um: float, + wotf_sign: float = -1.0, + eps: float = 1.0e-8, + tv_weight: float = 0.0, + tv_max_num_iter: int = 50, + tv_eps: float = 2.0e-4, + tv_voxel_size_zyx: Optional[Sequence[float]] = None, + tv_weight_scale_zyx: Optional[Sequence[float]] = None, + pad_zyx: Optional[Sequence[int]] = None, + z_taper: int = 0, + xy_taper: Optional[int] = None, + use_real_constraint: bool = True, + ) -> None: + super().__init__(n_samples=1, prox_parameters=None) + self.n0 = float(n0) + self.wavelength_um = float(wavelength_um) + self.wotf_sign = float(wotf_sign) + self.eps = float(eps) + self.tv_weight = float(tv_weight) + self.tv_max_num_iter = int(tv_max_num_iter) + self.tv_eps = float(tv_eps) + self.use_real_constraint = bool(use_real_constraint) + if tv_voxel_size_zyx is None: + self.tv_voxel_size_zyx = None + else: + if len(tv_voxel_size_zyx) != 3: + raise ValueError("tv_voxel_size_zyx must be a 3-tuple (dz, dy, dx).") + self.tv_voxel_size_zyx = tuple(float(v) for v in tv_voxel_size_zyx) + if tv_weight_scale_zyx is None: + self.tv_weight_scale_zyx = None + else: + if len(tv_weight_scale_zyx) != 3: + raise ValueError("tv_weight_scale_zyx must be a 3-tuple (sz, sy, sx).") + self.tv_weight_scale_zyx = tuple(float(v) for v in tv_weight_scale_zyx) + self._I_meas_cpu = np.asarray(to_cpu(I_meas)) + self._I0_pred_cpu = np.asarray(to_cpu(I0_pred)) + self._H_real_cpu = np.asarray(to_cpu(H_real)) + self._H_imag_cpu = np.asarray(to_cpu(H_imag)) + self._shape_zyx = self._I_meas_cpu.shape[1:] + self._z_taper = int(z_taper) + self._xy_taper = int(z_taper if xy_taper is None else xy_taper) + if pad_zyx is None: + self.pad_zyx = (0, 0, 0) + else: + if len(pad_zyx) != 3: + raise ValueError("pad_zyx must be a 3-tuple (pz, py, px).") + self.pad_zyx = tuple(int(v) for v in pad_zyx) + if any(v < 0 for v in self.pad_zyx): + raise ValueError("pad_zyx entries must be >= 0.") + if any(self.pad_zyx): + nz, ny, nx = self._shape_zyx + pz, py, px = self.pad_zyx + expected = (nz + 2 * pz, ny + 2 * py, nx + 2 * px) + if self._H_real_cpu.shape[1:] != expected: + raise ValueError( + "H_real/H_imag must match padded shape " + f"{expected} when pad_zyx={self.pad_zyx}." + ) + self._xp_cache = None + + def _pad_volume(self, xp: Any, vol: array) -> array: + if not any(self.pad_zyx): + return vol + pz, py, px = self.pad_zyx + return xp.pad(vol, ((pz, pz), (py, py), (px, px)), mode="constant") + + def _crop_volume(self, vol: array) -> array: + if not any(self.pad_zyx): + return vol + pz, py, px = self.pad_zyx + nz, ny, nx = self._shape_zyx + return vol[pz : pz + nz, py : py + ny, px : px + nx] + + def _pad_contrast(self, xp: Any, contrast: array) -> array: + if not any(self.pad_zyx): + return contrast + pz, py, px = self.pad_zyx + return xp.pad(contrast, ((0, 0), (pz, pz), (py, py), (px, px)), mode="constant") + + def _crop_contrast(self, contrast: array) -> array: + if not any(self.pad_zyx): + return contrast + pz, py, px = self.pad_zyx + nz, ny, nx = self._shape_zyx + return contrast[:, pz : pz + nz, py : py + ny, px : px + nx] + + def _ensure_backend(self, x: array) -> tuple: + xp = cp if cp and isinstance(x, cp.ndarray) else np + if self._xp_cache is None or self._xp_cache[0] is not xp: + z_weight = None + xy_weight = None + if self._z_taper > 0: + z_weight = _tukey_window(int(self._shape_zyx[0]), self._z_taper, xp) + xy_weight = _tukey_window(int(self._shape_zyx[1]), self._xy_taper, xp) + self._xp_cache = ( + xp, + xp.asarray(self._I_meas_cpu), + xp.asarray(self._I0_pred_cpu), + xp.asarray(self._H_real_cpu), + xp.asarray(self._H_imag_cpu), + z_weight, + xy_weight, + ) + return self._xp_cache + + def _forward_contrast(self, n_now: array) -> array: + xp, I_meas, I0_pred, H_real, H_imag, _, _ = self._ensure_backend(n_now) + v_now = get_v(n_now, self.n0, self.wavelength_um) + v_now = self._pad_volume(xp, v_now) + v_ft = ft3(v_now, axes=(0, 1, 2), shift=True) + pred_ft = H_real * v_ft.real + H_imag * v_ft.imag + contrast = ift3(pred_ft, axes=(1, 2, 3), shift=True).real + contrast = self._crop_contrast(contrast) + return contrast + + def _forward_intensity(self, n_now: array) -> array: + xp, I_meas, I0_pred, H_real, H_imag, _, _ = self._ensure_backend(n_now) + contrast = self._forward_contrast(n_now) + I_pred = I0_pred * (1.0 + self.wotf_sign * contrast) + return xp.maximum(I_pred, xp.float32(0.0)) + + def cost(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + xp, I_meas, I0_pred, H_real, H_imag, z_weight, xy_weight = self._ensure_backend(x) + contrast_pred = self.wotf_sign * self._forward_contrast(x) + contrast_meas = I_meas / xp.maximum(I0_pred, xp.float32(self.eps)) - xp.float32(1.0) + residual = contrast_pred - contrast_meas + if z_weight is not None: + residual = residual * z_weight[None, :, None, None] + if xy_weight is not None: + residual = residual * xy_weight[None, None, :, None] + residual = residual * xy_weight[None, None, None, :] + cost = 0.5 * xp.sum(residual ** 2) + return xp.asarray(cost)[None] + + def gradient(self, x: array, inds: Optional[Sequence[int]] = None) -> array: + xp, I_meas, I0_pred, H_real, H_imag, z_weight, xy_weight = self._ensure_backend(x) + contrast_pred = self.wotf_sign * self._forward_contrast(x) + contrast_meas = I_meas / xp.maximum(I0_pred, xp.float32(self.eps)) - xp.float32(1.0) + g_contrast = contrast_pred - contrast_meas + if z_weight is not None: + g_contrast = g_contrast * z_weight[None, :, None, None] + if xy_weight is not None: + g_contrast = g_contrast * xy_weight[None, None, :, None] + g_contrast = g_contrast * xy_weight[None, None, None, :] + g_contrast = self._pad_contrast(xp, g_contrast) + g_contrast_ft = ift3(g_contrast, axes=(1, 2, 3), shift=True, adjoint=True) + grad_vft_real = self.wotf_sign * xp.sum(xp.conj(H_real) * g_contrast_ft, axis=0) + grad_vft_imag = self.wotf_sign * xp.sum(xp.conj(H_imag) * g_contrast_ft, axis=0) + grad_v = ft3( + xp.real(grad_vft_real) + 1j * xp.real(grad_vft_imag), + axes=(0, 1, 2), + shift=True, + adjoint=True, + ) + grad_v = self._crop_volume(grad_v) + k2 = (2.0 * np.pi / float(self.wavelength_um)) ** 2 + grad = -2.0 * k2 * x * grad_v.real + return grad[None, ...] + + def prox(self, x: array, step: float) -> array: + xp = cp if cp and isinstance(x, cp.ndarray) else np + n_out = x + if self.use_real_constraint: + n_out = xp.maximum(n_out, float(self.n0)) + if self.tv_weight != 0.0: + if not (cp and isinstance(n_out, cp.ndarray)): + raise ValueError("tv_prox_fast requires GPU/CuPy.") + n_out = tv_prox_fast( + cp.asarray(n_out, dtype=cp.float32), + float(self.tv_weight), + num_iter=int(self.tv_max_num_iter), + eps=float(self.tv_eps), + voxel_size_zyx=self.tv_voxel_size_zyx, + weight_scale_zyx=self.tv_weight_scale_zyx, + ) + return n_out diff --git a/tests/test_dpc_inverse.py b/tests/test_dpc_inverse.py deleted file mode 100644 index cf1c263..0000000 --- a/tests/test_dpc_inverse.py +++ /dev/null @@ -1,192 +0,0 @@ -import numpy as np -import pytest - -try: - import cupy as cp # type: ignore -except ImportError: - cp = None - -from mcsim.analysis.dpc_inverse import DPCGeometry, DPCRytovInverse - - -def _make_simple_solver(ny=16, nx=16, n_planes=2, use_gpu: bool = False, delta: float = 1e-3): - # Four LEDs, one per half-plane - led_na = np.array( - [ - [-0.1, 0.0], # left - [0.1, 0.0], # right - [0.0, 0.1], # up - [0.0, -0.1], # down - ], - dtype=float, - ) - - geom = DPCGeometry( - wavelength_um=0.5, - n_medium=1.0, - na_obj=0.8, - camera_pixel_um=2.0, - magnification=10.0, - focal_offsets_um=[0.0] * n_planes, - pattern_order=("left", "right", "up", "down"), - normalize_by_led_count=True, - ) - - n_shape = (2, ny, nx) - drs_n = (0.5, geom.dxy_um, geom.dxy_um) - - # simple ground truth RI - n0 = np.full(n_shape, geom.n_medium, dtype=np.float32) - n0[0, ny // 2, nx // 2] += float(delta) # small perturbation - if use_gpu and cp is not None: - n0 = cp.asarray(n0) - - # simulate data from the forward model - solver = DPCRytovInverse( - np.zeros((n_planes, 4, ny, nx), dtype=np.float32), - led_na, - geom=geom, - n_shape=n_shape, - drs_n=drs_n, - use_gpu=use_gpu, - ) - dpc_pred, _ = solver._predict_fields(n0) - - # re-instantiate with the simulated data - solver = DPCRytovInverse( - dpc_pred, - led_na, - geom=geom, - n_shape=n_shape, - drs_n=drs_n, - use_gpu=use_gpu, - ) - return solver, n0 - - -@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) -def test_forward_shape_and_values(use_gpu): - solver, n0 = _make_simple_solver(use_gpu=use_gpu) - pred, _ = solver._predict_fields(n0) - assert pred.shape == solver.data.shape == (solver.n_planes, 4, solver.ny, solver.nx) - # Forward evaluated at ground truth should match data - np.testing.assert_allclose( - cp.asnumpy(pred) if cp and use_gpu else pred, - cp.asnumpy(solver.data) if cp and use_gpu else solver.data, - rtol=1e-5, - atol=1e-6, - ) - - -@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) -def test_gradient_matches_numeric(use_gpu): - solver, n0 = _make_simple_solver(use_gpu=use_gpu) - g, gn = solver.test_gradient(n0, jind=0, dx=1e-6) - g_np = cp.asnumpy(g) if cp and use_gpu else g - gn_np = cp.asnumpy(gn) if cp and use_gpu else gn - np.testing.assert_allclose(g_np, gn_np, rtol=1e-3, atol=1e-5) - - -def _make_multiled_solver(ny=16, nx=16, n_planes=2, use_gpu: bool = False): - # Multiple LEDs per half-plane - led_na = np.array( - [ - [-0.15, 0.0], - [-0.05, 0.05], - [0.15, 0.0], - [0.05, -0.05], - [0.0, 0.15], - [-0.05, 0.1], - [0.0, -0.15], - [0.05, -0.1], - ], - dtype=float, - ) - - geom = DPCGeometry( - wavelength_um=0.5, - n_medium=1.0, - na_obj=0.8, - camera_pixel_um=2.0, - magnification=10.0, - focal_offsets_um=[0.0] * n_planes, - pattern_order=("left", "right", "up", "down"), - normalize_by_led_count=True, - ) - - n_shape = (2, ny, nx) - drs_n = (0.5, geom.dxy_um, geom.dxy_um) - - n0 = np.full(n_shape, geom.n_medium, dtype=np.float32) - n0[1, ny // 2, nx // 2] += 5e-4 - if use_gpu and cp is not None: - n0 = cp.asarray(n0) - - solver = DPCRytovInverse( - np.zeros((n_planes, 4, ny, nx), dtype=np.float32), - led_na, - geom=geom, - n_shape=n_shape, - drs_n=drs_n, - use_gpu=use_gpu, - ) - dpc_pred, _ = solver._predict_fields(n0) - - solver = DPCRytovInverse( - dpc_pred, - led_na, - geom=geom, - n_shape=n_shape, - drs_n=drs_n, - use_gpu=use_gpu, - ) - return solver, n0 - - -@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) -def test_forward_multiled_shape_and_values(use_gpu): - solver, n0 = _make_multiled_solver(use_gpu=use_gpu) - pred, _ = solver._predict_fields(n0) - assert pred.shape == solver.data.shape == (solver.n_planes, 4, solver.ny, solver.nx) - np.testing.assert_allclose( - cp.asnumpy(pred) if cp and use_gpu else pred, - cp.asnumpy(solver.data) if cp and use_gpu else solver.data, - rtol=1e-5, - atol=1e-6, - ) - - -@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) -def test_gradient_multiled_matches_numeric(use_gpu): - solver, n0 = _make_multiled_solver(use_gpu=use_gpu) - g, gn = solver.test_gradient(n0, jind=0, dx=1e-6) - g_np = cp.asnumpy(g) if cp and use_gpu else g - gn_np = cp.asnumpy(gn) if cp and use_gpu else gn - np.testing.assert_allclose(g_np, gn_np, rtol=1e-3, atol=1e-5) - - -@pytest.mark.parametrize("use_gpu", [False] if cp is None else [False, True]) -def test_reconstruction_improves_mse(use_gpu): - solver, n_true = _make_simple_solver(ny=12, nx=12, n_planes=1, use_gpu=use_gpu, delta=5e-3) - xp = cp if (use_gpu and cp is not None) else np - - n_init = xp.full_like(n_true, solver.geom.n_medium) - - step = solver.guess_step() - res = solver.run( - n_init, - step=step, - max_iterations=80, - use_fista=True, - compute_cost=False, - verbose=False, - compute_all_costs=False, - line_search_iter_limit=None, - label="recon-test ", - ) - n_rec = res["x"] - - mse_init = xp.mean(xp.abs(n_init - n_true) ** 2) - mse_final = xp.mean(xp.abs(n_rec - n_true) ** 2) - # Expect a clear reduction - assert float(mse_final) < 0.5 * float(mse_init) From 63a4e33ac34b9e23ea368266b732f6ee048546c3 Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Sat, 14 Feb 2026 07:35:40 -0700 Subject: [PATCH 09/10] add WOTF comparison script --- examples/compare_wotf_predictions.py | 494 +++++++++++++++++++++++++++ 1 file changed, 494 insertions(+) create mode 100644 examples/compare_wotf_predictions.py diff --git a/examples/compare_wotf_predictions.py b/examples/compare_wotf_predictions.py new file mode 100644 index 0000000..0292447 --- /dev/null +++ b/examples/compare_wotf_predictions.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python3 +""" +Compare WOTF-predicted images for FISTA vs Tikhonov reconstructions. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np +import zarr # type: ignore + +try: + import cupy as cp # type: ignore +except Exception: + cp = None +try: + import matplotlib.pyplot as plt # type: ignore +except Exception: + plt = None + +from mcsim.analysis.dpc_meta import DPCMeta +from mcsim.analysis.fft import ft3, ift3 +from mcsim.analysis.field_prop import get_n +from mcsim.analysis.optimize import to_cpu +from mcsim.analysis.wotf_fista import ( + WOTFParams, + WOTFFISTAOptimizer, + build_led_na_grid, + build_led_pattern_membership, + build_wotf_transfer, + camera_to_photons, +) +from build_dpc_mie_stack import ensure_dpc_mie_stack + + +def _load_simulated_stack(zarr_path: Path) -> tuple[np.ndarray, dict, np.ndarray]: + g = zarr.open_group(str(zarr_path), mode="r") + if "dpc" not in g: + raise ValueError("Missing 'dpc' dataset in simulation output.") + dpc = np.asarray(g["dpc"], dtype=np.float32) + + meta_group = g.get("meta") + if meta_group is None or "focal_offsets_um" not in meta_group: + raise ValueError("Missing 'meta/focal_offsets_um' in simulation output.") + z_planes_um = np.asarray(meta_group["focal_offsets_um"], dtype=float) + + if dpc.ndim == 3: + if dpc.shape[0] != 4: + raise ValueError(f"Expected dpc shape (4, ny, nx), got {dpc.shape}") + dpc = dpc[None, ...] + elif dpc.ndim == 4: + if dpc.shape[1] != 4: + raise ValueError(f"Expected dpc shape (nz, 4, ny, nx), got {dpc.shape}") + else: + raise ValueError(f"Unexpected dpc ndim: {dpc.ndim}") + + I_cam = np.moveaxis(dpc, 1, 0) + return I_cam, dict(g.attrs), z_planes_um + + +def _build_meta(I_cam: np.ndarray, attrs: dict, z_planes_um: np.ndarray) -> DPCMeta: + nz = int(I_cam.shape[1]) + ny = int(I_cam.shape[2]) + nx = int(I_cam.shape[3]) + if z_planes_um.size != nz: + raise ValueError("z_planes_um length must match stack depth.") + + wavelength_um = float(attrs["wavelength_um"]) + na_obj = float(attrs["na_obj"]) + camera_pixel_um = float(attrs["camera_pixel_um"]) + magnification = float(attrs["magnification"]) + n_medium = float(attrs.get("n_medium", 1.0)) + led_grid_shape = tuple(int(v) for v in attrs.get("led_grid_shape", (64, 64))) + pattern_order = tuple(attrs.get("pattern_order", ("left", "right", "up", "down"))) + inner_na = float(attrs.get("inner_na", 0.0)) + include_center_led = bool(attrs.get("include_center_led", False)) + + dz = 0.0 + if z_planes_um.size > 1: + dzs = np.diff(z_planes_um) + if not np.allclose(dzs, dzs[0]): + raise ValueError("z_planes_um must be evenly spaced.") + dz = float(dzs[0]) + + dxy_um = camera_pixel_um / magnification + + return DPCMeta( + wavelength_um=wavelength_um, + n_background=n_medium, + NA_obj=na_obj, + magnification=magnification, + camera_pixel_pitch_um=camera_pixel_um, + volume_shape_zyx=(nz, ny, nx), + voxel_size_um_zyx=(dz, dxy_um, dxy_um), + z_planes_um=z_planes_um, + led_grid_shape=led_grid_shape, + inner_na=inner_na, + include_center_led=include_center_led, + pattern_order=pattern_order, # type: ignore[arg-type] + ) + + +def _report_z_sampling(meta: DPCMeta) -> tuple[float, float]: + dzs = np.diff(np.asarray(meta.z_planes_um, dtype=float)) + dz_mean = float(np.mean(dzs)) if dzs.size else 0.0 + dz_std = float(np.std(dzs)) if dzs.size else 0.0 + dz_meta = float(meta.voxel_size_um_zyx[0]) + print( + "Z sampling:", + f"dz_meta={dz_meta:.6g} um", + f"dz_mean={dz_mean:.6g} um", + f"dz_std={dz_std:.3g} um", + f"nz={meta.volume_shape_zyx[0]}", + ) + if dzs.size and not np.isclose(dz_mean, dz_meta, rtol=1.0e-6, atol=1.0e-9): + print("Warning: dz in meta does not match z_planes spacing.") + return dz_mean, dz_meta + + +def _axial_transfer_summary(H: np.ndarray) -> np.ndarray: + return np.mean(np.abs(H), axis=(2, 3)) + + +def _save_wotf_slice_plots( + out_dir: Path, + H_real: np.ndarray, + H_imag: np.ndarray, + *, + suffix: str, +) -> None: + if plt is None: + print("matplotlib not available; skipping WOTF plot images.") + return + + n_patterns, nz, ny, nx = H_real.shape + z_mid = nz // 2 + y_mid = ny // 2 + + def _grid_plot(data: np.ndarray, title: str, out_name: str) -> None: + log_data = np.log10(data + 1.0e-12) + fig, axes = plt.subplots(2, 2, figsize=(8, 8), constrained_layout=True) + axes = axes.ravel() + for pid in range(min(4, n_patterns)): + axes[pid].imshow(log_data[pid], origin="lower", aspect="auto") + axes[pid].set_title(f"pattern {pid}") + for ax in axes[n_patterns:]: + ax.axis("off") + fig.suptitle(f"{title} (log10)") + fig.savefig(out_dir / out_name, dpi=150) + plt.close(fig) + + h_real_kxy = np.abs(H_real[:, z_mid]) + h_imag_kxy = np.abs(H_imag[:, z_mid]) + h_real_kzx = np.abs(H_real[:, :, y_mid, :]) + h_imag_kzx = np.abs(H_imag[:, :, y_mid, :]) + + _grid_plot(h_real_kxy, "H_real |kz mid| (kx-ky)", f"wotf_real_kxy_{suffix}.png") + _grid_plot(h_imag_kxy, "H_imag |kz mid| (kx-ky)", f"wotf_imag_kxy_{suffix}.png") + _grid_plot(h_real_kzx, "H_real |ky mid| (kz-kx)", f"wotf_real_kzx_{suffix}.png") + _grid_plot(h_imag_kzx, "H_imag |ky mid| (kz-kx)", f"wotf_imag_kzx_{suffix}.png") + + +def _tikhonov_wotf( + contrast: np.ndarray, + H_real: np.ndarray, + *, + reg: float, +) -> np.ndarray: + contrast_ft = ft3(contrast, axes=(1, 2, 3), shift=True) + num = np.sum(np.conj(H_real) * contrast_ft, axis=0) + den = np.sum(np.abs(H_real) ** 2, axis=0) + float(reg) + v_ft_real = num / den + v_real = ift3(v_ft_real, axes=(0, 1, 2), shift=True).real + return v_real.astype(np.float32, copy=False) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--sim-zarr", type=Path, default=Path("build/dpc_mie_stack.zarr")) + parser.add_argument("--out", type=Path, default=Path("build/wotf_pred_compare.zarr")) + parser.add_argument("--use-gpu", action="store_true") + parser.add_argument("--step", type=float, default=5.0e-5) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--tv-weight", type=float, default=0.0) + parser.add_argument("--tv-max-num-iter", type=int, default=50) + parser.add_argument("--tv-eps", type=float, default=2.0e-4) + parser.add_argument("--tv-aniso-z", type=float, default=1.0) + parser.add_argument("--z-taper", type=int, default=8) + parser.add_argument("--xy-taper", type=int, default=None) + parser.add_argument("--pupil-taper-na", type=float, default=0.0) + parser.add_argument("--eps", type=float, default=1.0e-8) + parser.add_argument("--line-search", action="store_true") + parser.add_argument("--line-search-factor", type=float, default=0.5) + parser.add_argument("--restart-line-search", action="store_true") + parser.add_argument("--tikhonov-reg", type=float, default=5.0e-4) + parser.add_argument("--led-subsample", type=int, default=1) + parser.add_argument("--pad-z", type=int, default=64) + parser.add_argument("--pad-yx", type=int, default=64) + parser.add_argument("--adjoint-check", action="store_true") + parser.add_argument("--adjoint-crop-z", type=int, default=0) + parser.add_argument("--adjoint-crop-yx", type=int, default=0) + args = parser.parse_args() + + use_gpu = bool(args.use_gpu and cp is not None) + print(f"Using GPU: {use_gpu}") + + ensure_dpc_mie_stack(args.sim_zarr, use_gpu=use_gpu) + I_cam, attrs, z_planes_um = _load_simulated_stack(args.sim_zarr) + meta = _build_meta(I_cam, attrs, z_planes_um) + _report_z_sampling(meta) + + camera_gain = attrs.get("camera_gains", 1.0) + camera_offset = attrs.get("camera_offsets", 0.0) + I_cam_xp = cp.asarray(I_cam) if use_gpu else np.asarray(I_cam) + I_meas_phot = camera_to_photons( + I_cam_xp, + camera_offset_adu=camera_offset, + camera_gain_photons_per_adu=camera_gain, + ) + + I_meas = np.asarray(to_cpu(I_meas_phot), dtype=np.float32) + I0_pred = np.mean(I_meas, axis=(2, 3), keepdims=True).astype(np.float32, copy=False) + + led_na_xy = build_led_na_grid( + meta.led_grid_shape[0], + meta.led_grid_shape[1], + na_obj=meta.NA_obj, + na_in=meta.inner_na, + include_center=meta.include_center_led, + led_subsample=int(args.led_subsample), + ) + membership = build_led_pattern_membership( + led_na_xy, + order=meta.pattern_order, + include_center=meta.include_center_led, + ) + + params = WOTFParams( + wavelength_um=meta.wavelength_um, + na_obj=meta.NA_obj, + na_in=meta.inner_na, + n0=meta.n_background, + dxy_um=meta.dxy_um, + dz_um=float(meta.voxel_size_um_zyx[0]), + nz=meta.volume_shape_zyx[0] + 2 * int(args.pad_z), + ny=meta.volume_shape_zyx[1] + 2 * int(args.pad_yx), + nx=meta.volume_shape_zyx[2] + 2 * int(args.pad_yx), + pupil_taper_na=float(args.pupil_taper_na), + ) + H_real, H_imag = build_wotf_transfer(led_na_xy, membership, params) + expected_shape = ( + meta.volume_shape_zyx[0] + 2 * int(args.pad_z), + meta.volume_shape_zyx[1] + 2 * int(args.pad_yx), + meta.volume_shape_zyx[2] + 2 * int(args.pad_yx), + ) + if H_real.shape[1:] != expected_shape: + print(f"Warning: H_real shape {H_real.shape} != expected {expected_shape}") + + xp = cp if use_gpu else np + n0_vol = xp.full(meta.volume_shape_zyx, float(meta.n_background), dtype=xp.float32) + + contrast = (I_meas / I0_pred - 1.0) / -1.0 + pad_z = int(args.pad_z) + pad_yx = int(args.pad_yx) + if pad_z or pad_yx: + contrast = np.pad( + contrast, + ((0, 0), (pad_z, pad_z), (pad_yx, pad_yx), (pad_yx, pad_yx)), + mode="constant", + ) + v_real = _tikhonov_wotf(contrast, H_real, reg=float(args.tikhonov_reg)) + if pad_z or pad_yx: + nz, ny, nx = meta.volume_shape_zyx + v_real = v_real[pad_z : pad_z + nz, pad_yx : pad_yx + ny, pad_yx : pad_yx + nx] + n_tikh = get_n(v_real, meta.n_background, meta.wavelength_um).real.astype(np.float32, copy=False) + + optimizer = WOTFFISTAOptimizer( + xp.asarray(I_meas), + xp.asarray(I0_pred), + xp.asarray(H_real), + xp.asarray(H_imag), + n0=meta.n_background, + wavelength_um=meta.wavelength_um, + eps=float(args.eps), + tv_weight=float(args.tv_weight), + tv_max_num_iter=int(args.tv_max_num_iter), + tv_eps=float(args.tv_eps), + tv_voxel_size_zyx=(params.dz_um, params.dxy_um, params.dxy_um), + tv_weight_scale_zyx=(float(args.tv_aniso_z), 1.0, 1.0), + pad_zyx=(int(args.pad_z), int(args.pad_yx), int(args.pad_yx)), + z_taper=int(args.z_taper), + xy_taper=args.xy_taper, + use_real_constraint=True, + ) + + if args.adjoint_check: + nz, ny, nx = meta.volume_shape_zyx + crop_z = int(args.adjoint_crop_z) + crop_yx = int(args.adjoint_crop_yx) + if crop_z > 0: + crop_z = min(crop_z, nz) + if crop_yx > 0: + crop_yx = min(crop_yx, ny, nx) + z0 = (nz - crop_z) // 2 if crop_z > 0 else 0 + y0 = (ny - crop_yx) // 2 if crop_yx > 0 else 0 + x0 = (nx - crop_yx) // 2 if crop_yx > 0 else 0 + z1 = z0 + (crop_z if crop_z > 0 else nz) + y1 = y0 + (crop_yx if crop_yx > 0 else ny) + x1 = x0 + (crop_yx if crop_yx > 0 else nx) + v = xp.zeros((nz, ny, nx), dtype=xp.float32) + g = xp.zeros((H_real.shape[0], nz, ny, nx), dtype=xp.float32) + rng = np.random.default_rng(0) + v_cpu = rng.standard_normal((z1 - z0, y1 - y0, x1 - x0)).astype(np.float32) + g_cpu = rng.standard_normal((H_real.shape[0], z1 - z0, y1 - y0, x1 - x0)).astype(np.float32) + v[z0:z1, y0:y1, x0:x1] = xp.asarray(v_cpu) + g[:, z0:z1, y0:y1, x0:x1] = xp.asarray(g_cpu) + pad_z = int(args.pad_z) + pad_yx = int(args.pad_yx) + if pad_z or pad_yx: + v_pad = xp.pad(v, ((pad_z, pad_z), (pad_yx, pad_yx), (pad_yx, pad_yx)), mode="constant") + g_pad = xp.pad(g, ((0, 0), (pad_z, pad_z), (pad_yx, pad_yx), (pad_yx, pad_yx)), mode="constant") + else: + v_pad = v + g_pad = g + v_ft = ft3(v_pad, axes=(0, 1, 2), shift=True) + pred_ft = xp.asarray(H_real) * v_ft.real + xp.asarray(H_imag) * v_ft.imag + contrast = ift3(pred_ft, axes=(1, 2, 3), shift=True).real + if pad_z or pad_yx: + contrast = contrast[:, pad_z : pad_z + nz, pad_yx : pad_yx + ny, pad_yx : pad_yx + nx] + lhs = float(to_cpu(xp.sum(contrast * g))) + g_ft = ift3(g_pad, axes=(1, 2, 3), shift=True, adjoint=True) + grad_vft_real = xp.sum(xp.conj(xp.asarray(H_real)) * g_ft, axis=0) + grad_vft_imag = xp.sum(xp.conj(xp.asarray(H_imag)) * g_ft, axis=0) + grad_v = ft3( + xp.real(grad_vft_real) + 1j * xp.real(grad_vft_imag), + axes=(0, 1, 2), + shift=True, + adjoint=True, + ) + if pad_z or pad_yx: + grad_v = grad_v[pad_z : pad_z + nz, pad_yx : pad_yx + ny, pad_yx : pad_yx + nx] + rhs = float(to_cpu(xp.sum(v * grad_v.real))) + denom = max(1.0e-12, abs(lhs), abs(rhs)) + rel_err = abs(lhs - rhs) / denom + n_init = (cp.asarray if use_gpu else np.asarray)( + np.full(meta.volume_shape_zyx, meta.n_background, dtype=np.float32) + ) + diag_optimizer = WOTFFISTAOptimizer( + xp.asarray(I_meas), + xp.asarray(I0_pred), + xp.asarray(H_real), + xp.asarray(H_imag), + n0=meta.n_background, + wavelength_um=meta.wavelength_um, + eps=float(args.eps), + tv_weight=0.0, + tv_max_num_iter=0, + tv_eps=float(args.tv_eps), + tv_voxel_size_zyx=(params.dz_um, params.dxy_um, params.dxy_um), + tv_weight_scale_zyx=(1.0, 1.0, 1.0), + pad_zyx=(int(args.pad_z), int(args.pad_yx), int(args.pad_yx)), + z_taper=int(args.z_taper), + xy_taper=args.xy_taper, + use_real_constraint=True, + ) + cost_tikh = float(to_cpu(diag_optimizer.cost(n_init))[0]) + grad_tikh = diag_optimizer.gradient(n_init)[0] + grad_norm = float(to_cpu(xp.sqrt(xp.sum(grad_tikh**2)))) + step_size = float(args.step) + n_after_raw = n_init - step_size * grad_tikh + n_after = diag_optimizer.prox(n_after_raw, step_size) + cost_tikh_step = float(to_cpu(diag_optimizer.cost(n_after))[0]) + cost_tikh_step_raw = float(to_cpu(diag_optimizer.cost(n_after_raw))[0]) + dir_deriv = float(to_cpu(xp.sum(grad_tikh * grad_tikh))) + n_min = float(to_cpu(xp.min(n_init))) + n0 = float(meta.n_background) + n_below = float(to_cpu(xp.mean(n_init < n0))) + eps_candidates = [1.0e-5, 1.0e-6, 1.0e-7, 1.0e-8] + step_costs = [] + for eps_step in eps_candidates: + n_pos = n_init - eps_step * grad_tikh + n_neg = n_init + eps_step * grad_tikh + cost_pos = float(to_cpu(diag_optimizer.cost(n_pos))[0]) + cost_neg = float(to_cpu(diag_optimizer.cost(n_neg))[0]) + step_costs.append((eps_step, cost_pos, cost_neg)) + line_search_iter_limit = None if args.line_search else 0 + result = optimizer.run( + n_init, + step=float(args.step), + max_iterations=int(args.iters), + use_fista=True, + n_batch=1, + compute_batch_grad_parallel=True, + verbose=True, + compute_cost=True, + line_search_iter_limit=line_search_iter_limit, + line_search_factor=float(args.line_search_factor), + restart_line_search=bool(args.restart_line_search), + ) + n_fista = np.asarray(to_cpu(result["x"]), dtype=np.float32) + + I_pred_fista = np.asarray( + to_cpu( + optimizer._forward_intensity( # type: ignore[attr-defined] + cp.asarray(n_fista) if use_gpu else n_fista + ) + ), + dtype=np.float32, + ) + I_pred_tikh = np.asarray( + to_cpu( + optimizer._forward_intensity( # type: ignore[attr-defined] + cp.asarray(n_tikh) if use_gpu else n_tikh + ) + ), + dtype=np.float32, + ) + rmse_fista = np.sqrt(np.mean((I_pred_fista - I_meas) ** 2, axis=(1, 2, 3))) + rmse_tikh = np.sqrt(np.mean((I_pred_tikh - I_meas) ** 2, axis=(1, 2, 3))) + print("RMSE per pattern (FISTA):", rmse_fista) + print("RMSE per pattern (Tikhonov):", rmse_tikh) + + out_path = args.out + out_path.parent.mkdir(parents=True, exist_ok=True) + _save_wotf_slice_plots(out_path.parent, H_real, H_imag, suffix=f"{out_path.stem}_fista") + _save_wotf_slice_plots(out_path.parent, H_real, H_imag, suffix=f"{out_path.stem}_tikh") + g = zarr.open_group(str(out_path), mode="w", zarr_format=3) + g.attrs.update(meta.as_dict()) + g.attrs["use_gpu"] = bool(use_gpu) + g.attrs["iters"] = int(args.iters) + g.attrs["step"] = float(args.step) + g.attrs["tv_weight"] = float(args.tv_weight) + g.attrs["z_taper"] = int(args.z_taper) + g.attrs["xy_taper"] = None if args.xy_taper is None else int(args.xy_taper) + g.attrs["pupil_taper_na"] = float(args.pupil_taper_na) + g.attrs["tikhonov_reg"] = float(args.tikhonov_reg) + g.attrs["init_from_tikhonov"] = True + g.attrs["diag_cost_tikh"] = float(cost_tikh) + g.attrs["diag_cost_tikh_step"] = float(cost_tikh_step) + g.attrs["diag_step_size"] = float(step_size) + g.attrs["diag_grad_norm"] = float(grad_norm) + g.attrs["diag_dir_deriv"] = float(dir_deriv) + g.attrs["diag_step_costs"] = [(float(eps), float(cpos), float(cneg)) for eps, cpos, cneg in step_costs] + g.create_array("I_meas", shape=I_meas.shape, dtype="float32")[...] = I_meas + g.create_array("I0_pred", shape=I0_pred.shape, dtype="float32")[...] = I0_pred + g.create_array("n_fista", shape=n_fista.shape, dtype="float32")[...] = n_fista + g.create_array("n_tikh", shape=n_tikh.shape, dtype="float32")[...] = n_tikh + g.create_array("rmse_fista", shape=rmse_fista.shape, dtype="float32")[...] = rmse_fista.astype( + np.float32, copy=False + ) + g.create_array("rmse_tikh", shape=rmse_tikh.shape, dtype="float32")[...] = rmse_tikh.astype( + np.float32, copy=False + ) + g.create_array("I_pred_fista", shape=I_pred_fista.shape, dtype="float32")[...] = I_pred_fista + g.create_array("I_pred_tikh", shape=I_pred_tikh.shape, dtype="float32")[...] = I_pred_tikh + h_real_abs = _axial_transfer_summary(H_real) + h_imag_abs = _axial_transfer_summary(H_imag) + g.create_array("H_real_abs_mean", shape=h_real_abs.shape, dtype="float32")[...] = h_real_abs.astype( + np.float32, copy=False + ) + g.create_array("H_imag_abs_mean", shape=h_imag_abs.shape, dtype="float32")[...] = h_imag_abs.astype( + np.float32, copy=False + ) + z_mid = int(meta.volume_shape_zyx[0]) // 2 + y_mid = int(meta.volume_shape_zyx[1]) // 2 + h_real_kxy = np.abs(H_real[:, z_mid]) + h_imag_kxy = np.abs(H_imag[:, z_mid]) + h_real_kzx = np.abs(H_real[:, :, y_mid, :]) + h_imag_kzx = np.abs(H_imag[:, :, y_mid, :]) + g.create_array("H_real_kxy_abs", shape=h_real_kxy.shape, dtype="float32")[...] = h_real_kxy.astype( + np.float32, copy=False + ) + g.create_array("H_imag_kxy_abs", shape=h_imag_kxy.shape, dtype="float32")[...] = h_imag_kxy.astype( + np.float32, copy=False + ) + g.create_array("H_real_kzx_abs", shape=h_real_kzx.shape, dtype="float32")[...] = h_real_kzx.astype( + np.float32, copy=False + ) + g.create_array("H_imag_kzx_abs", shape=h_imag_kzx.shape, dtype="float32")[...] = h_imag_kzx.astype( + np.float32, copy=False + ) + kz = np.fft.fftshift(np.fft.fftfreq(int(meta.volume_shape_zyx[0]), float(meta.voxel_size_um_zyx[0]))) + g.create_array("kz_cycles_per_um", shape=kz.shape, dtype="float32")[...] = kz.astype(np.float32, copy=False) + g.create_array("led_na_xy", shape=led_na_xy.shape, dtype="float32")[...] = led_na_xy.astype(np.float32, copy=False) + g.create_array("pattern_membership", shape=membership.shape, dtype="uint8")[...] = membership.astype( + np.uint8, copy=False + ) + print(f"Wrote outputs to {out_path}") + + +if __name__ == "__main__": + main() From 975cc274cf6507a7af3b3fc1b2b7005213a54c7b Mon Sep 17 00:00:00 2001 From: dpshepherd Date: Sat, 14 Feb 2026 14:18:50 -0700 Subject: [PATCH 10/10] update gitignore --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 39a7793..ec27911 100644 --- a/.gitignore +++ b/.gitignore @@ -309,4 +309,6 @@ expt_ctrl/temp*.cfg *.hdf5 # documentation -docs/_* \ No newline at end of file +docs/_* + +.codex/** \ No newline at end of file