diff --git a/pyproject.toml b/pyproject.toml index 24f4a55..70f034c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ pyqt = ">=5.15.9,<6" numpy = "<2.2" pre-commit = ">=4.1.0,<5" dask = ">=2025.2.0,<2026" +scikit-learn = ">=1.8.0,<2" [tool.pixi.feature.test.dependencies] pytest = "*" diff --git a/scripts/script_neuromast.py b/scripts/script_neuromast.py index c4ee6da..142bbe1 100644 --- a/scripts/script_neuromast.py +++ b/scripts/script_neuromast.py @@ -49,7 +49,6 @@ "/hpc/projects/group.royer/people/teun.huijben/data/Thibault/" "neuromast_model/instanseg_27527750.pt" ) -instanseg_device = None # 'cuda', 'cpu', or None for auto-detect # ************************* if __name__ == "__main__": @@ -64,5 +63,4 @@ annotation_mapping=annotation_mapping, flag_allow_adding_instanseg_cell=flag_allow_adding_instanseg_cell, instanseg_model_path=instanseg_model_path, - instanseg_device=instanseg_device, ) diff --git a/trackedit/TrackEditClass.py b/trackedit/TrackEditClass.py index 65428ce..35fdd93 100644 --- a/trackedit/TrackEditClass.py +++ b/trackedit/TrackEditClass.py @@ -1,12 +1,18 @@ import napari import numpy as np from motile_toolbox.candidate_graph import NodeAttr +from napari.utils import progress from napari.utils.colormaps import DirectLabelColormap from napari.utils.notifications import show_info, show_warning from qtpy.QtWidgets import QTabWidget +from scipy import ndimage +from skimage.feature import peak_local_max +from skimage.segmentation import watershed +from sklearn.cluster import KMeans from ultrack.core.database import NodeDB, get_node_values from ultrack.core.interactive import add_new_node +from motile_tracker.data_model.actions import ActionGroup, AddEdges from motile_tracker.data_model.solution_tracks import SolutionTracks from motile_tracker.data_views import TracksViewer, TreeWidget from trackedit.DatabaseHandler import DatabaseHandler @@ -132,6 +138,12 @@ def __init__( self.EditingMenu.duplicate_cell_button_pressed.connect( self.duplicate_cell_from_database ) + self.EditingMenu.split_cell_button_pressed.connect(self.split_cell) + + def _split_cell_shortcut(viewer=None): + self.split_cell(self.EditingMenu.split_method_combo.currentText()) + + self.viewer.bind_key("s", overwrite=True)(_split_cell_shortcut) # Connect spherical cell signal if feature is enabled if flag_allow_adding_spherical_cell: @@ -140,6 +152,8 @@ def __init__( ) # Initialize spherical cell mode flag self._add_cell_mode_active = False + self._last_added_cell = None # (node_id, time, track_id) for auto-linking + self._spherical_wants_link = False # only True after a double-click advance # Connect InstanSeg cell signal if feature is enabled if flag_allow_adding_instanseg_cell: @@ -148,6 +162,8 @@ def __init__( ) # Initialize InstanSeg cell mode flag self._add_instanseg_mode_active = False + self._last_added_cell = None # (node_id, time, track_id) for auto-linking + self._instanseg_wants_link = False # only True after a double-click advance self.add_tracks() self.NavigationWidget.time_box.update_chunk_label() @@ -446,15 +462,32 @@ def duplicate_cell_from_database(self, node_id: int, time: int): self.EditingMenu.duplicate_cell_id_input.setText("") self.EditingMenu.duplicate_time_input.setText("") - def add_spherical_cell_at_position(self, position_scaled, radius_pixels=10): + def _mask_without_overlap(self, mask: np.ndarray, bbox: np.ndarray) -> np.ndarray: + """Remove voxels already occupied by other cells from a new cell mask. + + Reads the current segmentation array and clears any voxels that are + non-zero (i.e. belong to an existing cell). Always active — new cells + never overlap existing segmentation. + """ + seg = self.databasehandler.segments.array + ndim = mask.ndim + if ndim == 3: + z0, y0, x0, z1, y1, x1 = bbox.astype(int) + occupied = seg[z0:z1, y0:y1, x0:x1] > 0 + else: + y0, x0, y1, x1 = bbox.astype(int) + occupied = seg[y0:y1, x0:x1] > 0 + return mask & ~occupied + + def add_spherical_cell_at_position(self, position_scaled, radius_physical=10): """Add a new cell with spherical segmentation at the given position. Parameters ---------- position_scaled : array-like Position in viewer coordinates (scaled) - radius_pixels : float - Radius of the sphere in pixels (default: 10) + radius_physical : float + Radius of the sphere in physical units (same units as scale, default: 10) Returns ------- @@ -467,7 +500,7 @@ def add_spherical_cell_at_position(self, position_scaled, radius_pixels=10): # Create mask and bbox mask, bbox = create_cell_mask_and_bbox( position_scaled=position_scaled, - radius_pixels=radius_pixels, + radius_physical=radius_physical, ndim=self.databasehandler.ndim, scale=( self.databasehandler.z_scale, @@ -481,6 +514,11 @@ def add_spherical_cell_at_position(self, position_scaled, radius_pixels=10): if mask is None: return None + mask = self._mask_without_overlap(mask, np.array(bbox)) + if not mask.any(): + show_warning("New cell fully overlaps existing segmentation — skipped.") + return None + # Add to database try: new_node_id = add_new_node( @@ -503,46 +541,73 @@ def add_spherical_cell_at_position(self, position_scaled, radius_pixels=10): return None # Add to tracking system - track_ids = ( - self.NavigationWidget.tracks_viewer.tracks_controller.tracks.track_id_to_node.keys() - ) + tc = self.tracksviewer.tracks_controller + track_ids = tc.tracks.track_id_to_node.keys() max_track_id = max(track_ids) if track_ids else 0 time_in_chunk = current_time - self.databasehandler.time_window[0] + # Auto-link to previous cell only if the user double-clicked to advance here + link_to_prev = ( + self._spherical_wants_link + and self._last_added_cell is not None + and current_time == self._last_added_cell[1] + 1 + ) + self._spherical_wants_link = False # consume the flag + use_track_id = ( + max_track_id + 1 + ) # always fresh; edge creation handles track continuity + attributes = { NodeAttr.TIME.value: [time_in_chunk], - NodeAttr.TRACK_ID.value: [max_track_id + 1], + NodeAttr.TRACK_ID.value: [use_track_id], "node_id": [new_node_id], } - self.tracksviewer.tracks_controller.add_nodes( - attributes, [(np.array([0, 0, 0]))] - ) + nodes_action, _ = tc._add_nodes(attributes, [(np.array([0, 0, 0]))]) + + if link_to_prev: + edge = np.array([[self._last_added_cell[0], new_node_id]]) + is_valid, valid_action = tc.is_valid( + [self._last_added_cell[0], new_node_id] + ) + if is_valid: + edge_action = tc._add_edges(edge) + extra = [valid_action, edge_action] if valid_action else [edge_action] + action = ActionGroup(tc.tracks, [nodes_action] + extra) + else: + action = nodes_action + else: + action = nodes_action + + tc.action_history.add_new_action(action) + tc.tracks.refresh.emit(new_node_id) + + self._last_added_cell = (new_node_id, current_time, use_track_id) - # Refresh and auto-disable + # Refresh segmentation layer show_info(f"Added spherical cell with ID {new_node_id}") self.databasehandler.segments.force_refill() self.viewer.layers[self.databasehandler.name + "_seg"].refresh() - # Auto-disable only if button exists - if hasattr(self.EditingMenu, "add_spherical_cell_btn"): - from qtpy.QtCore import QTimer - - QTimer.singleShot( - 0, - lambda: ( - self.EditingMenu.add_spherical_cell_btn.setChecked(False), - self._toggle_add_spherical_cell_mode(False), - ), - ) - return new_node_id def _toggle_add_spherical_cell_mode(self, checked): """Toggle the add spherical cell mode on/off.""" self._add_cell_mode_active = checked + self._last_added_cell = None # reset auto-link chain on every toggle + self._spherical_wants_link = False if checked: - # Only add callback if not already present + # Mutual exclusion: disable instanseg mode if active + if ( + hasattr(self, "_add_instanseg_mode_active") + and self._add_instanseg_mode_active + ): + self._toggle_add_instanseg_mode(False) + if hasattr(self.EditingMenu, "add_instanseg_cell_btn"): + self.EditingMenu.add_instanseg_cell_btn.setChecked(False) + + self._suppress_double_click_zoom() + if ( self._on_mouse_click_add_spherical_cell not in self.viewer.mouse_drag_callbacks @@ -550,8 +615,14 @@ def _toggle_add_spherical_cell_mode(self, checked): self.viewer.mouse_drag_callbacks.append( self._on_mouse_click_add_spherical_cell ) + if ( + self._on_double_click_add_spherical_cell + not in self.viewer.mouse_double_click_callbacks + ): + self.viewer.mouse_double_click_callbacks.append( + self._on_double_click_add_spherical_cell + ) else: - # Remove ALL instances of the callback (in case of duplicates) while ( self._on_mouse_click_add_spherical_cell in self.viewer.mouse_drag_callbacks @@ -562,47 +633,111 @@ def _toggle_add_spherical_cell_mode(self, checked): ) except ValueError: break + while ( + self._on_double_click_add_spherical_cell + in self.viewer.mouse_double_click_callbacks + ): + try: + self.viewer.mouse_double_click_callbacks.remove( + self._on_double_click_add_spherical_cell + ) + except ValueError: + break - def _on_mouse_click_add_spherical_cell(self, viewer, event): - """Handle mouse click when add cell mode is active. + self._restore_double_click_zoom() - Called when user clicks in viewer while add cell mode is on. - """ + def _on_mouse_click_add_spherical_cell(self, viewer, event): + """Handle mouse click when add cell mode is active.""" # Guard: only proceed if mode is actually active if not self._add_cell_mode_active: return # Only trigger on click (not drag) if event.type == "mouse_press": - # Get click position in data coordinates - # Position includes time dimension: (t, z, y, x) or (t, y, x) - position = viewer.cursor.position - - # Extract spatial coordinates (remove time) - if self.databasehandler.ndim == 4: - # 4D data: (t, z, y, x) -> (z, y, x) - position_spatial = position[1:] + current_time = int(viewer.cursor.position[0]) + radius = self.EditingMenu.adding_spherical_cell_radius + + if self.databasehandler.imaging_flag: + # Use ray casting (channel 0) to find the brightest voxel along + # the camera ray — gives accurate 3D placement in oblique views + ch0 = self.databasehandler.imagingArray.get_channel_data(0)[ + current_time + ] + if hasattr(ch0, "compute"): + ch0 = ch0.compute() + nuclear_volume = np.asarray(ch0) + + position_data = self._find_ray_cast_position(viewer, nuclear_volume) + if position_data is None: + return + + # Convert data coords -> world coords expected by add_spherical_cell_at_position + if self.databasehandler.ndim == 4: + scale_array = np.array( + [ + self.databasehandler.z_scale, + self.databasehandler.y_scale, + self.databasehandler.x_scale, + ] + ) + else: + scale_array = np.array( + [ + self.databasehandler.y_scale, + self.databasehandler.x_scale, + ] + ) + position_spatial = tuple(np.array(position_data) * scale_array) else: - # 3D data: (t, y, x) -> (y, x) - position_spatial = position[1:] + # No imaging data: fall back to raw cursor position + position_spatial = tuple(viewer.cursor.position[1:]) - # Add spherical cell at clicked position - self.add_spherical_cell_at_position(position_spatial) + self.add_spherical_cell_at_position( + position_spatial, radius_physical=radius + ) - # Yield to prevent further event propagation yield + def _on_double_click_add_spherical_cell(self, _viewer, _event): + """Double-click: advance to the next frame and arm auto-linking for the next cell.""" + if not self._add_cell_mode_active: + return + self._spherical_wants_link = True + self._advance_one_frame() + + def _suppress_double_click_zoom(self): + """Remove napari's default double-click-to-zoom while an add-cell mode is active.""" + from napari.components._viewer_mouse_bindings import double_click_to_zoom + + if double_click_to_zoom in self.viewer.mouse_double_click_callbacks: + self.viewer.mouse_double_click_callbacks.remove(double_click_to_zoom) + + def _restore_double_click_zoom(self): + """Restore napari's default double-click-to-zoom if no add-cell mode is active.""" + from napari.components._viewer_mouse_bindings import double_click_to_zoom + + spherical_active = getattr(self, "_add_cell_mode_active", False) + instanseg_active = getattr(self, "_add_instanseg_mode_active", False) + if not spherical_active and not instanseg_active: + if double_click_to_zoom not in self.viewer.mouse_double_click_callbacks: + self.viewer.mouse_double_click_callbacks.append(double_click_to_zoom) + def _toggle_add_instanseg_mode(self, checked): """Toggle the add InstanSeg cell mode on/off.""" self._add_instanseg_mode_active = checked + self._last_added_cell = None # reset auto-link chain on every toggle + self._instanseg_wants_link = False if checked: # Mutual exclusion: disable spherical mode if active if hasattr(self, "_add_cell_mode_active") and self._add_cell_mode_active: + self._toggle_add_spherical_cell_mode(False) if hasattr(self.EditingMenu, "add_spherical_cell_btn"): self.EditingMenu.add_spherical_cell_btn.setChecked(False) - # Only add callback if not already present + self._suppress_double_click_zoom() + + # Only add callbacks if not already present if ( self._on_mouse_click_add_instanseg not in self.viewer.mouse_drag_callbacks @@ -610,8 +745,15 @@ def _toggle_add_instanseg_mode(self, checked): self.viewer.mouse_drag_callbacks.append( self._on_mouse_click_add_instanseg ) + if ( + self._on_double_click_add_instanseg + not in self.viewer.mouse_double_click_callbacks + ): + self.viewer.mouse_double_click_callbacks.append( + self._on_double_click_add_instanseg + ) else: - # Remove ALL instances of the callback (in case of duplicates) + # Remove ALL instances of the callbacks (in case of duplicates) while ( self._on_mouse_click_add_instanseg in self.viewer.mouse_drag_callbacks ): @@ -621,6 +763,18 @@ def _toggle_add_instanseg_mode(self, checked): ) except ValueError: break + while ( + self._on_double_click_add_instanseg + in self.viewer.mouse_double_click_callbacks + ): + try: + self.viewer.mouse_double_click_callbacks.remove( + self._on_double_click_add_instanseg + ) + except ValueError: + break + + self._restore_double_click_zoom() def _on_mouse_click_add_instanseg(self, viewer, event): """Handle mouse click when InstanSeg add cell mode is active. @@ -643,6 +797,112 @@ def _on_mouse_click_add_instanseg(self, viewer, event): # Yield to prevent further event propagation yield + def _find_ray_cast_position(self, viewer, nuclear_volume: np.ndarray): + """Find the data-space position of the brightest voxel along the camera ray. + + In 2D canvas mode (no view direction), returns the cursor position directly. + In 3D canvas mode, casts a ray through `nuclear_volume` and returns the + brightest voxel as the seed point. + + Parameters + ---------- + viewer : napari.Viewer + nuclear_volume : np.ndarray + Single-channel spatial volume at the current time frame, shape + (Z, Y, X) for 3D data or (Y, X) for 2D data. Used only in 3D + canvas mode to find the brightest voxel along the ray. + + Returns + ------- + tuple of float (data-space, unscaled) or None on failure. + """ + if self.databasehandler.ndim == 4: + scale_array = np.array( + [ + self.databasehandler.z_scale, + self.databasehandler.y_scale, + self.databasehandler.x_scale, + ] + ) + else: + scale_array = np.array( + [ + self.databasehandler.y_scale, + self.databasehandler.x_scale, + ] + ) + + cursor_pos_world = np.asarray(viewer.cursor.position, dtype=float) + cursor_spatial_world = cursor_pos_world[1:] # drop time dimension + origin_data = cursor_spatial_world / scale_array + + vd_world = getattr(viewer.cursor, "_view_direction", None) + if vd_world is None or np.allclose(vd_world, 0): + # 2D canvas mode: use cursor position directly + return tuple(np.round(origin_data).astype(float)) + + # 3D canvas mode: cast ray through nuclear_volume + vd_world = np.asarray(vd_world, dtype=float) + vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world + vd_data = vd_spatial_world / scale_array + norm = np.linalg.norm(vd_data) + if norm == 0: + show_warning("Invalid view direction. Try clicking again.") + return None + vd_data /= norm + + spatial_shape = nuclear_volume.shape + diag = int(np.sqrt(sum(s**2 for s in spatial_shape))) + t_values = np.arange(-diag, diag + 1, dtype=float) + ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :] + ray_voxels = np.round(ray_points).astype(int) + + if self.databasehandler.ndim == 4: + valid = ( + (ray_voxels[:, 0] >= 0) + & (ray_voxels[:, 0] < spatial_shape[0]) + & (ray_voxels[:, 1] >= 0) + & (ray_voxels[:, 1] < spatial_shape[1]) + & (ray_voxels[:, 2] >= 0) + & (ray_voxels[:, 2] < spatial_shape[2]) + ) + else: + valid = ( + (ray_voxels[:, 0] >= 0) + & (ray_voxels[:, 0] < spatial_shape[0]) + & (ray_voxels[:, 1] >= 0) + & (ray_voxels[:, 1] < spatial_shape[1]) + ) + + ray_voxels = ray_voxels[valid] + if len(ray_voxels) == 0: + show_warning("Ray does not intersect volume. Try clicking inside the data.") + return None + + ray_voxels = np.unique(ray_voxels, axis=0) + if self.databasehandler.ndim == 4: + intensities = nuclear_volume[ + ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2] + ] + else: + intensities = nuclear_volume[ray_voxels[:, 0], ray_voxels[:, 1]] + + best_voxel = ray_voxels[int(np.argmax(intensities))] + return tuple(best_voxel.astype(float)) + + def _on_double_click_add_instanseg(self, _viewer, _event): + """Double-click: advance to the next frame and arm auto-linking for the next cell.""" + if not self._add_instanseg_mode_active: + return + self._instanseg_wants_link = True + self._advance_one_frame() + + def _advance_one_frame(self): + current = self.viewer.dims.current_step[0] + max_frame = int(self.viewer.dims.range[0].stop) + if current < max_frame: + self.viewer.dims.set_current_step(0, current + 1) + def add_instanseg_cell_at_position(self, viewer, current_time): """Add a new cell using InstanSeg segmentation at the clicked position. @@ -686,95 +946,10 @@ def add_instanseg_cell_at_position(self, viewer, current_time): volumes.append(np.asarray(ch_data)) image_volume = np.stack(volumes, axis=0) # (C, Z, Y, X) or (C, Y, X) - # Get view direction for ray casting (unavailable in napari's 2D canvas mode) - vd_world = getattr(viewer.cursor, "_view_direction", None) - - # Get cursor position in world coordinates - cursor_pos_world = np.asarray(viewer.cursor.position, dtype=float) - - # Extract spatial position (remove time dimension) - if self.databasehandler.ndim == 4: - cursor_spatial_world = cursor_pos_world[1:] # (z, y, x) - scale_array = np.array( - [ - self.databasehandler.z_scale, - self.databasehandler.y_scale, - self.databasehandler.x_scale, - ] - ) - else: - cursor_spatial_world = cursor_pos_world[1:] # (y, x) - scale_array = np.array( - [self.databasehandler.y_scale, self.databasehandler.x_scale] - ) - - origin_data = cursor_spatial_world / scale_array - - if vd_world is None or np.allclose(vd_world, 0): - # 2D canvas mode: no view direction available, use cursor position directly - position_data = tuple(np.round(origin_data).astype(float)) - else: - # 3D canvas mode: cast ray and find brightest voxel along it - # Convert view direction to data coordinates - vd_world = np.asarray(vd_world, dtype=float) - # View direction is also 4D (time + spatial), extract only spatial components - vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world - vd_data = vd_spatial_world / scale_array - norm = np.linalg.norm(vd_data) - if norm == 0: - show_warning("Invalid view direction. Try clicking again.") - return None - vd_data /= norm - - # image_volume is (C, Z, Y, X) or (C, Y, X); spatial shape excludes channel dim - spatial_shape = image_volume.shape[1:] - nuclear_volume = image_volume[ - 0 - ] # use nuclear channel (idx 0) for ray casting - - # Cast ray through volume in both directions - diag = int(np.sqrt(sum(s**2 for s in spatial_shape))) - t_values = np.arange(-diag, diag + 1, dtype=float) - ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :] - ray_voxels = np.round(ray_points).astype(int) - - # Keep only voxels inside the volume - if self.databasehandler.ndim == 4: - valid = ( - (ray_voxels[:, 0] >= 0) - & (ray_voxels[:, 0] < spatial_shape[0]) - & (ray_voxels[:, 1] >= 0) - & (ray_voxels[:, 1] < spatial_shape[1]) - & (ray_voxels[:, 2] >= 0) - & (ray_voxels[:, 2] < spatial_shape[2]) - ) - else: - valid = ( - (ray_voxels[:, 0] >= 0) - & (ray_voxels[:, 0] < spatial_shape[0]) - & (ray_voxels[:, 1] >= 0) - & (ray_voxels[:, 1] < spatial_shape[1]) - ) - - ray_voxels = ray_voxels[valid] - if len(ray_voxels) == 0: - show_warning( - "Ray does not intersect volume. Try clicking inside the data." - ) - return None - - # Deduplicate and find voxel with maximum intensity in nuclear channel - ray_voxels = np.unique(ray_voxels, axis=0) - if self.databasehandler.ndim == 4: - intensities = nuclear_volume[ - ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2] - ] - else: - intensities = nuclear_volume[ray_voxels[:, 0], ray_voxels[:, 1]] - - best_idx = int(np.argmax(intensities)) - best_voxel = ray_voxels[best_idx] - position_data = tuple(best_voxel.astype(float)) + # Find seed position via ray casting (channel 0 = nuclear channel) + position_data = self._find_ray_cast_position(viewer, image_volume[0]) + if position_data is None: + return None # Prepare scale tuple and handle 2D vs 3D if self.databasehandler.ndim == 4: @@ -832,6 +1007,14 @@ def add_instanseg_cell_at_position(self, viewer, current_time): mask = mask.squeeze(0) # (Y, X) bbox = (bbox[0][1:], bbox[1][1:]) # Remove z from bbox + bbox_arr = np.concatenate(bbox) if isinstance(bbox, tuple) else np.asarray(bbox) + mask = self._mask_without_overlap(mask, bbox_arr) + if not mask.any(): + show_warning( + "InstanSeg cell fully overlaps existing segmentation — skipped." + ) + return None + # Add to database try: new_node_id = add_new_node( @@ -854,38 +1037,53 @@ def add_instanseg_cell_at_position(self, viewer, current_time): return None # Add to tracking system - track_ids = ( - self.NavigationWidget.tracks_viewer.tracks_controller.tracks.track_id_to_node.keys() - ) + tc = self.tracksviewer.tracks_controller + track_ids = tc.tracks.track_id_to_node.keys() max_track_id = max(track_ids) if track_ids else 0 time_in_chunk = current_time - self.databasehandler.time_window[0] + # Auto-link to previous cell only if the user double-clicked to advance here + link_to_prev = ( + self._instanseg_wants_link + and self._last_added_cell is not None + and current_time == self._last_added_cell[1] + 1 + ) + self._instanseg_wants_link = False # consume the flag + use_track_id = ( + max_track_id + 1 + ) # always fresh; edge creation handles track continuity + attributes = { NodeAttr.TIME.value: [time_in_chunk], - NodeAttr.TRACK_ID.value: [max_track_id + 1], + NodeAttr.TRACK_ID.value: [use_track_id], "node_id": [new_node_id], } - self.tracksviewer.tracks_controller.add_nodes( - attributes, [(np.array([0, 0, 0]))] - ) + nodes_action, _ = tc._add_nodes(attributes, [(np.array([0, 0, 0]))]) - # Refresh and auto-disable + if link_to_prev: + edge = np.array([[self._last_added_cell[0], new_node_id]]) + is_valid, valid_action = tc.is_valid( + [self._last_added_cell[0], new_node_id] + ) + if is_valid: + edge_action = tc._add_edges(edge) + extra = [valid_action, edge_action] if valid_action else [edge_action] + action = ActionGroup(tc.tracks, [nodes_action] + extra) + else: + action = nodes_action + else: + action = nodes_action + + tc.action_history.add_new_action(action) + tc.tracks.refresh.emit(new_node_id) + + self._last_added_cell = (new_node_id, current_time, use_track_id) + + # Refresh segmentation layer show_info(f"Added InstanSeg cell with ID {new_node_id}") self.databasehandler.segments.force_refill() self.viewer.layers[self.databasehandler.name + "_seg"].refresh() - # Auto-disable only if button exists - if hasattr(self.EditingMenu, "add_instanseg_cell_btn"): - from qtpy.QtCore import QTimer - - QTimer.singleShot( - 0, - lambda: ( - self.EditingMenu.add_instanseg_cell_btn.setChecked(False), - self._toggle_add_instanseg_mode(False), - ), - ) - return new_node_id def _convert_viewer_to_data_coords(self, position_spatial: tuple) -> tuple: @@ -912,3 +1110,195 @@ def _convert_viewer_to_data_coords(self, position_spatial: tuple) -> tuple: y_unscaled = position_spatial[0] / self.databasehandler.y_scale x_unscaled = position_spatial[1] / self.databasehandler.x_scale return (y_unscaled, x_unscaled) + + # =============================================== + # Split cell + # =============================================== + + def split_cell(self, method="Watershed (image)"): + selected = list(self.tracksviewer.selected_nodes) + if not selected: + return + tc = self.tracksviewer.tracks_controller + all_actions = [] + with progress(selected, desc="Splitting cells") as prog: + for node_id in prog: + if method == "K-means": + actions = self._split_single_cell(node_id) + else: + seed_source = ( + "image" if method == "Watershed (image)" else "distance" + ) + actions = self._split_single_cell_watershed(node_id, seed_source) + if actions: + all_actions.extend(actions) + if all_actions: + tc.action_history.add_new_action(ActionGroup(tc.tracks, all_actions)) + tc.tracks.refresh.emit() + + def _split_single_cell(self, node_id: int): + """Split a single node into two using K-means clustering on isotropic coordinates.""" + pickle_data = get_node_values( + self.databasehandler.config_adjusted.data_config, + [int(node_id)], + NodeDB.pickle, + ) + time = get_node_values( + self.databasehandler.config_adjusted.data_config, + int(node_id), + NodeDB.t, + ) + mask = pickle_data.mask + bbox = np.array(pickle_data.bbox) + + if not mask.any(): + return None + + coords = np.argwhere(mask) + if len(coords) < 2: + show_info(f"K-means: not enough voxels in node {node_id}.") + return None + + scale = np.array(self.databasehandler.scale) + coords_iso = coords * scale + kmeans = KMeans(n_clusters=2, random_state=42, n_init=3) + labels = kmeans.fit_predict(coords_iso) + + cluster_masks = [] + for label_val in [0, 1]: + cm = np.zeros_like(mask, dtype=bool) + idx = coords[labels == label_val] + if len(idx) == 0: + show_info(f"K-means: empty partition for node {node_id}.") + return None + cm[tuple(idx.T)] = True + cluster_masks.append(cm) + + return self._finish_split(node_id, time, bbox, cluster_masks) + + def _split_single_cell_watershed(self, node_id: int, seed_source: str): + """Split a single node into two using watershed segmentation. + + Parameters + ---------- + node_id : int + The node to split. + seed_source : str + "distance" — seed from distance transform peaks. + "image" — seed from image intensity peaks. + """ + pickle_data = get_node_values( + self.databasehandler.config_adjusted.data_config, + [int(node_id)], + NodeDB.pickle, + ) + time = get_node_values( + self.databasehandler.config_adjusted.data_config, + int(node_id), + NodeDB.t, + ) + mask = pickle_data.mask + bbox = np.array(pickle_data.bbox) + + if not mask.any(): + return None + + scale = np.array(self.databasehandler.scale) + + if seed_source == "image": + img = self.databasehandler.imagingArray.get_channel_data(0)[ + time - self.databasehandler.time_window[0] + ] + ndim = mask.ndim + if ndim == 3: + z0, y0, x0, z1, y1, x1 = bbox + field = img[z0:z1, y0:y1, x0:x1].astype(float) + else: + y0, x0, y1, x1 = bbox + field = img[y0:y1, x0:x1].astype(float) + sigma = 1.0 / scale + field = ndimage.gaussian_filter(field, sigma=sigma) + field[~mask] = 0 + else: # distance + field = ndimage.distance_transform_edt(mask, sampling=scale) + + min_distance = max(3, int(min(mask.shape) // 3)) + peaks = np.empty((0, mask.ndim), dtype=int) + while min_distance >= 1: + peaks = peak_local_max( + field, min_distance=min_distance, labels=mask.astype(int) + ) + if len(peaks) >= 2: + break + min_distance -= 1 + + if len(peaks) < 2: + show_info(f"Watershed: could not find 2 peaks in node {node_id}.") + return None + + peak_values = field[tuple(peaks.T)] + top2_idx = np.argsort(peak_values)[-2:] + peaks = peaks[top2_idx] + + markers = np.zeros_like(mask, dtype=np.int32) + markers[tuple(peaks[0])] = 1 + markers[tuple(peaks[1])] = 2 + ws = watershed(-field, markers, mask=mask) + + cluster_masks = [] + for label_val in [1, 2]: + cm = (ws == label_val) & mask + if not cm.any(): + show_info(f"Watershed: empty partition for node {node_id}.") + return None + cluster_masks.append(cm) + + return self._finish_split(node_id, time, bbox, cluster_masks) + + def _finish_split( + self, node_id: int, time: int, bbox: np.ndarray, cluster_masks: list + ): + """Delete the original node and add two new nodes from the given masks. + + Returns the flat list of actions so the caller can group them with + actions from other splits into a single history entry. + """ + tc = self.tracksviewer.tracks_controller + + # Get the delete action without adding to history yet + delete_action = tc._delete_nodes([node_id]) + + # Also delete any skip edges that motile auto-added after the deletion + # (mirrors the logic in my_delete_nodes in motile_overwrites.py) + all_actions = list(delete_action.actions) + for action in delete_action.actions: + if isinstance(action, AddEdges): + skip_delete = tc._delete_edges(np.array(action.edges)) + all_actions += list(skip_delete.actions) + + track_ids = tc.tracks.track_id_to_node.keys() + max_track_id = max(track_ids) if track_ids else 0 + + time_in_chunk = time - self.databasehandler.time_window[0] + for i, cm in enumerate(cluster_masks): + new_id = add_new_node( + self.databasehandler.config_adjusted, + time=time, + mask=cm, + bbox=bbox, + include_overlaps=True, + ) + fix_overlap_ancestor_ids( + database_path=self.databasehandler.config_adjusted.data_config.database_path, + new_node_id=new_id, + current_time=time, + ) + attributes = { + NodeAttr.TIME.value: [time_in_chunk], + NodeAttr.TRACK_ID.value: [max_track_id + 1 + i], + "node_id": [new_id], + } + action, _ = tc._add_nodes(attributes, [np.array([0, 0, 0])]) + all_actions.append(action) + + return all_actions diff --git a/trackedit/_tests/test_UI_actions.py b/trackedit/_tests/test_UI_actions.py index 6acf983..9ad57a0 100644 --- a/trackedit/_tests/test_UI_actions.py +++ b/trackedit/_tests/test_UI_actions.py @@ -56,7 +56,7 @@ def test_trackedit_widgets( imaging_zarr_file="", imaging_channel="", viewer=viewer, - flag_allow_adding_spherical_cell=True, # Enable spherical cell feature for testing + flag_allow_adding_spherical_cell=True, ) # Get the NavigationWidget directly from TrackEdit instance @@ -85,6 +85,7 @@ def test_trackedit_widgets( check_time_box(time_box) check_editing(TV, editing_menu) check_add_spherical_cell(track_edit, editing_menu) + check_split_cell(track_edit, TV) check_red_flag_box(TV, red_flag_box, time_box) check_division_box(division_box) check_annotation(toAnnotateBox) @@ -209,7 +210,7 @@ def check_add_spherical_cell(track_edit, editing_menu): # Call the method directly to add a cell new_node_id = track_edit.add_spherical_cell_at_position( - position_scaled=position, radius_pixels=10 + position_scaled=position, radius_physical=10 ) # Verify a node was created @@ -244,6 +245,82 @@ def check_add_spherical_cell(track_edit, editing_menu): track_edit.tracksviewer.undo() +def check_split_cell(track_edit, TV): + """Check cell splitting functionality and single-step undo.""" + tc = TV.tracks_controller + + def graph_nodes(): + return set(tc.tracks.graph.nodes()) + + # --- K-means split --- + node_id = 2000001 + TV.selected_nodes.add(node_id, append=False) + nodes_before = graph_nodes() + + track_edit.split_cell("K-means") + + nodes_after = graph_nodes() + assert node_id not in nodes_after, "Original node should be removed after split" + new_nodes = nodes_after - nodes_before + assert len(new_nodes) == 2, f"Expected 2 new nodes, got {len(new_nodes)}" + assert ( + len(nodes_after) == len(nodes_before) + 1 + ), "Net node count should increase by 1" + + # Single undo should restore the original node and remove both new ones + TV.undo() + nodes_after_undo = graph_nodes() + assert node_id in nodes_after_undo, "Original node should be restored after undo" + assert not (new_nodes & nodes_after_undo), "New nodes should be gone after undo" + assert ( + nodes_after_undo == nodes_before + ), "Graph should match pre-split state after undo" + + # --- Watershed (distance) split --- + node_id2 = 2000002 + TV.selected_nodes.add(node_id2, append=False) + nodes_before2 = graph_nodes() + + track_edit.split_cell("Watershed (distance)") + + nodes_after2 = graph_nodes() + assert ( + node_id2 not in nodes_after2 + ), "Original node should be removed after watershed split" + new_nodes2 = nodes_after2 - nodes_before2 + assert len(new_nodes2) == 2, f"Expected 2 new nodes, got {len(new_nodes2)}" + + TV.undo() + nodes_after_undo2 = graph_nodes() + assert node_id2 in nodes_after_undo2, "Original node should be restored after undo" + assert ( + nodes_after_undo2 == nodes_before2 + ), "Graph should match pre-split state after undo" + + # --- Multi-cell split (two cells at once) --- + node_a, node_b = 2000001, 2000002 + TV.selected_nodes.add(node_a, append=False) + TV.selected_nodes.add(node_b, append=True) + nodes_before_multi = graph_nodes() + + track_edit.split_cell("K-means") + + nodes_after_multi = graph_nodes() + assert node_a not in nodes_after_multi, "Node A should be removed" + assert node_b not in nodes_after_multi, "Node B should be removed" + new_nodes_multi = nodes_after_multi - nodes_before_multi + assert ( + len(new_nodes_multi) == 4 + ), f"Expected 4 new nodes from splitting 2 cells, got {len(new_nodes_multi)}" + + # All splits are grouped into a single undo step + TV.undo() + nodes_final = graph_nodes() + assert node_a in nodes_final, "Node A should be restored" + assert node_b in nodes_final, "Node B should be restored" + assert nodes_final == nodes_before_multi, "Graph should fully match pre-split state" + + def check_red_flag_box(TV, red_flag_box, time_box): """Check red flag box functionality""" diff --git a/trackedit/cli.py b/trackedit/cli.py index e8131b5..fe4470d 100644 --- a/trackedit/cli.py +++ b/trackedit/cli.py @@ -4,6 +4,7 @@ import click +from trackedit.utils.crop import crop_database_in_time from trackedit.utils.geff import convert_geff_to_db @@ -36,5 +37,27 @@ def geff_to_db(geff_path: Path, output: Path = None): convert_geff_to_db(geff_path, output) +@cli.command("crop-db") +@click.argument("source_db", type=click.Path(exists=True, path_type=Path)) +@click.option( + "--max-t", + required=True, + type=int, + help="Maximum time frame to include (inclusive).", +) +@click.option( + "--output", + "-o", + type=click.Path(path_type=Path), + help="Output database path (default: _t0-.db)", +) +def crop_db(source_db: Path, max_t: int, output: Path = None): + """Crop an Ultrack SQLite database to the first MAX_T frames.""" + if output is None: + output = source_db.parent / f"{source_db.stem}_t0-{max_t}.db" + crop_database_in_time(source_db, output, max_t) + print(f"Cropped database written to {output}") + + if __name__ == "__main__": cli() diff --git a/trackedit/motile_overwrites.py b/trackedit/motile_overwrites.py index c24791f..ea08472 100644 --- a/trackedit/motile_overwrites.py +++ b/trackedit/motile_overwrites.py @@ -13,6 +13,7 @@ from qtpy.QtGui import QColor from ultrack.core.database import NodeDB, get_node_values +import motile_tracker.data_views.views.tree_view.tree_widget_utils as _tree_widget_utils from motile_tracker.data_model.actions import ActionGroup, AddEdges, DeleteEdges from motile_tracker.data_model.solution_tracks import SolutionTracks from motile_tracker.data_model.tracks_controller import TracksController @@ -26,6 +27,19 @@ Edge: TypeAlias = tuple[Node, Node] +# Bug fix: original uses parent_map.get(current_track) which returns None for +# missing keys, and None != 0 is always True — causing an infinite loop when +# all root cells of a division are deleted. +def _patched_find_root(track_id: int, parent_map: dict) -> int: + current_track = track_id + while parent_map.get(current_track, 0) != 0: + current_track = parent_map.get(current_track) + return current_track + + +_tree_widget_utils.find_root = _patched_find_root + + def create_db_add_nodes(DB_handler): def db_add_nodes(self): # don't use full old function, because it includes painting pixels in segmentation diff --git a/trackedit/run.py b/trackedit/run.py index bee4383..2652a97 100644 --- a/trackedit/run.py +++ b/trackedit/run.py @@ -41,8 +41,8 @@ def run_trackedit( image_translate: Optional[Tuple[float, ...]] = None, viewer: Optional[napari.Viewer] = None, flag_show_hierarchy: bool = True, - flag_allow_adding_spherical_cell: bool = False, - adding_spherical_cell_radius: int = 10, + flag_allow_adding_spherical_cell: bool = True, + adding_spherical_cell_radius: int = 5, flag_allow_adding_instanseg_cell: bool = False, instanseg_model_path: Optional[str] = None, instanseg_device: Optional[str] = None, @@ -71,7 +71,7 @@ def run_trackedit( viewer: Optional existing napari viewer flag_show_hierarchy: Show hierarchy in the viewer flag_allow_adding_spherical_cell: Allow adding spherical cells via button (default: False) - adding_spherical_cell_radius: Radius of spherical cells in pixels (default: 10) + adding_spherical_cell_radius: Radius of spherical cells in physical units (default: 10) flag_allow_adding_instanseg_cell: Allow adding InstanSeg-segmented cells via button (default: False) instanseg_model_path: Path to InstanSeg TorchScript model file (required if flag_allow_adding_instanseg_cell=True) diff --git a/trackedit/utils/crop.py b/trackedit/utils/crop.py new file mode 100644 index 0000000..87b2d80 --- /dev/null +++ b/trackedit/utils/crop.py @@ -0,0 +1,71 @@ +import shutil +import sqlite3 +from pathlib import Path + + +def crop_database_in_time(source_db: Path, output_db: Path, max_t: int) -> None: + """Create a time-cropped copy of an ultrack SQLite database. + + Copies nodes, links, and overlaps tables from the source database, + keeping only entries up to and including time frame `max_t`. + The gt_nodes and gt_links tables are copied as-is (assumed empty). + + Args: + source_db: Path to the source .db file. + output_db: Path where the cropped copy will be written. + max_t: Maximum time frame to include (inclusive). Nodes with t <= max_t + are kept; links and overlaps are filtered to only reference + surviving node IDs. + """ + shutil.copy2(source_db, output_db) + + conn = sqlite3.connect(output_db) + try: + # Remove nodes outside time range + conn.execute("DELETE FROM nodes WHERE t > ?", (max_t,)) + + # Collect surviving node IDs + surviving_ids = { + row[0] for row in conn.execute("SELECT id FROM nodes").fetchall() + } + + # Remove links where either endpoint is gone + all_links = conn.execute( + "SELECT id, source_id, target_id FROM links" + ).fetchall() + link_ids_to_delete = [ + row[0] + for row in all_links + if row[1] not in surviving_ids or row[2] not in surviving_ids + ] + if link_ids_to_delete: + conn.execute( + f"DELETE FROM links WHERE id IN ({','.join('?' * len(link_ids_to_delete))})", + link_ids_to_delete, + ) + + # Remove overlaps where either node_id or ancestor_id is gone + all_overlaps = conn.execute( + "SELECT rowid, node_id, ancestor_id FROM overlaps" + ).fetchall() + overlap_rowids_to_delete = [ + row[0] + for row in all_overlaps + if row[1] not in surviving_ids or row[2] not in surviving_ids + ] + if overlap_rowids_to_delete: + conn.execute( + f"DELETE FROM overlaps WHERE rowid IN ({','.join('?' * len(overlap_rowids_to_delete))})", + overlap_rowids_to_delete, + ) + + conn.commit() + conn.execute("VACUUM") + finally: + conn.close() + + +if __name__ == "__main__": + from trackedit.cli import cli + + cli() diff --git a/trackedit/utils/utils.py b/trackedit/utils/utils.py index 1fa6a49..6c75745 100644 --- a/trackedit/utils/utils.py +++ b/trackedit/utils/utils.py @@ -396,7 +396,7 @@ def calculate_bbox_from_mask(mask): def create_cell_mask_and_bbox( position_scaled: np.ndarray, - radius_pixels: float, + radius_physical: float, ndim: int, scale: tuple, data_shape_full: tuple, @@ -407,8 +407,8 @@ def create_cell_mask_and_bbox( ---------- position_scaled : array-like Position in viewer coordinates (scaled) - radius_pixels : float - Radius of the sphere in pixels + radius_physical : float + Radius of the sphere in physical units (same units as scale) ndim : int Number of dimensions (3 for 2D+t, 4 for 3D+t) scale : tuple @@ -435,9 +435,9 @@ def create_cell_mask_and_bbox( ] ) radii = ( - radius_pixels / z_scale, - radius_pixels / y_scale, - radius_pixels / x_scale, + radius_physical / z_scale, + radius_physical / y_scale, + radius_physical / x_scale, ) else: y_scale, x_scale = scale @@ -449,8 +449,8 @@ def create_cell_mask_and_bbox( ] ) radii = ( - radius_pixels / y_scale, - radius_pixels / x_scale, + radius_physical / y_scale, + radius_physical / x_scale, ) # Calculate bounding box diff --git a/trackedit/widgets/CustomEditingWidget.py b/trackedit/widgets/CustomEditingWidget.py index 04d7018..3902495 100644 --- a/trackedit/widgets/CustomEditingWidget.py +++ b/trackedit/widgets/CustomEditingWidget.py @@ -1,7 +1,14 @@ import napari from PyQt5.QtGui import QIntValidator, QValidator from qtpy.QtCore import Signal -from qtpy.QtWidgets import QHBoxLayout, QLabel, QLineEdit, QPushButton +from qtpy.QtWidgets import ( + QComboBox, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QSpinBox, +) from motile_tracker.application_menus.editing_menu import EditingMenu from trackedit.DatabaseHandler import DatabaseHandler @@ -11,6 +18,7 @@ class CustomEditingMenu(EditingMenu): add_cell_button_pressed = Signal(int) duplicate_cell_button_pressed = Signal(int, int) + split_cell_button_pressed = Signal(str) add_spherical_cell_toggled = Signal(bool) # Signal for spherical cell mode toggle add_instanseg_cell_toggled = Signal(bool) # Signal for InstanSeg cell mode toggle @@ -65,24 +73,45 @@ def __init__( duplicate_cell_layout.addWidget(QLabel("to t=")) duplicate_cell_layout.addWidget(self.duplicate_time_input) - # Retrieve the node_box widget from the layout and insert add/duplicate cell layouts + # split cell + self.split_method_combo = QComboBox() + self.split_method_combo.addItems( + ["Watershed (image)", "Watershed (distance)", "K-means"] + ) + self.split_btn = QPushButton("Split cell") + self.split_btn.clicked.connect(self._on_split_cell_clicked) + + split_layout = QHBoxLayout() + split_layout.addWidget(self.split_btn) + split_layout.addWidget(self.split_method_combo) + + # Retrieve the node_box widget from the layout and insert add/duplicate/split cell layouts node_box = main_layout.itemAt(1).widget() node_box.layout().addLayout(add_cell_layout) node_box.layout().addLayout(duplicate_cell_layout) + node_box.layout().addLayout(split_layout) # Conditionally add spherical cell button if self.allow_adding_spherical_cell: - self.add_spherical_cell_btn = QPushButton( - f"Add Spherical Cell (R={self.adding_spherical_cell_radius}px)" - ) - self.add_spherical_cell_btn.setCheckable(True) # Toggle on/off + self.add_spherical_cell_btn = QPushButton("Add Spherical Cell") + self.add_spherical_cell_btn.setCheckable(True) self.add_spherical_cell_btn.setStyleSheet( "QPushButton:checked { background-color: #4CAF50; color: white; }" ) self.add_spherical_cell_btn.clicked.connect(self._on_spherical_cell_clicked) + self.sphere_radius_spinbox = QSpinBox() + self.sphere_radius_spinbox.setRange(1, 100) + self.sphere_radius_spinbox.setValue(self.adding_spherical_cell_radius) + self.sphere_radius_spinbox.setSuffix(" µm") + self.sphere_radius_spinbox.setFixedWidth(60) + self.sphere_radius_spinbox.valueChanged.connect( + self._on_sphere_radius_changed + ) + spherical_cell_layout = QHBoxLayout() spherical_cell_layout.addWidget(self.add_spherical_cell_btn) + spherical_cell_layout.addWidget(self.sphere_radius_spinbox) node_box.layout().addLayout(spherical_cell_layout) @@ -105,14 +134,14 @@ def __init__( [self.allow_adding_spherical_cell, self.allow_adding_instanseg_cell] ) if num_extra_buttons == 0: - node_box.setMaximumHeight(150) # Original height - self.setMaximumHeight(430) # Original height + node_box.setMaximumHeight(200) # +50 for split row + self.setMaximumHeight(480) elif num_extra_buttons == 1: - node_box.setMaximumHeight(200) # One extra button - self.setMaximumHeight(480) # One extra button + node_box.setMaximumHeight(250) + self.setMaximumHeight(530) else: # num_extra_buttons == 2 - node_box.setMaximumHeight(250) # Two extra buttons - self.setMaximumHeight(530) # Two extra buttons + node_box.setMaximumHeight(300) + self.setMaximumHeight(580) def update_add_cell_btn_state(self, text): state, _, _ = self.add_cell_input.validator().validate(text, 0) @@ -153,3 +182,9 @@ def _on_spherical_cell_clicked(self, checked): def _on_instanseg_cell_clicked(self, checked): """Emit signal when InstanSeg cell button is toggled.""" self.add_instanseg_cell_toggled.emit(checked) + + def _on_sphere_radius_changed(self, value: int): + self.adding_spherical_cell_radius = value + + def _on_split_cell_clicked(self): + self.split_cell_button_pressed.emit(self.split_method_combo.currentText())