diff --git a/trackedit/TrackEditClass.py b/trackedit/TrackEditClass.py index 43e6ba7..65428ce 100644 --- a/trackedit/TrackEditClass.py +++ b/trackedit/TrackEditClass.py @@ -686,13 +686,8 @@ def add_instanseg_cell_at_position(self, viewer, current_time): volumes.append(np.asarray(ch_data)) image_volume = np.stack(volumes, axis=0) # (C, Z, Y, X) or (C, Y, X) - # Get view direction for ray casting + # Get view direction for ray casting (unavailable in napari's 2D canvas mode) vd_world = getattr(viewer.cursor, "_view_direction", None) - if vd_world is None or np.allclose(vd_world, 0): - show_warning( - "Could not get view direction from napari cursor. Try clicking again." - ) - return None # Get cursor position in world coordinates cursor_pos_world = np.asarray(viewer.cursor.position, dtype=float) @@ -713,64 +708,73 @@ def add_instanseg_cell_at_position(self, viewer, current_time): [self.databasehandler.y_scale, self.databasehandler.x_scale] ) - # Convert view direction and cursor position to data coordinates - vd_world = np.asarray(vd_world, dtype=float) - # View direction is also 4D (time + spatial), extract only spatial components - vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world - vd_data = vd_spatial_world / scale_array - norm = np.linalg.norm(vd_data) - if norm == 0: - show_warning("Invalid view direction. Try clicking again.") - return None - vd_data /= norm - origin_data = cursor_spatial_world / scale_array - # image_volume is (C, Z, Y, X) or (C, Y, X); spatial shape excludes channel dim - spatial_shape = image_volume.shape[1:] - nuclear_volume = image_volume[0] # use nuclear channel (idx 0) for ray casting - - # Cast ray through volume in both directions - diag = int(np.sqrt(sum(s**2 for s in spatial_shape))) - t_values = np.arange(-diag, diag + 1, dtype=float) - ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :] - ray_voxels = np.round(ray_points).astype(int) - - # Keep only voxels inside the volume - if self.databasehandler.ndim == 4: - valid = ( - (ray_voxels[:, 0] >= 0) - & (ray_voxels[:, 0] < spatial_shape[0]) - & (ray_voxels[:, 1] >= 0) - & (ray_voxels[:, 1] < spatial_shape[1]) - & (ray_voxels[:, 2] >= 0) - & (ray_voxels[:, 2] < spatial_shape[2]) - ) + if vd_world is None or np.allclose(vd_world, 0): + # 2D canvas mode: no view direction available, use cursor position directly + position_data = tuple(np.round(origin_data).astype(float)) else: - valid = ( - (ray_voxels[:, 0] >= 0) - & (ray_voxels[:, 0] < spatial_shape[0]) - & (ray_voxels[:, 1] >= 0) - & (ray_voxels[:, 1] < spatial_shape[1]) - ) + # 3D canvas mode: cast ray and find brightest voxel along it + # Convert view direction to data coordinates + vd_world = np.asarray(vd_world, dtype=float) + # View direction is also 4D (time + spatial), extract only spatial components + vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world + vd_data = vd_spatial_world / scale_array + norm = np.linalg.norm(vd_data) + if norm == 0: + show_warning("Invalid view direction. Try clicking again.") + return None + vd_data /= norm + + # image_volume is (C, Z, Y, X) or (C, Y, X); spatial shape excludes channel dim + spatial_shape = image_volume.shape[1:] + nuclear_volume = image_volume[ + 0 + ] # use nuclear channel (idx 0) for ray casting + + # Cast ray through volume in both directions + diag = int(np.sqrt(sum(s**2 for s in spatial_shape))) + t_values = np.arange(-diag, diag + 1, dtype=float) + ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :] + ray_voxels = np.round(ray_points).astype(int) + + # Keep only voxels inside the volume + if self.databasehandler.ndim == 4: + valid = ( + (ray_voxels[:, 0] >= 0) + & (ray_voxels[:, 0] < spatial_shape[0]) + & (ray_voxels[:, 1] >= 0) + & (ray_voxels[:, 1] < spatial_shape[1]) + & (ray_voxels[:, 2] >= 0) + & (ray_voxels[:, 2] < spatial_shape[2]) + ) + else: + valid = ( + (ray_voxels[:, 0] >= 0) + & (ray_voxels[:, 0] < spatial_shape[0]) + & (ray_voxels[:, 1] >= 0) + & (ray_voxels[:, 1] < spatial_shape[1]) + ) - ray_voxels = ray_voxels[valid] - if len(ray_voxels) == 0: - show_warning("Ray does not intersect volume. Try clicking inside the data.") - return None + ray_voxels = ray_voxels[valid] + if len(ray_voxels) == 0: + show_warning( + "Ray does not intersect volume. Try clicking inside the data." + ) + return None - # Deduplicate and find voxel with maximum intensity in nuclear channel - ray_voxels = np.unique(ray_voxels, axis=0) - if self.databasehandler.ndim == 4: - intensities = nuclear_volume[ - ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2] - ] - else: - intensities = nuclear_volume[ray_voxels[:, 0], ray_voxels[:, 1]] + # Deduplicate and find voxel with maximum intensity in nuclear channel + ray_voxels = np.unique(ray_voxels, axis=0) + if self.databasehandler.ndim == 4: + intensities = nuclear_volume[ + ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2] + ] + else: + intensities = nuclear_volume[ray_voxels[:, 0], ray_voxels[:, 1]] - best_idx = int(np.argmax(intensities)) - best_voxel = ray_voxels[best_idx] - position_data = tuple(best_voxel.astype(float)) + best_idx = int(np.argmax(intensities)) + best_voxel = ray_voxels[best_idx] + position_data = tuple(best_voxel.astype(float)) # Prepare scale tuple and handle 2D vs 3D if self.databasehandler.ndim == 4: diff --git a/trackedit/motile_overwrites.py b/trackedit/motile_overwrites.py index 83977fb..c24791f 100644 --- a/trackedit/motile_overwrites.py +++ b/trackedit/motile_overwrites.py @@ -355,6 +355,68 @@ def fixed_click(layer, event): # TrackLabels.get_status = get_status +def create_center_view(DB_handler): + """Override TracksLayerGroup.center_view to convert world positions to data + coordinates when setting viewer.dims.current_step. + + Positions in the trackgraph are stored in world (scaled) units, but + current_step expects voxel indices, so we divide by the layer scale. + """ + + def center_view_with_scale(self, node): + if self.seg_layer is None or self.seg_layer.mode == "pan_zoom": + location = self.tracks.get_positions([node], incl_time=True)[0].tolist() + assert ( + len(location) == self.viewer.dims.ndim + ), f"Location {location} does not match viewer number of dims {self.viewer.dims.ndim}" + + # Build per-dimension scale: dim 0 is time (unscaled) + if DB_handler.ndim == 4: # (t, z, y, x) + scale_by_dim = { + 0: 1.0, + 1: DB_handler.z_scale, + 2: DB_handler.y_scale, + 3: DB_handler.x_scale, + } + else: # ndim == 3: (t, y, x) + scale_by_dim = { + 0: 1.0, + 1: DB_handler.y_scale, + 2: DB_handler.x_scale, + } + + step = list(self.viewer.dims.current_step) + for dim in self.viewer.dims.not_displayed: + scale = scale_by_dim.get(dim, 1.0) + step[dim] = int(location[dim] / scale + 0.5) + self.viewer.dims.current_step = step + + # Camera centering uses world coordinates — no scale conversion needed + example_layer = self.points_layer + corner_coordinates = example_layer.corner_pixels + dims_displayed = self.viewer.dims.displayed + x_dim = dims_displayed[-1] + y_dim = dims_displayed[-2] + + _min_x = corner_coordinates[0][x_dim] + _max_x = corner_coordinates[1][x_dim] + _min_y = corner_coordinates[0][y_dim] + _max_y = corner_coordinates[1][y_dim] + + if not ( + (location[x_dim] > _min_x and location[x_dim] < _max_x) + and (location[y_dim] > _min_y and location[y_dim] < _max_y) + ): + camera_center = self.viewer.camera.center + self.viewer.camera.center = ( + camera_center[0], + location[y_dim], + location[x_dim], + ) + + return center_view_with_scale + + # --- Custom keybindings --- diff --git a/trackedit/run.py b/trackedit/run.py index 7b2f4bb..bee4383 100644 --- a/trackedit/run.py +++ b/trackedit/run.py @@ -10,8 +10,10 @@ DeleteNodes, ) from motile_tracker.data_views import TracksViewer +from motile_tracker.data_views.views.layers.tracks_layer_group import TracksLayerGroup from trackedit.DatabaseHandler import DatabaseHandler from trackedit.motile_overwrites import ( + create_center_view, create_db_add_edges, create_db_add_nodes, create_db_delete_edges, @@ -132,6 +134,7 @@ def run_trackedit( DeleteEdges._apply = create_db_delete_edges(DB_handler) AddEdges._apply = create_db_add_edges(DB_handler) AddNodes._apply = create_db_add_nodes(DB_handler) + TracksLayerGroup.center_view = create_center_view(DB_handler) TracksViewer._refresh = create_tracks_viewer_and_segments_refresh( layer_name=layer_name )