diff --git a/panseg/tasks/training_tasks.py b/panseg/tasks/training_tasks.py index 4494f123..8903f559 100644 --- a/panseg/tasks/training_tasks.py +++ b/panseg/tasks/training_tasks.py @@ -24,6 +24,7 @@ def unet_training_task( max_num_iters: int, dimensionality: Literal["2D", "3D"], device: str, + layer_order: str, modality: str = "", output_type: str = "", description: str = "", @@ -100,6 +101,7 @@ def unet_training_task( description=description, resolution=resolution, pre_trained=pre_trained, + layer_order=layer_order, ) except RuntimeError as e: if "Output size is too small" in str(e): diff --git a/panseg/viewer_napari/widgets/training.py b/panseg/viewer_napari/widgets/training.py index 238cf351..bde6f5b8 100644 --- a/panseg/viewer_napari/widgets/training.py +++ b/panseg/viewer_napari/widgets/training.py @@ -304,11 +304,19 @@ def factory_unet_training( ) return + layer_order = "bcr" pre_model_path = None if pretrained is not None: model, model_config, pre_model_path = model_zoo.get_model_by_name( pretrained ) + pretrained_lo = model_config["layer_order"] + if pretrained_lo != layer_order: + logger.debug( + f"Pretrained model has non-default layer_order, changing from {layer_order}" + ) + layer_order = pretrained_lo + logger.info(f"Model architecture: {layer_order}") widgets_to_reset = [ self.widget_unet_training.pretrained, @@ -342,6 +350,7 @@ def factory_unet_training( "description": description, "resolution": resolution, "pre_trained": pre_model_path, + "layer_order": layer_order, "widgets_to_reset": widgets_to_reset, "_pbar": pbar, "_to_hide": [self.widget_unet_training.call_button], diff --git a/tests/widgets/test_training_widget.py b/tests/widgets/test_training_widget.py index aee8ec9a..6796ffc3 100644 --- a/tests/widgets/test_training_widget.py +++ b/tests/widgets/test_training_widget.py @@ -417,10 +417,11 @@ def test_unet_training_feature_maps(training_tab, mocker, tmp_path): def test_unet_training_pretrained(training_tab, mocker, tmp_path): m_log = mocker.patch("panseg.viewer_napari.widgets.training.log") + m_logger = mocker.patch("panseg.viewer_napari.widgets.training.logger") m_get_models = mocker.patch( "panseg.viewer_napari.widgets.training.model_zoo.get_model_by_name" ) - m_get_models.return_value = [1, 2, mocker.sentinel] + m_get_models.return_value = ["model", {"layer_order": "gcr"}, mocker.sentinel] m_schedule = mocker.patch("panseg.viewer_napari.widgets.training.schedule_task") mocker.patch( "panseg.viewer_napari.widgets.training.PATH_PANSEG_MODELS", new=tmp_path @@ -448,6 +449,8 @@ def test_unet_training_pretrained(training_tab, mocker, tmp_path): pbar=None, ) + m_logger.info.assert_called_once() + m_logger.debug.assert_called_once() m_log.assert_called_once() m_get_models.assert_called_with("SOMETHING") m_schedule.assert_called_once()