Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion panseg/functionals/prediction/utils/size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def find_batch_size(
model = model.to(device)
model.eval()
with torch.no_grad():
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
x = None
try:
x = torch.randn((batch_size, in_channels) + actual_patch_shape).to(
Expand Down
2 changes: 2 additions & 0 deletions panseg/functionals/training/biio.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def make_model_description(
feature_maps: int | list[int] | tuple[int, ...],
patch_size: tuple[int, int, int],
dimensionality: Literal["2D", "3D"],
layer_order: str,
modality: str,
output_type: str,
description: str,
Expand Down Expand Up @@ -165,6 +166,7 @@ def make_model_description(
"in_channels": in_channels,
"out_channels": out_channels,
"f_maps": feature_maps,
"layer_order": layer_order,
},
)

Expand Down
21 changes: 18 additions & 3 deletions panseg/functionals/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PATH_TRAIN_TEMPLATE,
)
from panseg.core.zoo import model_zoo
from panseg.functionals.prediction.utils.size_finder import find_batch_size
from panseg.functionals.training.augs import Augmenter
from panseg.functionals.training.biio import make_model_description
from panseg.functionals.training.h5dataset import HDF5Dataset
Expand All @@ -33,6 +34,7 @@ def create_model_config(
out_channels,
patch_size,
dimensionality: Literal["2D", "3D"],
layer_order: str,
sparse,
f_maps,
max_num_iters,
Expand All @@ -46,6 +48,7 @@ def create_model_config(

train_template["model"]["in_channels"] = in_channels
train_template["model"]["out_channels"] = out_channels
train_template["model"]["layer_order"] = layer_order
train_template["model"]["f_maps"] = f_maps
if dimensionality in ["2D", "2d", "2"]:
train_template["model"]["name"] = "UNet2D"
Expand Down Expand Up @@ -87,6 +90,7 @@ def unet_training(
description: str = "",
resolution: tuple[float, float, float] = (1.0, 1.0, 1.0),
pre_trained: Optional[Path] = None,
layer_order: str = "bcr",
) -> None:
"""
Main entrypoint for training a new unet model. Gets called when calling `panseg --train` from cli.
Expand All @@ -99,20 +103,29 @@ def unet_training(
out_channels=out_channels,
f_maps=feature_maps,
final_sigmoid=final_sigmoid,
layer_order=layer_order,
)
elif dimensionality in ["3D", "3d", "3"]:
model = UNet3D(
in_channels=in_channels,
out_channels=out_channels,
f_maps=feature_maps,
final_sigmoid=final_sigmoid,
layer_order=layer_order,
)
else:
raise ValueError(f"Unknown dimensionality {dimensionality}")
logger.info(f"Using {model.__class__.__name__} model for training.")

# Device configuration
batch_size = 1
batch_size = find_batch_size(
model=model,
in_channels=in_channels,
patch_shape=patch_size,
patch_halo=(4, 4, 4), # some slack
device=device,
)

if torch.cuda.device_count() > 1 and device != "cpu":
model = nn.DataParallel(model)
logger.info(f"Using {torch.cuda.device_count()} GPUs for prediction.")
Expand All @@ -132,7 +145,7 @@ def unet_training(
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
num_workers=4,
)
}
if len(val_datasets) > 0:
Expand All @@ -141,7 +154,7 @@ def unet_training(
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=1,
num_workers=4,
)
else:
loaders["val"] = []
Expand All @@ -160,6 +173,7 @@ def unet_training(
out_channels,
patch_size,
dimensionality,
layer_order,
sparse,
feature_maps,
max_num_iters,
Expand Down Expand Up @@ -219,6 +233,7 @@ def unet_training(
feature_maps=feature_maps,
patch_size=patch_size,
dimensionality=dimensionality,
layer_order=layer_order,
modality=modality,
output_type=output_type,
description=description,
Expand Down
1 change: 1 addition & 0 deletions tests/functionals/training/test_biio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_make_model_description(tmp_path):
feature_maps=64,
patch_size=(16, 32, 64),
dimensionality="3D",
layer_order="bcr",
modality="mod",
output_type="boundaries",
description="dummy model",
Expand Down
6 changes: 6 additions & 0 deletions tests/functionals/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_create_model_config_2d(self):
out_channels=2,
patch_size=[64, 64],
dimensionality="2D",
layer_order="bcr",
sparse=False,
f_maps=[16, 32, 64],
max_num_iters=1000,
Expand Down Expand Up @@ -114,6 +115,7 @@ def test_create_model_config_3d(self):
out_channels=3,
patch_size=[32, 64, 64],
dimensionality="3D",
layer_order="bcr",
sparse=True,
f_maps=[8, 16, 32],
max_num_iters=2000,
Expand Down Expand Up @@ -494,8 +496,10 @@ def test_unet_training_with_existing_checkpoint_dir(
@patch("torch.nn.DataParallel")
@patch("panseg.functionals.training.train.Adam")
@patch("panseg.functionals.training.train.ReduceLROnPlateau")
@patch("panseg.functionals.training.train.find_batch_size")
def test_unet_training_multi_gpu(
self,
mock_find_batch_size,
mock_reduce_lr,
mock_adam,
mock_data_parallel,
Expand Down Expand Up @@ -541,6 +545,8 @@ def test_unet_training_multi_gpu(

mock_isinstance.return_value = False

mock_find_batch_size.return_value = 1

# Create a temporary dataset directory
dataset_dir = tmp_path / "dataset"
dataset_dir.mkdir()
Expand Down
2 changes: 1 addition & 1 deletion tests/functionals/training/test_training_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_training_integration_3d_gpu(self, mocker, tmp_path):
model_name=model_name,
in_channels=1,
out_channels=1,
feature_maps=16,
feature_maps=[16, 32, 64],
patch_size=(16, 64, 64),
max_num_iters=100,
dimensionality="3D",
Expand Down
Loading