From a516ad0180f9ed2f845d697497e28021ff51c2d0 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 20 Jun 2024 14:24:37 -0700 Subject: [PATCH] updating cellpose args documentation and fixing cellpose import --- ultrack/__init__.py | 7 +++++++ ultrack/imgproc/segmentation.py | 28 +++++++++++++++++++++------- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/ultrack/__init__.py b/ultrack/__init__.py index 70e79ab..4c9b848 100644 --- a/ultrack/__init__.py +++ b/ultrack/__init__.py @@ -7,6 +7,13 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) +# Cellpose and ultrack had conflicts due to torch/cuda leading to Segmentation Fault +# importing Cellpose first avoids the issue, https://github.com/royerlab/ultrack/issues/108 +try: + from cellpose.models import Cellpose # noqa: F401 +except (ImportError, ModuleNotFoundError): + pass + # ignoring small float32/64 zero flush warning warnings.filterwarnings("ignore", message="The value of the smallest subnormal for") diff --git a/ultrack/imgproc/segmentation.py b/ultrack/imgproc/segmentation.py index 033ea9c..e94d9a4 100644 --- a/ultrack/imgproc/segmentation.py +++ b/ultrack/imgproc/segmentation.py @@ -1,5 +1,6 @@ +import functools import logging -from typing import Optional +from typing import Callable, Optional import edt import numpy as np @@ -211,20 +212,33 @@ def inverted_edt( return dist +def _maybe_wrap(wrapper_name: str) -> Callable: + """Wraps function with cellpose model method if cellpose is available.""" + try: + from cellpose.models import CellposeModel as _Cellpose + except ImportError: + return lambda x: x + + return functools.wraps(getattr(_Cellpose, wrapper_name)) + + class Cellpose: + @_maybe_wrap("__init__") def __init__(self, **kwargs) -> None: - """See cellpose.models.Cellpose documentation for details.""" - from cellpose.models import CellposeModel as _Cellpose + try: + from cellpose.models import CellposeModel as _Cellpose + except ImportError as e: + raise ImportError( + "Cellpose not found, please install it." + "See for instructions https://github.com/MouseLand/cellpose" + ) from e if "pretrained_model" not in kwargs and "model_type" not in kwargs: kwargs["model_type"] = "cyto" self.model = _Cellpose(**kwargs) + @_maybe_wrap("eval") def __call__(self, image: ArrayLike, **kwargs) -> np.ndarray: - """ - Predicts image labels. - See cellpose.models.Cellpose.eval documentation for details. - """ labels, _, _ = self.model.eval(image, **kwargs) return labels