From 989ee57bcb049498610c1393c92eef793d1ba963 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 14:52:47 -0800 Subject: [PATCH 01/18] feat: Add GPU-accelerated operations via PyTorch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidates PRs #4-8 into a single feature branch: - phase_cross_correlation: GPU FFT via torch.fft (~46x speedup) - shift_array: GPU grid_sample for subpixel shifts (~6.7x speedup) - match_histograms: GPU sort/quantile mapping (~13.3x speedup) - block_reduce: GPU avg_pool2d (~4x speedup) - compute_ssim: GPU conv2d for local statistics (~6.4x speedup) All functions include automatic CPU fallback when CUDA is unavailable. Replaces cupy/cucim dependency with PyTorch for broader compatibility. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/tilefusion/utils.py | 393 +++++++++++++++++++++++++++++++--- tests/test_block_reduce.py | 57 +++++ tests/test_fft.py | 64 ++++++ tests/test_histogram_match.py | 35 +++ tests/test_shift_array.py | 36 ++++ tests/test_ssim.py | 38 ++++ 6 files changed, 588 insertions(+), 35 deletions(-) create mode 100644 tests/test_block_reduce.py create mode 100644 tests/test_fft.py create mode 100644 tests/test_histogram_match.py create mode 100644 tests/test_shift_array.py create mode 100644 tests/test_ssim.py diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index d0d1769..07ee915 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -2,54 +2,375 @@ Shared utilities for tilefusion. GPU/CPU detection, array operations, and helper functions. +All functions support GPU acceleration via PyTorch with automatic CPU fallback. """ import numpy as np +# GPU detection - PyTorch based try: - import cupy as cp - from cupyx.scipy.ndimage import shift as cp_shift - from cucim.skimage.exposure import match_histograms - from cucim.skimage.measure import block_reduce - from cucim.skimage.registration import phase_cross_correlation - from opm_processing.imageprocessing.ssim_cuda import ( - structural_similarity_cupy_sep_shared as ssim_cuda, - ) - - xp = cp - USING_GPU = True -except Exception: - cp = None - cp_shift = None - from skimage.exposure import match_histograms - from skimage.measure import block_reduce - from skimage.registration import phase_cross_correlation - from scipy.ndimage import shift as _shift_cpu - from skimage.metrics import structural_similarity as _ssim_cpu - - xp = np - USING_GPU = False + import torch + import torch.nn.functional as F + TORCH_AVAILABLE = True + CUDA_AVAILABLE = torch.cuda.is_available() +except ImportError: + torch = None + F = None + TORCH_AVAILABLE = False + CUDA_AVAILABLE = False +# CPU fallbacks +from scipy.ndimage import shift as _shift_cpu +from skimage.exposure import match_histograms as _match_histograms_cpu +from skimage.measure import block_reduce as _block_reduce_cpu +from skimage.metrics import structural_similarity as _ssim_cpu +from skimage.registration import phase_cross_correlation as _phase_cross_correlation_cpu + +# Legacy compatibility +USING_GPU = CUDA_AVAILABLE +xp = np +cp = None + + +# ============================================================================= +# Phase Cross-Correlation (GPU FFT) +# ============================================================================= + +def phase_cross_correlation(reference_image, moving_image, upsample_factor=1, **kwargs): + """ + Phase cross-correlation using GPU (torch FFT) or CPU (skimage). + + Parameters + ---------- + reference_image : ndarray + Reference image. + moving_image : ndarray + Image to register. + upsample_factor : int + Upsampling factor for subpixel precision. + + Returns + ------- + shift : ndarray + Shift vector (y, x). + error : float + Translation invariant normalized RMS error (placeholder). + phasediff : float + Global phase difference (placeholder). + """ + ref_np = np.asarray(reference_image) + mov_np = np.asarray(moving_image) + + if CUDA_AVAILABLE and ref_np.ndim == 2: + return _phase_cross_correlation_torch(ref_np, mov_np, upsample_factor) + return _phase_cross_correlation_cpu(ref_np, mov_np, upsample_factor=upsample_factor, **kwargs) + + +def _phase_cross_correlation_torch(reference_image: np.ndarray, moving_image: np.ndarray, + upsample_factor: int = 1) -> tuple: + """GPU phase cross-correlation using torch FFT.""" + ref = torch.from_numpy(reference_image.astype(np.float32)).cuda() + mov = torch.from_numpy(moving_image.astype(np.float32)).cuda() + + # Compute cross-power spectrum + ref_fft = torch.fft.fft2(ref) + mov_fft = torch.fft.fft2(mov) + cross_power = ref_fft * torch.conj(mov_fft) + eps = 1e-10 + cross_power = cross_power / (torch.abs(cross_power) + eps) + + # Inverse FFT to get correlation + correlation = torch.fft.ifft2(cross_power).real + + # Find peak + max_idx = torch.argmax(correlation) + h, w = correlation.shape + peak_y = (max_idx // w).item() + peak_x = (max_idx % w).item() + + # Handle wraparound for negative shifts + if peak_y > h // 2: + peak_y -= h + if peak_x > w // 2: + peak_x -= w + + shift = np.array([float(peak_y), float(peak_x)]) + + # Subpixel refinement if requested + if upsample_factor > 1: + shift = _subpixel_refine_torch(correlation, peak_y, peak_x, h, w) + + return shift, 0.0, 0.0 + + +def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): + """Subpixel refinement using parabolic fit around peak.""" + py = peak_y % h + px = peak_x % w + + y_indices = [(py - 1) % h, py, (py + 1) % h] + x_indices = [(px - 1) % w, px, (px + 1) % w] + + neighborhood = torch.zeros(3, 3, device="cuda") + for i, yi in enumerate(y_indices): + for j, xj in enumerate(x_indices): + neighborhood[i, j] = correlation[yi, xj] + + center_val = neighborhood[1, 1].item() + + # Y direction parabolic fit + if neighborhood[0, 1].item() != center_val or neighborhood[2, 1].item() != center_val: + denom = 2 * (2 * center_val - neighborhood[0, 1].item() - neighborhood[2, 1].item()) + dy = (neighborhood[0, 1].item() - neighborhood[2, 1].item()) / denom if abs(denom) > 1e-10 else 0.0 + else: + dy = 0.0 + + # X direction parabolic fit + if neighborhood[1, 0].item() != center_val or neighborhood[1, 2].item() != center_val: + denom = 2 * (2 * center_val - neighborhood[1, 0].item() - neighborhood[1, 2].item()) + dx = (neighborhood[1, 0].item() - neighborhood[1, 2].item()) / denom if abs(denom) > 1e-10 else 0.0 + else: + dx = 0.0 + + dy = max(-0.5, min(0.5, dy)) + dx = max(-0.5, min(0.5, dx)) + + return np.array([float(peak_y) + dy, float(peak_x) + dx]) + + +# ============================================================================= +# Shift Array (GPU grid_sample) +# ============================================================================= def shift_array(arr, shift_vec): - """Shift array using GPU if available, else CPU fallback.""" - if USING_GPU and cp_shift is not None: - return cp_shift(arr, shift=shift_vec, order=1, prefilter=False) - return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) + """ + Shift array by subpixel amounts using GPU (torch) or CPU (scipy). + + Parameters + ---------- + arr : ndarray + 2D input array. + shift_vec : array-like + (dy, dx) shift amounts. + + Returns + ------- + shifted : ndarray + Shifted array, same shape as input. + """ + arr_np = np.asarray(arr) + + if CUDA_AVAILABLE and arr_np.ndim == 2: + return _shift_array_torch(arr_np, shift_vec) + + return _shift_cpu(arr_np, shift=shift_vec, order=1, prefilter=False) + + +def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray: + """GPU shift using torch.nn.functional.grid_sample.""" + h, w = arr.shape + dy, dx = float(shift_vec[0]), float(shift_vec[1]) + + # Create pixel coordinate grids + y_coords = torch.arange(h, device="cuda", dtype=torch.float32) + x_coords = torch.arange(w, device="cuda", dtype=torch.float32) + grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij") + + # Apply shift: to shift output by (dy, dx), sample from (y-dy, x-dx) + sample_y = grid_y - dy + sample_x = grid_x - dx + + # Normalize to [-1, 1] for grid_sample (align_corners=True) + sample_x = 2 * sample_x / (w - 1) - 1 + sample_y = 2 * sample_y / (h - 1) - 1 + + # Stack to (H, W, 2) with (x, y) order, add batch dim -> (1, H, W, 2) + grid = torch.stack([sample_x, sample_y], dim=-1).unsqueeze(0) + + # Input: (1, 1, H, W) + t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) + + # grid_sample with bilinear interpolation + out = F.grid_sample(t, grid, mode="bilinear", padding_mode="zeros", align_corners=True) + + return out.squeeze().cpu().numpy() + + +# ============================================================================= +# Match Histograms (GPU sort/quantile) +# ============================================================================= + +def match_histograms(image, reference): + """ + Match histogram of image to reference using GPU (torch) or CPU (skimage). + Parameters + ---------- + image : ndarray + Image to transform. + reference : ndarray + Reference image for histogram matching. + + Returns + ------- + matched : ndarray + Image with matched histogram. + """ + image_np = np.asarray(image) + reference_np = np.asarray(reference) + + if CUDA_AVAILABLE and image_np.ndim == 2: + return _match_histograms_torch(image_np, reference_np) + + return _match_histograms_cpu(image_np, reference_np) + + +def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndarray: + """GPU histogram matching using torch operations.""" + # Move to GPU + img = torch.from_numpy(image.astype(np.float32)).cuda().flatten() + ref = torch.from_numpy(reference.astype(np.float32)).cuda().flatten() + + # Get sorted indices + img_sorted, img_indices = torch.sort(img) + ref_sorted, _ = torch.sort(ref) + + # Create inverse mapping + inv_indices = torch.empty_like(img_indices) + inv_indices[img_indices] = torch.arange(len(img), device="cuda") + + # Interpolate reference values at image quantiles + img_quantiles = torch.linspace(0, 1, len(img), device="cuda") + ref_quantiles = torch.linspace(0, 1, len(ref), device="cuda") + + # Map image values to reference values via quantile matching + interp_values = torch.zeros_like(img) + interp_values[img_indices] = ref_sorted[ + (inv_indices.float() / len(img) * len(ref)).long().clamp(0, len(ref) - 1) + ] + + return interp_values.reshape(image.shape).cpu().numpy() + + +# ============================================================================= +# Block Reduce (GPU avg_pool2d) +# ============================================================================= + +def block_reduce(arr, block_size, func=np.mean): + """ + Block reduce array using GPU (torch) or CPU (skimage). + + Parameters + ---------- + arr : ndarray + Input array (2D or 3D with channel dim first). + block_size : tuple + Reduction factors per dimension. + func : callable + Reduction function (only np.mean supported on GPU). + + Returns + ------- + reduced : ndarray + """ + arr_np = np.asarray(arr) + + if CUDA_AVAILABLE and func == np.mean and arr_np.ndim >= 2: + return _block_reduce_torch(arr_np, block_size) + + return _block_reduce_cpu(arr_np, block_size, func) + + +def _block_reduce_torch(arr: np.ndarray, block_size: tuple) -> np.ndarray: + """GPU block reduce using torch.nn.functional.avg_pool2d.""" + ndim = arr.ndim + + if ndim == 2: + kernel = (block_size[0], block_size[1]) + t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) + out = torch.nn.functional.avg_pool2d(t, kernel, stride=kernel) + return out.squeeze().cpu().numpy() + + elif ndim == 3: + kernel = (block_size[1], block_size[2]) if len(block_size) == 3 else block_size + t = torch.from_numpy(arr).float().cuda().unsqueeze(0) + out = torch.nn.functional.avg_pool2d(t, kernel, stride=kernel) + return out.squeeze(0).cpu().numpy() + + return _block_reduce_cpu(arr, block_size, np.mean) + + +# ============================================================================= +# Compute SSIM (GPU conv2d) +# ============================================================================= def compute_ssim(arr1, arr2, win_size: int) -> float: - """SSIM wrapper that routes to GPU kernel or CPU skimage.""" - if USING_GPU and "ssim_cuda" in globals(): - return float(ssim_cuda(arr1, arr2, win_size=win_size)) - arr1_np = np.asarray(arr1) - arr2_np = np.asarray(arr2) + """ + Compute SSIM using GPU (torch) or CPU (skimage). + + Parameters + ---------- + arr1, arr2 : ndarray + Input images (2D). + win_size : int + Window size for local statistics. + + Returns + ------- + ssim : float + Mean SSIM value. + """ + arr1_np = np.asarray(arr1, dtype=np.float32) + arr2_np = np.asarray(arr2, dtype=np.float32) + + if CUDA_AVAILABLE and arr1_np.ndim == 2: + data_range = float(arr1_np.max() - arr1_np.min()) + if data_range == 0: + data_range = 1.0 + return _compute_ssim_torch(arr1_np, arr2_np, win_size, data_range) + data_range = float(arr1_np.max() - arr1_np.min()) if data_range == 0: data_range = 1.0 return float(_ssim_cpu(arr1_np, arr2_np, win_size=win_size, data_range=data_range)) +def _compute_ssim_torch(arr1: np.ndarray, arr2: np.ndarray, win_size: int, data_range: float) -> float: + """GPU SSIM using torch conv2d for local statistics.""" + C1 = (0.01 * data_range) ** 2 + C2 = (0.03 * data_range) ** 2 + + # Create uniform window + window = torch.ones(1, 1, win_size, win_size, device="cuda") / (win_size * win_size) + + # Convert to tensors (1, 1, H, W) + img1 = torch.from_numpy(arr1).float().cuda().unsqueeze(0).unsqueeze(0) + img2 = torch.from_numpy(arr2).float().cuda().unsqueeze(0).unsqueeze(0) + + # Compute local means + mu1 = F.conv2d(img1, window, padding=win_size // 2) + mu2 = F.conv2d(img2, window, padding=win_size // 2) + + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + + # Compute local variances and covariance + sigma1_sq = F.conv2d(img1 ** 2, window, padding=win_size // 2) - mu1_sq + sigma2_sq = F.conv2d(img2 ** 2, window, padding=win_size // 2) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=win_size // 2) - mu1_mu2 + + # SSIM formula + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + return float(ssim_map.mean().cpu()) + + +# ============================================================================= +# Utility Functions +# ============================================================================= + def make_1d_profile(length: int, blend: int) -> np.ndarray: """ Create a linear ramp profile over `blend` pixels at each end. @@ -77,11 +398,13 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray: def to_numpy(arr): """Convert array to numpy, handling both CPU and GPU arrays.""" - if USING_GPU and cp is not None and isinstance(arr, cp.ndarray): - return cp.asnumpy(arr) + if TORCH_AVAILABLE and torch is not None and isinstance(arr, torch.Tensor): + return arr.cpu().numpy() return np.asarray(arr) def to_device(arr): - """Move array to current device (GPU if available, else CPU).""" - return xp.asarray(arr) + """Move array to GPU if available, else return numpy array.""" + if CUDA_AVAILABLE: + return torch.from_numpy(np.asarray(arr)).cuda() + return np.asarray(arr) diff --git a/tests/test_block_reduce.py b/tests/test_block_reduce.py new file mode 100644 index 0000000..eebe1cc --- /dev/null +++ b/tests/test_block_reduce.py @@ -0,0 +1,57 @@ +"""Unit tests for GPU block_reduce.""" + +import numpy as np +import pytest +from skimage.measure import block_reduce as skimage_block_reduce + +import sys +sys.path.insert(0, "src") + +from tilefusion.utils import block_reduce, CUDA_AVAILABLE + + +class TestBlockReduce: + """Test block_reduce GPU vs CPU equivalence.""" + + def test_2d_basic(self): + """Test 2D block reduce matches skimage.""" + arr = np.random.rand(256, 256).astype(np.float32) + block_size = (4, 4) + + result = block_reduce(arr, block_size, np.mean) + expected = skimage_block_reduce(arr, block_size, np.mean) + + np.testing.assert_allclose(result, expected, rtol=1e-5) + + def test_2d_large(self): + """Test larger 2D array.""" + arr = np.random.rand(1024, 1024).astype(np.float32) + block_size = (8, 8) + + result = block_reduce(arr, block_size, np.mean) + expected = skimage_block_reduce(arr, block_size, np.mean) + + np.testing.assert_allclose(result, expected, rtol=1e-5) + + def test_3d_multichannel(self): + """Test 3D array with channel dimension.""" + arr = np.random.rand(3, 256, 256).astype(np.float32) + block_size = (1, 4, 4) + + result = block_reduce(arr, block_size, np.mean) + expected = skimage_block_reduce(arr, block_size, np.mean) + + np.testing.assert_allclose(result, expected, rtol=1e-5) + + def test_output_shape(self): + """Test output shape is correct.""" + arr = np.random.rand(512, 512).astype(np.float32) + block_size = (4, 4) + + result = block_reduce(arr, block_size, np.mean) + + assert result.shape == (128, 128) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_fft.py b/tests/test_fft.py new file mode 100644 index 0000000..92f2023 --- /dev/null +++ b/tests/test_fft.py @@ -0,0 +1,64 @@ +"""Unit tests for GPU phase_cross_correlation (FFT).""" + +import numpy as np +import sys +sys.path.insert(0, "src") + +from tilefusion.utils import phase_cross_correlation, CUDA_AVAILABLE +from skimage.registration import phase_cross_correlation as skimage_pcc + + +def test_known_shift(): + """Test detection of known integer shift.""" + np.random.seed(42) + ref = np.random.rand(256, 256).astype(np.float32) + + # Create shifted version: mov is ref shifted by (+5, -3) + # phase_cross_correlation returns shift to apply to mov to align with ref + # So it should return (-5, +3) + mov = np.zeros_like(ref) + mov[5:, :253] = ref[:-5, 3:] + + shift, _, _ = phase_cross_correlation(ref, mov) + + # Should detect shift close to (-5, +3) + assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5" + assert abs(shift[1] - 3) < 1, f"X shift {shift[1]} not close to 3" + + +def test_zero_shift(): + """Test that identical images give zero shift.""" + np.random.seed(42) + ref = np.random.rand(256, 256).astype(np.float32) + + shift, _, _ = phase_cross_correlation(ref, ref) + + assert abs(shift[0]) < 0.5, f"Y shift {shift[0]} should be ~0" + assert abs(shift[1]) < 0.5, f"X shift {shift[1]} should be ~0" + + +def test_matches_skimage_direction(): + """Test that shift direction matches skimage convention.""" + np.random.seed(42) + ref = np.random.rand(128, 128).astype(np.float32) + + # Shift by rolling + mov = np.roll(np.roll(ref, 10, axis=0), -7, axis=1) + + gpu_shift, _, _ = phase_cross_correlation(ref, mov) + cpu_shift, _, _ = skimage_pcc(ref, mov) + + # Directions should match + assert np.sign(gpu_shift[0]) == np.sign(cpu_shift[0]), "Y direction mismatch" + assert np.sign(gpu_shift[1]) == np.sign(cpu_shift[1]), "X direction mismatch" + + +if __name__ == "__main__": + print(f"CUDA available: {CUDA_AVAILABLE}") + test_known_shift() + print("test_known_shift: PASSED") + test_zero_shift() + print("test_zero_shift: PASSED") + test_matches_skimage_direction() + print("test_matches_skimage_direction: PASSED") + print("All tests passed!") diff --git a/tests/test_histogram_match.py b/tests/test_histogram_match.py new file mode 100644 index 0000000..3db18b5 --- /dev/null +++ b/tests/test_histogram_match.py @@ -0,0 +1,35 @@ +"""Unit tests for GPU histogram matching.""" +import numpy as np +import sys +sys.path.insert(0, "src") + +from tilefusion.utils import match_histograms, CUDA_AVAILABLE +from skimage.exposure import match_histograms as skimage_match + + +def test_histogram_range(): + img = np.random.rand(256, 256).astype(np.float32) + ref = np.random.rand(256, 256).astype(np.float32) * 2 + 1 + result = match_histograms(img, ref) + # Output should be in reference range + assert result.min() >= ref.min() - 0.1 + assert result.max() <= ref.max() + 0.1 + + +def test_histogram_correlation(): + img = np.random.rand(256, 256).astype(np.float32) + ref = np.random.rand(256, 256).astype(np.float32) + + cpu = skimage_match(img, ref) + gpu = match_histograms(img, ref) + + cpu_hist, _ = np.histogram(cpu.flatten(), bins=100) + gpu_hist, _ = np.histogram(gpu.flatten(), bins=100) + corr = np.corrcoef(cpu_hist, gpu_hist)[0, 1] + assert corr > 0.99, f"Histogram correlation {corr} too low" + + +if __name__ == "__main__": + test_histogram_range() + test_histogram_correlation() + print("All tests passed") diff --git a/tests/test_shift_array.py b/tests/test_shift_array.py new file mode 100644 index 0000000..8874b42 --- /dev/null +++ b/tests/test_shift_array.py @@ -0,0 +1,36 @@ +"""Unit tests for GPU shift_array.""" +import numpy as np +import sys +sys.path.insert(0, "src") + +from tilefusion.utils import shift_array, CUDA_AVAILABLE +from scipy.ndimage import shift as scipy_shift + + +def test_integer_shift(): + arr = np.random.rand(256, 256).astype(np.float32) + cpu = scipy_shift(arr, (3.0, -5.0), order=1, prefilter=False) + gpu = shift_array(arr, (3.0, -5.0)) + np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4) + + +def test_subpixel_mean_error(): + arr = np.random.rand(256, 256).astype(np.float32) + cpu = scipy_shift(arr, (5.5, -3.2), order=1, prefilter=False) + gpu = shift_array(arr, (5.5, -3.2)) + mean_diff = np.abs(cpu - gpu).mean() + assert mean_diff < 0.01, f"Mean diff {mean_diff} too high" + + +def test_zero_shift(): + arr = np.random.rand(256, 256).astype(np.float32) + result = shift_array(arr, (0.0, 0.0)) + # Allow small tolerance due to grid_sample interpolation + np.testing.assert_allclose(result, arr, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + test_integer_shift() + test_subpixel_mean_error() + test_zero_shift() + print("All tests passed") diff --git a/tests/test_ssim.py b/tests/test_ssim.py new file mode 100644 index 0000000..ce7f1a0 --- /dev/null +++ b/tests/test_ssim.py @@ -0,0 +1,38 @@ +"""Unit tests for GPU SSIM.""" +import numpy as np +import sys +sys.path.insert(0, "src") + +from tilefusion.utils import compute_ssim, CUDA_AVAILABLE +from skimage.metrics import structural_similarity as skimage_ssim + + +def test_ssim_similar_images(): + arr1 = np.random.rand(256, 256).astype(np.float32) + arr2 = arr1 + np.random.rand(256, 256).astype(np.float32) * 0.1 + + data_range = arr1.max() - arr1.min() + cpu = skimage_ssim(arr1, arr2, win_size=15, data_range=data_range) + gpu = compute_ssim(arr1, arr2, win_size=15) + + assert abs(cpu - gpu) < 0.01, f"SSIM diff {abs(cpu-gpu)} too high" + + +def test_ssim_identical_images(): + arr = np.random.rand(256, 256).astype(np.float32) + ssim = compute_ssim(arr, arr, win_size=15) + assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" + + +def test_ssim_different_images(): + arr1 = np.random.rand(256, 256).astype(np.float32) + arr2 = np.random.rand(256, 256).astype(np.float32) + ssim = compute_ssim(arr1, arr2, win_size=15) + assert ssim < 0.5, f"SSIM of random images should be low, got {ssim}" + + +if __name__ == "__main__": + test_ssim_similar_images() + test_ssim_identical_images() + test_ssim_different_images() + print("All tests passed") From 20d29bf06260e840216904501f54196b30835483 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 14:54:39 -0800 Subject: [PATCH 02/18] style: Apply black formatting --- src/tilefusion/utils.py | 41 +++++++++++++++++++++++++---------- tests/test_block_reduce.py | 1 + tests/test_fft.py | 1 + tests/test_histogram_match.py | 2 ++ tests/test_shift_array.py | 2 ++ tests/test_ssim.py | 2 ++ 6 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 07ee915..2e5da2b 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -11,6 +11,7 @@ try: import torch import torch.nn.functional as F + TORCH_AVAILABLE = True CUDA_AVAILABLE = torch.cuda.is_available() except ImportError: @@ -36,6 +37,7 @@ # Phase Cross-Correlation (GPU FFT) # ============================================================================= + def phase_cross_correlation(reference_image, moving_image, upsample_factor=1, **kwargs): """ Phase cross-correlation using GPU (torch FFT) or CPU (skimage). @@ -66,8 +68,9 @@ def phase_cross_correlation(reference_image, moving_image, upsample_factor=1, ** return _phase_cross_correlation_cpu(ref_np, mov_np, upsample_factor=upsample_factor, **kwargs) -def _phase_cross_correlation_torch(reference_image: np.ndarray, moving_image: np.ndarray, - upsample_factor: int = 1) -> tuple: +def _phase_cross_correlation_torch( + reference_image: np.ndarray, moving_image: np.ndarray, upsample_factor: int = 1 +) -> tuple: """GPU phase cross-correlation using torch FFT.""" ref = torch.from_numpy(reference_image.astype(np.float32)).cuda() mov = torch.from_numpy(moving_image.astype(np.float32)).cuda() @@ -121,14 +124,22 @@ def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): # Y direction parabolic fit if neighborhood[0, 1].item() != center_val or neighborhood[2, 1].item() != center_val: denom = 2 * (2 * center_val - neighborhood[0, 1].item() - neighborhood[2, 1].item()) - dy = (neighborhood[0, 1].item() - neighborhood[2, 1].item()) / denom if abs(denom) > 1e-10 else 0.0 + dy = ( + (neighborhood[0, 1].item() - neighborhood[2, 1].item()) / denom + if abs(denom) > 1e-10 + else 0.0 + ) else: dy = 0.0 # X direction parabolic fit if neighborhood[1, 0].item() != center_val or neighborhood[1, 2].item() != center_val: denom = 2 * (2 * center_val - neighborhood[1, 0].item() - neighborhood[1, 2].item()) - dx = (neighborhood[1, 0].item() - neighborhood[1, 2].item()) / denom if abs(denom) > 1e-10 else 0.0 + dx = ( + (neighborhood[1, 0].item() - neighborhood[1, 2].item()) / denom + if abs(denom) > 1e-10 + else 0.0 + ) else: dx = 0.0 @@ -142,6 +153,7 @@ def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): # Shift Array (GPU grid_sample) # ============================================================================= + def shift_array(arr, shift_vec): """ Shift array by subpixel amounts using GPU (torch) or CPU (scipy). @@ -200,6 +212,7 @@ def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray: # Match Histograms (GPU sort/quantile) # ============================================================================= + def match_histograms(image, reference): """ Match histogram of image to reference using GPU (torch) or CPU (skimage). @@ -256,6 +269,7 @@ def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndar # Block Reduce (GPU avg_pool2d) # ============================================================================= + def block_reduce(arr, block_size, func=np.mean): """ Block reduce array using GPU (torch) or CPU (skimage). @@ -304,6 +318,7 @@ def _block_reduce_torch(arr: np.ndarray, block_size: tuple) -> np.ndarray: # Compute SSIM (GPU conv2d) # ============================================================================= + def compute_ssim(arr1, arr2, win_size: int) -> float: """ Compute SSIM using GPU (torch) or CPU (skimage). @@ -335,7 +350,9 @@ def compute_ssim(arr1, arr2, win_size: int) -> float: return float(_ssim_cpu(arr1_np, arr2_np, win_size=win_size, data_range=data_range)) -def _compute_ssim_torch(arr1: np.ndarray, arr2: np.ndarray, win_size: int, data_range: float) -> float: +def _compute_ssim_torch( + arr1: np.ndarray, arr2: np.ndarray, win_size: int, data_range: float +) -> float: """GPU SSIM using torch conv2d for local statistics.""" C1 = (0.01 * data_range) ** 2 C2 = (0.03 * data_range) ** 2 @@ -351,18 +368,19 @@ def _compute_ssim_torch(arr1: np.ndarray, arr2: np.ndarray, win_size: int, data_ mu1 = F.conv2d(img1, window, padding=win_size // 2) mu2 = F.conv2d(img2, window, padding=win_size // 2) - mu1_sq = mu1 ** 2 - mu2_sq = mu2 ** 2 + mu1_sq = mu1**2 + mu2_sq = mu2**2 mu1_mu2 = mu1 * mu2 # Compute local variances and covariance - sigma1_sq = F.conv2d(img1 ** 2, window, padding=win_size // 2) - mu1_sq - sigma2_sq = F.conv2d(img2 ** 2, window, padding=win_size // 2) - mu2_sq + sigma1_sq = F.conv2d(img1**2, window, padding=win_size // 2) - mu1_sq + sigma2_sq = F.conv2d(img2**2, window, padding=win_size // 2) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=win_size // 2) - mu1_mu2 # SSIM formula - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ - ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) return float(ssim_map.mean().cpu()) @@ -371,6 +389,7 @@ def _compute_ssim_torch(arr1: np.ndarray, arr2: np.ndarray, win_size: int, data_ # Utility Functions # ============================================================================= + def make_1d_profile(length: int, blend: int) -> np.ndarray: """ Create a linear ramp profile over `blend` pixels at each end. diff --git a/tests/test_block_reduce.py b/tests/test_block_reduce.py index eebe1cc..961fc2b 100644 --- a/tests/test_block_reduce.py +++ b/tests/test_block_reduce.py @@ -5,6 +5,7 @@ from skimage.measure import block_reduce as skimage_block_reduce import sys + sys.path.insert(0, "src") from tilefusion.utils import block_reduce, CUDA_AVAILABLE diff --git a/tests/test_fft.py b/tests/test_fft.py index 92f2023..3658d33 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -2,6 +2,7 @@ import numpy as np import sys + sys.path.insert(0, "src") from tilefusion.utils import phase_cross_correlation, CUDA_AVAILABLE diff --git a/tests/test_histogram_match.py b/tests/test_histogram_match.py index 3db18b5..a5fee91 100644 --- a/tests/test_histogram_match.py +++ b/tests/test_histogram_match.py @@ -1,6 +1,8 @@ """Unit tests for GPU histogram matching.""" + import numpy as np import sys + sys.path.insert(0, "src") from tilefusion.utils import match_histograms, CUDA_AVAILABLE diff --git a/tests/test_shift_array.py b/tests/test_shift_array.py index 8874b42..a3047df 100644 --- a/tests/test_shift_array.py +++ b/tests/test_shift_array.py @@ -1,6 +1,8 @@ """Unit tests for GPU shift_array.""" + import numpy as np import sys + sys.path.insert(0, "src") from tilefusion.utils import shift_array, CUDA_AVAILABLE diff --git a/tests/test_ssim.py b/tests/test_ssim.py index ce7f1a0..694e557 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -1,6 +1,8 @@ """Unit tests for GPU SSIM.""" + import numpy as np import sys + sys.path.insert(0, "src") from tilefusion.utils import compute_ssim, CUDA_AVAILABLE From 6364c26cd38a0124cf37701da9a661e5f99caf32 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:06:12 -0800 Subject: [PATCH 03/18] fix: Remove unused variables in _match_histograms_torch --- src/tilefusion/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 2e5da2b..c7c2adc 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -248,15 +248,12 @@ def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndar img_sorted, img_indices = torch.sort(img) ref_sorted, _ = torch.sort(ref) - # Create inverse mapping + # Create inverse mapping (rank of each pixel) inv_indices = torch.empty_like(img_indices) inv_indices[img_indices] = torch.arange(len(img), device="cuda") - # Interpolate reference values at image quantiles - img_quantiles = torch.linspace(0, 1, len(img), device="cuda") - ref_quantiles = torch.linspace(0, 1, len(ref), device="cuda") - # Map image values to reference values via quantile matching + # For each pixel, find corresponding quantile in reference interp_values = torch.zeros_like(img) interp_values[img_indices] = ref_sorted[ (inv_indices.float() / len(img) * len(ref)).long().clamp(0, len(ref) - 1) From 27c239447605c3bbda5a00efc64d07ffbb2f681c Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:06:45 -0800 Subject: [PATCH 04/18] feat: Add dtype preservation to shift_array, match_histograms, block_reduce --- src/tilefusion/utils.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index c7c2adc..015c50b 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -154,7 +154,7 @@ def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): # ============================================================================= -def shift_array(arr, shift_vec): +def shift_array(arr, shift_vec, preserve_dtype=True): """ Shift array by subpixel amounts using GPU (torch) or CPU (scipy). @@ -164,6 +164,8 @@ def shift_array(arr, shift_vec): 2D input array. shift_vec : array-like (dy, dx) shift amounts. + preserve_dtype : bool + If True, output dtype matches input dtype. Default True. Returns ------- @@ -171,11 +173,16 @@ def shift_array(arr, shift_vec): Shifted array, same shape as input. """ arr_np = np.asarray(arr) + original_dtype = arr_np.dtype if CUDA_AVAILABLE and arr_np.ndim == 2: - return _shift_array_torch(arr_np, shift_vec) + result = _shift_array_torch(arr_np, shift_vec) + else: + result = _shift_cpu(arr_np, shift=shift_vec, order=1, prefilter=False) - return _shift_cpu(arr_np, shift=shift_vec, order=1, prefilter=False) + if preserve_dtype and result.dtype != original_dtype: + return result.astype(original_dtype) + return result def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray: @@ -213,7 +220,7 @@ def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray: # ============================================================================= -def match_histograms(image, reference): +def match_histograms(image, reference, preserve_dtype=True): """ Match histogram of image to reference using GPU (torch) or CPU (skimage). @@ -223,6 +230,8 @@ def match_histograms(image, reference): Image to transform. reference : ndarray Reference image for histogram matching. + preserve_dtype : bool + If True, output dtype matches input dtype. Default True. Returns ------- @@ -231,11 +240,16 @@ def match_histograms(image, reference): """ image_np = np.asarray(image) reference_np = np.asarray(reference) + original_dtype = image_np.dtype if CUDA_AVAILABLE and image_np.ndim == 2: - return _match_histograms_torch(image_np, reference_np) + result = _match_histograms_torch(image_np, reference_np) + else: + result = _match_histograms_cpu(image_np, reference_np) - return _match_histograms_cpu(image_np, reference_np) + if preserve_dtype and result.dtype != original_dtype: + return result.astype(original_dtype) + return result def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndarray: @@ -267,7 +281,7 @@ def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndar # ============================================================================= -def block_reduce(arr, block_size, func=np.mean): +def block_reduce(arr, block_size, func=np.mean, preserve_dtype=True): """ Block reduce array using GPU (torch) or CPU (skimage). @@ -279,17 +293,24 @@ def block_reduce(arr, block_size, func=np.mean): Reduction factors per dimension. func : callable Reduction function (only np.mean supported on GPU). + preserve_dtype : bool + If True, output dtype matches input dtype. Default True. Returns ------- reduced : ndarray """ arr_np = np.asarray(arr) + original_dtype = arr_np.dtype if CUDA_AVAILABLE and func == np.mean and arr_np.ndim >= 2: - return _block_reduce_torch(arr_np, block_size) + result = _block_reduce_torch(arr_np, block_size) + else: + result = _block_reduce_cpu(arr_np, block_size, func) - return _block_reduce_cpu(arr_np, block_size, func) + if preserve_dtype and result.dtype != original_dtype: + return result.astype(original_dtype) + return result def _block_reduce_torch(arr: np.ndarray, block_size: tuple) -> np.ndarray: From 58d4f567ae7e1372b3543d991b5f44e94563ae81 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:07:08 -0800 Subject: [PATCH 05/18] fix: Handle 2D block_size for 3D arrays in _block_reduce_torch --- src/tilefusion/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 015c50b..22841eb 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -324,7 +324,14 @@ def _block_reduce_torch(arr: np.ndarray, block_size: tuple) -> np.ndarray: return out.squeeze().cpu().numpy() elif ndim == 3: - kernel = (block_size[1], block_size[2]) if len(block_size) == 3 else block_size + # For 3D arrays (C, H, W), extract spatial kernel from block_size + if len(block_size) == 3: + # block_size is (c_factor, h_factor, w_factor) + # Only use spatial dimensions for avg_pool2d + kernel = (block_size[1], block_size[2]) + else: + # block_size is (h_factor, w_factor) - apply to spatial dims + kernel = (block_size[0], block_size[1]) t = torch.from_numpy(arr).float().cuda().unsqueeze(0) out = torch.nn.functional.avg_pool2d(t, kernel, stride=kernel) return out.squeeze(0).cpu().numpy() From 897d642db9153da1c9e182781e3ce5e2004d4584 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:07:29 -0800 Subject: [PATCH 06/18] refactor: Extract duplicate data_range calculation in compute_ssim --- src/tilefusion/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 22841eb..2ee7b13 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -363,15 +363,14 @@ def compute_ssim(arr1, arr2, win_size: int) -> float: arr1_np = np.asarray(arr1, dtype=np.float32) arr2_np = np.asarray(arr2, dtype=np.float32) - if CUDA_AVAILABLE and arr1_np.ndim == 2: - data_range = float(arr1_np.max() - arr1_np.min()) - if data_range == 0: - data_range = 1.0 - return _compute_ssim_torch(arr1_np, arr2_np, win_size, data_range) - + # Compute data range once data_range = float(arr1_np.max() - arr1_np.min()) if data_range == 0: data_range = 1.0 + + if CUDA_AVAILABLE and arr1_np.ndim == 2: + return _compute_ssim_torch(arr1_np, arr2_np, win_size, data_range) + return float(_ssim_cpu(arr1_np, arr2_np, win_size=win_size, data_range=data_range)) From 6c74bfcedc089d825009964b305a922d8d083a35 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:07:59 -0800 Subject: [PATCH 07/18] refactor: Add named constants for magic numbers (_FFT_EPS, _SSIM_K1, _SSIM_K2) --- src/tilefusion/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 2ee7b13..10d41d0 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -32,6 +32,11 @@ xp = np cp = None +# Constants +_FFT_EPS = 1e-10 # Epsilon for FFT normalization to avoid division by zero +_SSIM_K1 = 0.01 # SSIM constant K1 (luminance) +_SSIM_K2 = 0.03 # SSIM constant K2 (contrast) + # ============================================================================= # Phase Cross-Correlation (GPU FFT) @@ -79,8 +84,7 @@ def _phase_cross_correlation_torch( ref_fft = torch.fft.fft2(ref) mov_fft = torch.fft.fft2(mov) cross_power = ref_fft * torch.conj(mov_fft) - eps = 1e-10 - cross_power = cross_power / (torch.abs(cross_power) + eps) + cross_power = cross_power / (torch.abs(cross_power) + _FFT_EPS) # Inverse FFT to get correlation correlation = torch.fft.ifft2(cross_power).real @@ -378,8 +382,8 @@ def _compute_ssim_torch( arr1: np.ndarray, arr2: np.ndarray, win_size: int, data_range: float ) -> float: """GPU SSIM using torch conv2d for local statistics.""" - C1 = (0.01 * data_range) ** 2 - C2 = (0.03 * data_range) ** 2 + C1 = (_SSIM_K1 * data_range) ** 2 + C2 = (_SSIM_K2 * data_range) ** 2 # Create uniform window window = torch.ones(1, 1, win_size, win_size, device="cuda") / (win_size * win_size) From ab2d78445e292facf3b8d5a74cefbd54ef8ca949 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:09:51 -0800 Subject: [PATCH 08/18] test: Add CPU fallback and dtype preservation tests --- tests/conftest.py | 10 +++ tests/test_cpu_fallback.py | 137 +++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 tests/test_cpu_fallback.py diff --git a/tests/conftest.py b/tests/conftest.py index 2518116..a1a3105 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,3 +20,13 @@ def sample_tile(rng): def sample_multichannel_tile(rng): """Generate a sample multi-channel tile.""" return rng.random((3, 100, 100), dtype=np.float32) * 65535 + + +@pytest.fixture +def force_cpu(monkeypatch): + """Force CPU fallback by setting CUDA_AVAILABLE to False.""" + import tilefusion.utils as utils + + monkeypatch.setattr(utils, "CUDA_AVAILABLE", False) + yield + # monkeypatch automatically restores after test diff --git a/tests/test_cpu_fallback.py b/tests/test_cpu_fallback.py new file mode 100644 index 0000000..cf2543d --- /dev/null +++ b/tests/test_cpu_fallback.py @@ -0,0 +1,137 @@ +"""Tests for CPU fallback paths and dtype preservation.""" + +import numpy as np +import pytest +import sys + +sys.path.insert(0, "src") + +from tilefusion.utils import ( + phase_cross_correlation, + shift_array, + match_histograms, + block_reduce, + compute_ssim, +) + + +class TestCPUFallback: + """Test that CPU fallback paths work correctly.""" + + def test_phase_cross_correlation_cpu(self, force_cpu): + """Test phase_cross_correlation with CPU fallback.""" + np.random.seed(42) + ref = np.random.rand(128, 128).astype(np.float32) + mov = np.roll(ref, 5, axis=0) + + shift, error, phasediff = phase_cross_correlation(ref, mov) + + assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5" + + def test_shift_array_cpu(self, force_cpu): + """Test shift_array with CPU fallback.""" + arr = np.random.rand(128, 128).astype(np.float32) + result = shift_array(arr, (3.0, -2.0)) + + assert result.shape == arr.shape + assert result.dtype == arr.dtype + + def test_match_histograms_cpu(self, force_cpu): + """Test match_histograms with CPU fallback.""" + img = np.random.rand(128, 128).astype(np.float32) + ref = np.random.rand(128, 128).astype(np.float32) * 2 + + result = match_histograms(img, ref) + + assert result.shape == img.shape + + def test_block_reduce_cpu(self, force_cpu): + """Test block_reduce with CPU fallback.""" + arr = np.random.rand(128, 128).astype(np.float32) + result = block_reduce(arr, (4, 4), np.mean) + + assert result.shape == (32, 32) + + def test_compute_ssim_cpu(self, force_cpu): + """Test compute_ssim with CPU fallback.""" + arr1 = np.random.rand(128, 128).astype(np.float32) + arr2 = arr1 + np.random.rand(128, 128).astype(np.float32) * 0.1 + + ssim = compute_ssim(arr1, arr2, win_size=7) + + assert 0.0 <= ssim <= 1.0 + + +class TestDtypePreservation: + """Test that dtype is preserved when preserve_dtype=True.""" + + @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) + def test_shift_array_dtype(self, dtype, force_cpu): + """Test shift_array preserves dtype.""" + arr = (np.random.rand(64, 64) * 255).astype(dtype) + result = shift_array(arr, (1.5, -1.5), preserve_dtype=True) + + assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" + + @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) + def test_match_histograms_dtype(self, dtype, force_cpu): + """Test match_histograms preserves dtype.""" + img = (np.random.rand(64, 64) * 255).astype(dtype) + ref = (np.random.rand(64, 64) * 255).astype(dtype) + result = match_histograms(img, ref, preserve_dtype=True) + + assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" + + @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) + def test_block_reduce_dtype(self, dtype, force_cpu): + """Test block_reduce preserves dtype.""" + arr = (np.random.rand(64, 64) * 255).astype(dtype) + result = block_reduce(arr, (4, 4), np.mean, preserve_dtype=True) + + assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" + + def test_shift_array_no_preserve(self, force_cpu): + """Test shift_array returns float when preserve_dtype=False.""" + arr = (np.random.rand(64, 64) * 255).astype(np.uint16) + result = shift_array(arr, (1.5, -1.5), preserve_dtype=False) + + # Should return float64 (scipy default) + assert result.dtype in [np.float32, np.float64] + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_shift_zero(self, force_cpu): + """Test that zero shift returns nearly identical array.""" + arr = np.random.rand(64, 64).astype(np.float32) + result = shift_array(arr, (0.0, 0.0)) + + np.testing.assert_allclose(result, arr, rtol=1e-5, atol=1e-5) + + def test_identical_images_ssim(self, force_cpu): + """Test SSIM of identical images is ~1.0.""" + arr = np.random.rand(64, 64).astype(np.float32) + ssim = compute_ssim(arr, arr, win_size=7) + + assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" + + def test_block_reduce_3d(self, force_cpu): + """Test block_reduce with 3D array.""" + arr = np.random.rand(3, 64, 64).astype(np.float32) + result = block_reduce(arr, (1, 4, 4), np.mean) + + assert result.shape == (3, 16, 16) + + def test_different_size_histogram_match(self, force_cpu): + """Test histogram matching with different sized images.""" + img = np.random.rand(64, 64).astype(np.float32) + ref = np.random.rand(128, 128).astype(np.float32) + + result = match_histograms(img, ref) + + assert result.shape == img.shape + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 1af224e97f92321d90744e894742524b3f6c9475 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:10:34 -0800 Subject: [PATCH 09/18] test: Add subpixel phase correlation tests --- tests/test_fft.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_fft.py b/tests/test_fft.py index 3658d33..3a15eae 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -54,6 +54,52 @@ def test_matches_skimage_direction(): assert np.sign(gpu_shift[1]) == np.sign(cpu_shift[1]), "X direction mismatch" +def test_subpixel_refinement(): + """Test subpixel accuracy with upsample_factor > 1.""" + np.random.seed(42) + ref = np.random.rand(128, 128).astype(np.float32) + + # Use integer shift for ground truth (subpixel refinement should still work) + mov = np.roll(np.roll(ref, 7, axis=0), -4, axis=1) + + # Test with upsample_factor=10 for subpixel refinement + shift_subpixel, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10) + shift_integer, _, _ = phase_cross_correlation(ref, mov, upsample_factor=1) + + # Both should detect the shift direction correctly + assert ( + abs(shift_subpixel[0] - (-7)) < 1 + ), f"Subpixel Y shift {shift_subpixel[0]} not close to -7" + assert abs(shift_subpixel[1] - 4) < 1, f"Subpixel X shift {shift_subpixel[1]} not close to 4" + + # Subpixel should give fractional values (may have decimal component) + # Just verify it returns reasonable values + assert -10 < shift_subpixel[0] < 0, f"Subpixel Y shift {shift_subpixel[0]} out of range" + assert 0 < shift_subpixel[1] < 10, f"Subpixel X shift {shift_subpixel[1]} out of range" + + +def test_subpixel_vs_integer_consistency(): + """Test that subpixel and integer modes give consistent direction.""" + np.random.seed(123) + ref = np.random.rand(64, 64).astype(np.float32) + mov = np.roll(np.roll(ref, 3, axis=0), -2, axis=1) + + shift_int, _, _ = phase_cross_correlation(ref, mov, upsample_factor=1) + shift_sub, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10) + + # Signs should match + assert np.sign(shift_int[0]) == np.sign( + shift_sub[0] + ), "Y direction mismatch between int/subpixel" + assert np.sign(shift_int[1]) == np.sign( + shift_sub[1] + ), "X direction mismatch between int/subpixel" + + # Magnitudes should be close + assert abs(shift_int[0] - shift_sub[0]) < 1, "Y magnitude differs too much" + assert abs(shift_int[1] - shift_sub[1]) < 1, "X magnitude differs too much" + + if __name__ == "__main__": print(f"CUDA available: {CUDA_AVAILABLE}") test_known_shift() @@ -62,4 +108,8 @@ def test_matches_skimage_direction(): print("test_zero_shift: PASSED") test_matches_skimage_direction() print("test_matches_skimage_direction: PASSED") + test_subpixel_refinement() + print("test_subpixel_refinement: PASSED") + test_subpixel_vs_integer_consistency() + print("test_subpixel_vs_integer_consistency: PASSED") print("All tests passed!") From ea628b212c100648f1d2acef37ecd5254552d691 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:13:19 -0800 Subject: [PATCH 10/18] refactor: Add _PARABOLIC_EPS constant for subpixel refinement --- src/tilefusion/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 10d41d0..ba7b6f4 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -34,6 +34,7 @@ # Constants _FFT_EPS = 1e-10 # Epsilon for FFT normalization to avoid division by zero +_PARABOLIC_EPS = 1e-10 # Epsilon for parabolic fit denominator check _SSIM_K1 = 0.01 # SSIM constant K1 (luminance) _SSIM_K2 = 0.03 # SSIM constant K2 (contrast) @@ -130,7 +131,7 @@ def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): denom = 2 * (2 * center_val - neighborhood[0, 1].item() - neighborhood[2, 1].item()) dy = ( (neighborhood[0, 1].item() - neighborhood[2, 1].item()) / denom - if abs(denom) > 1e-10 + if abs(denom) > _PARABOLIC_EPS else 0.0 ) else: @@ -141,7 +142,7 @@ def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): denom = 2 * (2 * center_val - neighborhood[1, 0].item() - neighborhood[1, 2].item()) dx = ( (neighborhood[1, 0].item() - neighborhood[1, 2].item()) / denom - if abs(denom) > 1e-10 + if abs(denom) > _PARABOLIC_EPS else 0.0 ) else: From 5070944092ff965327251cd95b1428d41d2d76ce Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:13:35 -0800 Subject: [PATCH 11/18] fix: Guard against 1-pixel arrays in _shift_array_torch --- src/tilefusion/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index ba7b6f4..ea5a84d 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -193,6 +193,11 @@ def shift_array(arr, shift_vec, preserve_dtype=True): def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray: """GPU shift using torch.nn.functional.grid_sample.""" h, w = arr.shape + + # Guard against degenerate arrays (need at least 2 pixels for interpolation) + if h < 2 or w < 2: + return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) + dy, dx = float(shift_vec[0]), float(shift_vec[1]) # Create pixel coordinate grids From 6d0a9ebcaf8038f8ac9f937aa0e6dbfe360cc61d Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:14:55 -0800 Subject: [PATCH 12/18] refactor: Add __all__ export list and document legacy compatibility vars --- src/tilefusion/utils.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index ea5a84d..056d50e 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -7,6 +7,26 @@ import numpy as np +__all__ = [ + # GPU detection flags + "TORCH_AVAILABLE", + "CUDA_AVAILABLE", + "USING_GPU", + # Array module (legacy compatibility) + "xp", + "cp", + # Core functions + "phase_cross_correlation", + "shift_array", + "match_histograms", + "block_reduce", + "compute_ssim", + # Utility functions + "make_1d_profile", + "to_numpy", + "to_device", +] + # GPU detection - PyTorch based try: import torch @@ -27,10 +47,13 @@ from skimage.metrics import structural_similarity as _ssim_cpu from skimage.registration import phase_cross_correlation as _phase_cross_correlation_cpu -# Legacy compatibility +# Legacy compatibility - used by core.py and registration.py +# xp: array module (numpy, since cupy was removed) +# cp: cupy module (always None now, kept for API compatibility) +# USING_GPU: exported in __init__.py for user code USING_GPU = CUDA_AVAILABLE xp = np -cp = None +cp = None # cupy removed; GPU ops now use PyTorch internally # Constants _FFT_EPS = 1e-10 # Epsilon for FFT normalization to avoid division by zero From f631653c64f8ad45e2b285536c55729d5476f1c1 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:16:46 -0800 Subject: [PATCH 13/18] test: Refactor tests to use rng fixture and pytest class style --- tests/test_cpu_fallback.py | 61 +++++++------- tests/test_fft.py | 145 ++++++++++++++++------------------ tests/test_histogram_match.py | 54 ++++++++----- tests/test_shift_array.py | 62 +++++++++------ tests/test_ssim.py | 49 +++++++----- 5 files changed, 198 insertions(+), 173 deletions(-) diff --git a/tests/test_cpu_fallback.py b/tests/test_cpu_fallback.py index cf2543d..1b72206 100644 --- a/tests/test_cpu_fallback.py +++ b/tests/test_cpu_fallback.py @@ -18,44 +18,43 @@ class TestCPUFallback: """Test that CPU fallback paths work correctly.""" - def test_phase_cross_correlation_cpu(self, force_cpu): + def test_phase_cross_correlation_cpu(self, force_cpu, rng): """Test phase_cross_correlation with CPU fallback.""" - np.random.seed(42) - ref = np.random.rand(128, 128).astype(np.float32) + ref = rng.random((128, 128)).astype(np.float32) mov = np.roll(ref, 5, axis=0) shift, error, phasediff = phase_cross_correlation(ref, mov) assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5" - def test_shift_array_cpu(self, force_cpu): + def test_shift_array_cpu(self, force_cpu, rng): """Test shift_array with CPU fallback.""" - arr = np.random.rand(128, 128).astype(np.float32) + arr = rng.random((128, 128)).astype(np.float32) result = shift_array(arr, (3.0, -2.0)) assert result.shape == arr.shape assert result.dtype == arr.dtype - def test_match_histograms_cpu(self, force_cpu): + def test_match_histograms_cpu(self, force_cpu, rng): """Test match_histograms with CPU fallback.""" - img = np.random.rand(128, 128).astype(np.float32) - ref = np.random.rand(128, 128).astype(np.float32) * 2 + img = rng.random((128, 128)).astype(np.float32) + ref = rng.random((128, 128)).astype(np.float32) * 2 result = match_histograms(img, ref) assert result.shape == img.shape - def test_block_reduce_cpu(self, force_cpu): + def test_block_reduce_cpu(self, force_cpu, rng): """Test block_reduce with CPU fallback.""" - arr = np.random.rand(128, 128).astype(np.float32) + arr = rng.random((128, 128)).astype(np.float32) result = block_reduce(arr, (4, 4), np.mean) assert result.shape == (32, 32) - def test_compute_ssim_cpu(self, force_cpu): + def test_compute_ssim_cpu(self, force_cpu, rng): """Test compute_ssim with CPU fallback.""" - arr1 = np.random.rand(128, 128).astype(np.float32) - arr2 = arr1 + np.random.rand(128, 128).astype(np.float32) * 0.1 + arr1 = rng.random((128, 128)).astype(np.float32) + arr2 = arr1 + rng.random((128, 128)).astype(np.float32) * 0.1 ssim = compute_ssim(arr1, arr2, win_size=7) @@ -66,33 +65,33 @@ class TestDtypePreservation: """Test that dtype is preserved when preserve_dtype=True.""" @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) - def test_shift_array_dtype(self, dtype, force_cpu): + def test_shift_array_dtype(self, dtype, force_cpu, rng): """Test shift_array preserves dtype.""" - arr = (np.random.rand(64, 64) * 255).astype(dtype) + arr = (rng.random((64, 64)) * 255).astype(dtype) result = shift_array(arr, (1.5, -1.5), preserve_dtype=True) assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) - def test_match_histograms_dtype(self, dtype, force_cpu): + def test_match_histograms_dtype(self, dtype, force_cpu, rng): """Test match_histograms preserves dtype.""" - img = (np.random.rand(64, 64) * 255).astype(dtype) - ref = (np.random.rand(64, 64) * 255).astype(dtype) + img = (rng.random((64, 64)) * 255).astype(dtype) + ref = (rng.random((64, 64)) * 255).astype(dtype) result = match_histograms(img, ref, preserve_dtype=True) assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) - def test_block_reduce_dtype(self, dtype, force_cpu): + def test_block_reduce_dtype(self, dtype, force_cpu, rng): """Test block_reduce preserves dtype.""" - arr = (np.random.rand(64, 64) * 255).astype(dtype) + arr = (rng.random((64, 64)) * 255).astype(dtype) result = block_reduce(arr, (4, 4), np.mean, preserve_dtype=True) assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" - def test_shift_array_no_preserve(self, force_cpu): + def test_shift_array_no_preserve(self, force_cpu, rng): """Test shift_array returns float when preserve_dtype=False.""" - arr = (np.random.rand(64, 64) * 255).astype(np.uint16) + arr = (rng.random((64, 64)) * 255).astype(np.uint16) result = shift_array(arr, (1.5, -1.5), preserve_dtype=False) # Should return float64 (scipy default) @@ -102,31 +101,31 @@ def test_shift_array_no_preserve(self, force_cpu): class TestEdgeCases: """Test edge cases and boundary conditions.""" - def test_shift_zero(self, force_cpu): + def test_shift_zero(self, force_cpu, rng): """Test that zero shift returns nearly identical array.""" - arr = np.random.rand(64, 64).astype(np.float32) + arr = rng.random((64, 64)).astype(np.float32) result = shift_array(arr, (0.0, 0.0)) np.testing.assert_allclose(result, arr, rtol=1e-5, atol=1e-5) - def test_identical_images_ssim(self, force_cpu): + def test_identical_images_ssim(self, force_cpu, rng): """Test SSIM of identical images is ~1.0.""" - arr = np.random.rand(64, 64).astype(np.float32) + arr = rng.random((64, 64)).astype(np.float32) ssim = compute_ssim(arr, arr, win_size=7) assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" - def test_block_reduce_3d(self, force_cpu): + def test_block_reduce_3d(self, force_cpu, rng): """Test block_reduce with 3D array.""" - arr = np.random.rand(3, 64, 64).astype(np.float32) + arr = rng.random((3, 64, 64)).astype(np.float32) result = block_reduce(arr, (1, 4, 4), np.mean) assert result.shape == (3, 16, 16) - def test_different_size_histogram_match(self, force_cpu): + def test_different_size_histogram_match(self, force_cpu, rng): """Test histogram matching with different sized images.""" - img = np.random.rand(64, 64).astype(np.float32) - ref = np.random.rand(128, 128).astype(np.float32) + img = rng.random((64, 64)).astype(np.float32) + ref = rng.random((128, 128)).astype(np.float32) result = match_histograms(img, ref) diff --git a/tests/test_fft.py b/tests/test_fft.py index 3a15eae..ea69f6f 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -1,6 +1,7 @@ """Unit tests for GPU phase_cross_correlation (FFT).""" import numpy as np +import pytest import sys sys.path.insert(0, "src") @@ -9,107 +10,93 @@ from skimage.registration import phase_cross_correlation as skimage_pcc -def test_known_shift(): - """Test detection of known integer shift.""" - np.random.seed(42) - ref = np.random.rand(256, 256).astype(np.float32) +class TestPhaseCorrelation: + """Tests for phase_cross_correlation function.""" - # Create shifted version: mov is ref shifted by (+5, -3) - # phase_cross_correlation returns shift to apply to mov to align with ref - # So it should return (-5, +3) - mov = np.zeros_like(ref) - mov[5:, :253] = ref[:-5, 3:] + def test_known_shift(self, rng): + """Test detection of known integer shift.""" + ref = rng.random((256, 256)).astype(np.float32) - shift, _, _ = phase_cross_correlation(ref, mov) + # Create shifted version: mov is ref shifted by (+5, -3) + # phase_cross_correlation returns shift to apply to mov to align with ref + # So it should return (-5, +3) + mov = np.zeros_like(ref) + mov[5:, :253] = ref[:-5, 3:] - # Should detect shift close to (-5, +3) - assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5" - assert abs(shift[1] - 3) < 1, f"X shift {shift[1]} not close to 3" + shift, _, _ = phase_cross_correlation(ref, mov) + assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5" + assert abs(shift[1] - 3) < 1, f"X shift {shift[1]} not close to 3" -def test_zero_shift(): - """Test that identical images give zero shift.""" - np.random.seed(42) - ref = np.random.rand(256, 256).astype(np.float32) + def test_zero_shift(self, rng): + """Test that identical images give zero shift.""" + ref = rng.random((256, 256)).astype(np.float32) - shift, _, _ = phase_cross_correlation(ref, ref) + shift, _, _ = phase_cross_correlation(ref, ref) - assert abs(shift[0]) < 0.5, f"Y shift {shift[0]} should be ~0" - assert abs(shift[1]) < 0.5, f"X shift {shift[1]} should be ~0" + assert abs(shift[0]) < 0.5, f"Y shift {shift[0]} should be ~0" + assert abs(shift[1]) < 0.5, f"X shift {shift[1]} should be ~0" + def test_matches_skimage_direction(self, rng): + """Test that shift direction matches skimage convention.""" + ref = rng.random((128, 128)).astype(np.float32) -def test_matches_skimage_direction(): - """Test that shift direction matches skimage convention.""" - np.random.seed(42) - ref = np.random.rand(128, 128).astype(np.float32) + # Shift by rolling + mov = np.roll(np.roll(ref, 10, axis=0), -7, axis=1) - # Shift by rolling - mov = np.roll(np.roll(ref, 10, axis=0), -7, axis=1) + gpu_shift, _, _ = phase_cross_correlation(ref, mov) + cpu_shift, _, _ = skimage_pcc(ref, mov) - gpu_shift, _, _ = phase_cross_correlation(ref, mov) - cpu_shift, _, _ = skimage_pcc(ref, mov) + # Directions should match + assert np.sign(gpu_shift[0]) == np.sign(cpu_shift[0]), "Y direction mismatch" + assert np.sign(gpu_shift[1]) == np.sign(cpu_shift[1]), "X direction mismatch" - # Directions should match - assert np.sign(gpu_shift[0]) == np.sign(cpu_shift[0]), "Y direction mismatch" - assert np.sign(gpu_shift[1]) == np.sign(cpu_shift[1]), "X direction mismatch" +class TestSubpixelRefinement: + """Tests for subpixel phase correlation refinement.""" -def test_subpixel_refinement(): - """Test subpixel accuracy with upsample_factor > 1.""" - np.random.seed(42) - ref = np.random.rand(128, 128).astype(np.float32) + def test_subpixel_refinement(self, rng): + """Test subpixel accuracy with upsample_factor > 1.""" + ref = rng.random((128, 128)).astype(np.float32) - # Use integer shift for ground truth (subpixel refinement should still work) - mov = np.roll(np.roll(ref, 7, axis=0), -4, axis=1) + # Use integer shift for ground truth (subpixel refinement should still work) + mov = np.roll(np.roll(ref, 7, axis=0), -4, axis=1) - # Test with upsample_factor=10 for subpixel refinement - shift_subpixel, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10) - shift_integer, _, _ = phase_cross_correlation(ref, mov, upsample_factor=1) + # Test with upsample_factor=10 for subpixel refinement + shift_subpixel, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10) - # Both should detect the shift direction correctly - assert ( - abs(shift_subpixel[0] - (-7)) < 1 - ), f"Subpixel Y shift {shift_subpixel[0]} not close to -7" - assert abs(shift_subpixel[1] - 4) < 1, f"Subpixel X shift {shift_subpixel[1]} not close to 4" + # Should detect the shift direction correctly + assert ( + abs(shift_subpixel[0] - (-7)) < 1 + ), f"Subpixel Y shift {shift_subpixel[0]} not close to -7" + assert ( + abs(shift_subpixel[1] - 4) < 1 + ), f"Subpixel X shift {shift_subpixel[1]} not close to 4" - # Subpixel should give fractional values (may have decimal component) - # Just verify it returns reasonable values - assert -10 < shift_subpixel[0] < 0, f"Subpixel Y shift {shift_subpixel[0]} out of range" - assert 0 < shift_subpixel[1] < 10, f"Subpixel X shift {shift_subpixel[1]} out of range" + # Verify reasonable range + assert -10 < shift_subpixel[0] < 0, f"Subpixel Y shift {shift_subpixel[0]} out of range" + assert 0 < shift_subpixel[1] < 10, f"Subpixel X shift {shift_subpixel[1]} out of range" + def test_subpixel_vs_integer_consistency(self, rng): + """Test that subpixel and integer modes give consistent direction.""" + ref = rng.random((64, 64)).astype(np.float32) + mov = np.roll(np.roll(ref, 3, axis=0), -2, axis=1) -def test_subpixel_vs_integer_consistency(): - """Test that subpixel and integer modes give consistent direction.""" - np.random.seed(123) - ref = np.random.rand(64, 64).astype(np.float32) - mov = np.roll(np.roll(ref, 3, axis=0), -2, axis=1) + shift_int, _, _ = phase_cross_correlation(ref, mov, upsample_factor=1) + shift_sub, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10) - shift_int, _, _ = phase_cross_correlation(ref, mov, upsample_factor=1) - shift_sub, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10) + # Signs should match + assert np.sign(shift_int[0]) == np.sign( + shift_sub[0] + ), "Y direction mismatch between int/subpixel" + assert np.sign(shift_int[1]) == np.sign( + shift_sub[1] + ), "X direction mismatch between int/subpixel" - # Signs should match - assert np.sign(shift_int[0]) == np.sign( - shift_sub[0] - ), "Y direction mismatch between int/subpixel" - assert np.sign(shift_int[1]) == np.sign( - shift_sub[1] - ), "X direction mismatch between int/subpixel" - - # Magnitudes should be close - assert abs(shift_int[0] - shift_sub[0]) < 1, "Y magnitude differs too much" - assert abs(shift_int[1] - shift_sub[1]) < 1, "X magnitude differs too much" + # Magnitudes should be close + assert abs(shift_int[0] - shift_sub[0]) < 1, "Y magnitude differs too much" + assert abs(shift_int[1] - shift_sub[1]) < 1, "X magnitude differs too much" if __name__ == "__main__": - print(f"CUDA available: {CUDA_AVAILABLE}") - test_known_shift() - print("test_known_shift: PASSED") - test_zero_shift() - print("test_zero_shift: PASSED") - test_matches_skimage_direction() - print("test_matches_skimage_direction: PASSED") - test_subpixel_refinement() - print("test_subpixel_refinement: PASSED") - test_subpixel_vs_integer_consistency() - print("test_subpixel_vs_integer_consistency: PASSED") - print("All tests passed!") + pytest.main([__file__, "-v"]) diff --git a/tests/test_histogram_match.py b/tests/test_histogram_match.py index a5fee91..e0aa402 100644 --- a/tests/test_histogram_match.py +++ b/tests/test_histogram_match.py @@ -1,6 +1,7 @@ """Unit tests for GPU histogram matching.""" import numpy as np +import pytest import sys sys.path.insert(0, "src") @@ -9,29 +10,44 @@ from skimage.exposure import match_histograms as skimage_match -def test_histogram_range(): - img = np.random.rand(256, 256).astype(np.float32) - ref = np.random.rand(256, 256).astype(np.float32) * 2 + 1 - result = match_histograms(img, ref) - # Output should be in reference range - assert result.min() >= ref.min() - 0.1 - assert result.max() <= ref.max() + 0.1 +class TestMatchHistograms: + """Tests for match_histograms function.""" + def test_histogram_range(self, rng): + """Test output is in reference range.""" + img = rng.random((256, 256)).astype(np.float32) + ref = rng.random((256, 256)).astype(np.float32) * 2 + 1 + result = match_histograms(img, ref) + # Output should be in reference range + assert result.min() >= ref.min() - 0.1 + assert result.max() <= ref.max() + 0.1 -def test_histogram_correlation(): - img = np.random.rand(256, 256).astype(np.float32) - ref = np.random.rand(256, 256).astype(np.float32) + def test_histogram_correlation(self, rng): + """Test histogram correlation with skimage.""" + img = rng.random((256, 256)).astype(np.float32) + ref = rng.random((256, 256)).astype(np.float32) - cpu = skimage_match(img, ref) - gpu = match_histograms(img, ref) + cpu = skimage_match(img, ref) + gpu = match_histograms(img, ref) - cpu_hist, _ = np.histogram(cpu.flatten(), bins=100) - gpu_hist, _ = np.histogram(gpu.flatten(), bins=100) - corr = np.corrcoef(cpu_hist, gpu_hist)[0, 1] - assert corr > 0.99, f"Histogram correlation {corr} too low" + cpu_hist, _ = np.histogram(cpu.flatten(), bins=100) + gpu_hist, _ = np.histogram(gpu.flatten(), bins=100) + corr = np.corrcoef(cpu_hist, gpu_hist)[0, 1] + assert corr > 0.99, f"Histogram correlation {corr} too low" + + def test_same_image(self, rng): + """Test matching image to itself.""" + img = rng.random((128, 128)).astype(np.float32) + result = match_histograms(img, img) + np.testing.assert_allclose(result, img, rtol=1e-5) + + def test_different_sizes(self, rng): + """Test matching images of different sizes.""" + img = rng.random((64, 64)).astype(np.float32) + ref = rng.random((128, 128)).astype(np.float32) + result = match_histograms(img, ref) + assert result.shape == img.shape if __name__ == "__main__": - test_histogram_range() - test_histogram_correlation() - print("All tests passed") + pytest.main([__file__, "-v"]) diff --git a/tests/test_shift_array.py b/tests/test_shift_array.py index a3047df..530a7a5 100644 --- a/tests/test_shift_array.py +++ b/tests/test_shift_array.py @@ -1,6 +1,7 @@ """Unit tests for GPU shift_array.""" import numpy as np +import pytest import sys sys.path.insert(0, "src") @@ -9,30 +10,43 @@ from scipy.ndimage import shift as scipy_shift -def test_integer_shift(): - arr = np.random.rand(256, 256).astype(np.float32) - cpu = scipy_shift(arr, (3.0, -5.0), order=1, prefilter=False) - gpu = shift_array(arr, (3.0, -5.0)) - np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4) - - -def test_subpixel_mean_error(): - arr = np.random.rand(256, 256).astype(np.float32) - cpu = scipy_shift(arr, (5.5, -3.2), order=1, prefilter=False) - gpu = shift_array(arr, (5.5, -3.2)) - mean_diff = np.abs(cpu - gpu).mean() - assert mean_diff < 0.01, f"Mean diff {mean_diff} too high" - - -def test_zero_shift(): - arr = np.random.rand(256, 256).astype(np.float32) - result = shift_array(arr, (0.0, 0.0)) - # Allow small tolerance due to grid_sample interpolation - np.testing.assert_allclose(result, arr, rtol=1e-4, atol=1e-4) +class TestShiftArray: + """Tests for shift_array function.""" + + def test_integer_shift(self, rng): + """Test integer shift matches scipy.""" + arr = rng.random((256, 256)).astype(np.float32) + cpu = scipy_shift(arr, (3.0, -5.0), order=1, prefilter=False) + gpu = shift_array(arr, (3.0, -5.0)) + np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4) + + def test_subpixel_mean_error(self, rng): + """Test subpixel shift has low mean error vs scipy.""" + arr = rng.random((256, 256)).astype(np.float32) + cpu = scipy_shift(arr, (5.5, -3.2), order=1, prefilter=False) + gpu = shift_array(arr, (5.5, -3.2)) + mean_diff = np.abs(cpu - gpu).mean() + assert mean_diff < 0.01, f"Mean diff {mean_diff} too high" + + def test_zero_shift(self, rng): + """Test zero shift returns nearly identical array.""" + arr = rng.random((256, 256)).astype(np.float32) + result = shift_array(arr, (0.0, 0.0)) + # Allow small tolerance due to grid_sample interpolation + np.testing.assert_allclose(result, arr, rtol=1e-4, atol=1e-4) + + def test_small_array(self, rng): + """Test shift works on small arrays (edge case).""" + arr = rng.random((4, 4)).astype(np.float32) + result = shift_array(arr, (1.0, 1.0)) + assert result.shape == arr.shape + + def test_1pixel_fallback(self): + """Test 1-pixel array falls back to CPU.""" + arr = np.array([[1.0]], dtype=np.float32) + result = shift_array(arr, (0.5, 0.5)) + assert result.shape == (1, 1) if __name__ == "__main__": - test_integer_shift() - test_subpixel_mean_error() - test_zero_shift() - print("All tests passed") + pytest.main([__file__, "-v"]) diff --git a/tests/test_ssim.py b/tests/test_ssim.py index 694e557..2bc71b3 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -1,6 +1,7 @@ """Unit tests for GPU SSIM.""" import numpy as np +import pytest import sys sys.path.insert(0, "src") @@ -9,32 +10,40 @@ from skimage.metrics import structural_similarity as skimage_ssim -def test_ssim_similar_images(): - arr1 = np.random.rand(256, 256).astype(np.float32) - arr2 = arr1 + np.random.rand(256, 256).astype(np.float32) * 0.1 +class TestComputeSSIM: + """Tests for compute_ssim function.""" - data_range = arr1.max() - arr1.min() - cpu = skimage_ssim(arr1, arr2, win_size=15, data_range=data_range) - gpu = compute_ssim(arr1, arr2, win_size=15) + def test_ssim_similar_images(self, rng): + """Test SSIM of similar images matches skimage.""" + arr1 = rng.random((256, 256)).astype(np.float32) + arr2 = arr1 + rng.random((256, 256)).astype(np.float32) * 0.1 - assert abs(cpu - gpu) < 0.01, f"SSIM diff {abs(cpu-gpu)} too high" + data_range = arr1.max() - arr1.min() + cpu = skimage_ssim(arr1, arr2, win_size=15, data_range=data_range) + gpu = compute_ssim(arr1, arr2, win_size=15) + assert abs(cpu - gpu) < 0.01, f"SSIM diff {abs(cpu - gpu)} too high" -def test_ssim_identical_images(): - arr = np.random.rand(256, 256).astype(np.float32) - ssim = compute_ssim(arr, arr, win_size=15) - assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" + def test_ssim_identical_images(self, rng): + """Test SSIM of identical images is ~1.0.""" + arr = rng.random((256, 256)).astype(np.float32) + ssim = compute_ssim(arr, arr, win_size=15) + assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" + def test_ssim_different_images(self, rng): + """Test SSIM of random images is low.""" + arr1 = rng.random((256, 256)).astype(np.float32) + arr2 = rng.random((256, 256)).astype(np.float32) + ssim = compute_ssim(arr1, arr2, win_size=15) + assert ssim < 0.5, f"SSIM of random images should be low, got {ssim}" -def test_ssim_different_images(): - arr1 = np.random.rand(256, 256).astype(np.float32) - arr2 = np.random.rand(256, 256).astype(np.float32) - ssim = compute_ssim(arr1, arr2, win_size=15) - assert ssim < 0.5, f"SSIM of random images should be low, got {ssim}" + def test_ssim_range(self, rng): + """Test SSIM is in valid range [0, 1].""" + arr1 = rng.random((128, 128)).astype(np.float32) + arr2 = rng.random((128, 128)).astype(np.float32) + ssim = compute_ssim(arr1, arr2, win_size=7) + assert 0.0 <= ssim <= 1.0, f"SSIM {ssim} outside valid range" if __name__ == "__main__": - test_ssim_similar_images() - test_ssim_identical_images() - test_ssim_different_images() - print("All tests passed") + pytest.main([__file__, "-v"]) From a5ab6545a5c4a4bc6c7a83164b0057e8d7b8e48c Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:20:54 -0800 Subject: [PATCH 14/18] chore: Clean up unused imports and variables --- src/tilefusion/utils.py | 2 +- tests/test_block_reduce.py | 31 ++++++++++++++++++++----------- tests/test_fft.py | 2 +- tests/test_histogram_match.py | 2 +- tests/test_shift_array.py | 2 +- tests/test_ssim.py | 2 +- 6 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 056d50e..7dd0ddb 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -292,7 +292,7 @@ def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndar ref = torch.from_numpy(reference.astype(np.float32)).cuda().flatten() # Get sorted indices - img_sorted, img_indices = torch.sort(img) + _, img_indices = torch.sort(img) ref_sorted, _ = torch.sort(ref) # Create inverse mapping (rank of each pixel) diff --git a/tests/test_block_reduce.py b/tests/test_block_reduce.py index 961fc2b..95ad033 100644 --- a/tests/test_block_reduce.py +++ b/tests/test_block_reduce.py @@ -2,21 +2,20 @@ import numpy as np import pytest -from skimage.measure import block_reduce as skimage_block_reduce - import sys +from skimage.measure import block_reduce as skimage_block_reduce sys.path.insert(0, "src") -from tilefusion.utils import block_reduce, CUDA_AVAILABLE +from tilefusion.utils import block_reduce class TestBlockReduce: """Test block_reduce GPU vs CPU equivalence.""" - def test_2d_basic(self): + def test_2d_basic(self, rng): """Test 2D block reduce matches skimage.""" - arr = np.random.rand(256, 256).astype(np.float32) + arr = rng.random((256, 256)).astype(np.float32) block_size = (4, 4) result = block_reduce(arr, block_size, np.mean) @@ -24,9 +23,9 @@ def test_2d_basic(self): np.testing.assert_allclose(result, expected, rtol=1e-5) - def test_2d_large(self): + def test_2d_large(self, rng): """Test larger 2D array.""" - arr = np.random.rand(1024, 1024).astype(np.float32) + arr = rng.random((1024, 1024)).astype(np.float32) block_size = (8, 8) result = block_reduce(arr, block_size, np.mean) @@ -34,9 +33,9 @@ def test_2d_large(self): np.testing.assert_allclose(result, expected, rtol=1e-5) - def test_3d_multichannel(self): + def test_3d_multichannel(self, rng): """Test 3D array with channel dimension.""" - arr = np.random.rand(3, 256, 256).astype(np.float32) + arr = rng.random((3, 256, 256)).astype(np.float32) block_size = (1, 4, 4) result = block_reduce(arr, block_size, np.mean) @@ -44,15 +43,25 @@ def test_3d_multichannel(self): np.testing.assert_allclose(result, expected, rtol=1e-5) - def test_output_shape(self): + def test_output_shape(self, rng): """Test output shape is correct.""" - arr = np.random.rand(512, 512).astype(np.float32) + arr = rng.random((512, 512)).astype(np.float32) block_size = (4, 4) result = block_reduce(arr, block_size, np.mean) assert result.shape == (128, 128) + def test_non_divisible_shape(self, rng): + """Test block reduce with non-divisible dimensions.""" + arr = rng.random((100, 100)).astype(np.float32) + block_size = (8, 8) + + result = block_reduce(arr, block_size, np.mean) + expected = skimage_block_reduce(arr, block_size, np.mean) + + np.testing.assert_allclose(result, expected, rtol=1e-5) + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_fft.py b/tests/test_fft.py index ea69f6f..27e19bd 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -6,7 +6,7 @@ sys.path.insert(0, "src") -from tilefusion.utils import phase_cross_correlation, CUDA_AVAILABLE +from tilefusion.utils import phase_cross_correlation from skimage.registration import phase_cross_correlation as skimage_pcc diff --git a/tests/test_histogram_match.py b/tests/test_histogram_match.py index e0aa402..26fa774 100644 --- a/tests/test_histogram_match.py +++ b/tests/test_histogram_match.py @@ -6,7 +6,7 @@ sys.path.insert(0, "src") -from tilefusion.utils import match_histograms, CUDA_AVAILABLE +from tilefusion.utils import match_histograms from skimage.exposure import match_histograms as skimage_match diff --git a/tests/test_shift_array.py b/tests/test_shift_array.py index 530a7a5..4843eeb 100644 --- a/tests/test_shift_array.py +++ b/tests/test_shift_array.py @@ -6,7 +6,7 @@ sys.path.insert(0, "src") -from tilefusion.utils import shift_array, CUDA_AVAILABLE +from tilefusion.utils import shift_array from scipy.ndimage import shift as scipy_shift diff --git a/tests/test_ssim.py b/tests/test_ssim.py index 2bc71b3..d143889 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -6,7 +6,7 @@ sys.path.insert(0, "src") -from tilefusion.utils import compute_ssim, CUDA_AVAILABLE +from tilefusion.utils import compute_ssim from skimage.metrics import structural_similarity as skimage_ssim From d7cf01725f8f048960626161798d78a5af1c8975 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:27:42 -0800 Subject: [PATCH 15/18] Add type hints to public functions in utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add type hints to phase_cross_correlation, shift_array, match_histograms, block_reduce, compute_ssim - Add return type hints to to_numpy and to_device - Import Callable, Any, Union from typing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/tilefusion/utils.py | 45 +++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 7dd0ddb..5082ba2 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -5,6 +5,8 @@ All functions support GPU acceleration via PyTorch with automatic CPU fallback. """ +from typing import Any, Callable, Union + import numpy as np __all__ = [ @@ -67,7 +69,12 @@ # ============================================================================= -def phase_cross_correlation(reference_image, moving_image, upsample_factor=1, **kwargs): +def phase_cross_correlation( + reference_image: np.ndarray, + moving_image: np.ndarray, + upsample_factor: int = 1, + **kwargs, +) -> tuple[np.ndarray, float, float]: """ Phase cross-correlation using GPU (torch FFT) or CPU (skimage). @@ -182,7 +189,11 @@ def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): # ============================================================================= -def shift_array(arr, shift_vec, preserve_dtype=True): +def shift_array( + arr: np.ndarray, + shift_vec: tuple[float, float], + preserve_dtype: bool = True, +) -> np.ndarray: """ Shift array by subpixel amounts using GPU (torch) or CPU (scipy). @@ -190,7 +201,7 @@ def shift_array(arr, shift_vec, preserve_dtype=True): ---------- arr : ndarray 2D input array. - shift_vec : array-like + shift_vec : tuple[float, float] (dy, dx) shift amounts. preserve_dtype : bool If True, output dtype matches input dtype. Default True. @@ -253,7 +264,11 @@ def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray: # ============================================================================= -def match_histograms(image, reference, preserve_dtype=True): +def match_histograms( + image: np.ndarray, + reference: np.ndarray, + preserve_dtype: bool = True, +) -> np.ndarray: """ Match histogram of image to reference using GPU (torch) or CPU (skimage). @@ -314,7 +329,12 @@ def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndar # ============================================================================= -def block_reduce(arr, block_size, func=np.mean, preserve_dtype=True): +def block_reduce( + arr: np.ndarray, + block_size: tuple[int, ...], + func: Callable = np.mean, + preserve_dtype: bool = True, +) -> np.ndarray: """ Block reduce array using GPU (torch) or CPU (skimage). @@ -322,9 +342,9 @@ def block_reduce(arr, block_size, func=np.mean, preserve_dtype=True): ---------- arr : ndarray Input array (2D or 3D with channel dim first). - block_size : tuple + block_size : tuple[int, ...] Reduction factors per dimension. - func : callable + func : Callable Reduction function (only np.mean supported on GPU). preserve_dtype : bool If True, output dtype matches input dtype. Default True. @@ -377,7 +397,7 @@ def _block_reduce_torch(arr: np.ndarray, block_size: tuple) -> np.ndarray: # ============================================================================= -def compute_ssim(arr1, arr2, win_size: int) -> float: +def compute_ssim(arr1: np.ndarray, arr2: np.ndarray, win_size: int) -> float: """ Compute SSIM using GPU (torch) or CPU (skimage). @@ -472,15 +492,18 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray: return prof -def to_numpy(arr): +def to_numpy(arr) -> np.ndarray: """Convert array to numpy, handling both CPU and GPU arrays.""" if TORCH_AVAILABLE and torch is not None and isinstance(arr, torch.Tensor): return arr.cpu().numpy() return np.asarray(arr) -def to_device(arr): - """Move array to GPU if available, else return numpy array.""" +def to_device(arr) -> Union[Any, np.ndarray]: + """Move array to GPU if available, else return numpy array. + + Returns torch.Tensor on GPU if CUDA available, else np.ndarray. + """ if CUDA_AVAILABLE: return torch.from_numpy(np.asarray(arr)).cuda() return np.asarray(arr) From 66fd2e903234d57c27e4d422dc08a3f2bfc525be Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:34:04 -0800 Subject: [PATCH 16/18] Fix histogram matching bug and improve type hints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix incorrect pixel assignment in _match_histograms_torch The previous code used unnecessary indexing that permuted results incorrectly - Simplify to_device return type from Union[Any, np.ndarray] to Any - Remove unused Union import - Add pixel-by-pixel test comparing GPU vs skimage results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/tilefusion/utils.py | 9 ++++----- tests/test_histogram_match.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 5082ba2..668370a 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -5,7 +5,7 @@ All functions support GPU acceleration via PyTorch with automatic CPU fallback. """ -from typing import Any, Callable, Union +from typing import Any, Callable import numpy as np @@ -315,9 +315,8 @@ def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndar inv_indices[img_indices] = torch.arange(len(img), device="cuda") # Map image values to reference values via quantile matching - # For each pixel, find corresponding quantile in reference - interp_values = torch.zeros_like(img) - interp_values[img_indices] = ref_sorted[ + # inv_indices[i] = rank of pixel i, so look up ref value at that scaled rank + interp_values = ref_sorted[ (inv_indices.float() / len(img) * len(ref)).long().clamp(0, len(ref) - 1) ] @@ -499,7 +498,7 @@ def to_numpy(arr) -> np.ndarray: return np.asarray(arr) -def to_device(arr) -> Union[Any, np.ndarray]: +def to_device(arr) -> Any: """Move array to GPU if available, else return numpy array. Returns torch.Tensor on GPU if CUDA available, else np.ndarray. diff --git a/tests/test_histogram_match.py b/tests/test_histogram_match.py index 26fa774..9223c50 100644 --- a/tests/test_histogram_match.py +++ b/tests/test_histogram_match.py @@ -48,6 +48,17 @@ def test_different_sizes(self, rng): result = match_histograms(img, ref) assert result.shape == img.shape + def test_pixel_values_match_skimage(self, rng): + """Test pixel-by-pixel matching against skimage.""" + img = rng.random((64, 64)).astype(np.float32) + ref = rng.random((64, 64)).astype(np.float32) * 2 + 1 + + cpu = skimage_match(img, ref) + gpu = match_histograms(img, ref) + + # Pixel values should be close (not just histogram shape) + np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4) + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 876d400284e21b2474f25f2f4e03dca779640d64 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 15:48:04 -0800 Subject: [PATCH 17/18] Cosmetic cleanups and fix shift_array CPU/GPU consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove redundant `torch is not None` check in to_numpy - Add type hint to _shift_array_torch shift_vec parameter - Fix shift_array CPU path to compute in float64 for API consistency (preserve_dtype=False now returns float on both GPU and CPU paths) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/tilefusion/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 668370a..61ab205 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -217,14 +217,16 @@ def shift_array( if CUDA_AVAILABLE and arr_np.ndim == 2: result = _shift_array_torch(arr_np, shift_vec) else: - result = _shift_cpu(arr_np, shift=shift_vec, order=1, prefilter=False) + # Compute in float for consistency with GPU path + arr_float = arr_np.astype(np.float64) + result = _shift_cpu(arr_float, shift=shift_vec, order=1, prefilter=False) if preserve_dtype and result.dtype != original_dtype: return result.astype(original_dtype) return result -def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray: +def _shift_array_torch(arr: np.ndarray, shift_vec: tuple[float, float]) -> np.ndarray: """GPU shift using torch.nn.functional.grid_sample.""" h, w = arr.shape @@ -493,7 +495,7 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray: def to_numpy(arr) -> np.ndarray: """Convert array to numpy, handling both CPU and GPU arrays.""" - if TORCH_AVAILABLE and torch is not None and isinstance(arr, torch.Tensor): + if TORCH_AVAILABLE and isinstance(arr, torch.Tensor): return arr.cpu().numpy() return np.asarray(arr) From 5fd466f539e8904160999b794e748f51e4139299 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Sun, 4 Jan 2026 16:17:52 -0800 Subject: [PATCH 18/18] Document GPU path placeholder values in phase_cross_correlation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The GPU implementation returns 0.0 for error and phasediff values since these are not computed. Added notes to docstring to clarify. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/tilefusion/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index 61ab205..944d6ec 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -92,9 +92,11 @@ def phase_cross_correlation( shift : ndarray Shift vector (y, x). error : float - Translation invariant normalized RMS error (placeholder). + Translation invariant normalized RMS error. + Note: GPU path returns 0.0 (not computed). phasediff : float - Global phase difference (placeholder). + Global phase difference. + Note: GPU path returns 0.0 (not computed). """ ref_np = np.asarray(reference_image) mov_np = np.asarray(moving_image)