Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions panseg/tasks/training_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def unet_training_task(
max_num_iters: int,
dimensionality: Literal["2D", "3D"],
device: str,
layer_order: str,
modality: str = "",
output_type: str = "",
description: str = "",
Expand Down Expand Up @@ -100,6 +101,7 @@ def unet_training_task(
description=description,
resolution=resolution,
pre_trained=pre_trained,
layer_order=layer_order,
)
except RuntimeError as e:
if "Output size is too small" in str(e):
Expand Down
9 changes: 9 additions & 0 deletions panseg/viewer_napari/widgets/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,19 @@ def factory_unet_training(
)
return

layer_order = "bcr"
pre_model_path = None
if pretrained is not None:
model, model_config, pre_model_path = model_zoo.get_model_by_name(
pretrained
)
pretrained_lo = model_config["layer_order"]
if pretrained_lo != layer_order:
logger.debug(
f"Pretrained model has non-default layer_order, changing from {layer_order}"
)
layer_order = pretrained_lo
logger.info(f"Model architecture: {layer_order}")

widgets_to_reset = [
self.widget_unet_training.pretrained,
Expand Down Expand Up @@ -342,6 +350,7 @@ def factory_unet_training(
"description": description,
"resolution": resolution,
"pre_trained": pre_model_path,
"layer_order": layer_order,
"widgets_to_reset": widgets_to_reset,
"_pbar": pbar,
"_to_hide": [self.widget_unet_training.call_button],
Expand Down
5 changes: 4 additions & 1 deletion tests/widgets/test_training_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,11 @@ def test_unet_training_feature_maps(training_tab, mocker, tmp_path):

def test_unet_training_pretrained(training_tab, mocker, tmp_path):
m_log = mocker.patch("panseg.viewer_napari.widgets.training.log")
m_logger = mocker.patch("panseg.viewer_napari.widgets.training.logger")
m_get_models = mocker.patch(
"panseg.viewer_napari.widgets.training.model_zoo.get_model_by_name"
)
m_get_models.return_value = [1, 2, mocker.sentinel]
m_get_models.return_value = ["model", {"layer_order": "gcr"}, mocker.sentinel]
m_schedule = mocker.patch("panseg.viewer_napari.widgets.training.schedule_task")
mocker.patch(
"panseg.viewer_napari.widgets.training.PATH_PANSEG_MODELS", new=tmp_path
Expand Down Expand Up @@ -448,6 +449,8 @@ def test_unet_training_pretrained(training_tab, mocker, tmp_path):
pbar=None,
)

m_logger.info.assert_called_once()
m_logger.debug.assert_called_once()
m_log.assert_called_once()
m_get_models.assert_called_with("SOMETHING")
m_schedule.assert_called_once()
Expand Down
Loading