From c032914114d639e572ac966a82031d5fb46720b6 Mon Sep 17 00:00:00 2001 From: KRiedmiller Date: Thu, 25 Jun 2026 17:32:14 +0200 Subject: [PATCH 1/4] fix: model axis names and min axis size --- panseg/core/zoo.py | 2 +- panseg/functionals/training/biio.py | 42 ++++++++++++------------- panseg/functionals/training/train.py | 11 ++++++- tests/functionals/training/test_biio.py | 2 +- 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/panseg/core/zoo.py b/panseg/core/zoo.py index adf6c2b5..d9a3dd35 100644 --- a/panseg/core/zoo.py +++ b/panseg/core/zoo.py @@ -608,7 +608,7 @@ def compute_halo(self, module: Module) -> int: return halo def compute_3D_halo_for_pytorch3dunet( - self, module: AbstractUNet + self, module: UNet2D | UNet3D ) -> tuple[int, int, int]: if isinstance(module, UNet3D): halo = self.compute_halo(module) diff --git a/panseg/functionals/training/biio.py b/panseg/functionals/training/biio.py index 9debea87..6704210a 100644 --- a/panseg/functionals/training/biio.py +++ b/panseg/functionals/training/biio.py @@ -31,7 +31,7 @@ def make_model_description( in_channels: int, out_channels: int, feature_maps: int | list[int] | tuple[int, ...], - patch_size: tuple[int, int, int], + axis_min_sizes: tuple[int, int, int], dimensionality: Literal["2D", "3D"], layer_order: str, modality: str, @@ -50,20 +50,20 @@ def make_model_description( channel_names=[Identifier(f"in_ch_{i}") for i in range(in_channels)] ), SpaceInputAxis( - id=AxisId("z_in"), - size=ParameterizedSize(min=patch_size[0], step=1), + id=AxisId("z"), + size=ParameterizedSize(min=axis_min_sizes[0], step=1), scale=resolution[0], unit="micrometer", ), SpaceInputAxis( - id=AxisId("y_in"), - size=ParameterizedSize(min=patch_size[1], step=1), + id=AxisId("y"), + size=ParameterizedSize(min=axis_min_sizes[1], step=1), scale=resolution[1], unit="micrometer", ), SpaceInputAxis( - id=AxisId("x_in"), - size=ParameterizedSize(min=patch_size[2], step=1), + id=AxisId("x"), + size=ParameterizedSize(min=axis_min_sizes[2], step=1), scale=resolution[2], unit="micrometer", ), @@ -75,14 +75,14 @@ def make_model_description( channel_names=[Identifier(f"in_ch_{i}") for i in range(in_channels)] ), SpaceInputAxis( - id=AxisId("y_in"), - size=ParameterizedSize(min=patch_size[1], step=1), + id=AxisId("y"), + size=ParameterizedSize(min=axis_min_sizes[1], step=1), scale=resolution[1], unit="micrometer", ), SpaceInputAxis( - id=AxisId("x_in"), - size=ParameterizedSize(min=patch_size[2], step=1), + id=AxisId("x"), + size=ParameterizedSize(min=axis_min_sizes[2], step=1), scale=resolution[2], unit="micrometer", ), @@ -106,20 +106,20 @@ def make_model_description( channel_names=[Identifier(f"out_ch_{i}") for i in range(out_channels)] ), SpaceOutputAxis( - id=AxisId("z_out"), - size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("z_in")), + id=AxisId("z"), + size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("z")), scale=resolution[0], unit="micrometer", ), SpaceOutputAxis( - id=AxisId("y_out"), - size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("y_in")), + id=AxisId("y"), + size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("y")), scale=resolution[1], unit="micrometer", ), SpaceOutputAxis( - id=AxisId("x_out"), - size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("x_in")), + id=AxisId("x"), + size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("x")), scale=resolution[2], unit="micrometer", ), @@ -131,14 +131,14 @@ def make_model_description( channel_names=[Identifier(f"out_ch_{i}") for i in range(out_channels)] ), SpaceOutputAxis( - id=AxisId("y_out"), - size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("y_in")), + id=AxisId("y"), + size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("y")), scale=resolution[1], unit="micrometer", ), SpaceOutputAxis( - id=AxisId("x_out"), - size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("x_in")), + id=AxisId("x"), + size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("x")), scale=resolution[2], unit="micrometer", ), diff --git a/panseg/functionals/training/train.py b/panseg/functionals/training/train.py index b7e99806..220dec78 100644 --- a/panseg/functionals/training/train.py +++ b/panseg/functionals/training/train.py @@ -6,6 +6,7 @@ import numpy as np import torch import yaml +from bioimageio.core import test_model from torch import nn from torch.optim.adam import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -224,6 +225,13 @@ def unet_training( np.save(checkpoint_dir / "test_in.npy", test_in.numpy()) np.save(checkpoint_dir / "test_out.npy", test_out.detach().cpu().numpy()) + try: + axis_min_sizes = np.min( + (model_zoo.compute_3D_halo_for_pytorch3dunet(model), patch_size), axis=0 + ) + except ValueError: + axis_min_sizes = patch_size + with chdir(checkpoint_dir): model_desc = make_model_description( weights=weights, @@ -231,7 +239,7 @@ def unet_training( in_channels=in_channels, out_channels=out_channels, feature_maps=feature_maps, - patch_size=patch_size, + axis_min_sizes=axis_min_sizes, dimensionality=dimensionality, layer_order=layer_order, modality=modality, @@ -242,6 +250,7 @@ def unet_training( test_out=Path("test_out.npy"), panseg_config=checkpoint_dir / FILE_CONFIG_TRAIN_YAML, ) + test_model(model_desc).display() model_desc.package( checkpoint_dir / f"biio_model_{''.join(c for c in model_name if c.isalnum())}.zip" diff --git a/tests/functionals/training/test_biio.py b/tests/functionals/training/test_biio.py index 0791adfc..d16fdb99 100644 --- a/tests/functionals/training/test_biio.py +++ b/tests/functionals/training/test_biio.py @@ -28,7 +28,7 @@ def test_make_model_description(tmp_path): in_channels=1, out_channels=1, feature_maps=64, - patch_size=(16, 32, 64), + axis_min_sizes=(16, 32, 64), dimensionality="3D", layer_order="bcr", modality="mod", From ff3a626ed44b9e5d97820b3c38d19e710e39fd69 Mon Sep 17 00:00:00 2001 From: Kai Riedmiller Date: Fri, 26 Jun 2026 14:06:33 +0200 Subject: [PATCH 2/4] fix: validate biio test input/output --- panseg/functionals/training/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/panseg/functionals/training/train.py b/panseg/functionals/training/train.py index 220dec78..3e9fb9f7 100644 --- a/panseg/functionals/training/train.py +++ b/panseg/functionals/training/train.py @@ -212,6 +212,9 @@ def unet_training( # Obtain and save test in- and output for biio test_in, _ = next(iter(loaders["train"])) + # normalize, as this is post-augmentations + test_in = (test_in - test_in.mean()) / (test_in.std() + 1e-6) + if isinstance(model, UNet2D): # remove the singleton z-dimension from the input # Extra dim was only added in the panseg loader for the augmentations From 0cb643d8f82761535792e31aedf0a75396b597e5 Mon Sep 17 00:00:00 2001 From: Kai Riedmiller Date: Fri, 26 Jun 2026 14:22:51 +0200 Subject: [PATCH 3/4] test: fix training test --- tests/functionals/training/test_training.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/functionals/training/test_training.py b/tests/functionals/training/test_training.py index 0ee1dcff..8ef73f97 100644 --- a/tests/functionals/training/test_training.py +++ b/tests/functionals/training/test_training.py @@ -306,6 +306,7 @@ def test_create_datasets_invalid_phase(self): class TestUnetTraining: """Tests for unet_training function.""" + @patch("panseg.functionals.training.train.test_model") @patch("panseg.functionals.training.train.make_model_description") @patch("panseg.functionals.training.train.UNetTrainer") @patch("panseg.functionals.training.train.create_datasets") @@ -320,6 +321,7 @@ def test_unet_training_2d( mock_create_datasets, mock_trainer, mock_model_desc, + mock_test_model, tmp_path, ): """Test UNet training for 2D case.""" @@ -373,6 +375,7 @@ def test_unet_training_2d( assert (tmp_path / model_name / "test_in.npy").exists() assert (tmp_path / model_name / "test_out.npy").exists() + @patch("panseg.functionals.training.train.test_model") @patch("panseg.functionals.training.train.make_model_description") @patch("panseg.functionals.training.train.UNetTrainer") @patch("panseg.functionals.training.train.create_datasets") @@ -387,6 +390,7 @@ def test_unet_training_3d( mock_create_datasets, mock_trainer, mock_model_desc, + mock_test_model, tmp_path, ): """Test UNet training for 3D case.""" @@ -484,6 +488,7 @@ def test_unet_training_with_existing_checkpoint_dir( device="cpu", ) + @patch("panseg.functionals.training.train.test_model") @patch("panseg.functionals.training.train.isinstance") @patch("panseg.functionals.training.train.make_model_description") @patch("panseg.functionals.training.train.DataLoader") @@ -512,6 +517,7 @@ def test_unet_training_multi_gpu( mock_data_loader, mock_description, mock_isinstance, + mock_test_model, tmp_path, ): """Test UNet training with multiple GPUs.""" From 02d8b79fc8ceb7014420f0b229f37b5bba490243 Mon Sep 17 00:00:00 2001 From: Kai Riedmiller Date: Fri, 26 Jun 2026 14:25:19 +0200 Subject: [PATCH 4/4] feat: add PanSeg tag to trained models --- panseg/functionals/training/biio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/panseg/functionals/training/biio.py b/panseg/functionals/training/biio.py index 6704210a..c20ad574 100644 --- a/panseg/functionals/training/biio.py +++ b/panseg/functionals/training/biio.py @@ -173,7 +173,7 @@ def make_model_description( model_desc = ModelDescr( name=model_name, description=description, - tags=["UNet", modality, output_type], + tags=["UNet", "PanSeg", modality, output_type], inputs=[input_desc], outputs=[output_desc], weights=WeightsDescr(