From 5c94fe8c84fdd55f180b3660a36c7099b492db4a Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 20 Mar 2026 16:45:53 -0700 Subject: [PATCH 1/3] fix centering on cell in 2D + fix 2D instanseg inference ray tracing --- trackedit/TrackEditClass.py | 118 +++++++++++++++++---------------- trackedit/motile_overwrites.py | 63 ++++++++++++++++++ trackedit/run.py | 3 + 3 files changed, 126 insertions(+), 58 deletions(-) diff --git a/trackedit/TrackEditClass.py b/trackedit/TrackEditClass.py index 43e6ba7..66a3e2b 100644 --- a/trackedit/TrackEditClass.py +++ b/trackedit/TrackEditClass.py @@ -686,13 +686,8 @@ def add_instanseg_cell_at_position(self, viewer, current_time): volumes.append(np.asarray(ch_data)) image_volume = np.stack(volumes, axis=0) # (C, Z, Y, X) or (C, Y, X) - # Get view direction for ray casting + # Get view direction for ray casting (unavailable in napari's 2D canvas mode) vd_world = getattr(viewer.cursor, "_view_direction", None) - if vd_world is None or np.allclose(vd_world, 0): - show_warning( - "Could not get view direction from napari cursor. Try clicking again." - ) - return None # Get cursor position in world coordinates cursor_pos_world = np.asarray(viewer.cursor.position, dtype=float) @@ -713,64 +708,71 @@ def add_instanseg_cell_at_position(self, viewer, current_time): [self.databasehandler.y_scale, self.databasehandler.x_scale] ) - # Convert view direction and cursor position to data coordinates - vd_world = np.asarray(vd_world, dtype=float) - # View direction is also 4D (time + spatial), extract only spatial components - vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world - vd_data = vd_spatial_world / scale_array - norm = np.linalg.norm(vd_data) - if norm == 0: - show_warning("Invalid view direction. Try clicking again.") - return None - vd_data /= norm - origin_data = cursor_spatial_world / scale_array - # image_volume is (C, Z, Y, X) or (C, Y, X); spatial shape excludes channel dim - spatial_shape = image_volume.shape[1:] - nuclear_volume = image_volume[0] # use nuclear channel (idx 0) for ray casting - - # Cast ray through volume in both directions - diag = int(np.sqrt(sum(s**2 for s in spatial_shape))) - t_values = np.arange(-diag, diag + 1, dtype=float) - ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :] - ray_voxels = np.round(ray_points).astype(int) - - # Keep only voxels inside the volume - if self.databasehandler.ndim == 4: - valid = ( - (ray_voxels[:, 0] >= 0) - & (ray_voxels[:, 0] < spatial_shape[0]) - & (ray_voxels[:, 1] >= 0) - & (ray_voxels[:, 1] < spatial_shape[1]) - & (ray_voxels[:, 2] >= 0) - & (ray_voxels[:, 2] < spatial_shape[2]) - ) + if vd_world is None or np.allclose(vd_world, 0): + # 2D canvas mode: no view direction available, use cursor position directly + position_data = tuple(np.round(origin_data).astype(float)) else: - valid = ( - (ray_voxels[:, 0] >= 0) - & (ray_voxels[:, 0] < spatial_shape[0]) - & (ray_voxels[:, 1] >= 0) - & (ray_voxels[:, 1] < spatial_shape[1]) - ) + # 3D canvas mode: cast ray and find brightest voxel along it + # Convert view direction to data coordinates + vd_world = np.asarray(vd_world, dtype=float) + # View direction is also 4D (time + spatial), extract only spatial components + vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world + vd_data = vd_spatial_world / scale_array + norm = np.linalg.norm(vd_data) + if norm == 0: + show_warning("Invalid view direction. Try clicking again.") + return None + vd_data /= norm + + # image_volume is (C, Z, Y, X) or (C, Y, X); spatial shape excludes channel dim + spatial_shape = image_volume.shape[1:] + nuclear_volume = image_volume[0] # use nuclear channel (idx 0) for ray casting + + # Cast ray through volume in both directions + diag = int(np.sqrt(sum(s**2 for s in spatial_shape))) + t_values = np.arange(-diag, diag + 1, dtype=float) + ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :] + ray_voxels = np.round(ray_points).astype(int) + + # Keep only voxels inside the volume + if self.databasehandler.ndim == 4: + valid = ( + (ray_voxels[:, 0] >= 0) + & (ray_voxels[:, 0] < spatial_shape[0]) + & (ray_voxels[:, 1] >= 0) + & (ray_voxels[:, 1] < spatial_shape[1]) + & (ray_voxels[:, 2] >= 0) + & (ray_voxels[:, 2] < spatial_shape[2]) + ) + else: + valid = ( + (ray_voxels[:, 0] >= 0) + & (ray_voxels[:, 0] < spatial_shape[0]) + & (ray_voxels[:, 1] >= 0) + & (ray_voxels[:, 1] < spatial_shape[1]) + ) - ray_voxels = ray_voxels[valid] - if len(ray_voxels) == 0: - show_warning("Ray does not intersect volume. Try clicking inside the data.") - return None + ray_voxels = ray_voxels[valid] + if len(ray_voxels) == 0: + show_warning( + "Ray does not intersect volume. Try clicking inside the data." + ) + return None - # Deduplicate and find voxel with maximum intensity in nuclear channel - ray_voxels = np.unique(ray_voxels, axis=0) - if self.databasehandler.ndim == 4: - intensities = nuclear_volume[ - ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2] - ] - else: - intensities = nuclear_volume[ray_voxels[:, 0], ray_voxels[:, 1]] + # Deduplicate and find voxel with maximum intensity in nuclear channel + ray_voxels = np.unique(ray_voxels, axis=0) + if self.databasehandler.ndim == 4: + intensities = nuclear_volume[ + ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2] + ] + else: + intensities = nuclear_volume[ray_voxels[:, 0], ray_voxels[:, 1]] - best_idx = int(np.argmax(intensities)) - best_voxel = ray_voxels[best_idx] - position_data = tuple(best_voxel.astype(float)) + best_idx = int(np.argmax(intensities)) + best_voxel = ray_voxels[best_idx] + position_data = tuple(best_voxel.astype(float)) # Prepare scale tuple and handle 2D vs 3D if self.databasehandler.ndim == 4: diff --git a/trackedit/motile_overwrites.py b/trackedit/motile_overwrites.py index 83977fb..64e029e 100644 --- a/trackedit/motile_overwrites.py +++ b/trackedit/motile_overwrites.py @@ -17,6 +17,7 @@ from motile_tracker.data_model.solution_tracks import SolutionTracks from motile_tracker.data_model.tracks_controller import TracksController from motile_tracker.data_views import TracksViewer +from motile_tracker.data_views.views.layers.tracks_layer_group import TracksLayerGroup from motile_tracker.data_views.views.tree_view.tree_widget import TreePlot, TreeWidget AttrValue: TypeAlias = Any @@ -355,6 +356,68 @@ def fixed_click(layer, event): # TrackLabels.get_status = get_status +def create_center_view(DB_handler): + """Override TracksLayerGroup.center_view to convert world positions to data + coordinates when setting viewer.dims.current_step. + + Positions in the trackgraph are stored in world (scaled) units, but + current_step expects voxel indices, so we divide by the layer scale. + """ + + def center_view_with_scale(self, node): + if self.seg_layer is None or self.seg_layer.mode == "pan_zoom": + location = self.tracks.get_positions([node], incl_time=True)[0].tolist() + assert ( + len(location) == self.viewer.dims.ndim + ), f"Location {location} does not match viewer number of dims {self.viewer.dims.ndim}" + + # Build per-dimension scale: dim 0 is time (unscaled) + if DB_handler.ndim == 4: # (t, z, y, x) + scale_by_dim = { + 0: 1.0, + 1: DB_handler.z_scale, + 2: DB_handler.y_scale, + 3: DB_handler.x_scale, + } + else: # ndim == 3: (t, y, x) + scale_by_dim = { + 0: 1.0, + 1: DB_handler.y_scale, + 2: DB_handler.x_scale, + } + + step = list(self.viewer.dims.current_step) + for dim in self.viewer.dims.not_displayed: + scale = scale_by_dim.get(dim, 1.0) + step[dim] = int(location[dim] / scale + 0.5) + self.viewer.dims.current_step = step + + # Camera centering uses world coordinates — no scale conversion needed + example_layer = self.points_layer + corner_coordinates = example_layer.corner_pixels + dims_displayed = self.viewer.dims.displayed + x_dim = dims_displayed[-1] + y_dim = dims_displayed[-2] + + _min_x = corner_coordinates[0][x_dim] + _max_x = corner_coordinates[1][x_dim] + _min_y = corner_coordinates[0][y_dim] + _max_y = corner_coordinates[1][y_dim] + + if not ( + (location[x_dim] > _min_x and location[x_dim] < _max_x) + and (location[y_dim] > _min_y and location[y_dim] < _max_y) + ): + camera_center = self.viewer.camera.center + self.viewer.camera.center = ( + camera_center[0], + location[y_dim], + location[x_dim], + ) + + return center_view_with_scale + + # --- Custom keybindings --- diff --git a/trackedit/run.py b/trackedit/run.py index 7b2f4bb..13df903 100644 --- a/trackedit/run.py +++ b/trackedit/run.py @@ -12,6 +12,8 @@ from motile_tracker.data_views import TracksViewer from trackedit.DatabaseHandler import DatabaseHandler from trackedit.motile_overwrites import ( + TracksLayerGroup, + create_center_view, create_db_add_edges, create_db_add_nodes, create_db_delete_edges, @@ -132,6 +134,7 @@ def run_trackedit( DeleteEdges._apply = create_db_delete_edges(DB_handler) AddEdges._apply = create_db_add_edges(DB_handler) AddNodes._apply = create_db_add_nodes(DB_handler) + TracksLayerGroup.center_view = create_center_view(DB_handler) TracksViewer._refresh = create_tracks_viewer_and_segments_refresh( layer_name=layer_name ) From b22fedb65910043905c62cc8dc0c436ac4b1729c Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 20 Mar 2026 17:05:09 -0700 Subject: [PATCH 2/3] precommit fixes --- trackedit/TrackEditClass.py | 4 +++- trackedit/motile_overwrites.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/trackedit/TrackEditClass.py b/trackedit/TrackEditClass.py index 66a3e2b..65428ce 100644 --- a/trackedit/TrackEditClass.py +++ b/trackedit/TrackEditClass.py @@ -728,7 +728,9 @@ def add_instanseg_cell_at_position(self, viewer, current_time): # image_volume is (C, Z, Y, X) or (C, Y, X); spatial shape excludes channel dim spatial_shape = image_volume.shape[1:] - nuclear_volume = image_volume[0] # use nuclear channel (idx 0) for ray casting + nuclear_volume = image_volume[ + 0 + ] # use nuclear channel (idx 0) for ray casting # Cast ray through volume in both directions diag = int(np.sqrt(sum(s**2 for s in spatial_shape))) diff --git a/trackedit/motile_overwrites.py b/trackedit/motile_overwrites.py index 64e029e..c24791f 100644 --- a/trackedit/motile_overwrites.py +++ b/trackedit/motile_overwrites.py @@ -17,7 +17,6 @@ from motile_tracker.data_model.solution_tracks import SolutionTracks from motile_tracker.data_model.tracks_controller import TracksController from motile_tracker.data_views import TracksViewer -from motile_tracker.data_views.views.layers.tracks_layer_group import TracksLayerGroup from motile_tracker.data_views.views.tree_view.tree_widget import TreePlot, TreeWidget AttrValue: TypeAlias = Any From c73b17384d9c841fcfa2c891a94b42520ed5163d Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 20 Mar 2026 17:11:40 -0700 Subject: [PATCH 3/3] missing import --- trackedit/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trackedit/run.py b/trackedit/run.py index 13df903..bee4383 100644 --- a/trackedit/run.py +++ b/trackedit/run.py @@ -10,9 +10,9 @@ DeleteNodes, ) from motile_tracker.data_views import TracksViewer +from motile_tracker.data_views.views.layers.tracks_layer_group import TracksLayerGroup from trackedit.DatabaseHandler import DatabaseHandler from trackedit.motile_overwrites import ( - TracksLayerGroup, create_center_view, create_db_add_edges, create_db_add_nodes,