Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion panseg/core/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def compute_halo(self, module: Module) -> int:
return halo

def compute_3D_halo_for_pytorch3dunet(
self, module: AbstractUNet
self, module: UNet2D | UNet3D
) -> tuple[int, int, int]:
if isinstance(module, UNet3D):
halo = self.compute_halo(module)
Expand Down
44 changes: 22 additions & 22 deletions panseg/functionals/training/biio.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def make_model_description(
in_channels: int,
out_channels: int,
feature_maps: int | list[int] | tuple[int, ...],
patch_size: tuple[int, int, int],
axis_min_sizes: tuple[int, int, int],
dimensionality: Literal["2D", "3D"],
layer_order: str,
modality: str,
Expand All @@ -50,20 +50,20 @@ def make_model_description(
channel_names=[Identifier(f"in_ch_{i}") for i in range(in_channels)]
),
SpaceInputAxis(
id=AxisId("z_in"),
size=ParameterizedSize(min=patch_size[0], step=1),
id=AxisId("z"),
size=ParameterizedSize(min=axis_min_sizes[0], step=1),
scale=resolution[0],
unit="micrometer",
),
SpaceInputAxis(
id=AxisId("y_in"),
size=ParameterizedSize(min=patch_size[1], step=1),
id=AxisId("y"),
size=ParameterizedSize(min=axis_min_sizes[1], step=1),
scale=resolution[1],
unit="micrometer",
),
SpaceInputAxis(
id=AxisId("x_in"),
size=ParameterizedSize(min=patch_size[2], step=1),
id=AxisId("x"),
size=ParameterizedSize(min=axis_min_sizes[2], step=1),
scale=resolution[2],
unit="micrometer",
),
Expand All @@ -75,14 +75,14 @@ def make_model_description(
channel_names=[Identifier(f"in_ch_{i}") for i in range(in_channels)]
),
SpaceInputAxis(
id=AxisId("y_in"),
size=ParameterizedSize(min=patch_size[1], step=1),
id=AxisId("y"),
size=ParameterizedSize(min=axis_min_sizes[1], step=1),
scale=resolution[1],
unit="micrometer",
),
SpaceInputAxis(
id=AxisId("x_in"),
size=ParameterizedSize(min=patch_size[2], step=1),
id=AxisId("x"),
size=ParameterizedSize(min=axis_min_sizes[2], step=1),
scale=resolution[2],
unit="micrometer",
),
Expand All @@ -106,20 +106,20 @@ def make_model_description(
channel_names=[Identifier(f"out_ch_{i}") for i in range(out_channels)]
),
SpaceOutputAxis(
id=AxisId("z_out"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("z_in")),
id=AxisId("z"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("z")),
scale=resolution[0],
unit="micrometer",
),
SpaceOutputAxis(
id=AxisId("y_out"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("y_in")),
id=AxisId("y"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("y")),
scale=resolution[1],
unit="micrometer",
),
SpaceOutputAxis(
id=AxisId("x_out"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("x_in")),
id=AxisId("x"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("x")),
scale=resolution[2],
unit="micrometer",
),
Expand All @@ -131,14 +131,14 @@ def make_model_description(
channel_names=[Identifier(f"out_ch_{i}") for i in range(out_channels)]
),
SpaceOutputAxis(
id=AxisId("y_out"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("y_in")),
id=AxisId("y"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("y")),
scale=resolution[1],
unit="micrometer",
),
SpaceOutputAxis(
id=AxisId("x_out"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("x_in")),
id=AxisId("x"),
size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("x")),
scale=resolution[2],
unit="micrometer",
),
Expand Down Expand Up @@ -173,7 +173,7 @@ def make_model_description(
model_desc = ModelDescr(
name=model_name,
description=description,
tags=["UNet", modality, output_type],
tags=["UNet", "PanSeg", modality, output_type],
inputs=[input_desc],
outputs=[output_desc],
weights=WeightsDescr(
Expand Down
14 changes: 13 additions & 1 deletion panseg/functionals/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch
import yaml
from bioimageio.core import test_model
from torch import nn
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
Expand Down Expand Up @@ -211,6 +212,9 @@ def unet_training(

# Obtain and save test in- and output for biio
test_in, _ = next(iter(loaders["train"]))
# normalize, as this is post-augmentations
test_in = (test_in - test_in.mean()) / (test_in.std() + 1e-6)

if isinstance(model, UNet2D):
# remove the singleton z-dimension from the input
# Extra dim was only added in the panseg loader for the augmentations
Expand All @@ -224,14 +228,21 @@ def unet_training(
np.save(checkpoint_dir / "test_in.npy", test_in.numpy())
np.save(checkpoint_dir / "test_out.npy", test_out.detach().cpu().numpy())

try:
axis_min_sizes = np.min(
(model_zoo.compute_3D_halo_for_pytorch3dunet(model), patch_size), axis=0
)
except ValueError:
axis_min_sizes = patch_size

with chdir(checkpoint_dir):
model_desc = make_model_description(
weights=weights,
model_name=model_name,
in_channels=in_channels,
out_channels=out_channels,
feature_maps=feature_maps,
patch_size=patch_size,
axis_min_sizes=axis_min_sizes,
dimensionality=dimensionality,
layer_order=layer_order,
modality=modality,
Expand All @@ -242,6 +253,7 @@ def unet_training(
test_out=Path("test_out.npy"),
panseg_config=checkpoint_dir / FILE_CONFIG_TRAIN_YAML,
)
test_model(model_desc).display()
model_desc.package(
checkpoint_dir
/ f"biio_model_{''.join(c for c in model_name if c.isalnum())}.zip"
Expand Down
2 changes: 1 addition & 1 deletion tests/functionals/training/test_biio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_make_model_description(tmp_path):
in_channels=1,
out_channels=1,
feature_maps=64,
patch_size=(16, 32, 64),
axis_min_sizes=(16, 32, 64),
dimensionality="3D",
layer_order="bcr",
modality="mod",
Expand Down
6 changes: 6 additions & 0 deletions tests/functionals/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def test_create_datasets_invalid_phase(self):
class TestUnetTraining:
"""Tests for unet_training function."""

@patch("panseg.functionals.training.train.test_model")
@patch("panseg.functionals.training.train.make_model_description")
@patch("panseg.functionals.training.train.UNetTrainer")
@patch("panseg.functionals.training.train.create_datasets")
Expand All @@ -320,6 +321,7 @@ def test_unet_training_2d(
mock_create_datasets,
mock_trainer,
mock_model_desc,
mock_test_model,
tmp_path,
):
"""Test UNet training for 2D case."""
Expand Down Expand Up @@ -373,6 +375,7 @@ def test_unet_training_2d(
assert (tmp_path / model_name / "test_in.npy").exists()
assert (tmp_path / model_name / "test_out.npy").exists()

@patch("panseg.functionals.training.train.test_model")
@patch("panseg.functionals.training.train.make_model_description")
@patch("panseg.functionals.training.train.UNetTrainer")
@patch("panseg.functionals.training.train.create_datasets")
Expand All @@ -387,6 +390,7 @@ def test_unet_training_3d(
mock_create_datasets,
mock_trainer,
mock_model_desc,
mock_test_model,
tmp_path,
):
"""Test UNet training for 3D case."""
Expand Down Expand Up @@ -484,6 +488,7 @@ def test_unet_training_with_existing_checkpoint_dir(
device="cpu",
)

@patch("panseg.functionals.training.train.test_model")
@patch("panseg.functionals.training.train.isinstance")
@patch("panseg.functionals.training.train.make_model_description")
@patch("panseg.functionals.training.train.DataLoader")
Expand Down Expand Up @@ -512,6 +517,7 @@ def test_unet_training_multi_gpu(
mock_data_loader,
mock_description,
mock_isinstance,
mock_test_model,
tmp_path,
):
"""Test UNet training with multiple GPUs."""
Expand Down
Loading