diff --git a/micro_sam/_test_util.py b/micro_sam/_test_util.py index 03a29e4a..618f0cd1 100644 --- a/micro_sam/_test_util.py +++ b/micro_sam/_test_util.py @@ -6,15 +6,15 @@ def check_layer_initialization(viewer, expected_shape): assert len(viewer.layers) == 6 expected_layer_names = [ - "image", "auto_segmentation", "committed_objects", "current_object", "point_prompts", "prompts" + "image", "auto_segmentation", "committed_objects", "current_object", "points", "geometry" ] for layer_name in expected_layer_names: assert layer_name in viewer.layers # Check prompt layers - assert viewer.layers["prompts"].data == [] # shape data is list, not numpy array - np.testing.assert_equal(viewer.layers["point_prompts"].data, 0) + assert viewer.layers["geometry"].data == [] # shape data is list, not numpy array + np.testing.assert_equal(viewer.layers["points"].data, 0) # Check segmentation layers. for layer_name in ["auto_segmentation", "committed_objects", "current_object"]: diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index b92b04c0..9260de0c 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -73,11 +73,11 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None): # Add the point layer for point prompts. self._point_labels = ["positive", "negative"] - if "point_prompts" in self._viewer.layers: - self._point_prompt_layer = self._viewer.layers["point_prompts"] + if "points" in self._viewer.layers: + self._point_prompt_layer = self._viewer.layers["points"] else: self._point_prompt_layer = self._viewer.add_points( - name="point_prompts", + name="points", property_choices={"label": self._point_labels}, border_color="label", border_color_cycle=vutil.LABEL_COLOR_CYCLE, @@ -89,13 +89,13 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None): ) self._point_prompt_layer.border_color_mode = "cycle" - if "prompts" not in self._viewer.layers: + if "geometry" not in self._viewer.layers: # Add the shape layer for box and other shape prompts. self._viewer.add_shapes( face_color="transparent", edge_color="green", edge_width=4, - name="prompts", + name="geometry", ndim=self._ndim, ) @@ -106,7 +106,7 @@ def _get_widgets(self): ) def _create_embedding_widget(self): - return widgets.EmbeddingWidget() + return widgets.EmbeddingWidget(viewer=self._viewer, roi_selection=True) def _create_widgets(self): # Create the embedding widget and connect all events related to it. @@ -140,8 +140,8 @@ def _segment(viewer): # Note: we also need to over-write the keybindings for specific layers. # See https://github.com/napari/napari/issues/7302 for details. # Here, we need to over-write the 's' keybinding for both of the prompt layers. - prompt_layer = self._viewer.layers["prompts"] - point_prompt_layer = self._viewer.layers["point_prompts"] + prompt_layer = self._viewer.layers["geometry"] + point_prompt_layer = self._viewer.layers["points"] @prompt_layer.bind_key("s", overwrite=True) def _segment_prompts(event): @@ -274,7 +274,7 @@ def _rebuild_for_ndim(self, ndim, force=False): self._shape = PLACEHOLDER_SHAPE[ndim] # Remove the existing micro_sam layers so they are recreated with the new ndim and shape. - layer_names = ("current_object", "auto_segmentation", "committed_objects", "point_prompts", "prompts") + layer_names = ("current_object", "auto_segmentation", "committed_objects", "points", "geometry") for layer_name in layer_names: if layer_name in self._viewer.layers: del self._viewer.layers[layer_name] @@ -319,6 +319,7 @@ def _update_image(self, segmentation_result=None): # Update the image scale. scale = state.image_scale + translate = state.image_translate # Reset all layers. self._viewer.layers["current_object"].data = np.zeros( @@ -339,8 +340,16 @@ def _update_image(self, segmentation_result=None): self._viewer.layers["committed_objects"].data = segmentation_result self._viewer.layers["committed_objects"].scale = scale - self._viewer.layers["point_prompts"].scale = scale - self._viewer.layers["prompts"].scale = scale + self._viewer.layers["points"].scale = scale + self._viewer.layers["geometry"].scale = scale + + # Keep cropped annotation layers aligned with their location in the full image. The image + # itself remains unchanged and visible; only the embeddings and result arrays are cropped. + if translate is not None: + for layer_name in ( + "current_object", "auto_segmentation", "committed_objects", "points", "geometry" + ): + self._viewer.layers[layer_name].translate = translate vutil.clear_annotations(self._viewer, clear_segmentations=False) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index 59971590..3779a5b9 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -89,6 +89,7 @@ class AnnotatorState(metaclass=Singleton): predictor: Optional[SamPredictor] = None image_shape: Optional[Tuple[int, int]] = None image_scale: Optional[Tuple[float, ...]] = None + image_translate: Optional[Tuple[float, ...]] = None ndim: Optional[int] = None image_name: Optional[str] = None embedding_path: Optional[str] = None @@ -363,6 +364,7 @@ def reset_state(self): self.predictor = None self.image_shape = None self.image_scale = None + self.image_translate = None self.ndim = None self.image_name = None self.embedding_path = None diff --git a/micro_sam/sam_annotator/_tooltips.py b/micro_sam/sam_annotator/_tooltips.py index dd159aa8..2de3d99d 100644 --- a/micro_sam/sam_annotator/_tooltips.py +++ b/micro_sam/sam_annotator/_tooltips.py @@ -10,6 +10,7 @@ "model_family": "Select the segment anything 2 model family.", "model_family_advanced": "Select the advanced (non-SAM2) model family, e.g. a SAM1 family. Switched on via 'Advanced Models' in the embedding settings.", # noqa "model_size": "Select the image encoder size of the segment anything 2 model.", + "region": "Compute embeddings for the full image or for one rectangle selected in the 'geometry' layer. The ROI becomes a new selected image layer while the source stays open. For 3D volumes and timeseries, the rectangle crops Y/X across all slices or frames.", # noqa "advanced_model": "Switch the model list above to advanced models beyond the default SAM2 models (currently SAM1). Only available for the classification tools.", # noqa "automatic_segmentation_mode": "Select the automatic segmentation mode.", "run_button": "Compute embeddings or load embeddings if embedding_save_path is specified.", diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index f935a375..aa2bf05b 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -588,8 +588,8 @@ def _reset_tracking_state(viewer): state.lineage = {1: []} # Reset the layer properties. - viewer.layers["point_prompts"].property_choices["track_id"] = ["1"] - viewer.layers["prompts"].property_choices["track_id"] = ["1"] + viewer.layers["points"].property_choices["track_id"] = ["1"] + viewer.layers["geometry"].property_choices["track_id"] = ["1"] # Reset the choices in the track_id menu (index 2: prompt, track_state, track_id). state.annotator._tracking_widget[2].value = "1" @@ -914,8 +914,8 @@ def write_prompts( ds.attrs["track_state"] = track_state.tolist() # Get the prompts from the layers. - prompts = viewer.layers["prompts"].data - point_layer = viewer.layers["point_prompts"] + prompts = viewer.layers["geometry"].data + point_layer = viewer.layers["points"] point_prompts = point_layer.data point_labels = point_layer.properties["label"] if len(point_prompts) > 0: @@ -935,7 +935,7 @@ def write_prompts( ): # We have multiple objects from tracking a lineage with divisions. track_ids_points = np.array(point_layer.properties["track_id"]) track_ids_prompts = np.array( - viewer.layers["prompts"].properties["track_id"] + viewer.layers["geometry"].properties["track_id"] ) unique_track_ids = np.unique(track_ids_points) @@ -1345,8 +1345,8 @@ def _validate_layers( if not automatic_segmentation: # Check prompts layer. if ( - len(viewer.layers["prompts"].data) == 0 - and len(viewer.layers["point_prompts"].data) == 0 + len(viewer.layers["geometry"].data) == 0 + and len(viewer.layers["points"].data) == 0 ): msg = "No prompts were given. Please provide prompts to run interactive segmentation." return _generate_message("error", msg) @@ -1387,10 +1387,10 @@ def _segment_object_2d(viewer, batched=False): # get the current box and point prompts boxes, masks = vutil.shape_layer_to_prompts( - viewer.layers["prompts"], shape + viewer.layers["geometry"], shape ) points, labels = vutil.point_layer_to_prompts( - viewer.layers["point_prompts"], with_stop_annotation=False + viewer.layers["points"], with_stop_annotation=False ) state = AnnotatorState() @@ -1496,9 +1496,18 @@ def _process_tiling_inputs(tile_shape_x, tile_shape_y, halo_x, halo_y): class EmbeddingWidget(_WidgetBase): - def __init__(self, parent=None, sam2_only=False, ndim_choice=False, is_timeseries=False): + def __init__( + self, parent=None, sam2_only=False, ndim_choice=False, is_timeseries=False, + viewer=None, roi_selection=False, + ): super().__init__(parent=parent) self.sam2_only = sam2_only + self._viewer = viewer + # ROI selection is enabled for segmentation and tracking, which own the 'geometry' layer. + # Classification and image-series launchers reuse this widget without exposing this option. + self.roi_selection = roi_selection + self._last_roi = None + self._last_roi_image = None # Whether to expose the 'image dimensions' (ndim) override dropdown. Only the segmentation # annotator wires it into image normalization, so it is off by default (hidden for tracking # and the classifiers, which do not use it). @@ -1654,6 +1663,14 @@ def _create_settings_widget(self): ) setting_values.layout().addLayout(ndim_layout) + if self.roi_selection: + self.embedding_region = "full image" + self.embedding_region_dropdown, region_layout = self._add_choice_param( + "embedding_region", self.embedding_region, ["full image", "selected ROI"], + title="embedding region:", tooltip=get_tooltip("embedding", "region"), + ) + setting_values.layout().addLayout(region_layout) + # Create UI for tiling. A dropdown toggles whether tiling is used; when enabled, # the tile shape and halo fields are revealed with sensible defaults. self.tiling = "no" @@ -1785,6 +1802,147 @@ def _ndim_override(self): mode = dropdown.currentText() if dropdown is not None else "auto" return {"auto": None, "2d": 2, "3d": 3}.get(mode) + @staticmethod + def _crop_image(image, roi): + if roi is None: + return image.data + index = roi + ((slice(None),) if image.rgb else ()) + return image.data[index] + + @staticmethod + def _roi_from_bounds(bounds, spatial_shape): + """Convert serialized ``[[start, stop], ...]`` bounds to validated slices.""" + if bounds is None or len(bounds) != len(spatial_shape): + return None + roi = [] + for bound, size in zip(bounds, spatial_shape): + if len(bound) != 2: + return None + start, stop = int(bound[0]), int(bound[1]) + if start < 0 or stop > size or start >= stop: + return None + roi.append(slice(start, stop)) + return tuple(roi) + + @staticmethod + def _serialize_roi(roi): + if roi is None: + return None + return [[int(s.start or 0), int(s.stop)] for s in roi] + + def _saved_roi(self, spatial_shape): + if not self.embeddings_save_path or not os.path.isdir(self.embeddings_save_path): + return None + try: + f = zarr.open(self.embeddings_save_path, mode="r") + return self._roi_from_bounds(f.attrs.get("image_roi"), spatial_shape) + except (KeyError, RuntimeError, ValueError): + return None + + def _resolve_roi(self, image): + """Resolve the selected rectangle to an image-data ROI. + + A 3D shape is a rectangle in the current Y/X plane. We intentionally keep the leading + z/time axis complete, so one selection works consistently for volumes and timeseries. + """ + if not self.roi_selection or self.embedding_region_dropdown.currentText() == "full image": + return None + + spatial_shape = tuple(image.data.shape[:-1] if image.rgb else image.data.shape) + geometry = ( + None + if self._viewer is None or "geometry" not in self._viewer.layers + else self._viewer.layers["geometry"] + ) + selected = [] if geometry is None else sorted(geometry.selected_data) + if geometry is not None and len(selected) != 1 and len(geometry.data) == 1: + selected = [0] + + if geometry is not None and len(selected) == 1: + index = selected[0] + if str(geometry.shape_type[index]) != "rectangle": + raise ValueError("The embedding ROI must be a rectangle in the 'geometry' layer.") + + # Convert via world coordinates so non-default layer scale / translation are respected. + vertices = np.asarray(geometry.data[index]) + image_vertices = np.asarray([ + image.world_to_data(geometry.data_to_world(vertex)) for vertex in vertices + ]) + # Snap transform round-off close to integer pixel coordinates before floor / ceil; + # otherwise e.g. 70.00000000001 would add an unintended extra column. + lower = np.floor(np.round(image_vertices[:, -2:].min(axis=0), decimals=6)).astype(int) + upper = np.ceil(np.round(image_vertices[:, -2:].max(axis=0), decimals=6)).astype(int) + lower = np.maximum(lower, 0) + upper = np.minimum(upper, spatial_shape[-2:]) + if np.any(upper <= lower): + raise ValueError("The selected embedding ROI does not overlap the image.") + + roi = [slice(0, size) for size in spatial_shape] + roi[-2:] = [slice(int(lower[0]), int(upper[0])), slice(int(lower[1]), int(upper[1]))] + return tuple(roi) + + # Re-use the last crop after its defining rectangle was cleared on a successful compute. + if self._last_roi_image is image and self._last_roi is not None: + return self._last_roi + + # This also makes reopening a cached ROI embedding possible without redrawing its rectangle. + saved_roi = self._saved_roi(spatial_shape) + if saved_roi is not None: + return saved_roi + + raise ValueError( + "Select one rectangle in the 'geometry' layer before computing embeddings for an ROI." + ) + + def _create_roi_image_layer(self, image, roi): + """Create and select a napari image layer for an embedding ROI.""" + if self._viewer is None: + raise RuntimeError("Creating an ROI image layer requires a napari viewer.") + + data = self._crop_image(image, roi) + offset = np.array([s.start or 0 for s in roi], dtype=float) + translate = tuple(image.data_to_world(offset)) + + base_name = f"{image.name} ROI" + name, index = base_name, 2 + while name in self._viewer.layers: + name = f"{base_name} {index}" + index += 1 + + metadata = { + "micro_sam_roi_source": image.name, + "micro_sam_roi": self._serialize_roi(roi), + } + annotator = AnnotatorState().annotator + suppress = annotator is not None and getattr(annotator, "_viewer", None) is self._viewer + if suppress: + annotator._suppress_selection_rebuild = True + try: + roi_layer = self._viewer.add_image( + data, name=name, rgb=image.rgb, scale=tuple(image.scale), translate=translate, + metadata=metadata, + ) + roi_layer.contrast_limits = image.contrast_limits + roi_layer.gamma = image.gamma + if not image.rgb: + roi_layer.colormap = image.colormap + # Hide the source so the crop is the only image shown; it stays in the viewer. + image.visible = False + # Move the crop to the bottom so the prompt / segmentation layers render on top of it. + self._viewer.layers.move(self._viewer.layers.index(roi_layer), 0) + self.image_selection.value = roi_layer + self._viewer.layers.selection.active = roi_layer + if suppress: + # The selection-change callback was intentionally suppressed while materializing the + # crop. Record it as current so selecting another source/crop later performs a reset. + annotator._last_image_layer = roi_layer + finally: + if suppress: + annotator._suppress_selection_rebuild = False + + self.embedding_region_dropdown.setCurrentText("full image") + return roi_layer + def _update_tiling_visibility(self, index=None): # Show the in-plane tile shape and halo fields only when tiling is enabled. self.tiling = self.tiling_dropdown.currentText() @@ -1849,6 +2007,11 @@ def _reset_inputs_to_defaults(self): self.image_ndim_mode = "auto" self.image_ndim_dropdown.blockSignals(False) + if self.roi_selection: + self.embedding_region_dropdown.setCurrentText("full image") + self._last_roi = None + self._last_roi_image = None + self._set_default_tiling() def _validate_inputs(self): @@ -1877,6 +2040,11 @@ def _validate_inputs(self): if image is None: return _generate_message("error", "No image has been selected.") + try: + roi = self._resolve_roi(image) + except ValueError as e: + return _generate_message("error", str(e)) + # Check if we have an existing embedding path. # If yes we check the data signature of these embeddings against the selected image # and we ask the user if they want to load these embeddings. @@ -1899,8 +2067,7 @@ def _validate_inputs(self): # Validate image data signature. if "data_signature" in f.attrs: - image = self.image_selection.get_value() - img_signature = util._compute_data_signature(image.data) + img_signature = util._compute_data_signature(self._crop_image(image, roi)) if img_signature != f.attrs["data_signature"]: msg = f"The embeddings don't match with the image: {img_signature} {f.attrs['data_signature']}" return _generate_message("error", msg) @@ -1992,10 +2159,18 @@ def __call__(self, skip_validate=False): # Validate user inputs. if not skip_validate and self._validate_inputs(): + # The annotator's layer-update slot is connected to the same button. Keep it from + # clearing the ROI / annotations when embedding validation aborted the computation. + AnnotatorState().skip_recomputing_embeddings = True return # Get the image. image = self.image_selection.get_value() + try: + roi = self._resolve_roi(image) + except ValueError as e: + AnnotatorState().skip_recomputing_embeddings = True + return _generate_message("error", str(e)) # Update the image embeddings: state = AnnotatorState() @@ -2008,17 +2183,25 @@ def __call__(self, skip_validate=False): # Reset the state. state.reset_state() - # Get image dimensions. + source_image, source_roi = image, roi + if roi is not None: + image = self._create_roi_image_layer(image, roi) + image_data = image.data + + # Get image dimensions. The selected ROI only crops Y/X and keeps z/time complete. if image.rgb: - ndim = image.data.ndim - 1 - state.image_shape = image.data.shape[:-1] + ndim = image_data.ndim - 1 + state.image_shape = image_data.shape[:-1] else: - ndim = image.data.ndim - state.image_shape = image.data.shape + ndim = image_data.ndim + state.image_shape = image_data.shape state.ndim = ndim - # Set layer scale + # Annotation arrays use local coordinates of the selected image layer. ROI layers carry a + # world-space translation so results overlay the corresponding location in the source image. state.image_scale = tuple(image.scale) + state.image_translate = tuple(image.translate) + state.image_name = image.name # Process tile_shape and halo, set other data. Tiling is only applied when enabled. if self.tiling == "yes": @@ -2041,8 +2224,6 @@ def __call__(self, skip_validate=False): if self.embeddings_save_path == "" else self.embeddings_save_path ) - image_data = image.data - # Set up progress bar and signals for using it within a threadworker. pbar, pbar_signals = _create_pbar_for_threadworker() @@ -2079,6 +2260,19 @@ def pbar_update(update): pbar_signals.pbar_stop.emit() compute_image_embedding() + self._last_roi = source_roi + self._last_roi_image = source_image + + # Record the crop location alongside cached embeddings. The embedding backend already stores + # the crop's data signature and shape; this attribute restores its placement in the full image. + if isinstance(save_path, str): + f = zarr.open(save_path, mode="a") + if source_roi is None: + f.attrs.pop("image_roi", None) + f.attrs.pop("image_roi_source", None) + else: + f.attrs["image_roi"] = self._serialize_roi(source_roi) + f.attrs["image_roi_source"] = source_image.name self._update_model(state) # worker = compute_image_embedding() # worker.returned.connect(self._update_model) @@ -2347,8 +2541,8 @@ def _update_lineage(viewer, mother=None): track_ids = list(map(str, state.lineage.keys())) tracking_widget[2].choices = track_ids - viewer.layers["point_prompts"].property_choices["track_id"] = list(track_ids) - viewer.layers["prompts"].property_choices["track_id"] = list(track_ids) + viewer.layers["points"].property_choices["track_id"] = list(track_ids) + viewer.layers["geometry"].property_choices["track_id"] = list(track_ids) class UnifiedSegmentWidget(_WidgetBase): @@ -2543,20 +2737,20 @@ def _run_slice_segmentation(self): shape = self._viewer.layers["current_object"].data.shape[1:] position_world = self._viewer.dims.point - position = self._viewer.layers["point_prompts"].world_to_data( + position = self._viewer.layers["points"].world_to_data( position_world ) z = int(position[0]) point_prompts = vutil.point_layer_to_prompts( - self._viewer.layers["point_prompts"], z + self._viewer.layers["points"], z ) # this is a stop prompt, we do nothing if not point_prompts: return boxes, masks = vutil.shape_layer_to_prompts( - self._viewer.layers["prompts"], shape, i=z + self._viewer.layers["geometry"], shape, i=z ) points, labels = point_prompts @@ -2644,14 +2838,14 @@ def _segment_slice_batched(self, z, points, labels, boxes, shape): def _segment_track_on_frame(self, state, t, track_id, shape): """Segment a single track's object on frame 't'. Returns the binary mask or None.""" point_prompts = vutil.point_layer_to_prompts( - self._viewer.layers["point_prompts"], i=t, track_id=track_id, + self._viewer.layers["points"], i=t, track_id=track_id, ) # A single negative point is a stop prompt: nothing to segment for this track here. if not point_prompts: return None boxes, masks = vutil.shape_layer_to_prompts( - self._viewer.layers["prompts"], shape, i=t, track_id=track_id, + self._viewer.layers["geometry"], shape, i=t, track_id=track_id, ) points, labels = point_prompts @@ -2708,8 +2902,8 @@ def volumetric_segmentation_impl(): if state.is_sam2: # Prepare the prompts - point_prompts = self._viewer.layers["point_prompts"] - box_prompts = self._viewer.layers["prompts"] + point_prompts = self._viewer.layers["points"] + box_prompts = self._viewer.layers["geometry"] z_values_points = np.round(point_prompts.data[:, 0]) z_values_boxes = ( np.concatenate([box[:1, 0] for box in box_prompts.data]) @@ -2784,8 +2978,8 @@ def volumetric_segmentation_impl(): seg, slices, stop_lower, stop_upper = ( vutil.segment_slices_with_prompts( state.predictor, - self._viewer.layers["point_prompts"], - self._viewer.layers["prompts"], + self._viewer.layers["points"], + self._viewer.layers["geometry"], state.image_embeddings, shape, update_progress=emit_progress, @@ -2848,8 +3042,8 @@ def propagate_track(track_id, division_frame): # frames. A frame whose only prompt for this track is a single negative point is a # 'stop' annotation; a stop on the highest annotated frame bounds propagation above. shape = state.image_shape - point_layer = self._viewer.layers["point_prompts"] - box_layer = self._viewer.layers["prompts"] + point_layer = self._viewer.layers["points"] + box_layer = self._viewer.layers["geometry"] # Reset so a re-run does not accumulate prompts from a previous propagation. state.interactive_segmenter.reset_predictor() @@ -2920,7 +3114,7 @@ def propagate_track(track_id, division_frame): def tracking_impl(): # Propagate the current track. Its propagated mask is labelled with its track id. track_ids = [state.current_track_id] - point_layer = self._viewer.layers["point_prompts"] + point_layer = self._viewer.layers["points"] seg_layer = self._viewer.layers["current_object"] results = {} for track_id in track_ids: diff --git a/micro_sam/sam_annotator/annotator.py b/micro_sam/sam_annotator/annotator.py index ffbe54e5..19077aac 100644 --- a/micro_sam/sam_annotator/annotator.py +++ b/micro_sam/sam_annotator/annotator.py @@ -87,7 +87,7 @@ class Annotator(_AnnotatorBase): def _create_embedding_widget(self): # Expose the 'image dimensions' (ndim) override here: the segmentation annotator is the only # one that wires it into image normalization (it handles both 2d and 3d data). - return widgets.EmbeddingWidget(ndim_choice=True) + return widgets.EmbeddingWidget(viewer=self._viewer, ndim_choice=True, roi_selection=True) def _get_widgets(self): """Create the widgets for the segmentation annotator. @@ -122,8 +122,8 @@ def _segment(viewer): # We also need to over-write the keybindings for the prompt layers. # See https://github.com/napari/napari/issues/7302 for details. - prompt_layer = self._viewer.layers["prompts"] - point_prompt_layer = self._viewer.layers["point_prompts"] + prompt_layer = self._viewer.layers["geometry"] + point_prompt_layer = self._viewer.layers["points"] @prompt_layer.bind_key("s", overwrite=True) def _segment_prompts(event): @@ -214,6 +214,9 @@ def _on_image_selection_changed(self, *args): show_info(str(e)) return + # Re-show the selected image: an ROI crop hides its source, so re-selecting it unhides it. + image_layer.visible = True + # Detect an actual change of the selected image, tracked by layer identity (the state's # 'image_name' is not reliably set on every code path, so we don't depend on it). The first # call (during setup) just records the image and does not reset; a later switch to a diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index 09f76c8a..696865a0 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -155,7 +155,9 @@ def track_id_changed_boxes(new_track_id): class AnnotatorTracking(_AnnotatorBase): def _create_embedding_widget(self): - return widgets.EmbeddingWidget(sam2_only=True, is_timeseries=True) + return widgets.EmbeddingWidget( + viewer=getattr(self, "_viewer", None), sam2_only=True, is_timeseries=True, roi_selection=True + ) # The tracking annotator needs different settings for the prompt layers # to support the additional tracking state. @@ -214,18 +216,18 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None): } point_layer_mismatch = True - if "point_prompts" in self._viewer.layers: + if "points" in self._viewer.layers: # Check whether the 'property_choices' match or not. curr_property_choices = self._viewer.layers[ - "point_prompts" + "points" ].property_choices point_layer_mismatch = set(curr_property_choices.keys()) != set( _point_prompt_property_choices.keys() ) - if point_layer_mismatch and "point_prompts" not in self._viewer.layers: + if point_layer_mismatch and "points" not in self._viewer.layers: self._point_prompt_layer = self._viewer.add_points( - name="point_prompts", + name="points", property_choices=_point_prompt_property_choices, border_color="label", border_color_cycle=vutil.LABEL_COLOR_CYCLE, @@ -240,23 +242,23 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None): self._point_prompt_layer.face_color_mode = "cycle" _new_point_layer = True else: - self._point_prompt_layer = self._viewer.layers["point_prompts"] + self._point_prompt_layer = self._viewer.layers["points"] _new_point_layer = False # Add the point prompts layer. _box_prompt_property_choices = {"track_id": ["1"]} box_layer_mismatch = True - if "prompts" in self._viewer.layers: + if "geometry" in self._viewer.layers: # Check whether the 'property_choices' match or not. curr_property_choices = self._viewer.layers[ - "prompts" + "geometry" ].property_choices box_layer_mismatch = set(curr_property_choices.keys()) != set( _box_prompt_property_choices.keys() ) - if box_layer_mismatch and "prompts" not in self._viewer.layers: + if box_layer_mismatch and "geometry" not in self._viewer.layers: # Using the box layer to set divisions currently doesn't work. # That's why some of the code below is commented out. self._box_prompt_layer = self._viewer.add_shapes( @@ -264,7 +266,7 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None): edge_width=4, ndim=self._ndim, face_color="transparent", - name="prompts", + name="geometry", edge_color="green", property_choices=_box_prompt_property_choices, # property_choices={"track_id": ["1"], "state": self._track_state_labels}, @@ -273,7 +275,7 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None): # self._box_prompt_layer.edge_color_mode = "cycle" _new_box_layer = True else: - self._box_prompt_layer = self._viewer.layers["prompts"] + self._box_prompt_layer = self._viewer.layers["geometry"] _new_box_layer = False # Trigger a new connection for the tracking state menu only when a new layer is (re)created. @@ -325,8 +327,8 @@ def _segment(viewer): # We also need to over-write the keybindings for the prompt layers. # See https://github.com/napari/napari/issues/7302 for details. - prompt_layer = self._viewer.layers["prompts"] - point_prompt_layer = self._viewer.layers["point_prompts"] + prompt_layer = self._viewer.layers["geometry"] + point_prompt_layer = self._viewer.layers["points"] @prompt_layer.bind_key("s", overwrite=True) def _segment_prompts(event): diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index d23eaa33..e471239e 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -211,15 +211,15 @@ def _initialize_parser(description, with_segmentation_result=True, with_instance def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: """@private""" - viewer.layers["point_prompts"].data = [] - viewer.layers["point_prompts"].refresh() - if "prompts" in viewer.layers: + viewer.layers["points"].data = [] + viewer.layers["points"].refresh() + if "geometry" in viewer.layers: # Select all prompts and then remove them. # This is how it worked before napari 0.5. - # viewer.layers["prompts"].data = [] - viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data))) - viewer.layers["prompts"].remove_selected() - viewer.layers["prompts"].refresh() + # viewer.layers["geometry"].data = [] + viewer.layers["geometry"].selected_data = set(range(len(viewer.layers["geometry"].data))) + viewer.layers["geometry"].remove_selected() + viewer.layers["geometry"].refresh() if not clear_segmentations: return viewer.layers["current_object"].data = np.zeros(viewer.layers["current_object"].data.shape, dtype="uint32") @@ -228,15 +228,15 @@ def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: def clear_annotations_slice(viewer: napari.Viewer, i: int, clear_segmentations=True) -> None: """@private""" - point_prompts = viewer.layers["point_prompts"].data + point_prompts = viewer.layers["points"].data point_prompts = point_prompts[point_prompts[:, 0] != i] - viewer.layers["point_prompts"].data = point_prompts - viewer.layers["point_prompts"].refresh() - if "prompts" in viewer.layers: - prompts = viewer.layers["prompts"].data + viewer.layers["points"].data = point_prompts + viewer.layers["points"].refresh() + if "geometry" in viewer.layers: + prompts = viewer.layers["geometry"].data prompts = [prompt for prompt in prompts if not (prompt[:, 0] == i).all()] - viewer.layers["prompts"].data = prompts - viewer.layers["prompts"].refresh() + viewer.layers["geometry"].data = prompts + viewer.layers["geometry"].refresh() if not clear_segmentations: return viewer.layers["current_object"].data[i] = 0 diff --git a/test/test_sam_annotator/test_annotator.py b/test/test_sam_annotator/test_annotator.py index 40f6b477..3c426216 100644 --- a/test/test_sam_annotator/test_annotator.py +++ b/test/test_sam_annotator/test_annotator.py @@ -121,6 +121,15 @@ def test_widget_no_image_defaults_to_2d(self, make_napari_viewer_proxy): assert widget._ndim == 2 viewer.close() + def test_shape_prompt_layer_is_named_geometry(self, make_napari_viewer_proxy): + viewer = make_napari_viewer_proxy() + widget = Annotator(viewer) + + assert "geometry" in viewer.layers + assert "prompts" not in viewer.layers + assert widget._embedding_widget.roi_selection + viewer.close() + def test_widget_detects_ndim_from_loaded_image(self, make_napari_viewer_proxy): # When an image is loaded before opening the widget, ndim is detected from it. viewer = make_napari_viewer_proxy() @@ -140,7 +149,7 @@ def test_widget_rebuilds_when_3d_image_loaded_after_open(self, make_napari_viewe widget._embedding_widget.image_selection.reset_choices() assert widget._ndim == 3 # The prompt layers must be recreated with the new dimensionality. - assert viewer.layers["point_prompts"].ndim == 3 + assert viewer.layers["points"].ndim == 3 viewer.close() def test_annotator_3d(self, make_napari_viewer_proxy): @@ -263,7 +272,7 @@ def test_channels_first_forced_2d(self, make_napari_viewer_proxy): assert widget._ndim == 2 assert viewer.layers["image"].rgb is True assert tuple(viewer.layers["image"].data.shape) == (64, 64, 3) - assert viewer.layers["point_prompts"].ndim == 2 + assert viewer.layers["points"].ndim == 2 viewer.close() def test_channels_last_two_channel_auto(self, make_napari_viewer_proxy): @@ -316,6 +325,111 @@ def test_force_3d_on_2d_image_warns_and_reverts_to_auto(self, make_napari_viewer viewer.close() +@pytest.mark.gui +@pytest.mark.skipif(platform.system() in ("Windows",), reason="Gui test is not working on windows.") +class TestEmbeddingROI: + + def test_compute_2d_roi_updates_state_and_layer_alignment(self, make_napari_viewer_proxy, monkeypatch): + from micro_sam.sam_annotator._state import AnnotatorState + + viewer = make_napari_viewer_proxy() + image = viewer.add_image( + np.zeros((100, 120), dtype="uint8"), name="image", scale=(2, 3), translate=(5, 7) + ) + widget = Annotator(viewer) + geometry = viewer.layers["geometry"] + geometry.scale = image.scale + geometry.translate = image.translate + geometry.add_rectangles(np.array([[10, 20], [10, 80], [60, 80], [60, 20]])) + geometry.selected_data = {0} + + embedding_widget = widget._embedding_widget + embedding_widget.embedding_region_dropdown.setCurrentText("selected ROI") + embedding_widget._update_model = lambda state: None + captured = {} + + def fake_initialize(state, image_data, **kwargs): + captured["shape"] = image_data.shape + state.image_embeddings = {"features": np.zeros((1, 1, 1, 1)), "input_size": image_data.shape} + state.data_signature = "roi" + + monkeypatch.setattr(AnnotatorState, "initialize_predictor", fake_initialize) + embedding_widget(skip_validate=True) + + state = AnnotatorState() + assert captured["shape"] == (50, 60) + assert state.image_shape == (50, 60) + assert state.image_translate == tuple(image.data_to_world((10, 20))) + roi_layer = embedding_widget.image_selection.value + assert roi_layer.name == "image ROI" + assert roi_layer.data.shape == (50, 60) + assert roi_layer.metadata["micro_sam_roi_source"] == "image" + assert roi_layer.metadata["micro_sam_roi"] == [[10, 60], [20, 80]] + assert state.image_name == roi_layer.name + assert viewer.layers.selection.active.name == roi_layer.name + assert "image" in viewer.layers # the source layer stays open + assert not viewer.layers["image"].visible # but is hidden so only the crop shows + assert viewer.layers.index(roi_layer) == 0 # crop sits at the bottom, under the annotation layers + + widget._update_image() + assert viewer.layers["current_object"].data.shape == (50, 60) + assert tuple(viewer.layers["current_object"].translate) == state.image_translate + assert viewer.layers["geometry"].data == [] + + # Select the source again and create another independent crop. + embedding_widget.image_selection.value = image + assert viewer.layers["image"].visible # re-selecting the source unhides it + geometry = viewer.layers["geometry"] + source_vertices = np.array([[20, 30], [20, 70], [50, 70], [50, 30]]) + geometry_vertices = np.array([ + geometry.world_to_data(image.data_to_world(vertex)) for vertex in source_vertices + ]) + geometry.add_rectangles(geometry_vertices) + geometry.selected_data = {0} + embedding_widget.embedding_region_dropdown.setCurrentText("selected ROI") + embedding_widget(skip_validate=True) + + second_roi = embedding_widget.image_selection.value + assert second_roi.name == "image ROI 2" + assert second_roi.data.shape == (30, 40) + assert all(name in viewer.layers for name in ("image", "image ROI", "image ROI 2")) + viewer.close() + + def test_resolve_2d_roi(self, make_napari_viewer_proxy): + viewer = make_napari_viewer_proxy() + image = viewer.add_image(np.zeros((100, 120), dtype="uint8"), name="image") + widget = Annotator(viewer) + geometry = viewer.layers["geometry"] + geometry.add_rectangles(np.array([[10, 20], [10, 80], [60, 80], [60, 20]])) + geometry.selected_data = {0} + + embedding_widget = widget._embedding_widget + embedding_widget.embedding_region_dropdown.setCurrentText("selected ROI") + roi = embedding_widget._resolve_roi(image) + + assert [(s.start, s.stop) for s in roi] == [(10, 60), (20, 80)] + assert embedding_widget._crop_image(image, roi).shape == (50, 60) + viewer.close() + + def test_resolve_3d_roi_keeps_all_slices(self, make_napari_viewer_proxy): + viewer = make_napari_viewer_proxy() + image = viewer.add_image(np.zeros((8, 100, 120), dtype="uint8"), name="image") + widget = Annotator(viewer) + geometry = viewer.layers["geometry"] + geometry.add_rectangles( + np.array([[3, 10, 20], [3, 10, 80], [3, 60, 80], [3, 60, 20]]) + ) + geometry.selected_data = {0} + + embedding_widget = widget._embedding_widget + embedding_widget.embedding_region_dropdown.setCurrentText("selected ROI") + roi = embedding_widget._resolve_roi(image) + + assert [(s.start, s.stop) for s in roi] == [(0, 8), (10, 60), (20, 80)] + assert embedding_widget._crop_image(image, roi).shape == (8, 50, 60) + viewer.close() + + @pytest.mark.gui @pytest.mark.skipif(platform.system() in ("Windows",), reason="Gui test is not working on windows.") class TestZTilingControls: