From 6fbb762c19aebfce5a0bbc2927de13e4f801f9c3 Mon Sep 17 00:00:00 2001 From: Kai Riedmiller Date: Tue, 16 Jun 2026 17:11:44 +0200 Subject: [PATCH 1/2] fix: flexible training batch size --- .../functionals/prediction/utils/size_finder.py | 2 +- panseg/functionals/training/train.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/panseg/functionals/prediction/utils/size_finder.py b/panseg/functionals/prediction/utils/size_finder.py index c7a0e6cb..0d19bf92 100644 --- a/panseg/functionals/prediction/utils/size_finder.py +++ b/panseg/functionals/prediction/utils/size_finder.py @@ -209,7 +209,7 @@ def find_batch_size( model = model.to(device) model.eval() with torch.no_grad(): - for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: x = None try: x = torch.randn((batch_size, in_channels) + actual_patch_shape).to( diff --git a/panseg/functionals/training/train.py b/panseg/functionals/training/train.py index 5b412557..c7dadda9 100644 --- a/panseg/functionals/training/train.py +++ b/panseg/functionals/training/train.py @@ -17,6 +17,7 @@ PATH_TRAIN_TEMPLATE, ) from panseg.core.zoo import model_zoo +from panseg.functionals.prediction.utils.size_finder import find_batch_size from panseg.functionals.training.augs import Augmenter from panseg.functionals.training.biio import make_model_description from panseg.functionals.training.h5dataset import HDF5Dataset @@ -99,6 +100,7 @@ def unet_training( out_channels=out_channels, f_maps=feature_maps, final_sigmoid=final_sigmoid, + layer_order="gcr", ) elif dimensionality in ["3D", "3d", "3"]: model = UNet3D( @@ -106,13 +108,21 @@ def unet_training( out_channels=out_channels, f_maps=feature_maps, final_sigmoid=final_sigmoid, + layer_order="gcr", ) else: raise ValueError(f"Unknown dimensionality {dimensionality}") logger.info(f"Using {model.__class__.__name__} model for training.") # Device configuration - batch_size = 1 + batch_size = find_batch_size( + model=model, + in_channels=in_channels, + patch_shape=patch_size, + patch_halo=(4, 4, 4), # some slack + device=device, + ) + if torch.cuda.device_count() > 1 and device != "cpu": model = nn.DataParallel(model) logger.info(f"Using {torch.cuda.device_count()} GPUs for prediction.") @@ -132,7 +142,7 @@ def unet_training( batch_size=batch_size, shuffle=True, pin_memory=True, - num_workers=1, + num_workers=4, ) } if len(val_datasets) > 0: @@ -141,7 +151,7 @@ def unet_training( batch_size=batch_size, shuffle=False, pin_memory=True, - num_workers=1, + num_workers=4, ) else: loaders["val"] = [] From 912433a9ff8b26268c9cb91de8741ed448afd641 Mon Sep 17 00:00:00 2001 From: KRiedmiller Date: Thu, 18 Jun 2026 16:11:07 +0200 Subject: [PATCH 2/2] feat: layer order arg --- panseg/functionals/training/biio.py | 2 ++ panseg/functionals/training/train.py | 9 +++++++-- tests/functionals/training/test_biio.py | 1 + tests/functionals/training/test_training.py | 6 ++++++ tests/functionals/training/test_training_integration.py | 2 +- 5 files changed, 17 insertions(+), 3 deletions(-) diff --git a/panseg/functionals/training/biio.py b/panseg/functionals/training/biio.py index 967ed246..9debea87 100644 --- a/panseg/functionals/training/biio.py +++ b/panseg/functionals/training/biio.py @@ -33,6 +33,7 @@ def make_model_description( feature_maps: int | list[int] | tuple[int, ...], patch_size: tuple[int, int, int], dimensionality: Literal["2D", "3D"], + layer_order: str, modality: str, output_type: str, description: str, @@ -165,6 +166,7 @@ def make_model_description( "in_channels": in_channels, "out_channels": out_channels, "f_maps": feature_maps, + "layer_order": layer_order, }, ) diff --git a/panseg/functionals/training/train.py b/panseg/functionals/training/train.py index c7dadda9..b7e99806 100644 --- a/panseg/functionals/training/train.py +++ b/panseg/functionals/training/train.py @@ -34,6 +34,7 @@ def create_model_config( out_channels, patch_size, dimensionality: Literal["2D", "3D"], + layer_order: str, sparse, f_maps, max_num_iters, @@ -47,6 +48,7 @@ def create_model_config( train_template["model"]["in_channels"] = in_channels train_template["model"]["out_channels"] = out_channels + train_template["model"]["layer_order"] = layer_order train_template["model"]["f_maps"] = f_maps if dimensionality in ["2D", "2d", "2"]: train_template["model"]["name"] = "UNet2D" @@ -88,6 +90,7 @@ def unet_training( description: str = "", resolution: tuple[float, float, float] = (1.0, 1.0, 1.0), pre_trained: Optional[Path] = None, + layer_order: str = "bcr", ) -> None: """ Main entrypoint for training a new unet model. Gets called when calling `panseg --train` from cli. @@ -100,7 +103,7 @@ def unet_training( out_channels=out_channels, f_maps=feature_maps, final_sigmoid=final_sigmoid, - layer_order="gcr", + layer_order=layer_order, ) elif dimensionality in ["3D", "3d", "3"]: model = UNet3D( @@ -108,7 +111,7 @@ def unet_training( out_channels=out_channels, f_maps=feature_maps, final_sigmoid=final_sigmoid, - layer_order="gcr", + layer_order=layer_order, ) else: raise ValueError(f"Unknown dimensionality {dimensionality}") @@ -170,6 +173,7 @@ def unet_training( out_channels, patch_size, dimensionality, + layer_order, sparse, feature_maps, max_num_iters, @@ -229,6 +233,7 @@ def unet_training( feature_maps=feature_maps, patch_size=patch_size, dimensionality=dimensionality, + layer_order=layer_order, modality=modality, output_type=output_type, description=description, diff --git a/tests/functionals/training/test_biio.py b/tests/functionals/training/test_biio.py index cd299dab..0791adfc 100644 --- a/tests/functionals/training/test_biio.py +++ b/tests/functionals/training/test_biio.py @@ -30,6 +30,7 @@ def test_make_model_description(tmp_path): feature_maps=64, patch_size=(16, 32, 64), dimensionality="3D", + layer_order="bcr", modality="mod", output_type="boundaries", description="dummy model", diff --git a/tests/functionals/training/test_training.py b/tests/functionals/training/test_training.py index 9f63aedd..0ee1dcff 100644 --- a/tests/functionals/training/test_training.py +++ b/tests/functionals/training/test_training.py @@ -79,6 +79,7 @@ def test_create_model_config_2d(self): out_channels=2, patch_size=[64, 64], dimensionality="2D", + layer_order="bcr", sparse=False, f_maps=[16, 32, 64], max_num_iters=1000, @@ -114,6 +115,7 @@ def test_create_model_config_3d(self): out_channels=3, patch_size=[32, 64, 64], dimensionality="3D", + layer_order="bcr", sparse=True, f_maps=[8, 16, 32], max_num_iters=2000, @@ -494,8 +496,10 @@ def test_unet_training_with_existing_checkpoint_dir( @patch("torch.nn.DataParallel") @patch("panseg.functionals.training.train.Adam") @patch("panseg.functionals.training.train.ReduceLROnPlateau") + @patch("panseg.functionals.training.train.find_batch_size") def test_unet_training_multi_gpu( self, + mock_find_batch_size, mock_reduce_lr, mock_adam, mock_data_parallel, @@ -541,6 +545,8 @@ def test_unet_training_multi_gpu( mock_isinstance.return_value = False + mock_find_batch_size.return_value = 1 + # Create a temporary dataset directory dataset_dir = tmp_path / "dataset" dataset_dir.mkdir() diff --git a/tests/functionals/training/test_training_integration.py b/tests/functionals/training/test_training_integration.py index 605ad5c1..6c150d7a 100644 --- a/tests/functionals/training/test_training_integration.py +++ b/tests/functionals/training/test_training_integration.py @@ -110,7 +110,7 @@ def test_training_integration_3d_gpu(self, mocker, tmp_path): model_name=model_name, in_channels=1, out_channels=1, - feature_maps=16, + feature_maps=[16, 32, 64], patch_size=(16, 64, 64), max_num_iters=100, dimensionality="3D",