From 37dd4272b27d510c5d9e1dbc4c10a6c0d7bb2af0 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Thu, 5 Feb 2026 14:37:26 -0800 Subject: [PATCH 1/8] get metadata from geff and save as metadata.toml when converting geff to db --- pixi.lock | 4 ++-- trackedit/utils/geff.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pixi.lock b/pixi.lock index a942b64..da1f8b4 100644 --- a/pixi.lock +++ b/pixi.lock @@ -10440,7 +10440,7 @@ packages: requires_python: '>=3.7' - pypi: . name: trackedit - version: 0.0.10 + version: 0.0.12 sha256: 53d7ebc64418c7e2c6c3b9b83ee5341fa024c3fae629304a75d31deb78d8bff3 requires_dist: - click>=8.0 @@ -10450,7 +10450,7 @@ packages: editable: true - pypi: . name: trackedit - version: 0.0.10 + version: 0.0.12 sha256: 53d7ebc64418c7e2c6c3b9b83ee5341fa024c3fae629304a75d31deb78d8bff3 requires_dist: - click>=8.0 diff --git a/trackedit/utils/geff.py b/trackedit/utils/geff.py index 311a6dc..c485964 100644 --- a/trackedit/utils/geff.py +++ b/trackedit/utils/geff.py @@ -197,8 +197,24 @@ def convert_geff_to_db(geff_path: Path, output_path: Path = None) -> None: print(f"✓ Inserted {len(edge_records)} edges") print(f"✓ Database saved to: {database_path}") - print("Don't forget to create a metadata.toml file with the following content: ") - print("shape = [ 5, 128, 128, 128] ([t, (z, ), y, x])") + + # Extract and save shape metadata to TOML file + if geff_metadata.extra and "tracksdata" in geff_metadata.extra: + shape = geff_metadata.extra["tracksdata"].get("shape") + if shape: + metadata_path = database_path.parent / "metadata.toml" + shape_str = ", ".join(str(s) for s in shape) + toml_content = f"shape = [ {shape_str},]\n" + metadata_path.write_text(toml_content) + print(f"✓ Saved metadata to: {metadata_path}") + else: + print("No shape found in GEFF metadata extra/tracksdata") + else: + print("⚠ No extra/tracksdata metadata found in GEFF file") + print( + "Don't forget to create a metadata.toml file with the following content: " + ) + print("shape = [ 5, 128, 128, 128] ([t, (z, ), y, x])") @click.command() From a4e4642cbdbd1b0e8285b5c0de66bb58e57b34bd Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Tue, 10 Feb 2026 15:49:32 -0800 Subject: [PATCH 2/8] fix 3D selection of labels by clicking --- trackedit/motile_overwrites.py | 66 ++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/trackedit/motile_overwrites.py b/trackedit/motile_overwrites.py index 7f6d40b..5ee577c 100644 --- a/trackedit/motile_overwrites.py +++ b/trackedit/motile_overwrites.py @@ -284,6 +284,72 @@ def patched_create_pyqtgraph_content(self, track_df, feature): TreePlot._create_pyqtgraph_content = patched_create_pyqtgraph_content + +# Patch TrackLabels click handler to fix DatabaseArray lazy loading issue +def patch_track_labels_click_handler(): + """Monkey patch TrackLabels to add DatabaseArray loading workaround. + + After colormap updates, napari's get_value() fails to trigger DatabaseArray + loading. This patch pre-accesses the array to force loading before get_value(). + See: .claude/click-selection-bug-fix.md for details. + """ + from motile_tracker.data_views.views.layers.track_labels import TrackLabels + + # Store original __init__ + _original_init = TrackLabels.__init__ + + def patched_init(self, viewer, data, name, opacity, scale, tracks_viewer): + # Call original __init__ which sets up the original click callback + _original_init(self, viewer, data, name, opacity, scale, tracks_viewer) + + # Remove the original click callback (it's the last one added) + if self.mouse_drag_callbacks: + self.mouse_drag_callbacks.pop() + + # Add our fixed click callback + @self.mouse_drag_callbacks.append + def fixed_click(layer, event): + if ( + event.type == "mouse_press" + and layer.mode == "pan_zoom" + and not ( + layer.tracks_viewer.mode == "lineage" + and layer.viewer.dims.ndisplay == 3 + ) + ): + # WORKAROUND: Pre-access array to trigger DatabaseArray loading + # Without this, get_value() fails after colormap updates + data_coords = layer.world_to_data(event.position) + try: + t_idx = int(data_coords[0]) + # Access time slice to ensure DatabaseArray.fill_array() is called + _ = layer.data[t_idx] + except Exception: + pass # If this fails, get_value() will also fail + + label = layer.get_value( + event.position, + view_direction=event.view_direction, + dims_displayed=event.dims_displayed, + world=True, + ) + + if ( + label is not None + and label != 0 + and layer.colormap.map(label)[-1] != 0 + ): + append = "Shift" in event.modifiers + layer.tracks_viewer.selected_nodes.add(label, append) + + # Replace TrackLabels.__init__ with patched version + TrackLabels.__init__ = patched_init + + +# Apply the patch +patch_track_labels_click_handler() + + # def get_status(self, position, view_direction=None, dims_displayed=None, world=True): # return "True" #works to allow napari grid view, but not for cursor position/value display # TrackLabels.get_status = get_status From faa79aeef9927da4e66aa2ac1ed6369be83b31d7 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Tue, 10 Feb 2026 15:55:42 -0800 Subject: [PATCH 3/8] add claude to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index dc80cdd..bdaf5ab 100644 --- a/.gitignore +++ b/.gitignore @@ -168,6 +168,7 @@ cython_debug/ .history .vscode/ CLAUDE.md +.claude # Ignore examples files examples/*/* From 49cb28701ddd56478b7791910869b066d5cb5e40 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Wed, 11 Feb 2026 10:11:53 -0800 Subject: [PATCH 4/8] instanseg inference on the spot --- scripts/script_instanseg.py | 68 ++++ trackedit/TrackEditClass.py | 352 ++++++++++++++++++- trackedit/instanseg_inference.py | 418 +++++++++++++++++++++++ trackedit/run.py | 30 ++ trackedit/utils/red_flag_funcs.py | 3 +- trackedit/widgets/CustomEditingWidget.py | 36 +- 6 files changed, 893 insertions(+), 14 deletions(-) create mode 100644 scripts/script_instanseg.py create mode 100644 trackedit/instanseg_inference.py diff --git a/scripts/script_instanseg.py b/scripts/script_instanseg.py new file mode 100644 index 0000000..0e01e14 --- /dev/null +++ b/scripts/script_instanseg.py @@ -0,0 +1,68 @@ +import sys +import warnings +from pathlib import Path + +import numpy as np + +from trackedit.run import run_trackedit + +# Databases saved with numpy>2 need np._core.numeric, which is not available in numpy<2, hence the following hack +sys.modules["numpy._core.numeric"] = np.core.numeric + +warnings.filterwarnings("ignore", category=FutureWarning, message=".*qt_viewer.*") + +# **********INPUTS********* +# path to the working directory that contains the database file AND metadata.toml: +working_directory = Path("/home/teun.huijben/Documents/data/Thibaut/masks_on_geff/") +# name of the database file to start from, or "latest" to start from the latest version, defaults to "data.db" +db_filename_start = "latest" +# maximum number of frames display, defaults to None (use all frames) +tmax = 100 +# (Z),Y,X, defaults to (1, 1, 1) +scale = (1.625, 0.40625, 0.40625) +# overwrite existing database/changelog, defaults to False (not used when db_filename_start is "latest") +allow_overwrite = False + +# OPTIONAL: imaging data +imaging_zarr_file = ( + "/hpc/projects/group.royer/people/teun.huijben/data/Thibault/4th_exp/first_fov.zarr" +) +imaging_channel = "0" +imaging_layer_names = ["dense", "sparse"] + +# OPTIONAL: annotation mapping (default is neuromast cell types) +# annotation_mapping = { +# 1: {"name": "hair", "color": [0.0, 1.0, 0.0, 1.0]}, # green +# 2: {"name": "support", "color": [1.0, 0.1, 0.6, 1.0]}, # pink +# 3: {"name": "mantle", "color": [0.0, 0.0, 0.9, 1.0]}, # blue +# } +annotation_mapping = None + +# OPTIONAL: InstanSeg model for interactive cell segmentation +# Enable this to add cells via InstanSeg inference instead of spherical masks +flag_allow_adding_instanseg_cell = True +instanseg_model_path = ( + "/hpc/projects/group.royer/people/teun.huijben/data/Thibault/model_96.pt" +) +instanseg_device = None # 'cuda', 'cpu', or None for auto-detect +# ************************* + +if __name__ == "__main__": + run_trackedit( + working_directory=working_directory, + db_filename=db_filename_start, + tmax=tmax, + scale=scale, + allow_overwrite=allow_overwrite, + imaging_zarr_file=imaging_zarr_file, + imaging_channel=imaging_channel, + imaging_layer_names=imaging_layer_names, + annotation_mapping=annotation_mapping, + flag_allow_adding_spherical_cell=True, + adding_spherical_cell_radius=10, + flag_remove_red_flags_at_edge=True, + remove_red_flags_at_edge_threshold=10, + flag_allow_adding_instanseg_cell=flag_allow_adding_instanseg_cell, + instanseg_model_path=instanseg_model_path, + instanseg_device=instanseg_device, + ) diff --git a/trackedit/TrackEditClass.py b/trackedit/TrackEditClass.py index f08f222..3714373 100644 --- a/trackedit/TrackEditClass.py +++ b/trackedit/TrackEditClass.py @@ -25,11 +25,14 @@ def __init__( flag_show_hierarchy: bool = True, flag_allow_adding_spherical_cell: bool = False, adding_spherical_cell_radius: int = 10, + flag_allow_adding_instanseg_cell: bool = False, + instanseg_inference=None, ): self.viewer = viewer self.viewer.layers.clear() # Remove all existing layers self.databasehandler = databasehandler self.flag_show_hierarchy = flag_show_hierarchy + self.instanseg_inference = instanseg_inference self.tracksviewer = TracksViewer.get_instance(self.viewer) @@ -41,6 +44,7 @@ def __init__( self.databasehandler, allow_adding_spherical_cell=flag_allow_adding_spherical_cell, adding_spherical_cell_radius=adding_spherical_cell_radius, + allow_adding_instanseg_cell=flag_allow_adding_instanseg_cell, ) tabwidget_right = QTabWidget() @@ -98,7 +102,9 @@ def __init__( colormap=colormap, opacity=opacity, scale=self.databasehandler.scale, - visible=False, + visible=True + if i == 0 + else False, # Show only first channel by default translate=self.databasehandler.image_translate, ) layer.reset_contrast_limits() @@ -130,11 +136,19 @@ def __init__( # Connect spherical cell signal if feature is enabled if flag_allow_adding_spherical_cell: self.EditingMenu.add_spherical_cell_toggled.connect( - self._toggle_add_cell_mode + self._toggle_add_spherical_cell_mode ) # Initialize spherical cell mode flag self._add_cell_mode_active = False + # Connect InstanSeg cell signal if feature is enabled + if flag_allow_adding_instanseg_cell: + self.EditingMenu.add_instanseg_cell_toggled.connect( + self._toggle_add_instanseg_mode + ) + # Initialize InstanSeg cell mode flag + self._add_instanseg_mode_active = False + self.add_tracks() self.NavigationWidget.time_box.update_chunk_label() self.NavigationWidget.red_flag_box.update_red_flag_counter_and_info() @@ -409,7 +423,6 @@ def duplicate_cell_from_database(self, node_id: int, time: int): bbox=pickle.bbox, include_overlaps=True, ) - print("added node to db:", new_id, "at time", time) max_track_id = max( self.NavigationWidget.tracks_viewer.tracks_controller.tracks.track_id_to_node.keys() @@ -511,31 +524,39 @@ def add_spherical_cell_at_position(self, position_scaled, radius_pixels=10): 0, lambda: ( self.EditingMenu.add_spherical_cell_btn.setChecked(False), - self._toggle_add_cell_mode(False), + self._toggle_add_spherical_cell_mode(False), ), ) return new_node_id - def _toggle_add_cell_mode(self, checked): + def _toggle_add_spherical_cell_mode(self, checked): """Toggle the add spherical cell mode on/off.""" self._add_cell_mode_active = checked if checked: # Only add callback if not already present - if self._on_mouse_click_add_cell not in self.viewer.mouse_drag_callbacks: - self.viewer.mouse_drag_callbacks.append(self._on_mouse_click_add_cell) + if ( + self._on_mouse_click_add_spherical_cell + not in self.viewer.mouse_drag_callbacks + ): + self.viewer.mouse_drag_callbacks.append( + self._on_mouse_click_add_spherical_cell + ) else: # Remove ALL instances of the callback (in case of duplicates) - while self._on_mouse_click_add_cell in self.viewer.mouse_drag_callbacks: + while ( + self._on_mouse_click_add_spherical_cell + in self.viewer.mouse_drag_callbacks + ): try: self.viewer.mouse_drag_callbacks.remove( - self._on_mouse_click_add_cell + self._on_mouse_click_add_spherical_cell ) except ValueError: break - def _on_mouse_click_add_cell(self, viewer, event): + def _on_mouse_click_add_spherical_cell(self, viewer, event): """Handle mouse click when add cell mode is active. Called when user clicks in viewer while add cell mode is on. @@ -563,3 +584,314 @@ def _on_mouse_click_add_cell(self, viewer, event): # Yield to prevent further event propagation yield + + def _toggle_add_instanseg_mode(self, checked): + """Toggle the add InstanSeg cell mode on/off.""" + self._add_instanseg_mode_active = checked + + if checked: + # Mutual exclusion: disable spherical mode if active + if hasattr(self, "_add_cell_mode_active") and self._add_cell_mode_active: + if hasattr(self.EditingMenu, "add_spherical_cell_btn"): + self.EditingMenu.add_spherical_cell_btn.setChecked(False) + + # Only add callback if not already present + if ( + self._on_mouse_click_add_instanseg + not in self.viewer.mouse_drag_callbacks + ): + self.viewer.mouse_drag_callbacks.append( + self._on_mouse_click_add_instanseg + ) + else: + # Remove ALL instances of the callback (in case of duplicates) + while ( + self._on_mouse_click_add_instanseg in self.viewer.mouse_drag_callbacks + ): + try: + self.viewer.mouse_drag_callbacks.remove( + self._on_mouse_click_add_instanseg + ) + except ValueError: + break + + def _on_mouse_click_add_instanseg(self, viewer, event): + """Handle mouse click when InstanSeg add cell mode is active. + + Called when user clicks in viewer while InstanSeg mode is on. + """ + # Guard: only proceed if mode is actually active + if not self._add_instanseg_mode_active: + return + + # Only trigger on click (not drag) + if event.type == "mouse_press": + # Get click position and time + position = viewer.cursor.position + current_time = int(position[0]) + + # Add InstanSeg cell at clicked position (pass viewer for ray casting) + self.add_instanseg_cell_at_position(viewer, current_time) + + # Yield to prevent further event propagation + yield + + def add_instanseg_cell_at_position(self, viewer, current_time): + """Add a new cell using InstanSeg segmentation at the clicked position. + + Uses ray casting to find the brightest pixel along the camera ray, + then runs InstanSeg inference with that point as a seed. + + Parameters + ---------- + viewer : napari.Viewer + The napari viewer instance + current_time : int + Current time frame + + Returns + ------- + new_node_id : int or None + Database ID of the newly created node, or None if failed + """ + # Verify imaging data is available + if not self.databasehandler.imaging_flag: + show_warning("InstanSeg requires imaging data. No imaging data loaded.") + return None + + # Verify InstanSeg inference is available + if self.instanseg_inference is None: + show_warning("InstanSeg inference engine not initialized.") + return None + + # Update chunk if needed + self.update_chunk_from_frame(current_time) + + # Get image at current time - use sparse channel (index 1) + image_volume = self.databasehandler.imagingArray.get_channel_data(1)[ + current_time + ] # (Z, Y, X) or (Y, X) + + # Convert dask array to numpy array if needed + if hasattr(image_volume, "compute"): + image_volume = image_volume.compute() + image_volume = np.asarray(image_volume) + + # Get view direction for ray casting + vd_world = getattr(viewer.cursor, "_view_direction", None) + if vd_world is None or np.allclose(vd_world, 0): + show_warning( + "Could not get view direction from napari cursor. Try clicking again." + ) + return None + + # Get cursor position in world coordinates + cursor_pos_world = np.asarray(viewer.cursor.position, dtype=float) + + # Extract spatial position (remove time dimension) + if self.databasehandler.ndim == 4: + cursor_spatial_world = cursor_pos_world[1:] # (z, y, x) + scale_array = np.array( + [ + self.databasehandler.z_scale, + self.databasehandler.y_scale, + self.databasehandler.x_scale, + ] + ) + else: + cursor_spatial_world = cursor_pos_world[1:] # (y, x) + scale_array = np.array( + [self.databasehandler.y_scale, self.databasehandler.x_scale] + ) + + # Convert view direction and cursor position to data coordinates + vd_world = np.asarray(vd_world, dtype=float) + # View direction is also 4D (time + spatial), extract only spatial components + vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world + vd_data = vd_spatial_world / scale_array + norm = np.linalg.norm(vd_data) + if norm == 0: + show_warning("Invalid view direction. Try clicking again.") + return None + vd_data /= norm + + origin_data = cursor_spatial_world / scale_array + + # Cast ray through volume in both directions + diag = int(np.sqrt(sum(s**2 for s in image_volume.shape))) + t_values = np.arange(-diag, diag + 1, dtype=float) + ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :] + ray_voxels = np.round(ray_points).astype(int) + + # Keep only voxels inside the volume + if self.databasehandler.ndim == 4: + valid = ( + (ray_voxels[:, 0] >= 0) + & (ray_voxels[:, 0] < image_volume.shape[0]) + & (ray_voxels[:, 1] >= 0) + & (ray_voxels[:, 1] < image_volume.shape[1]) + & (ray_voxels[:, 2] >= 0) + & (ray_voxels[:, 2] < image_volume.shape[2]) + ) + else: + valid = ( + (ray_voxels[:, 0] >= 0) + & (ray_voxels[:, 0] < image_volume.shape[0]) + & (ray_voxels[:, 1] >= 0) + & (ray_voxels[:, 1] < image_volume.shape[1]) + ) + + ray_voxels = ray_voxels[valid] + if len(ray_voxels) == 0: + show_warning("Ray does not intersect volume. Try clicking inside the data.") + return None + + # Deduplicate and find voxel with maximum intensity (brightest pixel) + ray_voxels = np.unique(ray_voxels, axis=0) + if self.databasehandler.ndim == 4: + intensities = image_volume[ + ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2] + ] + else: + intensities = image_volume[ray_voxels[:, 0], ray_voxels[:, 1]] + + best_idx = int(np.argmax(intensities)) + best_voxel = ray_voxels[best_idx] + position_data = tuple(best_voxel.astype(float)) + + # Prepare scale tuple and handle 2D vs 3D + if self.databasehandler.ndim == 4: + # 3D data: need (z_scale, y_scale, x_scale) in microns + scale = ( + self.databasehandler.z_scale, + self.databasehandler.y_scale, + self.databasehandler.x_scale, + ) + else: + # 2D data: InstanSeg expects 3D, add singleton z dimension + image_volume = image_volume[np.newaxis, ...] # (1, Y, X) + position_data = (0,) + position_data # (0, y, x) + # Use y_scale as z_scale for isotropic rescaling + scale = ( + self.databasehandler.y_scale, + self.databasehandler.y_scale, + self.databasehandler.x_scale, + ) + + # Run InstanSeg inference + try: + mask, bbox = self.instanseg_inference.run_inference_at_position( + image_volume=image_volume, + time_frame=current_time, + position=position_data, + scale=scale, + ) + except RuntimeError as e: + if "CUDA out of memory" in str(e): + show_warning( + "GPU out of memory! Volume is too large after rescaling to isotropic. " + "Solutions: (1) Close other GPU apps, (2) Set instanseg_device='cpu' " + "in script (slower but uses less memory), or (3) Process smaller regions." + ) + else: + show_warning(f"InstanSeg inference failed: {str(e)}") + import traceback + + traceback.print_exc() + return None + except Exception as e: + show_warning(f"InstanSeg inference failed: {str(e)}") + import traceback + + traceback.print_exc() + return None + + if mask is None: + show_warning("InstanSeg could not segment a cell at this position.") + return None + + # Remove singleton z dimension for 2D data + if self.databasehandler.ndim == 3: + mask = mask.squeeze(0) # (Y, X) + bbox = (bbox[0][1:], bbox[1][1:]) # Remove z from bbox + + # Add to database + try: + new_node_id = add_new_node( + self.databasehandler.config_adjusted, + time=current_time, + mask=mask, + bbox=bbox, + include_overlaps=True, + ) + fix_overlap_ancestor_ids( + database_path=self.databasehandler.config_adjusted.data_config.database_path, + new_node_id=new_node_id, + current_time=current_time, + ) + except Exception as e: + show_warning(f"Failed to add node to database: {e}") + import traceback + + traceback.print_exc() + return None + + # Add to tracking system + track_ids = ( + self.NavigationWidget.tracks_viewer.tracks_controller.tracks.track_id_to_node.keys() + ) + max_track_id = max(track_ids) if track_ids else 0 + time_in_chunk = current_time - self.databasehandler.time_window[0] + + attributes = { + NodeAttr.TIME.value: [time_in_chunk], + NodeAttr.TRACK_ID.value: [max_track_id + 1], + "node_id": [new_node_id], + } + self.tracksviewer.tracks_controller.add_nodes( + attributes, [(np.array([0, 0, 0]))] + ) + + # Refresh and auto-disable + show_info(f"Added InstanSeg cell with ID {new_node_id}") + self.databasehandler.segments.force_refill() + self.viewer.layers[self.databasehandler.name + "_seg"].refresh() + + # Auto-disable only if button exists + if hasattr(self.EditingMenu, "add_instanseg_cell_btn"): + from qtpy.QtCore import QTimer + + QTimer.singleShot( + 0, + lambda: ( + self.EditingMenu.add_instanseg_cell_btn.setChecked(False), + self._toggle_add_instanseg_mode(False), + ), + ) + + return new_node_id + + def _convert_viewer_to_data_coords(self, position_spatial: tuple) -> tuple: + """Convert viewer-scaled coordinates to database data coordinates. + + Parameters + ---------- + position_spatial : tuple + Position in viewer coordinates (z, y, x) or (y, x), scaled + + Returns + ------- + tuple + Position in data coordinates (z, y, x) or (y, x), unscaled + """ + if self.databasehandler.ndim == 4: + # 4D: (z, y, x) viewer -> (z, y, x) data + z_unscaled = position_spatial[0] / self.databasehandler.z_scale + y_unscaled = position_spatial[1] / self.databasehandler.y_scale + x_unscaled = position_spatial[2] / self.databasehandler.x_scale + return (z_unscaled, y_unscaled, x_unscaled) + else: + # 3D: (y, x) viewer -> (y, x) data + y_unscaled = position_spatial[0] / self.databasehandler.y_scale + x_unscaled = position_spatial[1] / self.databasehandler.x_scale + return (y_unscaled, x_unscaled) diff --git a/trackedit/instanseg_inference.py b/trackedit/instanseg_inference.py new file mode 100644 index 0000000..6c0561a --- /dev/null +++ b/trackedit/instanseg_inference.py @@ -0,0 +1,418 @@ +"""InstanSeg inference integration for TrackEdit. + +This module provides interactive cell segmentation using InstanSeg models. +It manages model loading, embedding caching, and inference execution for +click-based cell addition. +""" + +from pathlib import Path +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from napari.utils.notifications import show_warning + +# Target pixel size for isotropic rescaling (must match training) +TARGET_PIXEL_SIZE = 0.5 + + +def normalize(image: np.ndarray) -> np.ndarray: + """Percentile-normalize a volume to [0, 1].""" + data = image.astype(np.float32) + low = np.percentile(data, 0.1) + high = np.percentile(data, 99.9) + if high > low: + normalized = (data - low) / (high - low) + # Clamp to [0, 1] to handle outliers above 99.9th percentile + return np.clip(normalized, 0, 1) + return data / (data.max() + 1e-8) + + +def rescale_to_isotropic( + image: np.ndarray, + current_scale: tuple, + target_pixel_size: float = TARGET_PIXEL_SIZE, + device: str = "cpu", +) -> Tuple[np.ndarray, tuple]: + """Rescale a (Z, Y, X) volume to isotropic resolution using trilinear interpolation. + + Args: + image: Input volume (Z, Y, X) + current_scale: Current voxel sizes (z_size, y_size, x_size) in microns + target_pixel_size: Target isotropic pixel size in microns + device: Device to run interpolation on ('cuda' or 'cpu') + + Returns: + (rescaled_image_np, scale_factors) where scale_factors = (sz, sy, sx) + """ + sz = current_scale[0] / target_pixel_size + sy = current_scale[1] / target_pixel_size + sx = current_scale[2] / target_pixel_size + + d, h, w = image.shape + new_d = int(round(d * sz)) + new_h = int(round(h * sy)) + new_w = int(round(w * sx)) + + tensor = torch.from_numpy(image).float().unsqueeze(0).unsqueeze(0) # (1,1,Z,Y,X) + if device == "cuda" and torch.cuda.is_available(): + tensor = tensor.to(device) + rescaled = F.interpolate( + tensor, size=(new_d, new_h, new_w), mode="trilinear", align_corners=False + ) + if device == "cuda": + rescaled = rescaled.cpu() + return rescaled.squeeze().numpy(), (sz, sy, sx) + + +def rescale_labels_back(labels: np.ndarray, original_shape: tuple) -> np.ndarray: + """Rescale labels from isotropic back to original shape using nearest interpolation.""" + tensor = torch.from_numpy(labels.astype(np.float32)).unsqueeze(0).unsqueeze(0) + rescaled = F.interpolate(tensor, size=original_shape, mode="nearest") + return rescaled.squeeze().numpy().astype(np.int32) + + +def run_backbone(model, image_normalized: np.ndarray, device: str) -> torch.Tensor: + """Run only the backbone. Returns (C, D, H, W) prediction tensor on GPU.""" + tensor = torch.from_numpy(image_normalized).float() + if tensor.ndim == 3: + tensor = tensor.unsqueeze(0).unsqueeze(0) # (1, 1, Z, Y, X) + tensor = tensor.to(device) + with torch.no_grad(): + pred = model.backbone(tensor) # (1, C, D, H, W) + return pred[0] # (C, D, H, W), stays on device + + +def run_postprocessing( + model, + cached_pred: torch.Tensor, + precomputed_seeds: Optional[torch.Tensor] = None, + seed_threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_threshold: float = 1.0, + pairwise_filter_threshold: float = 0.0, + peak_distance: int = 4, + window_size: int = 64, + use_iom: bool = False, + use_pairwise_seed_filter: bool = False, + cleanup_fragments: bool = True, +) -> np.ndarray: + """Run only the postprocessing on a cached prediction. + + Args: + model: InstanSeg TorchScript model + cached_pred: Cached backbone embeddings (C, D, H, W) on GPU + precomputed_seeds: Optional seed points in isotropic coordinates (N, 3) + seed_threshold: Threshold for seed detection + mask_threshold: Threshold for mask generation + overlap_threshold: Threshold for overlap detection + pairwise_filter_threshold: Threshold for pairwise filtering + peak_distance: Distance for peak detection + window_size: Size of processing window + use_iom: Use intersection over minimum instead of IOU + use_pairwise_seed_filter: Enable pairwise seed filtering + cleanup_fragments: Remove small fragments + + Returns: + Instance labels (Z, Y, X) as int32 numpy array + """ + # Prepare seeds if provided + seeds = None + if precomputed_seeds is not None and precomputed_seeds.shape[0] > 0: + seeds = precomputed_seeds.to(cached_pred.device) + + with torch.no_grad(): + labels = model.postprocessing( + cached_pred, + seed_threshold=seed_threshold, + mask_threshold=mask_threshold, + overlap_threshold=overlap_threshold, + pairwise_filter_threshold=pairwise_filter_threshold, + peak_distance=peak_distance, + window_size=window_size, + use_iom=use_iom, + use_pairwise_seed_filter=use_pairwise_seed_filter, + cleanup_fragments=cleanup_fragments, + precomputed_seeds=seeds if seeds is not None else torch.empty(0), + ) # (1, D, H, W) + + return labels.squeeze().cpu().numpy().astype(np.int32) + + +def extract_mask_at_seed( + labels: np.ndarray, seed_position: tuple +) -> Tuple[Optional[np.ndarray], Optional[tuple]]: + """Extract the instance mask at seed position and compute its bounding box. + + Args: + labels: Instance segmentation (Z, Y, X) with unique IDs per cell + seed_position: (z, y, x) position in label coordinates + + Returns: + (mask, bbox) where: + - mask: Binary mask (Z, Y, X) of the cell at seed, or None if no cell + - bbox: Bounding box ((z_min, y_min, x_min), (z_max, y_max, x_max)), or None + """ + z, y, x = seed_position + z, y, x = int(round(z)), int(round(y)), int(round(x)) + + # Check bounds + if not ( + 0 <= z < labels.shape[0] + and 0 <= y < labels.shape[1] + and 0 <= x < labels.shape[2] + ): + return None, None + + # Get label ID at seed position + label_id = labels[z, y, x] + + if label_id == 0: + # No cell at this position + return None, None + + # Extract binary mask for this instance + mask = (labels == label_id).astype(bool) + + # Return the full mask (will be cropped after rescaling to original resolution) + # Don't crop here because we need to rescale the full mask first + return mask, None + + +class InstanSegInference: + """Manages InstanSeg model and inference for TrackEdit. + + This class provides interactive cell segmentation by: + 1. Loading a TorchScript InstanSeg model + 2. Caching embeddings per time frame for fast repeated inference + 3. Running seed-based postprocessing on cached embeddings + 4. Extracting single-cell masks from instance segmentation + """ + + def __init__( + self, + model_path: str, + device: Optional[str] = None, + cache_size: int = 3, + ): + """Initialize InstanSeg inference engine. + + Args: + model_path: Path to TorchScript (.pt) model file + device: Device for inference ('cuda', 'cpu', or None for auto) + cache_size: Maximum number of frames to cache embeddings for + """ + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"Model not found: {model_path}") + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + self.device = device + self.cache_size = cache_size + + # Load model + self.model = torch.jit.load(str(model_path), map_location=device) + self.model.eval() + + # Cache structure: {time: (embeddings_tensor, iso_shape, scale_factors, original_shape)} + self.cached_embeddings = {} + self.cache_order = [] # Track insertion order for LRU eviction + + # Verify GPU setup + if device == "cuda": + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name(0) + gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 + print(f"✓ InstanSeg using GPU: {gpu_name} ({gpu_memory:.1f} GB)") + else: + print( + "⚠ Warning: CUDA requested but not available, falling back to CPU" + ) + self.device = "cpu" + else: + print(f"InstanSeg using device: {device}") + + def run_inference_at_position( + self, + image_volume: np.ndarray, + time_frame: int, + position: tuple, + scale: tuple, + ) -> Tuple[Optional[np.ndarray], Optional[tuple]]: + """Run InstanSeg inference with seed at position. + + Args: + image_volume: Raw image volume (Z, Y, X) at original resolution + time_frame: Time frame index (for caching) + position: (z, y, x) seed position in original data coordinates + scale: (z_scale, y_scale, x_scale) voxel sizes in microns + + Returns: + (mask, bbox) tuple: + - mask: Binary mask (Z, Y, X) at original resolution, or None if no cell + - bbox: Bounding box ((z_min, y_min, x_min), (z_max, y_max, x_max)), or None + """ + try: + # Normalize image + image_normalized = normalize(image_volume) + + # Rescale to isotropic resolution (on GPU if available) + image_isotropic, scale_factors = rescale_to_isotropic( + image_normalized, scale, TARGET_PIXEL_SIZE, device=self.device + ) + iso_shape = image_isotropic.shape + original_shape = image_volume.shape + + # Check cache for embeddings + if time_frame in self.cached_embeddings: + embeddings, cached_iso_shape, _, _ = self.cached_embeddings[time_frame] + + # Verify shape matches (in case scale changed) + if cached_iso_shape != iso_shape: + embeddings = self._compute_and_cache_embeddings( + time_frame, + image_isotropic, + iso_shape, + scale_factors, + original_shape, + ) + else: + # Run backbone and cache + embeddings = self._compute_and_cache_embeddings( + time_frame, + image_isotropic, + iso_shape, + scale_factors, + original_shape, + ) + + # Convert position to isotropic coordinates + position_iso = self._convert_to_isotropic_coords( + position, original_shape, iso_shape + ) + + # Run postprocessing with seed + seed_tensor = torch.from_numpy(np.array([position_iso])).float() # (1, 3) + + labels_isotropic = run_postprocessing( + self.model, + embeddings, + precomputed_seeds=seed_tensor, + seed_threshold=0.5, + mask_threshold=0.5, + overlap_threshold=1.0, + pairwise_filter_threshold=0.0, + peak_distance=4, + window_size=64, + use_iom=False, + use_pairwise_seed_filter=False, + cleanup_fragments=True, + ) + + # Extract single-instance mask at seed + mask_isotropic, _ = extract_mask_at_seed(labels_isotropic, position_iso) + + if mask_isotropic is None: + show_warning("No cell detected at clicked position") + return None, None + + # Rescale mask back to original resolution + mask_original = rescale_labels_back(mask_isotropic, original_shape) + mask_binary = (mask_original > 0).astype(bool) + + # Compute bounding box in original coordinates + coords = np.argwhere(mask_binary > 0) + if len(coords) == 0: + return None, None + + bbox_min = coords.min(axis=0) + bbox_max = coords.max(axis=0) + 1 + + # Crop mask to bbox region (mask must match bbox shape) + mask_cropped = mask_binary[ + bbox_min[0] : bbox_max[0], + bbox_min[1] : bbox_max[1], + bbox_min[2] : bbox_max[2], + ] + + # Flatten to [z_min, y_min, x_min, z_max, y_max, x_max] + bbox = np.concatenate([bbox_min, bbox_max]).astype(np.int32) + + return mask_cropped, bbox + + except Exception as e: + show_warning(f"InstanSeg inference failed: {str(e)}") + import traceback + + traceback.print_exc() + return None, None + + def _compute_and_cache_embeddings( + self, + time_frame: int, + image_isotropic: np.ndarray, + iso_shape: tuple, + scale_factors: tuple, + original_shape: tuple, + ) -> torch.Tensor: + """Compute backbone embeddings and add to cache. + + Args: + time_frame: Time frame index + image_isotropic: Normalized isotropic image (Z, Y, X) + iso_shape: Shape of isotropic volume + scale_factors: Scale factors used for rescaling + original_shape: Original image shape + + Returns: + Embeddings tensor (C, D, H, W) on GPU + """ + # Run backbone inference + embeddings = run_backbone(self.model, image_isotropic, self.device) + + # Report GPU memory usage if available + if self.device == "cuda" and torch.cuda.is_available(): + gpu_mem_mb = torch.cuda.memory_allocated() / 1e6 + print(f"GPU memory: {gpu_mem_mb:.0f} MB") + + # Add to cache + self.cached_embeddings[time_frame] = ( + embeddings, + iso_shape, + scale_factors, + original_shape, + ) + self.cache_order.append(time_frame) + + # Evict oldest if cache full + if len(self.cache_order) > self.cache_size: + oldest_time = self.cache_order.pop(0) + if oldest_time in self.cached_embeddings: + del self.cached_embeddings[oldest_time] + + return embeddings + + def _convert_to_isotropic_coords( + self, position: tuple, original_shape: tuple, iso_shape: tuple + ) -> tuple: + """Convert position from original to isotropic coordinates. + + Args: + position: (z, y, x) in original coordinates + original_shape: Original volume shape + iso_shape: Isotropic volume shape + + Returns: + (z_iso, y_iso, x_iso) in isotropic coordinates + """ + scale = np.array(iso_shape) / np.array(original_shape) + position_iso = np.array(position) * scale + return tuple(position_iso) + + def clear_cache(self): + """Clear all cached embeddings.""" + self.cached_embeddings.clear() + self.cache_order.clear() diff --git a/trackedit/run.py b/trackedit/run.py index 529b238..7b2f4bb 100644 --- a/trackedit/run.py +++ b/trackedit/run.py @@ -41,6 +41,9 @@ def run_trackedit( flag_show_hierarchy: bool = True, flag_allow_adding_spherical_cell: bool = False, adding_spherical_cell_radius: int = 10, + flag_allow_adding_instanseg_cell: bool = False, + instanseg_model_path: Optional[str] = None, + instanseg_device: Optional[str] = None, flag_remove_red_flags_at_edge: bool = False, remove_red_flags_at_edge_threshold: int = 10, annotation_mapping: Optional[dict] = None, @@ -66,6 +69,11 @@ def run_trackedit( viewer: Optional existing napari viewer flag_show_hierarchy: Show hierarchy in the viewer flag_allow_adding_spherical_cell: Allow adding spherical cells via button (default: False) + adding_spherical_cell_radius: Radius of spherical cells in pixels (default: 10) + flag_allow_adding_instanseg_cell: Allow adding InstanSeg-segmented cells via button (default: False) + instanseg_model_path: Path to InstanSeg TorchScript model file + (required if flag_allow_adding_instanseg_cell=True) + instanseg_device: Device for InstanSeg inference ('cuda', 'cpu', or None for auto-detect) annotation_mapping: Mapping of annotation ids to names and colors imaging_layer_names: Names for imaging layers. If None, defaults to ['nuclear', 'membrane'] for 2 channels @@ -99,6 +107,26 @@ def run_trackedit( remove_red_flags_at_edge_threshold=remove_red_flags_at_edge_threshold, ) + # Load InstanSeg model if enabled + instanseg_inference = None + if flag_allow_adding_instanseg_cell: + if instanseg_model_path is None: + raise ValueError( + "flag_allow_adding_instanseg_cell=True requires instanseg_model_path" + ) + + import torch + + from trackedit.instanseg_inference import InstanSegInference + + device = instanseg_device or ("cuda" if torch.cuda.is_available() else "cpu") + + instanseg_inference = InstanSegInference( + model_path=instanseg_model_path, + device=device, + cache_size=1, # Reduce memory usage - only cache current frame + ) + # overwrite some motile functions DeleteNodes._apply = create_db_delete_nodes(DB_handler) DeleteEdges._apply = create_db_delete_edges(DB_handler) @@ -114,6 +142,8 @@ def run_trackedit( flag_show_hierarchy=flag_show_hierarchy, flag_allow_adding_spherical_cell=flag_allow_adding_spherical_cell, adding_spherical_cell_radius=adding_spherical_cell_radius, + flag_allow_adding_instanseg_cell=flag_allow_adding_instanseg_cell, + instanseg_inference=instanseg_inference, ) if DB_handler.ndim == 4: viewer.dims.ndisplay = 3 # 3D view diff --git a/trackedit/utils/red_flag_funcs.py b/trackedit/utils/red_flag_funcs.py index 21e38c8..aa7a341 100644 --- a/trackedit/utils/red_flag_funcs.py +++ b/trackedit/utils/red_flag_funcs.py @@ -307,7 +307,8 @@ def filter_red_flags_at_edge( # Keep only 'added'/'removed' red flags that are NOT at the edge at_edge_mask = min_distance_to_edge <= edge_threshold - filtered_non_overlap = non_overlap_red_flags[~at_edge_mask] + # Use .values to avoid index alignment issues between non_overlap_red_flags and at_edge_mask + filtered_non_overlap = non_overlap_red_flags[~at_edge_mask.values] # Combine filtered non-overlap events with all overlap events filtered_red_flags = pd.concat( diff --git a/trackedit/widgets/CustomEditingWidget.py b/trackedit/widgets/CustomEditingWidget.py index 95017f3..04d7018 100644 --- a/trackedit/widgets/CustomEditingWidget.py +++ b/trackedit/widgets/CustomEditingWidget.py @@ -12,6 +12,7 @@ class CustomEditingMenu(EditingMenu): add_cell_button_pressed = Signal(int) duplicate_cell_button_pressed = Signal(int, int) add_spherical_cell_toggled = Signal(bool) # Signal for spherical cell mode toggle + add_instanseg_cell_toggled = Signal(bool) # Signal for InstanSeg cell mode toggle def __init__( self, @@ -19,11 +20,13 @@ def __init__( databasehandler: DatabaseHandler, allow_adding_spherical_cell: bool = False, adding_spherical_cell_radius: int = 10, + allow_adding_instanseg_cell: bool = False, ): super().__init__(viewer) # Call the original init method self.databasehandler = databasehandler self.allow_adding_spherical_cell = allow_adding_spherical_cell self.adding_spherical_cell_radius = adding_spherical_cell_radius + self.allow_adding_instanseg_cell = allow_adding_instanseg_cell main_layout = self.layout() # This retrieves the QVBoxLayout from EditingMenu main_layout.insertWidget(0, QLabel(r"""

Edit tracks

""")) @@ -82,11 +85,34 @@ def __init__( spherical_cell_layout.addWidget(self.add_spherical_cell_btn) node_box.layout().addLayout(spherical_cell_layout) - node_box.setMaximumHeight(200) # Increased to fit spherical cell button - self.setMaximumHeight(480) # Increased to fit spherical cell button - else: + + # Conditionally add InstanSeg cell button + if self.allow_adding_instanseg_cell: + self.add_instanseg_cell_btn = QPushButton("Add InstanSeg Cell") + self.add_instanseg_cell_btn.setCheckable(True) # Toggle on/off + self.add_instanseg_cell_btn.setStyleSheet( + "QPushButton:checked { background-color: #FF5722; color: white; }" + ) + self.add_instanseg_cell_btn.clicked.connect(self._on_instanseg_cell_clicked) + + instanseg_cell_layout = QHBoxLayout() + instanseg_cell_layout.addWidget(self.add_instanseg_cell_btn) + + node_box.layout().addLayout(instanseg_cell_layout) + + # Adjust heights based on which buttons are present + num_extra_buttons = sum( + [self.allow_adding_spherical_cell, self.allow_adding_instanseg_cell] + ) + if num_extra_buttons == 0: node_box.setMaximumHeight(150) # Original height self.setMaximumHeight(430) # Original height + elif num_extra_buttons == 1: + node_box.setMaximumHeight(200) # One extra button + self.setMaximumHeight(480) # One extra button + else: # num_extra_buttons == 2 + node_box.setMaximumHeight(250) # Two extra buttons + self.setMaximumHeight(530) # Two extra buttons def update_add_cell_btn_state(self, text): state, _, _ = self.add_cell_input.validator().validate(text, 0) @@ -123,3 +149,7 @@ def click_on_hierarchy_cell(self, label: int): def _on_spherical_cell_clicked(self, checked): """Emit signal when spherical cell button is toggled.""" self.add_spherical_cell_toggled.emit(checked) + + def _on_instanseg_cell_clicked(self, checked): + """Emit signal when InstanSeg cell button is toggled.""" + self.add_instanseg_cell_toggled.emit(checked) From 693f54f4ebfa52995d7277f26d1f30f5a40ebf80 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 13 Feb 2026 16:50:42 -0800 Subject: [PATCH 5/8] added red flags for trajectory and area changes --- scripts/script_instanseg.py | 4 +- trackedit/DatabaseHandler.py | 13 ++- trackedit/utils/red_flag_funcs.py | 152 ++++++++++++++++++++++++++++++ 3 files changed, 167 insertions(+), 2 deletions(-) diff --git a/scripts/script_instanseg.py b/scripts/script_instanseg.py index 0e01e14..abfab2f 100644 --- a/scripts/script_instanseg.py +++ b/scripts/script_instanseg.py @@ -13,7 +13,9 @@ # **********INPUTS********* # path to the working directory that contains the database file AND metadata.toml: -working_directory = Path("/home/teun.huijben/Documents/data/Thibaut/masks_on_geff/") +working_directory = Path( + "/hpc/projects/group.royer/people/teun.huijben/data/Thibault/4th_exp//masks_on_geff/" +) # name of the database file to start from, or "latest" to start from the latest version, defaults to "data.db" db_filename_start = "latest" # maximum number of frames display, defaults to None (use all frames) diff --git a/trackedit/DatabaseHandler.py b/trackedit/DatabaseHandler.py index b3dbcb7..71e9cce 100644 --- a/trackedit/DatabaseHandler.py +++ b/trackedit/DatabaseHandler.py @@ -28,7 +28,9 @@ combine_red_flags, filter_red_flags_at_edge, find_all_starts_and_ends, + find_area_changes, find_overlapping_cells, + find_trajectory_changes, ) from trackedit.utils.utils import ( annotations_to_zarr, @@ -605,6 +607,7 @@ def db_to_df( NodeDB.z, NodeDB.y, NodeDB.x, + NodeDB.area, NodeDB.generic, ), ) @@ -714,8 +717,16 @@ def find_all_red_flags(self) -> pd.DataFrame: df, self.db_path_new ) + # Trajectory changes (jumps and direction changes) + rfs_trajectory = find_trajectory_changes(df, self.scale) + + # Area/volume changes + rfs_area = find_area_changes(df) + # Combine all red flag detection results - result_df = combine_red_flags(rfs_starts_and_ends, rfs_overlap) + result_df = combine_red_flags( + rfs_starts_and_ends, rfs_overlap, rfs_trajectory, rfs_area + ) # ToDo: make option to filter redflags in the first two timepoints # (useful for neuromast with suboptimal beginning) diff --git a/trackedit/utils/red_flag_funcs.py b/trackedit/utils/red_flag_funcs.py index aa7a341..4dca677 100644 --- a/trackedit/utils/red_flag_funcs.py +++ b/trackedit/utils/red_flag_funcs.py @@ -1,3 +1,4 @@ +import numpy as np import pandas as pd import sqlalchemy as sqla from ultrack.core.database import OverlapDB, Session @@ -322,3 +323,154 @@ def filter_red_flags_at_edge( ) return filtered_red_flags + + +def find_trajectory_changes( + df: pd.DataFrame, + scale: tuple, + displacement_threshold_multiplier: float = 2.5, + angle_threshold_deg: float = 120, + min_displacement_for_angle: float = None, +) -> pd.DataFrame: + """ + Detect trajectory jumps and direction changes. + + Binary classification: 'good' or 'change' for each movement. + A movement is 'change' if EITHER: + - Large displacement (> threshold), OR + - Sharp direction change (> angle_threshold) AND displacement is significant + + Parameters + ---------- + df : pd.DataFrame + DataFrame with columns: track_id, t, z, y, x (indexed by id) + scale : tuple + (z_scale, y_scale, x_scale) for anisotropic scaling + displacement_threshold_multiplier : float + Multiplier for displacement threshold (default 2.5× the 95th percentile) + angle_threshold_deg : float + Angle change threshold in degrees (default 120) + min_displacement_for_angle : float or None + Min displacement to check angles (default: 75th percentile) + + Returns + ------- + pd.DataFrame + Red flags with columns: t, id, event='trajectory_change' + """ + + # Pass 1: Calculate all displacements for adaptive thresholds + all_displacements = [] + for track_id, track_df in df.groupby("track_id"): + if len(track_df) < 2: + continue + track_df = track_df.sort_values("t") + dz = np.diff(track_df["z"].values) * scale[0] + dy = np.diff(track_df["y"].values) * scale[1] + dx = np.diff(track_df["x"].values) * scale[2] + displacements = np.sqrt(dz**2 + dy**2 + dx**2) + all_displacements.extend(displacements) + + if len(all_displacements) == 0: + return pd.DataFrame(columns=["t", "id", "event"]) + + # Set thresholds + disp_threshold = ( + np.percentile(all_displacements, 95) * displacement_threshold_multiplier + ) + if min_displacement_for_angle is None: + min_displacement_for_angle = np.percentile(all_displacements, 75) + + # Pass 2: Classify all movements + changes = [] + for track_id, track_df in df.groupby("track_id"): + if len(track_df) < 2: + continue + + track_df = track_df.sort_values("t") + z = track_df["z"].values + y = track_df["y"].values + x = track_df["x"].values + t = track_df["t"].values + ids = track_df.index.values + + dz = np.diff(z) * scale[0] + dy = np.diff(y) * scale[1] + dx = np.diff(x) * scale[2] + displacements = np.sqrt(dz**2 + dy**2 + dx**2) + + for i, disp in enumerate(displacements): + is_change = False + + # Check 1: Large displacement? + if disp > disp_threshold: + is_change = True + + # Check 2: Sharp direction change (only if displacement is significant)? + if disp > min_displacement_for_angle and i > 0: + v1 = np.array([dz[i - 1], dy[i - 1], dx[i - 1]]) + v2 = np.array([dz[i], dy[i], dx[i]]) + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + + if norm1 > 0 and norm2 > 0: + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + cos_angle = np.clip(cos_angle, -1, 1) + angle_rad = np.arccos(cos_angle) + angle_deg = np.degrees(angle_rad) + + if angle_deg > angle_threshold_deg: + is_change = True + + if is_change: + changes.append( + {"t": t[i + 1], "id": ids[i + 1], "event": "trajectory_change"} + ) + + return pd.DataFrame(changes) + + +def find_area_changes(df: pd.DataFrame, threshold: float = 0.5) -> pd.DataFrame: + """ + Detect drastic area/volume changes. + + Flags cells with >50% (default) area increase or decrease between timesteps. + + Parameters + ---------- + df : pd.DataFrame + DataFrame with columns: track_id, t, area (indexed by id) + threshold : float + Relative change threshold (default 0.5 = 50%) + + Returns + ------- + pd.DataFrame + Red flags with columns: t, id, event='area_change' + """ + + if "area" not in df.columns: + return pd.DataFrame(columns=["t", "id", "event"]) + + changes = [] + + for track_id, track_df in df.groupby("track_id"): + if len(track_df) < 2: + continue + + track_df = track_df.sort_values("t") + areas = track_df["area"].values + t = track_df["t"].values + ids = track_df.index.values + + # Relative change: (new - old) / old + relative_changes = np.diff(areas) / areas[:-1] + + for i, rel_change in enumerate(relative_changes): + # Flag if absolute change > threshold (50% by default) + if np.abs(rel_change) > threshold: + changes.append( + {"t": t[i + 1], "id": ids[i + 1], "event": "area_change"} + ) + + return pd.DataFrame(changes) From 40f59e2d860538cb632681d3cd9b66d9047de955 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Wed, 18 Feb 2026 12:51:55 -0800 Subject: [PATCH 6/8] scale the edge_threshold when filtering red flags --- trackedit/DatabaseHandler.py | 1 + trackedit/utils/red_flag_funcs.py | 54 +++++++++++++++---------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/trackedit/DatabaseHandler.py b/trackedit/DatabaseHandler.py index 71e9cce..5dcea8e 100644 --- a/trackedit/DatabaseHandler.py +++ b/trackedit/DatabaseHandler.py @@ -740,6 +740,7 @@ def find_all_red_flags(self) -> pd.DataFrame: data_shape=self.data_shape_full[1:], edge_threshold=self.remove_red_flags_at_edge_threshold, ndim=self.ndim, + scale=self.scale, ) return result_df diff --git a/trackedit/utils/red_flag_funcs.py b/trackedit/utils/red_flag_funcs.py index 4dca677..4fd989c 100644 --- a/trackedit/utils/red_flag_funcs.py +++ b/trackedit/utils/red_flag_funcs.py @@ -211,6 +211,7 @@ def filter_red_flags_at_edge( data_shape: tuple, edge_threshold: int, ndim: int, + scale: tuple = None, ) -> pd.DataFrame: """ Filter out 'added' and 'removed' red flags near the edge of the field of view (FOV). @@ -232,10 +233,14 @@ def filter_red_flags_at_edge( - For 3D data: (z_max, y_max, x_max) - For 2D data: (y_max, x_max) edge_threshold : int - Distance threshold in pixels - cells within this distance from any - edge are considered "at edge" + Distance threshold in XY pixels. The Z threshold is derived by converting + this to a physical distance and back: z_threshold = edge_threshold * xy_scale / z_scale ndim : int Number of dimensions: 4 for 3D+time, 3 for 2D+time + scale : tuple or None + (z_scale, y_scale, x_scale) for 3D or (y_scale, x_scale) for 2D. + Used to make the Z threshold equivalent in physical units to the XY threshold. + If None, the same pixel threshold is used for all dimensions. Returns ------- @@ -275,39 +280,32 @@ def filter_red_flags_at_edge( df_full[["id", "z", "y", "x"]], on="id", how="left" ) - # Calculate distance to edges for each red flag + # Compute per-dimension pixel thresholds + xy_threshold = edge_threshold + if ndim == 4 and scale is not None: + z_scale, y_scale, x_scale = scale + # Convert xy_threshold to physical distance, then to z pixels + z_threshold = edge_threshold * y_scale / z_scale + else: + z_threshold = edge_threshold # Fall back to same threshold if no scale given + + # Check each dimension against its own threshold and flag cells at any edge + pos = red_flags_with_pos if ndim == 4: - # 3D data: check z, y, x z_max, y_max, x_max = data_shape - - distances_to_edges = pd.DataFrame( - { - "z_min": red_flags_with_pos["z"], - "z_max": z_max - red_flags_with_pos["z"] - 1, - "y_min": red_flags_with_pos["y"], - "y_max": y_max - red_flags_with_pos["y"] - 1, - "x_min": red_flags_with_pos["x"], - "x_max": x_max - red_flags_with_pos["x"] - 1, - } + at_edge_mask = ( + (pos["z"] <= z_threshold) | (z_max - pos["z"] - 1 <= z_threshold) | + (pos["y"] <= xy_threshold) | (y_max - pos["y"] - 1 <= xy_threshold) | + (pos["x"] <= xy_threshold) | (x_max - pos["x"] - 1 <= xy_threshold) ) else: - # 2D data: check y, x only y_max, x_max = data_shape - - distances_to_edges = pd.DataFrame( - { - "y_min": red_flags_with_pos["y"], - "y_max": y_max - red_flags_with_pos["y"] - 1, - "x_min": red_flags_with_pos["x"], - "x_max": x_max - red_flags_with_pos["x"] - 1, - } + at_edge_mask = ( + (pos["y"] <= xy_threshold) | (y_max - pos["y"] - 1 <= xy_threshold) | + (pos["x"] <= xy_threshold) | (x_max - pos["x"] - 1 <= xy_threshold) ) - # Find minimum distance to any edge for each red flag - min_distance_to_edge = distances_to_edges.min(axis=1) - - # Keep only 'added'/'removed' red flags that are NOT at the edge - at_edge_mask = min_distance_to_edge <= edge_threshold + # Keep only red flags that are NOT at the edge # Use .values to avoid index alignment issues between non_overlap_red_flags and at_edge_mask filtered_non_overlap = non_overlap_red_flags[~at_edge_mask.values] From 826e73337924aae8aeab238c5fb924ce96608c01 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Wed, 18 Feb 2026 12:52:34 -0800 Subject: [PATCH 7/8] precommit --- trackedit/utils/red_flag_funcs.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/trackedit/utils/red_flag_funcs.py b/trackedit/utils/red_flag_funcs.py index 4fd989c..df9830f 100644 --- a/trackedit/utils/red_flag_funcs.py +++ b/trackedit/utils/red_flag_funcs.py @@ -294,15 +294,20 @@ def filter_red_flags_at_edge( if ndim == 4: z_max, y_max, x_max = data_shape at_edge_mask = ( - (pos["z"] <= z_threshold) | (z_max - pos["z"] - 1 <= z_threshold) | - (pos["y"] <= xy_threshold) | (y_max - pos["y"] - 1 <= xy_threshold) | - (pos["x"] <= xy_threshold) | (x_max - pos["x"] - 1 <= xy_threshold) + (pos["z"] <= z_threshold) + | (z_max - pos["z"] - 1 <= z_threshold) + | (pos["y"] <= xy_threshold) + | (y_max - pos["y"] - 1 <= xy_threshold) + | (pos["x"] <= xy_threshold) + | (x_max - pos["x"] - 1 <= xy_threshold) ) else: y_max, x_max = data_shape at_edge_mask = ( - (pos["y"] <= xy_threshold) | (y_max - pos["y"] - 1 <= xy_threshold) | - (pos["x"] <= xy_threshold) | (x_max - pos["x"] - 1 <= xy_threshold) + (pos["y"] <= xy_threshold) + | (y_max - pos["y"] - 1 <= xy_threshold) + | (pos["x"] <= xy_threshold) + | (x_max - pos["x"] - 1 <= xy_threshold) ) # Keep only red flags that are NOT at the edge From c1100771abb45f404b2566968004e5da2216a5c3 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Wed, 18 Feb 2026 14:33:52 -0800 Subject: [PATCH 8/8] script to mass-convert geff to db for challenge --- scripts/convert_geff_to_db.py | 95 +++++++++++++++++++++++++++++++++++ trackedit/cli.py | 2 +- trackedit/utils/geff.py | 4 +- 3 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 scripts/convert_geff_to_db.py diff --git a/scripts/convert_geff_to_db.py b/scripts/convert_geff_to_db.py new file mode 100644 index 0000000..455d819 --- /dev/null +++ b/scripts/convert_geff_to_db.py @@ -0,0 +1,95 @@ +""" +Script to convert GEFF files to Ultrack SQLite databases for the cellmotv1 tracking challenge. + +For each experiment folder: + - Runs: trackedit convert geff-to-db .geff -o data.db + - Writes: metadata.toml with shape = [ 100, 64, 256, 256,] + +Usage: + python convert_geff_to_db.py # test mode: processes one experiment, output to Desktop + python convert_geff_to_db.py --all # processes all experiments, output next to geff/zarr +""" + +import subprocess +import sys +from pathlib import Path + +BASE_DIR = Path( + "/hpc/projects/group.royer/people/thibaut.goldsborough/tracking-challenge/cellmotv1" +) + +# Used only in test mode (no write access to BASE_DIR) +TEST_OUTPUT_DIR = Path("/home/teun.huijben/Desktop/tracking-challenge/cellmotv1") + +SHAPE = [100, 64, 256, 256] + +METADATA_CONTENT = f"shape = {SHAPE}\n" + + +def process_experiment(experiment_dir: Path, out_dir: Path) -> bool: + geff_files = list(experiment_dir.glob("*.geff")) + if not geff_files: + print(f" [SKIP] No .geff file found in {experiment_dir}") + return False + + geff_path = geff_files[0] + out_dir.mkdir(parents=True, exist_ok=True) + db_output = out_dir / "data.db" + metadata_path = out_dir / "metadata.toml" + + print(f" Converting: {geff_path.name}") + result = subprocess.run( + ["trackedit", "convert", "geff-to-db", str(geff_path), "-o", str(db_output)], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + print(f" [ERROR] Conversion failed:\n{result.stderr}") + return False + + print(f" Written: {db_output}") + + metadata_path.write_text(METADATA_CONTENT) + print(f" Written: {metadata_path}") + + return True + + +def main(): + run_all = "--all" in sys.argv + + date_folders = sorted(BASE_DIR.iterdir()) + + if run_all: + experiments = [ + exp + for date_folder in date_folders + for exp in sorted(date_folder.iterdir()) + if exp.is_dir() + ] + print(f"Processing all {len(experiments)} experiments...\n") + pairs = [(exp, exp) for exp in experiments] + else: + # Test mode: first experiment of the first folder, write to Desktop + first_folder = date_folders[0] + exp = sorted(first_folder.iterdir())[0] + experiments = [exp] + relative = exp.relative_to(BASE_DIR) + out = TEST_OUTPUT_DIR / relative + pairs = [(exp, out)] + print("TEST MODE: processing 1 experiment") + print(f" Input: {exp}") + print(f" Output: {out}\n") + + success = 0 + for exp, out in pairs: + print(f"[{exp.parent.name}/{exp.name}]") + if process_experiment(exp, out): + success += 1 + + print(f"\nDone: {success}/{len(experiments)} experiments converted successfully.") + + +if __name__ == "__main__": + main() diff --git a/trackedit/cli.py b/trackedit/cli.py index 080e77c..e8131b5 100644 --- a/trackedit/cli.py +++ b/trackedit/cli.py @@ -27,7 +27,7 @@ def convert(): help="Output database path (default: _from_geff.db)", ) def geff_to_db(geff_path: Path, output: Path = None): - """Convert GEFF file to ULTrack SQLite database. + """Convert GEFF file to Ultrack SQLite database. Args: geff_path: Path to the input GEFF file diff --git a/trackedit/utils/geff.py b/trackedit/utils/geff.py index c485964..d24d4f5 100644 --- a/trackedit/utils/geff.py +++ b/trackedit/utils/geff.py @@ -14,7 +14,7 @@ def convert_geff_to_db(geff_path: Path, output_path: Path = None) -> None: - """Convert GEFF file to ULTrack SQLite database. + """Convert GEFF file to Ultrack SQLite database. Args: geff_path: Path to the input GEFF file @@ -226,7 +226,7 @@ def convert_geff_to_db(geff_path: Path, output_path: Path = None) -> None: help="Output database path (default: _to_db.db)", ) def convert_geff_to_db_cli(geff_path: Path, output: Path = None) -> None: - """Convert GEFF file to ULTrack SQLite database (CLI version). + """Convert GEFF file to Ultrack SQLite database (CLI version). Args: geff_path: Path to the input GEFF file