diff --git a/.gitignore b/.gitignore
index fed8b4316..97361fc06 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,7 @@
.env
.idea/
.tox/
+Unet 2D Nuclei Broad_example/
__pycache__/
bioimageio_cache/
bioimageio_unzipped_tf_weights/
@@ -15,7 +16,8 @@ coverage.xml
dist/
docs/api/
dogfood/
+example/**/output/
+output.zip
pkgs/
site/
typings/pooch/
-example/**/output/
diff --git a/changelog.md b/changelog.md
index 131bc461d..bec314f0a 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,3 +1,10 @@
+### 0.11.0
+
+- bump spec to 0.5.11.0
+- support ONNX provider choice (as 'devices')
+- added experimental bioimageio.core.backends.gradio_backend
+- improved prediction pipeline and model adapter interfaces
+
### 0.10.4
- fix postprocessing order
diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml
index 2552cf93d..89d98362b 100644
--- a/conda-recipe/meta.yaml
+++ b/conda-recipe/meta.yaml
@@ -55,7 +55,7 @@ test:
requires:
{% for dep in pyproject['project']['optional-dependencies']['dev'] %}
{% if 'torch' not in dep %} # can't install pytorch>=2.8 from conda-forge smh
- - {{ dep.replace(';python_version<"3.10"', '').lower().replace('_', '-') }}
+ - {{ dep.replace('[mcp]', '').replace(';python_version>="3.10"', '').replace(';python_version<"3.10"', '').lower().replace('_', '-') }}
{% endif %}
{% endfor %}
commands:
diff --git a/conftest.py b/conftest.py
new file mode 100644
index 000000000..d58a89b2f
--- /dev/null
+++ b/conftest.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+from typing import Any
+
+
+def pytest_ignore_collect(collection_path: Path, config: Any) -> bool:
+ if sys.version_info >= (3, 10):
+ return False
+
+ path = str(collection_path).replace("\\", "/")
+ return "/src/bioimageio/core/remote_backends/gradio/" in path
diff --git a/example/dataset_statistics_demo.ipynb b/example/dataset_statistics_demo.ipynb
index 705fe2957..7e9aff6b8 100644
--- a/example/dataset_statistics_demo.ipynb
+++ b/example/dataset_statistics_demo.ipynb
@@ -29,7 +29,6 @@
"\n",
"from pprint import pprint\n",
"\n",
- "import imageio\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import xarray as xr\n",
@@ -39,7 +38,6 @@
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"import bioimageio.core\n",
- "from bioimageio.core.prediction import predict_with_tiling\n",
"from bioimageio.core.prediction_pipeline import create_prediction_pipeline"
]
},
@@ -196,7 +194,7 @@
"source": [
"def process_dataset(pp, dataset):\n",
" stats = pp._ipt_stats.compute_measures()[\"per_dataset\"]\n",
- " print(f\"initial stats:\")\n",
+ " print(\"initial stats:\")\n",
" pprint(\n",
" None\n",
" if not stats\n",
@@ -205,7 +203,7 @@
" stats = {}\n",
" sample_dataset = [{\"input0\": s} for s in dataset]\n",
" [pp.apply_preprocessing(s, stats) for s in sample_dataset]\n",
- " print(f\"final stats:\")\n",
+ " print(\"final stats:\")\n",
" pprint(\n",
" None\n",
" if not stats\n",
diff --git a/example/export_cellpose_model/cellpose_original.py b/example/export_cellpose_model/cellpose_original.py
index cd431446b..abe8f73b9 100644
--- a/example/export_cellpose_model/cellpose_original.py
+++ b/example/export_cellpose_model/cellpose_original.py
@@ -1,4 +1,6 @@
-"""Run original cellpose model and save an analog input and output for bioimageio tests"""
+"""Run original cellpose model and save an analog input and output for bioimageio tests
+Works on cellpose 4.1, but not since cellpose 4.2
+"""
import os
from pathlib import Path
diff --git a/example/model_usage.ipynb b/example/model_usage.ipynb
index 230ccb035..2e3db5d1b 100644
--- a/example/model_usage.ipynb
+++ b/example/model_usage.ipynb
@@ -495,7 +495,6 @@
],
"source": [
"from bioimageio.core.digest_spec import create_sample_for_model\n",
- "from bioimageio.spec.utils import download\n",
"\n",
"input_paths = {ipt.id: ipt.test_tensor.source for ipt in model.inputs}\n",
"print(f\"input paths: {input_paths}\")\n",
diff --git a/mkdocs.yaml b/mkdocs.yaml
index 8ebd9a4eb..8ad8b12c9 100644
--- a/mkdocs.yaml
+++ b/mkdocs.yaml
@@ -87,7 +87,7 @@ plugins:
python:
inventories:
- https://docs.pydantic.dev/latest/objects.inv
- - https://bioimage-io.github.io/spec-bioimage-io/v0.5.10.2/objects.inv
+ - https://bioimage-io.github.io/spec-bioimage-io/v0.5.11.0/objects.inv
options:
annotations_path: source
backlinks: tree
diff --git a/pyproject.toml b/pyproject.toml
index 6ad14f53d..c78938134 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,7 +6,7 @@ requires-python = ">=3.9"
readme = "README.md"
dynamic = ["version"]
dependencies = [
- "bioimageio.spec ==0.5.10.2",
+ "bioimageio.spec ==0.5.11.0",
"imagecodecs",
"imageio>=2.10",
"loguru",
@@ -52,9 +52,13 @@ partners = [
# "stardist", # for model testing and stardist postprocessing # TODO: add updated stardist to partners env
]
stardist = ["stardist"] # for stardist postprocessing
+gradio-server = ['gradio[mcp];python_version>="3.10"']
+gradio-client = ['gradio_client;python_version>="3.10"']
dev = [
- "cellpose", # for model testing
+ "cellpose<4.2", # for model testing
"crick",
+ 'gradio[mcp];python_version>="3.10"',
+ 'gradio_client;python_version>="3.10"',
"httpx",
"jax",
"jupyter",
@@ -68,7 +72,7 @@ dev = [
'onnx_ir!=0.1.14;python_version<"3.10"', # uses typing.Concatentate which requires py>=3.10
"packaging>=17.0",
"pre-commit",
- "pyright==1.1.408",
+ "pyright==1.1.410",
"pytest-cov",
"pytest",
"python-dotenv",
diff --git a/src/bioimageio/core/__init__.py b/src/bioimageio/core/__init__.py
index ab35c4e2f..85b84a23b 100644
--- a/src/bioimageio/core/__init__.py
+++ b/src/bioimageio/core/__init__.py
@@ -16,12 +16,27 @@
"""
# ruff: noqa: E402
-__version__ = "0.10.4"
+__version__ = "0.11.0"
from loguru import logger
logger.disable("bioimageio.core")
-import bioimageio.spec
+
+from bioimageio.spec import ValidationSummary as ValidationSummary
+from bioimageio.spec import build_description as build_description
+from bioimageio.spec import dump_description as dump_description
+from bioimageio.spec import load_dataset_description as load_dataset_description
+from bioimageio.spec import load_description as load_description
+from bioimageio.spec import (
+ load_description_and_validate_format_only as load_description_and_validate_format_only,
+)
+from bioimageio.spec import load_model_description as load_model_description
+from bioimageio.spec import save_bioimageio_package as save_bioimageio_package
+from bioimageio.spec import (
+ save_bioimageio_package_as_folder as save_bioimageio_package_as_folder,
+)
+from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only
+from bioimageio.spec import validate_format as validate_format
from . import axis as axis
from . import backends as backends
@@ -31,7 +46,6 @@
from . import common as common
from . import digest_spec as digest_spec
from . import io as io
-from . import model_adapters as model_adapters
from . import prediction as prediction
from . import proc_ops as proc_ops
from . import proc_setup as proc_setup
@@ -40,46 +54,40 @@
from . import stat_measures as stat_measures
from . import tensor as tensor
from . import weight_converters as weight_converters
+from ._prediction_pipeline import IntermediatePrediction as IntermediatePrediction
from ._prediction_pipeline import PredictionPipeline as PredictionPipeline
+from ._prediction_pipeline import RemotePredictionPipeline as RemotePredictionPipeline
from ._prediction_pipeline import (
create_prediction_pipeline as create_prediction_pipeline,
)
+from ._prediction_pipeline import (
+ create_remote_prediction_pipeline as create_remote_prediction_pipeline,
+)
from ._resource_tests import enable_determinism as enable_determinism
from ._resource_tests import load_description_and_test as load_description_and_test
from ._resource_tests import test_description as test_description
from ._resource_tests import test_model as test_model
+from ._sample_serializer import SampleSerializer as SampleSerializer
from ._settings import Settings as Settings
from ._settings import settings as settings
-# reexports from bioimageio.spec
-build_description = bioimageio.spec.build_description
-dump_description = bioimageio.spec.dump_description
-load_dataset_description = bioimageio.spec.load_dataset_description
-load_description = bioimageio.spec.load_description
-load_description_and_validate_format_only = (
- bioimageio.spec.load_description_and_validate_format_only
-)
-load_model_description = bioimageio.spec.load_model_description
-save_bioimageio_package = bioimageio.spec.save_bioimageio_package
-save_bioimageio_package_as_folder = bioimageio.spec.save_bioimageio_package_as_folder
-save_bioimageio_yaml_only = bioimageio.spec.save_bioimageio_yaml_only
-validate_format = bioimageio.spec.validate_format
-ValidationSummary = bioimageio.spec.ValidationSummary
-
-
# reexports from bioimageio.core submodules
-add_weights = weight_converters.add_weights
-Axis = axis.Axis
-AxisId = axis.AxisId
-BlockMeta = block_meta.BlockMeta
-compute_dataset_measures = stat_calculators.compute_dataset_measures
-create_model_adapter = backends.create_model_adapter
-MemberId = common.MemberId
-predict = prediction.predict
-predict_many = prediction.predict_many
-Sample = sample.Sample
-Stat = stat_measures.Stat
-Tensor = tensor.Tensor
+from .axis import Axis as Axis
+from .axis import AxisId as AxisId
+from .backends import create_model_adapter as create_model_adapter
+from .block_meta import BlockMeta as BlockMeta
+from .common import MemberId as MemberId
+from .prediction import predict as predict
+from .prediction import predict_many as predict_many
+from .sample import Sample as Sample
+from .sample import SampleBlock as SampleBlock
+from .sample import SampleBlockMeta as SampleBlockMeta
+from .stat_calculators import compute_dataset_measures as compute_dataset_measures
+from .stat_calculators import compute_measures as compute_measures
+from .stat_calculators import compute_sample_measures as compute_sample_measures
+from .stat_measures import Stat as Stat
+from .tensor import Tensor as Tensor
+from .weight_converters import add_weights as add_weights
# aliases
test_resource = test_description
diff --git a/src/bioimageio/core/__main__.py b/src/bioimageio/core/__main__.py
index 123b6a9c9..edcdea06a 100644
--- a/src/bioimageio/core/__main__.py
+++ b/src/bioimageio/core/__main__.py
@@ -14,7 +14,7 @@
+ "{module} - {message}",
)
-from .cli import Bioimageio
+from .cli import Bioimageio # noqa: E402
def main():
diff --git a/src/bioimageio/core/_axis_annotations.py b/src/bioimageio/core/_axis_annotations.py
new file mode 100644
index 000000000..8bd1a7e32
--- /dev/null
+++ b/src/bioimageio/core/_axis_annotations.py
@@ -0,0 +1,9 @@
+from typing import Annotated, TypeVar
+
+from ._common_annotations import PydanticMappingProxyAnnotation
+from .axis import PerAxis
+
+_T = TypeVar("_T")
+
+PerAxisAnno = Annotated[PerAxis[_T], PydanticMappingProxyAnnotation]
+"""PerAxis annotated with `PydanticMappingProxyAnnotation` to be compatible with pydantic models."""
diff --git a/src/bioimageio/core/_common_annotations.py b/src/bioimageio/core/_common_annotations.py
new file mode 100644
index 000000000..bb255e35e
--- /dev/null
+++ b/src/bioimageio/core/_common_annotations.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+from types import MappingProxyType
+from typing import (
+ Annotated,
+ Any,
+ Hashable,
+ Mapping,
+ TypeVar,
+)
+
+import pydantic
+from pydantic_core.core_schema import (
+ CoreSchema,
+ chain_schema,
+ is_instance_schema,
+ json_or_python_schema,
+ no_info_plain_validator_function,
+ plain_serializer_function_ser_schema,
+)
+from typing_extensions import get_args
+
+from .common import PerMember
+
+_K = TypeVar("_K", bound=Hashable)
+_V = TypeVar("_V")
+
+
+def _validate_from_mapping(d: Mapping[_K, _V]) -> MappingProxyType[_K, _V]:
+ return MappingProxyType(dict(d))
+
+
+class PydanticMappingProxyAnnotation:
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler
+ ) -> CoreSchema:
+
+ k_type, v_type = get_args(source_type)
+ mapping_proxy_schema = chain_schema(
+ [
+ handler.generate_schema(dict[k_type, v_type]),
+ no_info_plain_validator_function(_validate_from_mapping),
+ is_instance_schema(MappingProxyType),
+ ]
+ )
+ return json_or_python_schema(
+ json_schema=mapping_proxy_schema,
+ python_schema=mapping_proxy_schema,
+ serialization=plain_serializer_function_ser_schema(dict),
+ )
+
+
+_T = TypeVar("_T")
+
+PerMemberAnno = Annotated[PerMember[_T], PydanticMappingProxyAnnotation]
diff --git a/src/bioimageio/core/_description_serializer.py b/src/bioimageio/core/_description_serializer.py
new file mode 100644
index 000000000..516cda2c4
--- /dev/null
+++ b/src/bioimageio/core/_description_serializer.py
@@ -0,0 +1,69 @@
+import base64
+import hashlib
+from io import BytesIO
+from typing import Tuple
+from zipfile import ZipFile
+
+from bioimageio.spec import (
+ InvalidDescr,
+ ResourceDescr,
+ load_description,
+ save_bioimageio_package_to_stream,
+)
+from bioimageio.spec.common import Sha256
+
+
+class DescriptionSerializer:
+ """Description serializer intended for client/server communication, NOT for sharing resource descriptions.
+
+ This serializer only includes local files to keep the serialized package small.
+ """
+
+ STRING_ENCODING = "ascii"
+
+ @staticmethod
+ def serialize(rd: ResourceDescr) -> bytes:
+ stream = save_bioimageio_package_to_stream(rd, local_files_only=True)
+ _ = stream.seek(0)
+ return stream.read()
+
+ @classmethod
+ def serialize_to_string(cls, rd: ResourceDescr) -> str:
+ package_bytes = cls.serialize(rd)
+
+ safe_bytes = cls._get_safe_bytes(package_bytes)
+ serialized_str = safe_bytes.decode(cls.STRING_ENCODING)
+ if len(serialized_str) <= 2083:
+ raise RuntimeError(
+ "Serialized model description should be longer than 2083 characters to not be treated as a URL on the server side."
+ )
+ return serialized_str
+
+ @staticmethod
+ def _get_safe_bytes(raw_bytes: bytes) -> bytes:
+ return base64.b64encode(raw_bytes)
+
+ @classmethod
+ def deserialize_from_string(cls, serialized: str) -> ResourceDescr:
+ package_bytes = base64.b64decode(serialized.encode(cls.STRING_ENCODING))
+ return cls.deserialize(package_bytes)
+
+ @staticmethod
+ def deserialize(serialized: bytes) -> ResourceDescr:
+ descr = load_description(ZipFile(BytesIO(serialized)), perform_io_checks=False)
+ if isinstance(descr, InvalidDescr):
+ raise ValueError(f"invalid serialized model package: {descr.get_reason()}")
+
+ return descr
+
+ @classmethod
+ def serialize_to_string_and_hash(cls, rd: ResourceDescr) -> Tuple[str, Sha256]:
+ package_bytes = cls.serialize(rd)
+ safe_bytes = cls._get_safe_bytes(package_bytes)
+ serialized_str = safe_bytes.decode(cls.STRING_ENCODING)
+ if len(serialized_str) <= 2083:
+ raise RuntimeError(
+ "Serialized model description should be longer than 2083 characters to not be treated as a URL on the server side."
+ )
+ sha256 = Sha256(hashlib.sha256(package_bytes).hexdigest())
+ return serialized_str, sha256
diff --git a/src/bioimageio/core/_magic_tensor_ops.py b/src/bioimageio/core/_magic_tensor_ops.py
index 39084bccb..e60b7a736 100644
--- a/src/bioimageio/core/_magic_tensor_ops.py
+++ b/src/bioimageio/core/_magic_tensor_ops.py
@@ -38,31 +38,31 @@ def __mul__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.mul)
def __pow__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.pow)
+ return self._binary_op(other, operator.pow) # pyright: ignore[reportUnknownArgumentType]
def __truediv__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.truediv)
+ return self._binary_op(other, operator.truediv) # pyright: ignore[reportUnknownArgumentType]
def __floordiv__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.floordiv)
+ return self._binary_op(other, operator.floordiv) # pyright: ignore[reportUnknownArgumentType]
def __mod__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.mod)
def __and__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.and_)
+ return self._binary_op(other, operator.and_) # pyright: ignore[reportUnknownArgumentType]
def __xor__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.xor)
+ return self._binary_op(other, operator.xor) # pyright: ignore[reportUnknownArgumentType]
def __or__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.or_)
+ return self._binary_op(other, operator.or_) # pyright: ignore[reportUnknownArgumentType]
def __lshift__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.lshift)
+ return self._binary_op(other, operator.lshift) # pyright: ignore[reportUnknownArgumentType]
def __rshift__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.rshift)
+ return self._binary_op(other, operator.rshift) # pyright: ignore[reportUnknownArgumentType]
def __lt__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.lt)
@@ -102,25 +102,25 @@ def __rmul__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)
def __rpow__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.pow, reflexive=True)
+ return self._binary_op(other, operator.pow, reflexive=True) # pyright: ignore[reportUnknownArgumentType]
def __rtruediv__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.truediv, reflexive=True)
+ return self._binary_op(other, operator.truediv, reflexive=True) # pyright: ignore[reportUnknownArgumentType]
def __rfloordiv__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.floordiv, reflexive=True)
+ return self._binary_op(other, operator.floordiv, reflexive=True) # pyright: ignore[reportUnknownArgumentType]
def __rmod__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)
def __rand__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.and_, reflexive=True)
+ return self._binary_op(other, operator.and_, reflexive=True) # pyright: ignore[reportUnknownArgumentType]
def __rxor__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.xor, reflexive=True)
+ return self._binary_op(other, operator.xor, reflexive=True) # pyright: ignore[reportUnknownArgumentType]
def __ror__(self, other: _Compatible) -> Self:
- return self._binary_op(other, operator.or_, reflexive=True)
+ return self._binary_op(other, operator.or_, reflexive=True) # pyright: ignore[reportUnknownArgumentType]
def _inplace_binary_op(
self, other: _Compatible, f: Callable[[Any, Any], Any]
@@ -128,40 +128,40 @@ def _inplace_binary_op(
raise NotImplementedError
def __iadd__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.iadd)
+ return self._inplace_binary_op(other, operator.iadd) # pyright: ignore[reportUnknownArgumentType]
def __isub__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.isub)
+ return self._inplace_binary_op(other, operator.isub) # pyright: ignore[reportUnknownArgumentType]
def __imul__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.imul)
+ return self._inplace_binary_op(other, operator.imul) # pyright: ignore[reportUnknownArgumentType]
def __ipow__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.ipow)
+ return self._inplace_binary_op(other, operator.ipow) # pyright: ignore[reportUnknownArgumentType]
def __itruediv__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.itruediv)
+ return self._inplace_binary_op(other, operator.itruediv) # pyright: ignore[reportUnknownArgumentType]
def __ifloordiv__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.ifloordiv)
+ return self._inplace_binary_op(other, operator.ifloordiv) # pyright: ignore[reportUnknownArgumentType]
def __imod__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.imod)
+ return self._inplace_binary_op(other, operator.imod) # pyright: ignore[reportUnknownArgumentType]
def __iand__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.iand)
+ return self._inplace_binary_op(other, operator.iand) # pyright: ignore[reportUnknownArgumentType]
def __ixor__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.ixor)
+ return self._inplace_binary_op(other, operator.ixor) # pyright: ignore[reportUnknownArgumentType]
def __ior__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.ior)
+ return self._inplace_binary_op(other, operator.ior) # pyright: ignore[reportUnknownArgumentType]
def __ilshift__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.ilshift)
+ return self._inplace_binary_op(other, operator.ilshift) # pyright: ignore[reportUnknownArgumentType]
def __irshift__(self, other: _Compatible) -> Self:
- return self._inplace_binary_op(other, operator.irshift)
+ return self._inplace_binary_op(other, operator.irshift) # pyright: ignore[reportUnknownArgumentType]
def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError
diff --git a/src/bioimageio/core/_model_adapter.py b/src/bioimageio/core/_model_adapter.py
new file mode 100644
index 000000000..cfbe8033b
--- /dev/null
+++ b/src/bioimageio/core/_model_adapter.py
@@ -0,0 +1,279 @@
+import gc
+import warnings
+from abc import ABC, abstractmethod
+from queue import LifoQueue
+from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, Union
+
+from exceptiongroup import ExceptionGroup
+from loguru import logger
+from numpy.typing import NDArray
+from typing_extensions import TypeVar
+
+from bioimageio.spec import ValidationSummary
+from bioimageio.spec.model import AnyModelDescr, v0_4
+
+from ._sample_serializer import SampleSerializer, SerializedSampleBlockType
+from .common import PerMember
+from .digest_spec import get_axes_infos, get_member_ids
+from .sample import Sample
+from .tensor import Tensor
+
+
+class ModelAdapter(ABC):
+ """
+ Represents model *without* any preprocessing or postprocessing.
+
+ ```
+ from bioimageio.core import load_description
+
+ model = load_description(...)
+
+ # option 1:
+ adapter = create_model_adapter(model)
+ adapter.forward(...)
+ adapter.unload()
+
+ # option 2:
+ with create_model_adapter(model) as adapter:
+ adapter.forward(...)
+ ```
+ """
+
+ def __init__(
+ self, model_description: AnyModelDescr, devices: Optional[Sequence[str]]
+ ):
+ super().__init__()
+ self._model_descr = model_description
+ self._input_ids = get_member_ids(model_description.inputs)
+ self._output_ids = get_member_ids(model_description.outputs)
+ self._input_axes = [
+ tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
+ ]
+ self._output_axes = [
+ tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
+ ]
+ if isinstance(model_description, v0_4.ModelDescr):
+ self._input_is_optional = [False] * len(model_description.inputs)
+ else:
+ self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
+
+ self._devices = devices
+ self.load()
+
+ @property
+ def model_descr(self) -> AnyModelDescr:
+ return self._model_descr
+
+ @abstractmethod
+ def load(self) -> None:
+ self._loaded = True
+
+ @abstractmethod
+ def forward(
+ self, inputs: PerMember[Optional[Tensor]]
+ ) -> PerMember[Optional[Tensor]]: ...
+
+ @abstractmethod
+ def unload(self):
+ """Unload model from any devices, freeing their memory.
+
+ Note:
+ The moder adapter should be considered unusable afterwards.
+ """
+ self._loaded = False
+
+ def close(self):
+ """Close the model adapter, freeing any resources.
+
+ Note:
+ The moder adapter should be considered unusable afterwards.
+ """
+ self.unload()
+
+
+DeviceType = TypeVar("DeviceType")
+ModelType = TypeVar("ModelType")
+
+
+class LocalModelAdapter(ModelAdapter, ABC, Generic[DeviceType, ModelType]):
+ def load(self) -> None:
+ devices = self._devices
+ self._model_queue: LifoQueue[Tuple[DeviceType, ModelType]] = LifoQueue()
+ parsed_devices = self._parse_devices(devices)
+ assert parsed_devices
+ # prioritize devices by order specified by user
+ device_exceptions: Dict[str, Exception] = {}
+ self._initialized_devices: List[str] = []
+ for d in parsed_devices[::-1]:
+ try:
+ model = self._init_model_on_device(d)
+ except Exception as e:
+ device_exceptions[str(d)] = e
+ else:
+ self._model_queue.put((d, model))
+ self._initialized_devices.insert(0, str(d))
+
+ if self._model_queue.empty():
+ if len(device_exceptions) == 1:
+ raise next(iter(device_exceptions.values()))
+ else:
+ raise ExceptionGroup(
+ "Failed to initialize model on any of the requested devices.",
+ list(device_exceptions.values())[::-1],
+ )
+
+ if device_exceptions:
+ logger.warning(
+ "Failed to initialize model on some of the requested devices. Successfully initialized on {}, but got the following errors for other devices: {}",
+ self._initialized_devices,
+ device_exceptions,
+ )
+
+ super().load()
+
+ @abstractmethod
+ def _parse_devices(self, devices: Optional[Sequence[str]]) -> Sequence[DeviceType]:
+ """Parse devices
+
+ Note:
+ - May not return an empty sequence.
+ - The order of devices in the returned sequence determines the priority of device usage in the forward pass.
+ First devices has highgest priority, last device has lowest priority.
+ """
+
+ @abstractmethod
+ def _init_model_on_device(self, device: DeviceType) -> ModelType: ...
+
+ def forward(
+ self, inputs: PerMember[Optional[Tensor]]
+ ) -> PerMember[Optional[Tensor]]:
+ """
+ Run forward pass of model to get model predictions
+
+ Note: sample id and stample stat attributes are passed through
+ """
+ if not self._loaded:
+ raise RuntimeError("Model must be `.load()`ed before calling forward()")
+
+ unexpected = [mid for mid in inputs if mid not in self._input_ids]
+ if unexpected:
+ warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
+
+ input_arrays = [
+ (
+ None
+ if (a := inputs.get(in_id)) is None
+ else a.transpose(in_order).data.data
+ )
+ for in_id, in_order in zip(self._input_ids, self._input_axes)
+ ]
+ logger.debug(
+ "NN input shapes: {}",
+ [a.shape if a is not None else None for a in input_arrays],
+ )
+ device, model = self._model_queue.get()
+ try:
+ output_arrays = self._forward_impl(device, model, input_arrays)
+ finally:
+ self._model_queue.put((device, model))
+
+ logger.debug(
+ "NN output shapes: {}",
+ [a.shape if a is not None else None for a in output_arrays],
+ )
+ if len(output_arrays) > len(self._output_ids):
+ warnings.warn(
+ f"Model produced more outputs ({len(output_arrays)}) than specified in the model description ({len(self._output_ids)}). Extra outputs will be ignored."
+ )
+ output_arrays = output_arrays[: len(self._output_ids)]
+
+ output_tensors = [
+ None if a is None else Tensor(a, dims=d)
+ for a, d in zip(output_arrays, self._output_axes)
+ ]
+ return {
+ tid: out
+ for tid, out in zip(
+ self._output_ids,
+ output_tensors,
+ )
+ if out is not None
+ }
+
+ @abstractmethod
+ def _forward_impl(
+ self,
+ device: DeviceType,
+ model: ModelType,
+ input_arrays: Sequence[Optional[NDArray[Any]]],
+ ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]], ...]]:
+ """framework specific forward implementation"""
+
+ def unload(self):
+ for _ in range(len(self._initialized_devices)):
+ device, model = self._model_queue.get()
+ try:
+ self._cleanup_pre_model_deletion(device, model)
+ except Exception as e:
+ logger.warning(
+ "Got error during pre-deletion cleanup on device {}: {}", device, e
+ )
+ finally:
+ del model
+ try:
+ self._cleanup_post_model_deletion(device)
+ except Exception as e:
+ logger.warning(
+ "Got error during post-deletion cleanup on device {}: {}", device, e
+ )
+
+ _ = gc.collect() # deallocate memory
+ super().unload()
+
+ @abstractmethod
+ def _cleanup_pre_model_deletion(self, device: DeviceType, model: ModelType) -> None:
+ """Clean up before model reference deletion"""
+
+ @abstractmethod
+ def _cleanup_post_model_deletion(self, device: DeviceType) -> None:
+ """Clean up after model reference deletion"""
+
+
+class RemoteModelAdapter(ModelAdapter, ABC, Generic[SerializedSampleBlockType]):
+ """Model adapter to use a remote service for model inference."""
+
+ def __init__(
+ self,
+ model_description: AnyModelDescr,
+ server: str,
+ sample_serializer: SampleSerializer[SerializedSampleBlockType],
+ ):
+ super().__init__(model_description, devices=None)
+ self._server = server
+ self._serializer = sample_serializer
+
+ @property
+ def server(self) -> str:
+ return self._server
+
+ def forward(
+ self, inputs: PerMember[Optional[Tensor]]
+ ) -> PerMember[Optional[Tensor]]:
+ serialized_input = self._serializer.serialize_sample(
+ Sample(
+ members={k: v for k, v in inputs.items() if v is not None},
+ stat={},
+ id=None,
+ )
+ )
+ serialized_output = self._forward_impl(serialized_input)
+ return self._serializer.deserialize_sample(serialized_output).members
+
+ @abstractmethod
+ def _forward_impl(
+ self, serialized_input_sample: Iterable[SerializedSampleBlockType]
+ ) -> Iterable[SerializedSampleBlockType]: ...
+
+ @abstractmethod
+ def test(self) -> Optional[ValidationSummary]:
+ """Run the bioimageio model test."""
diff --git a/src/bioimageio/core/_prediction_pipeline.py b/src/bioimageio/core/_prediction_pipeline.py
index db769a3aa..04fac449a 100644
--- a/src/bioimageio/core/_prediction_pipeline.py
+++ b/src/bioimageio/core/_prediction_pipeline.py
@@ -1,4 +1,5 @@
import warnings
+from abc import ABC, abstractmethod
from types import MappingProxyType
from typing import (
Any,
@@ -6,6 +7,7 @@
List,
Literal,
Mapping,
+ NamedTuple,
Optional,
Sequence,
Tuple,
@@ -15,11 +17,14 @@
from loguru import logger
from tqdm import tqdm
+from typing_extensions import assert_never
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
+from ._model_adapter import ModelAdapter
from ._op_base import BlockwiseOperator, SamplewiseOperator
from .axis import AxisId, PerAxis
+from .backends import create_model_adapter
from .common import (
BlocksizeParameter,
Halo,
@@ -33,8 +38,6 @@
get_input_halo,
get_member_ids,
)
-from .model_adapters import ModelAdapter, create_model_adapter
-from .model_adapters import get_weight_formats as get_weight_formats
from .proc_ops import Processing
from .proc_setup import setup_pre_and_postprocessing
from .sample import Sample, SampleBlock
@@ -48,7 +51,236 @@
)
-class PredictionPipeline:
+class IntermediatePrediction(NamedTuple):
+ """Represents an intermediate prediction of a sample with blocking, including the predicted sample so far and the last predicted block.
+
+ The final `IntermediatePrediction` in a sequence holds the complete predicted (and postprocessed if applicable) sample."""
+
+ sample: Sample
+ last_block: SampleBlock
+
+
+class _PredictionPipelineBase(ABC):
+ def __init__(
+ self,
+ model_descr: AnyModelDescr,
+ *,
+ default_blocksize_parameter: BlocksizeParameter,
+ default_batch_size: int,
+ ) -> None:
+ super().__init__()
+ self._model_descr = model_descr
+ self._default_blocksize_parameter = default_blocksize_parameter
+ self._default_batch_size = default_batch_size
+
+ if isinstance(model_descr, v0_4.ModelDescr):
+ self._default_output_halo: PerMember[PerAxis[Halo]] = {}
+ self._default_input_halo: PerMember[PerAxis[Halo]] = {}
+ self._block_transform = None
+ else:
+ self._default_output_halo = {
+ t.id: {
+ a.id: Halo(a.halo, a.halo)
+ for a in t.axes
+ if isinstance(a, v0_5.WithHalo)
+ }
+ for t in model_descr.outputs
+ }
+ self._default_input_halo = get_input_halo(
+ model_descr, self._default_output_halo
+ )
+ self._block_transform = get_block_transform(model_descr)
+
+ self.pad_mode = (
+ {}
+ if isinstance(model_descr, v0_4.ModelDescr)
+ else {
+ descr.id: descr.pad or v0_5.SymmetricPadding()
+ for descr in model_descr.inputs
+ }
+ )
+
+ @property
+ def model_descr(self) -> AnyModelDescr:
+ return self._model_descr
+
+ @property
+ def model_description(self) -> AnyModelDescr:
+ return self._model_descr
+
+ @abstractmethod
+ def predict_sample_without_blocking(
+ self,
+ sample: Sample,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ skip_input_padding: bool = False,
+ skip_output_cropping: bool = False,
+ ) -> Sample:
+ """Predict a whole sample at once.
+
+ Note:
+ The sample's tensor shapes have to match the model's input tensor description.
+ If that is not the case, consider `predict_sample_with_blocking`
+
+ Args:
+ sample: input sample
+ skip_preprocessing: if `True`, skip all preprocessing steps.
+ skip_postprocessing: if `True`, skip all postprocessing steps.
+ skip_input_padding: if `True`, skip padding the input sample according to the model's (optional) output halos.
+ skip_output_cropping: if `True`, skip cropping any output halos from the model output.
+ """
+
+ def predict_sample_with_blocking(
+ self,
+ sample: Sample,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ ns: Optional[
+ Union[
+ v0_5.ParameterizedSize_N,
+ Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
+ ]
+ ] = None,
+ batch_size: Optional[int] = None,
+ ) -> Sample:
+ """Predict a sample by predicting sample blocks.
+
+ Note: For fixed/known blocksizes use `predict_sample_with_fixed_blocking`.
+
+ Args:
+ sample: The sample to predict on.
+ skip_preprocessing: If `True`, skip all preprocessing steps.
+ skip_postprocessing: If `True`, skip all postprocessing steps.
+ ns: Block size parameter(s) allows scaling the model's default input block size.
+ Blocksize parameters are only applied to parameterized input axes, all other axis sizes are fixed/derived or (for output axes) data dependent.
+ Unapplicable blocksize parameters are ignored.
+ batch_size: Batch size to use for prediction.
+ """
+ output = None
+ for output in self.predict_sample_with_blocking_yield_intermediates(
+ sample,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ ns=ns,
+ batch_size=batch_size,
+ )[1]:
+ pass
+
+ assert output is not None, (
+ "No blocks were predicted, cannot return final sample."
+ )
+ return output.sample
+
+ def predict_sample_with_fixed_blocking(
+ self,
+ sample: Sample,
+ input_block_shape: PerMember[PerAxis[int]],
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ ) -> Sample:
+ """Predict `sample` with given `input_block_shape`.
+
+ Note:
+ - `input_block_shape` is expected to be a valid input shape for the model.
+ - Use `predict_sample_with_blocking` if you want to control block sizes via generic block size parameters rather than fixed block shapes.
+
+ Args:
+ sample: The sample to predict on.
+ input_block_shape: Mapping of input member id to mapping of axis id to block size for that axis.
+ skip_preprocessing: If `True`, skip all preprocessing steps.
+ skip_postprocessing: If `True`, skip all postprocessing steps.
+ """
+ intermediate = None
+ for intermediate in self.predict_sample_with_fixed_blocking_yield_intermediates(
+ sample,
+ input_block_shape=input_block_shape,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ )[1]:
+ pass
+
+ assert intermediate is not None, (
+ "No blocks were predicted, cannot return final sample."
+ )
+ return intermediate.sample
+
+ def predict_sample_with_blocking_yield_intermediates(
+ self,
+ sample: Sample,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ ns: Optional[
+ Union[
+ v0_5.ParameterizedSize_N,
+ Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
+ ]
+ ] = None,
+ batch_size: Optional[int] = None,
+ ) -> Tuple[int, Iterable[IntermediatePrediction]]:
+ """Predict `sample` by predicting sample blocks and yield intermediate predictions if no samplewise postprocessing is included.
+
+ Returns:
+ Tuple of number of blocks and an iterator of predicted intermediate samples with the last predicted block,
+ All samples, but the last one, are intermediate samples with more and more blocks predicted.
+ In case samplewise postprocessing needs to be applied, no intermediate results are yielded, but only the final sample after all blocks are predicted and postprocessed.
+ """
+ if isinstance(self._model_descr, v0_4.ModelDescr):
+ raise NotImplementedError(
+ "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
+ + f" {self._model_descr.name}."
+ + " Consider using `predict_sample_with_fixed_blocking`"
+ )
+
+ ns = ns or self._default_blocksize_parameter
+ if isinstance(ns, int):
+ ns = {
+ (ipt.id, a.id): ns
+ for ipt in self._model_descr.inputs
+ for a in ipt.axes
+ if isinstance(a.size, v0_5.ParameterizedSize)
+ }
+ input_block_shape = self._model_descr.get_tensor_sizes(
+ ns, batch_size or self._default_batch_size
+ ).inputs
+
+ return self.predict_sample_with_fixed_blocking_yield_intermediates(
+ sample,
+ input_block_shape=input_block_shape,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ )
+
+ @abstractmethod
+ def predict_sample_with_fixed_blocking_yield_intermediates(
+ self,
+ sample: Sample,
+ input_block_shape: PerMember[PerAxis[int]],
+ *,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ fill_value: float = float("nan"),
+ ) -> Tuple[int, Iterable[IntermediatePrediction]]: ...
+
+ @abstractmethod
+ def predict_sample_block(
+ self,
+ sample_block: SampleBlock,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ ) -> SampleBlock:
+ """Predict a single sample block.
+
+ Note that this does not apply samplewise preprocessing or postprocessing steps, but only blockwise ones.
+
+ Args:
+ sample_block: The sample block to predict on.
+ skip_preprocessing: If `True`, skip blockwise preprocessing steps.
+ skip_postprocessing: If `True`, skip blockwise postprocessing steps.
+ """
+
+
+class PredictionPipeline(_PredictionPipelineBase):
"""
Represents model computation including preprocessing and postprocessing
Note: Ideally use the `PredictionPipeline` in a with statement
@@ -63,19 +295,15 @@ def __init__(
preprocessing: List[Processing],
postprocessing: List[Processing],
model_adapter: ModelAdapter,
- default_ns: Optional[BlocksizeParameter] = None,
default_blocksize_parameter: BlocksizeParameter = 10,
default_batch_size: int = 1,
) -> None:
"""Consider using `create_prediction_pipeline` to create a `PredictionPipeline` with sensible defaults."""
- super().__init__()
- default_blocksize_parameter = default_ns or default_blocksize_parameter
- if default_ns is not None:
- warnings.warn(
- "Argument `default_ns` is deprecated in favor of"
- + " `default_blocksize_paramter` and will be removed soon."
- )
- del default_ns
+ super().__init__(
+ model_descr=model_description,
+ default_blocksize_parameter=default_blocksize_parameter,
+ default_batch_size=default_batch_size,
+ )
if model_description.run_mode:
warnings.warn(
@@ -108,40 +336,10 @@ def __init__(
else:
self._samplewise_postprocessing.append(op)
- self.pad_mode = (
- {}
- if isinstance(model_description, v0_4.ModelDescr)
- else {
- descr.id: descr.pad or v0_5.SymmetricPadding()
- for descr in model_description.inputs
- }
- )
- self.model_description = model_description
- if isinstance(model_description, v0_4.ModelDescr):
- self._default_output_halo: PerMember[PerAxis[Halo]] = {}
- self._default_input_halo: PerMember[PerAxis[Halo]] = {}
- self._block_transform = None
- else:
- self._default_output_halo = {
- t.id: {
- a.id: Halo(a.halo, a.halo)
- for a in t.axes
- if isinstance(a, v0_5.WithHalo)
- }
- for t in model_description.outputs
- }
- self._default_input_halo = get_input_halo(
- model_description, self._default_output_halo
- )
- self._block_transform = get_block_transform(model_description)
-
- self._default_blocksize_parameter = default_blocksize_parameter
- self._default_batch_size = default_batch_size
-
self._input_ids = get_member_ids(model_description.inputs)
self._output_ids = get_member_ids(model_description.outputs)
- self._adapter: ModelAdapter = model_adapter
+ self._adapter = model_adapter
def __enter__(self):
self.load()
@@ -152,14 +350,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
return False
@property
- def has_blockwise_preprocessing(self) -> bool:
- """`True` if all preprocessing operators in the pipeline are blockwise."""
- return bool(self._blockwise_preprocessing)
+ def has_non_blockwise_preprocessing(self) -> bool:
+ """`True` if any preprocessing operators in the pipeline are not applicable blockwise."""
+ return bool(self._samplewise_preprocessing)
@property
- def has_blockwise_postprocessing(self) -> bool:
- """`True` if all postprocessing operators in the pipeline are blockwise."""
- return bool(self._blockwise_postprocessing)
+ def has_non_blockwise_postprocessing(self) -> bool:
+ """`True` if any postprocessing operators in the pipeline are not applicable blockwise."""
+ return bool(self._samplewise_postprocessing)
def _raise_for_non_blockwise_processing(
self, proc_type: Literal["preprocessing", "postprocessing"]
@@ -197,28 +395,25 @@ def predict_sample_block(
skip_preprocessing: bool = False,
skip_postprocessing: bool = False,
) -> SampleBlock:
- if isinstance(self.model_description, v0_4.ModelDescr):
+ if isinstance(self._model_descr, v0_4.ModelDescr):
raise NotImplementedError(
- f"predict_sample_block not implemented for model {self.model_description.format_version}"
+ f"predict_sample_block not implemented for model {self._model_descr.format_version}"
)
else:
assert self._block_transform is not None
if not skip_preprocessing:
- self.raise_for_non_blockwise_preprocessing()
-
- if not skip_postprocessing:
- self.raise_for_non_blockwise_postprocessing()
-
- if not skip_preprocessing:
- self.apply_preprocessing(sample_block)
+ self._apply_blockwise_preprocessing(sample_block)
output_meta = sample_block.get_transformed_meta(self._block_transform)
- local_output = self._adapter.forward(sample_block)
+ local_output = self._adapter.forward(sample_block.members)
- output = output_meta.with_data(local_output.members, stat=local_output.stat)
+ output = output_meta.with_data(
+ {k: v for k, v in local_output.items() if v is not None},
+ stat=sample_block.stat,
+ )
if not skip_postprocessing:
- self.apply_postprocessing(output)
+ self._apply_blockwise_postprocessing(output)
return output
@@ -230,26 +425,21 @@ def predict_sample_without_blocking(
skip_input_padding: bool = False,
skip_output_cropping: bool = False,
) -> Sample:
- """predict a whole sample
-
- Args:
- sample: input sample
- skip_preprocessing: if `True`, skip all preprocessing steps.
- skip_postprocessing: if `True`, skip all postprocessing steps.
- skip_input_padding: if `True`, skip padding the input sample according to the model's (optional) output halos.
- skip_output_cropping: if `True`, skip cropping any output halos from the model output.
- Note:
- The sample's tensor shapes have to match the model's input tensor description.
- If that is not the case, consider `predict_sample_with_blocking`
- """
-
if not skip_input_padding:
sample = sample.pad(pad_width=self._default_input_halo, mode=self.pad_mode)
if not skip_preprocessing:
self.apply_preprocessing(sample)
- output = self._adapter.forward(sample)
+ output = Sample(
+ members={
+ k: v
+ for k, v in self._adapter.forward(sample.members).items()
+ if v is not None
+ },
+ stat=sample.stat,
+ id=sample.id,
+ )
if not skip_postprocessing:
self.apply_postprocessing(output)
@@ -276,143 +466,225 @@ def get_output_sample_id(self, input_sample_id: SampleId):
)
return input_sample_id
- def predict_sample_with_fixed_blocking(
+ def predict_sample_with_blocking(
+ self,
+ sample: Sample,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ ns: Optional[
+ Union[
+ v0_5.ParameterizedSize_N,
+ Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
+ ]
+ ] = None,
+ batch_size: Optional[int] = None,
+ ) -> Sample:
+ output = None
+ for output in self.predict_sample_with_blocking_yield_intermediates(
+ sample,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ ns=ns,
+ batch_size=batch_size,
+ )[1]:
+ pass
+
+ assert output is not None, (
+ "No blocks were predicted, cannot return final sample."
+ )
+ return output.sample
+
+ def predict_sample_with_fixed_blocking_yield_intermediates(
self,
sample: Sample,
input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
*,
skip_preprocessing: bool = False,
skip_postprocessing: bool = False,
- ) -> Sample:
- """Predict `sample` with given `input_block_shape`.
+ fill_value: float = float("nan"),
+ ) -> Tuple[int, Iterable[IntermediatePrediction]]:
+ """Predict `sample` with given `input_block_shape` and yield the full sample with intermediate results.
+
+ Note:
+ - `input_block_shape` is expected to be a valid input shape for the model.
+ - Use `predict_sample_with_blocking` if you want to control block sizes via generic block size parameters
+ rather than fixed block shapes.
+ - Postprocessing may only be complete for the final sample (if samplewise postprocessing steps are included
+ in the pipeline), intermediate samples may have some (blockwise applicable) postprocessing steps applied.
- Note:
- `input_block_shape` is expected to be a valid input shape for the model.
+ Args:
+ sample: The sample to predict on.
+ input_block_shape: Mapping of input member id to mapping of axis id to block size for that axis.
+ skip_preprocessing: If `True`, skip all preprocessing steps.
+ skip_postprocessing: If `True`, skip all postprocessing steps.
+
+ Returns:
+ Tuple of number of blocks and an iterable of predicted intermediate samples with the last predicted block,
+ All samples, but the last one, are intermediate samples with more and more blocks predicted.
"""
+
if not skip_preprocessing:
- for op in self._samplewise_preprocessing:
- op(sample)
+ self._apply_samplewise_preprocessing(sample)
n_blocks, input_blocks = sample.split_into_blocks(
input_block_shape,
halo=self._default_input_halo,
pad_mode=self.pad_mode,
)
- input_blocks = list(input_blocks)
- predicted_blocks: List[SampleBlock] = []
logger.info(
"split sample shape {} into {} blocks of {}.",
{k: dict(v) for k, v in sample.shape.items()},
n_blocks,
{k: dict(v) for k, v in input_block_shape.items()},
)
- for b in tqdm(
- input_blocks,
- desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}",
- unit="block",
- unit_divisor=1,
- total=n_blocks,
- ):
- if not skip_preprocessing:
- for op in self._blockwise_preprocessing:
- op(b)
-
- predicted_blocks.append(
- self.predict_sample_block(
+
+ def _predict_blocks():
+ predicted_sample = None
+ for i, b in enumerate(
+ tqdm(
+ input_blocks,
+ desc=f"predict sample {sample.id or ''} with {self._model_descr.id or self._model_descr.name}",
+ unit="block",
+ unit_divisor=1,
+ total=n_blocks,
+ )
+ ):
+ if not skip_preprocessing:
+ self._apply_blockwise_preprocessing(b)
+
+ predicted_block = self.predict_sample_block(
b, skip_preprocessing=True, skip_postprocessing=True
)
- )
- if not skip_postprocessing:
- for op in self._blockwise_postprocessing:
- op(predicted_blocks[-1])
- predicted_sample = Sample.from_blocks(predicted_blocks)
- if not skip_postprocessing:
- for op in self._samplewise_postprocessing:
- op(predicted_sample)
+ if not skip_postprocessing:
+ self._apply_blockwise_postprocessing(predicted_block)
- return predicted_sample
+ if predicted_sample is None:
+ predicted_sample = Sample.from_blocks(
+ [predicted_block], fill_value=fill_value
+ )
+ else:
+ predicted_sample.set_block(predicted_block)
- def predict_sample_with_blocking(
- self,
- sample: Sample,
- skip_preprocessing: bool = False,
- skip_postprocessing: bool = False,
- ns: Optional[
- Union[
- v0_5.ParameterizedSize_N,
- Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
- ]
- ] = None,
- batch_size: Optional[int] = None,
- ) -> Sample:
- """Predict a sample by splitting it into blocks according to the mode
+ if not skip_postprocessing and i == n_blocks - 1:
+ self._apply_samplewise_postprocessing(predicted_sample)
+
+ yield IntermediatePrediction(predicted_sample, predicted_block)
+
+ return n_blocks, _predict_blocks()
- The `ns` parameter allow scaling the model's default input block size.
+ def _apply_samplewise_preprocessing(self, sample: Sample, /) -> None:
+ """Apply preprocessing operators up to and including the last samplewise operator in-place.
+
+ Note: This skips all blockwise preprocessing steps after the last samplewise operator.
"""
+ if isinstance(sample, SampleBlock):
+ self.raise_for_non_blockwise_preprocessing()
- if isinstance(self.model_description, v0_4.ModelDescr):
- raise NotImplementedError(
- "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
- + f" {self.model_description.name}."
- + " Consider using `predict_sample_with_fixed_blocking`"
- )
+ for op in self._samplewise_preprocessing:
+ op(sample)
- ns = ns or self._default_blocksize_parameter
- if isinstance(ns, int):
- ns = {
- (ipt.id, a.id): ns
- for ipt in self.model_description.inputs
- for a in ipt.axes
- if isinstance(a.size, v0_5.ParameterizedSize)
- }
- input_block_shape = self.model_description.get_tensor_sizes(
- ns, batch_size or self._default_batch_size
- ).inputs
+ def _apply_blockwise_preprocessing(
+ self, sample_block: Union[Sample, SampleBlock], /
+ ) -> None:
+ """Apply blockwise preprocessing operators in-place.
- return self.predict_sample_with_fixed_blocking(
- sample,
- input_block_shape=input_block_shape,
- skip_preprocessing=skip_preprocessing,
- skip_postprocessing=skip_postprocessing,
- )
+ Note: This skips all preprocessing operators up to and including the last samplewise one.
+ """
+ for op in self._blockwise_preprocessing:
+ op(sample_block)
def apply_preprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
- """apply preprocessing in-place, also may updates sample stats"""
- if isinstance(sample, SampleBlock):
+ """Apply preprocessing in-place, also may updates sample stats"""
+
+ if isinstance(sample, Sample):
+ self._apply_samplewise_preprocessing(sample)
+ else:
self.raise_for_non_blockwise_preprocessing()
- for op in self._samplewise_preprocessing + self._blockwise_preprocessing:
- if isinstance(sample, SampleBlock):
- assert isinstance(op, BlockwiseOperator)
- op(sample)
- else:
- op(sample)
+ self._apply_blockwise_preprocessing(sample)
- def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
- """apply postprocessing in-place, also may updates samples stats"""
+ def _apply_blockwise_postprocessing(
+ self, sample_block: Union[Sample, SampleBlock], /
+ ) -> None:
+ """Apply in-place blockwise postprocessing operators
+
+ Note: This does not apply all postprocessing operators from the first samplewise one onwards.
+ """
+ for op in self._blockwise_postprocessing:
+ op(sample_block)
+
+ def _apply_samplewise_postprocessing(self, sample: Sample, /) -> None:
+ """Apply in-place postprocessing operators starting from and including the first samplewise operator.
+
+ Note: This skips all blockwise postprocessing steps before the first samplewise one.
+ """
if isinstance(sample, SampleBlock):
self.raise_for_non_blockwise_postprocessing()
- for op in self._blockwise_postprocessing + self._samplewise_postprocessing:
- if isinstance(sample, SampleBlock):
- assert isinstance(op, BlockwiseOperator)
- op(sample)
- else:
- op(sample)
+ for op in self._samplewise_postprocessing:
+ op(sample)
+
+ def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
+ """apply postprocessing in-place, also may updates samples stats"""
+ self._apply_blockwise_postprocessing(sample)
+ if isinstance(sample, Sample):
+ self._apply_samplewise_postprocessing(sample)
+ else:
+ self.raise_for_non_blockwise_postprocessing()
def load(self):
+ """Prepare prediction pipeline for use.
+
+ Reusable model adapters may be loaded and unloaded multiple times, but currently not all model adapters
+ cleanly unload and reload.
+
+ Note:
+ For some model adapters loading is currently part of the constructor making them unusable after unloading.
"""
- optional step: load model onto devices before calling forward if not using it as context manager
- """
- pass
+ self._adapter.load()
def unload(self):
- """
- free any device memory in use
- """
+ """Free any device memory in use.
+
+ Note:
+ Currently prediction pipeline becomes unusable after unloading."""
self._adapter.unload()
+ def close(self):
+ """Permanently close the prediction pipeline and free any device memory in use.
+ This makes the prediction pipeline unusable afterwards."""
+ self.unload()
+
+
+class RemotePredictionPipeline(_PredictionPipelineBase):
+ """Abstract base class for fully remote prediction pipelines.
+
+ Note: A ("local") `PredictionPipeline` may also use a `RemoteModelAdapter` for remote model inference, but it may
+ still apply local preprocessing and postprocessing steps.
+ In contrast, a `RemotePredictionPipeline` is designed for the case where all steps including preprocessing and
+ postprocessing are performed remotely.
+ """
+
+ def __init__(
+ self,
+ model_descr: AnyModelDescr,
+ *,
+ server: str,
+ default_blocksize_parameter: BlocksizeParameter,
+ default_batch_size: int,
+ ) -> None:
+ super().__init__(
+ model_descr,
+ default_blocksize_parameter=default_blocksize_parameter,
+ default_batch_size=default_batch_size,
+ )
+ self._server = server
+
+ @property
+ def server(self) -> str:
+ return self._server
+
def create_prediction_pipeline(
bioimageio_model: AnyModelDescr,
@@ -499,3 +771,50 @@ def dataset():
postprocessing=postprocessing,
default_blocksize_parameter=default_blocksize_parameter,
)
+
+
+def create_remote_prediction_pipeline(
+ model_description: AnyModelDescr,
+ *,
+ server: Optional[str] = None,
+ server_type: Optional[Literal["gradio"]] = "gradio",
+ precomputed_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}),
+ default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo
+ default_batch_size: int = 1,
+) -> RemotePredictionPipeline:
+ """Create a `RemotePredictionPipeline` for the given `model_description`.
+
+ Args:
+ model_description: The model to run inference with.
+ server: The URL or Hugging Face space name of a running bioimageio server instance
+ server_type: The type of the remote server to connect to. Currently only "gradio" is supported.
+ precomputed_statistics: Precomputed dataset (and optionally sample) statistics.
+ Any included sample statistics will not be calculated on the fly and it is the callers
+ responsibility to use samples with the corresponding statistics availble in `sample.stat`.
+ default_blocksize_parameter: Allows to control the default block size with a single parameter for blockwise predictions. (not all models support this)
+ default_batch_size: Default batch size to use
+ """
+
+ if server_type is None:
+ server_type = "gradio"
+
+ try:
+ if server_type == "gradio":
+ from .remote_backends.gradio.client import (
+ GradioPredictionPipeline as RemotePredictionPipelineImpl,
+ )
+ else:
+ assert_never(server_type)
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import {server_type.capitalize()}PredictionPipeline. Make sure to install the '{server_type}-client' extra,"
+ + f" e.g. with `pip install bioimageio.core[{server_type}-client]`."
+ ) from e
+
+ return RemotePredictionPipelineImpl(
+ model_description,
+ server=server,
+ precomputed_statistics=precomputed_statistics,
+ default_blocksize_parameter=default_blocksize_parameter,
+ default_batch_size=default_batch_size,
+ )
diff --git a/src/bioimageio/core/_resource_tests.py b/src/bioimageio/core/_resource_tests.py
index fa528fc11..1f0b4c180 100644
--- a/src/bioimageio/core/_resource_tests.py
+++ b/src/bioimageio/core/_resource_tests.py
@@ -818,6 +818,70 @@ def _get_tolerance(
return rtol, atol, mismatched_tol
+def evaluate_mismatched_elements(
+ actual: Tensor, expected: Tensor, rtol: float, atol: float, name: str
+) -> Tuple[float, str, Optional[str]]:
+ try:
+ expected_np = expected.data.to_numpy().astype(np.float32)
+ dims = expected.dims
+ del expected
+ actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
+ del actual
+
+ rtol_value = rtol * abs(expected_np)
+ abs_diff = abs(actual_np - expected_np)
+ mismatched = abs_diff > atol + rtol_value
+ mismatched_elements = mismatched.sum().item()
+
+ mismatched_ppm = mismatched_elements / expected_np.size * 1e6
+ abs_diff[~mismatched] = 0 # ignore non-mismatched elements
+
+ r_max_idx_flat = (r_diff := (abs_diff / (abs(expected_np) + 1e-6))).argmax()
+ r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
+ r_max = r_diff[r_max_idx].item()
+ r_actual = actual_np[r_max_idx].item()
+ r_expected = expected_np[r_max_idx].item()
+
+ # Calculate the max absolute difference with the relative tolerance subtracted
+ abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
+ a_max_idx = np.unravel_index(abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape)
+
+ a_max = abs_diff[a_max_idx].item()
+ a_actual = actual_np[a_max_idx].item()
+ a_expected = expected_np[a_max_idx].item()
+ except Exception as e:
+ mismatched_ppm = -1
+ msg = ""
+ error_msg = (
+ f"Error while checking if '{name}' disagrees with expected values: {e}"
+ )
+ else:
+ error_msg = None
+ if mismatched_elements:
+ msg = (
+ f"Output '{name}': {mismatched_elements} of "
+ + f"{expected_np.size} elements disagree with expected values ("
+ + (
+ f"{mismatched_ppm * 10_000:.1f}%"
+ if mismatched_ppm >= 1_000
+ else f"{mismatched_ppm:.1f} ppm"
+ )
+ + "). "
+ )
+ else:
+ msg = f"Output `{name}`: all elements agree with expected values. "
+
+ msg += (
+ f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}"
+ + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
+ + f" at {dict(zip(dims, r_max_idx))} "
+ + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}"
+ + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
+ )
+
+ return mismatched_ppm, msg, error_msg
+
+
def _test_recreate_test_outputs(
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
weight_format: SupportedWeightsFormat,
@@ -917,7 +981,7 @@ def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]:
else:
continue
- if actual.dims != (dims := expected.dims):
+ if actual.dims != expected.dims:
add_error_entry(
f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}"
)
@@ -944,72 +1008,32 @@ def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]:
results_not_postprocessed.members[m],
)
)
+ except Exception as e:
+ logger.error(f"Failed to save actual output tensor for '{m}': {e}")
+ output_paths = None
- expected_np = expected.data.to_numpy().astype(np.float32)
- del expected
- actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
+ rtol, atol, mismatched_tol = _get_tolerance(
+ model, wf=weight_format, m=m, **deprecated
+ )
+ mismatched_ppm, msg, error_msg = evaluate_mismatched_elements(
+ actual, expected, rtol, atol, m
+ )
+ if error_msg is not None:
+ add_error_entry(error_msg)
+ if stop_early:
+ break
- rtol, atol, mismatched_tol = _get_tolerance(
- model, wf=weight_format, m=m, **deprecated
- )
- rtol_value = rtol * abs(expected_np)
- abs_diff = abs(actual_np - expected_np)
- mismatched = abs_diff > atol + rtol_value
- mismatched_elements = mismatched.sum().item()
-
- mismatched_ppm = mismatched_elements / expected_np.size * 1e6
- abs_diff[~mismatched] = 0 # ignore non-mismatched elements
-
- r_max_idx_flat = (
- r_diff := (abs_diff / (abs(expected_np) + 1e-6))
- ).argmax()
- r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
- r_max = r_diff[r_max_idx].item()
- r_actual = actual_np[r_max_idx].item()
- r_expected = expected_np[r_max_idx].item()
-
- # Calculate the max absolute difference with the relative tolerance subtracted
- abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
- a_max_idx = np.unravel_index(
- abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
- )
+ if output_paths:
+ msg += f"\n Saved (intermediate) outputs to {output_paths}."
- a_max = abs_diff[a_max_idx].item()
- a_actual = actual_np[a_max_idx].item()
- a_expected = expected_np[a_max_idx].item()
- except Exception as e:
- msg = f"Error while checking if '{m}' disagrees with expected values: {e}"
+ if mismatched_ppm > mismatched_tol:
add_error_entry(msg)
if stop_early:
break
else:
- if mismatched_elements:
- msg = (
- f"Output '{m}': {mismatched_elements} of "
- + f"{expected_np.size} elements disagree with expected values."
- + f" ({mismatched_ppm:.1f} ppm). "
- )
- else:
- msg = f"Output `{m}`: all elements agree with expected values. "
-
- msg += (
- f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}"
- + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
- + f" at {dict(zip(dims, r_max_idx))} "
- + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}"
- + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
+ add_warning_entry(
+ msg, severity=WARNING if mismatched_ppm != 0 else INFO
)
- if output_paths:
- msg += f"\n Saved (intermediate) outputs to {output_paths}."
-
- if mismatched_ppm > mismatched_tol:
- add_error_entry(msg)
- if stop_early:
- break
- else:
- add_warning_entry(
- msg, severity=WARNING if mismatched_elements else INFO
- )
except Exception as e:
if get_validation_context().raise_errors:
diff --git a/src/bioimageio/core/_sample_serializer.py b/src/bioimageio/core/_sample_serializer.py
new file mode 100644
index 000000000..2d3f461c6
--- /dev/null
+++ b/src/bioimageio/core/_sample_serializer.py
@@ -0,0 +1,85 @@
+from abc import ABC, abstractmethod
+from typing import (
+ Generic,
+ Iterable,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+from bioimageio.spec.model import v0_5
+
+from .axis import PerAxis
+from .common import HaloLike, PadMode, PerMember
+from .digest_spec import split_sample_into_blocks_for_model
+from .sample import Sample, SampleBlock
+
+SerializedSampleBlockType = TypeVar("SerializedSampleBlockType")
+
+
+class SampleSerializer(ABC, Generic[SerializedSampleBlockType]):
+ @classmethod
+ def serialize_sample(
+ cls,
+ sample: Sample,
+ ) -> Tuple[SerializedSampleBlockType]:
+ """Serialize a sample as a single block"""
+ return (cls.serialize_sample_block(sample.as_single_block()),)
+
+ @classmethod
+ def deserialize_sample(
+ cls,
+ serialized: Iterable[SerializedSampleBlockType],
+ fill_value: float = float("nan"),
+ ) -> Sample:
+ return Sample.from_blocks(
+ (cls.deserialize_sample_block(s) for s in serialized), fill_value=fill_value
+ )
+
+ def serialize_sample_blockwise(
+ self,
+ sample: Sample,
+ *,
+ model: v0_5.ModelDescr,
+ blocksize_parameter: int,
+ batch_size: int = 1,
+ ) -> Iterable[SerializedSampleBlockType]:
+ """Split a sample into blocks according to the model's input specifications and `blocksize_parameter` and serialize each block."""
+
+ _n_blocks, blocks = split_sample_into_blocks_for_model(
+ sample,
+ model=model,
+ blocksize_parameter=blocksize_parameter,
+ batch_size=batch_size,
+ )
+ for block in blocks:
+ yield self.serialize_sample_block(block)
+
+ @classmethod
+ def serialize_sample_with_fixed_blocking(
+ cls,
+ sample: Sample,
+ *,
+ block_shapes: PerMember[PerAxis[int]],
+ halo: PerMember[PerAxis[HaloLike]],
+ pad_mode: Union[PadMode, PerMember[PadMode]] = "symmetric",
+ ) -> Iterable[SerializedSampleBlockType]:
+
+ _n_blocks, input_blocks = sample.split_into_blocks(
+ block_shapes=block_shapes,
+ halo=halo,
+ pad_mode=pad_mode,
+ )
+ for block in input_blocks:
+ yield cls.serialize_sample_block(block)
+
+ @staticmethod
+ @abstractmethod
+ def serialize_sample_block(
+ sample_block: SampleBlock,
+ ) -> SerializedSampleBlockType: ...
+
+ @staticmethod
+ @abstractmethod
+ def deserialize_sample_block(serialized: SerializedSampleBlockType) -> SampleBlock:
+ """Deserialize a sample block into a new sample or merge it into `output_sample` if provided."""
diff --git a/src/bioimageio/core/_settings.py b/src/bioimageio/core/_settings.py
index 6e7675f56..16b460683 100644
--- a/src/bioimageio/core/_settings.py
+++ b/src/bioimageio/core/_settings.py
@@ -45,6 +45,18 @@ def _set_default_mps_fallback(cls, value: Optional[bool]):
)
"""URL to the bioimageio collection config"""
+ gradio_server: Optional[str] = None
+ """URL or Hugging Face space name to connect to with the remote gradio model adapter or remote gradio prediction pipeline.
+
+ Example: "bioimage-io/bioimage-io-gradio-server"
+ """
+
+ gradio_server_model_cache_max_size: int = 10
+ """Max number of models to cache in the gradio server for prediction pipelines using the gradio backend."""
+
+ gradio_server_model_cache_max_memory: str = "40GB"
+ """Max memory to use for model caching in the gradio server for prediction pipelines using the gradio backend."""
+
settings = Settings()
"""parsed environment variables for bioimageio.spec and bioimageio.core"""
diff --git a/src/bioimageio/core/axis.py b/src/bioimageio/core/axis.py
index 07ce7e042..6c85b8b3a 100644
--- a/src/bioimageio/core/axis.py
+++ b/src/bioimageio/core/axis.py
@@ -1,9 +1,15 @@
from __future__ import annotations
from dataclasses import dataclass
-from typing import Literal, Mapping, Optional, TypeVar, Union
+from typing import (
+ Literal,
+ Mapping,
+ Optional,
+ TypeVar,
+ Union,
+)
-from typing_extensions import Protocol, assert_never, runtime_checkable
+from typing_extensions import Protocol, TypeAlias, assert_never, runtime_checkable
from bioimageio.spec.model import v0_5
@@ -33,11 +39,12 @@ def _guess_axis_type(a: str):
S = TypeVar("S", bound=str)
-AxisId = v0_5.AxisId
+AxisId: TypeAlias = v0_5.AxisId
"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'"""
-T = TypeVar("T")
-PerAxis = Mapping[AxisId, T]
+_T = TypeVar("_T")
+PerAxis = Mapping[AxisId, _T]
+
BatchSize = int
diff --git a/src/bioimageio/core/backends/__init__.py b/src/bioimageio/core/backends/__init__.py
index c39b58b58..7fc53d589 100644
--- a/src/bioimageio/core/backends/__init__.py
+++ b/src/bioimageio/core/backends/__init__.py
@@ -1,3 +1,127 @@
-from ._model_adapter import create_model_adapter
+from typing import (
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
-__all__ = ["create_model_adapter"]
+from exceptiongroup import ExceptionGroup
+from typing_extensions import assert_never
+
+from bioimageio.spec.model import v0_4, v0_5
+
+from ..common import SupportedWeightsFormat
+
+# Known weight formats in order of priority
+# First match wins
+DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = (
+ "pytorch_state_dict",
+ "tensorflow_saved_model_bundle",
+ "torchscript",
+ "onnx",
+ "keras_v3",
+ "keras_hdf5",
+)
+
+
+def create_model_adapter(
+ model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
+ *,
+ devices: Optional[Sequence[str]] = None,
+ weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
+):
+ """Creates model adapter for `model_descritption`"""
+ if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
+ raise TypeError(
+ f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
+ )
+
+ weights = model_description.weights
+ errors: List[Exception] = []
+ weight_format_priority_order = (
+ DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
+ if weight_format_priority_order is None
+ else weight_format_priority_order
+ )
+ # limit weight formats to the ones present
+ weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
+ w for w in weight_format_priority_order if getattr(weights, w, None) is not None
+ ]
+ if not weight_format_priority_order_present:
+ raise ValueError(
+ f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
+ )
+
+ for wf in weight_format_priority_order_present:
+ if wf == "pytorch_state_dict":
+ assert weights.pytorch_state_dict is not None
+ try:
+ from .pytorch_backend import PytorchModelAdapter
+
+ return PytorchModelAdapter(model_description, devices=devices)
+ except Exception as e:
+ errors.append(e)
+ elif wf == "tensorflow_saved_model_bundle":
+ assert weights.tensorflow_saved_model_bundle is not None
+ try:
+ from .tensorflow_backend import create_tf_model_adapter
+
+ return create_tf_model_adapter(model_description, devices=devices)
+ except Exception as e:
+ errors.append(e)
+ elif wf == "onnx":
+ assert weights.onnx is not None
+ try:
+ from .onnx_backend import ONNXModelAdapter
+
+ return ONNXModelAdapter(model_description, devices=devices)
+ except Exception as e:
+ errors.append(e)
+ elif wf == "torchscript":
+ assert weights.torchscript is not None
+ try:
+ from .torchscript_backend import TorchscriptModelAdapter
+
+ return TorchscriptModelAdapter(model_description, devices=devices)
+ except Exception as e:
+ errors.append(e)
+ elif wf == "keras_hdf5":
+ assert weights.keras_hdf5 is not None
+ # keras can either be installed as a separate package or used as part of tensorflow
+ # we try to first import the keras model adapter using the separate package and,
+ # if it is not available, try to load the one using tf
+ try:
+ try:
+ from .keras_backend import KerasModelAdapter
+ except Exception:
+ from .tensorflow_backend import KerasModelAdapter
+
+ return KerasModelAdapter(model_description, devices=devices)
+ except Exception as e:
+ errors.append(e)
+ elif wf == "keras_v3":
+ assert not isinstance(weights, v0_4.WeightsDescr), (
+ "keras_v3 weights not supported for v0.4 specs"
+ )
+ assert weights.keras_v3 is not None
+ try:
+ from .keras_backend import KerasModelAdapter
+
+ return KerasModelAdapter(model_description, devices=devices)
+ except Exception as e:
+ errors.append(e)
+ else:
+ assert_never(wf)
+
+ assert errors
+ if len(weight_format_priority_order) == 1:
+ assert len(errors) == 1
+ raise errors[0]
+
+ else:
+ msg = (
+ "None of the weight format specific model adapters could be created"
+ + " in this environment."
+ )
+ raise ExceptionGroup(msg, errors)
diff --git a/src/bioimageio/core/backends/_model_adapter.py b/src/bioimageio/core/backends/_model_adapter.py
deleted file mode 100644
index e9bfeaa9c..000000000
--- a/src/bioimageio/core/backends/_model_adapter.py
+++ /dev/null
@@ -1,270 +0,0 @@
-import warnings
-from abc import ABC, abstractmethod
-from typing import (
- Any,
- List,
- Optional,
- Sequence,
- Tuple,
- Union,
- final,
-)
-
-from exceptiongroup import ExceptionGroup
-from loguru import logger
-from numpy.typing import NDArray
-from typing_extensions import assert_never
-
-from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
-
-from ..common import SupportedWeightsFormat
-from ..digest_spec import get_axes_infos, get_member_ids
-from ..sample import Sample, SampleBlock
-from ..tensor import Tensor
-
-# Known weight formats in order of priority
-# First match wins
-DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = (
- "pytorch_state_dict",
- "tensorflow_saved_model_bundle",
- "torchscript",
- "onnx",
- "keras_v3",
- "keras_hdf5",
-)
-
-
-class ModelAdapter(ABC):
- """
- Represents model *without* any preprocessing or postprocessing.
-
- ```
- from bioimageio.core import load_description
-
- model = load_description(...)
-
- # option 1:
- adapter = ModelAdapter.create(model)
- adapter.forward(...)
- adapter.unload()
-
- # option 2:
- with ModelAdapter.create(model) as adapter:
- adapter.forward(...)
- ```
- """
-
- def __init__(self, model_description: AnyModelDescr):
- super().__init__()
- self._model_descr = model_description
- self._input_ids = get_member_ids(model_description.inputs)
- self._output_ids = get_member_ids(model_description.outputs)
- self._input_axes = [
- tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
- ]
- self._output_axes = [
- tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
- ]
- if isinstance(model_description, v0_4.ModelDescr):
- self._input_is_optional = [False] * len(model_description.inputs)
- else:
- self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
-
- @final
- @classmethod
- def create(
- cls,
- model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- *,
- devices: Optional[Sequence[str]] = None,
- weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
- ):
- """
- Creates model adapter based on the passed spec
- Note: All specific adapters should happen inside this function to prevent different framework
- initializations interfering with each other
- """
- if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
- raise TypeError(
- f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
- )
-
- weights = model_description.weights
- errors: List[Exception] = []
- weight_format_priority_order = (
- DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
- if weight_format_priority_order is None
- else weight_format_priority_order
- )
- # limit weight formats to the ones present
- weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
- w
- for w in weight_format_priority_order
- if getattr(weights, w, None) is not None
- ]
- if not weight_format_priority_order_present:
- raise ValueError(
- f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
- )
-
- for wf in weight_format_priority_order_present:
- if wf == "pytorch_state_dict":
- assert weights.pytorch_state_dict is not None
- try:
- from .pytorch_backend import PytorchModelAdapter
-
- return PytorchModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append(e)
- elif wf == "tensorflow_saved_model_bundle":
- assert weights.tensorflow_saved_model_bundle is not None
- try:
- from .tensorflow_backend import create_tf_model_adapter
-
- return create_tf_model_adapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append(e)
- elif wf == "onnx":
- assert weights.onnx is not None
- try:
- from .onnx_backend import ONNXModelAdapter
-
- return ONNXModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append(e)
- elif wf == "torchscript":
- assert weights.torchscript is not None
- try:
- from .torchscript_backend import TorchscriptModelAdapter
-
- return TorchscriptModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append(e)
- elif wf == "keras_hdf5":
- assert weights.keras_hdf5 is not None
- # keras can either be installed as a separate package or used as part of tensorflow
- # we try to first import the keras model adapter using the separate package and,
- # if it is not available, try to load the one using tf
- try:
- try:
- from .keras_backend import KerasModelAdapter
- except Exception:
- from .tensorflow_backend import KerasModelAdapter
-
- return KerasModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append(e)
- elif wf == "keras_v3":
- assert not isinstance(weights, v0_4.WeightsDescr), (
- "keras_v3 weights not supported for v0.4 specs"
- )
- assert weights.keras_v3 is not None
- try:
- from .keras_backend import KerasModelAdapter
-
- return KerasModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append(e)
- else:
- assert_never(wf)
-
- assert errors
- if len(weight_format_priority_order) == 1:
- assert len(errors) == 1
- raise errors[0]
-
- else:
- msg = (
- "None of the weight format specific model adapters could be created"
- + " in this environment."
- )
- raise ExceptionGroup(msg, errors)
-
- @final
- def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
- warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
-
- def forward(self, input_sample: Union[Sample, SampleBlock]) -> Sample:
- """
- Run forward pass of model to get model predictions
-
- Note: sample id and stample stat attributes are passed through
- """
- unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
- if unexpected:
- warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
-
- input_arrays = [
- (
- None
- if (a := input_sample.members.get(in_id)) is None
- else a.transpose(in_order).data.data
- )
- for in_id, in_order in zip(self._input_ids, self._input_axes)
- ]
- logger.debug(
- "NN input shapes: {}",
- [a.shape if a is not None else None for a in input_arrays],
- )
- output_arrays = self._forward_impl(input_arrays)
- logger.debug(
- "NN output shapes: {}",
- [a.shape if a is not None else None for a in output_arrays],
- )
- if len(output_arrays) > len(self._output_ids):
- warnings.warn(
- f"Model produced more outputs ({len(output_arrays)}) than specified in the model description ({len(self._output_ids)}). Extra outputs will be ignored."
- )
- output_arrays = output_arrays[: len(self._output_ids)]
-
- output_tensors = [
- None if a is None else Tensor(a, dims=d)
- for a, d in zip(output_arrays, self._output_axes)
- ]
- return Sample(
- members={
- tid: out
- for tid, out in zip(
- self._output_ids,
- output_tensors,
- )
- if out is not None
- },
- stat=input_sample.stat,
- id=(
- input_sample.id
- if isinstance(input_sample, Sample)
- else input_sample.sample_id
- ),
- )
-
- @abstractmethod
- def _forward_impl(
- self, input_arrays: Sequence[Optional[NDArray[Any]]]
- ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]:
- """framework specific forward implementation"""
-
- @abstractmethod
- def unload(self):
- """
- Unload model from any devices, freeing their memory.
- The moder adapter should be considered unusable afterwards.
- """
-
- def _get_input_args_numpy(self, input_sample: Sample):
- """helper to extract tensor args as transposed numpy arrays"""
-
-
-create_model_adapter = ModelAdapter.create
diff --git a/src/bioimageio/core/backends/keras_backend.py b/src/bioimageio/core/backends/keras_backend.py
index 9e56c6d19..fbdfe498d 100644
--- a/src/bioimageio/core/backends/keras_backend.py
+++ b/src/bioimageio/core/backends/keras_backend.py
@@ -2,7 +2,7 @@
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
-from typing import Any, Optional, Sequence, Union
+from typing import Any, Optional, Sequence, Tuple
from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs]
legacy_h5_format,
@@ -11,13 +11,12 @@
from numpy.typing import NDArray
from bioimageio.spec._internal.version_type import Version
-from bioimageio.spec.model import v0_4, v0_5
+from bioimageio.spec.model import v0_4
+from .._model_adapter import LocalModelAdapter
from .._settings import settings
-from ..digest_spec import get_axes_infos
from ..utils._compare import warn_about_version
from ..utils._type_guards import is_list, is_tuple
-from ._model_adapter import ModelAdapter
os.environ["KERAS_BACKEND"] = settings.keras_backend
@@ -35,25 +34,27 @@
tf_version = None
-class KerasModelAdapter(ModelAdapter):
- def __init__(
- self,
- *,
- model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- devices: Optional[Sequence[str]] = None,
- ) -> None:
- super().__init__(model_description=model_description)
+class KerasModelAdapter(LocalModelAdapter[None, Any]):
+ def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]:
+ # TODO keras device management
+ if devices is not None:
+ logger.warning(
+ "Device management is not implemented for keras yet, ignoring the devices {}",
+ devices,
+ )
+ return (None,)
+ def _init_model_on_device(self, device: None) -> Any:
if (
- not isinstance(model_description, v0_4.ModelDescr)
- and model_description.weights.keras_v3 is not None
+ not isinstance(self._model_descr, v0_4.ModelDescr)
+ and self._model_descr.weights.keras_v3 is not None
):
- weight_reader = model_description.weights.keras_v3.get_reader()
- backend, backend_version = model_description.weights.keras_v3.backend
- elif model_description.weights.keras_hdf5 is not None:
+ weight_reader = self._model_descr.weights.keras_v3.get_reader()
+ backend, backend_version = self._model_descr.weights.keras_v3.backend
+ elif self._model_descr.weights.keras_hdf5 is not None:
backend = "legacy_tensorflow"
- backend_version = model_description.weights.keras_hdf5.tensorflow_version
- weight_reader = model_description.weights.keras_hdf5.get_reader()
+ backend_version = self._model_descr.weights.keras_hdf5.tensorflow_version
+ weight_reader = self._model_descr.weights.keras_hdf5.get_reader()
else:
raise ValueError("model has no Keras weights")
@@ -81,41 +82,33 @@ def __init__(
jax_version = Version(jax.__version__)
warn_about_version("jax", backend_version, jax_version)
- # TODO keras device management
- if devices is not None:
- logger.warning(
- "Device management is not implemented for keras yet, ignoring the devices {}",
- devices,
- )
-
if weight_reader.suffix in (".h5", "hdf5"):
import h5py # pyright: ignore[reportMissingTypeStubs]
h5_file = h5py.File(weight_reader, mode="r")
- self._network = legacy_h5_format.load_model_from_hdf5(h5_file)
+ return legacy_h5_format.load_model_from_hdf5(h5_file) # pyright: ignore[reportUnknownVariableType]
else:
with TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir) / weight_reader.original_file_name
with temp_path.open("wb") as f:
shutil.copyfileobj(weight_reader, f)
- self._network = keras.models.load_model(temp_path)
-
- self._output_axes = [
- tuple(a.id for a in get_axes_infos(out))
- for out in model_description.outputs
- ]
+ return keras.models.load_model(temp_path) # pyright: ignore[reportUnknownVariableType]
- def _forward_impl( # pyright: ignore[reportUnknownParameterType]
- self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ def _forward_impl(
+ self,
+ device: None,
+ model: Any,
+ input_arrays: Sequence[Optional[NDArray[Any]]],
):
- network_output = self._network.predict(*input_arrays) # type: ignore
+ network_output = model.predict(*input_arrays)
if is_list(network_output) or is_tuple(network_output):
return network_output
else:
- return [network_output] # pyright: ignore[reportUnknownVariableType]
+ return [network_output]
+
+ def _cleanup_pre_model_deletion(self, device: None, model: Any) -> None:
+ return
- def unload(self) -> None:
- logger.warning(
- "Device management is not implemented for keras yet, cannot unload model"
- )
+ def _cleanup_post_model_deletion(self, device: None) -> None:
+ return
diff --git a/src/bioimageio/core/backends/onnx_backend.py b/src/bioimageio/core/backends/onnx_backend.py
index 6a752f8fb..0dab80c12 100644
--- a/src/bioimageio/core/backends/onnx_backend.py
+++ b/src/bioimageio/core/backends/onnx_backend.py
@@ -1,35 +1,37 @@
# pyright: reportUnknownVariableType=false
import shutil
import tempfile
-import warnings
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Any, List, Optional, Sequence, Union, cast
import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
-from exceptiongroup import ExceptionGroup
from loguru import logger
from numpy.typing import NDArray
from bioimageio.spec.model import v0_4, v0_5
-from ..model_adapters import ModelAdapter
+from .._model_adapter import LocalModelAdapter
from ..utils._type_guards import is_list, is_tuple
-class ONNXModelAdapter(ModelAdapter):
+class ONNXModelAdapter(LocalModelAdapter[Optional[str], rt.InferenceSession]):
def __init__(
self,
- *,
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
devices: Optional[Sequence[str]] = None,
):
- super().__init__(model_description=model_description)
-
onnx_descr = model_description.weights.onnx
if onnx_descr is None:
raise ValueError("No ONNX weights specified for {model_description.name}")
+ self._onnx_descr = onnx_descr
+ self._input_names: Optional[List[str]] = None
+ super().__init__(model_description=model_description, devices=devices)
+
+ def _parse_devices(
+ self, devices: Optional[Sequence[str]]
+ ) -> Sequence[Optional[str]]:
available_providers: Any = None
if hasattr(rt, "get_available_providers"):
available_providers = cast(Any, rt.get_available_providers())
@@ -40,8 +42,35 @@ def __init__(
else:
providers = available_providers
else:
+ available_providers = [available_providers]
providers = [available_providers]
+ if devices is not None:
+ available_devices = [d for d in devices if d in providers]
+ unavailable_devices = [d for d in devices if d not in providers]
+ if available_devices:
+ if unavailable_devices:
+ logger.warning(
+ "The following requested devices are not available for ONNX Runtime and will be ignored: {}.\nSelected available providers/devices are: {}\nOther available providers are: {}",
+ unavailable_devices,
+ available_devices,
+ [p for p in providers if p not in devices],
+ )
+
+ providers = available_devices
+ elif not available_providers:
+ logger.error(
+ "ONNX Runtime does not report any available providers. Attempting to load model with default providers, but this will likely fail."
+ )
+ else:
+ logger.warning(
+ "None of the requested devices are available for ONNX Runtime, falling back to default, available providers: {}",
+ available_providers,
+ )
+ return providers
+
+ def _init_model_on_device(self, device: Optional[str]) -> rt.InferenceSession:
+ onnx_descr = self._onnx_descr
if (
isinstance(onnx_descr, v0_5.OnnxWeightsDescr)
and onnx_descr.external_data is not None
@@ -88,51 +117,30 @@ def source_context_func():
with source_context as s:
assert isinstance(s, bytes) or s.exists()
+ session = rt.InferenceSession(
+ s,
+ providers=None if device is None else [device],
+ )
- # try providers in order until one works
- # TODO: check if issue with backup providers is fixed and evaluate handing over all available providers
- # currently (onnxruntime 1.23.2) if a higher priority providers fails a RUNTIME_EXCEPTION may be raised
- # stating 'model_path must not be empty' instead of trying the next provider, see # TODO: reference issue
- provider_exceptions: List[Exception] = []
- for p in providers:
- try:
- self._session = rt.InferenceSession(
- s,
- providers=None if p is None else [p],
- )
- except Exception as e:
- provider_exceptions.append(e)
- else:
- for bad_p, e in zip(
- providers[: len(provider_exceptions)], provider_exceptions
- ):
- logger.warning(
- "Failed to load ONNX model with provider {}: {}",
- bad_p,
- e,
- )
-
- break
- else:
- raise ExceptionGroup(
- "Failed to load ONNX model with any of the available providers.",
- provider_exceptions,
- )
-
- onnx_inputs = self._session.get_inputs()
- self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
-
- if devices is not None:
- warnings.warn(
- f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
+ onnx_inputs = session.get_inputs()
+ onnx_input_names = [str(ipt.name) for ipt in onnx_inputs] # pyright: ignore[reportUnknownArgumentType]
+ if self._input_names is None:
+ self._input_names = onnx_input_names
+ elif self._input_names != onnx_input_names:
+ raise RuntimeError(
+ f"Input names of the ONNX model {onnx_input_names} do not match expected input names {self._input_names} from previous model initialization."
)
+ return session
+
def _forward_impl(
- self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ self,
+ device: Optional[str],
+ model: rt.InferenceSession,
+ input_arrays: Sequence[Optional[NDArray[Any]]],
) -> List[Optional[NDArray[Any]]]:
- result: Any = self._session.run(
- None, dict(zip(self._input_names, input_arrays))
- )
+ assert self._input_names is not None, "set during model initialization"
+ result: Any = model.run(None, dict(zip(self._input_names, input_arrays)))
if is_list(result) or is_tuple(result):
result_seq = list(result)
else:
@@ -140,7 +148,10 @@ def _forward_impl(
return result_seq
- def unload(self) -> None:
- warnings.warn(
- "Device management is not implemented for onnx yet, cannot unload model"
- )
+ def _cleanup_pre_model_deletion(
+ self, device: Optional[str], model: rt.InferenceSession
+ ) -> None:
+ return
+
+ def _cleanup_post_model_deletion(self, device: Optional[str]) -> None:
+ return
diff --git a/src/bioimageio/core/backends/pytorch_backend.py b/src/bioimageio/core/backends/pytorch_backend.py
index 9ca5903d1..2d4634522 100644
--- a/src/bioimageio/core/backends/pytorch_backend.py
+++ b/src/bioimageio/core/backends/pytorch_backend.py
@@ -1,5 +1,4 @@
import gc
-import warnings
from abc import abstractmethod
from contextlib import nullcontext
from io import BytesIO, TextIOWrapper
@@ -17,9 +16,9 @@
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
from bioimageio.spec.utils import download
+from .._model_adapter import LocalModelAdapter
from ..digest_spec import import_callable
from ..utils._type_guards import is_list, is_ndarray, is_tuple
-from ._model_adapter import ModelAdapter
@runtime_checkable
@@ -48,37 +47,46 @@ def eval(self) -> Self:
return self
-class PytorchModelAdapter(ModelAdapter):
+class PytorchModelAdapter(LocalModelAdapter[torch.device, nn.Module]):
def __init__(
self,
- *,
model_description: AnyModelDescr,
- devices: Optional[Sequence[Union[str, torch.device]]] = None,
mode: Literal["eval", "train"] = "eval",
+ devices: Optional[Sequence[str]] = None,
):
- super().__init__(model_description=model_description)
weights = model_description.weights.pytorch_state_dict
if weights is None:
raise ValueError("No `pytorch_state_dict` weights found")
- devices = get_devices(devices)
- self._model = load_torch_model(weights, load_state=True, devices=devices)
- if mode == "eval":
- self._model = self._model.eval()
- elif mode == "train":
- self._model = self._model.train()
+ self._weights = weights
+ self._mode: Literal["eval", "train"] = mode
+ super().__init__(model_description=model_description, devices=devices)
+
+ def _parse_devices(
+ self, devices: Optional[Sequence[str]]
+ ) -> Sequence[torch.device]:
+ return get_devices(devices)
+
+ def _init_model_on_device(self, device: torch.device) -> nn.Module:
+ model = load_torch_model(self._weights, load_state=True, devices=[device])
+
+ if self._mode == "eval":
+ model = model.eval()
+ elif self._mode == "train":
+ model = model.train()
else:
- assert_never(mode)
+ assert_never(self._mode)
- self._mode: Literal["eval", "train"] = mode
- self._primary_device = devices[0]
+ return model
def _forward_impl(
- self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ self,
+ device: torch.device,
+ model: nn.Module,
+ input_arrays: Sequence[Optional[NDArray[Any]]],
) -> List[Optional[NDArray[Any]]]:
tensors = [
- None if a is None else torch.from_numpy(a).to(self._primary_device)
- for a in input_arrays
+ None if a is None else torch.from_numpy(a).to(device) for a in input_arrays
]
if self._mode == "eval":
@@ -89,7 +97,7 @@ def _forward_impl(
assert_never(self._mode)
with ctxt():
- model_out = self._model(*tensors)
+ model_out = model(*tensors)
if is_tuple(model_out) or is_list(model_out):
model_out_seq = model_out
@@ -112,11 +120,15 @@ def _forward_impl(
return result
- def unload(self) -> None:
- del self._model
+ def _cleanup_pre_model_deletion(
+ self, device: torch.device, model: nn.Module
+ ) -> None:
+ return
+
+ def _cleanup_post_model_deletion(self, device: torch.device) -> None:
_ = gc.collect() # deallocate memory
- assert torch is not None
- torch.cuda.empty_cache() # release reserved memory
+ if device.type == "cuda":
+ torch.cuda.empty_cache() # release reserved memory
def load_torch_model(
@@ -232,18 +244,24 @@ def get_devices(
) -> List[torch.device]:
if not devices:
if torch.cuda.is_available():
- torch_devices = [torch.device("cuda")]
+ torch_devices = [
+ torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())
+ ]
elif torch.backends.mps.is_available():
torch_devices = [torch.device("mps")]
else:
- torch_devices = [torch.device("cpu")]
+ try:
+ if (
+ torch.accelerator.is_available()
+ and (current_accelerator := torch.accelerator.current_accelerator())
+ is not None
+ ):
+ torch_devices = [current_accelerator]
+ else:
+ torch_devices = [torch.device("cpu")]
+ except Exception:
+ torch_devices = [torch.device("cpu")]
else:
torch_devices = [torch.device(d) for d in devices]
- if len(torch_devices) > 1:
- warnings.warn(
- f"Multiple devices for pytorch model not yet implemented; ignoring {torch_devices[1:]}"
- )
- torch_devices = torch_devices[:1]
-
return torch_devices
diff --git a/src/bioimageio/core/backends/tensorflow_backend.py b/src/bioimageio/core/backends/tensorflow_backend.py
index 8d13a7822..263dc8783 100644
--- a/src/bioimageio/core/backends/tensorflow_backend.py
+++ b/src/bioimageio/core/backends/tensorflow_backend.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import Any, Optional, Sequence, Union
+from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
import tensorflow as tf
@@ -8,42 +8,49 @@
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
+from .._model_adapter import LocalModelAdapter
from ..io import ensure_unzipped
-from ._model_adapter import ModelAdapter
-class TensorflowModelAdapter(ModelAdapter):
+class TensorflowModelAdapter(LocalModelAdapter[None, Any]):
+ """Adapter for TensorFlow 1 models"""
+
weight_format = "tensorflow_saved_model_bundle"
def __init__(
self,
- *,
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
devices: Optional[Sequence[str]] = None,
):
- super().__init__(model_description=model_description)
- weight_file = model_description.weights.tensorflow_saved_model_bundle
if model_description.weights.tensorflow_saved_model_bundle is None:
raise ValueError("No `tensorflow_saved_model_bundle` weights found")
+ if isinstance(model_description, v0_4.ModelDescr):
+ self._weight_src = (
+ model_description.weights.tensorflow_saved_model_bundle.source
+ )
+ else:
+ self._weight_src = model_description.weights.tensorflow_saved_model_bundle
+
+ self._graph = None
+ self._io_names: Optional[Tuple[List[str], List[str]]] = None
+ super().__init__(model_description=model_description, devices=devices)
+
+ def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]:
if devices is not None:
logger.warning(
f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
)
+ return (None,)
+
+ def _init_model_on_device(self, device: Optional[str]) -> Any:
# TODO: check how to load tf weights without unzipping
weight_file = ensure_unzipped(
- model_description.weights.tensorflow_saved_model_bundle.source,
- Path("bioimageio_unzipped_tf_weights"),
+ self._weight_src, Path("bioimageio_unzipped_tf_weights")
)
- self._network = str(weight_file)
- # TODO currently we relaod the model every time. it would be better to keep the graph and session
- # alive in between of forward passes (but then the sessions need to be properly opened / closed)
- def _forward_impl( # pyright: ignore[reportUnknownParameterType]
- self, input_arrays: Sequence[Optional[NDArray[Any]]]
- ):
# TODO read from spec
tag = ( # pyright: ignore[reportUnknownVariableType]
tf.saved_model.tag_constants.SERVING
@@ -52,92 +59,94 @@ def _forward_impl( # pyright: ignore[reportUnknownParameterType]
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
)
- graph = tf.Graph()
- with graph.as_default():
- with tf.Session(graph=graph) as sess: # pyright: ignore[reportUnknownVariableType]
- # load the model and the signature
- graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType]
- sess, [tag], self._network
- )
- signature = ( # pyright: ignore[reportUnknownVariableType]
- graph_def.signature_def
- )
+ self._graph = tf.Graph()
+ with self._graph.as_default():
+ sess = tf.Session(graph=self._graph) # pyright: ignore[reportUnknownVariableType]
+ # load the model and the signature
+ graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType]
+ sess, [tag], str(weight_file)
+ )
+ signature = ( # pyright: ignore[reportUnknownVariableType]
+ graph_def.signature_def
+ )
- # get the tensors into the graph
- in_names = [ # pyright: ignore[reportUnknownVariableType]
- signature[signature_key].inputs[key].name for key in self._input_ids
- ]
- out_names = [ # pyright: ignore[reportUnknownVariableType]
- signature[signature_key].outputs[key].name
- for key in self._output_ids
- ]
- in_tf_tensors = [
- graph.get_tensor_by_name(
- name # pyright: ignore[reportUnknownArgumentType]
- )
- for name in in_names # pyright: ignore[reportUnknownVariableType]
- ]
- out_tf_tensors = [
- graph.get_tensor_by_name(
- name # pyright: ignore[reportUnknownArgumentType]
- )
- for name in out_names # pyright: ignore[reportUnknownVariableType]
- ]
-
- # run prediction
- res = sess.run( # pyright: ignore[reportUnknownVariableType]
- dict(
- zip(
- out_names, # pyright: ignore[reportUnknownArgumentType]
- out_tf_tensors,
- )
- ),
- dict(zip(in_tf_tensors, input_arrays)),
- )
- # from dict to list of tensors
- res = [ # pyright: ignore[reportUnknownVariableType]
- res[out]
- for out in out_names # pyright: ignore[reportUnknownVariableType]
- ]
+ # get the tensors into the graph
+ in_names = [ # pyright: ignore[reportUnknownVariableType]
+ signature[signature_key].inputs[key].name for key in self._input_ids
+ ]
+ out_names = [ # pyright: ignore[reportUnknownVariableType]
+ signature[signature_key].outputs[key].name for key in self._output_ids
+ ]
+ self._io_names = (in_names, out_names)
- return res # pyright: ignore[reportUnknownVariableType]
+ return sess # pyright: ignore[reportUnknownVariableType]
- def unload(self) -> None:
- logger.warning(
- "Device management is not implemented for tensorflow 1, cannot unload model"
+ def _forward_impl(
+ self, device: None, model: Any, input_arrays: Sequence[Optional[NDArray[Any]]]
+ ):
+ assert self._io_names is not None
+ assert self._graph is not None
+
+ in_names, out_names = self._io_names
+ in_tf_tensors = [self._graph.get_tensor_by_name(name) for name in in_names]
+ out_tf_tensors = [self._graph.get_tensor_by_name(name) for name in out_names]
+
+ # run prediction
+ res = model.run(
+ dict(zip(out_names, out_tf_tensors)),
+ dict(zip(in_tf_tensors, input_arrays)),
)
+ # from dict to list of tensors
+ res = [res[out] for out in out_names]
+
+ return res
+
+ def _cleanup_pre_model_deletion(self, device: Optional[str], model: Any) -> None:
+ return
+
+ def _cleanup_post_model_deletion(self, device: Optional[str]) -> None:
+ return
-class KerasModelAdapter(ModelAdapter):
+class KerasModelAdapter(LocalModelAdapter[None, Any]):
def __init__(
self,
- *,
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
devices: Optional[Sequence[str]] = None,
):
if model_description.weights.tensorflow_saved_model_bundle is None:
raise ValueError("No `tensorflow_saved_model_bundle` weights found")
- super().__init__(model_description=model_description)
+ if isinstance(model_description, v0_4.ModelDescr):
+ self._weight_src = (
+ model_description.weights.tensorflow_saved_model_bundle.source
+ )
+ else:
+ self._weight_src = model_description.weights.tensorflow_saved_model_bundle
+
+ super().__init__(model_description=model_description, devices=devices)
+
+ def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]:
if devices is not None:
logger.warning(
f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
)
+ return (None,)
+ def _init_model_on_device(self, device: None) -> Any:
# TODO: check how to load tf weights without unzipping
- weight_file = ensure_unzipped(
- model_description.weights.tensorflow_saved_model_bundle.source,
- Path("bioimageio_unzipped_tf_weights"),
+ weight_file = str(
+ ensure_unzipped(self._weight_src, Path("bioimageio_unzipped_tf_weights"))
)
try:
- self._network = tf.keras.layers.TFSMLayer(
+ tfsm_layer = tf.keras.layers.TFSMLayer( # pyright: ignore[reportUnknownVariableType]
weight_file,
call_endpoint="serve",
)
except Exception as e:
try:
- self._network = tf.keras.layers.TFSMLayer(
+ tfsm_layer = tf.keras.layers.TFSMLayer( # pyright: ignore[reportUnknownVariableType]
weight_file, call_endpoint="serving_default"
)
except Exception as ee:
@@ -146,16 +155,16 @@ def __init__(
)
raise e
+ return tfsm_layer # pyright: ignore[reportUnknownVariableType]
+
def _forward_impl( # pyright: ignore[reportUnknownParameterType]
- self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ self, device: None, model: Any, input_arrays: Sequence[Optional[NDArray[Any]]]
):
assert tf is not None
tf_tensor = [
None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays
]
-
- result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType]
-
+ result = model(*tf_tensor)
assert isinstance(result, dict)
# TODO: Use RDF's `outputs[i].id` here
@@ -168,15 +177,15 @@ def _forward_impl( # pyright: ignore[reportUnknownParameterType]
for r in result # pyright: ignore[reportUnknownVariableType]
]
- def unload(self) -> None:
- logger.warning(
- "Device management is not implemented for tensorflow>=2 models"
- + f" using `{self.__class__.__name__}`, cannot unload model"
- )
+ def _cleanup_pre_model_deletion(self, device: Optional[str], model: Any) -> None:
+ return
+
+ def _cleanup_post_model_deletion(self, device: Optional[str]) -> None:
+ return
def create_tf_model_adapter(
- model_description: AnyModelDescr, devices: Optional[Sequence[str]]
+ model_description: AnyModelDescr, devices: Optional[Sequence[str]] = None
):
tf_version = v0_5.Version(tf.__version__) # type: ignore[reportUnknownVariableType]
weights = model_description.weights.tensorflow_saved_model_bundle
@@ -203,8 +212,6 @@ def create_tf_model_adapter(
)
if tf_version.major <= 1:
- return TensorflowModelAdapter(
- model_description=model_description, devices=devices
- )
+ return TensorflowModelAdapter(model_description, devices=devices)
else:
- return KerasModelAdapter(model_description=model_description, devices=devices)
+ return KerasModelAdapter(model_description, devices=devices)
diff --git a/src/bioimageio/core/backends/torchscript_backend.py b/src/bioimageio/core/backends/torchscript_backend.py
index 4e7fe0092..f1f81ed35 100644
--- a/src/bioimageio/core/backends/torchscript_backend.py
+++ b/src/bioimageio/core/backends/torchscript_backend.py
@@ -3,45 +3,57 @@
from typing import Any, List, Optional, Sequence, Union
import torch
+from loguru import logger
from numpy.typing import NDArray
from bioimageio.spec.model import v0_4, v0_5
-from ..model_adapters import ModelAdapter
+from .._model_adapter import LocalModelAdapter
from ..utils._type_guards import is_list, is_tuple
from .pytorch_backend import get_devices
-class TorchscriptModelAdapter(ModelAdapter):
+class TorchscriptModelAdapter(LocalModelAdapter[torch.device, Any]):
def __init__(
self,
- *,
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
devices: Optional[Sequence[str]] = None,
):
- super().__init__(model_description=model_description)
if model_description.weights.torchscript is None:
raise ValueError(
f"No torchscript weights found for model {model_description.name}"
)
- self.devices = get_devices(devices)
+ self._weight_descr = model_description.weights.torchscript
+ super().__init__(model_description=model_description, devices=devices)
- weight_reader = model_description.weights.torchscript.get_reader()
- self._model = torch.jit.load(weight_reader)
+ def _parse_devices(
+ self, devices: Optional[Sequence[str]]
+ ) -> Sequence[torch.device]:
+ return get_devices(devices)
- self._model.to(self.devices[0])
- self._model = self._model.eval()
+ def _init_model_on_device(self, device: torch.device) -> Any:
+ model = torch.jit.load(self._weight_descr.get_reader(), map_location=device)
+ try:
+ model.eval()
+ except Exception as e:
+ logger.warning(
+ f"Failed to set model to evaluation mode for torchscript model on {device}: {e}"
+ )
+ return model
def _forward_impl(
- self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ self,
+ device: torch.device,
+ model: Any,
+ input_arrays: Sequence[Optional[NDArray[Any]]],
) -> List[Optional[NDArray[Any]]]:
with torch.no_grad():
torch_tensor = [
- None if a is None else torch.from_numpy(a).to(self.devices[0])
+ None if a is None else torch.from_numpy(a).to(device)
for a in input_arrays
]
- output: Any = self._model.forward(*torch_tensor)
+ output: Any = model.forward(*torch_tensor)
if is_list(output) or is_tuple(output):
output_seq: Sequence[Any] = output
else:
@@ -58,8 +70,10 @@ def _forward_impl(
for r in output_seq
]
- def unload(self) -> None:
- self._devices = None
- del self._model
+ def _cleanup_pre_model_deletion(self, device: torch.device, model: Any) -> None:
+ return
+
+ def _cleanup_post_model_deletion(self, device: torch.device) -> None:
_ = gc.collect() # deallocate memory
- torch.cuda.empty_cache() # release reserved memory
+ if device.type == "cuda":
+ torch.cuda.empty_cache() # release reserved memory
diff --git a/src/bioimageio/core/block.py b/src/bioimageio/core/block.py
index d1d9d7d80..69af9118a 100644
--- a/src/bioimageio/core/block.py
+++ b/src/bioimageio/core/block.py
@@ -87,6 +87,15 @@ def from_meta(cls, meta: BlockMeta, data: Tensor) -> Self:
data=data,
)
+ def get_meta(self) -> BlockMeta:
+ return BlockMeta(
+ sample_shape=self.sample_shape,
+ inner_slice=self.inner_slice,
+ halo=self.halo,
+ block_index=self.block_index,
+ blocks_in_sample=self.blocks_in_sample,
+ )
+
def split_tensor_into_blocks(
tensor: Tensor,
diff --git a/src/bioimageio/core/block_meta.py b/src/bioimageio/core/block_meta.py
index f2849e02c..d4ab3c873 100644
--- a/src/bioimageio/core/block_meta.py
+++ b/src/bioimageio/core/block_meta.py
@@ -2,6 +2,7 @@
from dataclasses import dataclass
from functools import cached_property
from math import floor, prod
+from types import MappingProxyType
from typing import (
Any,
Callable,
@@ -15,13 +16,14 @@
Union,
)
+import pydantic
from loguru import logger
from typing_extensions import Self
+from ._axis_annotations import PerAxisAnno
from .axis import AxisId, PerAxis
from .common import (
BlockIndex,
- Frozen,
Halo,
HaloLike,
MemberId,
@@ -42,7 +44,7 @@ def compute(self, s: int, round: Callable[[float], int] = floor) -> int:
return round(s * self.scale) + self.offset
-@dataclass(frozen=True)
+@pydantic.dataclasses.dataclass(frozen=True)
class BlockMeta:
"""Block meta data of a sample member (a tensor in a sample)
@@ -76,13 +78,13 @@ class BlockMeta:
"""
- sample_shape: PerAxis[int]
+ sample_shape: PerAxisAnno[int]
"""the axis sizes of the whole (unblocked) sample"""
- inner_slice: PerAxis[SliceInfo]
+ inner_slice: PerAxisAnno[SliceInfo]
"""inner region (without halo) wrt the sample"""
- halo: PerAxis[Halo]
+ halo: PerAxisAnno[Halo]
"""halo enlarging the inner region to the block's sizes"""
block_index: BlockIndex
@@ -94,7 +96,7 @@ class BlockMeta:
@cached_property
def shape(self) -> PerAxis[int]:
"""axis lengths of the block"""
- return Frozen(
+ return MappingProxyType(
{
a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0)
for a, s in self.inner_slice.items()
@@ -105,7 +107,7 @@ def shape(self) -> PerAxis[int]:
def padding(self) -> PerAxis[PadWidth]:
"""padding to realize the halo at the sample edge
where we cannot simply enlarge the inner slice"""
- return Frozen(
+ return MappingProxyType(
{
a: PadWidth(
(
@@ -128,7 +130,7 @@ def padding(self) -> PerAxis[PadWidth]:
@cached_property
def outer_slice(self) -> PerAxis[SliceInfo]:
"""slice of the outer block (without padding) wrt the sample"""
- return Frozen(
+ return MappingProxyType(
{
a: SliceInfo(
max(
@@ -154,16 +156,22 @@ def outer_slice(self) -> PerAxis[SliceInfo]:
@cached_property
def inner_shape(self) -> PerAxis[int]:
"""axis lengths of the inner region (without halo)"""
- return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})
+ return MappingProxyType(
+ {a: s.stop - s.start for a, s in self.inner_slice.items()}
+ )
@cached_property
def local_slice(self) -> PerAxis[SliceInfo]:
"""inner slice wrt the block, **not** the sample"""
- return Frozen(
+ return MappingProxyType(
{
- a: SliceInfo(
- self.halo[a].left,
- self.halo[a].left + self.inner_shape[a],
+ a: (
+ SliceInfo(
+ h.left,
+ h.left + self.inner_shape[a],
+ )
+ if (h := self.halo.get(a)) is not None
+ else SliceInfo(0, self.inner_shape[a])
)
for a in self.inner_slice
}
@@ -189,16 +197,6 @@ def inner_slice_wo_overlap(self) -> PerAxis[SliceInfo]:
return self.inner_slice
def __post_init__(self):
- # freeze mutable inputs
- if not isinstance(self.sample_shape, Frozen):
- object.__setattr__(self, "sample_shape", Frozen(self.sample_shape))
-
- if not isinstance(self.inner_slice, Frozen):
- object.__setattr__(self, "inner_slice", Frozen(self.inner_slice))
-
- if not isinstance(self.halo, Frozen):
- object.__setattr__(self, "halo", Frozen(self.halo))
-
assert all(a in self.sample_shape for a in self.inner_slice), (
"block has axes not present in sample"
)
@@ -254,10 +252,12 @@ def split_shape_into_blocks(
halo: PerAxis[HaloLike],
stride: Optional[PerAxis[int]] = None,
) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]:
- assert all(a in shape for a in block_shape), (
- tuple(shape),
- set(block_shape),
- )
+ unknown_axes = [a for a in block_shape if a not in shape]
+ if unknown_axes:
+ raise ValueError(
+ f"unknown axes in block_shape: {unknown_axes} for shape {shape}"
+ )
+
if any(shape[a] < block_shape[a] for a in block_shape):
# TODO: allow larger blockshape
raise ValueError(f"shape {shape} is smaller than block shape {block_shape}")
diff --git a/src/bioimageio/core/cli.py b/src/bioimageio/core/cli.py
index 318d61512..0a8dd8a0f 100644
--- a/src/bioimageio/core/cli.py
+++ b/src/bioimageio/core/cli.py
@@ -83,19 +83,19 @@
write_yaml,
)
+from ._prediction_pipeline import (
+ create_prediction_pipeline,
+ create_remote_prediction_pipeline,
+)
from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test
from .common import MemberId, SampleId, SupportedWeightsFormat
from .digest_spec import get_member_ids, load_sample_for_model
from .io import load_stat, save_sample, save_stat
-from .prediction import create_prediction_pipeline
-from .proc_setup import (
- Measure,
- MeasureValue,
- StatsCalculator,
- get_required_dataset_measures,
-)
+from .proc_setup import get_required_dataset_measures
+from .remote_backends import create_remote_model_adapter
from .sample import Sample
-from .stat_measures import Stat
+from .stat_calculators import StatsCalculator
+from .stat_measures import Measure, MeasureValue, Stat
from .utils import compare
from .weight_converters._add_weights import add_weights
@@ -520,6 +520,20 @@ class PredictCmd(CmdBase, WithSource):
)
"""Device(s) to use"""
+ server: Optional[str] = None
+ """The URL or Hugging Face space name of a running bioimageio (gradio) server instance to use as a remote backend for prediction."""
+
+ pre_post_processing_location: Literal["local", "remote"] = Field(
+ "local", alias="pre-post-processing-location"
+ )
+ """Where to run preprocessing/postprocessing operations when using `--server`.
+
+ - `local`: Run preprocessing/postprocessing locally and only model inference on the server.
+ - `remote`: Run preprocessing/postprocessing on the server as well.
+
+  
+ """
+
example: bool = False
"""generate and run an example
@@ -794,11 +808,25 @@ def input_dataset(stat: Stat):
).items()
)
- pp = create_prediction_pipeline(
- model_descr,
- weight_format=None if self.weight_format == "any" else self.weight_format,
- devices=self.devices,
- )
+ if self.server is not None and self.pre_post_processing_location == "remote":
+ pp = create_remote_prediction_pipeline(model_descr, server=self.server)
+ else:
+ if self.server is None:
+ model_adapter = None
+ else:
+ assert self.pre_post_processing_location == "local"
+ model_adapter = create_remote_model_adapter(
+ model_descr, server=self.server
+ )
+
+ pp = create_prediction_pipeline(
+ model_descr,
+ weight_format=None
+ if self.weight_format == "any"
+ else self.weight_format,
+ devices=self.devices,
+ model_adapter=model_adapter,
+ )
if blockwise:
predict_method = partial(
@@ -920,13 +948,40 @@ def cli_cmd(self):
self.log(updated_model_descr)
-class EmptyCache(CmdBase):
+class EmptyCacheCmd(CmdBase):
"""Empty the bioimageio cache directory."""
def cli_cmd(self):
empty_cache()
+class ServerCmd(CmdBase):
+ """Start a server to connect to with remote model adapters or remote prediction pipelines."""
+
+ backend: Literal["gradio"] = "gradio"
+ """The remote backend to use."""
+
+ port: Optional[int] = None
+ """The port to start the server on. If not given, a free port will be used."""
+
+ def cli_cmd(self) -> None:
+ try:
+ if self.backend == "gradio":
+ from .remote_backends.gradio.server import main
+ else:
+ assert_never(self.backend)
+ except ImportError as e:
+ raise ImportError(
+ f"{self.backend.capitalize()} is not installed. Please install the '{self.backend}-server' extra to use this command,"
+ + f" e.g. with `pip install bioimageio.core[{self.backend}-server]`."
+ ) from e
+
+ local_server_url = main(port=self.port)
+ logger.info(
+ "{} server shutdown at {}", self.backend.capitalize(), local_server_url
+ )
+
+
JSON_FILE = "bioimageio-cli.json"
YAML_FILE = "bioimageio-cli.yaml"
@@ -972,9 +1027,12 @@ class Bioimageio(
add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights")
"""Add additional weights to a model description by converting from available formats."""
- empty_cache: CliSubCommand[EmptyCache] = Field(alias="empty-cache")
+ empty_cache: CliSubCommand[EmptyCacheCmd] = Field(alias="empty-cache")
"""Empty the bioimageio cache directory."""
+ server: CliSubCommand[ServerCmd]
+ """Start a server to connect to with remote model adapters or remote prediction pipelines."""
+
@classmethod
def settings_customise_sources(
cls,
diff --git a/src/bioimageio/core/common.py b/src/bioimageio/core/common.py
index 981e1efc8..2c4af41e8 100644
--- a/src/bioimageio/core/common.py
+++ b/src/bioimageio/core/common.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-from types import MappingProxyType
from typing import (
Hashable,
Literal,
@@ -11,12 +10,10 @@
Union,
)
-from typing_extensions import Self, assert_never
+from typing_extensions import Self, TypeAlias, assert_never
from bioimageio.spec.model import v0_5
-from .axis import AxisId
-
SupportedWeightsFormat = Literal[
"keras_hdf5",
"keras_v3",
@@ -112,10 +109,10 @@ class PadWidth(_LeftRight):
pass
-PadWidthLike = _LeftRightLike[PadWidth]
-Padding = v0_5.Padding
-PadMode = Union[Literal["constant", "edge", "reflect", "symmetric"], Padding]
-PadWhere = _Where
+PadWidthLike: TypeAlias = _LeftRightLike[PadWidth]
+Padding: TypeAlias = v0_5.Padding
+PadMode: TypeAlias = Union[Literal["constant", "edge", "reflect", "symmetric"], Padding]
+PadWhere: TypeAlias = _Where
class SliceInfo(NamedTuple):
@@ -128,49 +125,17 @@ class SliceInfo(NamedTuple):
MemberId = v0_5.TensorId
"""ID of a `Sample` member, see `bioimageio.core.sample.Sample`"""
-BlocksizeParameter = Union[
+BlocksizeParameter: TypeAlias = Union[
v0_5.ParameterizedSize_N,
- Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
+ Mapping[Tuple[MemberId, v0_5.AxisId], v0_5.ParameterizedSize_N],
]
"""
Parameter to determine a concrete size for paramtrized axis sizes defined by
`bioimageio.spec.model.v0_5.ParameterizedSize`.
"""
-T = TypeVar("T")
-PerMember = Mapping[MemberId, T]
+_T = TypeVar("_T")
+PerMember = Mapping[MemberId, _T]
BlockIndex = int
TotalNumberOfBlocks = int
-
-
-K = TypeVar("K", bound=Hashable)
-V = TypeVar("V")
-
-Frozen = MappingProxyType
-# class Frozen(Mapping[K, V]): # adapted from xarray.core.utils.Frozen
-# """Wrapper around an object implementing the mapping interface to make it
-# immutable."""
-
-# __slots__ = ("mapping",)
-
-# def __init__(self, mapping: Mapping[K, V]):
-# super().__init__()
-# self.mapping = deepcopy(
-# mapping
-# ) # added deepcopy (compared to xarray.core.utils.Frozen)
-
-# def __getitem__(self, key: K) -> V:
-# return self.mapping[key]
-
-# def __iter__(self) -> Iterator[K]:
-# return iter(self.mapping)
-
-# def __len__(self) -> int:
-# return len(self.mapping)
-
-# def __contains__(self, key: object) -> bool:
-# return key in self.mapping
-
-# def __repr__(self) -> str:
-# return f"{type(self).__name__}({self.mapping!r})"
diff --git a/src/bioimageio/core/digest_spec.py b/src/bioimageio/core/digest_spec.py
index 280cadb87..4819f4178 100644
--- a/src/bioimageio/core/digest_spec.py
+++ b/src/bioimageio/core/digest_spec.py
@@ -25,7 +25,7 @@
import xarray as xr
from loguru import logger
from numpy.typing import NDArray
-from typing_extensions import Unpack, assert_never
+from typing_extensions import TypeAlias, Unpack, assert_never
from bioimageio.spec._internal.io import HashKwargs, PermissiveFileSource
from bioimageio.spec.common import FileDescr, FileSource
@@ -46,12 +46,15 @@
LinearSampleAxisTransform,
Sample,
SampleBlockMeta,
+ SampleBlockWithOrigin,
sample_block_meta_generator,
)
from .stat_measures import Stat
from .tensor import Tensor
-TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource]
+TensorSource: TypeAlias = Union[
+ Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource
+]
def import_callable(
@@ -291,12 +294,23 @@ class IO_SampleBlockMeta(NamedTuple):
output: SampleBlockMeta
-def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
+def get_input_halo(
+ model: v0_5.ModelDescr, output_halo: Optional[PerMember[PerAxis[Halo]]] = None
+):
"""returns which halo input tensors need to be divided into blocks with, such that
`output_halo` can be cropped from their outputs without introducing gaps."""
input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
outputs = {t.id: t for t in model.outputs}
all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
+ if output_halo is None:
+ output_halo = {
+ t.id: {
+ a.id: Halo(a.halo, a.halo)
+ for a in t.axes
+ if isinstance(a, v0_5.WithHalo)
+ }
+ for t in model.outputs
+ }
for t, th in output_halo.items():
axes = {a.id: a for a in outputs[t].axes}
@@ -547,3 +561,33 @@ def load_sample_for_model(
stat={} if stat is None else stat,
id=sample_id or tuple(sorted(paths.values())),
)
+
+
+def split_sample_into_blocks_for_model(
+ sample: Sample,
+ model: v0_5.ModelDescr,
+ blocksize_parameter: int,
+ batch_size: int = 1,
+) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]:
+ if isinstance(model, v0_4.ModelDescr):
+ raise NotImplementedError(
+ "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
+ + f" {model.name}."
+ + " Consider using `predict_sample_with_fixed_blocking` or update the model description to format version 0.5."
+ )
+
+ ns = {
+ (ipt.id, a.id): blocksize_parameter
+ for ipt in model.inputs
+ for a in ipt.axes
+ if isinstance(a.size, v0_5.ParameterizedSize)
+ }
+ halo = get_input_halo(model)
+
+ input_block_shape = model.get_tensor_sizes(ns, batch_size=batch_size).inputs
+
+ return sample.split_into_blocks(
+ block_shapes=input_block_shape,
+ halo=halo,
+ pad_mode={ipt.id: ipt.pad or "symmetric" for ipt in model.inputs},
+ )
diff --git a/src/bioimageio/core/io.py b/src/bioimageio/core/io.py
index 179070619..82cb512f6 100644
--- a/src/bioimageio/core/io.py
+++ b/src/bioimageio/core/io.py
@@ -7,6 +7,7 @@
from pathlib import Path
from shutil import copyfileobj
from typing import (
+ TYPE_CHECKING,
Any,
Dict,
List,
@@ -22,6 +23,8 @@
from loguru import logger
from numpy.typing import NDArray
from pydantic import BaseModel, RootModel
+from typing_extensions import TypeAlias
+from typing_extensions import TypeAliasType as _TypeAliasType
from bioimageio.spec._internal.io import get_reader, interprete_file_source
from bioimageio.spec._internal.type_guards import is_ndarray
@@ -38,13 +41,23 @@
from .axis import AxisId, AxisLike
from .common import PerMember
-from .sample import Sample, Stat
-from .stat_measures import DatasetMeasure, MeasureValue, SampleMeasure
+from .sample import Sample
+from .stat_measures import DatasetMeasure, MeasureValue, SampleMeasure, Stat
from .tensor import Tensor
-JsonValue = Union[
- bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"]
-]
+if TYPE_CHECKING:
+ JsonValue: TypeAlias = Union[
+ bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"]
+ ] # note: order relevant for deserializing
+
+else:
+ # for pydantic validation we need to use `TypeAliasType`,
+ # see https://docs.pydantic.dev/latest/concepts/types/#named-recursive-types
+ # however this results in a partially unknown type with the current pyright 1.1.388
+ JsonValue: TypeAlias = _TypeAliasType(
+ "JsonValue",
+ Union[bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"]],
+ )
def load_image(
diff --git a/src/bioimageio/core/model_adapters.py b/src/bioimageio/core/model_adapters.py
deleted file mode 100644
index db92d013a..000000000
--- a/src/bioimageio/core/model_adapters.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""DEPRECATED"""
-
-from typing import List
-
-from .backends._model_adapter import (
- DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER,
- ModelAdapter,
- create_model_adapter,
-)
-
-__all__ = [
- "ModelAdapter",
- "create_model_adapter",
- "get_weight_formats",
-]
-
-
-def get_weight_formats() -> List[str]:
- """
- Return list of supported weight types
- """
- return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)
diff --git a/src/bioimageio/core/prediction.py b/src/bioimageio/core/prediction.py
index 65e0802cd..d00585be9 100644
--- a/src/bioimageio/core/prediction.py
+++ b/src/bioimageio/core/prediction.py
@@ -65,7 +65,7 @@ def predict(
"""
if isinstance(model, PredictionPipeline):
pp = model
- model = pp.model_description
+ model = pp.model_descr
else:
if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
loaded = load_description(model)
@@ -78,54 +78,56 @@ def predict(
fixed_dataset_statistics=inputs.stat if isinstance(inputs, Sample) else {},
)
- if save_output_path is not None:
- if (
- "{output_id}" not in str(save_output_path)
- and "{member_id}" not in str(save_output_path)
- and len(model.outputs) > 1
- ):
- raise ValueError(
- f"Missing `{{output_id}}` in save_output_path={save_output_path} to "
- + "distinguish model outputs "
- + str([get_member_id(d) for d in model.outputs])
+ with pp:
+ model = pp.model_descr
+ if save_output_path is not None:
+ if (
+ "{output_id}" not in str(save_output_path)
+ and "{member_id}" not in str(save_output_path)
+ and len(model.outputs) > 1
+ ):
+ raise ValueError(
+ f"Missing `{{output_id}}` in save_output_path={save_output_path} to "
+ + "distinguish model outputs "
+ + str([get_member_id(d) for d in model.outputs])
+ )
+
+ if isinstance(inputs, Sample):
+ sample = inputs
+ else:
+ sample = create_sample_for_model(
+ pp.model_descr, inputs=inputs, sample_id=sample_id
)
- if isinstance(inputs, Sample):
- sample = inputs
- else:
- sample = create_sample_for_model(
- pp.model_description, inputs=inputs, sample_id=sample_id
- )
-
- if input_block_shape is not None:
- if blocksize_parameter is not None:
- logger.warning(
- "ignoring blocksize_parameter={} in favor of input_block_shape={}",
- blocksize_parameter,
- input_block_shape,
+ if input_block_shape is not None:
+ if blocksize_parameter is not None:
+ logger.warning(
+ "ignoring blocksize_parameter={} in favor of input_block_shape={}",
+ blocksize_parameter,
+ input_block_shape,
+ )
+
+ output = pp.predict_sample_with_fixed_blocking(
+ sample,
+ input_block_shape=input_block_shape,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
)
-
- output = pp.predict_sample_with_fixed_blocking(
- sample,
- input_block_shape=input_block_shape,
- skip_preprocessing=skip_preprocessing,
- skip_postprocessing=skip_postprocessing,
- )
- elif blocksize_parameter is not None:
- output = pp.predict_sample_with_blocking(
- sample,
- skip_preprocessing=skip_preprocessing,
- skip_postprocessing=skip_postprocessing,
- ns=blocksize_parameter,
- )
- else:
- output = pp.predict_sample_without_blocking(
- sample,
- skip_preprocessing=skip_preprocessing,
- skip_postprocessing=skip_postprocessing,
- )
- if save_output_path:
- save_sample(save_output_path, output)
+ elif blocksize_parameter is not None:
+ output = pp.predict_sample_with_blocking(
+ sample,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ ns=blocksize_parameter,
+ )
+ else:
+ output = pp.predict_sample_without_blocking(
+ sample,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ )
+ if save_output_path:
+ save_sample(save_output_path, output)
return output
diff --git a/src/bioimageio/core/proc_ops.py b/src/bioimageio/core/proc_ops.py
index 052767fe0..cce52970e 100644
--- a/src/bioimageio/core/proc_ops.py
+++ b/src/bioimageio/core/proc_ops.py
@@ -60,7 +60,7 @@ def _convert_axis_ids(
if mode == "per_sample":
ret = []
elif mode == "per_dataset":
- ret = [v0_5.BATCH_AXIS_ID]
+ ret = [AxisId(v0_5.BATCH_AXIS_ID)]
else:
assert_never(mode)
@@ -562,7 +562,9 @@ def get_descr(self):
return v0_5.ScaleRangeDescr(
kwargs=v0_5.ScaleRangeKwargs(
- axes=self.lower.axes,
+ axes=None
+ if self.lower.axes is None
+ else [v0_5.AxisId(a) for a in self.lower.axes],
min_percentile=self.lower.q * 100,
max_percentile=self.upper.q * 100,
eps=self.eps,
@@ -625,7 +627,7 @@ def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
def get_descr(self):
- return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
+ return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=v0_5.AxisId(self.axis)))
@dataclass
diff --git a/src/bioimageio/core/remote_backends/__init__.py b/src/bioimageio/core/remote_backends/__init__.py
new file mode 100644
index 000000000..1e2be1d07
--- /dev/null
+++ b/src/bioimageio/core/remote_backends/__init__.py
@@ -0,0 +1,38 @@
+from typing import TYPE_CHECKING, Literal, Optional
+
+from typing_extensions import assert_never
+
+from bioimageio.spec.model import AnyModelDescr
+
+if TYPE_CHECKING:
+ from .gradio.client import GradioModelAdapter
+
+
+def create_remote_model_adapter(
+ model_description: AnyModelDescr,
+ server: Optional[str] = None,
+ server_type: Optional[Literal["gradio"]] = None,
+) -> "GradioModelAdapter":
+ """Create a remote model adapter
+
+ Args:
+ model_description: The model to run inference with.
+ server: The URL or Hugging Face space name of a running bioimageio server instance
+ server_type: The type of the remote server to connect to. Currently only "gradio" is supported.
+ """
+
+ if server_type is None:
+ server_type = "gradio"
+
+ try:
+ if server_type == "gradio":
+ from .gradio.client import GradioModelAdapter as RemoteModelAdapterImpl
+ else:
+ assert_never(server_type)
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import {server_type.capitalize()}ModelAdapter. Make sure to install the '{server_type}-client' extra,"
+ + f" e.g. with `pip install bioimageio.core[{server_type}-client]`."
+ ) from e
+
+ return RemoteModelAdapterImpl(model_description=model_description, server=server)
diff --git a/src/bioimageio/core/remote_backends/gradio/__init__.py b/src/bioimageio/core/remote_backends/gradio/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/bioimageio/core/remote_backends/gradio/client.py b/src/bioimageio/core/remote_backends/gradio/client.py
new file mode 100644
index 000000000..8a368e0c8
--- /dev/null
+++ b/src/bioimageio/core/remote_backends/gradio/client.py
@@ -0,0 +1,316 @@
+from types import MappingProxyType
+from typing import Dict, Iterable, Literal, Mapping, Optional, Tuple, Union
+
+from gradio_client import Client
+from loguru import logger
+
+from bioimageio.spec import AnyModelDescr, ValidationSummary
+from bioimageio.spec.model import v0_4
+
+from ..._description_serializer import DescriptionSerializer as DescriptionSerializer
+from ..._model_adapter import RemoteModelAdapter
+from ..._prediction_pipeline import IntermediatePrediction, RemotePredictionPipeline
+from ..._settings import settings
+from ...axis import PerAxis
+from ...common import BlocksizeParameter, PerMember
+from ...io import JsonValue
+from ...sample import Sample, SampleBlock
+from ...stat_measures import Measure, MeasureValue
+from .serializer import GradioSampleSerializer
+
+SerializedSampleBlock = Dict[str, JsonValue]
+
+
+class GradioModelAdapter(RemoteModelAdapter[SerializedSampleBlock]):
+ """Model adapter to use the bioimage-io-gradio-runner as a backend for model inference."""
+
+ def __init__(
+ self, model_description: AnyModelDescr, *, server: Optional[str] = None
+ ):
+ """Initialize the GradioModelAdapter.
+
+ Note:
+ - This adapter requires an environment with the same gradio version as the one used on the bioimage-io-gradio-runner server.
+
+ Args:
+ model_description: The model to run inference with.
+ server: The URL of a running bioimage-io-gradio-server instance (default server might not be availability/compatible).
+ """
+ server = server or settings.gradio_server
+ if server is None:
+ raise ValueError(
+ "No gradio server specified. Please provide a server URL or set the 'BIOIMAGEIO_GRADIO_SERVER' environment variable."
+ )
+
+ self._client = Client(server, httpx_kwargs={"timeout": 60})
+ self._serialized_model, self._sha256 = (
+ DescriptionSerializer.serialize_to_string_and_hash(model_description)
+ )
+ super().__init__(
+ model_description, server=server, sample_serializer=GradioSampleSerializer()
+ )
+
+ def _forward_impl(
+ self, serialized_input_sample: Iterable[SerializedSampleBlock]
+ ) -> Iterable[SerializedSampleBlock]:
+ return _call_predict_api(
+ self._client,
+ self._serialized_model,
+ self._sha256,
+ serialized_input_sample,
+ blocksize=None,
+ skip_preprocessing=True,
+ skip_postprocessing=True,
+ skip_input_padding=True,
+ skip_output_cropping=True,
+ batch_size=None,
+ )
+
+ def unload(self):
+ return super().unload()
+
+ def load(self) -> None:
+ for model_data in ("", self._serialized_model):
+ try:
+ result = self._client.submit(
+ api_name="/load_model", model=model_data, sha256=self._sha256
+ ).result()
+ except Exception as e:
+ if model_data:
+ logger.warning(
+ "Failed to load model on server with model_data, error was: {}",
+ len(model_data),
+ e,
+ )
+ else:
+ if result:
+ break
+
+ def test(self) -> Optional[ValidationSummary]:
+ for model_data in ("", self._serialized_model):
+ try:
+ result = self._client.submit(
+ api_name="/test_model", model=model_data, sha256=self._sha256
+ ).result()
+ except Exception as e:
+ if model_data:
+ logger.warning(
+ "Failed to test model on server with model_data, error was: {}",
+ len(model_data),
+ e,
+ )
+ else:
+ if result:
+ return ValidationSummary.model_validate_json(result)
+
+ return None
+
+
+class GradioPredictionPipeline(RemotePredictionPipeline):
+ """Prediction pipeline to use the bioimage-io-gradio-runner as a fully remote prediction pipeline."""
+
+ def __init__(
+ self,
+ model_description: AnyModelDescr,
+ *,
+ server: Optional[str] = None,
+ precomputed_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}),
+ default_blocksize_parameter: BlocksizeParameter = 10,
+ default_batch_size: int = 1,
+ ):
+ """
+ Note:
+ - This pipeline requires an environment with the same gradio version as the one used on the bioimage-io-gradio-runner server.
+
+ Args:
+ model_description: The model to run inference with.
+ server: The URL or Hugging Face space name of a running bioimageio gradio server instance (Note: default server might not be availabile/compatible!).
+ """
+ server = server or settings.gradio_server
+ if server is None:
+ raise ValueError(
+ "No gradio server specified. Please provide a server URL or set the 'BIOIMAGEIO_GRADIO_SERVER' environment variable."
+ )
+
+ super().__init__(
+ model_description,
+ server=server,
+ default_blocksize_parameter=default_blocksize_parameter,
+ default_batch_size=default_batch_size,
+ )
+ self._client = Client(self.server, httpx_kwargs={"timeout": 60})
+ self._serialized_model, self._sha256 = (
+ DescriptionSerializer.serialize_to_string_and_hash(model_description)
+ )
+ self._serializer = GradioSampleSerializer
+ self._precomputed_statistics = dict(precomputed_statistics)
+
+ def predict_sample_block(
+ self,
+ sample_block: SampleBlock,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ ) -> SampleBlock:
+ if isinstance(self._model_descr, v0_4.ModelDescr):
+ raise NotImplementedError(
+ f"predict_sample_block not implemented for model {self._model_descr.format_version}"
+ )
+ else:
+ assert self._block_transform is not None
+
+ sample_block.stat.update(self._precomputed_statistics)
+ output_block = self._serializer.deserialize_sample(
+ _call_predict_api(
+ self._client,
+ self._serialized_model,
+ self._sha256,
+ serialized_input_sample=self._serializer.serialize_sample(
+ sample_block.as_sample()
+ ),
+ blocksize=None,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ skip_input_padding=True,
+ skip_output_cropping=True,
+ batch_size=self._default_batch_size,
+ )
+ )
+ output_meta = sample_block.get_transformed_meta(self._block_transform)
+ return output_meta.with_data(output_block.members, stat=sample_block.stat)
+
+ def predict_sample_without_blocking(
+ self,
+ sample: Sample,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ skip_input_padding: bool = False,
+ skip_output_cropping: bool = False,
+ ) -> Sample:
+ sample.stat.update(self._precomputed_statistics)
+ return self._serializer.deserialize_sample(
+ _call_predict_api(
+ self._client,
+ self._serialized_model,
+ self._sha256,
+ serialized_input_sample=self._serializer.serialize_sample(sample),
+ blocksize=None,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ skip_input_padding=skip_input_padding,
+ skip_output_cropping=skip_output_cropping,
+ batch_size=self._default_batch_size,
+ )
+ )
+
+ def predict_sample_with_fixed_blocking_yield_intermediates(
+ self,
+ sample: Sample,
+ input_block_shape: PerMember[PerAxis[int]],
+ *,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ fill_value: float = float("nan"),
+ ) -> Tuple[int, Iterable[IntermediatePrediction]]:
+ sample.stat.update(self._precomputed_statistics)
+
+ # blocking for serialization is not really important, but we might as well block
+ # the same way we want the backend to block for blockwise prediction
+ serialized_input_sample = self._serializer.serialize_sample_with_fixed_blocking(
+ sample, block_shapes=input_block_shape, halo=self._default_input_halo
+ )
+
+ def _predict_blocks():
+ output_sample = None
+ for serialized_output_block in _call_predict_api(
+ self._client,
+ self._serialized_model,
+ self._sha256,
+ serialized_input_sample=serialized_input_sample,
+ blocksize=input_block_shape,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ skip_input_padding=False,
+ skip_output_cropping=False,
+ batch_size=self._default_batch_size,
+ ):
+ output_block = self._serializer.deserialize_sample_block(
+ serialized_output_block
+ )
+ if output_sample is None:
+ output_sample = Sample.from_blocks(
+ [output_block], fill_value=fill_value
+ )
+ else:
+ output_sample.set_block(output_block)
+
+ yield IntermediatePrediction(output_sample, output_block)
+
+ block_iterator = _predict_blocks()
+ first_intermediate = next(block_iterator)
+
+ def _intermediate_predictions() -> Iterable[IntermediatePrediction]:
+ yield first_intermediate
+ yield from block_iterator
+
+ return (
+ first_intermediate.last_block.blocks_in_sample,
+ _intermediate_predictions(),
+ )
+
+
+def _call_predict_api(
+ client: Client,
+ serialized_model: str,
+ sha256: str,
+ serialized_input_sample: Iterable[SerializedSampleBlock],
+ blocksize: Optional[
+ Union[int, Literal["blockwise_as_serialized"], PerMember[PerAxis[int]]]
+ ],
+ skip_preprocessing: bool,
+ skip_postprocessing: bool,
+ skip_input_padding: bool,
+ skip_output_cropping: bool,
+ batch_size: Optional[int],
+) -> Iterable[SerializedSampleBlock]:
+ def submit(model: str):
+ return client.submit(
+ api_name="/predict",
+ model=model,
+ sha256=sha256,
+ input_sample=serialized_input_sample,
+ blocksize={
+ str(k): {str(kk): vv for kk, vv in v.items()}
+ for k, v in blocksize.items()
+ }
+ if not (blocksize is None or isinstance(blocksize, (int, str)))
+ else blocksize,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ skip_input_padding=skip_input_padding,
+ skip_output_cropping=skip_output_cropping,
+ batch_size=batch_size,
+ )
+
+ try_with_model_upload = True
+ try:
+ job = submit("")
+ for block in job: # pyright: ignore[reportUnknownVariableType]
+ yield block # pyright: ignore[reportReturnType]
+ # we got one response, so the model cache was hit...
+ try_with_model_upload = False
+ except Exception as e:
+ # A raised exception on the server seems to simply return an empty response sequence,
+ # so this except is likely not triggered at all.
+ # Below we retry on empty return value, too.
+ if try_with_model_upload:
+ logger.warning(
+ "Failed to submit job without model upload, trying with model upload, error was: {}",
+ e,
+ )
+ else:
+ raise e
+
+ if try_with_model_upload:
+ job = submit(serialized_model)
+ for block in job: # pyright: ignore[reportUnknownVariableType]
+ yield block # pyright: ignore[reportReturnType]
diff --git a/src/bioimageio/core/remote_backends/gradio/serializer.py b/src/bioimageio/core/remote_backends/gradio/serializer.py
new file mode 100644
index 000000000..3800b3a6c
--- /dev/null
+++ b/src/bioimageio/core/remote_backends/gradio/serializer.py
@@ -0,0 +1,72 @@
+import tempfile
+from pathlib import Path
+from typing import Dict, List, Mapping, Union
+
+import numpy as np
+from gradio_client import handle_file
+from pydantic import BaseModel
+from typing_extensions import Self
+
+from ..._common_annotations import PerMemberAnno
+from ..._description_serializer import DescriptionSerializer as DescriptionSerializer
+from ..._sample_serializer import SampleSerializer
+from ...common import MemberId
+from ...io import JsonValue, load_stat, save_tensor, serialize_stat
+from ...sample import SampleBlock, SampleBlockMeta
+from ...tensor import Tensor
+
+
+class _SerializableBlock(BaseModel, frozen=True):
+ path: Path
+ meta: Mapping[str, str]
+ orig_name: str
+
+ @classmethod
+ def from_tensor(cls, tensor: Tensor) -> Self:
+ with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
+ save_tensor(tmp.name, tensor)
+
+ handled = handle_file(Path(tmp.name))
+ return cls.model_validate(handled)
+
+
+class _SerializableSampleBlock(BaseModel, frozen=True):
+ meta: SampleBlockMeta
+ data: PerMemberAnno[Union[_SerializableBlock, Path]]
+ serialized_stat: List[JsonValue]
+
+
+SerializedSampleBlock = Dict[str, JsonValue]
+
+
+class GradioSampleSerializer(SampleSerializer[SerializedSampleBlock]):
+ @staticmethod
+ def serialize_sample_block(sample_block: SampleBlock) -> SerializedSampleBlock:
+ handled_members: Dict[MemberId, _SerializableBlock] = {}
+ for m, t in sample_block.members.items():
+ handled_members[m] = _SerializableBlock.from_tensor(t)
+
+ serializable = _SerializableSampleBlock(
+ data=handled_members,
+ meta=sample_block.get_meta(),
+ serialized_stat=serialize_stat(sample_block.stat),
+ )
+ serialized = serializable.model_dump(mode="json")
+ return serialized
+
+ @staticmethod
+ def deserialize_sample_block(serialized: SerializedSampleBlock) -> SampleBlock:
+ deserializable_sample = _SerializableSampleBlock.model_validate(serialized)
+ sample_meta = deserializable_sample.meta
+ members = {
+ k: Tensor.from_numpy(
+ np.load(v if isinstance(v, Path) else v.path),
+ dims=list(sample_meta.shape[k]),
+ )
+ for k, v in deserializable_sample.data.items()
+ }
+ return SampleBlock.from_meta(
+ sample_meta,
+ data=members,
+ stat=load_stat(deserializable_sample.serialized_stat),
+ )
diff --git a/src/bioimageio/core/remote_backends/gradio/server.py b/src/bioimageio/core/remote_backends/gradio/server.py
new file mode 100644
index 000000000..cf1cc2361
--- /dev/null
+++ b/src/bioimageio/core/remote_backends/gradio/server.py
@@ -0,0 +1,249 @@
+from itertools import chain
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Literal,
+ Optional,
+ Union,
+)
+
+import gradio as gr
+from loguru import logger
+
+import bioimageio.core
+from bioimageio.core import AxisId, Stat
+from bioimageio.core.axis import PerAxis
+from bioimageio.core.backends import create_model_adapter
+from bioimageio.core.common import PerMember
+from bioimageio.core.remote_backends.gradio.serializer import (
+ DescriptionSerializer,
+ GradioSampleSerializer,
+ SerializedSampleBlock,
+)
+from bioimageio.spec import load_model_description
+from bioimageio.spec.common import Sha256
+from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
+
+try:
+ import spaces # pyright: ignore
+except ImportError:
+ logger.warning("Failed to import 'spaces' package")
+
+ class spaces:
+ @staticmethod
+ def GPU(func: Any):
+ return func
+
+
+logger.enable("bioimageio")
+
+app = gr.Server()
+
+
+@app.api(name="predict") # pyright: ignore[reportUntypedFunctionDecorator]
+@spaces.GPU
+def predict(
+ model: str,
+ sha256: str,
+ input_sample: Iterable[SerializedSampleBlock],
+ blocksize: Optional[
+ Union[int, Literal["blockwise_as_serialized"], PerMember[PerAxis[int]]]
+ ] = None,
+ skip_preprocessing: bool = False,
+ skip_postprocessing: bool = False,
+ skip_input_padding: bool = False,
+ skip_output_cropping: bool = False,
+ batch_size: Optional[int] = None,
+) -> Iterable[SerializedSampleBlock]:
+ """Run prediction on a sample
+
+ Args:
+ input_sample: Input sample as a sequence of serialized sample blocks.
+ Use bioimageio.core.backends.gradio_backend.GradioModelAdapter.serialize_sample to create this from a Sample object.
+ model: A model source: URL, nickname or base64 encoded model package (if len(model) > 2083).
+ sha256: Sha256 hash of the model's bioimageio.yaml file at the model source or of the encoded model package.
+ blocksize:
+ - None (default): run non-blockwise, full-sample prediction.
+ - integer: run blockwise prediction with a block size derived from the model and this blocksize parameter.
+ - "blockwise_as_serialized": run blockwise prediction with the same blocking as the serialized input sample.
+ (Non-blockwise pre- and postprocessing steps will be ignored.)
+ - PerMember[PerAxis[int]]: run blockwise prediction with a fixed block shape given for each sample member.
+ skip_preprocessing: If True, skip preprocessing steps defined in the model.
+ skip_postprocessing: If True, skip postprocessing steps defined in the model.
+ skip_input_padding: If True, skip input padding for non-blockwise prediction.
+ Set this flag when predicting an (overlapping) sample block rather than a full sample.
+ skip_output_cropping: If True, skip output cropping for non-blockwise prediction.
+ Set this flag when predicting an (overlapping) sample block rather than a full sample.
+ batch_size: Optional batch size only applicable to predicting input samples with batch dimension.
+ """
+
+ def setup(stat: Stat):
+ model_adapter = _get_model_adapter(model, sha256=sha256)
+ return bioimageio.core.create_prediction_pipeline(
+ model_adapter.model_descr, fixed_dataset_statistics=stat
+ )
+
+ if blocksize == "blockwise_as_serialized":
+ sample_block_iterator = iter(input_sample)
+ deserialized_input_block = GradioSampleSerializer.deserialize_sample_block(
+ next(sample_block_iterator)
+ )
+ pp = setup(deserialized_input_block.stat)
+ for block in chain(
+ [deserialized_input_block],
+ (
+ GradioSampleSerializer.deserialize_sample_block(b)
+ for b in sample_block_iterator
+ ),
+ ):
+ output_block = pp.predict_sample_block(
+ block,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ )
+ yield GradioSampleSerializer.serialize_sample_block(output_block)
+ else:
+ deserialized_input_sample = GradioSampleSerializer.deserialize_sample(
+ input_sample
+ )
+ pp = setup(deserialized_input_sample.stat)
+
+ output_sample = None
+ if isinstance(blocksize, int):
+ try:
+ if pp.has_non_blockwise_postprocessing and not skip_postprocessing:
+ output_sample = pp.predict_sample_with_blocking(
+ deserialized_input_sample,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ ns=blocksize,
+ batch_size=batch_size,
+ )
+ else:
+ for output in pp.predict_sample_with_blocking_yield_intermediates(
+ deserialized_input_sample,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ ns=blocksize,
+ batch_size=batch_size,
+ )[1]:
+ # with purely blockwise postprocesssing or with postprocessing skipped,
+ # predicted blocks are part of the final result, so we yield them immediately.
+ yield GradioSampleSerializer.serialize_sample_block(
+ output.last_block
+ )
+
+ return
+
+ except Exception as e:
+ logger.warning(
+ "Falling back to full-sample prediction for model {}: {}",
+ pp.model_descr.id or pp.model_descr.name,
+ e,
+ )
+ if output_sample is None:
+ output_sample = pp.predict_sample_without_blocking(
+ deserialized_input_sample,
+ skip_preprocessing=skip_preprocessing,
+ skip_postprocessing=skip_postprocessing,
+ skip_input_padding=skip_input_padding,
+ skip_output_cropping=skip_output_cropping,
+ )
+
+ if all(
+ axes.get(AxisId("batch"), 1) > 1 for axes in output_sample.shape.values()
+ ):
+ # yield batches
+ yield from GradioSampleSerializer.serialize_sample_with_fixed_blocking(
+ output_sample,
+ block_shapes={
+ m: {AxisId("batch"): batch_size or 1} for m in output_sample.shape
+ },
+ halo={},
+ )
+ else:
+ yield from GradioSampleSerializer.serialize_sample(output_sample)
+
+
+@app.api(name="load_model") # pyright: ignore[reportUntypedFunctionDecorator]
+def load_model(
+ model: str,
+ sha256: str,
+) -> dict[Literal["message"], str]:
+ """Load a model into the server's model cache. This can be used to pre-load a model before running predictions to avoid the overhead of loading the model during the first prediction request."""
+ _ = _get_model_adapter(model, sha256=sha256)
+ return {"message": "Model loaded successfully"}
+
+
+@app.api(name="test_model") # pyright: ignore[reportUntypedFunctionDecorator]
+def test_model(
+ model: str,
+ sha256: str,
+) -> str:
+ """Run the bioimageio model test and return the validation summary. Returns None if testing failed."""
+ model_adapter = _get_model_adapter(model, sha256=sha256)
+ summary = bioimageio.core.test_model(model_adapter.model_descr)
+ return summary.model_dump_json()
+
+
+def _cache_key(kwargs: Dict[str, Any]) -> str:
+ return kwargs["sha256"]
+
+
+@gr.cache( # pyright: ignore[reportUntypedFunctionDecorator]
+ key=_cache_key,
+ max_size=bioimageio.core.settings.gradio_server_model_cache_max_size,
+ max_memory=bioimageio.core.settings.gradio_server_model_cache_max_memory,
+ per_session=False,
+)
+def _get_model_adapter(
+ model: str,
+ *,
+ sha256: str,
+):
+ """Get a model adapter for the given model
+
+ Args:
+ model: A model source: URL (len(model) <= 2083)) or model base64 encoded package bytes (len(model) > 2083).
+ sha256: Sha256 hash of the model source at model URL or of the encoded model package bytes.
+ """
+ if not model:
+ raise ValueError("Model source cannot be empty")
+
+ model_descr = _get_model(model, sha256=sha256)
+ return create_model_adapter(model_description=model_descr)
+
+
+def _get_model(
+ model: str,
+ *,
+ sha256: str,
+) -> AnyModelDescr:
+ if len(model) > 2083:
+ ret = DescriptionSerializer.deserialize_from_string(model)
+ if not isinstance(ret, (v0_4.ModelDescr, v0_5.ModelDescr)):
+ raise ValueError(
+ f"Deserialized model description is not a valid model description: got {ret.type} {ret.format_version}"
+ )
+ return ret
+ else:
+ return load_model_description(model, sha256=Sha256(sha256) if sha256 else None)
+
+
+@app.get("/")
+def root():
+ return {
+ "message": f"Running bioimageio.core {bioimageio.core.__version__} gradio server."
+ }
+
+
+def main(port: Optional[int] = None) -> str:
+ _app, local_url, _share_url = app.launch(
+ mcp_server=True, show_error=True, server_port=port
+ )
+ return local_url
+
+
+if __name__ == "__main__":
+ _ = main()
diff --git a/src/bioimageio/core/sample.py b/src/bioimageio/core/sample.py
index 8d35c7b2a..8f0b771f9 100644
--- a/src/bioimageio/core/sample.py
+++ b/src/bioimageio/core/sample.py
@@ -3,6 +3,7 @@
import collections.abc
from dataclasses import dataclass
from math import ceil, floor
+from types import MappingProxyType
from typing import (
Any,
Callable,
@@ -17,10 +18,12 @@
)
import numpy as np
+import pydantic
import xarray as xr
from numpy.typing import NDArray
from typing_extensions import Self
+from ._common_annotations import PerMemberAnno
from .axis import AxisId, PerAxis
from .block import Block
from .block_meta import (
@@ -84,6 +87,29 @@ def __getitem__(
id=self.id,
)
+ def set_block(self, block: SampleBlock) -> None:
+ """Set values of `block`.
+
+ Note:
+ - Updates only existing sample members (extra block members are ignored)
+ - Ignores missing block members (i.e. members in the sample but not in the block are not modified)
+
+ Raises:
+ ValueError if block and sample members do not overlap at all.
+ """
+ no_overlap = True
+ for m in self.members:
+ if m not in block.blocks:
+ continue
+ b = block.blocks[m]
+ self.members[m][b.inner_slice] = b.inner_data
+ no_overlap = False
+
+ if no_overlap:
+ raise ValueError(
+ f"block with members {list(block.blocks)} does not overlap with sample members {list(self.members)}"
+ )
+
@property
def shape(self) -> PerMember[PerAxis[int]]:
return {tid: t.sizes for tid, t in self.members.items()}
@@ -146,21 +172,58 @@ def from_blocks(
*,
fill_value: float = float("nan"),
) -> Self:
- members: PerMember[Tensor] = {}
- stat: Stat = {}
- sample_id = None
+ """Create a `Sample` from an iterable of `SampleBlock`s.
+
+ Note:
+ All sample blocks must have the same `sample_id`.
+
+ Args:
+ sample_blocks: The blocks to create the sample from.
+ fill_value: The value to fill missing values with (default: `nan`).
+ """
+ output = None
+ for output in cls.from_blocks_yield_intermediates(
+ sample_blocks, fill_value=fill_value
+ ):
+ pass
+
+ if output is None:
+ raise ValueError("no sample blocks provided")
+
+ return output
+
+ @classmethod
+ def from_blocks_yield_intermediates(
+ cls,
+ sample_blocks: Iterable[SampleBlock],
+ *,
+ fill_value: float = float("nan"),
+ ):
+ """Create a `Sample` from an iterable of `SampleBlock`s, yielding the intermediate sample after each block.
+
+ Args:
+ sample_blocks: The blocks to create the sample from.
+ fill_value: The value to fill missing values with (default: `nan`).
+ """
+ output = cls(members={}, stat={}, id=None)
for sample_block in sample_blocks:
- assert sample_id is None or sample_id == sample_block.sample_id
- sample_id = sample_block.sample_id
- stat = sample_block.stat
+ if output.id is None:
+ output.id = sample_block.sample_id
+ else:
+ assert output.id == sample_block.sample_id, (
+ "sample id changed between sample blocks"
+ )
+
+ output.stat = sample_block.stat
+
for m, block in sample_block.blocks.items():
- if m not in members:
+ if m not in output.members:
if -1 in block.sample_shape.values():
raise NotImplementedError(
"merging blocks with data dependent axis not yet implemented"
)
- members[m] = Tensor(
+ output.members[m] = Tensor(
np.full(
tuple(block.sample_shape[a] for a in block.data.dims),
fill_value,
@@ -169,9 +232,10 @@ def from_blocks(
dims=block.data.dims,
)
- members[m][block.inner_slice] = block.inner_data
+ output.members[m][block.inner_slice] = block.inner_data
+ yield output
- return cls(members=members, stat=stat, id=sample_id)
+ yield output
def pad(
self,
@@ -199,20 +263,20 @@ def pad(
)
-BlockT = TypeVar("BlockT", Block, BlockMeta)
+BlockT = TypeVar("BlockT", bound=BlockMeta)
-@dataclass
+@pydantic.dataclasses.dataclass(frozen=True)
class SampleBlockBase(Generic[BlockT]):
"""base class for `SampleBlockMeta` and `SampleBlock`"""
- sample_shape: PerMember[PerAxis[int]]
+ sample_shape: PerMemberAnno[PerAxis[int]]
"""the sample shape this block represents a part of"""
sample_id: SampleId
"""identifier for the sample within its dataset"""
- blocks: Dict[MemberId, BlockT]
+ blocks: PerMemberAnno[BlockT]
"""Individual tensor blocks comprising this sample block"""
block_index: BlockIndex
@@ -223,11 +287,11 @@ class SampleBlockBase(Generic[BlockT]):
@property
def shape(self) -> PerMember[PerAxis[int]]:
- return {mid: b.shape for mid, b in self.blocks.items()}
+ return MappingProxyType({mid: b.shape for mid, b in self.blocks.items()})
@property
def inner_shape(self) -> PerMember[PerAxis[int]]:
- return {mid: b.inner_shape for mid, b in self.blocks.items()}
+ return MappingProxyType({mid: b.inner_shape for mid, b in self.blocks.items()})
@dataclass
@@ -235,7 +299,7 @@ class LinearSampleAxisTransform(LinearAxisTransform):
member: MemberId
-@dataclass
+@pydantic.dataclasses.dataclass(frozen=True)
class SampleBlockMeta(SampleBlockBase[BlockMeta]):
"""Meta data of a dataset sample block"""
@@ -329,10 +393,13 @@ def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock:
)
-@dataclass
+@dataclass(frozen=True)
class SampleBlock(SampleBlockBase[Block]):
"""A block of a dataset sample"""
+ blocks: Dict[MemberId, Block]
+ """Individual tensor blocks comprising this sample block"""
+
stat: Stat
"""computed statistics"""
@@ -352,8 +419,45 @@ def get_transformed_meta(
blocks_in_sample=self.blocks_in_sample,
).get_transformed(new_axes)
+ @classmethod
+ def from_meta(
+ cls, meta: SampleBlockMeta, data: PerMember[Tensor], stat: Stat
+ ) -> Self:
+ return cls(
+ sample_shape=meta.sample_shape,
+ sample_id=meta.sample_id,
+ blocks={
+ m: Block.from_meta(b, data=data[m]) for m, b in meta.blocks.items()
+ },
+ stat=stat,
+ block_index=meta.block_index,
+ blocks_in_sample=meta.blocks_in_sample,
+ )
-@dataclass
+ def get_meta(self) -> SampleBlockMeta:
+ return SampleBlockMeta(
+ sample_id=self.sample_id,
+ blocks={m: b.get_meta() for m, b in self.blocks.items()},
+ sample_shape=self.sample_shape,
+ block_index=self.block_index,
+ blocks_in_sample=self.blocks_in_sample,
+ )
+
+ def as_sample(self) -> Sample:
+ """Convert this sample block to a `Sample` with the shape of this block.
+
+ Note:
+ If you want to convert one or more sample block to a sample with the shape of the original, whole sample,
+ use `Sample.from_blocks()` instead.
+ """
+ return Sample(
+ members=dict(self.members),
+ stat=dict(self.stat),
+ id=self.sample_id,
+ )
+
+
+@dataclass(frozen=True)
class SampleBlockWithOrigin(SampleBlock):
"""A `SampleBlock` with a reference (`origin`) to the whole `Sample`"""
diff --git a/src/bioimageio/core/tensor.py b/src/bioimageio/core/tensor.py
index f89e2eb70..63d545b51 100644
--- a/src/bioimageio/core/tensor.py
+++ b/src/bioimageio/core/tensor.py
@@ -511,8 +511,8 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
ndim = array.ndim
if ndim == 2:
current_axes = (
- v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
- v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[0]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[1]),
)
elif ndim == 3 and any(s <= 3 for s in array.shape):
current_axes = (
@@ -521,14 +521,14 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
]
),
- v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
- v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[1]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[2]),
)
elif ndim == 3:
current_axes = (
- v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
- v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
- v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("z"), size=array.shape[0]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[1]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[2]),
)
elif ndim == 4:
current_axes = (
@@ -537,9 +537,9 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
]
),
- v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
- v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
- v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("z"), size=array.shape[1]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[2]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[3]),
)
elif ndim == 5:
current_axes = (
@@ -549,9 +549,9 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
]
),
- v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
- v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
- v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("z"), size=array.shape[2]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("y"), size=array.shape[3]),
+ v0_5.SpaceInputAxis(id=v0_5.AxisId("x"), size=array.shape[4]),
)
else:
raise ValueError(f"Could not guess an axis mapping for {array.shape}")
diff --git a/src/bioimageio/core/utils/_type_guards.py b/src/bioimageio/core/utils/_type_guards.py
index 0a33b8084..ad61c6911 100644
--- a/src/bioimageio/core/utils/_type_guards.py
+++ b/src/bioimageio/core/utils/_type_guards.py
@@ -6,3 +6,5 @@
is_list = type_guards.is_list
is_ndarray = type_guards.is_ndarray
is_tuple = type_guards.is_tuple
+is_dict = type_guards.is_dict
+is_kwargs = type_guards.is_kwargs
diff --git a/src/bioimageio/core/weight_converters/torchscript_to_onnx.py b/src/bioimageio/core/weight_converters/torchscript_to_onnx.py
index 26a479dd1..3da2cdbeb 100644
--- a/src/bioimageio/core/weight_converters/torchscript_to_onnx.py
+++ b/src/bioimageio/core/weight_converters/torchscript_to_onnx.py
@@ -2,6 +2,7 @@
from typing import Optional, Sequence, Union
import torch.jit
+from loguru import logger
from torch._export.converter import TS2EPConverter
from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
@@ -57,10 +58,16 @@ def convert(
model, # pyright: ignore[reportUnknownArgumentType]
torch_sample_inputs,
).convert()
+ exported_module = exported_program.module()
+
+ try:
+ exported_module = exported_module.eval()
+ except Exception as e:
+ logger.warning("Failed to set TS2EPConverter program to evaluation mode: {}", e)
return export_to_onnx(
model_descr,
- exported_program.module(),
+ exported_module,
output_path,
verbose,
opset_version,
diff --git a/tests/conftest.py b/tests/conftest.py
index fe55efd06..7675c93e9 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2,6 +2,7 @@
import subprocess
from itertools import chain
+from pathlib import Path
from typing import Dict, List
from dotenv import load_dotenv
@@ -50,7 +51,14 @@
logger.warning("testing with bioimageio.spec {}", bioimageio_spec_version)
-EXAMPLE_DESCRIPTIONS = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/"
+EXAMPLE_DESCRIPTIONS = (
+ LOCAL_EXAMPLE_DESCRIPTIONS.as_posix() + "/"
+ if (
+ LOCAL_EXAMPLE_DESCRIPTIONS := Path(__file__).parent
+ / "../../spec-bioimage-io/example_descriptions/"
+ ).exists()
+ else "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/"
+)
# TODO: use models from new collection on S3
MODEL_SOURCES: Dict[str, str] = {
@@ -116,7 +124,7 @@
"unet2d_nuclei_broad_model",
]
)
-ONNX_MODELS = [] if onnxruntime is None else ["hpa_densenet"]
+ONNX_MODELS = [] if onnxruntime is None else ["unet2d_nuclei_broad_model"]
TENSORFLOW_MODELS = (
[]
if tensorflow is None
@@ -138,9 +146,8 @@
TENSORFLOW_JS_MODELS: List[str] = [] # TODO: add a tensorflow_js example model
ALL_MODELS = sorted(
- {
- m
- for m in chain(
+ set(
+ chain(
TORCH_MODELS,
TORCHSCRIPT_MODELS,
ONNX_MODELS,
@@ -148,7 +155,7 @@
KERAS_MODELS,
TENSORFLOW_JS_MODELS,
)
- }
+ )
)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 7c7545ff3..ddfd92b23 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -6,7 +6,7 @@
import pytest
from pydantic import FilePath
-from bioimageio.spec import load_description, settings
+from bioimageio.spec import settings
def run_subprocess(
@@ -72,13 +72,13 @@ def test_cli(
def test_empty_cache(tmp_path: Path, unet2d_nuclei_broad_model: str):
- from bioimageio.spec.utils import empty_cache
+ from bioimageio.spec.utils import empty_cache, get_reader
origingal_cache_path = settings.cache_path
try:
settings.cache_path = tmp_path / "cache"
assert not settings.cache_path.exists()
- _ = load_description(unet2d_nuclei_broad_model, perform_io_checks=False)
+ _ = get_reader("https://example.com")
assert (
len([fn for fn in settings.cache_path.iterdir() if fn.suffix != ".lock"])
== 1
diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py
index 8a595cf8d..fd7077abb 100644
--- a/tests/test_prediction_pipeline.py
+++ b/tests/test_prediction_pipeline.py
@@ -1,13 +1,24 @@
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
from pathlib import Path
-from numpy.testing import assert_array_almost_equal
-
+from bioimageio.core import Sample
+from bioimageio.core._resource_tests import evaluate_mismatched_elements
from bioimageio.core.common import SupportedWeightsFormat
from bioimageio.spec import load_description
from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04
from bioimageio.spec.model.v0_5 import ModelDescr
+def _alter_sample(sample: Sample, offset: float) -> Sample:
+ # add 1 to all values to get a different sample with the same shape and axes
+ return Sample(
+ id=f"{sample.id}_altered",
+ members={m: t + offset for m, t in sample.members.items()},
+ stat=sample.stat,
+ )
+
+
def _test_prediction_pipeline(
model_package: Path, weights_format: SupportedWeightsFormat
):
@@ -21,22 +32,42 @@ def _test_prediction_pipeline(
assert isinstance(bio_model, (ModelDescr, ModelDescr04)), (
bio_model.validation_summary.format()
)
+
pp = create_prediction_pipeline(
- bioimageio_model=bio_model, weight_format=weights_format
+ bioimageio_model=bio_model, weight_format=weights_format, devices=["cpu", "cpu"]
)
inputs = get_test_input_sample(bio_model)
- outputs = pp.predict_sample_without_blocking(
- inputs, skip_input_padding=True, skip_output_cropping=True
- )
+
+ # test in a multi-threaded setting
+ multiple_inputs = [inputs, _alter_sample(inputs, offset=100.0)]
+ with ThreadPoolExecutor(max_workers=3) as executor:
+ multiple_outputs = list(
+ executor.map(
+ partial(
+ pp.predict_sample_without_blocking,
+ skip_input_padding=True,
+ skip_output_cropping=True,
+ ),
+ multiple_inputs,
+ )
+ )
+
+ outputs = multiple_outputs[0]
expected_outputs = get_test_output_sample(bio_model)
assert len(outputs.shape) == len(expected_outputs.shape)
for m in expected_outputs.members:
- out = outputs.members[m].data
+ out = outputs.members[m]
assert out is not None
- exp = expected_outputs.members[m].data
- assert_array_almost_equal(out, exp, decimal=4)
+ exp = expected_outputs.members[m]
+ mismatched_ppm, msg, error_msg = evaluate_mismatched_elements(
+ out, exp, rtol=0.01, atol=0.1, name=m
+ )
+ if error_msg is not None:
+ raise AssertionError(error_msg)
+ elif mismatched_ppm > 50_000:
+ raise AssertionError(msg)
def test_prediction_pipeline_torch(any_torch_model: Path):
diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py
index 7c06bcb46..ad84e97b4 100644
--- a/tests/test_prediction_pipeline_device_management.py
+++ b/tests/test_prediction_pipeline_device_management.py
@@ -8,7 +8,9 @@
from bioimageio.spec.model.v0_5 import ModelDescr
-def _test_device_management(model_package: Path, weight_format: SupportedWeightsFormat):
+def _test_device_management(
+ model_package: Path, weight_format: SupportedWeightsFormat, device: str
+):
import torch
from bioimageio.core import load_description
@@ -18,18 +20,22 @@ def _test_device_management(model_package: Path, weight_format: SupportedWeights
get_test_output_sample,
)
- if not hasattr(torch, "cuda") or torch.cuda.device_count() == 0:
+ if device == "cuda" and (
+ not hasattr(torch, "cuda") or torch.cuda.device_count() == 0
+ ):
pytest.skip("Need at least one cuda device for this test")
bio_model = load_description(model_package)
assert isinstance(bio_model, (ModelDescr, ModelDescr04))
pred_pipe = create_prediction_pipeline(
- bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"]
+ bioimageio_model=bio_model, weight_format=weight_format, devices=[device]
)
inputs = get_test_input_sample(bio_model)
with pred_pipe as pp:
- outputs = pp.predict_sample_without_blocking(inputs)
+ outputs = pp.predict_sample_without_blocking(
+ inputs, skip_input_padding=True, skip_output_cropping=True
+ )
expected_outputs = get_test_output_sample(bio_model)
@@ -38,35 +44,44 @@ def _test_device_management(model_package: Path, weight_format: SupportedWeights
out = outputs.members[m].data
assert out is not None
exp = expected_outputs.members[m].data
- assert_array_almost_equal(out, exp, decimal=4)
+ assert_array_almost_equal(out, exp, decimal=2)
# repeat inference with context manager to test load/predict/unload/load/predict
with pred_pipe as pp:
- outputs = pp.predict_sample_without_blocking(inputs)
+ outputs = pp.predict_sample_without_blocking(
+ inputs, skip_input_padding=True, skip_output_cropping=True
+ )
assert len(outputs.shape) == len(expected_outputs.shape)
for m in expected_outputs.members:
out = outputs.members[m].data
assert out is not None
exp = expected_outputs.members[m].data
- assert_array_almost_equal(out, exp, decimal=4)
+ assert_array_almost_equal(out, exp, decimal=2)
-def test_device_management_torch(any_torch_model: Path):
- _test_device_management(any_torch_model, "pytorch_state_dict")
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+def test_device_management_torch(any_torch_model: Path, device: str):
+ _test_device_management(any_torch_model, "pytorch_state_dict", device=device)
-def test_device_management_torchscript(any_torchscript_model: Path):
- _test_device_management(any_torchscript_model, "torchscript")
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+def test_device_management_torchscript(any_torchscript_model: Path, device: str):
+ _test_device_management(any_torchscript_model, "torchscript", device=device)
-def test_device_management_onnx(any_onnx_model: Path):
- _test_device_management(any_onnx_model, "onnx")
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+def test_device_management_onnx(any_onnx_model: Path, device: str):
+ _test_device_management(any_onnx_model, "onnx", device=device)
-def test_device_management_tensorflow(any_tensorflow_model: Path):
- _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle")
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+def test_device_management_tensorflow(any_tensorflow_model: Path, device: str):
+ _test_device_management(
+ any_tensorflow_model, "tensorflow_saved_model_bundle", device=device
+ )
-def test_device_management_keras(any_keras_model: Path):
- _test_device_management(any_keras_model, "keras_hdf5")
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+def test_device_management_keras(any_keras_model: Path, device: str):
+ _test_device_management(any_keras_model, "keras_hdf5", device=device)
diff --git a/tests/test_remote_backends/test_gradio.py b/tests/test_remote_backends/test_gradio.py
new file mode 100644
index 000000000..48fee4d4c
--- /dev/null
+++ b/tests/test_remote_backends/test_gradio.py
@@ -0,0 +1,77 @@
+import socket
+import sys
+import time
+from concurrent.futures import ThreadPoolExecutor
+from multiprocessing import Process
+from typing import List, Tuple
+
+import pytest
+from loguru import logger
+
+from bioimageio.core import Tensor
+from bioimageio.core.common import PerMember
+
+
+def _is_port_in_use(port: int) -> bool:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ return s.connect_ex(("localhost", port)) == 0
+
+
+@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python 3.10 or higher")
+def test_gradio_backend():
+ from bioimageio.core import load_model
+ from bioimageio.core.digest_spec import get_test_input_sample
+ from bioimageio.core.remote_backends.gradio.client import GradioModelAdapter
+ from bioimageio.core.remote_backends.gradio.server import main as gradio_server_main
+
+ port = 7860
+ if _is_port_in_use(port):
+ sock = socket.socket()
+ sock.bind(("", 0))
+ _host, port = sock.getsockname()
+
+ server_process = Process(target=gradio_server_main, kwargs={"port": port})
+ server_process.start()
+
+ try:
+ deadline = time.monotonic() + 30
+ while time.monotonic() < deadline:
+ try:
+ with socket.create_connection(("localhost", port), timeout=1):
+ break
+ except OSError:
+ pass
+ time.sleep(0.2)
+ else:
+ raise TimeoutError(f"gradio server did not become ready on port {port}")
+
+ server_url = f"http://localhost:{port}/"
+ prepared: List[Tuple[str, GradioModelAdapter, PerMember[Tensor]]] = []
+ for model_id in ("affable-shark", "ambitious-sloth"):
+ model = load_model(
+ model_id, format_version="latest", perform_io_checks=False
+ )
+ sample = get_test_input_sample(model)
+
+ logger.debug("connecting adapter to {} for {}", server_url, model_id)
+ adapter = GradioModelAdapter(model, server=server_url)
+ prepared.append((model_id, adapter, sample.members))
+
+ # Exercise concurrent requests pooled across both loaded models.
+ with ThreadPoolExecutor(max_workers=4) as executor:
+ future_to_model_id = {
+ executor.submit(adapter.forward, sample_members): model_id
+ for model_id, adapter, sample_members in prepared
+ for _ in range(2)
+ }
+
+ for future, model_id in future_to_model_id.items():
+ assert future.result() is not None, model_id
+
+ for model_id, adapter, _sample_members in prepared:
+ summary = adapter.test()
+ assert summary is not None
+ assert summary.status == "passed", f"{model_id}: {summary.display()}"
+ finally:
+ server_process.terminate()
+ server_process.join()
diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py
index 5b09176b7..892fbe77b 100644
--- a/tests/test_resource_tests.py
+++ b/tests/test_resource_tests.py
@@ -1,7 +1,9 @@
from pathlib import Path
import numpy as np
+import pytest
+from bioimageio.core import __version__
from bioimageio.spec import InvalidDescr, ValidationContext
@@ -48,6 +50,10 @@ def test_loading_description_multiple_times(unet2d_nuclei_broad_model: str):
assert not isinstance(model_descr, InvalidDescr)
+@pytest.mark.skipif(
+ __version__ == "0.11.0",
+ reason="Previously released bioimageio.core 0.10.4 is incompatible with updated unet2d nuclei broad model description 0.5.11",
+)
def test_test_description_runtime_env(unet2d_nuclei_broad_model: str):
from bioimageio.core._resource_tests import test_description
diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py
index e4cd9d926..5437dd53f 100644
--- a/tests/test_stat_measures.py
+++ b/tests/test_stat_measures.py
@@ -19,9 +19,11 @@
@pytest.mark.parametrize(
"name,axes",
- product(
- ["mean", "var", "std"],
- [None, (AxisId("c"),), (AxisId("x"), AxisId("y"))],
+ list(
+ product(
+ ["mean", "var", "std"],
+ [None, (AxisId("c"),), (AxisId("x"), AxisId("y"))],
+ )
),
)
def test_individual_normal_measure(
diff --git a/tests/test_weight_converters.py b/tests/test_weight_converters.py
index fa7cabdd3..42647e98c 100644
--- a/tests/test_weight_converters.py
+++ b/tests/test_weight_converters.py
@@ -18,7 +18,7 @@ def test_pytorch_to_torchscript(any_torch_model, tmp_path):
pytest.skip("cannot convert to old 0.4 format")
out_path = tmp_path / "weights.pt"
- ret_val = convert(model_descr, out_path)
+ ret_val = convert(model_descr, out_path, devices=["cpu"])
assert out_path.exists()
assert isinstance(ret_val, v0_5.TorchscriptWeightsDescr)
assert ret_val.source == out_path