diff --git a/docs/api/core.md b/docs/api/core.md index e830563..aa51869 100644 --- a/docs/api/core.md +++ b/docs/api/core.md @@ -30,6 +30,14 @@ Low-level optical flow computation engine implementing variational optical flow .. autofunction:: pyflowreg.core.optical_flow.get_motion_tensor_gc ``` +```{eval-rst} +.. autofunction:: pyflowreg.core.optical_flow.get_motion_tensor_gray +``` + +```{eval-rst} +.. autofunction:: pyflowreg.core.optical_flow.get_motion_tensor_cs +``` + ### Boundary Handling ```{eval-rst} diff --git a/docs/api/session.md b/docs/api/session.md index 3dc543f..59e08cd 100644 --- a/docs/api/session.md +++ b/docs/api/session.md @@ -37,6 +37,7 @@ from pyflowreg.session.config import SessionConfig - sigma_smooth - alpha_between - iterations_between + - stage2_constancy_assumption **Configuration File Support** @@ -78,6 +79,7 @@ cc_upsample = 4 sigma_smooth = 6.0 alpha_between = 25.0 iterations_between = 100 +stage2_constancy_assumption = "gc" # Options: "gc", "gray", "cs" ``` ## Stage 1: Per-Recording Compensation @@ -321,13 +323,15 @@ config = SessionConfig( cc_upsample=4, sigma_smooth=6.0, alpha_between=25.0, - iterations_between=100 + iterations_between=100, + stage2_constancy_assumption="gc", ) # Stage 1: Motion correct each recording print("Running Stage 1...") config.flow_options = { "quality_setting": "balanced", + "constancy_assumption": "gc", "save_valid_idx": True, "save_w": False, } diff --git a/docs/theory/parameters.md b/docs/theory/parameters.md index 8e7eed2..c44bdca 100644 --- a/docs/theory/parameters.md +++ b/docs/theory/parameters.md @@ -38,6 +38,11 @@ The sublinear value follows best practices from {cite}`sun2010secrets` for handl - **Edge preservation**: Sublinear diffusion allows the model to handle brightness discontinuities at cell boundaries more gracefully - **Empirical validation**: This value has been validated across diverse 2-photon imaging datasets {cite}`flotho2022flow` +Optional GNC staging can be enabled with `gnc_schedule`, for example +`(0.0, 0.5, 1.0)`. This reruns the pyramid from quadratic to robust stages +while keeping the final `a_data` and `a_smooth` values unchanged. Leaving +`gnc_schedule=None` preserves the legacy solver path. + ## Spatial-Temporal Filtering: `sigma` The `sigma` parameter controls Gaussian filtering applied to the video data before optical flow computation. It is specified as `[σx, σy, σt]` for each channel, where: diff --git a/docs/user_guide/configuration.md b/docs/user_guide/configuration.md index 7e61ad4..01e0b52 100644 --- a/docs/user_guide/configuration.md +++ b/docs/user_guide/configuration.md @@ -47,9 +47,24 @@ options = OFOptions( # Nonlinear diffusion parameters a_smooth=1.0, # Smoothness diffusion parameter a_data=0.45, # Data term diffusion parameter + + # Optional solver-level GNC stages for sublinear penalties + gnc_schedule=(0.0, 0.5, 1.0), + + # Data term, default preserves MATLAB Flow-Registration behavior + constancy_assumption="gc", # Options: "gc", "gray", "cs" ) ``` +`constancy_assumption="gc"` is the default gradient constancy data term used by +the MATLAB Flow-Registration reference. `"gray"` selects gray-value constancy, +and `"cs"` selects census constancy. These data terms are implemented by the +native `flowreg` backend; the `diso` backend rejects non-default values. + +Set `gnc_schedule=None` to keep the default solver path. When provided, +PyFlowReg reruns the pyramid once per stage, warm-starting each stage from +the previous result. + ### Alpha (Smoothness Weight) Controls the tradeoff between fitting the data and enforcing smooth flow fields: diff --git a/docs/user_guide/multi_session.md b/docs/user_guide/multi_session.md index 14b1e92..b95c7ba 100644 --- a/docs/user_guide/multi_session.md +++ b/docs/user_guide/multi_session.md @@ -62,6 +62,7 @@ cc_upsample = 4 # Subpixel accuracy sigma_smooth = 6.0 # Gaussian smoothing alpha_between = 25.0 # Regularization iterations_between = 100 +stage2_constancy_assumption = "gc" # Options: "gc", "gray", "cs" ``` ### 3. Run Processing @@ -118,6 +119,9 @@ The `pyflowreg.session` pipeline always runs the same three deterministic stages - Temporal averages are reloaded from disk and the reference recording (center) is selected automatically or from `SessionConfig.center`. - `compute_between_displacement()` smooths both averages, applies phase cross-correlation for a rigid guess, then refines with the configured flow backend (`src/pyflowreg/session/stage2_between_avgs.py`). +- `stage2_constancy_assumption` controls the Stage 2 data term. The default + `"gc"` preserves MATLAB Flow-Registration behavior; `"cs"` enables the + census term for the native `flowreg` backend. - Results are written to `w_to_reference.npz` (separate `u`/`v` arrays) so MATLAB users can load them directly. **Outputs:** `w_to_reference.npz`, per-recording `status.json` updates, and `middle_idx` (0-based) pointing to the reference average. @@ -170,6 +174,7 @@ quality_setting = "fast" # Options: fast, balanced, quality buffer_size = 1000 # Frames per batch save_w = false # Don't save displacement fields save_valid_idx = true # Required for Stage 3 +# constancy_assumption = "cs" # Optional Stage 1 data term override ``` Alternatively, point to a saved MATLAB/Python options file: diff --git a/examples/session_config.toml b/examples/session_config.toml index ffefce4..a8c7170 100644 --- a/examples/session_config.toml +++ b/examples/session_config.toml @@ -42,6 +42,7 @@ sigma_smooth = 6.0 # Sigma for Gaussian filter # Optical flow refinement alpha_between = 25.0 # Regularization strength (higher = smoother) iterations_between = 100 # Solver iterations (higher = more accurate) +stage2_constancy_assumption = "gc" # Options: "gc", "gray", "cs" # === Stage 1 Flow Options (Optional) === # Provide inline overrides passed to OFOptions @@ -51,6 +52,7 @@ buffer_size = 1000 # Frames per batch save_w = false # Save displacement fields save_valid_idx = true # Required for Stage 3 save_meta_info = true # Save statistics +# constancy_assumption = "gc" # Options: "gc", "gray", "cs" # Alternatively reference a saved JSON file: # flow_options = "./saved_options/session_stage1.json" diff --git a/examples/session_config.yml b/examples/session_config.yml index 311f7c5..6d35c3e 100644 --- a/examples/session_config.yml +++ b/examples/session_config.yml @@ -31,6 +31,7 @@ cc_upsample: 4 # Subpixel accuracy (higher = more precise) sigma_smooth: 6.0 # Sigma for Gaussian filter alpha_between: 25.0 # Regularization strength (higher = smoother) iterations_between: 100 # Solver iterations (higher = more accurate) +stage2_constancy_assumption: gc # Options: gc, gray, cs # Stage 1 Flow Options (Optional) # Provide inline overrides passed to OFOptions @@ -40,6 +41,7 @@ flow_options: save_w: false # Save displacement fields save_valid_idx: true # Required for Stage 3 save_meta_info: true # Save statistics + # constancy_assumption: gc # Options: gc, gray, cs # Alternatively, reference a JSON file saved via OF_options: # flow_options: "./saved_options/session_stage1.json" diff --git a/src/pyflowreg/core/__init__.py b/src/pyflowreg/core/__init__.py index 59d1e34..29c21c0 100644 --- a/src/pyflowreg/core/__init__.py +++ b/src/pyflowreg/core/__init__.py @@ -145,6 +145,7 @@ def torch_level_solver( a_smooth, hx, hy, + gnc_beta=None, ): # Convert to tensors dtype_map = {"float64": torch.float64, "float32": torch.float32} @@ -171,6 +172,8 @@ def to_tensor(a): a_smooth, hx, hy, + update_lag_semantics="matlab" if gnc_beta is not None else "torch", + gnc_beta=gnc_beta, ) return du.cpu().numpy(), dv.cpu().numpy() @@ -232,6 +235,7 @@ def cuda_level_solver( a_smooth, hx, hy, + gnc_beta=None, ): # CUDA solver handles numpy/cupy conversion internally return level_solver_rbgs_cuda( @@ -251,6 +255,8 @@ def cuda_level_solver( a_smooth, hx, hy, + update_lag_semantics="matlab" if gnc_beta is not None else "torch", + gnc_beta=gnc_beta, ) # Return a partial function with the custom level solver diff --git a/src/pyflowreg/core/level_solver.py b/src/pyflowreg/core/level_solver.py index 1f162f1..29c66c4 100644 --- a/src/pyflowreg/core/level_solver.py +++ b/src/pyflowreg/core/level_solver.py @@ -367,3 +367,189 @@ def compute_flow( flow[:, :, 0] = du flow[:, :, 1] = dv return flow + + +@njit(fastmath=True, cache=True) +def compute_flow_gnc( + J11, + J22, + J33, + J12, + J13, + J23, + weight, + u, + v, + alpha_x, + alpha_y, + iterations, + update_lag, + a_data, + a_smooth, + hx, + hy, + gnc_beta, +): + """ + Iterative solver for one fixed GNC stage at a single pyramid level. + + This variant keeps the baseline solver unchanged and mixes the quadratic + and robust penalties with a fixed stage weight ``gnc_beta`` in ``[0, 1]``. + """ + m, n, n_channels = J11.shape + du = np.zeros((m, n)) + dv = np.zeros((m, n)) + psi = np.ones((m, n, n_channels)) + psi_smooth = np.ones((m, n)) + + OMEGA = 1.95 + alpha = np.array([alpha_x, alpha_y], dtype=np.float64) + mix_quadratic = 1.0 - gnc_beta + + for iteration_counter in range(iterations): + if (iteration_counter + 1) % update_lag == 0: + for k in range(n_channels): + a_k = a_data[k] + use_gnc_data = 0.0 < a_k < 1.0 and gnc_beta < 1.0 + for i in range(n): + for j in range(m): + val = ( + J11[j, i, k] * du[j, i] * du[j, i] + + J22[j, i, k] * dv[j, i] * dv[j, i] + + J23[j, i, k] * dv[j, i] + + 2.0 * J12[j, i, k] * du[j, i] * dv[j, i] + + 2.0 * J13[j, i, k] * du[j, i] + + J23[j, i, k] * dv[j, i] + + J33[j, i, k] + ) + if val < 0.0: + val = 0.0 + robust = a_k * (val + 0.00001) ** (a_k - 1.0) + if use_gnc_data: + psi[j, i, k] = mix_quadratic + gnc_beta * robust + else: + psi[j, i, k] = robust + + if a_smooth != 1.0: + nonlinearity_smoothness_2d( + psi_smooth, u, du, v, dv, m, n, a_smooth, hx, hy + ) + if 0.0 < a_smooth < 1.0 and gnc_beta < 1.0: + for i in range(n): + for j in range(m): + psi_smooth[j, i] = ( + mix_quadratic + gnc_beta * psi_smooth[j, i] + ) + else: + for i in range(n): + for j in range(m): + psi_smooth[j, i] = 1.0 + + set_boundary_2d(du) + set_boundary_2d(dv) + + for i in range(1, n - 1): + for j in range(1, m - 1): + denom_u = 0.0 + denom_v = 0.0 + num_u = 0.0 + num_v = 0.0 + + left = (j, i - 1) + right = (j, i + 1) + down = (j + 1, i) + up = (j - 1, i) + + if a_smooth != 1.0: + tmp = ( + 0.5 + * (psi_smooth[j, i] + psi_smooth[left]) + * (alpha[0] / (hx * hx)) + ) + num_u += tmp * (u[left] + du[left] - u[j, i]) + num_v += tmp * (v[left] + dv[left] - v[j, i]) + denom_u += tmp + denom_v += tmp + + tmp = ( + 0.5 + * (psi_smooth[j, i] + psi_smooth[right]) + * (alpha[0] / (hx * hx)) + ) + num_u += tmp * (u[right] + du[right] - u[j, i]) + num_v += tmp * (v[right] + dv[right] - v[j, i]) + denom_u += tmp + denom_v += tmp + + tmp = ( + 0.5 + * (psi_smooth[j, i] + psi_smooth[down]) + * (alpha[1] / (hy * hy)) + ) + num_u += tmp * (u[down] + du[down] - u[j, i]) + num_v += tmp * (v[down] + dv[down] - v[j, i]) + denom_u += tmp + denom_v += tmp + + tmp = ( + 0.5 + * (psi_smooth[j, i] + psi_smooth[up]) + * (alpha[1] / (hy * hy)) + ) + num_u += tmp * (u[up] + du[up] - u[j, i]) + num_v += tmp * (v[up] + dv[up] - v[j, i]) + denom_u += tmp + denom_v += tmp + else: + tmp = alpha[0] / (hx * hx) + num_u += tmp * (u[left] + du[left] - u[j, i]) + num_v += tmp * (v[left] + dv[left] - v[j, i]) + denom_u += tmp + denom_v += tmp + + tmp = alpha[0] / (hx * hx) + num_u += tmp * (u[right] + du[right] - u[j, i]) + num_v += tmp * (v[right] + dv[right] - v[j, i]) + denom_u += tmp + denom_v += tmp + + tmp = alpha[1] / (hy * hy) + num_u += tmp * (u[down] + du[down] - u[j, i]) + num_v += tmp * (v[down] + dv[down] - v[j, i]) + denom_u += tmp + denom_v += tmp + + tmp = alpha[1] / (hy * hy) + num_u += tmp * (u[up] + du[up] - u[j, i]) + num_v += tmp * (v[up] + dv[up] - v[j, i]) + denom_u += tmp + denom_v += tmp + + for k in range(n_channels): + val_u = ( + weight[j, i, k] + * psi[j, i, k] + * (J13[j, i, k] + J12[j, i, k] * dv[j, i]) + ) + num_u -= val_u + denom_u += weight[j, i, k] * psi[j, i, k] * J11[j, i, k] + denom_v += weight[j, i, k] * psi[j, i, k] * J22[j, i, k] + + du_kp1 = num_u / denom_u if denom_u != 0.0 else 0.0 + du[j, i] = (1.0 - OMEGA) * du[j, i] + OMEGA * du_kp1 + + num_v2 = num_v + for k in range(n_channels): + num_v2 -= ( + weight[j, i, k] + * psi[j, i, k] + * (J23[j, i, k] + J12[j, i, k] * du[j, i]) + ) + + dv_kp1 = num_v2 / denom_v if denom_v != 0.0 else 0.0 + dv[j, i] = (1.0 - OMEGA) * dv[j, i] + OMEGA * dv_kp1 + + flow = np.zeros((m, n, 2)) + flow[:, :, 0] = du + flow[:, :, 1] = dv + return flow diff --git a/src/pyflowreg/core/optical_flow.py b/src/pyflowreg/core/optical_flow.py index c9faf5a..a08bfc8 100644 --- a/src/pyflowreg/core/optical_flow.py +++ b/src/pyflowreg/core/optical_flow.py @@ -14,6 +14,10 @@ Main API function for computing optical flow between two frames get_motion_tensor_gc Compute motion tensor components for gradient constancy +get_motion_tensor_gray + Compute motion tensor components for gray-value constancy +get_motion_tensor_cs + Compute motion tensor components for census constancy imregister_wrapper Warp an image using computed displacement fields warpingDepth @@ -36,7 +40,7 @@ import numpy as np from scipy.ndimage import median_filter -from pyflowreg.core import compute_flow +from pyflowreg.core.level_solver import compute_flow, compute_flow_gnc from pyflowreg.core.warping import imregister_wrapper, warpingDepth from pyflowreg.util.resize_util import imresize_fused_gauss_cubic as resize @@ -222,7 +226,7 @@ def get_motion_tensor_gray(f1, f2, hy, hx): return J11, J22, J33, J12, J13, J23 -def get_motion_tensor_cs(f1, f2, hy, hx): +def get_motion_tensor_cs(f1, f2, hy, hx, eps=None): """ Compute motion tensor components for census-based constancy assumption. @@ -242,6 +246,14 @@ def get_motion_tensor_cs(f1, f2, hy, hx): Spatial grid spacing in y-direction. hx : float Spatial grid spacing in x-direction. + eps : float, optional + Smoothing width for the smoothed Heaviside function applied to + directional differences ``r = (neighbor - center) / dist``. If None, + uses ``0.1 / 255.0``, matching the Hafner/Demetz/Weickert + ``epsilon = 0.1`` convention for images scaled from ``[0, 255]`` to + approximately ``[0, 1]``. When ``hx`` or ``hy`` are physical units + rather than pixel-like units, callers may need to scale ``eps`` + consistently. Returns ------- @@ -252,26 +264,33 @@ def get_motion_tensor_cs(f1, f2, hy, hx): Notes ----- - The census transform is invariant to monotonically increasing grey-level - changes, improving robustness to illumination variations compared to - gray-value or gradient constancy. A smoothed Heaviside approximation - controlled by `eps` stabilizes the derivatives while preserving the - ordering information that drives the census constraint. - - Symmetric padding enforces Neumann boundary conditions, and the border - is zeroed after aggregation to avoid wrap-around effects from the - cyclic shifts used during neighbor comparisons. + The hard census transform is invariant to global monotonically increasing + grey-value transforms because it depends only on ordering. This + implementation uses finite differences, Gaussian-preprocessed inputs, a + smoothed Heaviside function, and a linearized motion tensor, so invariance + is approximate. + + Additive offsets cancel exactly in neighbor-center differences. Positive + multiplicative changes are approximately handled only when ``eps`` is small + relative to the directional-difference scale, or when ``eps`` is scaled + consistently with image intensity scale. Gamma and other nonlinear + monotone transforms preserve hard ordering but not exact smoothed + Heaviside values. References ---------- .. [1] Hafner, D., Demetz, O., and Weickert, J. "Why is the Census Transform Good for Robust Optic Flow Computation?", SSVM 2013. """ - eps = 0.1 + if eps is None: + eps = 0.1 / 255.0 eps2 = eps * eps + H, W = f1.shape f1p = np.pad(f1, ((1, 1), (1, 1)), mode="symmetric") f2p = np.pad(f2, ((1, 1), (1, 1)), mode="symmetric") + center1 = f1p[1:-1, 1:-1] + center2 = f2p[1:-1, 1:-1] offsets = [ (dy, dx) for dy in (-1, 0, 1) for dx in (-1, 0, 1) if not (dy == 0 and dx == 0) @@ -289,22 +308,22 @@ def get_motion_tensor_cs(f1, f2, hy, hx): for dy, dx in offsets: dist = float(np.sqrt((hy * dy) * (hy * dy) + (hx * dx) * (hx * dx))) - r1 = (np.roll(f1p, shift=(-dy, -dx), axis=(0, 1)) - f1p) / dist - r2 = (np.roll(f2p, shift=(-dy, -dx), axis=(0, 1)) - f2p) / dist + neigh1 = f1p[1 + dy : 1 + dy + H, 1 + dx : 1 + dx + W] + neigh2 = f2p[1 + dy : 1 + dy + H, 1 + dx : 1 + dx + W] + + r1_core = (neigh1 - center1) / dist + r2_core = (neigh2 - center2) / dist - s1 = 0.5 * (1.0 + r1 / np.sqrt(r1 * r1 + eps2)) - s2 = 0.5 * (1.0 + r2 / np.sqrt(r2 * r2 + eps2)) + s1_core = 0.5 * (1.0 + r1_core / np.sqrt(r1_core * r1_core + eps2)) + s2_core = 0.5 * (1.0 + r2_core / np.sqrt(r2_core * r2_core + eps2)) - s1[0, :] = s1[1, :] - s1[-1, :] = s1[-2, :] - s1[:, 0] = s1[:, 1] - s1[:, -1] = s1[:, -2] - s2[0, :] = s2[1, :] - s2[-1, :] = s2[-2, :] - s2[:, 0] = s2[:, 1] - s2[:, -1] = s2[:, -2] + s1 = np.pad(s1_core, 1, mode="edge") + s2 = np.pad(s2_core, 1, mode="edge") - sy, sx = np.gradient(s1, hy, hx) + sy1, sx1 = np.gradient(s1, hy, hx) + sy2, sx2 = np.gradient(s2, hy, hx) + sx = 0.5 * (sx1 + sx2) + sy = 0.5 * (sy1 + sy2) st = s2 - s1 J11 += sx * sx @@ -342,192 +361,135 @@ def level_solver( a_smooth, hx, hy, + gnc_beta=None, ): - result = compute_flow( - J11, - J22, - J33, - J12, - J13, - J23, - weight=weight, - u=u, - v=v, - alpha_x=alpha[0], - alpha_y=alpha[1], - iterations=iterations, - update_lag=update_lag, - a_data=a_data, - a_smooth=a_smooth, - hx=hx, - hy=hy, - ) + if gnc_beta is None: + result = compute_flow( + J11, + J22, + J33, + J12, + J13, + J23, + weight=weight, + u=u, + v=v, + alpha_x=alpha[0], + alpha_y=alpha[1], + iterations=iterations, + update_lag=update_lag, + a_data=a_data, + a_smooth=a_smooth, + hx=hx, + hy=hy, + ) + else: + result = compute_flow_gnc( + J11, + J22, + J33, + J12, + J13, + J23, + weight=weight, + u=u, + v=v, + alpha_x=alpha[0], + alpha_y=alpha[1], + iterations=iterations, + update_lag=update_lag, + a_data=a_data, + a_smooth=a_smooth, + hx=hx, + hy=hy, + gnc_beta=gnc_beta, + ) du = result[:, :, 0] dv = result[:, :, 1] return du, dv -def get_displacement( - fixed, - moving, - alpha=(2, 2), - update_lag=5, - iterations=50, - min_level=0, - levels=50, - eta=0.8, - a_smooth=1.0, - a_data=0.45, - const_assumption="gc", - uv=None, - weight=None, - level_solver_backend=None, -): - """ - Compute optical flow displacement field using variational approach. - - This function implements the main pyramid-based variational optical flow - algorithm using gradient constancy assumption with non-linear diffusion - regularization. Flow is computed from coarse to fine scales, with each - level initialized from the upsampled result of the previous coarser level. - - Parameters - ---------- - fixed : np.ndarray - Reference (fixed) image, shape (H, W) or (H, W, C) - moving : np.ndarray - Moving image to register, shape (H, W) or (H, W, C) - alpha : tuple of float, default=(2, 2) - Regularization strength (alpha_x, alpha_y) controlling smoothness. - Larger values enforce smoother flow fields. - update_lag : int, default=5 - Number of iterations between updates of non-linearity weights (psi). - Smaller values update more frequently (slower, potentially more accurate). - Larger values update less frequently (faster convergence). - iterations : int, default=50 - Number of SOR iterations per pyramid level - min_level : int, default=0 - Minimum (finest) pyramid level to compute. 0 = full resolution. - levels : int, default=50 - Maximum number of pyramid levels attempted - eta : float, default=0.8 - Pyramid downsampling factor per level (0 < eta <= 1). - Each level is eta times the size of the previous level. - a_smooth : float, default=1.0 - Exponent for generalized Charbonnier penalty on smoothness term. - Controls robustness of smoothness regularization via ρ(s²) = (s²)^a: - - a = 1.0: quadratic (L2) penalty, assumes smooth flow everywhere - - a = 0.5: linear (L1) penalty - - 0.5 < a < 1.0: sublinear, robust to local discontinuities - a_data : float, default=0.45 - Exponent for generalized Charbonnier penalty on data term. - Controls robustness to noise and outliers via ρ(d²) = (d²)^a: - - a = 1.0: quadratic (L2) penalty - - a = 0.5: linear (L1) penalty - - a = 0.45: sublinear, robust to noisy microscopy data - const_assumption : str, default='gc' - Constancy assumption: 'gc' for gradient constancy (only option implemented) - uv : np.ndarray, optional - Initial displacement field (H, W, 2) with [u, v] components to initialize - the coarsest (highest) pyramid level. If None, initializes with zeros. - weight : np.ndarray or list, optional - Channel weights for multi-channel registration. Can be: - - 1D array of length C: per-channel weights (normalized to sum to 1) - - 2D array (H, W): spatial weights broadcast to all channels - - 3D array (H, W, C): full spatial and channel weights - If None, uses equal weights (1/C) for all channels. - level_solver_backend : Callable, optional - Custom level solver function to use instead of the default CPU solver. - Used by GPU backends to inject accelerated solvers. Must have the same - signature as the default level_solver function. +def normalize_gnc_schedule(gnc_schedule): + """Validate and normalize an optional GNC schedule.""" + if gnc_schedule is None: + return None - Returns - ------- - flow : np.ndarray - Displacement field, shape (H, W, 2) where [:, :, 0] is horizontal (u) - and [:, :, 1] is vertical (v) displacement in pixels + schedule = np.asarray(gnc_schedule, dtype=np.float64) + if schedule.ndim != 1: + raise ValueError("gnc_schedule must be a 1D sequence of stage weights") + if schedule.size < 2: + raise ValueError("gnc_schedule must contain at least two stages") + if np.any(schedule < 0.0) or np.any(schedule > 1.0): + raise ValueError("gnc_schedule entries must lie in [0, 1]") + if not np.all(np.diff(schedule) >= 0.0): + raise ValueError("gnc_schedule must be monotone nondecreasing") + if not np.isclose(schedule[0], 0.0): + raise ValueError("gnc_schedule must start at 0.0") + if not np.isclose(schedule[-1], 1.0): + raise ValueError("gnc_schedule must end at 1.0") + return np.ascontiguousarray(schedule) - Notes - ----- - The pyramid depth is computed independently for each dimension (height and width), - stopping when that dimension becomes < 10 pixels. This allows narrow ROIs to - achieve large displacements along their longer dimension without being limited - by the shorter dimension. - The algorithm uses successive over-relaxation (SOR) with omega=1.95 for - fast convergence at each pyramid level. +def normalize_warping_steps(warping_steps): + """Validate and normalize an optional number of warping steps.""" + if warping_steps is None: + return None - When a_smooth=1.0 (quadratic), the smoothness penalty computation is - optimized by skipping weight updates (psi_smooth = 1.0). + steps = int(warping_steps) + if steps < 1: + raise ValueError("warping_steps must be a positive integer") + return steps - This implementation maintains compatibility with MATLAB Flow-Registration. - See Also - -------- - pyflowreg.motion_correction.compensate_arr : High-level API with OFOptions - get_motion_tensor_gc : Compute motion tensor for gradient constancy +def _resolve_motion_tensor_func(const_assumption): + """ + Resolve a constancy-assumption selector to a motion tensor function. - References - ---------- - .. [4] Flotho et al. "Software for Non-Parametric Image Registration of - 2-Photon Imaging Data", J Biophotonics, 2022. - https://doi.org/10.1002/jbio.202100330 + The default ``gc`` path is the MATLAB Flow-Registration behavior. Census + and gray-value constancy are explicit opt-in alternatives. """ - # Ensure fixed and moving have the same number of dimensions - assert ( - fixed.ndim == moving.ndim - ), f"Fixed and moving must have same dimensions: fixed.shape={fixed.shape}, moving.shape={moving.shape}" - fixed = fixed.astype(np.float64) - moving = moving.astype(np.float64) - if fixed.ndim == 3: - m, n, n_channels = fixed.shape - else: - m, n = fixed.shape - n_channels = 1 - fixed = fixed[:, :, np.newaxis] - moving = moving[:, :, np.newaxis] - if uv is not None: - u_init = uv[:, :, 0] - v_init = uv[:, :, 1] - else: - u_init = np.zeros((m, n), dtype=np.float64) - v_init = np.zeros((m, n), dtype=np.float64) - if weight is None: - weight = np.ones((m, n, n_channels), dtype=np.float64) / n_channels - else: - weight = weight.astype(np.float64) - if weight.ndim < 3: - # Handle 1D weight array - if weight.ndim == 1: - # If weight has fewer elements than channels, pad with 1/n_channels - if len(weight) < n_channels: - # Use default value for missing channels (MATLAB behavior) - default_weight = 1.0 / n_channels - weight_expanded = np.full( - n_channels, default_weight, dtype=np.float64 - ) - weight_expanded[: len(weight)] = weight - weight = weight_expanded - elif len(weight) > n_channels: - # Truncate if more weights than channels - weight = weight[:n_channels] - # Normalize weights to sum to 1 - weight = weight / weight.sum() - # Broadcast to spatial dimensions - weight = np.ones((m, n, n_channels), dtype=np.float64) * weight.reshape( - 1, 1, -1 - ) - else: - # 2D spatial weight - broadcast to all channels - weight = ( - np.ones((m, n, n_channels), dtype=np.float64) - * weight[..., np.newaxis] - ) - if not isinstance(a_data, np.ndarray): - a_data_arr = np.full(n_channels, a_data, dtype=np.float64) - else: - a_data_arr = a_data - a_data_arr = np.ascontiguousarray(a_data_arr) + if hasattr(const_assumption, "value"): + const_assumption = const_assumption.value + + key = str(const_assumption).strip().lower() + tensor_funcs = { + "gc": get_motion_tensor_gc, + "gradient": get_motion_tensor_gc, + "gray": get_motion_tensor_gray, + "brightness": get_motion_tensor_gray, + "cs": get_motion_tensor_cs, + "census": get_motion_tensor_cs, + } + + try: + return tensor_funcs[key] + except KeyError as e: + supported = "', '".join(sorted(tensor_funcs)) + raise ValueError( + f"Unknown constancy assumption: '{const_assumption}'. " + f"Supported values are: '{supported}'." + ) from e + + +def _solve_displacement_stage( + fixed, + moving, + alpha, + update_lag, + iterations, + min_level, + levels, + eta, + a_smooth, + a_data_arr, + uv, + weight, + level_solver_backend, + motion_tensor_func, + gnc_beta, +): + """Solve one full pyramid pass for a fixed GNC stage.""" + m, n, n_channels = fixed.shape f1_low = fixed f2_low = moving max_level_y = warpingDepth(eta, levels, m, m) @@ -539,6 +501,12 @@ def get_displacement( min_level = max(max_level_x, max_level_y) - 1 if min_level < 0: min_level = 0 + if uv is not None: + u_init = uv[:, :, 0] + v_init = uv[:, :, 1] + else: + u_init = np.zeros((m, n), dtype=np.float64) + v_init = np.zeros((m, n), dtype=np.float64) u = None v = None for i in range(max(max_level_x, max_level_y), min_level - 1, -1): @@ -578,7 +546,7 @@ def get_displacement( J13 = np.zeros(J_size, dtype=np.float64) J23 = np.zeros(J_size, dtype=np.float64) for ch in range(n_channels): - J11_ch, J22_ch, J33_ch, J12_ch, J13_ch, J23_ch = get_motion_tensor_gc( + J11_ch, J22_ch, J33_ch, J12_ch, J13_ch, J23_ch = motion_tensor_func( f1_level[:, :, ch], tmp[:, :, ch], current_hx, current_hy ) J11[:, :, ch] = J11_ch @@ -625,6 +593,7 @@ def get_displacement( a_smooth, current_hx, current_hy, + gnc_beta, ) if min(level_size) > 5: du[1:-1, 1:-1] = median_filter(du[1:-1, 1:-1], size=(5, 5), mode="mirror") @@ -637,3 +606,308 @@ def get_displacement( if min_level > 0: w = cv2.resize(w, (n, m), interpolation=cv2.INTER_CUBIC) return w + + +def _solve_displacement_stage_gnc( + fixed, + moving, + alpha, + update_lag, + iterations, + min_level, + levels, + eta, + a_smooth, + a_data_arr, + uv, + weight, + level_solver_backend, + motion_tensor_func, + gnc_beta, + warping_steps, +): + """Solve one GNC stage with repeated warp/relinearize steps per level.""" + m, n, n_channels = fixed.shape + f1_low = fixed + f2_low = moving + max_level_y = warpingDepth(eta, levels, m, m) + max_level_x = warpingDepth(eta, levels, n, n) + max_level = min(max_level_x, max_level_y) * 4 + max_level_y = min(max_level_y, max_level) + max_level_x = min(max_level_x, max_level) + if max(max_level_x, max_level_y) <= min_level: + min_level = max(max_level_x, max_level_y) - 1 + if min_level < 0: + min_level = 0 + if uv is not None: + u_init = uv[:, :, 0] + v_init = uv[:, :, 1] + else: + u_init = np.zeros((m, n), dtype=np.float64) + v_init = np.zeros((m, n), dtype=np.float64) + + solver_func = ( + level_solver_backend if level_solver_backend is not None else level_solver + ) + u = None + v = None + for i in range(max(max_level_x, max_level_y), min_level - 1, -1): + level_size = ( + int(round(m * eta ** (min(i, max_level_y)))), + int(round(n * eta ** (min(i, max_level_x)))), + ) + f1_level = resize(f1_low, level_size) + f2_level = resize(f2_low, level_size) + if f1_level.ndim == 2: + f1_level = f1_level[:, :, np.newaxis] + f2_level = f2_level[:, :, np.newaxis] + current_hx = float(m) / f1_level.shape[0] + current_hy = float(n) / f1_level.shape[1] + + if i == max(max_level_x, max_level_y): + u = add_boundary(resize(u_init, level_size)) + v = add_boundary(resize(v_init, level_size)) + else: + u = add_boundary(resize(u[1:-1, 1:-1], level_size)) + v = add_boundary(resize(v[1:-1, 1:-1], level_size)) + + u = np.ascontiguousarray(u) + v = np.ascontiguousarray(v) + + weight_level = resize(weight, f1_level.shape[:2]) + weight_level = cv2.copyMakeBorder( + weight_level, 1, 1, 1, 1, borderType=cv2.BORDER_CONSTANT, value=0.0 + ) + if weight_level.ndim < 3: + weight_level = weight_level[:, :, np.newaxis] + weight_level = np.ascontiguousarray(weight_level) + + if i == min_level: + alpha_scaling = 1 + else: + alpha_scaling = eta ** (-0.5 * i) + alpha_tmp = [alpha_scaling * alpha[j] for j in range(len(alpha))] + + for _ in range(warping_steps): + tmp = imregister_wrapper( + f2_level, + u[1:-1, 1:-1] / current_hy, + v[1:-1, 1:-1] / current_hx, + f1_level, + ) + if tmp.ndim == 2: + tmp = tmp[:, :, np.newaxis] + + J_size = (f1_level.shape[0] + 2, f1_level.shape[1] + 2, n_channels) + J11 = np.zeros(J_size, dtype=np.float64) + J22 = np.zeros(J_size, dtype=np.float64) + J33 = np.zeros(J_size, dtype=np.float64) + J12 = np.zeros(J_size, dtype=np.float64) + J13 = np.zeros(J_size, dtype=np.float64) + J23 = np.zeros(J_size, dtype=np.float64) + for ch in range(n_channels): + J11_ch, J22_ch, J33_ch, J12_ch, J13_ch, J23_ch = motion_tensor_func( + f1_level[:, :, ch], tmp[:, :, ch], current_hx, current_hy + ) + J11[:, :, ch] = J11_ch + J22[:, :, ch] = J22_ch + J33[:, :, ch] = J33_ch + J12[:, :, ch] = J12_ch + J13[:, :, ch] = J13_ch + J23[:, :, ch] = J23_ch + + du, dv = solver_func( + np.ascontiguousarray(J11), + np.ascontiguousarray(J22), + np.ascontiguousarray(J33), + np.ascontiguousarray(J12), + np.ascontiguousarray(J13), + np.ascontiguousarray(J23), + weight_level, + u, + v, + alpha_tmp, + iterations, + update_lag, + 0, + a_data_arr, + a_smooth, + current_hx, + current_hy, + gnc_beta, + ) + u = u + du + v = v + dv + if min(level_size) > 5: + u = add_boundary( + median_filter(u[1:-1, 1:-1], size=(5, 5), mode="mirror") + ) + v = add_boundary( + median_filter(v[1:-1, 1:-1], size=(5, 5), mode="mirror") + ) + + w = np.zeros((u.shape[0] - 2, u.shape[1] - 2, 2), dtype=np.float64) + w[:, :, 0] = u[1:-1, 1:-1] + w[:, :, 1] = v[1:-1, 1:-1] + if min_level > 0: + w = cv2.resize(w, (n, m), interpolation=cv2.INTER_CUBIC) + return w + + +def get_displacement( + fixed, + moving, + alpha=(2, 2), + update_lag=5, + iterations=50, + min_level=0, + levels=50, + eta=0.8, + a_smooth=1.0, + a_data=0.45, + const_assumption="gc", + uv=None, + weight=None, + level_solver_backend=None, + gnc_schedule=None, + warping_steps=None, +): + """ + Compute optical flow displacement field using variational approach. + + This function implements the main pyramid-based variational optical flow + algorithm using gradient constancy assumption with non-linear diffusion + regularization. Flow is computed from coarse to fine scales, with each + level initialized from the upsampled result of the previous coarser level. + + Parameters + ---------- + fixed : np.ndarray + Reference (fixed) image, shape (H, W) or (H, W, C) + moving : np.ndarray + Moving image to register, shape (H, W) or (H, W, C) + alpha : tuple of float, default=(2, 2) + Regularization strength (alpha_x, alpha_y) controlling smoothness. + Larger values enforce smoother flow fields. + update_lag : int, default=5 + Number of iterations between updates of non-linearity weights (psi). + Smaller values update more frequently (slower, potentially more accurate). + Larger values update less frequently (faster convergence). + iterations : int, default=50 + Number of SOR iterations per pyramid level + min_level : int, default=0 + Minimum (finest) pyramid level to compute. 0 = full resolution. + levels : int, default=50 + Maximum number of pyramid levels attempted + eta : float, default=0.8 + Pyramid downsampling factor per level (0 < eta <= 1). + Each level is eta times the size of the previous level. + a_smooth : float, default=1.0 + Exponent for generalized Charbonnier penalty on smoothness term. + a_data : float, default=0.45 + Exponent for generalized Charbonnier penalty on data term. + const_assumption : str, default='gc' + Constancy assumption. Supported values are 'gc'/'gradient', + 'gray'/'brightness', and 'cs'/'census'. + uv : np.ndarray, optional + Initial displacement field (H, W, 2) with [u, v] components. + weight : np.ndarray or list, optional + Channel weights for multi-channel registration. + level_solver_backend : Callable, optional + Custom level solver function to use instead of the default CPU solver. + gnc_schedule : sequence of float, optional + Opt-in stage weights interpolating from quadratic (0.0) to fully robust + (1.0). Each stage reruns the pyramid with the previous stage result used + as initialization. + warping_steps : int, optional + Number of warp/relinearize steps per pyramid level in optional GNC mode. + If omitted, GNC defaults to 10 steps per level. Ignored when GNC is off. + """ + assert ( + fixed.ndim == moving.ndim + ), f"Fixed and moving must have same dimensions: fixed.shape={fixed.shape}, moving.shape={moving.shape}" + motion_tensor_func = _resolve_motion_tensor_func(const_assumption) + fixed = fixed.astype(np.float64) + moving = moving.astype(np.float64) + if fixed.ndim == 3: + m, n, n_channels = fixed.shape + else: + m, n = fixed.shape + n_channels = 1 + fixed = fixed[:, :, np.newaxis] + moving = moving[:, :, np.newaxis] + + if weight is None: + weight = np.ones((m, n, n_channels), dtype=np.float64) / n_channels + else: + weight = weight.astype(np.float64) + if weight.ndim < 3: + if weight.ndim == 1: + if len(weight) < n_channels: + default_weight = 1.0 / n_channels + weight_expanded = np.full( + n_channels, default_weight, dtype=np.float64 + ) + weight_expanded[: len(weight)] = weight + weight = weight_expanded + elif len(weight) > n_channels: + weight = weight[:n_channels] + weight = weight / weight.sum() + weight = np.ones((m, n, n_channels), dtype=np.float64) * weight.reshape( + 1, 1, -1 + ) + else: + weight = ( + np.ones((m, n, n_channels), dtype=np.float64) + * weight[..., np.newaxis] + ) + + if not isinstance(a_data, np.ndarray): + a_data_arr = np.full(n_channels, a_data, dtype=np.float64) + else: + a_data_arr = a_data + a_data_arr = np.ascontiguousarray(a_data_arr) + + gnc_schedule_arr = normalize_gnc_schedule(gnc_schedule) + warping_steps = normalize_warping_steps(warping_steps) + if gnc_schedule_arr is None: + return _solve_displacement_stage( + fixed, + moving, + alpha, + update_lag, + iterations, + min_level, + levels, + eta, + a_smooth, + a_data_arr, + uv, + weight, + level_solver_backend, + motion_tensor_func, + None, + ) + + flow = uv + effective_warping_steps = 10 if warping_steps is None else warping_steps + for gnc_beta in gnc_schedule_arr: + flow = _solve_displacement_stage_gnc( + fixed, + moving, + alpha, + update_lag, + iterations, + min_level, + levels, + eta, + a_smooth, + a_data_arr, + flow, + weight, + level_solver_backend, + motion_tensor_func, + float(gnc_beta), + effective_warping_steps, + ) + return flow diff --git a/src/pyflowreg/cuda/core/level_solver.py b/src/pyflowreg/cuda/core/level_solver.py index f410b56..ab10a2d 100644 --- a/src/pyflowreg/cuda/core/level_solver.py +++ b/src/pyflowreg/cuda/core/level_solver.py @@ -30,6 +30,7 @@ def level_solver_rbgs_cuda( omega=1.95, eps=1e-6, update_lag_semantics="torch", + gnc_beta=None, ): """ Solve for flow increments using Red-Black Gauss-Seidel relaxation on GPU. @@ -148,6 +149,10 @@ def level_solver_rbgs_cuda( ax = float(alpha[0]) / (hx * hx) ay = float(alpha[1]) / (hy * hy) + use_gnc = gnc_beta is not None + beta = float(gnc_beta) if use_gnc else 1.0 + mix_quadratic = 1.0 - beta + data_mask = (a_vec > 0.0) & (a_vec < 1.0) if use_gnc else None mod = cp.RawModule( code=r""" @@ -374,6 +379,12 @@ def level_solver_rbgs_cuda( np.int32(K), ), ) + if use_gnc: + psi_data = cp.where( + data_mask.reshape(1, 1, K), + mix_quadratic + beta * psi_data, + psi_data, + ) k_denoms( g2, b2, @@ -411,6 +422,8 @@ def level_solver_rbgs_cuda( ) k_brow(g1_rows, b1, (psi_smooth, np.int32(m), np.int32(n))) k_bcol(g1_cols, b1, (psi_smooth, np.int32(m), np.int32(n))) + if use_gnc and 0.0 < float(a_smooth) < 1.0: + psi_smooth = mix_quadratic + beta * psi_smooth else: psi_smooth.fill(1.0) diff --git a/src/pyflowreg/motion_correction/OF_options.py b/src/pyflowreg/motion_correction/OF_options.py index 16008ff..98deb94 100644 --- a/src/pyflowreg/motion_correction/OF_options.py +++ b/src/pyflowreg/motion_correction/OF_options.py @@ -78,6 +78,22 @@ class InterpolationMethod(str, Enum): class ConstancyAssumption(str, Enum): GRAY = "gray" GRADIENT = "gc" + CENSUS = "cs" + + +def _normalize_constancy_assumption_value(v): + """Normalize constancy assumption aliases to serialized option values.""" + if hasattr(v, "value"): + v = v.value + if isinstance(v, str): + aliases = { + "gradient": ConstancyAssumption.GRADIENT.value, + "brightness": ConstancyAssumption.GRAY.value, + "census": ConstancyAssumption.CENSUS.value, + } + key = v.strip().lower() + return aliases.get(key, key) + return v class NamingConvention(str, Enum): @@ -126,6 +142,15 @@ class OFOptions(BaseModel): iterations: StrictInt = Field(50, ge=1, description="Iterations per level") a_smooth: float = Field(1.0, ge=0, description="Smoothness diffusion parameter") a_data: float = Field(0.45, gt=0, le=1, description="Data-term diffusion parameter") + gnc_schedule: Optional[Tuple[float, ...]] = Field( + None, + description="Optional graduated non-convexity stage weights from 0.0 to 1.0", + ) + warping_steps: Optional[StrictInt] = Field( + None, + ge=1, + description="Optional warp/relinearize steps per pyramid level in GNC mode", + ) # Preprocessing sigma: Any = Field( @@ -176,7 +201,8 @@ class OFOptions(BaseModel): NamingConvention.DEFAULT, description="Output filename style" ) constancy_assumption: ConstancyAssumption = Field( - ConstancyAssumption.GRADIENT, description="Constancy assumption" + ConstancyAssumption.GRADIENT, + description="Optical-flow data term: 'gc', 'gray', or 'cs'", ) # Backend configuration @@ -278,6 +304,34 @@ def normalize_sigma(cls, v): raise ValueError("Sigma must be [sx,sy,st] or (n_channels, 3)") return v + @field_validator("gnc_schedule", mode="before") + @classmethod + def normalize_gnc_schedule(cls, v): + """Normalize and validate an optional GNC stage schedule.""" + if v is None: + return None + + schedule = np.asarray(v, dtype=float) + if schedule.ndim != 1: + raise ValueError("gnc_schedule must be a 1D sequence") + if schedule.size < 2: + raise ValueError("gnc_schedule must contain at least two stages") + if np.any(schedule < 0.0) or np.any(schedule > 1.0): + raise ValueError("gnc_schedule entries must lie in [0, 1]") + if not np.all(np.diff(schedule) >= 0.0): + raise ValueError("gnc_schedule must be monotone nondecreasing") + if not np.isclose(schedule[0], 0.0): + raise ValueError("gnc_schedule must start at 0.0") + if not np.isclose(schedule[-1], 1.0): + raise ValueError("gnc_schedule must end at 1.0") + return tuple(float(x) for x in schedule.tolist()) + + @field_validator("constancy_assumption", mode="before") + @classmethod + def normalize_constancy_assumption(cls, v): + """Normalize constancy assumption aliases to serialized option values.""" + return _normalize_constancy_assumption_value(v) + @model_validator(mode="after") def validate_and_normalize(self) -> "OFOptions": """Normalize fields and maintain MATLAB parity.""" @@ -691,6 +745,16 @@ def resolve_get_displacement(self) -> Callable: # Priority 3: Registry backend from pyflowreg.core.backend_registry import get_backend + constancy_assumption = _normalize_constancy_assumption_value( + self.constancy_assumption + ) + if self.flow_backend == "diso" and constancy_assumption != "gc": + raise ValueError( + "The 'diso' backend does not support variational constancy " + f"assumption '{constancy_assumption}'. Use " + "flow_backend='flowreg' for 'gray' or 'cs'." + ) + factory = get_backend(self.flow_backend) return factory(**self.backend_params) @@ -706,7 +770,11 @@ def to_dict(self) -> dict: "update_lag": self.update_lag, "a_data": self.a_data, "a_smooth": self.a_smooth, - "const_assumption": self.constancy_assumption.value, # Fixed: use const_assumption for API compatibility + "gnc_schedule": self.gnc_schedule, + "warping_steps": self.warping_steps, + "const_assumption": _normalize_constancy_assumption_value( + self.constancy_assumption + ), } def __repr__(self) -> str: diff --git a/src/pyflowreg/motion_correction/compensate_recording.py b/src/pyflowreg/motion_correction/compensate_recording.py index ac629a8..4c8814b 100644 --- a/src/pyflowreg/motion_correction/compensate_recording.py +++ b/src/pyflowreg/motion_correction/compensate_recording.py @@ -211,6 +211,36 @@ def _resolve_displacement_func(self): """Resolve the displacement function to use based on options.""" self._get_disp = self.options.resolve_get_displacement() + def _get_flow_params(self) -> Dict[str, Any]: + """Build optical-flow parameters for the configured displacement backend.""" + if hasattr(self.options, "to_dict"): + flow_params = dict(self.options.to_dict()) + else: + flow_params = { + "alpha": self.options.alpha, + "levels": self.options.levels, + "min_level": getattr( + self.options, + "effective_min_level", + getattr(self.options, "min_level", 0), + ), + "eta": self.options.eta, + "update_lag": self.options.update_lag, + "iterations": self.options.iterations, + "a_smooth": self.options.a_smooth, + "a_data": self.options.a_data, + "gnc_schedule": getattr(self.options, "gnc_schedule", None), + "warping_steps": getattr(self.options, "warping_steps", None), + } + const_assumption = getattr(self.options, "constancy_assumption", None) + if const_assumption is not None: + flow_params["const_assumption"] = getattr( + const_assumption, "value", const_assumption + ) + + flow_params["weight"] = self.weight + return flow_params + def register_progress_callback(self, callback: Callable[[int, int], None]) -> None: """ Register a progress callback function. @@ -396,21 +426,7 @@ def _compute_flow_single( w_init: Optional[np.ndarray] = None, ) -> np.ndarray: """Compute flow for a single frame.""" - flow_params = { - "alpha": self.options.alpha, - "weight": self.weight, - "levels": self.options.levels, - "min_level": getattr( - self.options, - "effective_min_level", - getattr(self.options, "min_level", 0), - ), - "eta": self.options.eta, - "update_lag": self.options.update_lag, - "iterations": self.options.iterations, - "a_smooth": self.options.a_smooth, - "a_data": self.options.a_data, - } + flow_params = self._get_flow_params() if w_init is not None: flow_params["uv"] = w_init @@ -433,22 +449,7 @@ def _process_batch_parallel( w_init: Initial displacement field task_id: Task identifier for progress tracking (default: "main") """ - # Build flow parameters dictionary - flow_params = { - "alpha": self.options.alpha, - "weight": self.weight, - "levels": self.options.levels, - "min_level": getattr( - self.options, - "effective_min_level", - getattr(self.options, "min_level", 0), - ), - "eta": self.options.eta, - "update_lag": self.options.update_lag, - "iterations": self.options.iterations, - "a_smooth": self.options.a_smooth, - "a_data": self.options.a_data, - } + flow_params = self._get_flow_params() # Get interpolation method interp_method = getattr(self.options, "interpolation_method", "cubic") diff --git a/src/pyflowreg/session/config.py b/src/pyflowreg/session/config.py index e95c422..a7709b6 100644 --- a/src/pyflowreg/session/config.py +++ b/src/pyflowreg/session/config.py @@ -57,6 +57,9 @@ class SessionConfig(BaseModel): Regularization for inter-sequence optical flow iterations_between : int, default=100 Iterations for inter-sequence optical flow + stage2_constancy_assumption : str, default="gc" + Constancy assumption for Stage 2 optical flow. The default "gc" + preserves the MATLAB Flow-Registration behavior. align_chunk_size : int, default=64 Number of frames to process per batch during Stage 3 video alignment align_output_format : str, default="TIFF" @@ -88,6 +91,7 @@ class SessionConfig(BaseModel): sigma_smooth: float = 6.0 alpha_between: float = 25.0 iterations_between: int = 100 + stage2_constancy_assumption: str = "gc" # Stage 3 parameters align_chunk_size: int = 64 diff --git a/src/pyflowreg/session/stage2_between_avgs.py b/src/pyflowreg/session/stage2_between_avgs.py index 2960eed..e886d84 100644 --- a/src/pyflowreg/session/stage2_between_avgs.py +++ b/src/pyflowreg/session/stage2_between_avgs.py @@ -26,6 +26,32 @@ from pyflowreg.util.xcorr_prealignment import estimate_rigid_xcorr_2d +def normalize_constancy_assumption(value) -> str: + """ + Normalize a constancy-assumption selector to the backend API value. + + The default ``gc`` value is the MATLAB Flow-Registration behavior. Other + data terms are explicit opt-in extensions for the native flowreg backend. + """ + if hasattr(value, "value"): + value = value.value + + key = str(value).strip().lower() + aliases = { + "gradient": "gc", + "brightness": "gray", + "census": "cs", + } + key = aliases.get(key, key) + if key not in {"gc", "gray", "cs"}: + raise ValueError( + f"Unknown constancy assumption: '{value}'. " + "Supported values are: 'gc', 'gradient', 'gray', 'brightness', " + "'cs', and 'census'." + ) + return key + + def mat2gray_ref(img: np.ndarray, ref: np.ndarray = None) -> np.ndarray: """ Normalize image to [0, 1] range. @@ -117,6 +143,16 @@ def compute_between_displacement( img2 = img2.reshape(img2_dims) # Get displacement function based on configured backend + constancy_assumption = normalize_constancy_assumption( + config.stage2_constancy_assumption + ) + if config.flow_backend == "diso" and constancy_assumption != "gc": + raise ValueError( + "The 'diso' backend does not support variational constancy " + f"assumption '{constancy_assumption}'. Use flow_backend='flowreg' " + "for 'gray' or 'cs'." + ) + backend_factory = get_backend(config.flow_backend) get_displacement_func = backend_factory(**config.backend_params) @@ -131,6 +167,7 @@ def compute_between_displacement( img2, alpha=alpha, iterations=config.iterations_between, + const_assumption=constancy_assumption, ) return w + w_init diff --git a/src/pyflowreg/torch/core/level_solver.py b/src/pyflowreg/torch/core/level_solver.py index cade1c8..24d031b 100644 --- a/src/pyflowreg/torch/core/level_solver.py +++ b/src/pyflowreg/torch/core/level_solver.py @@ -28,6 +28,7 @@ def level_solver_rbgs_torch( omega=1.95, eps=1e-6, update_lag_semantics: str = "torch", + gnc_beta=None, ): with torch.no_grad(): m, n, K = J11.shape @@ -36,6 +37,12 @@ def level_solver_rbgs_torch( ax = torch.as_tensor(alpha[0], dtype=u.dtype, device=u.device) / (hx * hx) ay = torch.as_tensor(alpha[1], dtype=u.dtype, device=u.device) / (hy * hy) epsv = torch.as_tensor(eps, dtype=u.dtype, device=u.device) + use_gnc = gnc_beta is not None + beta = None + mix_quadratic = None + if use_gnc: + beta = torch.as_tensor(float(gnc_beta), dtype=u.dtype, device=u.device) + mix_quadratic = 1.0 - beta if isinstance(a_data, (float, int)): A_vec = torch.full( (1, 1, K), float(a_data), dtype=J11.dtype, device=J11.device @@ -82,6 +89,11 @@ def level_solver_rbgs_torch( ) E.clamp_min_(0) psi_data = A_vec * (E + epsv) ** (A_vec - 1) + if use_gnc: + data_mask = (A_vec > 0) & (A_vec < 1) + psi_data = torch.where( + data_mask, mix_quadratic + beta * psi_data, psi_data + ) if a_smooth != 1: uc = u + du vc = v + dv @@ -95,6 +107,8 @@ def level_solver_rbgs_torch( a_s = torch.as_tensor(a_smooth, dtype=u.dtype, device=u.device) psi_smooth[c1, c2] = a_s * (mag[c1, c2] + epsv) ** (a_s - 1) _set_boundary2d_(psi_smooth) + if use_gnc and 0.0 < float(a_smooth) < 1.0: + psi_smooth = mix_quadratic + beta * psi_smooth else: psi_smooth.fill_(1) denom_u_data = torch.sum(W * psi_data * J11, dim=2) diff --git a/tests/core/test_level_solver_gnc.py b/tests/core/test_level_solver_gnc.py new file mode 100644 index 0000000..d08ed3b --- /dev/null +++ b/tests/core/test_level_solver_gnc.py @@ -0,0 +1,106 @@ +"""Tests for low-level GNC penalty blending in the 2D solver.""" + +import numpy as np + +from pyflowreg.core.level_solver import compute_flow_gnc + + +def _data_term_case(a_data=0.45): + """Create a one-pixel data force with smoothness in the denominator.""" + shape = (3, 3, 1) + field_shape = shape[:2] + tensors = [np.zeros(shape, dtype=np.float64) for _ in range(6)] + J11, J22, J33, J12, J13, J23 = tensors + J11[1, 1, 0] = 1.0 + J13[1, 1, 0] = -1.0 + J33[1, 1, 0] = 4.0 + return dict( + J11=J11, + J22=J22, + J33=J33, + J12=J12, + J13=J13, + J23=J23, + weight=np.ones(shape, dtype=np.float64), + u=np.zeros(field_shape, dtype=np.float64), + v=np.zeros(field_shape, dtype=np.float64), + alpha_x=1.0, + alpha_y=1.0, + iterations=1, + a_data=np.array([a_data], dtype=np.float64), + a_smooth=1.0, + hx=1.0, + hy=1.0, + ) + + +def _smoothness_case(a_smooth): + """Create a non-flat initial flow so smoothness weights affect the solve.""" + shape = (5, 5, 1) + field_shape = shape[:2] + u = np.zeros(field_shape, dtype=np.float64) + u[2, 2] = 1.0 + return dict( + J11=np.zeros(shape, dtype=np.float64), + J22=np.zeros(shape, dtype=np.float64), + J33=np.zeros(shape, dtype=np.float64), + J12=np.zeros(shape, dtype=np.float64), + J13=np.zeros(shape, dtype=np.float64), + J23=np.zeros(shape, dtype=np.float64), + weight=np.ones(shape, dtype=np.float64), + u=u, + v=np.zeros(field_shape, dtype=np.float64), + alpha_x=1.0, + alpha_y=1.0, + iterations=1, + a_data=np.array([1.0], dtype=np.float64), + a_smooth=a_smooth, + hx=1.0, + hy=1.0, + ) + + +def test_compute_flow_gnc_updates_data_weights_on_matlab_cadence(): + """Test GNC data weights update on the MATLAB-style update-lag tick.""" + common = _data_term_case() + + lagged_quadratic = compute_flow_gnc(**common, update_lag=2, gnc_beta=0.0) + lagged_robust = compute_flow_gnc(**common, update_lag=2, gnc_beta=1.0) + np.testing.assert_allclose(lagged_quadratic, lagged_robust) + + updated_quadratic = compute_flow_gnc(**common, update_lag=1, gnc_beta=0.0) + updated_midpoint = compute_flow_gnc(**common, update_lag=1, gnc_beta=0.5) + updated_robust = compute_flow_gnc(**common, update_lag=1, gnc_beta=1.0) + + assert updated_quadratic[1, 1, 0] > updated_midpoint[1, 1, 0] + assert updated_midpoint[1, 1, 0] > updated_robust[1, 1, 0] + + +def test_compute_flow_gnc_leaves_quadratic_data_term_beta_invariant(): + """Test ``a_data=1`` is quadratic already and ignores the GNC beta.""" + common = _data_term_case(a_data=1.0) + + quadratic_stage = compute_flow_gnc(**common, update_lag=1, gnc_beta=0.0) + robust_stage = compute_flow_gnc(**common, update_lag=1, gnc_beta=1.0) + + np.testing.assert_allclose(quadratic_stage, robust_stage) + + +def test_compute_flow_gnc_blends_sublinear_smoothness_weights(): + """Test sublinear smoothness weights change with the GNC stage beta.""" + common = _smoothness_case(a_smooth=0.5) + + quadratic_stage = compute_flow_gnc(**common, update_lag=1, gnc_beta=0.0) + robust_stage = compute_flow_gnc(**common, update_lag=1, gnc_beta=1.0) + + assert not np.allclose(quadratic_stage, robust_stage) + + +def test_compute_flow_gnc_leaves_quadratic_smoothness_beta_invariant(): + """Test ``a_smooth=1`` is quadratic already and ignores the GNC beta.""" + common = _smoothness_case(a_smooth=1.0) + + quadratic_stage = compute_flow_gnc(**common, update_lag=1, gnc_beta=0.0) + robust_stage = compute_flow_gnc(**common, update_lag=1, gnc_beta=1.0) + + np.testing.assert_allclose(quadratic_stage, robust_stage) diff --git a/tests/core/test_optical_flow.py b/tests/core/test_optical_flow.py new file mode 100644 index 0000000..56dc40d --- /dev/null +++ b/tests/core/test_optical_flow.py @@ -0,0 +1,421 @@ +"""Tests for optical flow stage dispatch, GNC plumbing, and tensor helpers.""" + +import inspect + +import numpy as np + +import pyflowreg.core.optical_flow as optical_flow +from pyflowreg.core.optical_flow import get_displacement, get_motion_tensor_cs + + +def _compress_runs(values): + """Collapse consecutive duplicate values while preserving order.""" + compressed = [] + for value in values: + if not compressed or compressed[-1] != value: + compressed.append(value) + return compressed + + +def _make_zero_level_solver(sink=None): + """Create a fake level solver that records GNC stages.""" + + def fake_level_solver(*args): + u = args[7] + v = args[8] + gnc_beta = args[-1] + if sink is not None: + sink.append(gnc_beta) + return np.zeros_like(u), np.zeros_like(v) + + return fake_level_solver + + +def _make_tensor_stub(counter=None): + """Return zero motion tensors with the expected padded shape.""" + + def fake_get_motion_tensor_gc(f1, f2, hy, hx): + if counter is not None: + counter.append((f1.shape, f2.shape, hy, hx)) + shape = (f1.shape[0] + 2, f1.shape[1] + 2) + return tuple(np.zeros(shape, dtype=np.float64) for _ in range(6)) + + return fake_get_motion_tensor_gc + + +def _sample_images(shape=(8, 9)): + y = np.linspace(0.0, 1.0, shape[0])[:, np.newaxis] + x = np.linspace(0.0, 1.0, shape[1])[np.newaxis, :] + f1 = 0.25 + 0.35 * x + 0.20 * y + f2 = 0.30 + 0.25 * x * x + 0.15 * y + return f1.astype(np.float64), f2.astype(np.float64) + + +def _assert_zero_border(tensors): + for tensor in tensors: + assert np.array_equal(tensor[0, :], np.zeros_like(tensor[0, :])) + assert np.array_equal(tensor[-1, :], np.zeros_like(tensor[-1, :])) + assert np.array_equal(tensor[:, 0], np.zeros_like(tensor[:, 0])) + assert np.array_equal(tensor[:, -1], np.zeros_like(tensor[:, -1])) + + +def test_get_displacement_without_gnc_uses_baseline_stage(): + """Test the default path does not pass a GNC stage to the solver.""" + fixed = np.zeros((12, 12), dtype=np.float64) + moving = np.zeros((12, 12), dtype=np.float64) + seen_betas = [] + + flow = get_displacement( + fixed, + moving, + levels=0, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(seen_betas), + ) + + assert flow.shape == (12, 12, 2) + assert _compress_runs(seen_betas) == [None] + + +def test_get_displacement_without_gnc_ignores_warping_steps(): + """Test warping_steps alone does not change the default path.""" + fixed = np.zeros((12, 12), dtype=np.float64) + moving = np.zeros((12, 12), dtype=np.float64) + seen_betas = [] + + flow = get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(seen_betas), + warping_steps=3, + ) + + assert flow.shape == (12, 12, 2) + assert _compress_runs(seen_betas) == [None] + + +def test_get_displacement_with_gnc_repeats_pyramid_per_stage(): + """Test GNC reruns the pyramid with one fixed stage weight per pass.""" + fixed = np.zeros((12, 12), dtype=np.float64) + moving = np.zeros((12, 12), dtype=np.float64) + seen_betas = [] + + flow = get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(seen_betas), + gnc_schedule=(0.0, 0.5, 1.0), + ) + + assert flow.shape == (12, 12, 2) + assert _compress_runs(seen_betas) == [0.0, 0.5, 1.0] + + +def test_get_displacement_with_gnc_warping_steps_repeats_solver_calls(): + """Test optional GNC warping steps repeat the per-level solver calls.""" + fixed = np.zeros((12, 12), dtype=np.float64) + moving = np.zeros((12, 12), dtype=np.float64) + base_seen_betas = [] + warp_seen_betas = [] + + get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(base_seen_betas), + gnc_schedule=(0.0, 0.5, 1.0), + warping_steps=1, + ) + get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(warp_seen_betas), + gnc_schedule=(0.0, 0.5, 1.0), + warping_steps=3, + ) + + assert _compress_runs(base_seen_betas) == [0.0, 0.5, 1.0] + for beta in (0.0, 0.5, 1.0): + assert warp_seen_betas.count(beta) == 3 * base_seen_betas.count(beta) + + +def test_get_displacement_with_gnc_default_warping_steps_is_ten(): + """Test omitted GNC warping steps use Sun-style ten warps per level.""" + fixed = np.zeros((12, 12), dtype=np.float64) + moving = np.zeros((12, 12), dtype=np.float64) + base_seen_betas = [] + seen_betas = [] + + get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(base_seen_betas), + gnc_schedule=(0.0, 1.0), + warping_steps=1, + ) + flow = get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(seen_betas), + gnc_schedule=(0.0, 1.0), + ) + + assert flow.shape == (12, 12, 2) + assert seen_betas.count(0.0) == 10 * base_seen_betas.count(0.0) + assert seen_betas.count(1.0) == 10 * base_seen_betas.count(1.0) + + +def test_get_displacement_with_gnc_rebuilds_tensors_each_warp(monkeypatch): + """Test each GNC warp rewarps images and rebuilds motion tensors.""" + fixed = np.zeros((12, 12, 2), dtype=np.float64) + moving = np.zeros((12, 12, 2), dtype=np.float64) + warp_calls = [] + tensor_calls = [] + seen_betas = [] + + def fake_imregister_wrapper(f2_level, u, v, f1_level): + warp_calls.append((f2_level.shape, u.shape, v.shape, f1_level.shape)) + return f2_level + + monkeypatch.setattr(optical_flow, "imregister_wrapper", fake_imregister_wrapper) + monkeypatch.setattr( + optical_flow, + "get_motion_tensor_gc", + _make_tensor_stub(tensor_calls), + ) + + flow = optical_flow.get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(seen_betas), + gnc_schedule=(0.0, 1.0), + warping_steps=2, + ) + + assert flow.shape == (12, 12, 2) + assert len(warp_calls) == len(seen_betas) + assert len(tensor_calls) == 2 * len(warp_calls) + + +def test_get_displacement_with_gnc_median_filters_after_each_warp(monkeypatch): + """Test GNC applies the median filter after each warp-level solve.""" + fixed = np.zeros((12, 12), dtype=np.float64) + moving = np.zeros((12, 12), dtype=np.float64) + median_calls = [] + seen_betas = [] + + def fake_median_filter(arr, size, mode): + median_calls.append((arr.shape, size, mode)) + return arr + + monkeypatch.setattr(optical_flow, "median_filter", fake_median_filter) + monkeypatch.setattr(optical_flow, "get_motion_tensor_gc", _make_tensor_stub()) + + flow = optical_flow.get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=_make_zero_level_solver(seen_betas), + gnc_schedule=(0.0, 1.0), + warping_steps=3, + ) + + assert flow.shape == (12, 12, 2) + assert len(median_calls) == 2 * len(seen_betas) + assert {(size, mode) for _, size, mode in median_calls} == {((5, 5), "mirror")} + assert all(min(shape) > 5 for shape, _, _ in median_calls) + + +def test_get_displacement_with_gnc_carries_flow_between_stages(): + """Test each GNC stage starts from the previous stage displacement.""" + fixed = np.zeros((12, 12), dtype=np.float64) + moving = np.zeros((12, 12), dtype=np.float64) + stage_initial_means = [] + + def fake_level_solver(*args): + u = args[7] + v = args[8] + gnc_beta = args[-1] + stage_initial_means.append((gnc_beta, float(u[1:-1, 1:-1].mean()))) + du = np.full_like(u, gnc_beta + 1.0) + return du, np.zeros_like(v) + + flow = get_displacement( + fixed, + moving, + levels=1, + iterations=2, + update_lag=1, + level_solver_backend=fake_level_solver, + gnc_schedule=(0.0, 1.0), + warping_steps=1, + ) + + assert flow.shape == (12, 12, 2) + beta0_means = [mean for beta, mean in stage_initial_means if beta == 0.0] + beta1_means = [mean for beta, mean in stage_initial_means if beta == 1.0] + assert beta0_means[0] == 0.0 + np.testing.assert_allclose(beta0_means[-1], len(beta0_means) - 1, atol=1e-6) + np.testing.assert_allclose(beta1_means[0], len(beta0_means), atol=1e-6) + + +def test_level_solver_dispatches_to_default_solver_without_gnc(monkeypatch): + """Test the CPU wrapper keeps the default solver branch separate.""" + shape = (3, 3, 1) + field_shape = shape[:2] + seen = [] + + def fake_default(*args, **kwargs): + seen.append(("default", kwargs)) + return np.zeros((*field_shape, 2), dtype=np.float64) + + def fake_gnc(*args, **kwargs): + seen.append(("gnc", kwargs)) + return np.zeros((*field_shape, 2), dtype=np.float64) + + monkeypatch.setattr(optical_flow, "compute_flow", fake_default) + monkeypatch.setattr(optical_flow, "compute_flow_gnc", fake_gnc) + + optical_flow.level_solver( + *[np.zeros(shape, dtype=np.float64) for _ in range(6)], + np.ones(shape, dtype=np.float64), + np.zeros(field_shape, dtype=np.float64), + np.zeros(field_shape, dtype=np.float64), + (1.0, 1.0), + 1, + 1, + 0, + np.array([0.45], dtype=np.float64), + 1.0, + 1.0, + 1.0, + ) + + assert len(seen) == 1 + assert seen[0][0] == "default" + assert "gnc_beta" not in seen[0][1] + + +def test_level_solver_dispatches_to_gnc_solver_with_beta(monkeypatch): + """Test the CPU wrapper only enters the GNC solver when requested.""" + shape = (3, 3, 1) + field_shape = shape[:2] + seen = [] + + def fake_default(*args, **kwargs): + seen.append(("default", kwargs)) + return np.zeros((*field_shape, 2), dtype=np.float64) + + def fake_gnc(*args, **kwargs): + seen.append(("gnc", kwargs)) + return np.zeros((*field_shape, 2), dtype=np.float64) + + monkeypatch.setattr(optical_flow, "compute_flow", fake_default) + monkeypatch.setattr(optical_flow, "compute_flow_gnc", fake_gnc) + + optical_flow.level_solver( + *[np.zeros(shape, dtype=np.float64) for _ in range(6)], + np.ones(shape, dtype=np.float64), + np.zeros(field_shape, dtype=np.float64), + np.zeros(field_shape, dtype=np.float64), + (1.0, 1.0), + 1, + 1, + 0, + np.array([0.45], dtype=np.float64), + 1.0, + 1.0, + 1.0, + gnc_beta=0.5, + ) + + assert len(seen) == 1 + assert seen[0][0] == "gnc" + assert seen[0][1]["gnc_beta"] == 0.5 + + +def test_get_motion_tensor_cs_shape_and_zero_border(): + """Returned tensors match the solver contract.""" + f1, f2 = _sample_images() + + tensors = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0) + + assert len(tensors) == 6 + for tensor in tensors: + assert tensor.shape == (f1.shape[0] + 2, f1.shape[1] + 2) + _assert_zero_border(tensors) + + +def test_get_motion_tensor_cs_constant_images_near_zero(): + """Constant frames should not create census tensor energy.""" + f1 = np.full((7, 6), 0.25, dtype=np.float64) + f2 = np.full((7, 6), 0.75, dtype=np.float64) + + tensors = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0) + + for tensor in tensors: + np.testing.assert_allclose(tensor, 0.0, atol=1e-14) + + +def test_get_motion_tensor_cs_additive_shift_invariance(): + """Neighbor-center differences should cancel common additive offsets.""" + f1, f2 = _sample_images() + + base = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0) + shifted = get_motion_tensor_cs(f1 + 4.0, f2 + 4.0, hy=1.0, hx=1.0) + + for base_tensor, shifted_tensor in zip(base, shifted): + np.testing.assert_allclose(shifted_tensor, base_tensor, atol=1e-12) + + +def test_get_motion_tensor_cs_default_eps_matches_normalized_convention(): + """Default epsilon should be the normalized 0.1 / 255.0 value.""" + f1, f2 = _sample_images() + + default = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0) + explicit = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0, eps=0.1 / 255.0) + + for default_tensor, explicit_tensor in zip(default, explicit): + np.testing.assert_allclose(default_tensor, explicit_tensor) + + +def test_get_motion_tensor_cs_does_not_use_np_roll(): + """Neighbor access should not use cyclic shifts.""" + source = inspect.getsource(get_motion_tensor_cs) + + assert "np.roll" not in source + + +def test_get_motion_tensor_cs_anisotropic_spacing(): + """Anisotropic spacing should preserve tensor shape and zero borders.""" + f1, f2 = _sample_images(shape=(6, 10)) + + tensors = get_motion_tensor_cs(f1, f2, hy=0.5, hx=2.0) + + for tensor in tensors: + assert tensor.shape == (f1.shape[0] + 2, f1.shape[1] + 2) + assert np.all(np.isfinite(tensor)) + _assert_zero_border(tensors) diff --git a/tests/core/test_optical_flow_dataterm.py b/tests/core/test_optical_flow_dataterm.py new file mode 100644 index 0000000..797770c --- /dev/null +++ b/tests/core/test_optical_flow_dataterm.py @@ -0,0 +1,128 @@ +"""Tests for optical-flow data term dispatch.""" + +import numpy as np +import pytest + +import pyflowreg.core.optical_flow as optical_flow + + +def _zero_level_solver( + J11, + J22, + J33, + J12, + J13, + J23, + weight, + u, + v, + alpha, + iterations, + update_lag, + verbose, + a_data, + a_smooth, + hx, + hy, + gnc_beta=None, +): + """Return no correction so tests isolate motion tensor dispatch.""" + return np.zeros_like(u), np.zeros_like(v) + + +def _make_motion_tensor(name, calls): + def motion_tensor(f1, f2, hy, hx): + calls.append(name) + shape = (f1.shape[0] + 2, f1.shape[1] + 2) + return tuple(np.zeros(shape, dtype=np.float64) for _ in range(6)) + + return motion_tensor + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, "gc"), + ("gc", "gc"), + ("gradient", "gc"), + ("gray", "gray"), + ("brightness", "gray"), + ("cs", "cs"), + ("census", "cs"), + ], +) +def test_get_displacement_dispatches_constancy_assumption(monkeypatch, value, expected): + """get_displacement should route each selector to its data term.""" + calls = [] + monkeypatch.setattr( + optical_flow, "get_motion_tensor_gc", _make_motion_tensor("gc", calls) + ) + monkeypatch.setattr( + optical_flow, "get_motion_tensor_gray", _make_motion_tensor("gray", calls) + ) + monkeypatch.setattr( + optical_flow, "get_motion_tensor_cs", _make_motion_tensor("cs", calls) + ) + + fixed = np.zeros((16, 16), dtype=np.float64) + moving = np.zeros((16, 16), dtype=np.float64) + kwargs = {} if value is None else {"const_assumption": value} + + optical_flow.get_displacement( + fixed, + moving, + levels=1, + min_level=0, + iterations=1, + level_solver_backend=_zero_level_solver, + **kwargs, + ) + + assert calls + assert set(calls) == {expected} + + +def test_get_displacement_rejects_unknown_constancy_assumption(): + """Unknown data terms should fail before any solver work starts.""" + fixed = np.zeros((16, 16), dtype=np.float64) + moving = np.zeros((16, 16), dtype=np.float64) + + with pytest.raises(ValueError, match="Unknown constancy assumption"): + optical_flow.get_displacement( + fixed, + moving, + const_assumption="invalid", + level_solver_backend=_zero_level_solver, + ) + + +def test_get_displacement_with_gnc_dispatches_constancy_assumption(monkeypatch): + """GNC should rebuild tensors with the selected data term.""" + calls = [] + monkeypatch.setattr( + optical_flow, "get_motion_tensor_gc", _make_motion_tensor("gc", calls) + ) + monkeypatch.setattr( + optical_flow, "get_motion_tensor_gray", _make_motion_tensor("gray", calls) + ) + monkeypatch.setattr( + optical_flow, "get_motion_tensor_cs", _make_motion_tensor("cs", calls) + ) + + fixed = np.zeros((16, 16), dtype=np.float64) + moving = np.zeros((16, 16), dtype=np.float64) + + optical_flow.get_displacement( + fixed, + moving, + levels=1, + min_level=0, + iterations=1, + level_solver_backend=_zero_level_solver, + const_assumption="census", + gnc_schedule=(0.0, 1.0), + warping_steps=1, + ) + + assert calls + assert set(calls) == {"cs"} diff --git a/tests/motion_correction/test_OF_options.py b/tests/motion_correction/test_OF_options.py index 7efa4d1..8294e29 100644 --- a/tests/motion_correction/test_OF_options.py +++ b/tests/motion_correction/test_OF_options.py @@ -8,6 +8,7 @@ import numpy as np from pyflowreg.motion_correction.OF_options import ( + ConstancyAssumption, OFOptions, QualitySetting, ) @@ -162,6 +163,42 @@ def test_sigma_wrong_length_raises(self): OFOptions(sigma=[1.0, 1.0]) # Missing temporal component +class TestGNCScheduleValidation: + """Test GNC schedule validation and normalization.""" + + def test_gnc_schedule_normalized_to_tuple(self): + """Test valid GNC schedules are normalized to tuples.""" + opts = OFOptions(gnc_schedule=[0.0, 0.5, 1.0]) + + assert opts.gnc_schedule == (0.0, 0.5, 1.0) + + @pytest.mark.parametrize( + "schedule,match", + [ + ([0.5, 1.0], "must start at 0.0"), + ([0.0, 0.5], "must end at 1.0"), + ([0.0], "at least two stages"), + ([0.0, 0.75, 0.5, 1.0], "monotone nondecreasing"), + ([-0.1, 1.0], "must lie in \\[0, 1\\]"), + ], + ) + def test_gnc_schedule_invalid(self, schedule, match): + """Test invalid GNC schedules raise validation errors.""" + with pytest.raises(ValueError, match=match): + OFOptions(gnc_schedule=schedule) + + def test_warping_steps_positive_integer(self): + """Test valid warping_steps values are preserved.""" + opts = OFOptions(warping_steps=10) + + assert opts.warping_steps == 10 + + def test_warping_steps_invalid(self): + """Test invalid warping_steps values raise validation errors.""" + with pytest.raises(ValueError, match="greater than or equal to 1"): + OFOptions(warping_steps=0) + + class TestQualitySettingEffectiveMinLevel: """Test quality setting affects effective_min_level.""" @@ -186,6 +223,42 @@ def test_custom_min_level_override(self): assert opts.effective_min_level == 8 +class TestConstancyAssumption: + """Test optical-flow data term configuration.""" + + @pytest.mark.parametrize( + ("value", "expected"), + [ + ("gc", ConstancyAssumption.GRADIENT), + ("gradient", ConstancyAssumption.GRADIENT), + ("gray", ConstancyAssumption.GRAY), + ("brightness", ConstancyAssumption.GRAY), + ("cs", ConstancyAssumption.CENSUS), + ("census", ConstancyAssumption.CENSUS), + ], + ) + def test_constancy_assumption_aliases(self, value, expected): + """Aliases normalize to the enum value used by get_displacement.""" + opts = OFOptions(constancy_assumption=value) + + assert opts.constancy_assumption == expected + assert opts.to_dict()["const_assumption"] == expected.value + + def test_diso_rejects_non_default_constancy_assumption(self): + """DISO backend should not silently accept flowreg-only data terms.""" + opts = OFOptions(flow_backend="diso", constancy_assumption="census") + + with pytest.raises(ValueError, match="does not support"): + opts.resolve_get_displacement() + + def test_to_dict_normalizes_assignment_alias(self): + """Assignment-time aliases should serialize to backend API values.""" + opts = OFOptions() + opts.constancy_assumption = "census" + + assert opts.to_dict()["const_assumption"] == "cs" + + class TestGetWeightAt: """Test get_weight_at method.""" @@ -290,6 +363,22 @@ def test_to_dict_with_spatial_weights(self): assert isinstance(params["weight"], np.ndarray) assert params["weight"].shape == (H, W, C) + def test_to_dict_includes_gnc_schedule(self): + """Test to_dict forwards the optional GNC schedule.""" + opts = OFOptions(gnc_schedule=(0.0, 0.5, 1.0)) + + params = opts.to_dict() + + assert params["gnc_schedule"] == (0.0, 0.5, 1.0) + + def test_to_dict_includes_warping_steps(self): + """Test to_dict forwards the optional GNC warping step count.""" + opts = OFOptions(warping_steps=10) + + params = opts.to_dict() + + assert params["warping_steps"] == 10 + class TestExampleConfigurations: """Test that all example configurations can be created without errors.""" diff --git a/tests/motion_correction/test_compensate_recording.py b/tests/motion_correction/test_compensate_recording.py index 52e89e2..b2fd13d 100644 --- a/tests/motion_correction/test_compensate_recording.py +++ b/tests/motion_correction/test_compensate_recording.py @@ -13,7 +13,7 @@ RegistrationConfig, compensate_recording, ) -from pyflowreg.motion_correction.OF_options import OutputFormat +from pyflowreg.motion_correction.OF_options import OFOptions, OutputFormat from pyflowreg._runtime import RuntimeContext from pyflowreg.util.io.factory import get_video_file_reader @@ -111,6 +111,26 @@ def test_initialization_with_basic_options(self, basic_of_options): assert len(pipeline.mean_disp) == 0 assert len(pipeline.max_disp) == 0 + def test_flow_params_include_constancy_assumption(self, tmp_path): + """Batch flow calls should receive the configured data term.""" + options = OFOptions( + output_path=tmp_path, + quality_setting="fast", + constancy_assumption="census", + levels=1, + iterations=2, + ) + config = RegistrationConfig( + n_jobs=1, verbose=True, parallelization="sequential" + ) + pipeline = BatchMotionCorrector(options, config) + pipeline.weight = np.ones((8, 8, 1), dtype=np.float64) + + flow_params = pipeline._get_flow_params() + + assert flow_params["const_assumption"] == "cs" + assert flow_params["weight"] is pipeline.weight + class TestExecutorTypes: """Test different executor types work correctly.""" diff --git a/tests/session/test_config.py b/tests/session/test_config.py index 8ed16fe..ee9370b 100644 --- a/tests/session/test_config.py +++ b/tests/session/test_config.py @@ -320,6 +320,7 @@ def test_default_stage2_parameters(self, tmp_path): assert config.sigma_smooth == 6.0 assert config.alpha_between == 25.0 assert config.iterations_between == 100 + assert config.stage2_constancy_assumption == "gc" def test_custom_stage2_parameters(self, tmp_path): """Test setting custom Stage 2 parameters.""" @@ -329,12 +330,14 @@ def test_custom_stage2_parameters(self, tmp_path): sigma_smooth=4.5, alpha_between=20.0, iterations_between=150, + stage2_constancy_assumption="census", ) assert config.cc_upsample == 8 assert config.sigma_smooth == 4.5 assert config.alpha_between == 20.0 assert config.iterations_between == 150 + assert config.stage2_constancy_assumption == "census" class TestSessionConfigBackendParameters: diff --git a/tests/session/test_stage2_between_avgs.py b/tests/session/test_stage2_between_avgs.py new file mode 100644 index 0000000..4f4baf9 --- /dev/null +++ b/tests/session/test_stage2_between_avgs.py @@ -0,0 +1,77 @@ +"""Tests for Stage 2 inter-sequence optical-flow configuration.""" + +import numpy as np +import pytest + +from pyflowreg.session.config import SessionConfig +import pyflowreg.session.stage2_between_avgs as stage2 + + +@pytest.fixture +def stage2_lightweight_ops(monkeypatch): + """Patch expensive image operations so tests isolate configuration wiring.""" + monkeypatch.setattr(stage2, "gaussian_filter", lambda image, sigma: image) + monkeypatch.setattr( + stage2, + "estimate_rigid_xcorr_2d", + lambda img1, img2, up: (0, 0), + ) + monkeypatch.setattr( + stage2, + "imregister_wrapper", + lambda img, u, v, ref, interpolation_method="cubic": img, + ) + + +def test_compute_between_displacement_passes_stage2_constancy( + tmp_path, monkeypatch, stage2_lightweight_ops +): + """Stage 2 should pass its data term selector into get_displacement.""" + captured = {} + + def fake_get_backend(name): + assert name == "flowreg" + + def factory(**backend_params): + assert backend_params == {"sentinel": True} + + def get_displacement(fixed, moving, **kwargs): + captured.update(kwargs) + return np.zeros((*fixed.shape, 2), dtype=np.float32) + + return get_displacement + + return factory + + monkeypatch.setattr(stage2, "get_backend", fake_get_backend) + + config = SessionConfig( + root=tmp_path, + backend_params={"sentinel": True}, + stage2_constancy_assumption="census", + ) + reference_avg = np.arange(16, dtype=np.float64).reshape(4, 4) + current_avg = reference_avg + 1 + + w = stage2.compute_between_displacement(reference_avg, current_avg, config) + + assert captured["const_assumption"] == "cs" + assert captured["alpha"] == (25.0, 25.0) + assert captured["iterations"] == 100 + assert w.shape == (4, 4, 2) + + +def test_compute_between_displacement_rejects_diso_non_default_constancy( + tmp_path, stage2_lightweight_ops +): + """DISO should fail explicitly for flowreg-only data terms.""" + config = SessionConfig( + root=tmp_path, + flow_backend="diso", + stage2_constancy_assumption="gray", + ) + reference_avg = np.arange(16, dtype=np.float64).reshape(4, 4) + current_avg = reference_avg + 1 + + with pytest.raises(ValueError, match="does not support"): + stage2.compute_between_displacement(reference_avg, current_avg, config)