diff --git a/.gitignore b/.gitignore index fed8b4316..97361fc06 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ .env .idea/ .tox/ +Unet 2D Nuclei Broad_example/ __pycache__/ bioimageio_cache/ bioimageio_unzipped_tf_weights/ @@ -15,7 +16,8 @@ coverage.xml dist/ docs/api/ dogfood/ +example/**/output/ +output.zip pkgs/ site/ typings/pooch/ -example/**/output/ diff --git a/changelog.md b/changelog.md index 131bc461d..bec314f0a 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,10 @@ +### 0.11.0 + +- bump spec to 0.5.11.0 +- support ONNX provider choice (as 'devices') +- added experimental bioimageio.core.backends.gradio_backend +- improved prediction pipeline and model adapter interfaces + ### 0.10.4 - fix postprocessing order diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 2552cf93d..89d98362b 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -55,7 +55,7 @@ test: requires: {% for dep in pyproject['project']['optional-dependencies']['dev'] %} {% if 'torch' not in dep %} # can't install pytorch>=2.8 from conda-forge smh - - {{ dep.replace(';python_version<"3.10"', '').lower().replace('_', '-') }} + - {{ dep.replace('[mcp]', '').replace(';python_version>="3.10"', '').replace(';python_version<"3.10"', '').lower().replace('_', '-') }} {% endif %} {% endfor %} commands: diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..d58a89b2f --- /dev/null +++ b/conftest.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any + + +def pytest_ignore_collect(collection_path: Path, config: Any) -> bool: + if sys.version_info >= (3, 10): + return False + + path = str(collection_path).replace("\\", "/") + return "/src/bioimageio/core/remote_backends/gradio/" in path diff --git a/example/dataset_statistics_demo.ipynb b/example/dataset_statistics_demo.ipynb index 705fe2957..7e9aff6b8 100644 --- a/example/dataset_statistics_demo.ipynb +++ b/example/dataset_statistics_demo.ipynb @@ -29,7 +29,6 @@ "\n", "from pprint import pprint\n", "\n", - "import imageio\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import xarray as xr\n", @@ -39,7 +38,6 @@ "warnings.filterwarnings(\"ignore\")\n", "\n", "import bioimageio.core\n", - "from bioimageio.core.prediction import predict_with_tiling\n", "from bioimageio.core.prediction_pipeline import create_prediction_pipeline" ] }, @@ -196,7 +194,7 @@ "source": [ "def process_dataset(pp, dataset):\n", " stats = pp._ipt_stats.compute_measures()[\"per_dataset\"]\n", - " print(f\"initial stats:\")\n", + " print(\"initial stats:\")\n", " pprint(\n", " None\n", " if not stats\n", @@ -205,7 +203,7 @@ " stats = {}\n", " sample_dataset = [{\"input0\": s} for s in dataset]\n", " [pp.apply_preprocessing(s, stats) for s in sample_dataset]\n", - " print(f\"final stats:\")\n", + " print(\"final stats:\")\n", " pprint(\n", " None\n", " if not stats\n", diff --git a/example/export_cellpose_model/cellpose_original.py b/example/export_cellpose_model/cellpose_original.py index cd431446b..abe8f73b9 100644 --- a/example/export_cellpose_model/cellpose_original.py +++ b/example/export_cellpose_model/cellpose_original.py @@ -1,4 +1,6 @@ -"""Run original cellpose model and save an analog input and output for bioimageio tests""" +"""Run original cellpose model and save an analog input and output for bioimageio tests +Works on cellpose 4.1, but not since cellpose 4.2 +""" import os from pathlib import Path diff --git a/example/model_usage.ipynb b/example/model_usage.ipynb index 230ccb035..2e3db5d1b 100644 --- a/example/model_usage.ipynb +++ b/example/model_usage.ipynb @@ -495,7 +495,6 @@ ], "source": [ "from bioimageio.core.digest_spec import create_sample_for_model\n", - "from bioimageio.spec.utils import download\n", "\n", "input_paths = {ipt.id: ipt.test_tensor.source for ipt in model.inputs}\n", "print(f\"input paths: {input_paths}\")\n", diff --git a/mkdocs.yaml b/mkdocs.yaml index 8ebd9a4eb..8ad8b12c9 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -87,7 +87,7 @@ plugins: python: inventories: - https://docs.pydantic.dev/latest/objects.inv - - https://bioimage-io.github.io/spec-bioimage-io/v0.5.10.2/objects.inv + - https://bioimage-io.github.io/spec-bioimage-io/v0.5.11.0/objects.inv options: annotations_path: source backlinks: tree diff --git a/pyproject.toml b/pyproject.toml index 6ad14f53d..c78938134 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires-python = ">=3.9" readme = "README.md" dynamic = ["version"] dependencies = [ - "bioimageio.spec ==0.5.10.2", + "bioimageio.spec ==0.5.11.0", "imagecodecs", "imageio>=2.10", "loguru", @@ -52,9 +52,13 @@ partners = [ # "stardist", # for model testing and stardist postprocessing # TODO: add updated stardist to partners env ] stardist = ["stardist"] # for stardist postprocessing +gradio-server = ['gradio[mcp];python_version>="3.10"'] +gradio-client = ['gradio_client;python_version>="3.10"'] dev = [ - "cellpose", # for model testing + "cellpose<4.2", # for model testing "crick", + 'gradio[mcp];python_version>="3.10"', + 'gradio_client;python_version>="3.10"', "httpx", "jax", "jupyter", @@ -68,7 +72,7 @@ dev = [ 'onnx_ir!=0.1.14;python_version<"3.10"', # uses typing.Concatentate which requires py>=3.10 "packaging>=17.0", "pre-commit", - "pyright==1.1.408", + "pyright==1.1.410", "pytest-cov", "pytest", "python-dotenv", diff --git a/src/bioimageio/core/__init__.py b/src/bioimageio/core/__init__.py index ab35c4e2f..85b84a23b 100644 --- a/src/bioimageio/core/__init__.py +++ b/src/bioimageio/core/__init__.py @@ -16,12 +16,27 @@ """ # ruff: noqa: E402 -__version__ = "0.10.4" +__version__ = "0.11.0" from loguru import logger logger.disable("bioimageio.core") -import bioimageio.spec + +from bioimageio.spec import ValidationSummary as ValidationSummary +from bioimageio.spec import build_description as build_description +from bioimageio.spec import dump_description as dump_description +from bioimageio.spec import load_dataset_description as load_dataset_description +from bioimageio.spec import load_description as load_description +from bioimageio.spec import ( + load_description_and_validate_format_only as load_description_and_validate_format_only, +) +from bioimageio.spec import load_model_description as load_model_description +from bioimageio.spec import save_bioimageio_package as save_bioimageio_package +from bioimageio.spec import ( + save_bioimageio_package_as_folder as save_bioimageio_package_as_folder, +) +from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only +from bioimageio.spec import validate_format as validate_format from . import axis as axis from . import backends as backends @@ -31,7 +46,6 @@ from . import common as common from . import digest_spec as digest_spec from . import io as io -from . import model_adapters as model_adapters from . import prediction as prediction from . import proc_ops as proc_ops from . import proc_setup as proc_setup @@ -40,46 +54,40 @@ from . import stat_measures as stat_measures from . import tensor as tensor from . import weight_converters as weight_converters +from ._prediction_pipeline import IntermediatePrediction as IntermediatePrediction from ._prediction_pipeline import PredictionPipeline as PredictionPipeline +from ._prediction_pipeline import RemotePredictionPipeline as RemotePredictionPipeline from ._prediction_pipeline import ( create_prediction_pipeline as create_prediction_pipeline, ) +from ._prediction_pipeline import ( + create_remote_prediction_pipeline as create_remote_prediction_pipeline, +) from ._resource_tests import enable_determinism as enable_determinism from ._resource_tests import load_description_and_test as load_description_and_test from ._resource_tests import test_description as test_description from ._resource_tests import test_model as test_model +from ._sample_serializer import SampleSerializer as SampleSerializer from ._settings import Settings as Settings from ._settings import settings as settings -# reexports from bioimageio.spec -build_description = bioimageio.spec.build_description -dump_description = bioimageio.spec.dump_description -load_dataset_description = bioimageio.spec.load_dataset_description -load_description = bioimageio.spec.load_description -load_description_and_validate_format_only = ( - bioimageio.spec.load_description_and_validate_format_only -) -load_model_description = bioimageio.spec.load_model_description -save_bioimageio_package = bioimageio.spec.save_bioimageio_package -save_bioimageio_package_as_folder = bioimageio.spec.save_bioimageio_package_as_folder -save_bioimageio_yaml_only = bioimageio.spec.save_bioimageio_yaml_only -validate_format = bioimageio.spec.validate_format -ValidationSummary = bioimageio.spec.ValidationSummary - - # reexports from bioimageio.core submodules -add_weights = weight_converters.add_weights -Axis = axis.Axis -AxisId = axis.AxisId -BlockMeta = block_meta.BlockMeta -compute_dataset_measures = stat_calculators.compute_dataset_measures -create_model_adapter = backends.create_model_adapter -MemberId = common.MemberId -predict = prediction.predict -predict_many = prediction.predict_many -Sample = sample.Sample -Stat = stat_measures.Stat -Tensor = tensor.Tensor +from .axis import Axis as Axis +from .axis import AxisId as AxisId +from .backends import create_model_adapter as create_model_adapter +from .block_meta import BlockMeta as BlockMeta +from .common import MemberId as MemberId +from .prediction import predict as predict +from .prediction import predict_many as predict_many +from .sample import Sample as Sample +from .sample import SampleBlock as SampleBlock +from .sample import SampleBlockMeta as SampleBlockMeta +from .stat_calculators import compute_dataset_measures as compute_dataset_measures +from .stat_calculators import compute_measures as compute_measures +from .stat_calculators import compute_sample_measures as compute_sample_measures +from .stat_measures import Stat as Stat +from .tensor import Tensor as Tensor +from .weight_converters import add_weights as add_weights # aliases test_resource = test_description diff --git a/src/bioimageio/core/__main__.py b/src/bioimageio/core/__main__.py index 123b6a9c9..edcdea06a 100644 --- a/src/bioimageio/core/__main__.py +++ b/src/bioimageio/core/__main__.py @@ -14,7 +14,7 @@ + "{module} - {message}", ) -from .cli import Bioimageio +from .cli import Bioimageio # noqa: E402 def main(): diff --git a/src/bioimageio/core/_axis_annotations.py b/src/bioimageio/core/_axis_annotations.py new file mode 100644 index 000000000..8bd1a7e32 --- /dev/null +++ b/src/bioimageio/core/_axis_annotations.py @@ -0,0 +1,9 @@ +from typing import Annotated, TypeVar + +from ._common_annotations import PydanticMappingProxyAnnotation +from .axis import PerAxis + +_T = TypeVar("_T") + +PerAxisAnno = Annotated[PerAxis[_T], PydanticMappingProxyAnnotation] +"""PerAxis annotated with `PydanticMappingProxyAnnotation` to be compatible with pydantic models.""" diff --git a/src/bioimageio/core/_common_annotations.py b/src/bioimageio/core/_common_annotations.py new file mode 100644 index 000000000..bb255e35e --- /dev/null +++ b/src/bioimageio/core/_common_annotations.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from types import MappingProxyType +from typing import ( + Annotated, + Any, + Hashable, + Mapping, + TypeVar, +) + +import pydantic +from pydantic_core.core_schema import ( + CoreSchema, + chain_schema, + is_instance_schema, + json_or_python_schema, + no_info_plain_validator_function, + plain_serializer_function_ser_schema, +) +from typing_extensions import get_args + +from .common import PerMember + +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + + +def _validate_from_mapping(d: Mapping[_K, _V]) -> MappingProxyType[_K, _V]: + return MappingProxyType(dict(d)) + + +class PydanticMappingProxyAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + + k_type, v_type = get_args(source_type) + mapping_proxy_schema = chain_schema( + [ + handler.generate_schema(dict[k_type, v_type]), + no_info_plain_validator_function(_validate_from_mapping), + is_instance_schema(MappingProxyType), + ] + ) + return json_or_python_schema( + json_schema=mapping_proxy_schema, + python_schema=mapping_proxy_schema, + serialization=plain_serializer_function_ser_schema(dict), + ) + + +_T = TypeVar("_T") + +PerMemberAnno = Annotated[PerMember[_T], PydanticMappingProxyAnnotation] diff --git a/src/bioimageio/core/_description_serializer.py b/src/bioimageio/core/_description_serializer.py new file mode 100644 index 000000000..516cda2c4 --- /dev/null +++ b/src/bioimageio/core/_description_serializer.py @@ -0,0 +1,69 @@ +import base64 +import hashlib +from io import BytesIO +from typing import Tuple +from zipfile import ZipFile + +from bioimageio.spec import ( + InvalidDescr, + ResourceDescr, + load_description, + save_bioimageio_package_to_stream, +) +from bioimageio.spec.common import Sha256 + + +class DescriptionSerializer: + """Description serializer intended for client/server communication, NOT for sharing resource descriptions. + + This serializer only includes local files to keep the serialized package small. + """ + + STRING_ENCODING = "ascii" + + @staticmethod + def serialize(rd: ResourceDescr) -> bytes: + stream = save_bioimageio_package_to_stream(rd, local_files_only=True) + _ = stream.seek(0) + return stream.read() + + @classmethod + def serialize_to_string(cls, rd: ResourceDescr) -> str: + package_bytes = cls.serialize(rd) + + safe_bytes = cls._get_safe_bytes(package_bytes) + serialized_str = safe_bytes.decode(cls.STRING_ENCODING) + if len(serialized_str) <= 2083: + raise RuntimeError( + "Serialized model description should be longer than 2083 characters to not be treated as a URL on the server side." + ) + return serialized_str + + @staticmethod + def _get_safe_bytes(raw_bytes: bytes) -> bytes: + return base64.b64encode(raw_bytes) + + @classmethod + def deserialize_from_string(cls, serialized: str) -> ResourceDescr: + package_bytes = base64.b64decode(serialized.encode(cls.STRING_ENCODING)) + return cls.deserialize(package_bytes) + + @staticmethod + def deserialize(serialized: bytes) -> ResourceDescr: + descr = load_description(ZipFile(BytesIO(serialized)), perform_io_checks=False) + if isinstance(descr, InvalidDescr): + raise ValueError(f"invalid serialized model package: {descr.get_reason()}") + + return descr + + @classmethod + def serialize_to_string_and_hash(cls, rd: ResourceDescr) -> Tuple[str, Sha256]: + package_bytes = cls.serialize(rd) + safe_bytes = cls._get_safe_bytes(package_bytes) + serialized_str = safe_bytes.decode(cls.STRING_ENCODING) + if len(serialized_str) <= 2083: + raise RuntimeError( + "Serialized model description should be longer than 2083 characters to not be treated as a URL on the server side." + ) + sha256 = Sha256(hashlib.sha256(package_bytes).hexdigest()) + return serialized_str, sha256 diff --git a/src/bioimageio/core/_magic_tensor_ops.py b/src/bioimageio/core/_magic_tensor_ops.py index 39084bccb..e60b7a736 100644 --- a/src/bioimageio/core/_magic_tensor_ops.py +++ b/src/bioimageio/core/_magic_tensor_ops.py @@ -38,31 +38,31 @@ def __mul__(self, other: _Compatible) -> Self: return self._binary_op(other, operator.mul) def __pow__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.pow) + return self._binary_op(other, operator.pow) # pyright: ignore[reportUnknownArgumentType] def __truediv__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.truediv) + return self._binary_op(other, operator.truediv) # pyright: ignore[reportUnknownArgumentType] def __floordiv__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.floordiv) + return self._binary_op(other, operator.floordiv) # pyright: ignore[reportUnknownArgumentType] def __mod__(self, other: _Compatible) -> Self: return self._binary_op(other, operator.mod) def __and__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.and_) + return self._binary_op(other, operator.and_) # pyright: ignore[reportUnknownArgumentType] def __xor__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.xor) + return self._binary_op(other, operator.xor) # pyright: ignore[reportUnknownArgumentType] def __or__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.or_) + return self._binary_op(other, operator.or_) # pyright: ignore[reportUnknownArgumentType] def __lshift__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.lshift) + return self._binary_op(other, operator.lshift) # pyright: ignore[reportUnknownArgumentType] def __rshift__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.rshift) + return self._binary_op(other, operator.rshift) # pyright: ignore[reportUnknownArgumentType] def __lt__(self, other: _Compatible) -> Self: return self._binary_op(other, operator.lt) @@ -102,25 +102,25 @@ def __rmul__(self, other: _Compatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) def __rpow__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.pow, reflexive=True) + return self._binary_op(other, operator.pow, reflexive=True) # pyright: ignore[reportUnknownArgumentType] def __rtruediv__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.truediv, reflexive=True) + return self._binary_op(other, operator.truediv, reflexive=True) # pyright: ignore[reportUnknownArgumentType] def __rfloordiv__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.floordiv, reflexive=True) + return self._binary_op(other, operator.floordiv, reflexive=True) # pyright: ignore[reportUnknownArgumentType] def __rmod__(self, other: _Compatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) def __rand__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.and_, reflexive=True) + return self._binary_op(other, operator.and_, reflexive=True) # pyright: ignore[reportUnknownArgumentType] def __rxor__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.xor, reflexive=True) + return self._binary_op(other, operator.xor, reflexive=True) # pyright: ignore[reportUnknownArgumentType] def __ror__(self, other: _Compatible) -> Self: - return self._binary_op(other, operator.or_, reflexive=True) + return self._binary_op(other, operator.or_, reflexive=True) # pyright: ignore[reportUnknownArgumentType] def _inplace_binary_op( self, other: _Compatible, f: Callable[[Any, Any], Any] @@ -128,40 +128,40 @@ def _inplace_binary_op( raise NotImplementedError def __iadd__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.iadd) + return self._inplace_binary_op(other, operator.iadd) # pyright: ignore[reportUnknownArgumentType] def __isub__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.isub) + return self._inplace_binary_op(other, operator.isub) # pyright: ignore[reportUnknownArgumentType] def __imul__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.imul) + return self._inplace_binary_op(other, operator.imul) # pyright: ignore[reportUnknownArgumentType] def __ipow__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.ipow) + return self._inplace_binary_op(other, operator.ipow) # pyright: ignore[reportUnknownArgumentType] def __itruediv__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.itruediv) + return self._inplace_binary_op(other, operator.itruediv) # pyright: ignore[reportUnknownArgumentType] def __ifloordiv__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.ifloordiv) + return self._inplace_binary_op(other, operator.ifloordiv) # pyright: ignore[reportUnknownArgumentType] def __imod__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.imod) + return self._inplace_binary_op(other, operator.imod) # pyright: ignore[reportUnknownArgumentType] def __iand__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.iand) + return self._inplace_binary_op(other, operator.iand) # pyright: ignore[reportUnknownArgumentType] def __ixor__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.ixor) + return self._inplace_binary_op(other, operator.ixor) # pyright: ignore[reportUnknownArgumentType] def __ior__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.ior) + return self._inplace_binary_op(other, operator.ior) # pyright: ignore[reportUnknownArgumentType] def __ilshift__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.ilshift) + return self._inplace_binary_op(other, operator.ilshift) # pyright: ignore[reportUnknownArgumentType] def __irshift__(self, other: _Compatible) -> Self: - return self._inplace_binary_op(other, operator.irshift) + return self._inplace_binary_op(other, operator.irshift) # pyright: ignore[reportUnknownArgumentType] def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: raise NotImplementedError diff --git a/src/bioimageio/core/_model_adapter.py b/src/bioimageio/core/_model_adapter.py new file mode 100644 index 000000000..cfbe8033b --- /dev/null +++ b/src/bioimageio/core/_model_adapter.py @@ -0,0 +1,279 @@ +import gc +import warnings +from abc import ABC, abstractmethod +from queue import LifoQueue +from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, Union + +from exceptiongroup import ExceptionGroup +from loguru import logger +from numpy.typing import NDArray +from typing_extensions import TypeVar + +from bioimageio.spec import ValidationSummary +from bioimageio.spec.model import AnyModelDescr, v0_4 + +from ._sample_serializer import SampleSerializer, SerializedSampleBlockType +from .common import PerMember +from .digest_spec import get_axes_infos, get_member_ids +from .sample import Sample +from .tensor import Tensor + + +class ModelAdapter(ABC): + """ + Represents model *without* any preprocessing or postprocessing. + + ``` + from bioimageio.core import load_description + + model = load_description(...) + + # option 1: + adapter = create_model_adapter(model) + adapter.forward(...) + adapter.unload() + + # option 2: + with create_model_adapter(model) as adapter: + adapter.forward(...) + ``` + """ + + def __init__( + self, model_description: AnyModelDescr, devices: Optional[Sequence[str]] + ): + super().__init__() + self._model_descr = model_description + self._input_ids = get_member_ids(model_description.inputs) + self._output_ids = get_member_ids(model_description.outputs) + self._input_axes = [ + tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs + ] + self._output_axes = [ + tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs + ] + if isinstance(model_description, v0_4.ModelDescr): + self._input_is_optional = [False] * len(model_description.inputs) + else: + self._input_is_optional = [ipt.optional for ipt in model_description.inputs] + + self._devices = devices + self.load() + + @property + def model_descr(self) -> AnyModelDescr: + return self._model_descr + + @abstractmethod + def load(self) -> None: + self._loaded = True + + @abstractmethod + def forward( + self, inputs: PerMember[Optional[Tensor]] + ) -> PerMember[Optional[Tensor]]: ... + + @abstractmethod + def unload(self): + """Unload model from any devices, freeing their memory. + + Note: + The moder adapter should be considered unusable afterwards. + """ + self._loaded = False + + def close(self): + """Close the model adapter, freeing any resources. + + Note: + The moder adapter should be considered unusable afterwards. + """ + self.unload() + + +DeviceType = TypeVar("DeviceType") +ModelType = TypeVar("ModelType") + + +class LocalModelAdapter(ModelAdapter, ABC, Generic[DeviceType, ModelType]): + def load(self) -> None: + devices = self._devices + self._model_queue: LifoQueue[Tuple[DeviceType, ModelType]] = LifoQueue() + parsed_devices = self._parse_devices(devices) + assert parsed_devices + # prioritize devices by order specified by user + device_exceptions: Dict[str, Exception] = {} + self._initialized_devices: List[str] = [] + for d in parsed_devices[::-1]: + try: + model = self._init_model_on_device(d) + except Exception as e: + device_exceptions[str(d)] = e + else: + self._model_queue.put((d, model)) + self._initialized_devices.insert(0, str(d)) + + if self._model_queue.empty(): + if len(device_exceptions) == 1: + raise next(iter(device_exceptions.values())) + else: + raise ExceptionGroup( + "Failed to initialize model on any of the requested devices.", + list(device_exceptions.values())[::-1], + ) + + if device_exceptions: + logger.warning( + "Failed to initialize model on some of the requested devices. Successfully initialized on {}, but got the following errors for other devices: {}", + self._initialized_devices, + device_exceptions, + ) + + super().load() + + @abstractmethod + def _parse_devices(self, devices: Optional[Sequence[str]]) -> Sequence[DeviceType]: + """Parse devices + + Note: + - May not return an empty sequence. + - The order of devices in the returned sequence determines the priority of device usage in the forward pass. + First devices has highgest priority, last device has lowest priority. + """ + + @abstractmethod + def _init_model_on_device(self, device: DeviceType) -> ModelType: ... + + def forward( + self, inputs: PerMember[Optional[Tensor]] + ) -> PerMember[Optional[Tensor]]: + """ + Run forward pass of model to get model predictions + + Note: sample id and stample stat attributes are passed through + """ + if not self._loaded: + raise RuntimeError("Model must be `.load()`ed before calling forward()") + + unexpected = [mid for mid in inputs if mid not in self._input_ids] + if unexpected: + warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") + + input_arrays = [ + ( + None + if (a := inputs.get(in_id)) is None + else a.transpose(in_order).data.data + ) + for in_id, in_order in zip(self._input_ids, self._input_axes) + ] + logger.debug( + "NN input shapes: {}", + [a.shape if a is not None else None for a in input_arrays], + ) + device, model = self._model_queue.get() + try: + output_arrays = self._forward_impl(device, model, input_arrays) + finally: + self._model_queue.put((device, model)) + + logger.debug( + "NN output shapes: {}", + [a.shape if a is not None else None for a in output_arrays], + ) + if len(output_arrays) > len(self._output_ids): + warnings.warn( + f"Model produced more outputs ({len(output_arrays)}) than specified in the model description ({len(self._output_ids)}). Extra outputs will be ignored." + ) + output_arrays = output_arrays[: len(self._output_ids)] + + output_tensors = [ + None if a is None else Tensor(a, dims=d) + for a, d in zip(output_arrays, self._output_axes) + ] + return { + tid: out + for tid, out in zip( + self._output_ids, + output_tensors, + ) + if out is not None + } + + @abstractmethod + def _forward_impl( + self, + device: DeviceType, + model: ModelType, + input_arrays: Sequence[Optional[NDArray[Any]]], + ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]], ...]]: + """framework specific forward implementation""" + + def unload(self): + for _ in range(len(self._initialized_devices)): + device, model = self._model_queue.get() + try: + self._cleanup_pre_model_deletion(device, model) + except Exception as e: + logger.warning( + "Got error during pre-deletion cleanup on device {}: {}", device, e + ) + finally: + del model + try: + self._cleanup_post_model_deletion(device) + except Exception as e: + logger.warning( + "Got error during post-deletion cleanup on device {}: {}", device, e + ) + + _ = gc.collect() # deallocate memory + super().unload() + + @abstractmethod + def _cleanup_pre_model_deletion(self, device: DeviceType, model: ModelType) -> None: + """Clean up before model reference deletion""" + + @abstractmethod + def _cleanup_post_model_deletion(self, device: DeviceType) -> None: + """Clean up after model reference deletion""" + + +class RemoteModelAdapter(ModelAdapter, ABC, Generic[SerializedSampleBlockType]): + """Model adapter to use a remote service for model inference.""" + + def __init__( + self, + model_description: AnyModelDescr, + server: str, + sample_serializer: SampleSerializer[SerializedSampleBlockType], + ): + super().__init__(model_description, devices=None) + self._server = server + self._serializer = sample_serializer + + @property + def server(self) -> str: + return self._server + + def forward( + self, inputs: PerMember[Optional[Tensor]] + ) -> PerMember[Optional[Tensor]]: + serialized_input = self._serializer.serialize_sample( + Sample( + members={k: v for k, v in inputs.items() if v is not None}, + stat={}, + id=None, + ) + ) + serialized_output = self._forward_impl(serialized_input) + return self._serializer.deserialize_sample(serialized_output).members + + @abstractmethod + def _forward_impl( + self, serialized_input_sample: Iterable[SerializedSampleBlockType] + ) -> Iterable[SerializedSampleBlockType]: ... + + @abstractmethod + def test(self) -> Optional[ValidationSummary]: + """Run the bioimageio model test.""" diff --git a/src/bioimageio/core/_prediction_pipeline.py b/src/bioimageio/core/_prediction_pipeline.py index db769a3aa..04fac449a 100644 --- a/src/bioimageio/core/_prediction_pipeline.py +++ b/src/bioimageio/core/_prediction_pipeline.py @@ -1,4 +1,5 @@ import warnings +from abc import ABC, abstractmethod from types import MappingProxyType from typing import ( Any, @@ -6,6 +7,7 @@ List, Literal, Mapping, + NamedTuple, Optional, Sequence, Tuple, @@ -15,11 +17,14 @@ from loguru import logger from tqdm import tqdm +from typing_extensions import assert_never from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from ._model_adapter import ModelAdapter from ._op_base import BlockwiseOperator, SamplewiseOperator from .axis import AxisId, PerAxis +from .backends import create_model_adapter from .common import ( BlocksizeParameter, Halo, @@ -33,8 +38,6 @@ get_input_halo, get_member_ids, ) -from .model_adapters import ModelAdapter, create_model_adapter -from .model_adapters import get_weight_formats as get_weight_formats from .proc_ops import Processing from .proc_setup import setup_pre_and_postprocessing from .sample import Sample, SampleBlock @@ -48,7 +51,236 @@ ) -class PredictionPipeline: +class IntermediatePrediction(NamedTuple): + """Represents an intermediate prediction of a sample with blocking, including the predicted sample so far and the last predicted block. + + The final `IntermediatePrediction` in a sequence holds the complete predicted (and postprocessed if applicable) sample.""" + + sample: Sample + last_block: SampleBlock + + +class _PredictionPipelineBase(ABC): + def __init__( + self, + model_descr: AnyModelDescr, + *, + default_blocksize_parameter: BlocksizeParameter, + default_batch_size: int, + ) -> None: + super().__init__() + self._model_descr = model_descr + self._default_blocksize_parameter = default_blocksize_parameter + self._default_batch_size = default_batch_size + + if isinstance(model_descr, v0_4.ModelDescr): + self._default_output_halo: PerMember[PerAxis[Halo]] = {} + self._default_input_halo: PerMember[PerAxis[Halo]] = {} + self._block_transform = None + else: + self._default_output_halo = { + t.id: { + a.id: Halo(a.halo, a.halo) + for a in t.axes + if isinstance(a, v0_5.WithHalo) + } + for t in model_descr.outputs + } + self._default_input_halo = get_input_halo( + model_descr, self._default_output_halo + ) + self._block_transform = get_block_transform(model_descr) + + self.pad_mode = ( + {} + if isinstance(model_descr, v0_4.ModelDescr) + else { + descr.id: descr.pad or v0_5.SymmetricPadding() + for descr in model_descr.inputs + } + ) + + @property + def model_descr(self) -> AnyModelDescr: + return self._model_descr + + @property + def model_description(self) -> AnyModelDescr: + return self._model_descr + + @abstractmethod + def predict_sample_without_blocking( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + skip_input_padding: bool = False, + skip_output_cropping: bool = False, + ) -> Sample: + """Predict a whole sample at once. + + Note: + The sample's tensor shapes have to match the model's input tensor description. + If that is not the case, consider `predict_sample_with_blocking` + + Args: + sample: input sample + skip_preprocessing: if `True`, skip all preprocessing steps. + skip_postprocessing: if `True`, skip all postprocessing steps. + skip_input_padding: if `True`, skip padding the input sample according to the model's (optional) output halos. + skip_output_cropping: if `True`, skip cropping any output halos from the model output. + """ + + def predict_sample_with_blocking( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ns: Optional[ + Union[ + v0_5.ParameterizedSize_N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], + ] + ] = None, + batch_size: Optional[int] = None, + ) -> Sample: + """Predict a sample by predicting sample blocks. + + Note: For fixed/known blocksizes use `predict_sample_with_fixed_blocking`. + + Args: + sample: The sample to predict on. + skip_preprocessing: If `True`, skip all preprocessing steps. + skip_postprocessing: If `True`, skip all postprocessing steps. + ns: Block size parameter(s) allows scaling the model's default input block size. + Blocksize parameters are only applied to parameterized input axes, all other axis sizes are fixed/derived or (for output axes) data dependent. + Unapplicable blocksize parameters are ignored. + batch_size: Batch size to use for prediction. + """ + output = None + for output in self.predict_sample_with_blocking_yield_intermediates( + sample, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ns=ns, + batch_size=batch_size, + )[1]: + pass + + assert output is not None, ( + "No blocks were predicted, cannot return final sample." + ) + return output.sample + + def predict_sample_with_fixed_blocking( + self, + sample: Sample, + input_block_shape: PerMember[PerAxis[int]], + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ) -> Sample: + """Predict `sample` with given `input_block_shape`. + + Note: + - `input_block_shape` is expected to be a valid input shape for the model. + - Use `predict_sample_with_blocking` if you want to control block sizes via generic block size parameters rather than fixed block shapes. + + Args: + sample: The sample to predict on. + input_block_shape: Mapping of input member id to mapping of axis id to block size for that axis. + skip_preprocessing: If `True`, skip all preprocessing steps. + skip_postprocessing: If `True`, skip all postprocessing steps. + """ + intermediate = None + for intermediate in self.predict_sample_with_fixed_blocking_yield_intermediates( + sample, + input_block_shape=input_block_shape, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + )[1]: + pass + + assert intermediate is not None, ( + "No blocks were predicted, cannot return final sample." + ) + return intermediate.sample + + def predict_sample_with_blocking_yield_intermediates( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ns: Optional[ + Union[ + v0_5.ParameterizedSize_N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], + ] + ] = None, + batch_size: Optional[int] = None, + ) -> Tuple[int, Iterable[IntermediatePrediction]]: + """Predict `sample` by predicting sample blocks and yield intermediate predictions if no samplewise postprocessing is included. + + Returns: + Tuple of number of blocks and an iterator of predicted intermediate samples with the last predicted block, + All samples, but the last one, are intermediate samples with more and more blocks predicted. + In case samplewise postprocessing needs to be applied, no intermediate results are yielded, but only the final sample after all blocks are predicted and postprocessed. + """ + if isinstance(self._model_descr, v0_4.ModelDescr): + raise NotImplementedError( + "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" + + f" {self._model_descr.name}." + + " Consider using `predict_sample_with_fixed_blocking`" + ) + + ns = ns or self._default_blocksize_parameter + if isinstance(ns, int): + ns = { + (ipt.id, a.id): ns + for ipt in self._model_descr.inputs + for a in ipt.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + input_block_shape = self._model_descr.get_tensor_sizes( + ns, batch_size or self._default_batch_size + ).inputs + + return self.predict_sample_with_fixed_blocking_yield_intermediates( + sample, + input_block_shape=input_block_shape, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ) + + @abstractmethod + def predict_sample_with_fixed_blocking_yield_intermediates( + self, + sample: Sample, + input_block_shape: PerMember[PerAxis[int]], + *, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + fill_value: float = float("nan"), + ) -> Tuple[int, Iterable[IntermediatePrediction]]: ... + + @abstractmethod + def predict_sample_block( + self, + sample_block: SampleBlock, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ) -> SampleBlock: + """Predict a single sample block. + + Note that this does not apply samplewise preprocessing or postprocessing steps, but only blockwise ones. + + Args: + sample_block: The sample block to predict on. + skip_preprocessing: If `True`, skip blockwise preprocessing steps. + skip_postprocessing: If `True`, skip blockwise postprocessing steps. + """ + + +class PredictionPipeline(_PredictionPipelineBase): """ Represents model computation including preprocessing and postprocessing Note: Ideally use the `PredictionPipeline` in a with statement @@ -63,19 +295,15 @@ def __init__( preprocessing: List[Processing], postprocessing: List[Processing], model_adapter: ModelAdapter, - default_ns: Optional[BlocksizeParameter] = None, default_blocksize_parameter: BlocksizeParameter = 10, default_batch_size: int = 1, ) -> None: """Consider using `create_prediction_pipeline` to create a `PredictionPipeline` with sensible defaults.""" - super().__init__() - default_blocksize_parameter = default_ns or default_blocksize_parameter - if default_ns is not None: - warnings.warn( - "Argument `default_ns` is deprecated in favor of" - + " `default_blocksize_paramter` and will be removed soon." - ) - del default_ns + super().__init__( + model_descr=model_description, + default_blocksize_parameter=default_blocksize_parameter, + default_batch_size=default_batch_size, + ) if model_description.run_mode: warnings.warn( @@ -108,40 +336,10 @@ def __init__( else: self._samplewise_postprocessing.append(op) - self.pad_mode = ( - {} - if isinstance(model_description, v0_4.ModelDescr) - else { - descr.id: descr.pad or v0_5.SymmetricPadding() - for descr in model_description.inputs - } - ) - self.model_description = model_description - if isinstance(model_description, v0_4.ModelDescr): - self._default_output_halo: PerMember[PerAxis[Halo]] = {} - self._default_input_halo: PerMember[PerAxis[Halo]] = {} - self._block_transform = None - else: - self._default_output_halo = { - t.id: { - a.id: Halo(a.halo, a.halo) - for a in t.axes - if isinstance(a, v0_5.WithHalo) - } - for t in model_description.outputs - } - self._default_input_halo = get_input_halo( - model_description, self._default_output_halo - ) - self._block_transform = get_block_transform(model_description) - - self._default_blocksize_parameter = default_blocksize_parameter - self._default_batch_size = default_batch_size - self._input_ids = get_member_ids(model_description.inputs) self._output_ids = get_member_ids(model_description.outputs) - self._adapter: ModelAdapter = model_adapter + self._adapter = model_adapter def __enter__(self): self.load() @@ -152,14 +350,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore return False @property - def has_blockwise_preprocessing(self) -> bool: - """`True` if all preprocessing operators in the pipeline are blockwise.""" - return bool(self._blockwise_preprocessing) + def has_non_blockwise_preprocessing(self) -> bool: + """`True` if any preprocessing operators in the pipeline are not applicable blockwise.""" + return bool(self._samplewise_preprocessing) @property - def has_blockwise_postprocessing(self) -> bool: - """`True` if all postprocessing operators in the pipeline are blockwise.""" - return bool(self._blockwise_postprocessing) + def has_non_blockwise_postprocessing(self) -> bool: + """`True` if any postprocessing operators in the pipeline are not applicable blockwise.""" + return bool(self._samplewise_postprocessing) def _raise_for_non_blockwise_processing( self, proc_type: Literal["preprocessing", "postprocessing"] @@ -197,28 +395,25 @@ def predict_sample_block( skip_preprocessing: bool = False, skip_postprocessing: bool = False, ) -> SampleBlock: - if isinstance(self.model_description, v0_4.ModelDescr): + if isinstance(self._model_descr, v0_4.ModelDescr): raise NotImplementedError( - f"predict_sample_block not implemented for model {self.model_description.format_version}" + f"predict_sample_block not implemented for model {self._model_descr.format_version}" ) else: assert self._block_transform is not None if not skip_preprocessing: - self.raise_for_non_blockwise_preprocessing() - - if not skip_postprocessing: - self.raise_for_non_blockwise_postprocessing() - - if not skip_preprocessing: - self.apply_preprocessing(sample_block) + self._apply_blockwise_preprocessing(sample_block) output_meta = sample_block.get_transformed_meta(self._block_transform) - local_output = self._adapter.forward(sample_block) + local_output = self._adapter.forward(sample_block.members) - output = output_meta.with_data(local_output.members, stat=local_output.stat) + output = output_meta.with_data( + {k: v for k, v in local_output.items() if v is not None}, + stat=sample_block.stat, + ) if not skip_postprocessing: - self.apply_postprocessing(output) + self._apply_blockwise_postprocessing(output) return output @@ -230,26 +425,21 @@ def predict_sample_without_blocking( skip_input_padding: bool = False, skip_output_cropping: bool = False, ) -> Sample: - """predict a whole sample - - Args: - sample: input sample - skip_preprocessing: if `True`, skip all preprocessing steps. - skip_postprocessing: if `True`, skip all postprocessing steps. - skip_input_padding: if `True`, skip padding the input sample according to the model's (optional) output halos. - skip_output_cropping: if `True`, skip cropping any output halos from the model output. - Note: - The sample's tensor shapes have to match the model's input tensor description. - If that is not the case, consider `predict_sample_with_blocking` - """ - if not skip_input_padding: sample = sample.pad(pad_width=self._default_input_halo, mode=self.pad_mode) if not skip_preprocessing: self.apply_preprocessing(sample) - output = self._adapter.forward(sample) + output = Sample( + members={ + k: v + for k, v in self._adapter.forward(sample.members).items() + if v is not None + }, + stat=sample.stat, + id=sample.id, + ) if not skip_postprocessing: self.apply_postprocessing(output) @@ -276,143 +466,225 @@ def get_output_sample_id(self, input_sample_id: SampleId): ) return input_sample_id - def predict_sample_with_fixed_blocking( + def predict_sample_with_blocking( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ns: Optional[ + Union[ + v0_5.ParameterizedSize_N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], + ] + ] = None, + batch_size: Optional[int] = None, + ) -> Sample: + output = None + for output in self.predict_sample_with_blocking_yield_intermediates( + sample, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ns=ns, + batch_size=batch_size, + )[1]: + pass + + assert output is not None, ( + "No blocks were predicted, cannot return final sample." + ) + return output.sample + + def predict_sample_with_fixed_blocking_yield_intermediates( self, sample: Sample, input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], *, skip_preprocessing: bool = False, skip_postprocessing: bool = False, - ) -> Sample: - """Predict `sample` with given `input_block_shape`. + fill_value: float = float("nan"), + ) -> Tuple[int, Iterable[IntermediatePrediction]]: + """Predict `sample` with given `input_block_shape` and yield the full sample with intermediate results. + + Note: + - `input_block_shape` is expected to be a valid input shape for the model. + - Use `predict_sample_with_blocking` if you want to control block sizes via generic block size parameters + rather than fixed block shapes. + - Postprocessing may only be complete for the final sample (if samplewise postprocessing steps are included + in the pipeline), intermediate samples may have some (blockwise applicable) postprocessing steps applied. - Note: - `input_block_shape` is expected to be a valid input shape for the model. + Args: + sample: The sample to predict on. + input_block_shape: Mapping of input member id to mapping of axis id to block size for that axis. + skip_preprocessing: If `True`, skip all preprocessing steps. + skip_postprocessing: If `True`, skip all postprocessing steps. + + Returns: + Tuple of number of blocks and an iterable of predicted intermediate samples with the last predicted block, + All samples, but the last one, are intermediate samples with more and more blocks predicted. """ + if not skip_preprocessing: - for op in self._samplewise_preprocessing: - op(sample) + self._apply_samplewise_preprocessing(sample) n_blocks, input_blocks = sample.split_into_blocks( input_block_shape, halo=self._default_input_halo, pad_mode=self.pad_mode, ) - input_blocks = list(input_blocks) - predicted_blocks: List[SampleBlock] = [] logger.info( "split sample shape {} into {} blocks of {}.", {k: dict(v) for k, v in sample.shape.items()}, n_blocks, {k: dict(v) for k, v in input_block_shape.items()}, ) - for b in tqdm( - input_blocks, - desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", - unit="block", - unit_divisor=1, - total=n_blocks, - ): - if not skip_preprocessing: - for op in self._blockwise_preprocessing: - op(b) - - predicted_blocks.append( - self.predict_sample_block( + + def _predict_blocks(): + predicted_sample = None + for i, b in enumerate( + tqdm( + input_blocks, + desc=f"predict sample {sample.id or ''} with {self._model_descr.id or self._model_descr.name}", + unit="block", + unit_divisor=1, + total=n_blocks, + ) + ): + if not skip_preprocessing: + self._apply_blockwise_preprocessing(b) + + predicted_block = self.predict_sample_block( b, skip_preprocessing=True, skip_postprocessing=True ) - ) - if not skip_postprocessing: - for op in self._blockwise_postprocessing: - op(predicted_blocks[-1]) - predicted_sample = Sample.from_blocks(predicted_blocks) - if not skip_postprocessing: - for op in self._samplewise_postprocessing: - op(predicted_sample) + if not skip_postprocessing: + self._apply_blockwise_postprocessing(predicted_block) - return predicted_sample + if predicted_sample is None: + predicted_sample = Sample.from_blocks( + [predicted_block], fill_value=fill_value + ) + else: + predicted_sample.set_block(predicted_block) - def predict_sample_with_blocking( - self, - sample: Sample, - skip_preprocessing: bool = False, - skip_postprocessing: bool = False, - ns: Optional[ - Union[ - v0_5.ParameterizedSize_N, - Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], - ] - ] = None, - batch_size: Optional[int] = None, - ) -> Sample: - """Predict a sample by splitting it into blocks according to the mode + if not skip_postprocessing and i == n_blocks - 1: + self._apply_samplewise_postprocessing(predicted_sample) + + yield IntermediatePrediction(predicted_sample, predicted_block) + + return n_blocks, _predict_blocks() - The `ns` parameter allow scaling the model's default input block size. + def _apply_samplewise_preprocessing(self, sample: Sample, /) -> None: + """Apply preprocessing operators up to and including the last samplewise operator in-place. + + Note: This skips all blockwise preprocessing steps after the last samplewise operator. """ + if isinstance(sample, SampleBlock): + self.raise_for_non_blockwise_preprocessing() - if isinstance(self.model_description, v0_4.ModelDescr): - raise NotImplementedError( - "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" - + f" {self.model_description.name}." - + " Consider using `predict_sample_with_fixed_blocking`" - ) + for op in self._samplewise_preprocessing: + op(sample) - ns = ns or self._default_blocksize_parameter - if isinstance(ns, int): - ns = { - (ipt.id, a.id): ns - for ipt in self.model_description.inputs - for a in ipt.axes - if isinstance(a.size, v0_5.ParameterizedSize) - } - input_block_shape = self.model_description.get_tensor_sizes( - ns, batch_size or self._default_batch_size - ).inputs + def _apply_blockwise_preprocessing( + self, sample_block: Union[Sample, SampleBlock], / + ) -> None: + """Apply blockwise preprocessing operators in-place. - return self.predict_sample_with_fixed_blocking( - sample, - input_block_shape=input_block_shape, - skip_preprocessing=skip_preprocessing, - skip_postprocessing=skip_postprocessing, - ) + Note: This skips all preprocessing operators up to and including the last samplewise one. + """ + for op in self._blockwise_preprocessing: + op(sample_block) def apply_preprocessing(self, sample: Union[Sample, SampleBlock]) -> None: - """apply preprocessing in-place, also may updates sample stats""" - if isinstance(sample, SampleBlock): + """Apply preprocessing in-place, also may updates sample stats""" + + if isinstance(sample, Sample): + self._apply_samplewise_preprocessing(sample) + else: self.raise_for_non_blockwise_preprocessing() - for op in self._samplewise_preprocessing + self._blockwise_preprocessing: - if isinstance(sample, SampleBlock): - assert isinstance(op, BlockwiseOperator) - op(sample) - else: - op(sample) + self._apply_blockwise_preprocessing(sample) - def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None: - """apply postprocessing in-place, also may updates samples stats""" + def _apply_blockwise_postprocessing( + self, sample_block: Union[Sample, SampleBlock], / + ) -> None: + """Apply in-place blockwise postprocessing operators + + Note: This does not apply all postprocessing operators from the first samplewise one onwards. + """ + for op in self._blockwise_postprocessing: + op(sample_block) + + def _apply_samplewise_postprocessing(self, sample: Sample, /) -> None: + """Apply in-place postprocessing operators starting from and including the first samplewise operator. + + Note: This skips all blockwise postprocessing steps before the first samplewise one. + """ if isinstance(sample, SampleBlock): self.raise_for_non_blockwise_postprocessing() - for op in self._blockwise_postprocessing + self._samplewise_postprocessing: - if isinstance(sample, SampleBlock): - assert isinstance(op, BlockwiseOperator) - op(sample) - else: - op(sample) + for op in self._samplewise_postprocessing: + op(sample) + + def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None: + """apply postprocessing in-place, also may updates samples stats""" + self._apply_blockwise_postprocessing(sample) + if isinstance(sample, Sample): + self._apply_samplewise_postprocessing(sample) + else: + self.raise_for_non_blockwise_postprocessing() def load(self): + """Prepare prediction pipeline for use. + + Reusable model adapters may be loaded and unloaded multiple times, but currently not all model adapters + cleanly unload and reload. + + Note: + For some model adapters loading is currently part of the constructor making them unusable after unloading. """ - optional step: load model onto devices before calling forward if not using it as context manager - """ - pass + self._adapter.load() def unload(self): - """ - free any device memory in use - """ + """Free any device memory in use. + + Note: + Currently prediction pipeline becomes unusable after unloading.""" self._adapter.unload() + def close(self): + """Permanently close the prediction pipeline and free any device memory in use. + This makes the prediction pipeline unusable afterwards.""" + self.unload() + + +class RemotePredictionPipeline(_PredictionPipelineBase): + """Abstract base class for fully remote prediction pipelines. + + Note: A ("local") `PredictionPipeline` may also use a `RemoteModelAdapter` for remote model inference, but it may + still apply local preprocessing and postprocessing steps. + In contrast, a `RemotePredictionPipeline` is designed for the case where all steps including preprocessing and + postprocessing are performed remotely. + """ + + def __init__( + self, + model_descr: AnyModelDescr, + *, + server: str, + default_blocksize_parameter: BlocksizeParameter, + default_batch_size: int, + ) -> None: + super().__init__( + model_descr, + default_blocksize_parameter=default_blocksize_parameter, + default_batch_size=default_batch_size, + ) + self._server = server + + @property + def server(self) -> str: + return self._server + def create_prediction_pipeline( bioimageio_model: AnyModelDescr, @@ -499,3 +771,50 @@ def dataset(): postprocessing=postprocessing, default_blocksize_parameter=default_blocksize_parameter, ) + + +def create_remote_prediction_pipeline( + model_description: AnyModelDescr, + *, + server: Optional[str] = None, + server_type: Optional[Literal["gradio"]] = "gradio", + precomputed_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}), + default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo + default_batch_size: int = 1, +) -> RemotePredictionPipeline: + """Create a `RemotePredictionPipeline` for the given `model_description`. + + Args: + model_description: The model to run inference with. + server: The URL or Hugging Face space name of a running bioimageio server instance + server_type: The type of the remote server to connect to. Currently only "gradio" is supported. + precomputed_statistics: Precomputed dataset (and optionally sample) statistics. + Any included sample statistics will not be calculated on the fly and it is the callers + responsibility to use samples with the corresponding statistics availble in `sample.stat`. + default_blocksize_parameter: Allows to control the default block size with a single parameter for blockwise predictions. (not all models support this) + default_batch_size: Default batch size to use + """ + + if server_type is None: + server_type = "gradio" + + try: + if server_type == "gradio": + from .remote_backends.gradio.client import ( + GradioPredictionPipeline as RemotePredictionPipelineImpl, + ) + else: + assert_never(server_type) + except ImportError as e: + raise ImportError( + f"Failed to import {server_type.capitalize()}PredictionPipeline. Make sure to install the '{server_type}-client' extra," + + f" e.g. with `pip install bioimageio.core[{server_type}-client]`." + ) from e + + return RemotePredictionPipelineImpl( + model_description, + server=server, + precomputed_statistics=precomputed_statistics, + default_blocksize_parameter=default_blocksize_parameter, + default_batch_size=default_batch_size, + ) diff --git a/src/bioimageio/core/_resource_tests.py b/src/bioimageio/core/_resource_tests.py index fa528fc11..1f0b4c180 100644 --- a/src/bioimageio/core/_resource_tests.py +++ b/src/bioimageio/core/_resource_tests.py @@ -818,6 +818,70 @@ def _get_tolerance( return rtol, atol, mismatched_tol +def evaluate_mismatched_elements( + actual: Tensor, expected: Tensor, rtol: float, atol: float, name: str +) -> Tuple[float, str, Optional[str]]: + try: + expected_np = expected.data.to_numpy().astype(np.float32) + dims = expected.dims + del expected + actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) + del actual + + rtol_value = rtol * abs(expected_np) + abs_diff = abs(actual_np - expected_np) + mismatched = abs_diff > atol + rtol_value + mismatched_elements = mismatched.sum().item() + + mismatched_ppm = mismatched_elements / expected_np.size * 1e6 + abs_diff[~mismatched] = 0 # ignore non-mismatched elements + + r_max_idx_flat = (r_diff := (abs_diff / (abs(expected_np) + 1e-6))).argmax() + r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) + r_max = r_diff[r_max_idx].item() + r_actual = actual_np[r_max_idx].item() + r_expected = expected_np[r_max_idx].item() + + # Calculate the max absolute difference with the relative tolerance subtracted + abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value + a_max_idx = np.unravel_index(abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape) + + a_max = abs_diff[a_max_idx].item() + a_actual = actual_np[a_max_idx].item() + a_expected = expected_np[a_max_idx].item() + except Exception as e: + mismatched_ppm = -1 + msg = "" + error_msg = ( + f"Error while checking if '{name}' disagrees with expected values: {e}" + ) + else: + error_msg = None + if mismatched_elements: + msg = ( + f"Output '{name}': {mismatched_elements} of " + + f"{expected_np.size} elements disagree with expected values (" + + ( + f"{mismatched_ppm * 10_000:.1f}%" + if mismatched_ppm >= 1_000 + else f"{mismatched_ppm:.1f} ppm" + ) + + "). " + ) + else: + msg = f"Output `{name}`: all elements agree with expected values. " + + msg += ( + f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}" + + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" + + f" at {dict(zip(dims, r_max_idx))} " + + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}" + + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" + ) + + return mismatched_ppm, msg, error_msg + + def _test_recreate_test_outputs( model: Union[v0_4.ModelDescr, v0_5.ModelDescr], weight_format: SupportedWeightsFormat, @@ -917,7 +981,7 @@ def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]: else: continue - if actual.dims != (dims := expected.dims): + if actual.dims != expected.dims: add_error_entry( f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" ) @@ -944,72 +1008,32 @@ def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]: results_not_postprocessed.members[m], ) ) + except Exception as e: + logger.error(f"Failed to save actual output tensor for '{m}': {e}") + output_paths = None - expected_np = expected.data.to_numpy().astype(np.float32) - del expected - actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) + rtol, atol, mismatched_tol = _get_tolerance( + model, wf=weight_format, m=m, **deprecated + ) + mismatched_ppm, msg, error_msg = evaluate_mismatched_elements( + actual, expected, rtol, atol, m + ) + if error_msg is not None: + add_error_entry(error_msg) + if stop_early: + break - rtol, atol, mismatched_tol = _get_tolerance( - model, wf=weight_format, m=m, **deprecated - ) - rtol_value = rtol * abs(expected_np) - abs_diff = abs(actual_np - expected_np) - mismatched = abs_diff > atol + rtol_value - mismatched_elements = mismatched.sum().item() - - mismatched_ppm = mismatched_elements / expected_np.size * 1e6 - abs_diff[~mismatched] = 0 # ignore non-mismatched elements - - r_max_idx_flat = ( - r_diff := (abs_diff / (abs(expected_np) + 1e-6)) - ).argmax() - r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) - r_max = r_diff[r_max_idx].item() - r_actual = actual_np[r_max_idx].item() - r_expected = expected_np[r_max_idx].item() - - # Calculate the max absolute difference with the relative tolerance subtracted - abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value - a_max_idx = np.unravel_index( - abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape - ) + if output_paths: + msg += f"\n Saved (intermediate) outputs to {output_paths}." - a_max = abs_diff[a_max_idx].item() - a_actual = actual_np[a_max_idx].item() - a_expected = expected_np[a_max_idx].item() - except Exception as e: - msg = f"Error while checking if '{m}' disagrees with expected values: {e}" + if mismatched_ppm > mismatched_tol: add_error_entry(msg) if stop_early: break else: - if mismatched_elements: - msg = ( - f"Output '{m}': {mismatched_elements} of " - + f"{expected_np.size} elements disagree with expected values." - + f" ({mismatched_ppm:.1f} ppm). " - ) - else: - msg = f"Output `{m}`: all elements agree with expected values. " - - msg += ( - f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}" - + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" - + f" at {dict(zip(dims, r_max_idx))} " - + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}" - + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" + add_warning_entry( + msg, severity=WARNING if mismatched_ppm != 0 else INFO ) - if output_paths: - msg += f"\n Saved (intermediate) outputs to {output_paths}." - - if mismatched_ppm > mismatched_tol: - add_error_entry(msg) - if stop_early: - break - else: - add_warning_entry( - msg, severity=WARNING if mismatched_elements else INFO - ) except Exception as e: if get_validation_context().raise_errors: diff --git a/src/bioimageio/core/_sample_serializer.py b/src/bioimageio/core/_sample_serializer.py new file mode 100644 index 000000000..2d3f461c6 --- /dev/null +++ b/src/bioimageio/core/_sample_serializer.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from typing import ( + Generic, + Iterable, + Tuple, + TypeVar, + Union, +) + +from bioimageio.spec.model import v0_5 + +from .axis import PerAxis +from .common import HaloLike, PadMode, PerMember +from .digest_spec import split_sample_into_blocks_for_model +from .sample import Sample, SampleBlock + +SerializedSampleBlockType = TypeVar("SerializedSampleBlockType") + + +class SampleSerializer(ABC, Generic[SerializedSampleBlockType]): + @classmethod + def serialize_sample( + cls, + sample: Sample, + ) -> Tuple[SerializedSampleBlockType]: + """Serialize a sample as a single block""" + return (cls.serialize_sample_block(sample.as_single_block()),) + + @classmethod + def deserialize_sample( + cls, + serialized: Iterable[SerializedSampleBlockType], + fill_value: float = float("nan"), + ) -> Sample: + return Sample.from_blocks( + (cls.deserialize_sample_block(s) for s in serialized), fill_value=fill_value + ) + + def serialize_sample_blockwise( + self, + sample: Sample, + *, + model: v0_5.ModelDescr, + blocksize_parameter: int, + batch_size: int = 1, + ) -> Iterable[SerializedSampleBlockType]: + """Split a sample into blocks according to the model's input specifications and `blocksize_parameter` and serialize each block.""" + + _n_blocks, blocks = split_sample_into_blocks_for_model( + sample, + model=model, + blocksize_parameter=blocksize_parameter, + batch_size=batch_size, + ) + for block in blocks: + yield self.serialize_sample_block(block) + + @classmethod + def serialize_sample_with_fixed_blocking( + cls, + sample: Sample, + *, + block_shapes: PerMember[PerAxis[int]], + halo: PerMember[PerAxis[HaloLike]], + pad_mode: Union[PadMode, PerMember[PadMode]] = "symmetric", + ) -> Iterable[SerializedSampleBlockType]: + + _n_blocks, input_blocks = sample.split_into_blocks( + block_shapes=block_shapes, + halo=halo, + pad_mode=pad_mode, + ) + for block in input_blocks: + yield cls.serialize_sample_block(block) + + @staticmethod + @abstractmethod + def serialize_sample_block( + sample_block: SampleBlock, + ) -> SerializedSampleBlockType: ... + + @staticmethod + @abstractmethod + def deserialize_sample_block(serialized: SerializedSampleBlockType) -> SampleBlock: + """Deserialize a sample block into a new sample or merge it into `output_sample` if provided.""" diff --git a/src/bioimageio/core/_settings.py b/src/bioimageio/core/_settings.py index 6e7675f56..16b460683 100644 --- a/src/bioimageio/core/_settings.py +++ b/src/bioimageio/core/_settings.py @@ -45,6 +45,18 @@ def _set_default_mps_fallback(cls, value: Optional[bool]): ) """URL to the bioimageio collection config""" + gradio_server: Optional[str] = None + """URL or Hugging Face space name to connect to with the remote gradio model adapter or remote gradio prediction pipeline. + + Example: "bioimage-io/bioimage-io-gradio-server" + """ + + gradio_server_model_cache_max_size: int = 10 + """Max number of models to cache in the gradio server for prediction pipelines using the gradio backend.""" + + gradio_server_model_cache_max_memory: str = "40GB" + """Max memory to use for model caching in the gradio server for prediction pipelines using the gradio backend.""" + settings = Settings() """parsed environment variables for bioimageio.spec and bioimageio.core""" diff --git a/src/bioimageio/core/axis.py b/src/bioimageio/core/axis.py index 07ce7e042..6c85b8b3a 100644 --- a/src/bioimageio/core/axis.py +++ b/src/bioimageio/core/axis.py @@ -1,9 +1,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal, Mapping, Optional, TypeVar, Union +from typing import ( + Literal, + Mapping, + Optional, + TypeVar, + Union, +) -from typing_extensions import Protocol, assert_never, runtime_checkable +from typing_extensions import Protocol, TypeAlias, assert_never, runtime_checkable from bioimageio.spec.model import v0_5 @@ -33,11 +39,12 @@ def _guess_axis_type(a: str): S = TypeVar("S", bound=str) -AxisId = v0_5.AxisId +AxisId: TypeAlias = v0_5.AxisId """An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'""" -T = TypeVar("T") -PerAxis = Mapping[AxisId, T] +_T = TypeVar("_T") +PerAxis = Mapping[AxisId, _T] + BatchSize = int diff --git a/src/bioimageio/core/backends/__init__.py b/src/bioimageio/core/backends/__init__.py index c39b58b58..7fc53d589 100644 --- a/src/bioimageio/core/backends/__init__.py +++ b/src/bioimageio/core/backends/__init__.py @@ -1,3 +1,127 @@ -from ._model_adapter import create_model_adapter +from typing import ( + List, + Optional, + Sequence, + Tuple, + Union, +) -__all__ = ["create_model_adapter"] +from exceptiongroup import ExceptionGroup +from typing_extensions import assert_never + +from bioimageio.spec.model import v0_4, v0_5 + +from ..common import SupportedWeightsFormat + +# Known weight formats in order of priority +# First match wins +DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = ( + "pytorch_state_dict", + "tensorflow_saved_model_bundle", + "torchscript", + "onnx", + "keras_v3", + "keras_hdf5", +) + + +def create_model_adapter( + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + *, + devices: Optional[Sequence[str]] = None, + weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, +): + """Creates model adapter for `model_descritption`""" + if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise TypeError( + f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" + ) + + weights = model_description.weights + errors: List[Exception] = [] + weight_format_priority_order = ( + DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER + if weight_format_priority_order is None + else weight_format_priority_order + ) + # limit weight formats to the ones present + weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ + w for w in weight_format_priority_order if getattr(weights, w, None) is not None + ] + if not weight_format_priority_order_present: + raise ValueError( + f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" + ) + + for wf in weight_format_priority_order_present: + if wf == "pytorch_state_dict": + assert weights.pytorch_state_dict is not None + try: + from .pytorch_backend import PytorchModelAdapter + + return PytorchModelAdapter(model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "tensorflow_saved_model_bundle": + assert weights.tensorflow_saved_model_bundle is not None + try: + from .tensorflow_backend import create_tf_model_adapter + + return create_tf_model_adapter(model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "onnx": + assert weights.onnx is not None + try: + from .onnx_backend import ONNXModelAdapter + + return ONNXModelAdapter(model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "torchscript": + assert weights.torchscript is not None + try: + from .torchscript_backend import TorchscriptModelAdapter + + return TorchscriptModelAdapter(model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "keras_hdf5": + assert weights.keras_hdf5 is not None + # keras can either be installed as a separate package or used as part of tensorflow + # we try to first import the keras model adapter using the separate package and, + # if it is not available, try to load the one using tf + try: + try: + from .keras_backend import KerasModelAdapter + except Exception: + from .tensorflow_backend import KerasModelAdapter + + return KerasModelAdapter(model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "keras_v3": + assert not isinstance(weights, v0_4.WeightsDescr), ( + "keras_v3 weights not supported for v0.4 specs" + ) + assert weights.keras_v3 is not None + try: + from .keras_backend import KerasModelAdapter + + return KerasModelAdapter(model_description, devices=devices) + except Exception as e: + errors.append(e) + else: + assert_never(wf) + + assert errors + if len(weight_format_priority_order) == 1: + assert len(errors) == 1 + raise errors[0] + + else: + msg = ( + "None of the weight format specific model adapters could be created" + + " in this environment." + ) + raise ExceptionGroup(msg, errors) diff --git a/src/bioimageio/core/backends/_model_adapter.py b/src/bioimageio/core/backends/_model_adapter.py deleted file mode 100644 index e9bfeaa9c..000000000 --- a/src/bioimageio/core/backends/_model_adapter.py +++ /dev/null @@ -1,270 +0,0 @@ -import warnings -from abc import ABC, abstractmethod -from typing import ( - Any, - List, - Optional, - Sequence, - Tuple, - Union, - final, -) - -from exceptiongroup import ExceptionGroup -from loguru import logger -from numpy.typing import NDArray -from typing_extensions import assert_never - -from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 - -from ..common import SupportedWeightsFormat -from ..digest_spec import get_axes_infos, get_member_ids -from ..sample import Sample, SampleBlock -from ..tensor import Tensor - -# Known weight formats in order of priority -# First match wins -DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = ( - "pytorch_state_dict", - "tensorflow_saved_model_bundle", - "torchscript", - "onnx", - "keras_v3", - "keras_hdf5", -) - - -class ModelAdapter(ABC): - """ - Represents model *without* any preprocessing or postprocessing. - - ``` - from bioimageio.core import load_description - - model = load_description(...) - - # option 1: - adapter = ModelAdapter.create(model) - adapter.forward(...) - adapter.unload() - - # option 2: - with ModelAdapter.create(model) as adapter: - adapter.forward(...) - ``` - """ - - def __init__(self, model_description: AnyModelDescr): - super().__init__() - self._model_descr = model_description - self._input_ids = get_member_ids(model_description.inputs) - self._output_ids = get_member_ids(model_description.outputs) - self._input_axes = [ - tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs - ] - self._output_axes = [ - tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs - ] - if isinstance(model_description, v0_4.ModelDescr): - self._input_is_optional = [False] * len(model_description.inputs) - else: - self._input_is_optional = [ipt.optional for ipt in model_description.inputs] - - @final - @classmethod - def create( - cls, - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - *, - devices: Optional[Sequence[str]] = None, - weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, - ): - """ - Creates model adapter based on the passed spec - Note: All specific adapters should happen inside this function to prevent different framework - initializations interfering with each other - """ - if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): - raise TypeError( - f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" - ) - - weights = model_description.weights - errors: List[Exception] = [] - weight_format_priority_order = ( - DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER - if weight_format_priority_order is None - else weight_format_priority_order - ) - # limit weight formats to the ones present - weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ - w - for w in weight_format_priority_order - if getattr(weights, w, None) is not None - ] - if not weight_format_priority_order_present: - raise ValueError( - f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" - ) - - for wf in weight_format_priority_order_present: - if wf == "pytorch_state_dict": - assert weights.pytorch_state_dict is not None - try: - from .pytorch_backend import PytorchModelAdapter - - return PytorchModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append(e) - elif wf == "tensorflow_saved_model_bundle": - assert weights.tensorflow_saved_model_bundle is not None - try: - from .tensorflow_backend import create_tf_model_adapter - - return create_tf_model_adapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append(e) - elif wf == "onnx": - assert weights.onnx is not None - try: - from .onnx_backend import ONNXModelAdapter - - return ONNXModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append(e) - elif wf == "torchscript": - assert weights.torchscript is not None - try: - from .torchscript_backend import TorchscriptModelAdapter - - return TorchscriptModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append(e) - elif wf == "keras_hdf5": - assert weights.keras_hdf5 is not None - # keras can either be installed as a separate package or used as part of tensorflow - # we try to first import the keras model adapter using the separate package and, - # if it is not available, try to load the one using tf - try: - try: - from .keras_backend import KerasModelAdapter - except Exception: - from .tensorflow_backend import KerasModelAdapter - - return KerasModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append(e) - elif wf == "keras_v3": - assert not isinstance(weights, v0_4.WeightsDescr), ( - "keras_v3 weights not supported for v0.4 specs" - ) - assert weights.keras_v3 is not None - try: - from .keras_backend import KerasModelAdapter - - return KerasModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append(e) - else: - assert_never(wf) - - assert errors - if len(weight_format_priority_order) == 1: - assert len(errors) == 1 - raise errors[0] - - else: - msg = ( - "None of the weight format specific model adapters could be created" - + " in this environment." - ) - raise ExceptionGroup(msg, errors) - - @final - def load(self, *, devices: Optional[Sequence[str]] = None) -> None: - warnings.warn("Deprecated. ModelAdapter is loaded on initialization") - - def forward(self, input_sample: Union[Sample, SampleBlock]) -> Sample: - """ - Run forward pass of model to get model predictions - - Note: sample id and stample stat attributes are passed through - """ - unexpected = [mid for mid in input_sample.members if mid not in self._input_ids] - if unexpected: - warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") - - input_arrays = [ - ( - None - if (a := input_sample.members.get(in_id)) is None - else a.transpose(in_order).data.data - ) - for in_id, in_order in zip(self._input_ids, self._input_axes) - ] - logger.debug( - "NN input shapes: {}", - [a.shape if a is not None else None for a in input_arrays], - ) - output_arrays = self._forward_impl(input_arrays) - logger.debug( - "NN output shapes: {}", - [a.shape if a is not None else None for a in output_arrays], - ) - if len(output_arrays) > len(self._output_ids): - warnings.warn( - f"Model produced more outputs ({len(output_arrays)}) than specified in the model description ({len(self._output_ids)}). Extra outputs will be ignored." - ) - output_arrays = output_arrays[: len(self._output_ids)] - - output_tensors = [ - None if a is None else Tensor(a, dims=d) - for a, d in zip(output_arrays, self._output_axes) - ] - return Sample( - members={ - tid: out - for tid, out in zip( - self._output_ids, - output_tensors, - ) - if out is not None - }, - stat=input_sample.stat, - id=( - input_sample.id - if isinstance(input_sample, Sample) - else input_sample.sample_id - ), - ) - - @abstractmethod - def _forward_impl( - self, input_arrays: Sequence[Optional[NDArray[Any]]] - ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]: - """framework specific forward implementation""" - - @abstractmethod - def unload(self): - """ - Unload model from any devices, freeing their memory. - The moder adapter should be considered unusable afterwards. - """ - - def _get_input_args_numpy(self, input_sample: Sample): - """helper to extract tensor args as transposed numpy arrays""" - - -create_model_adapter = ModelAdapter.create diff --git a/src/bioimageio/core/backends/keras_backend.py b/src/bioimageio/core/backends/keras_backend.py index 9e56c6d19..fbdfe498d 100644 --- a/src/bioimageio/core/backends/keras_backend.py +++ b/src/bioimageio/core/backends/keras_backend.py @@ -2,7 +2,7 @@ import shutil from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Tuple from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs] legacy_h5_format, @@ -11,13 +11,12 @@ from numpy.typing import NDArray from bioimageio.spec._internal.version_type import Version -from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model import v0_4 +from .._model_adapter import LocalModelAdapter from .._settings import settings -from ..digest_spec import get_axes_infos from ..utils._compare import warn_about_version from ..utils._type_guards import is_list, is_tuple -from ._model_adapter import ModelAdapter os.environ["KERAS_BACKEND"] = settings.keras_backend @@ -35,25 +34,27 @@ tf_version = None -class KerasModelAdapter(ModelAdapter): - def __init__( - self, - *, - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - devices: Optional[Sequence[str]] = None, - ) -> None: - super().__init__(model_description=model_description) +class KerasModelAdapter(LocalModelAdapter[None, Any]): + def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]: + # TODO keras device management + if devices is not None: + logger.warning( + "Device management is not implemented for keras yet, ignoring the devices {}", + devices, + ) + return (None,) + def _init_model_on_device(self, device: None) -> Any: if ( - not isinstance(model_description, v0_4.ModelDescr) - and model_description.weights.keras_v3 is not None + not isinstance(self._model_descr, v0_4.ModelDescr) + and self._model_descr.weights.keras_v3 is not None ): - weight_reader = model_description.weights.keras_v3.get_reader() - backend, backend_version = model_description.weights.keras_v3.backend - elif model_description.weights.keras_hdf5 is not None: + weight_reader = self._model_descr.weights.keras_v3.get_reader() + backend, backend_version = self._model_descr.weights.keras_v3.backend + elif self._model_descr.weights.keras_hdf5 is not None: backend = "legacy_tensorflow" - backend_version = model_description.weights.keras_hdf5.tensorflow_version - weight_reader = model_description.weights.keras_hdf5.get_reader() + backend_version = self._model_descr.weights.keras_hdf5.tensorflow_version + weight_reader = self._model_descr.weights.keras_hdf5.get_reader() else: raise ValueError("model has no Keras weights") @@ -81,41 +82,33 @@ def __init__( jax_version = Version(jax.__version__) warn_about_version("jax", backend_version, jax_version) - # TODO keras device management - if devices is not None: - logger.warning( - "Device management is not implemented for keras yet, ignoring the devices {}", - devices, - ) - if weight_reader.suffix in (".h5", "hdf5"): import h5py # pyright: ignore[reportMissingTypeStubs] h5_file = h5py.File(weight_reader, mode="r") - self._network = legacy_h5_format.load_model_from_hdf5(h5_file) + return legacy_h5_format.load_model_from_hdf5(h5_file) # pyright: ignore[reportUnknownVariableType] else: with TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) / weight_reader.original_file_name with temp_path.open("wb") as f: shutil.copyfileobj(weight_reader, f) - self._network = keras.models.load_model(temp_path) - - self._output_axes = [ - tuple(a.id for a in get_axes_infos(out)) - for out in model_description.outputs - ] + return keras.models.load_model(temp_path) # pyright: ignore[reportUnknownVariableType] - def _forward_impl( # pyright: ignore[reportUnknownParameterType] - self, input_arrays: Sequence[Optional[NDArray[Any]]] + def _forward_impl( + self, + device: None, + model: Any, + input_arrays: Sequence[Optional[NDArray[Any]]], ): - network_output = self._network.predict(*input_arrays) # type: ignore + network_output = model.predict(*input_arrays) if is_list(network_output) or is_tuple(network_output): return network_output else: - return [network_output] # pyright: ignore[reportUnknownVariableType] + return [network_output] + + def _cleanup_pre_model_deletion(self, device: None, model: Any) -> None: + return - def unload(self) -> None: - logger.warning( - "Device management is not implemented for keras yet, cannot unload model" - ) + def _cleanup_post_model_deletion(self, device: None) -> None: + return diff --git a/src/bioimageio/core/backends/onnx_backend.py b/src/bioimageio/core/backends/onnx_backend.py index 6a752f8fb..0dab80c12 100644 --- a/src/bioimageio/core/backends/onnx_backend.py +++ b/src/bioimageio/core/backends/onnx_backend.py @@ -1,35 +1,37 @@ # pyright: reportUnknownVariableType=false import shutil import tempfile -import warnings from contextlib import contextmanager, nullcontext from pathlib import Path from typing import Any, List, Optional, Sequence, Union, cast import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs] -from exceptiongroup import ExceptionGroup from loguru import logger from numpy.typing import NDArray from bioimageio.spec.model import v0_4, v0_5 -from ..model_adapters import ModelAdapter +from .._model_adapter import LocalModelAdapter from ..utils._type_guards import is_list, is_tuple -class ONNXModelAdapter(ModelAdapter): +class ONNXModelAdapter(LocalModelAdapter[Optional[str], rt.InferenceSession]): def __init__( self, - *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ): - super().__init__(model_description=model_description) - onnx_descr = model_description.weights.onnx if onnx_descr is None: raise ValueError("No ONNX weights specified for {model_description.name}") + self._onnx_descr = onnx_descr + self._input_names: Optional[List[str]] = None + super().__init__(model_description=model_description, devices=devices) + + def _parse_devices( + self, devices: Optional[Sequence[str]] + ) -> Sequence[Optional[str]]: available_providers: Any = None if hasattr(rt, "get_available_providers"): available_providers = cast(Any, rt.get_available_providers()) @@ -40,8 +42,35 @@ def __init__( else: providers = available_providers else: + available_providers = [available_providers] providers = [available_providers] + if devices is not None: + available_devices = [d for d in devices if d in providers] + unavailable_devices = [d for d in devices if d not in providers] + if available_devices: + if unavailable_devices: + logger.warning( + "The following requested devices are not available for ONNX Runtime and will be ignored: {}.\nSelected available providers/devices are: {}\nOther available providers are: {}", + unavailable_devices, + available_devices, + [p for p in providers if p not in devices], + ) + + providers = available_devices + elif not available_providers: + logger.error( + "ONNX Runtime does not report any available providers. Attempting to load model with default providers, but this will likely fail." + ) + else: + logger.warning( + "None of the requested devices are available for ONNX Runtime, falling back to default, available providers: {}", + available_providers, + ) + return providers + + def _init_model_on_device(self, device: Optional[str]) -> rt.InferenceSession: + onnx_descr = self._onnx_descr if ( isinstance(onnx_descr, v0_5.OnnxWeightsDescr) and onnx_descr.external_data is not None @@ -88,51 +117,30 @@ def source_context_func(): with source_context as s: assert isinstance(s, bytes) or s.exists() + session = rt.InferenceSession( + s, + providers=None if device is None else [device], + ) - # try providers in order until one works - # TODO: check if issue with backup providers is fixed and evaluate handing over all available providers - # currently (onnxruntime 1.23.2) if a higher priority providers fails a RUNTIME_EXCEPTION may be raised - # stating 'model_path must not be empty' instead of trying the next provider, see # TODO: reference issue - provider_exceptions: List[Exception] = [] - for p in providers: - try: - self._session = rt.InferenceSession( - s, - providers=None if p is None else [p], - ) - except Exception as e: - provider_exceptions.append(e) - else: - for bad_p, e in zip( - providers[: len(provider_exceptions)], provider_exceptions - ): - logger.warning( - "Failed to load ONNX model with provider {}: {}", - bad_p, - e, - ) - - break - else: - raise ExceptionGroup( - "Failed to load ONNX model with any of the available providers.", - provider_exceptions, - ) - - onnx_inputs = self._session.get_inputs() - self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] - - if devices is not None: - warnings.warn( - f"Device management is not implemented for onnx yet, ignoring the devices {devices}" + onnx_inputs = session.get_inputs() + onnx_input_names = [str(ipt.name) for ipt in onnx_inputs] # pyright: ignore[reportUnknownArgumentType] + if self._input_names is None: + self._input_names = onnx_input_names + elif self._input_names != onnx_input_names: + raise RuntimeError( + f"Input names of the ONNX model {onnx_input_names} do not match expected input names {self._input_names} from previous model initialization." ) + return session + def _forward_impl( - self, input_arrays: Sequence[Optional[NDArray[Any]]] + self, + device: Optional[str], + model: rt.InferenceSession, + input_arrays: Sequence[Optional[NDArray[Any]]], ) -> List[Optional[NDArray[Any]]]: - result: Any = self._session.run( - None, dict(zip(self._input_names, input_arrays)) - ) + assert self._input_names is not None, "set during model initialization" + result: Any = model.run(None, dict(zip(self._input_names, input_arrays))) if is_list(result) or is_tuple(result): result_seq = list(result) else: @@ -140,7 +148,10 @@ def _forward_impl( return result_seq - def unload(self) -> None: - warnings.warn( - "Device management is not implemented for onnx yet, cannot unload model" - ) + def _cleanup_pre_model_deletion( + self, device: Optional[str], model: rt.InferenceSession + ) -> None: + return + + def _cleanup_post_model_deletion(self, device: Optional[str]) -> None: + return diff --git a/src/bioimageio/core/backends/pytorch_backend.py b/src/bioimageio/core/backends/pytorch_backend.py index 9ca5903d1..2d4634522 100644 --- a/src/bioimageio/core/backends/pytorch_backend.py +++ b/src/bioimageio/core/backends/pytorch_backend.py @@ -1,5 +1,4 @@ import gc -import warnings from abc import abstractmethod from contextlib import nullcontext from io import BytesIO, TextIOWrapper @@ -17,9 +16,9 @@ from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.utils import download +from .._model_adapter import LocalModelAdapter from ..digest_spec import import_callable from ..utils._type_guards import is_list, is_ndarray, is_tuple -from ._model_adapter import ModelAdapter @runtime_checkable @@ -48,37 +47,46 @@ def eval(self) -> Self: return self -class PytorchModelAdapter(ModelAdapter): +class PytorchModelAdapter(LocalModelAdapter[torch.device, nn.Module]): def __init__( self, - *, model_description: AnyModelDescr, - devices: Optional[Sequence[Union[str, torch.device]]] = None, mode: Literal["eval", "train"] = "eval", + devices: Optional[Sequence[str]] = None, ): - super().__init__(model_description=model_description) weights = model_description.weights.pytorch_state_dict if weights is None: raise ValueError("No `pytorch_state_dict` weights found") - devices = get_devices(devices) - self._model = load_torch_model(weights, load_state=True, devices=devices) - if mode == "eval": - self._model = self._model.eval() - elif mode == "train": - self._model = self._model.train() + self._weights = weights + self._mode: Literal["eval", "train"] = mode + super().__init__(model_description=model_description, devices=devices) + + def _parse_devices( + self, devices: Optional[Sequence[str]] + ) -> Sequence[torch.device]: + return get_devices(devices) + + def _init_model_on_device(self, device: torch.device) -> nn.Module: + model = load_torch_model(self._weights, load_state=True, devices=[device]) + + if self._mode == "eval": + model = model.eval() + elif self._mode == "train": + model = model.train() else: - assert_never(mode) + assert_never(self._mode) - self._mode: Literal["eval", "train"] = mode - self._primary_device = devices[0] + return model def _forward_impl( - self, input_arrays: Sequence[Optional[NDArray[Any]]] + self, + device: torch.device, + model: nn.Module, + input_arrays: Sequence[Optional[NDArray[Any]]], ) -> List[Optional[NDArray[Any]]]: tensors = [ - None if a is None else torch.from_numpy(a).to(self._primary_device) - for a in input_arrays + None if a is None else torch.from_numpy(a).to(device) for a in input_arrays ] if self._mode == "eval": @@ -89,7 +97,7 @@ def _forward_impl( assert_never(self._mode) with ctxt(): - model_out = self._model(*tensors) + model_out = model(*tensors) if is_tuple(model_out) or is_list(model_out): model_out_seq = model_out @@ -112,11 +120,15 @@ def _forward_impl( return result - def unload(self) -> None: - del self._model + def _cleanup_pre_model_deletion( + self, device: torch.device, model: nn.Module + ) -> None: + return + + def _cleanup_post_model_deletion(self, device: torch.device) -> None: _ = gc.collect() # deallocate memory - assert torch is not None - torch.cuda.empty_cache() # release reserved memory + if device.type == "cuda": + torch.cuda.empty_cache() # release reserved memory def load_torch_model( @@ -232,18 +244,24 @@ def get_devices( ) -> List[torch.device]: if not devices: if torch.cuda.is_available(): - torch_devices = [torch.device("cuda")] + torch_devices = [ + torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count()) + ] elif torch.backends.mps.is_available(): torch_devices = [torch.device("mps")] else: - torch_devices = [torch.device("cpu")] + try: + if ( + torch.accelerator.is_available() + and (current_accelerator := torch.accelerator.current_accelerator()) + is not None + ): + torch_devices = [current_accelerator] + else: + torch_devices = [torch.device("cpu")] + except Exception: + torch_devices = [torch.device("cpu")] else: torch_devices = [torch.device(d) for d in devices] - if len(torch_devices) > 1: - warnings.warn( - f"Multiple devices for pytorch model not yet implemented; ignoring {torch_devices[1:]}" - ) - torch_devices = torch_devices[:1] - return torch_devices diff --git a/src/bioimageio/core/backends/tensorflow_backend.py b/src/bioimageio/core/backends/tensorflow_backend.py index 8d13a7822..263dc8783 100644 --- a/src/bioimageio/core/backends/tensorflow_backend.py +++ b/src/bioimageio/core/backends/tensorflow_backend.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import tensorflow as tf @@ -8,42 +8,49 @@ from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from .._model_adapter import LocalModelAdapter from ..io import ensure_unzipped -from ._model_adapter import ModelAdapter -class TensorflowModelAdapter(ModelAdapter): +class TensorflowModelAdapter(LocalModelAdapter[None, Any]): + """Adapter for TensorFlow 1 models""" + weight_format = "tensorflow_saved_model_bundle" def __init__( self, - *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ): - super().__init__(model_description=model_description) - weight_file = model_description.weights.tensorflow_saved_model_bundle if model_description.weights.tensorflow_saved_model_bundle is None: raise ValueError("No `tensorflow_saved_model_bundle` weights found") + if isinstance(model_description, v0_4.ModelDescr): + self._weight_src = ( + model_description.weights.tensorflow_saved_model_bundle.source + ) + else: + self._weight_src = model_description.weights.tensorflow_saved_model_bundle + + self._graph = None + self._io_names: Optional[Tuple[List[str], List[str]]] = None + super().__init__(model_description=model_description, devices=devices) + + def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]: if devices is not None: logger.warning( f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" ) + return (None,) + + def _init_model_on_device(self, device: Optional[str]) -> Any: # TODO: check how to load tf weights without unzipping weight_file = ensure_unzipped( - model_description.weights.tensorflow_saved_model_bundle.source, - Path("bioimageio_unzipped_tf_weights"), + self._weight_src, Path("bioimageio_unzipped_tf_weights") ) - self._network = str(weight_file) - # TODO currently we relaod the model every time. it would be better to keep the graph and session - # alive in between of forward passes (but then the sessions need to be properly opened / closed) - def _forward_impl( # pyright: ignore[reportUnknownParameterType] - self, input_arrays: Sequence[Optional[NDArray[Any]]] - ): # TODO read from spec tag = ( # pyright: ignore[reportUnknownVariableType] tf.saved_model.tag_constants.SERVING @@ -52,92 +59,94 @@ def _forward_impl( # pyright: ignore[reportUnknownParameterType] tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY ) - graph = tf.Graph() - with graph.as_default(): - with tf.Session(graph=graph) as sess: # pyright: ignore[reportUnknownVariableType] - # load the model and the signature - graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType] - sess, [tag], self._network - ) - signature = ( # pyright: ignore[reportUnknownVariableType] - graph_def.signature_def - ) + self._graph = tf.Graph() + with self._graph.as_default(): + sess = tf.Session(graph=self._graph) # pyright: ignore[reportUnknownVariableType] + # load the model and the signature + graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType] + sess, [tag], str(weight_file) + ) + signature = ( # pyright: ignore[reportUnknownVariableType] + graph_def.signature_def + ) - # get the tensors into the graph - in_names = [ # pyright: ignore[reportUnknownVariableType] - signature[signature_key].inputs[key].name for key in self._input_ids - ] - out_names = [ # pyright: ignore[reportUnknownVariableType] - signature[signature_key].outputs[key].name - for key in self._output_ids - ] - in_tf_tensors = [ - graph.get_tensor_by_name( - name # pyright: ignore[reportUnknownArgumentType] - ) - for name in in_names # pyright: ignore[reportUnknownVariableType] - ] - out_tf_tensors = [ - graph.get_tensor_by_name( - name # pyright: ignore[reportUnknownArgumentType] - ) - for name in out_names # pyright: ignore[reportUnknownVariableType] - ] - - # run prediction - res = sess.run( # pyright: ignore[reportUnknownVariableType] - dict( - zip( - out_names, # pyright: ignore[reportUnknownArgumentType] - out_tf_tensors, - ) - ), - dict(zip(in_tf_tensors, input_arrays)), - ) - # from dict to list of tensors - res = [ # pyright: ignore[reportUnknownVariableType] - res[out] - for out in out_names # pyright: ignore[reportUnknownVariableType] - ] + # get the tensors into the graph + in_names = [ # pyright: ignore[reportUnknownVariableType] + signature[signature_key].inputs[key].name for key in self._input_ids + ] + out_names = [ # pyright: ignore[reportUnknownVariableType] + signature[signature_key].outputs[key].name for key in self._output_ids + ] + self._io_names = (in_names, out_names) - return res # pyright: ignore[reportUnknownVariableType] + return sess # pyright: ignore[reportUnknownVariableType] - def unload(self) -> None: - logger.warning( - "Device management is not implemented for tensorflow 1, cannot unload model" + def _forward_impl( + self, device: None, model: Any, input_arrays: Sequence[Optional[NDArray[Any]]] + ): + assert self._io_names is not None + assert self._graph is not None + + in_names, out_names = self._io_names + in_tf_tensors = [self._graph.get_tensor_by_name(name) for name in in_names] + out_tf_tensors = [self._graph.get_tensor_by_name(name) for name in out_names] + + # run prediction + res = model.run( + dict(zip(out_names, out_tf_tensors)), + dict(zip(in_tf_tensors, input_arrays)), ) + # from dict to list of tensors + res = [res[out] for out in out_names] + + return res + + def _cleanup_pre_model_deletion(self, device: Optional[str], model: Any) -> None: + return + + def _cleanup_post_model_deletion(self, device: Optional[str]) -> None: + return -class KerasModelAdapter(ModelAdapter): +class KerasModelAdapter(LocalModelAdapter[None, Any]): def __init__( self, - *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ): if model_description.weights.tensorflow_saved_model_bundle is None: raise ValueError("No `tensorflow_saved_model_bundle` weights found") - super().__init__(model_description=model_description) + if isinstance(model_description, v0_4.ModelDescr): + self._weight_src = ( + model_description.weights.tensorflow_saved_model_bundle.source + ) + else: + self._weight_src = model_description.weights.tensorflow_saved_model_bundle + + super().__init__(model_description=model_description, devices=devices) + + def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]: if devices is not None: logger.warning( f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" ) + return (None,) + def _init_model_on_device(self, device: None) -> Any: # TODO: check how to load tf weights without unzipping - weight_file = ensure_unzipped( - model_description.weights.tensorflow_saved_model_bundle.source, - Path("bioimageio_unzipped_tf_weights"), + weight_file = str( + ensure_unzipped(self._weight_src, Path("bioimageio_unzipped_tf_weights")) ) try: - self._network = tf.keras.layers.TFSMLayer( + tfsm_layer = tf.keras.layers.TFSMLayer( # pyright: ignore[reportUnknownVariableType] weight_file, call_endpoint="serve", ) except Exception as e: try: - self._network = tf.keras.layers.TFSMLayer( + tfsm_layer = tf.keras.layers.TFSMLayer( # pyright: ignore[reportUnknownVariableType] weight_file, call_endpoint="serving_default" ) except Exception as ee: @@ -146,16 +155,16 @@ def __init__( ) raise e + return tfsm_layer # pyright: ignore[reportUnknownVariableType] + def _forward_impl( # pyright: ignore[reportUnknownParameterType] - self, input_arrays: Sequence[Optional[NDArray[Any]]] + self, device: None, model: Any, input_arrays: Sequence[Optional[NDArray[Any]]] ): assert tf is not None tf_tensor = [ None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays ] - - result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType] - + result = model(*tf_tensor) assert isinstance(result, dict) # TODO: Use RDF's `outputs[i].id` here @@ -168,15 +177,15 @@ def _forward_impl( # pyright: ignore[reportUnknownParameterType] for r in result # pyright: ignore[reportUnknownVariableType] ] - def unload(self) -> None: - logger.warning( - "Device management is not implemented for tensorflow>=2 models" - + f" using `{self.__class__.__name__}`, cannot unload model" - ) + def _cleanup_pre_model_deletion(self, device: Optional[str], model: Any) -> None: + return + + def _cleanup_post_model_deletion(self, device: Optional[str]) -> None: + return def create_tf_model_adapter( - model_description: AnyModelDescr, devices: Optional[Sequence[str]] + model_description: AnyModelDescr, devices: Optional[Sequence[str]] = None ): tf_version = v0_5.Version(tf.__version__) # type: ignore[reportUnknownVariableType] weights = model_description.weights.tensorflow_saved_model_bundle @@ -203,8 +212,6 @@ def create_tf_model_adapter( ) if tf_version.major <= 1: - return TensorflowModelAdapter( - model_description=model_description, devices=devices - ) + return TensorflowModelAdapter(model_description, devices=devices) else: - return KerasModelAdapter(model_description=model_description, devices=devices) + return KerasModelAdapter(model_description, devices=devices) diff --git a/src/bioimageio/core/backends/torchscript_backend.py b/src/bioimageio/core/backends/torchscript_backend.py index 4e7fe0092..f1f81ed35 100644 --- a/src/bioimageio/core/backends/torchscript_backend.py +++ b/src/bioimageio/core/backends/torchscript_backend.py @@ -3,45 +3,57 @@ from typing import Any, List, Optional, Sequence, Union import torch +from loguru import logger from numpy.typing import NDArray from bioimageio.spec.model import v0_4, v0_5 -from ..model_adapters import ModelAdapter +from .._model_adapter import LocalModelAdapter from ..utils._type_guards import is_list, is_tuple from .pytorch_backend import get_devices -class TorchscriptModelAdapter(ModelAdapter): +class TorchscriptModelAdapter(LocalModelAdapter[torch.device, Any]): def __init__( self, - *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ): - super().__init__(model_description=model_description) if model_description.weights.torchscript is None: raise ValueError( f"No torchscript weights found for model {model_description.name}" ) - self.devices = get_devices(devices) + self._weight_descr = model_description.weights.torchscript + super().__init__(model_description=model_description, devices=devices) - weight_reader = model_description.weights.torchscript.get_reader() - self._model = torch.jit.load(weight_reader) + def _parse_devices( + self, devices: Optional[Sequence[str]] + ) -> Sequence[torch.device]: + return get_devices(devices) - self._model.to(self.devices[0]) - self._model = self._model.eval() + def _init_model_on_device(self, device: torch.device) -> Any: + model = torch.jit.load(self._weight_descr.get_reader(), map_location=device) + try: + model.eval() + except Exception as e: + logger.warning( + f"Failed to set model to evaluation mode for torchscript model on {device}: {e}" + ) + return model def _forward_impl( - self, input_arrays: Sequence[Optional[NDArray[Any]]] + self, + device: torch.device, + model: Any, + input_arrays: Sequence[Optional[NDArray[Any]]], ) -> List[Optional[NDArray[Any]]]: with torch.no_grad(): torch_tensor = [ - None if a is None else torch.from_numpy(a).to(self.devices[0]) + None if a is None else torch.from_numpy(a).to(device) for a in input_arrays ] - output: Any = self._model.forward(*torch_tensor) + output: Any = model.forward(*torch_tensor) if is_list(output) or is_tuple(output): output_seq: Sequence[Any] = output else: @@ -58,8 +70,10 @@ def _forward_impl( for r in output_seq ] - def unload(self) -> None: - self._devices = None - del self._model + def _cleanup_pre_model_deletion(self, device: torch.device, model: Any) -> None: + return + + def _cleanup_post_model_deletion(self, device: torch.device) -> None: _ = gc.collect() # deallocate memory - torch.cuda.empty_cache() # release reserved memory + if device.type == "cuda": + torch.cuda.empty_cache() # release reserved memory diff --git a/src/bioimageio/core/block.py b/src/bioimageio/core/block.py index d1d9d7d80..69af9118a 100644 --- a/src/bioimageio/core/block.py +++ b/src/bioimageio/core/block.py @@ -87,6 +87,15 @@ def from_meta(cls, meta: BlockMeta, data: Tensor) -> Self: data=data, ) + def get_meta(self) -> BlockMeta: + return BlockMeta( + sample_shape=self.sample_shape, + inner_slice=self.inner_slice, + halo=self.halo, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, + ) + def split_tensor_into_blocks( tensor: Tensor, diff --git a/src/bioimageio/core/block_meta.py b/src/bioimageio/core/block_meta.py index f2849e02c..d4ab3c873 100644 --- a/src/bioimageio/core/block_meta.py +++ b/src/bioimageio/core/block_meta.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from functools import cached_property from math import floor, prod +from types import MappingProxyType from typing import ( Any, Callable, @@ -15,13 +16,14 @@ Union, ) +import pydantic from loguru import logger from typing_extensions import Self +from ._axis_annotations import PerAxisAnno from .axis import AxisId, PerAxis from .common import ( BlockIndex, - Frozen, Halo, HaloLike, MemberId, @@ -42,7 +44,7 @@ def compute(self, s: int, round: Callable[[float], int] = floor) -> int: return round(s * self.scale) + self.offset -@dataclass(frozen=True) +@pydantic.dataclasses.dataclass(frozen=True) class BlockMeta: """Block meta data of a sample member (a tensor in a sample) @@ -76,13 +78,13 @@ class BlockMeta: """ - sample_shape: PerAxis[int] + sample_shape: PerAxisAnno[int] """the axis sizes of the whole (unblocked) sample""" - inner_slice: PerAxis[SliceInfo] + inner_slice: PerAxisAnno[SliceInfo] """inner region (without halo) wrt the sample""" - halo: PerAxis[Halo] + halo: PerAxisAnno[Halo] """halo enlarging the inner region to the block's sizes""" block_index: BlockIndex @@ -94,7 +96,7 @@ class BlockMeta: @cached_property def shape(self) -> PerAxis[int]: """axis lengths of the block""" - return Frozen( + return MappingProxyType( { a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) for a, s in self.inner_slice.items() @@ -105,7 +107,7 @@ def shape(self) -> PerAxis[int]: def padding(self) -> PerAxis[PadWidth]: """padding to realize the halo at the sample edge where we cannot simply enlarge the inner slice""" - return Frozen( + return MappingProxyType( { a: PadWidth( ( @@ -128,7 +130,7 @@ def padding(self) -> PerAxis[PadWidth]: @cached_property def outer_slice(self) -> PerAxis[SliceInfo]: """slice of the outer block (without padding) wrt the sample""" - return Frozen( + return MappingProxyType( { a: SliceInfo( max( @@ -154,16 +156,22 @@ def outer_slice(self) -> PerAxis[SliceInfo]: @cached_property def inner_shape(self) -> PerAxis[int]: """axis lengths of the inner region (without halo)""" - return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) + return MappingProxyType( + {a: s.stop - s.start for a, s in self.inner_slice.items()} + ) @cached_property def local_slice(self) -> PerAxis[SliceInfo]: """inner slice wrt the block, **not** the sample""" - return Frozen( + return MappingProxyType( { - a: SliceInfo( - self.halo[a].left, - self.halo[a].left + self.inner_shape[a], + a: ( + SliceInfo( + h.left, + h.left + self.inner_shape[a], + ) + if (h := self.halo.get(a)) is not None + else SliceInfo(0, self.inner_shape[a]) ) for a in self.inner_slice } @@ -189,16 +197,6 @@ def inner_slice_wo_overlap(self) -> PerAxis[SliceInfo]: return self.inner_slice def __post_init__(self): - # freeze mutable inputs - if not isinstance(self.sample_shape, Frozen): - object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) - - if not isinstance(self.inner_slice, Frozen): - object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) - - if not isinstance(self.halo, Frozen): - object.__setattr__(self, "halo", Frozen(self.halo)) - assert all(a in self.sample_shape for a in self.inner_slice), ( "block has axes not present in sample" ) @@ -254,10 +252,12 @@ def split_shape_into_blocks( halo: PerAxis[HaloLike], stride: Optional[PerAxis[int]] = None, ) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: - assert all(a in shape for a in block_shape), ( - tuple(shape), - set(block_shape), - ) + unknown_axes = [a for a in block_shape if a not in shape] + if unknown_axes: + raise ValueError( + f"unknown axes in block_shape: {unknown_axes} for shape {shape}" + ) + if any(shape[a] < block_shape[a] for a in block_shape): # TODO: allow larger blockshape raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") diff --git a/src/bioimageio/core/cli.py b/src/bioimageio/core/cli.py index 318d61512..0a8dd8a0f 100644 --- a/src/bioimageio/core/cli.py +++ b/src/bioimageio/core/cli.py @@ -83,19 +83,19 @@ write_yaml, ) +from ._prediction_pipeline import ( + create_prediction_pipeline, + create_remote_prediction_pipeline, +) from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test from .common import MemberId, SampleId, SupportedWeightsFormat from .digest_spec import get_member_ids, load_sample_for_model from .io import load_stat, save_sample, save_stat -from .prediction import create_prediction_pipeline -from .proc_setup import ( - Measure, - MeasureValue, - StatsCalculator, - get_required_dataset_measures, -) +from .proc_setup import get_required_dataset_measures +from .remote_backends import create_remote_model_adapter from .sample import Sample -from .stat_measures import Stat +from .stat_calculators import StatsCalculator +from .stat_measures import Measure, MeasureValue, Stat from .utils import compare from .weight_converters._add_weights import add_weights @@ -520,6 +520,20 @@ class PredictCmd(CmdBase, WithSource): ) """Device(s) to use""" + server: Optional[str] = None + """The URL or Hugging Face space name of a running bioimageio (gradio) server instance to use as a remote backend for prediction.""" + + pre_post_processing_location: Literal["local", "remote"] = Field( + "local", alias="pre-post-processing-location" + ) + """Where to run preprocessing/postprocessing operations when using `--server`. + + - `local`: Run preprocessing/postprocessing locally and only model inference on the server. + - `remote`: Run preprocessing/postprocessing on the server as well. + +   + """ + example: bool = False """generate and run an example @@ -794,11 +808,25 @@ def input_dataset(stat: Stat): ).items() ) - pp = create_prediction_pipeline( - model_descr, - weight_format=None if self.weight_format == "any" else self.weight_format, - devices=self.devices, - ) + if self.server is not None and self.pre_post_processing_location == "remote": + pp = create_remote_prediction_pipeline(model_descr, server=self.server) + else: + if self.server is None: + model_adapter = None + else: + assert self.pre_post_processing_location == "local" + model_adapter = create_remote_model_adapter( + model_descr, server=self.server + ) + + pp = create_prediction_pipeline( + model_descr, + weight_format=None + if self.weight_format == "any" + else self.weight_format, + devices=self.devices, + model_adapter=model_adapter, + ) if blockwise: predict_method = partial( @@ -920,13 +948,40 @@ def cli_cmd(self): self.log(updated_model_descr) -class EmptyCache(CmdBase): +class EmptyCacheCmd(CmdBase): """Empty the bioimageio cache directory.""" def cli_cmd(self): empty_cache() +class ServerCmd(CmdBase): + """Start a server to connect to with remote model adapters or remote prediction pipelines.""" + + backend: Literal["gradio"] = "gradio" + """The remote backend to use.""" + + port: Optional[int] = None + """The port to start the server on. If not given, a free port will be used.""" + + def cli_cmd(self) -> None: + try: + if self.backend == "gradio": + from .remote_backends.gradio.server import main + else: + assert_never(self.backend) + except ImportError as e: + raise ImportError( + f"{self.backend.capitalize()} is not installed. Please install the '{self.backend}-server' extra to use this command," + + f" e.g. with `pip install bioimageio.core[{self.backend}-server]`." + ) from e + + local_server_url = main(port=self.port) + logger.info( + "{} server shutdown at {}", self.backend.capitalize(), local_server_url + ) + + JSON_FILE = "bioimageio-cli.json" YAML_FILE = "bioimageio-cli.yaml" @@ -972,9 +1027,12 @@ class Bioimageio( add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights") """Add additional weights to a model description by converting from available formats.""" - empty_cache: CliSubCommand[EmptyCache] = Field(alias="empty-cache") + empty_cache: CliSubCommand[EmptyCacheCmd] = Field(alias="empty-cache") """Empty the bioimageio cache directory.""" + server: CliSubCommand[ServerCmd] + """Start a server to connect to with remote model adapters or remote prediction pipelines.""" + @classmethod def settings_customise_sources( cls, diff --git a/src/bioimageio/core/common.py b/src/bioimageio/core/common.py index 981e1efc8..2c4af41e8 100644 --- a/src/bioimageio/core/common.py +++ b/src/bioimageio/core/common.py @@ -1,6 +1,5 @@ from __future__ import annotations -from types import MappingProxyType from typing import ( Hashable, Literal, @@ -11,12 +10,10 @@ Union, ) -from typing_extensions import Self, assert_never +from typing_extensions import Self, TypeAlias, assert_never from bioimageio.spec.model import v0_5 -from .axis import AxisId - SupportedWeightsFormat = Literal[ "keras_hdf5", "keras_v3", @@ -112,10 +109,10 @@ class PadWidth(_LeftRight): pass -PadWidthLike = _LeftRightLike[PadWidth] -Padding = v0_5.Padding -PadMode = Union[Literal["constant", "edge", "reflect", "symmetric"], Padding] -PadWhere = _Where +PadWidthLike: TypeAlias = _LeftRightLike[PadWidth] +Padding: TypeAlias = v0_5.Padding +PadMode: TypeAlias = Union[Literal["constant", "edge", "reflect", "symmetric"], Padding] +PadWhere: TypeAlias = _Where class SliceInfo(NamedTuple): @@ -128,49 +125,17 @@ class SliceInfo(NamedTuple): MemberId = v0_5.TensorId """ID of a `Sample` member, see `bioimageio.core.sample.Sample`""" -BlocksizeParameter = Union[ +BlocksizeParameter: TypeAlias = Union[ v0_5.ParameterizedSize_N, - Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], + Mapping[Tuple[MemberId, v0_5.AxisId], v0_5.ParameterizedSize_N], ] """ Parameter to determine a concrete size for paramtrized axis sizes defined by `bioimageio.spec.model.v0_5.ParameterizedSize`. """ -T = TypeVar("T") -PerMember = Mapping[MemberId, T] +_T = TypeVar("_T") +PerMember = Mapping[MemberId, _T] BlockIndex = int TotalNumberOfBlocks = int - - -K = TypeVar("K", bound=Hashable) -V = TypeVar("V") - -Frozen = MappingProxyType -# class Frozen(Mapping[K, V]): # adapted from xarray.core.utils.Frozen -# """Wrapper around an object implementing the mapping interface to make it -# immutable.""" - -# __slots__ = ("mapping",) - -# def __init__(self, mapping: Mapping[K, V]): -# super().__init__() -# self.mapping = deepcopy( -# mapping -# ) # added deepcopy (compared to xarray.core.utils.Frozen) - -# def __getitem__(self, key: K) -> V: -# return self.mapping[key] - -# def __iter__(self) -> Iterator[K]: -# return iter(self.mapping) - -# def __len__(self) -> int: -# return len(self.mapping) - -# def __contains__(self, key: object) -> bool: -# return key in self.mapping - -# def __repr__(self) -> str: -# return f"{type(self).__name__}({self.mapping!r})" diff --git a/src/bioimageio/core/digest_spec.py b/src/bioimageio/core/digest_spec.py index 280cadb87..4819f4178 100644 --- a/src/bioimageio/core/digest_spec.py +++ b/src/bioimageio/core/digest_spec.py @@ -25,7 +25,7 @@ import xarray as xr from loguru import logger from numpy.typing import NDArray -from typing_extensions import Unpack, assert_never +from typing_extensions import TypeAlias, Unpack, assert_never from bioimageio.spec._internal.io import HashKwargs, PermissiveFileSource from bioimageio.spec.common import FileDescr, FileSource @@ -46,12 +46,15 @@ LinearSampleAxisTransform, Sample, SampleBlockMeta, + SampleBlockWithOrigin, sample_block_meta_generator, ) from .stat_measures import Stat from .tensor import Tensor -TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource] +TensorSource: TypeAlias = Union[ + Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource +] def import_callable( @@ -291,12 +294,23 @@ class IO_SampleBlockMeta(NamedTuple): output: SampleBlockMeta -def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): +def get_input_halo( + model: v0_5.ModelDescr, output_halo: Optional[PerMember[PerAxis[Halo]]] = None +): """returns which halo input tensors need to be divided into blocks with, such that `output_halo` can be cropped from their outputs without introducing gaps.""" input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} outputs = {t.id: t for t in model.outputs} all_tensors = {**{t.id: t for t in model.inputs}, **outputs} + if output_halo is None: + output_halo = { + t.id: { + a.id: Halo(a.halo, a.halo) + for a in t.axes + if isinstance(a, v0_5.WithHalo) + } + for t in model.outputs + } for t, th in output_halo.items(): axes = {a.id: a for a in outputs[t].axes} @@ -547,3 +561,33 @@ def load_sample_for_model( stat={} if stat is None else stat, id=sample_id or tuple(sorted(paths.values())), ) + + +def split_sample_into_blocks_for_model( + sample: Sample, + model: v0_5.ModelDescr, + blocksize_parameter: int, + batch_size: int = 1, +) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: + if isinstance(model, v0_4.ModelDescr): + raise NotImplementedError( + "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" + + f" {model.name}." + + " Consider using `predict_sample_with_fixed_blocking` or update the model description to format version 0.5." + ) + + ns = { + (ipt.id, a.id): blocksize_parameter + for ipt in model.inputs + for a in ipt.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + halo = get_input_halo(model) + + input_block_shape = model.get_tensor_sizes(ns, batch_size=batch_size).inputs + + return sample.split_into_blocks( + block_shapes=input_block_shape, + halo=halo, + pad_mode={ipt.id: ipt.pad or "symmetric" for ipt in model.inputs}, + ) diff --git a/src/bioimageio/core/io.py b/src/bioimageio/core/io.py index 179070619..82cb512f6 100644 --- a/src/bioimageio/core/io.py +++ b/src/bioimageio/core/io.py @@ -7,6 +7,7 @@ from pathlib import Path from shutil import copyfileobj from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -22,6 +23,8 @@ from loguru import logger from numpy.typing import NDArray from pydantic import BaseModel, RootModel +from typing_extensions import TypeAlias +from typing_extensions import TypeAliasType as _TypeAliasType from bioimageio.spec._internal.io import get_reader, interprete_file_source from bioimageio.spec._internal.type_guards import is_ndarray @@ -38,13 +41,23 @@ from .axis import AxisId, AxisLike from .common import PerMember -from .sample import Sample, Stat -from .stat_measures import DatasetMeasure, MeasureValue, SampleMeasure +from .sample import Sample +from .stat_measures import DatasetMeasure, MeasureValue, SampleMeasure, Stat from .tensor import Tensor -JsonValue = Union[ - bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"] -] +if TYPE_CHECKING: + JsonValue: TypeAlias = Union[ + bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"] + ] # note: order relevant for deserializing + +else: + # for pydantic validation we need to use `TypeAliasType`, + # see https://docs.pydantic.dev/latest/concepts/types/#named-recursive-types + # however this results in a partially unknown type with the current pyright 1.1.388 + JsonValue: TypeAlias = _TypeAliasType( + "JsonValue", + Union[bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"]], + ) def load_image( diff --git a/src/bioimageio/core/model_adapters.py b/src/bioimageio/core/model_adapters.py deleted file mode 100644 index db92d013a..000000000 --- a/src/bioimageio/core/model_adapters.py +++ /dev/null @@ -1,22 +0,0 @@ -"""DEPRECATED""" - -from typing import List - -from .backends._model_adapter import ( - DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, - ModelAdapter, - create_model_adapter, -) - -__all__ = [ - "ModelAdapter", - "create_model_adapter", - "get_weight_formats", -] - - -def get_weight_formats() -> List[str]: - """ - Return list of supported weight types - """ - return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) diff --git a/src/bioimageio/core/prediction.py b/src/bioimageio/core/prediction.py index 65e0802cd..d00585be9 100644 --- a/src/bioimageio/core/prediction.py +++ b/src/bioimageio/core/prediction.py @@ -65,7 +65,7 @@ def predict( """ if isinstance(model, PredictionPipeline): pp = model - model = pp.model_description + model = pp.model_descr else: if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): loaded = load_description(model) @@ -78,54 +78,56 @@ def predict( fixed_dataset_statistics=inputs.stat if isinstance(inputs, Sample) else {}, ) - if save_output_path is not None: - if ( - "{output_id}" not in str(save_output_path) - and "{member_id}" not in str(save_output_path) - and len(model.outputs) > 1 - ): - raise ValueError( - f"Missing `{{output_id}}` in save_output_path={save_output_path} to " - + "distinguish model outputs " - + str([get_member_id(d) for d in model.outputs]) + with pp: + model = pp.model_descr + if save_output_path is not None: + if ( + "{output_id}" not in str(save_output_path) + and "{member_id}" not in str(save_output_path) + and len(model.outputs) > 1 + ): + raise ValueError( + f"Missing `{{output_id}}` in save_output_path={save_output_path} to " + + "distinguish model outputs " + + str([get_member_id(d) for d in model.outputs]) + ) + + if isinstance(inputs, Sample): + sample = inputs + else: + sample = create_sample_for_model( + pp.model_descr, inputs=inputs, sample_id=sample_id ) - if isinstance(inputs, Sample): - sample = inputs - else: - sample = create_sample_for_model( - pp.model_description, inputs=inputs, sample_id=sample_id - ) - - if input_block_shape is not None: - if blocksize_parameter is not None: - logger.warning( - "ignoring blocksize_parameter={} in favor of input_block_shape={}", - blocksize_parameter, - input_block_shape, + if input_block_shape is not None: + if blocksize_parameter is not None: + logger.warning( + "ignoring blocksize_parameter={} in favor of input_block_shape={}", + blocksize_parameter, + input_block_shape, + ) + + output = pp.predict_sample_with_fixed_blocking( + sample, + input_block_shape=input_block_shape, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, ) - - output = pp.predict_sample_with_fixed_blocking( - sample, - input_block_shape=input_block_shape, - skip_preprocessing=skip_preprocessing, - skip_postprocessing=skip_postprocessing, - ) - elif blocksize_parameter is not None: - output = pp.predict_sample_with_blocking( - sample, - skip_preprocessing=skip_preprocessing, - skip_postprocessing=skip_postprocessing, - ns=blocksize_parameter, - ) - else: - output = pp.predict_sample_without_blocking( - sample, - skip_preprocessing=skip_preprocessing, - skip_postprocessing=skip_postprocessing, - ) - if save_output_path: - save_sample(save_output_path, output) + elif blocksize_parameter is not None: + output = pp.predict_sample_with_blocking( + sample, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ns=blocksize_parameter, + ) + else: + output = pp.predict_sample_without_blocking( + sample, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ) + if save_output_path: + save_sample(save_output_path, output) return output diff --git a/src/bioimageio/core/proc_ops.py b/src/bioimageio/core/proc_ops.py index 052767fe0..cce52970e 100644 --- a/src/bioimageio/core/proc_ops.py +++ b/src/bioimageio/core/proc_ops.py @@ -60,7 +60,7 @@ def _convert_axis_ids( if mode == "per_sample": ret = [] elif mode == "per_dataset": - ret = [v0_5.BATCH_AXIS_ID] + ret = [AxisId(v0_5.BATCH_AXIS_ID)] else: assert_never(mode) @@ -562,7 +562,9 @@ def get_descr(self): return v0_5.ScaleRangeDescr( kwargs=v0_5.ScaleRangeKwargs( - axes=self.lower.axes, + axes=None + if self.lower.axes is None + else [v0_5.AxisId(a) for a in self.lower.axes], min_percentile=self.lower.q * 100, max_percentile=self.upper.q * 100, eps=self.eps, @@ -625,7 +627,7 @@ def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) def get_descr(self): - return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis)) + return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=v0_5.AxisId(self.axis))) @dataclass diff --git a/src/bioimageio/core/remote_backends/__init__.py b/src/bioimageio/core/remote_backends/__init__.py new file mode 100644 index 000000000..1e2be1d07 --- /dev/null +++ b/src/bioimageio/core/remote_backends/__init__.py @@ -0,0 +1,38 @@ +from typing import TYPE_CHECKING, Literal, Optional + +from typing_extensions import assert_never + +from bioimageio.spec.model import AnyModelDescr + +if TYPE_CHECKING: + from .gradio.client import GradioModelAdapter + + +def create_remote_model_adapter( + model_description: AnyModelDescr, + server: Optional[str] = None, + server_type: Optional[Literal["gradio"]] = None, +) -> "GradioModelAdapter": + """Create a remote model adapter + + Args: + model_description: The model to run inference with. + server: The URL or Hugging Face space name of a running bioimageio server instance + server_type: The type of the remote server to connect to. Currently only "gradio" is supported. + """ + + if server_type is None: + server_type = "gradio" + + try: + if server_type == "gradio": + from .gradio.client import GradioModelAdapter as RemoteModelAdapterImpl + else: + assert_never(server_type) + except ImportError as e: + raise ImportError( + f"Failed to import {server_type.capitalize()}ModelAdapter. Make sure to install the '{server_type}-client' extra," + + f" e.g. with `pip install bioimageio.core[{server_type}-client]`." + ) from e + + return RemoteModelAdapterImpl(model_description=model_description, server=server) diff --git a/src/bioimageio/core/remote_backends/gradio/__init__.py b/src/bioimageio/core/remote_backends/gradio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/bioimageio/core/remote_backends/gradio/client.py b/src/bioimageio/core/remote_backends/gradio/client.py new file mode 100644 index 000000000..8a368e0c8 --- /dev/null +++ b/src/bioimageio/core/remote_backends/gradio/client.py @@ -0,0 +1,316 @@ +from types import MappingProxyType +from typing import Dict, Iterable, Literal, Mapping, Optional, Tuple, Union + +from gradio_client import Client +from loguru import logger + +from bioimageio.spec import AnyModelDescr, ValidationSummary +from bioimageio.spec.model import v0_4 + +from ..._description_serializer import DescriptionSerializer as DescriptionSerializer +from ..._model_adapter import RemoteModelAdapter +from ..._prediction_pipeline import IntermediatePrediction, RemotePredictionPipeline +from ..._settings import settings +from ...axis import PerAxis +from ...common import BlocksizeParameter, PerMember +from ...io import JsonValue +from ...sample import Sample, SampleBlock +from ...stat_measures import Measure, MeasureValue +from .serializer import GradioSampleSerializer + +SerializedSampleBlock = Dict[str, JsonValue] + + +class GradioModelAdapter(RemoteModelAdapter[SerializedSampleBlock]): + """Model adapter to use the bioimage-io-gradio-runner as a backend for model inference.""" + + def __init__( + self, model_description: AnyModelDescr, *, server: Optional[str] = None + ): + """Initialize the GradioModelAdapter. + + Note: + - This adapter requires an environment with the same gradio version as the one used on the bioimage-io-gradio-runner server. + + Args: + model_description: The model to run inference with. + server: The URL of a running bioimage-io-gradio-server instance (default server might not be availability/compatible). + """ + server = server or settings.gradio_server + if server is None: + raise ValueError( + "No gradio server specified. Please provide a server URL or set the 'BIOIMAGEIO_GRADIO_SERVER' environment variable." + ) + + self._client = Client(server, httpx_kwargs={"timeout": 60}) + self._serialized_model, self._sha256 = ( + DescriptionSerializer.serialize_to_string_and_hash(model_description) + ) + super().__init__( + model_description, server=server, sample_serializer=GradioSampleSerializer() + ) + + def _forward_impl( + self, serialized_input_sample: Iterable[SerializedSampleBlock] + ) -> Iterable[SerializedSampleBlock]: + return _call_predict_api( + self._client, + self._serialized_model, + self._sha256, + serialized_input_sample, + blocksize=None, + skip_preprocessing=True, + skip_postprocessing=True, + skip_input_padding=True, + skip_output_cropping=True, + batch_size=None, + ) + + def unload(self): + return super().unload() + + def load(self) -> None: + for model_data in ("", self._serialized_model): + try: + result = self._client.submit( + api_name="/load_model", model=model_data, sha256=self._sha256 + ).result() + except Exception as e: + if model_data: + logger.warning( + "Failed to load model on server with model_data, error was: {}", + len(model_data), + e, + ) + else: + if result: + break + + def test(self) -> Optional[ValidationSummary]: + for model_data in ("", self._serialized_model): + try: + result = self._client.submit( + api_name="/test_model", model=model_data, sha256=self._sha256 + ).result() + except Exception as e: + if model_data: + logger.warning( + "Failed to test model on server with model_data, error was: {}", + len(model_data), + e, + ) + else: + if result: + return ValidationSummary.model_validate_json(result) + + return None + + +class GradioPredictionPipeline(RemotePredictionPipeline): + """Prediction pipeline to use the bioimage-io-gradio-runner as a fully remote prediction pipeline.""" + + def __init__( + self, + model_description: AnyModelDescr, + *, + server: Optional[str] = None, + precomputed_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}), + default_blocksize_parameter: BlocksizeParameter = 10, + default_batch_size: int = 1, + ): + """ + Note: + - This pipeline requires an environment with the same gradio version as the one used on the bioimage-io-gradio-runner server. + + Args: + model_description: The model to run inference with. + server: The URL or Hugging Face space name of a running bioimageio gradio server instance (Note: default server might not be availabile/compatible!). + """ + server = server or settings.gradio_server + if server is None: + raise ValueError( + "No gradio server specified. Please provide a server URL or set the 'BIOIMAGEIO_GRADIO_SERVER' environment variable." + ) + + super().__init__( + model_description, + server=server, + default_blocksize_parameter=default_blocksize_parameter, + default_batch_size=default_batch_size, + ) + self._client = Client(self.server, httpx_kwargs={"timeout": 60}) + self._serialized_model, self._sha256 = ( + DescriptionSerializer.serialize_to_string_and_hash(model_description) + ) + self._serializer = GradioSampleSerializer + self._precomputed_statistics = dict(precomputed_statistics) + + def predict_sample_block( + self, + sample_block: SampleBlock, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ) -> SampleBlock: + if isinstance(self._model_descr, v0_4.ModelDescr): + raise NotImplementedError( + f"predict_sample_block not implemented for model {self._model_descr.format_version}" + ) + else: + assert self._block_transform is not None + + sample_block.stat.update(self._precomputed_statistics) + output_block = self._serializer.deserialize_sample( + _call_predict_api( + self._client, + self._serialized_model, + self._sha256, + serialized_input_sample=self._serializer.serialize_sample( + sample_block.as_sample() + ), + blocksize=None, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + skip_input_padding=True, + skip_output_cropping=True, + batch_size=self._default_batch_size, + ) + ) + output_meta = sample_block.get_transformed_meta(self._block_transform) + return output_meta.with_data(output_block.members, stat=sample_block.stat) + + def predict_sample_without_blocking( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + skip_input_padding: bool = False, + skip_output_cropping: bool = False, + ) -> Sample: + sample.stat.update(self._precomputed_statistics) + return self._serializer.deserialize_sample( + _call_predict_api( + self._client, + self._serialized_model, + self._sha256, + serialized_input_sample=self._serializer.serialize_sample(sample), + blocksize=None, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + skip_input_padding=skip_input_padding, + skip_output_cropping=skip_output_cropping, + batch_size=self._default_batch_size, + ) + ) + + def predict_sample_with_fixed_blocking_yield_intermediates( + self, + sample: Sample, + input_block_shape: PerMember[PerAxis[int]], + *, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + fill_value: float = float("nan"), + ) -> Tuple[int, Iterable[IntermediatePrediction]]: + sample.stat.update(self._precomputed_statistics) + + # blocking for serialization is not really important, but we might as well block + # the same way we want the backend to block for blockwise prediction + serialized_input_sample = self._serializer.serialize_sample_with_fixed_blocking( + sample, block_shapes=input_block_shape, halo=self._default_input_halo + ) + + def _predict_blocks(): + output_sample = None + for serialized_output_block in _call_predict_api( + self._client, + self._serialized_model, + self._sha256, + serialized_input_sample=serialized_input_sample, + blocksize=input_block_shape, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + skip_input_padding=False, + skip_output_cropping=False, + batch_size=self._default_batch_size, + ): + output_block = self._serializer.deserialize_sample_block( + serialized_output_block + ) + if output_sample is None: + output_sample = Sample.from_blocks( + [output_block], fill_value=fill_value + ) + else: + output_sample.set_block(output_block) + + yield IntermediatePrediction(output_sample, output_block) + + block_iterator = _predict_blocks() + first_intermediate = next(block_iterator) + + def _intermediate_predictions() -> Iterable[IntermediatePrediction]: + yield first_intermediate + yield from block_iterator + + return ( + first_intermediate.last_block.blocks_in_sample, + _intermediate_predictions(), + ) + + +def _call_predict_api( + client: Client, + serialized_model: str, + sha256: str, + serialized_input_sample: Iterable[SerializedSampleBlock], + blocksize: Optional[ + Union[int, Literal["blockwise_as_serialized"], PerMember[PerAxis[int]]] + ], + skip_preprocessing: bool, + skip_postprocessing: bool, + skip_input_padding: bool, + skip_output_cropping: bool, + batch_size: Optional[int], +) -> Iterable[SerializedSampleBlock]: + def submit(model: str): + return client.submit( + api_name="/predict", + model=model, + sha256=sha256, + input_sample=serialized_input_sample, + blocksize={ + str(k): {str(kk): vv for kk, vv in v.items()} + for k, v in blocksize.items() + } + if not (blocksize is None or isinstance(blocksize, (int, str))) + else blocksize, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + skip_input_padding=skip_input_padding, + skip_output_cropping=skip_output_cropping, + batch_size=batch_size, + ) + + try_with_model_upload = True + try: + job = submit("") + for block in job: # pyright: ignore[reportUnknownVariableType] + yield block # pyright: ignore[reportReturnType] + # we got one response, so the model cache was hit... + try_with_model_upload = False + except Exception as e: + # A raised exception on the server seems to simply return an empty response sequence, + # so this except is likely not triggered at all. + # Below we retry on empty return value, too. + if try_with_model_upload: + logger.warning( + "Failed to submit job without model upload, trying with model upload, error was: {}", + e, + ) + else: + raise e + + if try_with_model_upload: + job = submit(serialized_model) + for block in job: # pyright: ignore[reportUnknownVariableType] + yield block # pyright: ignore[reportReturnType] diff --git a/src/bioimageio/core/remote_backends/gradio/serializer.py b/src/bioimageio/core/remote_backends/gradio/serializer.py new file mode 100644 index 000000000..3800b3a6c --- /dev/null +++ b/src/bioimageio/core/remote_backends/gradio/serializer.py @@ -0,0 +1,72 @@ +import tempfile +from pathlib import Path +from typing import Dict, List, Mapping, Union + +import numpy as np +from gradio_client import handle_file +from pydantic import BaseModel +from typing_extensions import Self + +from ..._common_annotations import PerMemberAnno +from ..._description_serializer import DescriptionSerializer as DescriptionSerializer +from ..._sample_serializer import SampleSerializer +from ...common import MemberId +from ...io import JsonValue, load_stat, save_tensor, serialize_stat +from ...sample import SampleBlock, SampleBlockMeta +from ...tensor import Tensor + + +class _SerializableBlock(BaseModel, frozen=True): + path: Path + meta: Mapping[str, str] + orig_name: str + + @classmethod + def from_tensor(cls, tensor: Tensor) -> Self: + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp: + save_tensor(tmp.name, tensor) + + handled = handle_file(Path(tmp.name)) + return cls.model_validate(handled) + + +class _SerializableSampleBlock(BaseModel, frozen=True): + meta: SampleBlockMeta + data: PerMemberAnno[Union[_SerializableBlock, Path]] + serialized_stat: List[JsonValue] + + +SerializedSampleBlock = Dict[str, JsonValue] + + +class GradioSampleSerializer(SampleSerializer[SerializedSampleBlock]): + @staticmethod + def serialize_sample_block(sample_block: SampleBlock) -> SerializedSampleBlock: + handled_members: Dict[MemberId, _SerializableBlock] = {} + for m, t in sample_block.members.items(): + handled_members[m] = _SerializableBlock.from_tensor(t) + + serializable = _SerializableSampleBlock( + data=handled_members, + meta=sample_block.get_meta(), + serialized_stat=serialize_stat(sample_block.stat), + ) + serialized = serializable.model_dump(mode="json") + return serialized + + @staticmethod + def deserialize_sample_block(serialized: SerializedSampleBlock) -> SampleBlock: + deserializable_sample = _SerializableSampleBlock.model_validate(serialized) + sample_meta = deserializable_sample.meta + members = { + k: Tensor.from_numpy( + np.load(v if isinstance(v, Path) else v.path), + dims=list(sample_meta.shape[k]), + ) + for k, v in deserializable_sample.data.items() + } + return SampleBlock.from_meta( + sample_meta, + data=members, + stat=load_stat(deserializable_sample.serialized_stat), + ) diff --git a/src/bioimageio/core/remote_backends/gradio/server.py b/src/bioimageio/core/remote_backends/gradio/server.py new file mode 100644 index 000000000..cf1cc2361 --- /dev/null +++ b/src/bioimageio/core/remote_backends/gradio/server.py @@ -0,0 +1,249 @@ +from itertools import chain +from typing import ( + Any, + Dict, + Iterable, + Literal, + Optional, + Union, +) + +import gradio as gr +from loguru import logger + +import bioimageio.core +from bioimageio.core import AxisId, Stat +from bioimageio.core.axis import PerAxis +from bioimageio.core.backends import create_model_adapter +from bioimageio.core.common import PerMember +from bioimageio.core.remote_backends.gradio.serializer import ( + DescriptionSerializer, + GradioSampleSerializer, + SerializedSampleBlock, +) +from bioimageio.spec import load_model_description +from bioimageio.spec.common import Sha256 +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 + +try: + import spaces # pyright: ignore +except ImportError: + logger.warning("Failed to import 'spaces' package") + + class spaces: + @staticmethod + def GPU(func: Any): + return func + + +logger.enable("bioimageio") + +app = gr.Server() + + +@app.api(name="predict") # pyright: ignore[reportUntypedFunctionDecorator] +@spaces.GPU +def predict( + model: str, + sha256: str, + input_sample: Iterable[SerializedSampleBlock], + blocksize: Optional[ + Union[int, Literal["blockwise_as_serialized"], PerMember[PerAxis[int]]] + ] = None, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + skip_input_padding: bool = False, + skip_output_cropping: bool = False, + batch_size: Optional[int] = None, +) -> Iterable[SerializedSampleBlock]: + """Run prediction on a sample + + Args: + input_sample: Input sample as a sequence of serialized sample blocks. + Use bioimageio.core.backends.gradio_backend.GradioModelAdapter.serialize_sample to create this from a Sample object. + model: A model source: URL, nickname or base64 encoded model package (if len(model) > 2083). + sha256: Sha256 hash of the model's bioimageio.yaml file at the model source or of the encoded model package. + blocksize: + - None (default): run non-blockwise, full-sample prediction. + - integer: run blockwise prediction with a block size derived from the model and this blocksize parameter. + - "blockwise_as_serialized": run blockwise prediction with the same blocking as the serialized input sample. + (Non-blockwise pre- and postprocessing steps will be ignored.) + - PerMember[PerAxis[int]]: run blockwise prediction with a fixed block shape given for each sample member. + skip_preprocessing: If True, skip preprocessing steps defined in the model. + skip_postprocessing: If True, skip postprocessing steps defined in the model. + skip_input_padding: If True, skip input padding for non-blockwise prediction. + Set this flag when predicting an (overlapping) sample block rather than a full sample. + skip_output_cropping: If True, skip output cropping for non-blockwise prediction. + Set this flag when predicting an (overlapping) sample block rather than a full sample. + batch_size: Optional batch size only applicable to predicting input samples with batch dimension. + """ + + def setup(stat: Stat): + model_adapter = _get_model_adapter(model, sha256=sha256) + return bioimageio.core.create_prediction_pipeline( + model_adapter.model_descr, fixed_dataset_statistics=stat + ) + + if blocksize == "blockwise_as_serialized": + sample_block_iterator = iter(input_sample) + deserialized_input_block = GradioSampleSerializer.deserialize_sample_block( + next(sample_block_iterator) + ) + pp = setup(deserialized_input_block.stat) + for block in chain( + [deserialized_input_block], + ( + GradioSampleSerializer.deserialize_sample_block(b) + for b in sample_block_iterator + ), + ): + output_block = pp.predict_sample_block( + block, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ) + yield GradioSampleSerializer.serialize_sample_block(output_block) + else: + deserialized_input_sample = GradioSampleSerializer.deserialize_sample( + input_sample + ) + pp = setup(deserialized_input_sample.stat) + + output_sample = None + if isinstance(blocksize, int): + try: + if pp.has_non_blockwise_postprocessing and not skip_postprocessing: + output_sample = pp.predict_sample_with_blocking( + deserialized_input_sample, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ns=blocksize, + batch_size=batch_size, + ) + else: + for output in pp.predict_sample_with_blocking_yield_intermediates( + deserialized_input_sample, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ns=blocksize, + batch_size=batch_size, + )[1]: + # with purely blockwise postprocesssing or with postprocessing skipped, + # predicted blocks are part of the final result, so we yield them immediately. + yield GradioSampleSerializer.serialize_sample_block( + output.last_block + ) + + return + + except Exception as e: + logger.warning( + "Falling back to full-sample prediction for model {}: {}", + pp.model_descr.id or pp.model_descr.name, + e, + ) + if output_sample is None: + output_sample = pp.predict_sample_without_blocking( + deserialized_input_sample, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + skip_input_padding=skip_input_padding, + skip_output_cropping=skip_output_cropping, + ) + + if all( + axes.get(AxisId("batch"), 1) > 1 for axes in output_sample.shape.values() + ): + # yield batches + yield from GradioSampleSerializer.serialize_sample_with_fixed_blocking( + output_sample, + block_shapes={ + m: {AxisId("batch"): batch_size or 1} for m in output_sample.shape + }, + halo={}, + ) + else: + yield from GradioSampleSerializer.serialize_sample(output_sample) + + +@app.api(name="load_model") # pyright: ignore[reportUntypedFunctionDecorator] +def load_model( + model: str, + sha256: str, +) -> dict[Literal["message"], str]: + """Load a model into the server's model cache. This can be used to pre-load a model before running predictions to avoid the overhead of loading the model during the first prediction request.""" + _ = _get_model_adapter(model, sha256=sha256) + return {"message": "Model loaded successfully"} + + +@app.api(name="test_model") # pyright: ignore[reportUntypedFunctionDecorator] +def test_model( + model: str, + sha256: str, +) -> str: + """Run the bioimageio model test and return the validation summary. Returns None if testing failed.""" + model_adapter = _get_model_adapter(model, sha256=sha256) + summary = bioimageio.core.test_model(model_adapter.model_descr) + return summary.model_dump_json() + + +def _cache_key(kwargs: Dict[str, Any]) -> str: + return kwargs["sha256"] + + +@gr.cache( # pyright: ignore[reportUntypedFunctionDecorator] + key=_cache_key, + max_size=bioimageio.core.settings.gradio_server_model_cache_max_size, + max_memory=bioimageio.core.settings.gradio_server_model_cache_max_memory, + per_session=False, +) +def _get_model_adapter( + model: str, + *, + sha256: str, +): + """Get a model adapter for the given model + + Args: + model: A model source: URL (len(model) <= 2083)) or model base64 encoded package bytes (len(model) > 2083). + sha256: Sha256 hash of the model source at model URL or of the encoded model package bytes. + """ + if not model: + raise ValueError("Model source cannot be empty") + + model_descr = _get_model(model, sha256=sha256) + return create_model_adapter(model_description=model_descr) + + +def _get_model( + model: str, + *, + sha256: str, +) -> AnyModelDescr: + if len(model) > 2083: + ret = DescriptionSerializer.deserialize_from_string(model) + if not isinstance(ret, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise ValueError( + f"Deserialized model description is not a valid model description: got {ret.type} {ret.format_version}" + ) + return ret + else: + return load_model_description(model, sha256=Sha256(sha256) if sha256 else None) + + +@app.get("/") +def root(): + return { + "message": f"Running bioimageio.core {bioimageio.core.__version__} gradio server." + } + + +def main(port: Optional[int] = None) -> str: + _app, local_url, _share_url = app.launch( + mcp_server=True, show_error=True, server_port=port + ) + return local_url + + +if __name__ == "__main__": + _ = main() diff --git a/src/bioimageio/core/sample.py b/src/bioimageio/core/sample.py index 8d35c7b2a..8f0b771f9 100644 --- a/src/bioimageio/core/sample.py +++ b/src/bioimageio/core/sample.py @@ -3,6 +3,7 @@ import collections.abc from dataclasses import dataclass from math import ceil, floor +from types import MappingProxyType from typing import ( Any, Callable, @@ -17,10 +18,12 @@ ) import numpy as np +import pydantic import xarray as xr from numpy.typing import NDArray from typing_extensions import Self +from ._common_annotations import PerMemberAnno from .axis import AxisId, PerAxis from .block import Block from .block_meta import ( @@ -84,6 +87,29 @@ def __getitem__( id=self.id, ) + def set_block(self, block: SampleBlock) -> None: + """Set values of `block`. + + Note: + - Updates only existing sample members (extra block members are ignored) + - Ignores missing block members (i.e. members in the sample but not in the block are not modified) + + Raises: + ValueError if block and sample members do not overlap at all. + """ + no_overlap = True + for m in self.members: + if m not in block.blocks: + continue + b = block.blocks[m] + self.members[m][b.inner_slice] = b.inner_data + no_overlap = False + + if no_overlap: + raise ValueError( + f"block with members {list(block.blocks)} does not overlap with sample members {list(self.members)}" + ) + @property def shape(self) -> PerMember[PerAxis[int]]: return {tid: t.sizes for tid, t in self.members.items()} @@ -146,21 +172,58 @@ def from_blocks( *, fill_value: float = float("nan"), ) -> Self: - members: PerMember[Tensor] = {} - stat: Stat = {} - sample_id = None + """Create a `Sample` from an iterable of `SampleBlock`s. + + Note: + All sample blocks must have the same `sample_id`. + + Args: + sample_blocks: The blocks to create the sample from. + fill_value: The value to fill missing values with (default: `nan`). + """ + output = None + for output in cls.from_blocks_yield_intermediates( + sample_blocks, fill_value=fill_value + ): + pass + + if output is None: + raise ValueError("no sample blocks provided") + + return output + + @classmethod + def from_blocks_yield_intermediates( + cls, + sample_blocks: Iterable[SampleBlock], + *, + fill_value: float = float("nan"), + ): + """Create a `Sample` from an iterable of `SampleBlock`s, yielding the intermediate sample after each block. + + Args: + sample_blocks: The blocks to create the sample from. + fill_value: The value to fill missing values with (default: `nan`). + """ + output = cls(members={}, stat={}, id=None) for sample_block in sample_blocks: - assert sample_id is None or sample_id == sample_block.sample_id - sample_id = sample_block.sample_id - stat = sample_block.stat + if output.id is None: + output.id = sample_block.sample_id + else: + assert output.id == sample_block.sample_id, ( + "sample id changed between sample blocks" + ) + + output.stat = sample_block.stat + for m, block in sample_block.blocks.items(): - if m not in members: + if m not in output.members: if -1 in block.sample_shape.values(): raise NotImplementedError( "merging blocks with data dependent axis not yet implemented" ) - members[m] = Tensor( + output.members[m] = Tensor( np.full( tuple(block.sample_shape[a] for a in block.data.dims), fill_value, @@ -169,9 +232,10 @@ def from_blocks( dims=block.data.dims, ) - members[m][block.inner_slice] = block.inner_data + output.members[m][block.inner_slice] = block.inner_data + yield output - return cls(members=members, stat=stat, id=sample_id) + yield output def pad( self, @@ -199,20 +263,20 @@ def pad( ) -BlockT = TypeVar("BlockT", Block, BlockMeta) +BlockT = TypeVar("BlockT", bound=BlockMeta) -@dataclass +@pydantic.dataclasses.dataclass(frozen=True) class SampleBlockBase(Generic[BlockT]): """base class for `SampleBlockMeta` and `SampleBlock`""" - sample_shape: PerMember[PerAxis[int]] + sample_shape: PerMemberAnno[PerAxis[int]] """the sample shape this block represents a part of""" sample_id: SampleId """identifier for the sample within its dataset""" - blocks: Dict[MemberId, BlockT] + blocks: PerMemberAnno[BlockT] """Individual tensor blocks comprising this sample block""" block_index: BlockIndex @@ -223,11 +287,11 @@ class SampleBlockBase(Generic[BlockT]): @property def shape(self) -> PerMember[PerAxis[int]]: - return {mid: b.shape for mid, b in self.blocks.items()} + return MappingProxyType({mid: b.shape for mid, b in self.blocks.items()}) @property def inner_shape(self) -> PerMember[PerAxis[int]]: - return {mid: b.inner_shape for mid, b in self.blocks.items()} + return MappingProxyType({mid: b.inner_shape for mid, b in self.blocks.items()}) @dataclass @@ -235,7 +299,7 @@ class LinearSampleAxisTransform(LinearAxisTransform): member: MemberId -@dataclass +@pydantic.dataclasses.dataclass(frozen=True) class SampleBlockMeta(SampleBlockBase[BlockMeta]): """Meta data of a dataset sample block""" @@ -329,10 +393,13 @@ def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: ) -@dataclass +@dataclass(frozen=True) class SampleBlock(SampleBlockBase[Block]): """A block of a dataset sample""" + blocks: Dict[MemberId, Block] + """Individual tensor blocks comprising this sample block""" + stat: Stat """computed statistics""" @@ -352,8 +419,45 @@ def get_transformed_meta( blocks_in_sample=self.blocks_in_sample, ).get_transformed(new_axes) + @classmethod + def from_meta( + cls, meta: SampleBlockMeta, data: PerMember[Tensor], stat: Stat + ) -> Self: + return cls( + sample_shape=meta.sample_shape, + sample_id=meta.sample_id, + blocks={ + m: Block.from_meta(b, data=data[m]) for m, b in meta.blocks.items() + }, + stat=stat, + block_index=meta.block_index, + blocks_in_sample=meta.blocks_in_sample, + ) -@dataclass + def get_meta(self) -> SampleBlockMeta: + return SampleBlockMeta( + sample_id=self.sample_id, + blocks={m: b.get_meta() for m, b in self.blocks.items()}, + sample_shape=self.sample_shape, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, + ) + + def as_sample(self) -> Sample: + """Convert this sample block to a `Sample` with the shape of this block. + + Note: + If you want to convert one or more sample block to a sample with the shape of the original, whole sample, + use `Sample.from_blocks()` instead. + """ + return Sample( + members=dict(self.members), + stat=dict(self.stat), + id=self.sample_id, + ) + + +@dataclass(frozen=True) class SampleBlockWithOrigin(SampleBlock): """A `SampleBlock` with a reference (`origin`) to the whole `Sample`""" diff --git a/src/bioimageio/core/tensor.py b/src/bioimageio/core/tensor.py index f89e2eb70..63d545b51 100644 --- a/src/bioimageio/core/tensor.py +++ b/src/bioimageio/core/tensor.py @@ -511,8 +511,8 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): ndim = array.ndim if ndim == 2: current_axes = ( - v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), - v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[0]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[1]), ) elif ndim == 3 and any(s <= 3 for s in array.shape): current_axes = ( @@ -521,14 +521,14 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) ] ), - v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), - v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[2]), ) elif ndim == 3: current_axes = ( - v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), - v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), - v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("z"), size=array.shape[0]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[2]), ) elif ndim == 4: current_axes = ( @@ -537,9 +537,9 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) ] ), - v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), - v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), - v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("z"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[2]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[3]), ) elif ndim == 5: current_axes = ( @@ -549,9 +549,9 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) ] ), - v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), - v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), - v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("z"), size=array.shape[2]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[3]), + v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[4]), ) else: raise ValueError(f"Could not guess an axis mapping for {array.shape}") diff --git a/src/bioimageio/core/utils/_type_guards.py b/src/bioimageio/core/utils/_type_guards.py index 0a33b8084..ad61c6911 100644 --- a/src/bioimageio/core/utils/_type_guards.py +++ b/src/bioimageio/core/utils/_type_guards.py @@ -6,3 +6,5 @@ is_list = type_guards.is_list is_ndarray = type_guards.is_ndarray is_tuple = type_guards.is_tuple +is_dict = type_guards.is_dict +is_kwargs = type_guards.is_kwargs diff --git a/src/bioimageio/core/weight_converters/torchscript_to_onnx.py b/src/bioimageio/core/weight_converters/torchscript_to_onnx.py index 26a479dd1..3da2cdbeb 100644 --- a/src/bioimageio/core/weight_converters/torchscript_to_onnx.py +++ b/src/bioimageio/core/weight_converters/torchscript_to_onnx.py @@ -2,6 +2,7 @@ from typing import Optional, Sequence, Union import torch.jit +from loguru import logger from torch._export.converter import TS2EPConverter from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr @@ -57,10 +58,16 @@ def convert( model, # pyright: ignore[reportUnknownArgumentType] torch_sample_inputs, ).convert() + exported_module = exported_program.module() + + try: + exported_module = exported_module.eval() + except Exception as e: + logger.warning("Failed to set TS2EPConverter program to evaluation mode: {}", e) return export_to_onnx( model_descr, - exported_program.module(), + exported_module, output_path, verbose, opset_version, diff --git a/tests/conftest.py b/tests/conftest.py index fe55efd06..7675c93e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import subprocess from itertools import chain +from pathlib import Path from typing import Dict, List from dotenv import load_dotenv @@ -50,7 +51,14 @@ logger.warning("testing with bioimageio.spec {}", bioimageio_spec_version) -EXAMPLE_DESCRIPTIONS = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/" +EXAMPLE_DESCRIPTIONS = ( + LOCAL_EXAMPLE_DESCRIPTIONS.as_posix() + "/" + if ( + LOCAL_EXAMPLE_DESCRIPTIONS := Path(__file__).parent + / "../../spec-bioimage-io/example_descriptions/" + ).exists() + else "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/" +) # TODO: use models from new collection on S3 MODEL_SOURCES: Dict[str, str] = { @@ -116,7 +124,7 @@ "unet2d_nuclei_broad_model", ] ) -ONNX_MODELS = [] if onnxruntime is None else ["hpa_densenet"] +ONNX_MODELS = [] if onnxruntime is None else ["unet2d_nuclei_broad_model"] TENSORFLOW_MODELS = ( [] if tensorflow is None @@ -138,9 +146,8 @@ TENSORFLOW_JS_MODELS: List[str] = [] # TODO: add a tensorflow_js example model ALL_MODELS = sorted( - { - m - for m in chain( + set( + chain( TORCH_MODELS, TORCHSCRIPT_MODELS, ONNX_MODELS, @@ -148,7 +155,7 @@ KERAS_MODELS, TENSORFLOW_JS_MODELS, ) - } + ) ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 7c7545ff3..ddfd92b23 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,7 +6,7 @@ import pytest from pydantic import FilePath -from bioimageio.spec import load_description, settings +from bioimageio.spec import settings def run_subprocess( @@ -72,13 +72,13 @@ def test_cli( def test_empty_cache(tmp_path: Path, unet2d_nuclei_broad_model: str): - from bioimageio.spec.utils import empty_cache + from bioimageio.spec.utils import empty_cache, get_reader origingal_cache_path = settings.cache_path try: settings.cache_path = tmp_path / "cache" assert not settings.cache_path.exists() - _ = load_description(unet2d_nuclei_broad_model, perform_io_checks=False) + _ = get_reader("https://example.com") assert ( len([fn for fn in settings.cache_path.iterdir() if fn.suffix != ".lock"]) == 1 diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index 8a595cf8d..fd7077abb 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -1,13 +1,24 @@ +from concurrent.futures import ThreadPoolExecutor +from functools import partial from pathlib import Path -from numpy.testing import assert_array_almost_equal - +from bioimageio.core import Sample +from bioimageio.core._resource_tests import evaluate_mismatched_elements from bioimageio.core.common import SupportedWeightsFormat from bioimageio.spec import load_description from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 from bioimageio.spec.model.v0_5 import ModelDescr +def _alter_sample(sample: Sample, offset: float) -> Sample: + # add 1 to all values to get a different sample with the same shape and axes + return Sample( + id=f"{sample.id}_altered", + members={m: t + offset for m, t in sample.members.items()}, + stat=sample.stat, + ) + + def _test_prediction_pipeline( model_package: Path, weights_format: SupportedWeightsFormat ): @@ -21,22 +32,42 @@ def _test_prediction_pipeline( assert isinstance(bio_model, (ModelDescr, ModelDescr04)), ( bio_model.validation_summary.format() ) + pp = create_prediction_pipeline( - bioimageio_model=bio_model, weight_format=weights_format + bioimageio_model=bio_model, weight_format=weights_format, devices=["cpu", "cpu"] ) inputs = get_test_input_sample(bio_model) - outputs = pp.predict_sample_without_blocking( - inputs, skip_input_padding=True, skip_output_cropping=True - ) + + # test in a multi-threaded setting + multiple_inputs = [inputs, _alter_sample(inputs, offset=100.0)] + with ThreadPoolExecutor(max_workers=3) as executor: + multiple_outputs = list( + executor.map( + partial( + pp.predict_sample_without_blocking, + skip_input_padding=True, + skip_output_cropping=True, + ), + multiple_inputs, + ) + ) + + outputs = multiple_outputs[0] expected_outputs = get_test_output_sample(bio_model) assert len(outputs.shape) == len(expected_outputs.shape) for m in expected_outputs.members: - out = outputs.members[m].data + out = outputs.members[m] assert out is not None - exp = expected_outputs.members[m].data - assert_array_almost_equal(out, exp, decimal=4) + exp = expected_outputs.members[m] + mismatched_ppm, msg, error_msg = evaluate_mismatched_elements( + out, exp, rtol=0.01, atol=0.1, name=m + ) + if error_msg is not None: + raise AssertionError(error_msg) + elif mismatched_ppm > 50_000: + raise AssertionError(msg) def test_prediction_pipeline_torch(any_torch_model: Path): diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index 7c06bcb46..ad84e97b4 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -8,7 +8,9 @@ from bioimageio.spec.model.v0_5 import ModelDescr -def _test_device_management(model_package: Path, weight_format: SupportedWeightsFormat): +def _test_device_management( + model_package: Path, weight_format: SupportedWeightsFormat, device: str +): import torch from bioimageio.core import load_description @@ -18,18 +20,22 @@ def _test_device_management(model_package: Path, weight_format: SupportedWeights get_test_output_sample, ) - if not hasattr(torch, "cuda") or torch.cuda.device_count() == 0: + if device == "cuda" and ( + not hasattr(torch, "cuda") or torch.cuda.device_count() == 0 + ): pytest.skip("Need at least one cuda device for this test") bio_model = load_description(model_package) assert isinstance(bio_model, (ModelDescr, ModelDescr04)) pred_pipe = create_prediction_pipeline( - bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"] + bioimageio_model=bio_model, weight_format=weight_format, devices=[device] ) inputs = get_test_input_sample(bio_model) with pred_pipe as pp: - outputs = pp.predict_sample_without_blocking(inputs) + outputs = pp.predict_sample_without_blocking( + inputs, skip_input_padding=True, skip_output_cropping=True + ) expected_outputs = get_test_output_sample(bio_model) @@ -38,35 +44,44 @@ def _test_device_management(model_package: Path, weight_format: SupportedWeights out = outputs.members[m].data assert out is not None exp = expected_outputs.members[m].data - assert_array_almost_equal(out, exp, decimal=4) + assert_array_almost_equal(out, exp, decimal=2) # repeat inference with context manager to test load/predict/unload/load/predict with pred_pipe as pp: - outputs = pp.predict_sample_without_blocking(inputs) + outputs = pp.predict_sample_without_blocking( + inputs, skip_input_padding=True, skip_output_cropping=True + ) assert len(outputs.shape) == len(expected_outputs.shape) for m in expected_outputs.members: out = outputs.members[m].data assert out is not None exp = expected_outputs.members[m].data - assert_array_almost_equal(out, exp, decimal=4) + assert_array_almost_equal(out, exp, decimal=2) -def test_device_management_torch(any_torch_model: Path): - _test_device_management(any_torch_model, "pytorch_state_dict") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_management_torch(any_torch_model: Path, device: str): + _test_device_management(any_torch_model, "pytorch_state_dict", device=device) -def test_device_management_torchscript(any_torchscript_model: Path): - _test_device_management(any_torchscript_model, "torchscript") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_management_torchscript(any_torchscript_model: Path, device: str): + _test_device_management(any_torchscript_model, "torchscript", device=device) -def test_device_management_onnx(any_onnx_model: Path): - _test_device_management(any_onnx_model, "onnx") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_management_onnx(any_onnx_model: Path, device: str): + _test_device_management(any_onnx_model, "onnx", device=device) -def test_device_management_tensorflow(any_tensorflow_model: Path): - _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_management_tensorflow(any_tensorflow_model: Path, device: str): + _test_device_management( + any_tensorflow_model, "tensorflow_saved_model_bundle", device=device + ) -def test_device_management_keras(any_keras_model: Path): - _test_device_management(any_keras_model, "keras_hdf5") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_management_keras(any_keras_model: Path, device: str): + _test_device_management(any_keras_model, "keras_hdf5", device=device) diff --git a/tests/test_remote_backends/test_gradio.py b/tests/test_remote_backends/test_gradio.py new file mode 100644 index 000000000..48fee4d4c --- /dev/null +++ b/tests/test_remote_backends/test_gradio.py @@ -0,0 +1,77 @@ +import socket +import sys +import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Process +from typing import List, Tuple + +import pytest +from loguru import logger + +from bioimageio.core import Tensor +from bioimageio.core.common import PerMember + + +def _is_port_in_use(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python 3.10 or higher") +def test_gradio_backend(): + from bioimageio.core import load_model + from bioimageio.core.digest_spec import get_test_input_sample + from bioimageio.core.remote_backends.gradio.client import GradioModelAdapter + from bioimageio.core.remote_backends.gradio.server import main as gradio_server_main + + port = 7860 + if _is_port_in_use(port): + sock = socket.socket() + sock.bind(("", 0)) + _host, port = sock.getsockname() + + server_process = Process(target=gradio_server_main, kwargs={"port": port}) + server_process.start() + + try: + deadline = time.monotonic() + 30 + while time.monotonic() < deadline: + try: + with socket.create_connection(("localhost", port), timeout=1): + break + except OSError: + pass + time.sleep(0.2) + else: + raise TimeoutError(f"gradio server did not become ready on port {port}") + + server_url = f"http://localhost:{port}/" + prepared: List[Tuple[str, GradioModelAdapter, PerMember[Tensor]]] = [] + for model_id in ("affable-shark", "ambitious-sloth"): + model = load_model( + model_id, format_version="latest", perform_io_checks=False + ) + sample = get_test_input_sample(model) + + logger.debug("connecting adapter to {} for {}", server_url, model_id) + adapter = GradioModelAdapter(model, server=server_url) + prepared.append((model_id, adapter, sample.members)) + + # Exercise concurrent requests pooled across both loaded models. + with ThreadPoolExecutor(max_workers=4) as executor: + future_to_model_id = { + executor.submit(adapter.forward, sample_members): model_id + for model_id, adapter, sample_members in prepared + for _ in range(2) + } + + for future, model_id in future_to_model_id.items(): + assert future.result() is not None, model_id + + for model_id, adapter, _sample_members in prepared: + summary = adapter.test() + assert summary is not None + assert summary.status == "passed", f"{model_id}: {summary.display()}" + finally: + server_process.terminate() + server_process.join() diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index 5b09176b7..892fbe77b 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -1,7 +1,9 @@ from pathlib import Path import numpy as np +import pytest +from bioimageio.core import __version__ from bioimageio.spec import InvalidDescr, ValidationContext @@ -48,6 +50,10 @@ def test_loading_description_multiple_times(unet2d_nuclei_broad_model: str): assert not isinstance(model_descr, InvalidDescr) +@pytest.mark.skipif( + __version__ == "0.11.0", + reason="Previously released bioimageio.core 0.10.4 is incompatible with updated unet2d nuclei broad model description 0.5.11", +) def test_test_description_runtime_env(unet2d_nuclei_broad_model: str): from bioimageio.core._resource_tests import test_description diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index e4cd9d926..5437dd53f 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -19,9 +19,11 @@ @pytest.mark.parametrize( "name,axes", - product( - ["mean", "var", "std"], - [None, (AxisId("c"),), (AxisId("x"), AxisId("y"))], + list( + product( + ["mean", "var", "std"], + [None, (AxisId("c"),), (AxisId("x"), AxisId("y"))], + ) ), ) def test_individual_normal_measure( diff --git a/tests/test_weight_converters.py b/tests/test_weight_converters.py index fa7cabdd3..42647e98c 100644 --- a/tests/test_weight_converters.py +++ b/tests/test_weight_converters.py @@ -18,7 +18,7 @@ def test_pytorch_to_torchscript(any_torch_model, tmp_path): pytest.skip("cannot convert to old 0.4 format") out_path = tmp_path / "weights.pt" - ret_val = convert(model_descr, out_path) + ret_val = convert(model_descr, out_path, devices=["cpu"]) assert out_path.exists() assert isinstance(ret_val, v0_5.TorchscriptWeightsDescr) assert ret_val.source == out_path