Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions micro_sam/_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ def check_layer_initialization(viewer, expected_shape):

assert len(viewer.layers) == 6
expected_layer_names = [
"image", "auto_segmentation", "committed_objects", "current_object", "point_prompts", "prompts"
"image", "auto_segmentation", "committed_objects", "current_object", "points", "geometry"
]

for layer_name in expected_layer_names:
assert layer_name in viewer.layers

# Check prompt layers
assert viewer.layers["prompts"].data == [] # shape data is list, not numpy array
np.testing.assert_equal(viewer.layers["point_prompts"].data, 0)
assert viewer.layers["geometry"].data == [] # shape data is list, not numpy array
np.testing.assert_equal(viewer.layers["points"].data, 0)

# Check segmentation layers.
for layer_name in ["auto_segmentation", "committed_objects", "current_object"]:
Expand Down
31 changes: 20 additions & 11 deletions micro_sam/sam_annotator/_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None):

# Add the point layer for point prompts.
self._point_labels = ["positive", "negative"]
if "point_prompts" in self._viewer.layers:
self._point_prompt_layer = self._viewer.layers["point_prompts"]
if "points" in self._viewer.layers:
self._point_prompt_layer = self._viewer.layers["points"]
else:
self._point_prompt_layer = self._viewer.add_points(
name="point_prompts",
name="points",
property_choices={"label": self._point_labels},
border_color="label",
border_color_cycle=vutil.LABEL_COLOR_CYCLE,
Expand All @@ -89,13 +89,13 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None):
)
self._point_prompt_layer.border_color_mode = "cycle"

if "prompts" not in self._viewer.layers:
if "geometry" not in self._viewer.layers:
# Add the shape layer for box and other shape prompts.
self._viewer.add_shapes(
face_color="transparent",
edge_color="green",
edge_width=4,
name="prompts",
name="geometry",
ndim=self._ndim,
)

Expand All @@ -106,7 +106,7 @@ def _get_widgets(self):
)

def _create_embedding_widget(self):
return widgets.EmbeddingWidget()
return widgets.EmbeddingWidget(viewer=self._viewer, roi_selection=True)

def _create_widgets(self):
# Create the embedding widget and connect all events related to it.
Expand Down Expand Up @@ -140,8 +140,8 @@ def _segment(viewer):
# Note: we also need to over-write the keybindings for specific layers.
# See https://github.com/napari/napari/issues/7302 for details.
# Here, we need to over-write the 's' keybinding for both of the prompt layers.
prompt_layer = self._viewer.layers["prompts"]
point_prompt_layer = self._viewer.layers["point_prompts"]
prompt_layer = self._viewer.layers["geometry"]
point_prompt_layer = self._viewer.layers["points"]

@prompt_layer.bind_key("s", overwrite=True)
def _segment_prompts(event):
Expand Down Expand Up @@ -274,7 +274,7 @@ def _rebuild_for_ndim(self, ndim, force=False):
self._shape = PLACEHOLDER_SHAPE[ndim]

# Remove the existing micro_sam layers so they are recreated with the new ndim and shape.
layer_names = ("current_object", "auto_segmentation", "committed_objects", "point_prompts", "prompts")
layer_names = ("current_object", "auto_segmentation", "committed_objects", "points", "geometry")
for layer_name in layer_names:
if layer_name in self._viewer.layers:
del self._viewer.layers[layer_name]
Expand Down Expand Up @@ -319,6 +319,7 @@ def _update_image(self, segmentation_result=None):

# Update the image scale.
scale = state.image_scale
translate = state.image_translate

# Reset all layers.
self._viewer.layers["current_object"].data = np.zeros(
Expand All @@ -339,8 +340,16 @@ def _update_image(self, segmentation_result=None):
self._viewer.layers["committed_objects"].data = segmentation_result
self._viewer.layers["committed_objects"].scale = scale

self._viewer.layers["point_prompts"].scale = scale
self._viewer.layers["prompts"].scale = scale
self._viewer.layers["points"].scale = scale
self._viewer.layers["geometry"].scale = scale

# Keep cropped annotation layers aligned with their location in the full image. The image
# itself remains unchanged and visible; only the embeddings and result arrays are cropped.
if translate is not None:
for layer_name in (
"current_object", "auto_segmentation", "committed_objects", "points", "geometry"
):
self._viewer.layers[layer_name].translate = translate

vutil.clear_annotations(self._viewer, clear_segmentations=False)

Expand Down
2 changes: 2 additions & 0 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class AnnotatorState(metaclass=Singleton):
predictor: Optional[SamPredictor] = None
image_shape: Optional[Tuple[int, int]] = None
image_scale: Optional[Tuple[float, ...]] = None
image_translate: Optional[Tuple[float, ...]] = None
ndim: Optional[int] = None
image_name: Optional[str] = None
embedding_path: Optional[str] = None
Expand Down Expand Up @@ -363,6 +364,7 @@ def reset_state(self):
self.predictor = None
self.image_shape = None
self.image_scale = None
self.image_translate = None
self.ndim = None
self.image_name = None
self.embedding_path = None
Expand Down
1 change: 1 addition & 0 deletions micro_sam/sam_annotator/_tooltips.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"model_family": "Select the segment anything 2 model family.",
"model_family_advanced": "Select the advanced (non-SAM2) model family, e.g. a SAM1 family. Switched on via 'Advanced Models' in the embedding settings.", # noqa
"model_size": "Select the image encoder size of the segment anything 2 model.",
"region": "Compute embeddings for the full image or for one rectangle selected in the 'geometry' layer. The ROI becomes a new selected image layer while the source stays open. For 3D volumes and timeseries, the rectangle crops Y/X across all slices or frames.", # noqa
"advanced_model": "Switch the model list above to advanced models beyond the default SAM2 models (currently SAM1). Only available for the classification tools.", # noqa
"automatic_segmentation_mode": "Select the automatic segmentation mode.",
"run_button": "Compute embeddings or load embeddings if embedding_save_path is specified.",
Expand Down
Loading
Loading