From 8ad82f46ec3c86e62ca8c2037c293a48144e98ee Mon Sep 17 00:00:00 2001 From: Deepika Sundarraman Date: Wed, 12 Mar 2025 16:53:33 -0700 Subject: [PATCH 1/7] modified fancy threshold function to incorporate max foreground to clip hist beyond these values --- dexpv2/segmentation.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dexpv2/segmentation.py b/dexpv2/segmentation.py index 90effa6..ff43766 100644 --- a/dexpv2/segmentation.py +++ b/dexpv2/segmentation.py @@ -46,7 +46,10 @@ def reconstruction_by_dilation( def fancy_otsu_threshold( - image: ArrayLike, remove_hist_mode: bool = False, min_foreground: float = 0.0 + image: ArrayLike, + remove_hist_mode: bool = False, + min_foreground: float = 0.0, + max_foreground: float = 0.0, ) -> float: """ Compute Otsu threshold with some additional features. @@ -61,6 +64,8 @@ def fancy_otsu_threshold( Removes histogram mode before computing otsu threshold, useful when background regions are being detected. min_foreground : float, optional Minimum threshold value, by default 0.0 + max_foreground: float, optional + Maximum threshold value, by default max value of image Returns ------- @@ -86,6 +91,11 @@ def fancy_otsu_threshold( hist, bin_centers = exposure.histogram(image, nbins) + # clip bins and histogram beyond max_foreground value + if max_foreground != 0.0: + hist = [hist[i] for i in range(len(hist)) if bins[i] < max_foreground] + bins = [bin for bin in bins if bin < max_foreground] + # histogram disconsidering pixels we are sure are background if remove_hist_mode: remaining_background_idx = hist.argmax() + 1 From 49d60b502953ad35f666d3b5039cd8584bc93afd Mon Sep 17 00:00:00 2001 From: dsundarraman <35620952+dsundarraman@users.noreply.github.com> Date: Fri, 14 Mar 2025 10:17:32 -0700 Subject: [PATCH 2/7] Update dexpv2/segmentation.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jordão Bragantini --- dexpv2/segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dexpv2/segmentation.py b/dexpv2/segmentation.py index ff43766..e2044db 100644 --- a/dexpv2/segmentation.py +++ b/dexpv2/segmentation.py @@ -49,7 +49,7 @@ def fancy_otsu_threshold( image: ArrayLike, remove_hist_mode: bool = False, min_foreground: float = 0.0, - max_foreground: float = 0.0, + max_foreground: Optional[float] = None, ) -> float: """ Compute Otsu threshold with some additional features. From 98b2aed399bd2aabc016218eee0d012da34129bf Mon Sep 17 00:00:00 2001 From: dsundarraman <35620952+dsundarraman@users.noreply.github.com> Date: Fri, 14 Mar 2025 10:17:41 -0700 Subject: [PATCH 3/7] Update dexpv2/segmentation.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jordão Bragantini --- dexpv2/segmentation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dexpv2/segmentation.py b/dexpv2/segmentation.py index e2044db..454965e 100644 --- a/dexpv2/segmentation.py +++ b/dexpv2/segmentation.py @@ -92,9 +92,10 @@ def fancy_otsu_threshold( hist, bin_centers = exposure.histogram(image, nbins) # clip bins and histogram beyond max_foreground value - if max_foreground != 0.0: - hist = [hist[i] for i in range(len(hist)) if bins[i] < max_foreground] - bins = [bin for bin in bins if bin < max_foreground] + if max_foreground is not None: + below_threshold_mask = bin < max_foreground + bins = bins[below_threshold_mask] + hist = hist[below_threshold_mask] # histogram disconsidering pixels we are sure are background if remove_hist_mode: From be1d9045ddef3894ecec0b27ba4c91810d66e7b4 Mon Sep 17 00:00:00 2001 From: Deepika Sundarraman Date: Fri, 14 Mar 2025 14:21:34 -0700 Subject: [PATCH 4/7] fixed issue with max foreground parameter and adding sigma to alter thresholding --- dexpv2/segmentation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dexpv2/segmentation.py b/dexpv2/segmentation.py index 454965e..aa3d348 100644 --- a/dexpv2/segmentation.py +++ b/dexpv2/segmentation.py @@ -49,7 +49,7 @@ def fancy_otsu_threshold( image: ArrayLike, remove_hist_mode: bool = False, min_foreground: float = 0.0, - max_foreground: Optional[float] = None, + max_foreground: float = None, ) -> float: """ Compute Otsu threshold with some additional features. @@ -90,11 +90,12 @@ def fancy_otsu_threshold( LOG.info(f"Histogram with {nbins}") hist, bin_centers = exposure.histogram(image, nbins) - + print(len(bin_centers)) # clip bins and histogram beyond max_foreground value if max_foreground is not None: - below_threshold_mask = bin < max_foreground - bins = bins[below_threshold_mask] + below_threshold_mask = bin_centers < np.sqrt(max_foreground) + bin_centers = bin_centers[below_threshold_mask] + print(bin_centers) hist = hist[below_threshold_mask] # histogram disconsidering pixels we are sure are background @@ -109,7 +110,6 @@ def fancy_otsu_threshold( threshold = max(threshold, min_foreground) LOG.info(f"Threshold after minimum filtering {threshold}") - return threshold @@ -159,6 +159,7 @@ def detect_foreground( sigma: float = 15.0, remove_hist_mode: bool = False, min_foreground: float = 0.0, + max_foreground: float = None, ) -> ArrayLike: """ Detect foreground using morphological reconstruction by dilation and thresholding. @@ -194,6 +195,7 @@ def detect_foreground( small_foreground, remove_hist_mode=remove_hist_mode, min_foreground=min_foreground, + max_foreground=max_foreground, ) mask = foreground > threshold From 1c3a63e5b60fb0b8820d5f8fd6005db16fccf532 Mon Sep 17 00:00:00 2001 From: "ilan.theodoro" Date: Wed, 7 May 2025 13:37:15 -0700 Subject: [PATCH 5/7] Fix bug of grey_dilation on cp.float16 data --- dexpv2/_tests/test_segmentation.py | 31 +++++- dexpv2/segmentation.py | 149 ++++++++++++++++++++++++++++- 2 files changed, 177 insertions(+), 3 deletions(-) diff --git a/dexpv2/_tests/test_segmentation.py b/dexpv2/_tests/test_segmentation.py index 71ef18e..b4cb562 100644 --- a/dexpv2/_tests/test_segmentation.py +++ b/dexpv2/_tests/test_segmentation.py @@ -1,8 +1,9 @@ import logging from skimage.data import cells3d +import pytest -from dexpv2.segmentation import detect_foreground +from dexpv2.segmentation import detect_foreground, reconstruction_by_dilation from dexpv2.utils import to_cpu LOG = logging.getLogger(__name__) @@ -32,3 +33,31 @@ def test_foreground_detection(interactive_test: bool) -> None: viewer.add_labels(to_cpu(foreground)) napari.run() + + +def test_foreground_detection_with_float16() -> None: + # Test with float16 dat + # a + nuclei = xp.asarray(cells3d()[:, 1]) + nuclei = nuclei / nuclei.max() + nuclei = nuclei.astype(xp.float16) + mask = xp.copy(nuclei) + + # Ensure we are using cupy backend + import numpy as np + + if isinstance(nuclei, np.ndarray): + pytest.skip("Skipping test as cupy is not available.") + + foreground_cp = reconstruction_by_dilation(nuclei, mask, iterations=10) + foreground_cp = to_cpu(foreground_cp) + + # Convert to numpy for comparison + # Obs. skimage operations won't work with np.float16 so we need to convert + # to float32 and hope that the conversion doesn't change the result too much + nuclei_np = nuclei.astype(xp.float32) + mask_np = mask.astype(xp.float32) + foreground_np = reconstruction_by_dilation(nuclei_np, mask_np, iterations=10) + + # Check that the output is a binary mask + assert np.allclose(foreground_cp, foreground_np) diff --git a/dexpv2/segmentation.py b/dexpv2/segmentation.py index aa3d348..92c91db 100644 --- a/dexpv2/segmentation.py +++ b/dexpv2/segmentation.py @@ -1,4 +1,5 @@ import logging +from typing import Tuple, List import numpy as np from numpy.typing import ArrayLike @@ -10,6 +11,135 @@ LOG = logging.getLogger(__name__) LOG.setLevel(logging.INFO) +try: + import cupy as xp + + LOG.info("cupy found.") +except (ModuleNotFoundError, ImportError): + import numpy as xp + + LOG.info("cupy not found using numpy.") + + +def discretize_multiple_float16_to_uint16( + float16_arrays: List["xp.ndarray"], # type: ignore # 'xp' will be defined at runtime +) -> Tuple[List["xp.ndarray"], "xp.ndarray"]: # type: ignore + """ + Discretizes multiple arrays (e.g., CuPy or NumPy) of float16 values to uint16, + preserving order using a global mapping across all arrays. + + Args: + float16_arrays (List[xp.ndarray]): A list of arrays (e.g., CuPy or NumPy), + each with dtype float16. + 'xp' should be the array module (e.g., numpy or cupy). + + Returns: + tuple: A tuple containing: + - y_uint16_list (List[xp.ndarray]): A list of discretized arrays, + each with dtype uint16, corresponding + to the input arrays. + - uint16_to_float16_lookup (xp.ndarray): A single lookup table + (array of float16) for all arrays, where the index + is the uint16 value and the value is the corresponding + original float16 value. + + Raises: + TypeError: If any input array's dtype is not float16. + ValueError: If the list of arrays is empty, or if the total number + of unique values across all arrays exceeds the capacity + of uint16 (65536). + """ + # Ensure xp is defined (this is more of a runtime check if not using static analysis) + if "xp" not in globals() and "xp" not in locals(): + raise NameError( + "Array library 'xp' is not defined. Please import numpy as xp or cupy as xp." + ) + + if not float16_arrays: + raise ValueError("Input list of arrays cannot be empty.") + + # Validate input types and collect original shapes and sizes + original_shapes = [] + original_sizes = [] + for i, arr in enumerate(float16_arrays): + # Assuming 'xp' is defined, xp.ndarray would be the type to check against + # For simplicity and following user's snippet, primarily checking dtype. + if not hasattr(arr, "dtype") or arr.dtype != xp.float16: + raise TypeError( + f"Array at index {i} must be an 'xp.ndarray' with dtype xp.float16. " + f"Got type {type(arr)} with dtype {getattr(arr, 'dtype', 'N/A')}." + ) + original_shapes.append(arr.shape) + original_sizes.append(arr.size) + + # Handle case where all arrays combined are empty + if sum(original_sizes) == 0: + # Create empty uint16 arrays with original shapes + empty_uint16_list = [ + xp.array([], dtype=xp.uint16).reshape(shape) for shape in original_shapes + ] + return empty_uint16_list, xp.array([], dtype=xp.float16) + + # Concatenate all arrays into a single flat array for global unique value finding. + # We need to ensure that we only concatenate non-empty arrays if ravel() on empty + # arrays with certain shapes causes issues, or handle shapes carefully. + # xp.concatenate([arr.ravel() for arr in float16_arrays]) should generally work. + # Ravel ensures that even multi-dimensional arrays become 1D before concatenation. + try: + combined_float16_array = xp.concatenate([arr.ravel() for arr in float16_arrays]) + except Exception as e: + raise ValueError( + f"Error during concatenation of arrays: {e}. Ensure 'xp' is correctly defined (NumPy/CuPy)." + ) + + # Find unique values and their inverse indices from the combined array. + # unique_values will be sorted, which is crucial for order preservation. + # inverse_indices will correspond to the flattened combined_float16_array. + unique_values: "xp.ndarray" # type: ignore + inverse_indices: "xp.ndarray" # type: ignore + unique_values, inverse_indices = xp.unique( + combined_float16_array, return_inverse=True + ) + + # The unique_values array serves as the global uint16 to float16 lookup table. + uint16_to_float16_lookup: "xp.ndarray" = unique_values # type: ignore + + # Check if the number of unique values fits into uint16 + # xp.iinfo(xp.uint16).max gives the max value (e.g., 65535). + # Number of unique values can be up to (max_value + 1). + if len(unique_values) > xp.iinfo(xp.uint16).max + 1: + raise ValueError( + f"Number of unique values ({len(unique_values)}) across all arrays " + f"exceeds the maximum capacity of uint16 ({xp.iinfo(xp.uint16).max + 1})." + ) + + # The inverse_indices array contains the uint16 representations for the combined flat array. + # Cast it to uint16. + y_uint16_combined_flat: "xp.ndarray" = inverse_indices.astype(xp.uint16) # type: ignore + + # Split the combined flat uint16 array back into individual arrays and reshape them + y_uint16_list: List["xp.ndarray"] = [] # type: ignore + current_pos = 0 + for i in range(len(float16_arrays)): + size = original_sizes[i] + shape = original_shapes[i] + + if size == 0: + # Create an empty uint16 array with the original shape + y_uint16_list.append(xp.array([], dtype=xp.uint16).reshape(shape)) + else: + segment = y_uint16_combined_flat[current_pos : current_pos + size] + y_uint16_list.append(segment.reshape(shape)) + current_pos += size + + if current_pos != y_uint16_combined_flat.size: + # This should not happen if logic is correct, but good for sanity check + raise AssertionError( + "Mismatch in processed elements during splitting of combined array." + ) + + return y_uint16_list, uint16_to_float16_lookup + def reconstruction_by_dilation( seed: ArrayLike, mask: ArrayLike, iterations: int @@ -34,14 +164,29 @@ def reconstruction_by_dilation( ------- Image reconstructed by dilation. """ - ndi = import_module("scipy", "ndimage") + ndi = import_module("scipy", "ndimage", seed) + + import numpy as np - seed = np.minimum(seed, mask, out=seed) # just making sure + cupy_used = np != xp and not isinstance(seed, np.ndarray) + + lut = None + # quick-fix for the issue https://github.com/cupy/cupy/issues/9122 + if cupy_used and seed.dtype == xp.float16: + arrs, lut = discretize_multiple_float16_to_uint16([seed, mask]) + seed = arrs[0] + mask = arrs[1] + + seed = np.minimum(seed, mask, out=seed) for _ in range(iterations): seed = ndi.grey_dilation(seed, size=3, output=seed, mode="constant") seed = np.minimum(seed, mask, out=seed) + if lut is not None: + # convert back to float16 + seed = xp.take(lut, seed) + return seed From f334bb24d030374dd3f16733e1d85585e19d1c57 Mon Sep 17 00:00:00 2001 From: "ilan.theodoro" Date: Wed, 7 May 2025 13:45:33 -0700 Subject: [PATCH 6/7] Improve testing on new reconstruction function --- dexpv2/_tests/test_segmentation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dexpv2/_tests/test_segmentation.py b/dexpv2/_tests/test_segmentation.py index b4cb562..246e73f 100644 --- a/dexpv2/_tests/test_segmentation.py +++ b/dexpv2/_tests/test_segmentation.py @@ -52,12 +52,17 @@ def test_foreground_detection_with_float16() -> None: foreground_cp = reconstruction_by_dilation(nuclei, mask, iterations=10) foreground_cp = to_cpu(foreground_cp) + nuclei_f32 = nuclei.astype(xp.float32) + mask_f32 = mask.astype(xp.float32) + foreground_f32 = reconstruction_by_dilation(nuclei_f32, mask_f32, iterations=10) + # Convert to numpy for comparison # Obs. skimage operations won't work with np.float16 so we need to convert # to float32 and hope that the conversion doesn't change the result too much - nuclei_np = nuclei.astype(xp.float32) - mask_np = mask.astype(xp.float32) + nuclei_np = to_cpu(nuclei_f32) + mask_np = to_cpu(mask_f32) foreground_np = reconstruction_by_dilation(nuclei_np, mask_np, iterations=10) # Check that the output is a binary mask assert np.allclose(foreground_cp, foreground_np) + assert np.allclose(foreground_cp, foreground_f32) From e137d11ebe30eb5eb94ccea55fe250f0baefa4b9 Mon Sep 17 00:00:00 2001 From: "ilan.theodoro" Date: Thu, 8 May 2025 18:26:19 -0700 Subject: [PATCH 7/7] Improve fp16 discretization to reduce memory usage --- dexpv2/_tests/test_segmentation.py | 3 +- dexpv2/segmentation.py | 157 +++++++++++------------------ 2 files changed, 58 insertions(+), 102 deletions(-) diff --git a/dexpv2/_tests/test_segmentation.py b/dexpv2/_tests/test_segmentation.py index 246e73f..c904116 100644 --- a/dexpv2/_tests/test_segmentation.py +++ b/dexpv2/_tests/test_segmentation.py @@ -36,8 +36,7 @@ def test_foreground_detection(interactive_test: bool) -> None: def test_foreground_detection_with_float16() -> None: - # Test with float16 dat - # a + # Test with float16 data nuclei = xp.asarray(cells3d()[:, 1]) nuclei = nuclei / nuclei.max() nuclei = nuclei.astype(xp.float16) diff --git a/dexpv2/segmentation.py b/dexpv2/segmentation.py index 92c91db..bf00078 100644 --- a/dexpv2/segmentation.py +++ b/dexpv2/segmentation.py @@ -21,124 +21,83 @@ LOG.info("cupy not found using numpy.") -def discretize_multiple_float16_to_uint16( - float16_arrays: List["xp.ndarray"], # type: ignore # 'xp' will be defined at runtime -) -> Tuple[List["xp.ndarray"], "xp.ndarray"]: # type: ignore +def discretize_multiple_f16_to_u16( + f16_arrays: List[ArrayLike], +) -> Tuple[List[ArrayLike], ArrayLike]: """ - Discretizes multiple arrays (e.g., CuPy or NumPy) of float16 values to uint16, - preserving order using a global mapping across all arrays. - - Args: - float16_arrays (List[xp.ndarray]): A list of arrays (e.g., CuPy or NumPy), - each with dtype float16. - 'xp' should be the array module (e.g., numpy or cupy). - - Returns: - tuple: A tuple containing: - - y_uint16_list (List[xp.ndarray]): A list of discretized arrays, - each with dtype uint16, corresponding - to the input arrays. - - uint16_to_float16_lookup (xp.ndarray): A single lookup table - (array of float16) for all arrays, where the index - is the uint16 value and the value is the corresponding - original float16 value. - - Raises: + Discretizes multiple arrays (e.g., CuPy or NumPy) of float16 values to + uint16, preserving order using a global mapping across all arrays. + + Parameters + ---------- + f16_arrays : List[ArrayLike] + List of input arrays to be discretized. Each array must have dtype + float16. + + Returns + ------- + Tuple[List[ArrayLike], ArrayLike]: A tuple containing: + - u16_list (List[ArrayLike]): A list of discretized arrays, + each with dtype uint16, + corresponding to the input + arrays. + - u16_to_f16_lut (ArrayLike): A single lookup table (array of + float16) for all arrays, where the + index is the uint16 value and the + value is the corresponding original + float16 value. + + Raises + ------ TypeError: If any input array's dtype is not float16. ValueError: If the list of arrays is empty, or if the total number of unique values across all arrays exceeds the capacity of uint16 (65536). """ - # Ensure xp is defined (this is more of a runtime check if not using static analysis) - if "xp" not in globals() and "xp" not in locals(): - raise NameError( - "Array library 'xp' is not defined. Please import numpy as xp or cupy as xp." - ) - - if not float16_arrays: + if not f16_arrays: raise ValueError("Input list of arrays cannot be empty.") # Validate input types and collect original shapes and sizes original_shapes = [] - original_sizes = [] - for i, arr in enumerate(float16_arrays): - # Assuming 'xp' is defined, xp.ndarray would be the type to check against - # For simplicity and following user's snippet, primarily checking dtype. - if not hasattr(arr, "dtype") or arr.dtype != xp.float16: + for i, arr in enumerate(f16_arrays): + if arr.dtype != xp.float16: raise TypeError( f"Array at index {i} must be an 'xp.ndarray' with dtype xp.float16. " f"Got type {type(arr)} with dtype {getattr(arr, 'dtype', 'N/A')}." ) + if arr.size == 0: + raise ValueError( + f"Array at index {i} is empty. Cannot discretize empty arrays." + ) original_shapes.append(arr.shape) - original_sizes.append(arr.size) - - # Handle case where all arrays combined are empty - if sum(original_sizes) == 0: - # Create empty uint16 arrays with original shapes - empty_uint16_list = [ - xp.array([], dtype=xp.uint16).reshape(shape) for shape in original_shapes - ] - return empty_uint16_list, xp.array([], dtype=xp.float16) - - # Concatenate all arrays into a single flat array for global unique value finding. - # We need to ensure that we only concatenate non-empty arrays if ravel() on empty - # arrays with certain shapes causes issues, or handle shapes carefully. - # xp.concatenate([arr.ravel() for arr in float16_arrays]) should generally work. - # Ravel ensures that even multi-dimensional arrays become 1D before concatenation. - try: - combined_float16_array = xp.concatenate([arr.ravel() for arr in float16_arrays]) - except Exception as e: - raise ValueError( - f"Error during concatenation of arrays: {e}. Ensure 'xp' is correctly defined (NumPy/CuPy)." - ) - # Find unique values and their inverse indices from the combined array. - # unique_values will be sorted, which is crucial for order preservation. - # inverse_indices will correspond to the flattened combined_float16_array. - unique_values: "xp.ndarray" # type: ignore - inverse_indices: "xp.ndarray" # type: ignore - unique_values, inverse_indices = xp.unique( - combined_float16_array, return_inverse=True - ) + # Collect unique values and their indices + uniques, inverses = [], [] + for arr in f16_arrays: + unq, inv = np.unique(arr, return_inverse=True) + uniques.append(unq) + inverses.append(inv.astype(np.uint16)) - # The unique_values array serves as the global uint16 to float16 lookup table. - uint16_to_float16_lookup: "xp.ndarray" = unique_values # type: ignore - - # Check if the number of unique values fits into uint16 - # xp.iinfo(xp.uint16).max gives the max value (e.g., 65535). - # Number of unique values can be up to (max_value + 1). - if len(unique_values) > xp.iinfo(xp.uint16).max + 1: + # Concatenate all unique values and sort them + u16_to_f16_lut = np.sort(np.concatenate(uniques)) + if len(u16_to_f16_lut) > np.iinfo(np.uint16).max: raise ValueError( - f"Number of unique values ({len(unique_values)}) across all arrays " - f"exceeds the maximum capacity of uint16 ({xp.iinfo(xp.uint16).max + 1})." + "The total number of unique values across all arrays exceeds " + "the capacity of uint16 (65536)." ) - # The inverse_indices array contains the uint16 representations for the combined flat array. - # Cast it to uint16. - y_uint16_combined_flat: "xp.ndarray" = inverse_indices.astype(xp.uint16) # type: ignore - - # Split the combined flat uint16 array back into individual arrays and reshape them - y_uint16_list: List["xp.ndarray"] = [] # type: ignore - current_pos = 0 - for i in range(len(float16_arrays)): - size = original_sizes[i] - shape = original_shapes[i] - - if size == 0: - # Create an empty uint16 array with the original shape - y_uint16_list.append(xp.array([], dtype=xp.uint16).reshape(shape)) - else: - segment = y_uint16_combined_flat[current_pos : current_pos + size] - y_uint16_list.append(segment.reshape(shape)) - current_pos += size - - if current_pos != y_uint16_combined_flat.size: - # This should not happen if logic is correct, but good for sanity check - raise AssertionError( - "Mismatch in processed elements during splitting of combined array." - ) + # Fix inverses to preserve order + for unq_k, inv_k in zip(uniques, inverses): + new_idx_k = np.searchsorted(u16_to_f16_lut, unq_k) + inv_k[:] = new_idx_k[inv_k] + + # Reshape inverses to match original shapes + u16_list: list[ArrayLike] = [] + for shape, inv in zip(original_shapes, inverses): + inv_reshaped = inv.reshape(shape) + u16_list.append(inv_reshaped) - return y_uint16_list, uint16_to_float16_lookup + return u16_list, u16_to_f16_lut def reconstruction_by_dilation( @@ -173,9 +132,7 @@ def reconstruction_by_dilation( lut = None # quick-fix for the issue https://github.com/cupy/cupy/issues/9122 if cupy_used and seed.dtype == xp.float16: - arrs, lut = discretize_multiple_float16_to_uint16([seed, mask]) - seed = arrs[0] - mask = arrs[1] + (seed, mask), lut = discretize_multiple_f16_to_u16([seed, mask]) seed = np.minimum(seed, mask, out=seed)