From d5e5a96ac0cb7770a3625857b498f9a2786d3465 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 9 Jun 2026 00:01:39 -0400 Subject: [PATCH 01/39] renamed tracks.graph to tracks.solution_graph, and graph_full looks at _root --- src/funtracks/actions/add_delete_edge.py | 14 +-- src/funtracks/actions/add_delete_node.py | 4 +- src/funtracks/actions/update_segmentation.py | 6 +- src/funtracks/annotators/_edge_annotator.py | 16 +-- src/funtracks/annotators/_graph_annotator.py | 7 +- .../annotators/_regionprops_annotator.py | 12 +-- src/funtracks/annotators/_track_annotator.py | 32 +++--- src/funtracks/data_model/solution_tracks.py | 8 +- src/funtracks/data_model/tracks.py | 97 +++++++++++-------- .../import_export/_export_segmentation.py | 6 +- src/funtracks/import_export/csv/_export.py | 6 +- src/funtracks/import_export/geff/_export.py | 10 +- .../user_actions/_user_swap_predecessors.py | 2 +- src/funtracks/user_actions/user_add_edge.py | 12 +-- src/funtracks/user_actions/user_add_node.py | 8 +- .../user_actions/user_delete_edge.py | 6 +- .../user_actions/user_update_segmentation.py | 4 +- tests/actions/test_action_history.py | 8 +- tests/actions/test_add_delete_edge.py | 42 ++++---- tests/actions/test_add_delete_nodes.py | 45 +++++---- tests/actions/test_update_node_segs.py | 22 ++--- tests/annotators/test_annotator_registry.py | 12 +-- tests/annotators/test_edge_annotator.py | 11 ++- .../annotators/test_regionprops_annotator.py | 20 ++-- tests/annotators/test_track_annotator.py | 10 +- tests/candidate_graph/test_compute_graph.py | 22 +++-- tests/candidate_graph/test_iou.py | 8 +- .../test_relabel_segmentation.py | 2 +- tests/data_model/test_solution_tracks.py | 8 +- tests/data_model/test_tracks.py | 32 +++--- tests/import_export/test_csv_export.py | 16 +-- tests/import_export/test_csv_import.py | 66 +++++++------ tests/import_export/test_export_to_geff.py | 12 ++- tests/import_export/test_import_from_geff.py | 49 +++++----- tests/import_export/test_internal_format.py | 18 ++-- tests/user_actions/test_user_actions_force.py | 14 +-- .../user_actions/test_user_add_delete_edge.py | 42 ++++---- .../user_actions/test_user_add_delete_node.py | 8 +- .../test_user_swap_predecessors.py | 36 +++---- .../test_user_update_node_attrs.py | 24 +++-- .../test_user_update_nodes_attrs.py | 22 +++-- .../test_user_update_segmentation.py | 38 ++++---- 42 files changed, 473 insertions(+), 364 deletions(-) diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index 8db36ee2..e4fc8a18 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -47,25 +47,25 @@ def _apply(self) -> None: """ # Check that both endpoints exist before computing edge attributes for node in self.edge: - if not self.tracks.graph.has_node(node): + if not self.tracks.graph_solution.has_node(node): raise ValueError( f"Cannot add edge {self.edge}: endpoint {node} not in graph yet" ) - if self.tracks.graph.has_edge(*self.edge): + if self.tracks.graph_solution.has_edge(*self.edge): raise ValueError(f"Edge {self.edge} already exists in the graph") attrs = dict(self.attributes) # Fill in missing edge attributes with schema defaults (includes # solution and any other registered edge attrs). - schemas = self.tracks.graph._edge_attr_schemas() - for attr in self.tracks.graph.edge_attr_keys(): + schemas = self.tracks.graph_solution._edge_attr_schemas() + for attr in self.tracks.graph_solution.edge_attr_keys(): if attr not in attrs: attrs[attr] = schemas[attr].default_value # Create edge attributes for this specific edge - self.tracks.graph.add_edge( + self.tracks.graph_solution.add_edge( source_id=self.edge[0], target_id=self.edge[1], attrs=attrs, @@ -89,7 +89,7 @@ def __init__(self, tracks: Tracks, edge: Edge): """ super().__init__(tracks) self.edge = edge - if not self.tracks.graph.has_edge(*self.edge): + if not self.tracks.graph_solution.has_edge(*self.edge): raise ValueError(f"Edge {self.edge} not in the graph, and cannot be removed") # Save all edge feature values from the features dict @@ -106,5 +106,5 @@ def inverse(self) -> BasicAction: return AddEdge(self.tracks, self.edge, attributes=self.attributes) def _apply(self) -> None: - self.tracks.graph.remove_edge(*self.edge) + self.tracks.graph_solution.remove_edge(*self.edge) self.tracks.notify_annotators(self) diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index d7812a1b..5744bed5 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -77,7 +77,7 @@ def inverse(self) -> BasicAction: def _apply(self) -> None: """Add the node with all attributes from self.attributes.""" - self.tracks.graph.add_node( + self.tracks.graph_solution.add_node( attrs=dict(self.attributes), index=self.node, validate_keys=False ) @@ -117,5 +117,5 @@ def inverse(self) -> BasicAction: def _apply(self) -> None: """Remove the node from the graph.""" - self.tracks.graph.remove_node(self.node) + self.tracks.graph_solution.remove_node(self.node) self.tracks.notify_annotators(self) diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py index 4d83a158..7c2cdc58 100644 --- a/src/funtracks/actions/update_segmentation.py +++ b/src/funtracks/actions/update_segmentation.py @@ -61,13 +61,13 @@ def _apply(self) -> None: if value == 0: # val=0 means deleting (part of) the mask - mask_old = self.tracks.graph.nodes[self.node][self.mask_key] + mask_old = self.tracks.graph_solution.nodes[self.node][self.mask_key] mask_subtracted = mask_old.__isub__(mask_new) self.tracks.update_mask(self.node, mask_subtracted, mask_key=self.mask_key) - elif self.tracks.graph.has_node(value): + elif self.tracks.graph_solution.has_node(value): # if node already exists: - mask_old = self.tracks.graph.nodes[value][self.mask_key] + mask_old = self.tracks.graph_solution.nodes[value][self.mask_key] mask_combined = mask_old.__or__(mask_new) self.tracks.update_mask(value, mask_combined, mask_key=self.mask_key) diff --git a/src/funtracks/annotators/_edge_annotator.py b/src/funtracks/annotators/_edge_annotator.py index b612f431..df9202fc 100644 --- a/src/funtracks/annotators/_edge_annotator.py +++ b/src/funtracks/annotators/_edge_annotator.py @@ -82,14 +82,14 @@ def compute(self, feature_keys: list[str] | None = None) -> None: # TODO: add skip edges if self.iou_key in keys_to_compute: nodes_by_frame = defaultdict(list) - for n in self.tracks.graph.node_ids(): + for n in self.tracks.graph_solution.node_ids(): nodes_by_frame[self.tracks.get_time(n)].append(n) for t in range(self.tracks.segmentation.shape[0] - 1): nodes_in_t = nodes_by_frame[t] edges = [] for node in nodes_in_t: - for succ in self.tracks.graph.successors(node): + for succ in self.tracks.graph_solution.successors(node): edges.append((node, succ)) self._iou_update(edges) @@ -105,8 +105,8 @@ def _iou_update( """ for edge in edges: source, target = edge - mask1 = self.tracks.graph.nodes[source]["mask"] - mask2 = self.tracks.graph.nodes[target]["mask"] + mask1 = self.tracks.graph_solution.nodes[source]["mask"] + mask2 = self.tracks.graph_solution.nodes[target]["mask"] iou = mask1.iou(mask2) self.tracks._set_edge_attr(edge, self.iou_key, iou) @@ -136,16 +136,16 @@ def update(self, action: BasicAction): # Get all incident edges to the modified node modified_node = action.node edges_to_update = [] - for pred in self.tracks.graph.predecessors(modified_node): + for pred in self.tracks.graph_solution.predecessors(modified_node): edges_to_update.append((pred, modified_node)) - for succ in self.tracks.graph.successors(modified_node): + for succ in self.tracks.graph_solution.successors(modified_node): edges_to_update.append((modified_node, succ)) # Update IoU for each edge for edge in edges_to_update: source, target = edge - mask1 = self.tracks.graph.nodes[source]["mask"] - mask2 = self.tracks.graph.nodes[target]["mask"] + mask1 = self.tracks.graph_solution.nodes[source]["mask"] + mask2 = self.tracks.graph_solution.nodes[target]["mask"] if mask1.mask.sum() == 0 or mask2.mask.sum() == 0: empty_node = source if mask1.mask.sum() == 0 else target frame = self.tracks.get_time(empty_node) diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py index 7fb8583c..43dc879d 100644 --- a/src/funtracks/annotators/_graph_annotator.py +++ b/src/funtracks/annotators/_graph_annotator.py @@ -106,7 +106,7 @@ def _filter_feature_keys(self, feature_keys: list[str] | None) -> list[str]: def compute(self, feature_keys: list[str] | None = None) -> None: """Compute a set of features and add them to the tracks. - This involves both updating the node/edge attributes on the tracks.graph + This involves both updating the node/edge attributes on the tracks.graph_solution and adding the features to the FeatureDict, if necessary. This is distinct from `update` to allow more efficient bulk computation of features. @@ -123,8 +123,9 @@ def compute(self, feature_keys: list[str] | None = None) -> None: def update(self, action: BasicAction) -> None: """Update a set of features based on the given action. - This involves both updating the node or edge attributes on the tracks.graph - and adding the features to the FeatureDict, if necessary. This is distinct + This involves both updating the node or edge attributes on the + tracks.graph_solution and adding the features to the FeatureDict, if + necessary. This is distinct from `compute` to allow more efficient computation of features for single elements. diff --git a/src/funtracks/annotators/_regionprops_annotator.py b/src/funtracks/annotators/_regionprops_annotator.py index 81ec42db..61953fcc 100644 --- a/src/funtracks/annotators/_regionprops_annotator.py +++ b/src/funtracks/annotators/_regionprops_annotator.py @@ -170,10 +170,10 @@ def compute(self, feature_keys: list[str] | None = None) -> None: all_node_ids = [] all_values: dict[str, list] = {key: [] for key in keys_to_compute} - for node_id in self.tracks.graph.node_ids(): - if not self.tracks.graph.has_node(node_id): + for node_id in self.tracks.graph_solution.node_ids(): + if not self.tracks.graph_solution.has_node(node_id): continue - mask = self.tracks.graph.nodes[node_id]["mask"] + mask = self.tracks.graph_solution.nodes[node_id]["mask"] for region in regionprops_extended(mask, spacing=spacing): all_node_ids.append(node_id) for key in keys_to_compute: @@ -203,7 +203,7 @@ def _regionprops_update( spacing = None if self.tracks.scale is None else tuple(self.tracks.scale[1:]) for region in regionprops_extended(mask, spacing=spacing): # Skip labels that aren't nodes in the graph (e.g., unselected detections) - if not self.tracks.graph.has_node(node_id): + if not self.tracks.graph_solution.has_node(node_id): continue for key in feature_keys: value = getattr(region, self.regionprops_names[key]) @@ -240,7 +240,7 @@ def update(self, action: BasicAction): time = self.tracks.get_time(node) - if self.tracks.graph.nodes[node]["mask"].mask.sum() == 0: + if self.tracks.graph_solution.nodes[node]["mask"].mask.sum() == 0: warnings.warn( f"Cannot find label {node} in frame {time}: " "updating regionprops values to None", @@ -250,7 +250,7 @@ def update(self, action: BasicAction): value = None self.tracks._set_node_attr(node, key, value) else: - mask = self.tracks.graph.nodes[node]["mask"] + mask = self.tracks.graph_solution.nodes[node]["mask"] self._regionprops_update(node, mask, keys_to_compute) def change_key(self, old_key: str, new_key: str) -> None: diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index ae49758d..caf3d748 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -112,13 +112,13 @@ def __init__( self.max_lineage_id = 0 # Initialize tracklet bookkeeping if track IDs already exist in the graph - if tracks.graph.num_nodes() > 0: + if tracks.graph_solution.num_nodes() > 0: max_id, id_to_nodes = self._get_max_id_and_map(self.tracklet_key) self.max_tracklet_id = max_id self.tracklet_id_to_nodes = id_to_nodes # Initialize lineage bookkeeping if lineage IDs already exist - if lineage_key is not None and tracks.graph.num_nodes() > 0: + if lineage_key is not None and tracks.graph_solution.num_nodes() > 0: max_id, id_to_nodes = self._get_max_id_and_map(self.lineage_key) self.max_lineage_id = max_id self.lineage_id_to_nodes = id_to_nodes @@ -133,9 +133,11 @@ def _get_max_id_and_map(self, key: str) -> tuple[int, dict[int, list[int]]]: tuple[int, dict[int, list[int]]]: The maximum id value, and a mapping from ids to a list of nodes with that id. """ - if key not in self.tracks.graph.node_attr_keys(): + if key not in self.tracks.graph_solution.node_attr_keys(): return 0, {} - df = self.tracks.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, key]) + df = self.tracks.graph_solution.node_attrs( + attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, key] + ) id_to_nodes = defaultdict(list) for node, _id in zip(df[td.DEFAULT_ATTR_KEYS.NODE_ID], df[key], strict=True): if _id is None: @@ -199,12 +201,14 @@ def _assign_lineage_ids(self) -> None: attributes will be updated. """ - lineages_internal = rx.weakly_connected_components(self.tracks.graph.rx_graph) + lineages_internal = rx.weakly_connected_components( + self.tracks.graph_solution.rx_graph + ) lineages_external = [] for lin in lineages_internal: node_ids_internal = list(lin) node_ids_external = [ - self.tracks.graph.node_ids()[nid] for nid in node_ids_internal + self.tracks.graph_solution.node_ids()[nid] for nid in node_ids_internal ] lineages_external.append(node_ids_external) @@ -218,19 +222,21 @@ def _assign_tracklet_ids(self) -> None: After removing division edges, each connected component will get a unique ID, and the relevant class attributes will be updated. """ - graph_copy = self.tracks.graph.detach().filter().subgraph() + graph_copy = self.tracks.graph_solution.detach().filter().subgraph() parents = [ node for node, degree in zip( - self.tracks.graph.node_ids(), self.tracks.graph.out_degree(), strict=True + self.tracks.graph_solution.node_ids(), + self.tracks.graph_solution.out_degree(), + strict=True, ) if degree >= 2 ] # Remove all intertrack edges from a copy of the original graph for parent in parents: - all_edges = self.tracks.graph.edge_list() + all_edges = self.tracks.graph_solution.edge_list() daughters = [edge[1] for edge in all_edges if edge[0] == parent] for daughter in daughters: @@ -247,7 +253,7 @@ def _assign_tracklet_ids(self) -> None: self.tracklet_id_to_nodes[track_id] = node_ids_external track_id += 1 if all_node_ids: - self.tracks.graph.update_node_attrs( + self.tracks.graph_solution.update_node_attrs( attrs={self.tracks.features.tracklet_key: all_track_ids}, node_ids=all_node_ids, ) @@ -317,18 +323,18 @@ def _handle_update_track_ids(self, action: UpdateTrackIDs) -> None: still_in_tracklet = False # Continue to all successors - next_nodes.extend(self.tracks.graph.successors(node)) + next_nodes.extend(self.tracks.graph_solution.successors(node)) curr_nodes = next_nodes # Bulk-write all collected node attribute changes in one call each if update_tracklet and tracklet_nodes: - self.tracks.graph.update_node_attrs( + self.tracks.graph_solution.update_node_attrs( attrs={self.tracklet_key: [new_tracklet_id] * len(tracklet_nodes)}, node_ids=tracklet_nodes, ) if update_lineage and lineage_nodes: - self.tracks.graph.update_node_attrs( + self.tracks.graph_solution.update_node_attrs( attrs={self.lineage_key: [new_lineage_id] * len(lineage_nodes)}, node_ids=lineage_nodes, ) diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 9b3eaaee..bb060425 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -101,7 +101,7 @@ def from_tracks(cls, tracks: Tracks): if ( tracks.features.tracklet_key is not None and ( - tracks.graph.node_attrs(attr_keys=tracks.features.tracklet_key)[ + tracks.graph_solution.node_attrs(attr_keys=tracks.features.tracklet_key)[ tracks.features.tracklet_key ] == -1 @@ -111,7 +111,7 @@ def from_tracks(cls, tracks: Tracks): force_recompute = True soln_tracks = cls( - tracks.graph, + tracks.graph_solution, scale=tracks.scale, ndim=tracks.ndim, features=tracks.features, @@ -164,7 +164,9 @@ def get_track_ids(self, nodes) -> list[int]: if self.features.tracklet_key is None: raise ValueError("Tracklet key not initialized in features") tracklet_key = self.features.tracklet_key - df = self.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key]) + df = self.graph_solution.node_attrs( + attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key] + ) id_to_val = dict( zip( df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list(), diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 0cc16dc8..6669c605 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -52,8 +52,10 @@ class Tracks: position attribute. Edges in the graph represent links across time. Attributes: - graph (td.graph.GraphView): A graph with nodes representing detections and - and edges representing links across time. + graph_solution (td.graph.GraphView): A solution=True view of the full graph, + with nodes representing detections and edges representing links across time. + graph_full (td.graph.GraphView): The full graph: every node/edge ever known, + including soft-deleted (solution=False) ones. Backs graph_solution. features (FeatureDict): Dictionary of features tracked on graph nodes/edges. annotators (AnnotatorRegistry): List of annotators that compute features. scale (list[float] | None): How much to scale each dimension by, including time. @@ -103,7 +105,7 @@ def __init__( _segmentation (GraphArrayView | None): Internal parameter for reusing an existing GraphArrayView instance. Not intended for public use. """ - self.graph = graph + self.graph_solution = graph if _segmentation is not None: # Reuse provided segmentation instance (internal use only) self.segmentation = _segmentation @@ -173,6 +175,16 @@ def __init__( else: self._setup_core_computed_features() + @property + def graph_full(self) -> td.graph.GraphView: + """The full graph: every node/edge ever known, including soft-deleted + (solution=False) ones. `graph_solution` is a solution=True view of this. + + Backed by the solution view's root, so the two can never drift: rebuilding + graph_solution via graph_full.filter(...).subgraph() keeps the same root. + """ + return self.graph_solution._root + def _get_feature_set( self, time_attr: str | None, @@ -247,7 +259,7 @@ def _get_feature_set( # else: single pos_attr with segmentation - RegionpropsAnnotator will handle it # Register solution feature when present on the graph - if "solution" in self.graph.node_attr_keys(): + if "solution" in self.graph_solution.node_attr_keys(): feature_dict["solution"] = Solution() # Register mask and bbox features if segmentation exists @@ -326,11 +338,11 @@ def _check_existing_feature(self, key: str) -> bool: bool: True if the key is on the first sampled node or there are no nodes, and False if missing from the first node. """ - if self.graph.num_nodes() == 0: + if self.graph_solution.num_nodes() == 0: return True # Check which attributes exist - node_attrs = set(self.graph.node_attr_keys()) + node_attrs = set(self.graph_solution.node_attr_keys()) return key in node_attrs def _setup_core_computed_features(self) -> None: @@ -368,10 +380,10 @@ def _setup_core_computed_features(self) -> None: self.enable_features([key]) def nodes(self): - return np.array(self.graph.node_ids()) + return np.array(self.graph_solution.node_ids()) def edges(self): - return np.array(self.graph.edge_ids()) + return np.array(self.graph_solution.edge_ids()) def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: """Get the in-degree edge_ids of the nodes in the graph.""" @@ -380,9 +392,11 @@ def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: if not isinstance(nodes, np.ndarray): nodes = np.array(nodes) - return np.array([self.graph.in_degree(node.item()) for node in nodes]) + return np.array( + [self.graph_solution.in_degree(node.item()) for node in nodes] + ) else: - return np.array(self.graph.in_degree()) + return np.array(self.graph_solution.in_degree()) def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: if nodes is not None: @@ -390,15 +404,17 @@ def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: if not isinstance(nodes, np.ndarray): nodes = np.array(nodes) - return np.array([self.graph.out_degree(node.item()) for node in nodes]) + return np.array( + [self.graph_solution.out_degree(node.item()) for node in nodes] + ) else: - return np.array(self.graph.out_degree()) + return np.array(self.graph_solution.out_degree()) def predecessors(self, node: int) -> list[int]: - return list(self.graph.predecessors(node)) + return list(self.graph_solution.predecessors(node)) def successors(self, node: int) -> list[int]: - return list(self.graph.successors(node)) + return list(self.graph_solution.successors(node)) def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.ndarray: """Get the positions of nodes in the graph. Optionally include the @@ -430,7 +446,7 @@ def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.nd + ([self.features.time_key] if incl_time else []) ) - df = self.graph.node_attrs(attr_keys=attr_keys) + df = self.graph_solution.node_attrs(attr_keys=attr_keys) id_to_row = { nid: i for i, nid in enumerate(df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list()) } @@ -505,7 +521,7 @@ def get_times(self, nodes: Iterable[Node]) -> Sequence[int]: For a single node use get_time() instead. """ nodes = list(nodes) - df = self.graph.node_attrs( + df = self.graph_solution.node_attrs( attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, self.features.time_key] ) id_to_val = dict( @@ -546,7 +562,7 @@ def get_mask( if self.segmentation is None: return None - mask = self.graph.nodes[node][mask_key] + mask = self.graph_solution.nodes[node][mask_key] return mask def update_mask( @@ -563,13 +579,13 @@ def update_mask( mask_key: The feature key for the mask column. Defaults to the standard mask key. """ - self.graph.nodes[node][mask_key] = mask + self.graph_solution.nodes[node][mask_key] = mask mask_feature = self.features.get(mask_key) if mask_feature is not None: # NOTE: all derived features of a mask are currently assumed to be # its bounding box. Revisit if non-bbox derived features are added. for derived_key in mask_feature.get("derived_features", []): - self.graph.nodes[node][derived_key] = mask.bbox + self.graph_solution.nodes[node][derived_key] = mask.bbox def undo(self) -> bool: """Undo the last performed action from the action history. @@ -603,10 +619,13 @@ def _get_new_node_ids(self, n: int) -> list[Node]: Returns: list[Node]: A list of new node ids. """ + # Check against graph_full, not the solution view: a soft-deleted + # (solution=False) node still occupies its id in the full graph, so reissuing + # it to a genuinely new node would collide with the still-present root node. ids = [self.node_id_counter + i for i in range(n)] self.node_id_counter += n for idx, _id in enumerate(ids): - while self.graph.has_node(_id): + while self.graph_full.has_node(_id): _id = self.node_id_counter self.node_id_counter += 1 ids[idx] = _id @@ -638,34 +657,36 @@ def _compute_ndim( def _set_node_attr(self, node: Node, attr: str, value: Any): if isinstance(value, np.ndarray): value = list(value) - self.graph.nodes[node][attr] = value + self.graph_solution.nodes[node][attr] = value def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): nodes_list = list(nodes) values_list = list(values) if nodes_list: - self.graph.update_node_attrs(attrs={attr: values_list}, node_ids=nodes_list) + self.graph_solution.update_node_attrs( + attrs={attr: values_list}, node_ids=nodes_list + ) def get_node_attr(self, node: Node, attr: str): - return self.graph.nodes[int(node)][attr] + return self.graph_solution.nodes[int(node)][attr] def get_nodes_attr(self, nodes: Iterable[Node], attr: str): return [self.get_node_attr(node, attr) for node in nodes] def _set_edge_attr(self, edge: Edge, attr: str, value: Any): - edge_id = self.graph.edge_id(edge[0], edge[1]) - self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) + edge_id = self.graph_solution.edge_id(edge[0], edge[1]) + self.graph_solution.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]): for edge, value in zip(edges, values, strict=False): - edge_id = self.graph.edge_id(edge[0], edge[1]) - self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) + edge_id = self.graph_solution.edge_id(edge[0], edge[1]) + self.graph_solution.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def get_edge_attr(self, edge: Edge, attr: str): - if attr not in self.graph.edge_attr_keys(): + if attr not in self.graph_solution.edge_attr_keys(): return None - edge_id = self.graph.edge_id(edge[0], edge[1]) - return self.graph.edges[edge_id][attr] + edge_id = self.graph_solution.edge_id(edge[0], edge[1]) + return self.graph_solution.edges[edge_id][attr] def get_edges_attr(self, edges: Iterable[Edge], attr: str): return [self.get_edge_attr(edge, attr) for edge in edges] @@ -752,19 +773,19 @@ def add_feature(self, key: str, feature: Feature) -> None: # Perform custom graph operations when a feature is added ft = feature["feature_type"] - if "node" in ft and key not in self.graph.node_attr_keys(): + if "node" in ft and key not in self.graph_solution.node_attr_keys(): # "mask" value_type maps to pl.Object via to_polars_dtype dtype = to_polars_dtype(feature["value_type"]) num_values = feature.get("num_values") if num_values is not None and num_values > 1: dtype = pl.Array(dtype, num_values) - self.graph.add_node_attr_key( + self.graph_solution.add_node_attr_key( key, default_value=feature["default_value"], dtype=dtype, ) - if "edge" in ft and key not in self.graph.edge_attr_keys(): - self.graph.add_edge_attr_key( + if "edge" in ft and key not in self.graph_solution.edge_attr_keys(): + self.graph_solution.add_edge_attr_key( key, default_value=feature["default_value"], dtype=to_polars_dtype(feature["value_type"]), @@ -800,7 +821,7 @@ def delete_feature(self, key: str) -> None: return # Perform custom graph operations when a feature is deleted - if "node" in feature_type and key in self.graph.node_attr_keys(): - self.graph.remove_node_attr_key(key) - if "edge" in feature_type and key in self.graph.edge_attr_keys(): - self.graph.remove_edge_attr_key(key) + if "node" in feature_type and key in self.graph_solution.node_attr_keys(): + self.graph_solution.remove_node_attr_key(key) + if "edge" in feature_type and key in self.graph_solution.edge_attr_keys(): + self.graph_solution.remove_edge_attr_key(key) diff --git a/src/funtracks/import_export/_export_segmentation.py b/src/funtracks/import_export/_export_segmentation.py index 0b4f5ddb..070d87e5 100644 --- a/src/funtracks/import_export/_export_segmentation.py +++ b/src/funtracks/import_export/_export_segmentation.py @@ -45,7 +45,7 @@ def resolve_relabel_attr( else: return None - existing_attrs = tracks.graph.node_attr_keys() + existing_attrs = tracks.graph_solution.node_attr_keys() if label_attr not in existing_attrs: raise ValueError( f"relabel='{relabel}' resolved to attribute '{label_attr}', " @@ -97,9 +97,9 @@ def export_segmentation( if label_attr is not None: graph = ( - tracks.graph.filter(node_ids=list(node_ids)).subgraph() + tracks.graph_solution.filter(node_ids=list(node_ids)).subgraph() if node_ids is not None - else tracks.graph + else tracks.graph_solution ) view = GraphArrayView(graph, label_attr, shape=shape) diff --git a/src/funtracks/import_export/csv/_export.py b/src/funtracks/import_export/csv/_export.py index 39b79525..e155845b 100644 --- a/src/funtracks/import_export/csv/_export.py +++ b/src/funtracks/import_export/csv/_export.py @@ -155,15 +155,15 @@ def convert_numpy_to_python(value): # Determine which nodes to export if node_ids is None: - node_to_keep = tracks.graph.node_ids() + node_to_keep = tracks.graph_solution.node_ids() else: - node_to_keep = filter_graph_with_ancestors(tracks.graph, node_ids) + node_to_keep = filter_graph_with_ancestors(tracks.graph_solution, node_ids) # Write CSV file rows: list[dict[str, Any]] = [] for node_id in node_to_keep: - parents = list(tracks.graph.predecessors(node_id)) + parents = list(tracks.graph_solution.predecessors(node_id)) parent_id = "" if len(parents) == 0 else parents[0] row: dict[str, Any] diff --git a/src/funtracks/import_export/geff/_export.py b/src/funtracks/import_export/geff/_export.py index 31b0fbfa..774f173f 100644 --- a/src/funtracks/import_export/geff/_export.py +++ b/src/funtracks/import_export/geff/_export.py @@ -56,7 +56,7 @@ def export_to_geff( if node_ids is not None: nodes_to_keep = filter_graph_with_ancestors( - tracks.graph, node_ids + tracks.graph_solution, node_ids ) # include the ancestors to make sure the graph is valid and has no missing # parent nodes. @@ -139,7 +139,7 @@ def export_to_geff( # GeffMetadata has no segmentation_shape field, so it must be stored separately. # This allows import_from_geff to reconstruct the segmentation (GraphArrayView) # without requiring an external segmentation file. - seg_shape = tracks.graph.metadata.get("segmentation_shape") + seg_shape = tracks.graph_solution.metadata.get("segmentation_shape") if seg_shape is not None: import zarr as _zarr @@ -168,7 +168,7 @@ def split_position_attr(tracks: Tracks) -> tuple[td.graph.GraphView, list[str] | if isinstance(pos_key, str): # Position is stored as a single attribute, need to split - new_graph = tracks.graph.detach() + new_graph = tracks.graph_solution.detach() new_graph = new_graph.filter().subgraph() # Register new attribute keys @@ -200,6 +200,6 @@ def split_position_attr(tracks: Tracks) -> tuple[td.graph.GraphView, list[str] | return new_graph, new_keys elif pos_key is not None: # Position is already split into separate attributes - return tracks.graph, list(pos_key) + return tracks.graph_solution, list(pos_key) else: - return tracks.graph, None + return tracks.graph_solution, None diff --git a/src/funtracks/user_actions/_user_swap_predecessors.py b/src/funtracks/user_actions/_user_swap_predecessors.py index 2353a4f0..da88bfec 100644 --- a/src/funtracks/user_actions/_user_swap_predecessors.py +++ b/src/funtracks/user_actions/_user_swap_predecessors.py @@ -40,7 +40,7 @@ def __init__( raise InvalidActionError("You can only swap a pair of two nodes.") node1, node2 = nodes - graph = tracks.graph + graph = tracks.graph_solution # Find predecessors pred1 = graph.predecessors(node1)[0] if graph.predecessors(node1) else None diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py index cf05f002..d281c505 100644 --- a/src/funtracks/user_actions/user_add_edge.py +++ b/src/funtracks/user_actions/user_add_edge.py @@ -37,18 +37,18 @@ def __init__( super().__init__(tracks, actions=[]) self.tracks: SolutionTracks # Narrow type from base class source, target = edge - if not tracks.graph.has_node(source): + if not tracks.graph_solution.has_node(source): raise InvalidActionError( f"Source node {source} not in solution yet - must be added before edge" ) - if not tracks.graph.has_node(target): + if not tracks.graph_solution.has_node(target): raise InvalidActionError( f"Target node {target} not in solution yet - must be added before edge" ) # Check if making a merge. If yes and force, remove the other edge and update # track ids. - in_degree_target = self.tracks.graph.in_degree(target) + in_degree_target = self.tracks.graph_solution.in_degree(target) if in_degree_target > 0: if not force: raise InvalidActionError( @@ -57,7 +57,7 @@ def __init__( forceable=True, ) else: - pred = next(iter(self.tracks.graph.predecessors(target))) + pred = next(iter(self.tracks.graph_solution.predecessors(target))) merge_edge = (pred, target) warnings.warn( f"Removing edge {merge_edge} to add new edge without merging.", @@ -68,7 +68,7 @@ def __init__( ) # update track ids if needed - out_degree_source = self.tracks.graph.out_degree(source) + out_degree_source = self.tracks.graph_solution.out_degree(source) if out_degree_source == 0: # joining two segments # assign the track id and lineage id of the source node to the target # and all downstream nodes @@ -79,7 +79,7 @@ def __init__( ) elif out_degree_source == 1: # creating a division # assign a new track id to existing child (lineage stays the same) - successor = next(iter(self.tracks.graph.successors(source))) + successor = next(iter(self.tracks.graph_solution.successors(source))) self.actions.append( UpdateTrackIDs(self.tracks, successor, self.tracks.get_next_track_id()) ) diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index d1160e7a..308eb15c 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -77,7 +77,7 @@ def __init__( raise InvalidActionError( f"Cannot add node without track id. Please add {track_id_key} attribute" ) - if self.tracks.graph.has_node(node): + if self.tracks.graph_solution.has_node(node): raise InvalidActionError( f"Node {node} already exists in the tracks, cannot add." ) @@ -98,7 +98,7 @@ def __init__( pred, succ = self.tracks.get_track_neighbors(track_id, time) # check if you are adding a node to a track that divided previously - if pred is not None and self.tracks.graph.out_degree(int(pred)) == 2: + if pred is not None and self.tracks.graph_solution.out_degree(int(pred)) == 2: if not force: raise InvalidActionError( "Cannot add node here - upstream division event detected.", @@ -118,11 +118,11 @@ def __init__( # downstream elif succ is not None: # check pred of succ - preds = self.tracks.graph.predecessors(succ) + preds = self.tracks.graph_solution.predecessors(succ) pred_of_succ = preds[0] if preds else None if ( pred_of_succ is not None - and self.tracks.graph.out_degree(pred_of_succ) == 2 + and self.tracks.graph_solution.out_degree(pred_of_succ) == 2 ): if not force: raise InvalidActionError( diff --git a/src/funtracks/user_actions/user_delete_edge.py b/src/funtracks/user_actions/user_delete_edge.py index 8af17c80..84bc0a01 100644 --- a/src/funtracks/user_actions/user_delete_edge.py +++ b/src/funtracks/user_actions/user_delete_edge.py @@ -32,11 +32,11 @@ def __init__( """ super().__init__(tracks, actions=[]) self.tracks: SolutionTracks # Narrow type from base class - if not self.tracks.graph.has_edge(*edge): + if not self.tracks.graph_solution.has_edge(*edge): raise InvalidActionError(f"Edge {edge} not in solution, can't remove") self.actions.append(DeleteEdge(tracks, edge)) - out_degree = self.tracks.graph.out_degree(edge[0]) + out_degree = self.tracks.graph_solution.out_degree(edge[0]) if out_degree == 0: # removed a normal (non division) edge # orphaned segment gets new track id and new lineage id new_track_id = self.tracks.get_next_track_id() @@ -46,7 +46,7 @@ def __init__( ) elif out_degree == 1: # removed a division edge # sibling gets parent's track id (lineage stays the same) - sibling = next(iter(self.tracks.graph.successors(edge[0]))) + sibling = next(iter(self.tracks.graph_solution.successors(edge[0]))) new_track_id = self.tracks.get_track_id(edge[0]) self.actions.append(UpdateTrackIDs(self.tracks, sibling, new_track_id)) # orphaned child gets a new lineage id (now a separate component) diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index 6a312265..74b7e8ac 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -62,7 +62,7 @@ def __init__( "Can only update one time point at a time" ) time = int(all_pixels[0][0]) - if self.tracks.graph.has_node(new_value): + if self.tracks.graph_solution.has_node(new_value): mask_pixels = pixels_to_td_mask(all_pixels, self.tracks.ndim) self.actions.append( UpdateNodeSeg(tracks, new_value, mask_pixels, added=True) @@ -96,7 +96,7 @@ def __init__( time = pixels[0][0] # check if all pixels of old_value are removed mask_pixels = pixels_to_td_mask(pixels, self.tracks.ndim) - mask_old_value = self.tracks.graph.nodes[old_value]["mask"] + mask_old_value = self.tracks.graph_solution.nodes[old_value]["mask"] # If pixels fully overlaps with old_value mask, delete node if mask_pixels.intersection(mask_old_value) == mask_old_value.mask.sum(): self.actions.append( diff --git a/tests/actions/test_action_history.py b/tests/actions/test_action_history.py index 4365f22c..5afed0d7 100644 --- a/tests/actions/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -24,7 +24,7 @@ def test_action_history(): history.add_new_action(action1) # undo the action assert history.undo() - assert tracks.graph.num_nodes() == 0 + assert tracks.graph_solution.num_nodes() == 0 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 1 assert history._undo_pointer == -1 @@ -34,7 +34,7 @@ def test_action_history(): # redo the action assert history.redo() - assert tracks.graph.num_nodes() == 1 + assert tracks.graph_solution.num_nodes() == 1 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 0 assert history._undo_pointer == 0 @@ -46,7 +46,7 @@ def test_action_history(): assert history.undo() action2 = AddNode(tracks, node=10, attributes={"t": 10, "pos": pos, "track_id": 2}) history.add_new_action(action2) - assert tracks.graph.num_nodes() == 1 + assert tracks.graph_solution.num_nodes() == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 0 @@ -55,7 +55,7 @@ def test_action_history(): # undo back to after action 1 assert history.undo() assert history.undo() - assert tracks.graph.num_nodes() == 1 + assert tracks.graph_solution.num_nodes() == 1 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 2 diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index de856974..14af4d3f 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -19,12 +19,12 @@ @pytest.mark.parametrize("with_seg", [True, False]) def test_add_delete_edges(get_tracks, ndim, with_seg): tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - reference_graph = tracks.graph + reference_graph = tracks.graph_solution reference_seg = np.asarray(tracks.segmentation).copy() # Create an empty tracks with just nodes (no edges) - for edge in tracks.graph.edge_list(): - tracks.graph.remove_edge(*edge) + for edge in tracks.graph_solution.edge_list(): + tracks.graph_solution.remove_edge(*edge) edges = [(1, 2), (1, 3), (3, 4), (4, 5)] @@ -35,9 +35,9 @@ def test_add_delete_edges(get_tracks, ndim, with_seg): # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) assert_frame_equal( - tracks.graph.edge_attrs(), + tracks.graph_solution.edge_attrs(), reference_graph.edge_attrs(), check_row_order=False, check_column_order=False, @@ -47,16 +47,18 @@ def test_add_delete_edges(get_tracks, ndim, with_seg): inverse = action.inverse() - assert set(tracks.graph.edge_ids()) == set() + assert set(tracks.graph_solution.edge_ids()) == set() if tracks.segmentation is not None: assert_array_almost_equal(tracks.segmentation, reference_seg) re_added = inverse.inverse() - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - assert set(tracks.graph.edge_ids()) == set(reference_graph.edge_ids()) - assert sorted(tracks.graph.edge_list()) == sorted(reference_graph.edge_list()) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + assert set(tracks.graph_solution.edge_ids()) == set(reference_graph.edge_ids()) + assert sorted(tracks.graph_solution.edge_list()) == sorted( + reference_graph.edge_list() + ) assert_frame_equal( - tracks.graph.edge_attrs(), + tracks.graph_solution.edge_attrs(), reference_graph.edge_attrs(), check_row_order=False, check_column_order=False, @@ -74,7 +76,7 @@ def test_add_delete_edges(get_tracks, ndim, with_seg): # objects again — that's where the corruption surfaces. re_added.inverse() # reset state: edges absent (fresh objects, no bug here) inverse.inverse() # same DeleteEdge objects called again — must not crash - assert set(tracks.graph.edge_ids()) == set(reference_graph.edge_ids()) + assert set(tracks.graph_solution.edge_ids()) == set(reference_graph.edge_ids()) def test_add_edge_missing_endpoint(get_tracks): @@ -138,25 +140,25 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg): action = AddEdge(tracks, edge, attributes=custom_attrs) # Verify all attributes are present after adding - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) for key, value in custom_attrs.items(): - edge_id = tracks.graph.edge_id(*edge) - assert tracks.graph.edges[edge_id][key] == value, ( + edge_id = tracks.graph_solution.edge_id(*edge) + assert tracks.graph_solution.edges[edge_id][key] == value, ( f"Attribute {key} not set correctly on edge" ) # Delete the edge delete_action = action.inverse() - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) # Re-add the edge by inverting the delete delete_action.inverse() - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) # Verify all custom attributes are still present after re-adding for key, value in custom_attrs.items(): - edge_id = tracks.graph.edge_id(*edge) - assert tracks.graph.edges[edge_id][key] == value, ( + edge_id = tracks.graph_solution.edge_id(*edge) + assert tracks.graph_solution.edges[edge_id][key] == value, ( f"Attribute {key} not preserved after delete/re-add cycle" ) @@ -220,10 +222,10 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path): tracks = SolutionTracks(graph, ndim=3, features=features) # Sanity: "custom_score" is in the graph schema but NOT in tracks.features. - assert "custom_score" in tracks.graph.edge_attr_keys() + assert "custom_score" in tracks.graph_solution.edge_attr_keys() assert "custom_score" not in tracks.features # Before the fix this raises: KeyError: 'custom_score' AddEdge(tracks, (1, 2)) - assert tracks.graph.has_edge(1, 2) + assert tracks.graph_solution.has_edge(1, 2) diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py index 9d6b99c4..f6983e8c 100644 --- a/tests/actions/test_add_delete_nodes.py +++ b/tests/actions/test_add_delete_nodes.py @@ -21,7 +21,7 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): # Get a tracks instance tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - reference_graph = tracks.graph + reference_graph = tracks.graph_solution reference_seg = np.asarray(tracks.segmentation).copy() if with_seg else None # Start with an empty Tracks @@ -38,11 +38,14 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): ndim=ndim, ) empty_seg = np.zeros_like(tracks.segmentation) if with_seg else None - tracks.graph = empty_graph + tracks.graph_solution = empty_graph segmentation_shape = (5, 100, 100) if ndim == 3 else (5, 100, 100, 100) tracks.segmentation = ( GraphArrayView( - graph=tracks.graph, shape=segmentation_shape, attr_key="node_id", offset=0 + graph=tracks.graph_solution, + shape=segmentation_shape, + attr_key="node_id", + offset=0, ) if with_seg else None @@ -78,8 +81,8 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): actions.append(AddNode(tracks, node, attributes=attrs)) action = ActionGroup(tracks=tracks, actions=actions) - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - data_tracks = tracks.graph.node_attrs() + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + data_tracks = tracks.graph_solution.node_attrs() data_reference = reference_graph.node_attrs() if with_seg: assert_array_almost_equal(tracks.segmentation, reference_seg) @@ -95,14 +98,14 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): # Invert the action to delete all the nodes del_nodes = action.inverse() - assert set(tracks.graph.node_ids()) == set(empty_graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set(empty_graph.node_ids()) if with_seg: assert_array_almost_equal(tracks.segmentation, empty_seg) # Re-invert the action to add back all the nodes and their attributes del_nodes.inverse() - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - data_tracks = tracks.graph.node_attrs() + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + data_tracks = tracks.graph_solution.node_attrs() data_reference = reference_graph.node_attrs() if with_seg: assert_array_almost_equal(tracks.segmentation, reference_seg) @@ -197,36 +200,44 @@ def test_custom_attributes_preserved(get_tracks, ndim, with_seg): node_id = 100 action = AddNode(tracks, node_id, custom_attrs.copy()) # Verify all attributes are present after adding - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) for key, value in custom_attrs.items(): if key == "pos": - assert_array_almost_equal(tracks.graph.nodes[node_id][key], np.array(value)) + assert_array_almost_equal( + tracks.graph_solution.nodes[node_id][key], np.array(value) + ) elif key == "mask": continue elif key == "bbox": - assert_array_equal(np.asarray(tracks.graph.nodes[node_id][key]), value) + assert_array_equal( + np.asarray(tracks.graph_solution.nodes[node_id][key]), value + ) else: - assert tracks.graph.nodes[node_id][key] == value, ( + assert tracks.graph_solution.nodes[node_id][key] == value, ( f"Attribute {key} not preserved after add" ) # Delete the node delete_action = action.inverse() - assert node_id not in tracks.graph.node_ids() + assert node_id not in tracks.graph_solution.node_ids() # Re-add the node by inverting the delete delete_action.inverse() - assert node_id in tracks.graph.node_ids() + assert node_id in tracks.graph_solution.node_ids() # Verify all custom attributes are still present after re-adding for key, value in custom_attrs.items(): if key == "pos": - assert_array_almost_equal(tracks.graph.nodes[node_id][key], np.array(value)) + assert_array_almost_equal( + tracks.graph_solution.nodes[node_id][key], np.array(value) + ) elif key == "mask": continue elif key == "bbox": - assert_array_equal(np.asarray(tracks.graph.nodes[node_id][key]), value) + assert_array_equal( + np.asarray(tracks.graph_solution.nodes[node_id][key]), value + ) else: - assert tracks.graph.nodes[node_id][key] == value, ( + assert tracks.graph_solution.nodes[node_id][key] == value, ( f"Attribute {key} not preserved after delete/re-add cycle" ) diff --git a/tests/actions/test_update_node_segs.py b/tests/actions/test_update_node_segs.py index 73a98af4..8c7a66d8 100644 --- a/tests/actions/test_update_node_segs.py +++ b/tests/actions/test_update_node_segs.py @@ -11,14 +11,14 @@ def test_update_node_segs(get_tracks, ndim): # Get tracks with segmentation tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) - reference_graph = tracks.graph.detach().filter().subgraph() + reference_graph = tracks.graph_solution.detach().filter().subgraph() node = 1 time = tracks.get_time(node) original_seg = np.asarray(tracks.segmentation).copy() - original_area = tracks.graph.nodes[1]["area"] - original_pos = tracks.graph.nodes[1]["pos"] + original_area = tracks.graph_solution.nodes[1]["area"] + original_pos = tracks.graph_solution.nodes[1]["pos"] # Add a couple pixels to the first node new_seg = np.asarray(tracks.segmentation).copy() @@ -31,22 +31,22 @@ def test_update_node_segs(get_tracks, ndim): action = UpdateNodeSeg(tracks, node, mask=mask, added=True) - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - assert tracks.graph.nodes[1]["area"] == original_area + 1 - assert not np.allclose(tracks.graph.nodes[1]["pos"], original_pos) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + assert tracks.graph_solution.nodes[1]["area"] == original_area + 1 + assert not np.allclose(tracks.graph_solution.nodes[1]["pos"], original_pos) assert_array_almost_equal(tracks.segmentation, new_seg) inverse = action.inverse() - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) assert_series_equal( reference_graph.nodes[1]["pos"], - tracks.graph.nodes[1]["pos"], + tracks.graph_solution.nodes[1]["pos"], ) assert_array_almost_equal(tracks.segmentation, original_seg) inverse.inverse() - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - assert tracks.graph.nodes[1]["area"] == original_area + 1 - assert not np.allclose(tracks.graph.nodes[1]["pos"], original_pos) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + assert tracks.graph_solution.nodes[1]["area"] == original_area + 1 + assert not np.allclose(tracks.graph_solution.nodes[1]["pos"], original_pos) assert_array_almost_equal(tracks.segmentation, new_seg) diff --git a/tests/annotators/test_annotator_registry.py b/tests/annotators/test_annotator_registry.py index 6a7e8e3d..f8d541c7 100644 --- a/tests/annotators/test_annotator_registry.py +++ b/tests/annotators/test_annotator_registry.py @@ -57,13 +57,13 @@ def test_enable_disable_features(graph_2d_with_segmentation): **track_attrs, ) - nodes = list(tracks.graph.node_ids()) - edges = list(tracks.graph.edge_ids()) + nodes = list(tracks.graph_solution.node_ids()) + edges = list(tracks.graph_solution.edge_ids()) # Core features (time, pos) should be in tracks.features and computed assert "pos" in tracks.features assert "t" in tracks.features - assert tracks.graph.nodes[nodes[0]]["pos"] is not None + assert tracks.graph_solution.nodes[nodes[0]]["pos"] is not None # area and other features should NOT be in tracks.features initially assert "area" not in tracks.features @@ -78,9 +78,9 @@ def test_enable_disable_features(graph_2d_with_segmentation): assert "circularity" in tracks.features # Verify values are actually computed on the graph - assert tracks.graph.nodes[nodes[0]]["circularity"] is not None + assert tracks.graph_solution.nodes[nodes[0]]["circularity"] is not None if edges: - assert None not in tracks.graph.edge_attrs()["iou"].to_list() + assert None not in tracks.graph_solution.edge_attrs()["iou"].to_list() # Disable one feature tracks.disable_features(["area"]) @@ -92,7 +92,7 @@ def test_enable_disable_features(graph_2d_with_segmentation): assert "circularity" in tracks.features # Values no longer exist in the graph for tracksdata - # assert tracks.graph.nodes[1]["area"] is not None + # assert tracks.graph_solution.nodes[1]["area"] is not None # Disable the remaining enabled features tracks.disable_features(["pos", "iou", "circularity"]) diff --git a/tests/annotators/test_edge_annotator.py b/tests/annotators/test_edge_annotator.py index cbef21c7..b4e98acb 100644 --- a/tests/annotators/test_edge_annotator.py +++ b/tests/annotators/test_edge_annotator.py @@ -42,7 +42,7 @@ def test_compute_all(self, get_graph, ndim): # Compute values ann.compute() for key in all_features: - assert key in tracks.graph.edge_attr_keys() + assert key in tracks.graph_solution.edge_attr_keys() def test_update_all(self, get_graph, ndim) -> None: graph = get_graph(ndim, with_seg=True) @@ -78,7 +78,10 @@ def test_update_all(self, get_graph, ndim) -> None: with pytest.warns(match="Cannot find label 1 in frame .*"): UpdateNodeSeg(tracks, node_id, mask, added=False) - assert tracks.graph.edges[tracks.graph.edge_id(*edge_id)]["iou"] == 0 + assert ( + tracks.graph_solution.edges[tracks.graph_solution.edge_id(*edge_id)]["iou"] + == 0 + ) def test_add_remove_feature(self, get_graph, ndim): graph = get_graph(ndim, with_seg=True) @@ -140,7 +143,7 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): node_id = 3 edge = (1, 3) - edge_id = tracks.graph.edge_id(*edge) + edge_id = tracks.graph_solution.edge_id(*edge) initial_iou = tracks.get_edge_attr(edge, "iou") # If we recomputed IoU now, it would be different @@ -155,6 +158,6 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): UpdateTrackIDs(tracks, node_id, new_track_id) # IoU should remain unchanged (no recomputation happened despite seg change) - assert tracks.graph.edges[edge_id]["iou"] == initial_iou + assert tracks.graph_solution.edges[edge_id]["iou"] == initial_iou # But track_id should be updated assert tracks.get_track_id(node_id) == new_track_id diff --git a/tests/annotators/test_regionprops_annotator.py b/tests/annotators/test_regionprops_annotator.py index 3193dd06..4cb34f53 100644 --- a/tests/annotators/test_regionprops_annotator.py +++ b/tests/annotators/test_regionprops_annotator.py @@ -39,9 +39,9 @@ def test_compute_all(self, get_graph, ndim): tracks.enable_features(list(rp_ann.all_features.keys())) for key in rp_ann.all_features: - assert key in tracks.graph.node_attr_keys() - for node_id in tracks.graph.node_ids(): - value = tracks.graph.nodes[node_id][key] + assert key in tracks.graph_solution.node_attr_keys() + for node_id in tracks.graph_solution.node_ids(): + value = tracks.graph_solution.nodes[node_id][key] assert value is not None def test_update_all(self, get_graph, ndim): @@ -69,7 +69,7 @@ def test_update_all(self, get_graph, ndim): UpdateNodeSeg(tracks, node_id, removal, added=False) assert tracks.get_node_attr(node_id, "area") == expected_area for key in rp_ann.features: - assert key in tracks.graph.node_attr_keys() + assert key in tracks.graph_solution.node_attr_keys() # segmentation is fully erased and you try to update node_id = 1 @@ -80,8 +80,8 @@ def test_update_all(self, get_graph, ndim): UpdateNodeSeg(tracks, node_id, mask, added=False) # all regionprops features should be the defaults, because seg doesn't exist for key in rp_ann.features: - actual = tracks.graph.nodes[node_id][key] - expected = tracks.graph._node_attr_schemas()[key].default_value + actual = tracks.graph_solution.nodes[node_id][key] + expected = tracks.graph_solution._node_attr_schemas()[key].default_value # Convert to numpy arrays for comparison (handles both scalar and array types) actual_np = np.asarray(actual) expected_np = np.asarray(expected) @@ -105,7 +105,7 @@ def test_add_remove_feature(self, get_graph, ndim): tracks.disable_features([to_remove_key]) rp_ann.compute() - assert to_remove_key not in tracks.graph.node_attr_keys() + assert to_remove_key not in tracks.graph_solution.node_attr_keys() # add it back in tracks.enable_features([to_remove_key]) @@ -155,7 +155,7 @@ def test_centroid_world_coords_with_scale(self, get_graph, ndim): # Force recomputation so regionprops runs with the given scale as spacing tracks.enable_features(["pos"]) - pos = np.array(tracks.graph.nodes[6]["pos"]) + pos = np.array(tracks.graph_solution.nodes[6]["pos"]) expected = pixel_centroid * np.array(scale[1:]) bug_value = np.array([1.5] * len(pixel_centroid)) * np.array( @@ -192,7 +192,9 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): # RegionpropsAnnotator, it would recompute area back to initial_area and # the assertion below would fail. fake_area = initial_area + 999 - tracks.graph.update_node_attrs(attrs={"area": [fake_area]}, node_ids=[node_id]) + tracks.graph_solution.update_node_attrs( + attrs={"area": [fake_area]}, node_ids=[node_id] + ) original_track_id = tracks.get_track_id(node_id) new_track_id = original_track_id + 100 diff --git a/tests/annotators/test_track_annotator.py b/tests/annotators/test_track_annotator.py index 2f7fa0d2..803ac885 100644 --- a/tests/annotators/test_track_annotator.py +++ b/tests/annotators/test_track_annotator.py @@ -41,9 +41,9 @@ def test_compute_all(self, get_tracks, ndim, with_seg) -> None: # Compute values ann.compute() - for node in tracks.graph.node_ids(): + for node in tracks.graph_solution.node_ids(): for key in all_features: - assert tracks.graph.nodes[node][key] is not None + assert tracks.graph_solution.nodes[node][key] is not None lineages = [ [1, 2, 3, 4, 5], @@ -79,7 +79,9 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg): node_id = 6 edge_id = (4, 6) attrs = {"iou": 0, "solution": True} if with_seg else {"solution": True} - tracks.graph.add_edge(source_id=edge_id[0], target_id=edge_id[1], attrs=attrs) + tracks.graph_solution.add_edge( + source_id=edge_id[0], target_id=edge_id[1], attrs=attrs + ) to_remove_key = ann.lineage_key orig_lin = tracks.get_node_attr(node_id, ann.lineage_key) orig_tra = tracks.get_node_attr(node_id, ann.tracklet_key) @@ -155,7 +157,7 @@ def test_lineage_id_updated_on_add_and_delete_edge( source_node = 3 target_node = 4 - edge = next(e for e in tracks.graph.edge_list() if set(e) == {3, 4}) + edge = next(e for e in tracks.graph_solution.edge_list() if set(e) == {3, 4}) expected_lineage_id = ann.max_lineage_id + 1 UserDeleteEdge(tracks, edge=edge) diff --git a/tests/candidate_graph/test_compute_graph.py b/tests/candidate_graph/test_compute_graph.py index 6abfff97..18c0c996 100644 --- a/tests/candidate_graph/test_compute_graph.py +++ b/tests/candidate_graph/test_compute_graph.py @@ -18,13 +18,13 @@ def test_graph_from_segmentation_2d(get_tracks): ) # Same node IDs as the segmentation labels - assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids()) + assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids()) # t, pos, area must match the source graph for every node for node in cand_graph.node_ids(): for key in ["t", "pos", "area"]: assert np.array(cand_graph.nodes[node][key]) == pytest.approx( - np.array(tracks.graph.nodes[node][key]), abs=0.01 + np.array(tracks.graph_solution.nodes[node][key]), abs=0.01 ) # mask and bbox must be present on every node @@ -39,12 +39,14 @@ def test_graph_from_segmentation_2d(get_tracks): # because t=3 has no nodes (add_cand_edges only links frame → frame+1) assert sorted(cand_graph.edge_list()) == [[1, 2], [1, 3], [2, 4], [3, 4]] - # For edges shared with tracks.graph, iou must agree + # For edges shared with tracks.graph_solution, iou must agree cand_edges = {tuple(e) for e in cand_graph.edge_list()} - ref_edges = {tuple(e) for e in tracks.graph.edge_list()} + ref_edges = {tuple(e) for e in tracks.graph_solution.edge_list()} for src, tgt in cand_edges & ref_edges: cand_iou = cand_graph.edges[cand_graph.edge_id(src, tgt)]["iou"] - ref_iou = tracks.graph.edges[tracks.graph.edge_id(src, tgt)]["iou"] + ref_iou = tracks.graph_solution.edges[tracks.graph_solution.edge_id(src, tgt)][ + "iou" + ] assert cand_iou == pytest.approx(ref_iou, abs=0.01) # lower edge distance: only (1, 3) is within 15 pixels (~11.2), (1, 2) is ~42 away @@ -52,7 +54,7 @@ def test_graph_from_segmentation_2d(get_tracks): segmentation=segmentation_2d, max_edge_distance=15, ) - assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids()) + assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids()) assert sorted(cand_graph.edge_list()) == [[1, 3]] @@ -65,12 +67,12 @@ def test_graph_from_segmentation_3d(get_tracks): max_edge_distance=100, ) - assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids()) + assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids()) for node in cand_graph.node_ids(): for key in ["t", "pos", "area"]: assert np.array(cand_graph.nodes[node][key]) == pytest.approx( - np.array(tracks.graph.nodes[node][key]), abs=0.01 + np.array(tracks.graph_solution.nodes[node][key]), abs=0.01 ) # mask and bbox must be present on every node @@ -115,8 +117,8 @@ def test_graph_from_segmentation_t_start(get_tracks): # t values should match the original tracks graph for shared nodes for node in cand_graph.node_ids(): - if node in set(tracks.graph.node_ids()): - assert cand_graph.nodes[node]["t"] == tracks.graph.nodes[node]["t"] + if node in set(tracks.graph_solution.node_ids()): + assert cand_graph.nodes[node]["t"] == tracks.graph_solution.nodes[node]["t"] # Edges should still be formed between adjacent (shifted) frames assert cand_graph.num_edges() > 0 diff --git a/tests/candidate_graph/test_iou.py b/tests/candidate_graph/test_iou.py index 0fc3e28a..21c7a464 100644 --- a/tests/candidate_graph/test_iou.py +++ b/tests/candidate_graph/test_iou.py @@ -43,10 +43,12 @@ def test_add_iou_2d(get_tracks): add_cand_edges(cand_graph, max_edge_distance=100, node_frame_dict=node_frame_dict) add_iou(cand_graph, segmentation_2d, node_frame_dict=node_frame_dict) - # For edges shared with tracks.graph, iou must agree + # For edges shared with tracks.graph_solution, iou must agree cand_edges = {tuple(e) for e in cand_graph.edge_list()} - ref_edges = {tuple(e) for e in tracks.graph.edge_list()} + ref_edges = {tuple(e) for e in tracks.graph_solution.edge_list()} for src, tgt in cand_edges & ref_edges: cand_iou = cand_graph.edges[cand_graph.edge_id(src, tgt)]["iou"] - ref_iou = tracks.graph.edges[tracks.graph.edge_id(src, tgt)]["iou"] + ref_iou = tracks.graph_solution.edges[tracks.graph_solution.edge_id(src, tgt)][ + "iou" + ] assert cand_iou == pytest.approx(ref_iou, abs=0.01) diff --git a/tests/candidate_graph/test_relabel_segmentation.py b/tests/candidate_graph/test_relabel_segmentation.py index 9ea20e96..0e32f28d 100644 --- a/tests/candidate_graph/test_relabel_segmentation.py +++ b/tests/candidate_graph/test_relabel_segmentation.py @@ -30,7 +30,7 @@ def test_relabel_segmentation(get_tracks): segmentation = np.asarray(tracks.segmentation) # Use only nodes 1 and 2 (single tracklet: node 1 at t=0, node 2 at t=1) - subgraph = tracks.graph.filter(node_ids=[1, 2]).subgraph() + subgraph = tracks.graph_solution.filter(node_ids=[1, 2]).subgraph() relabeled = relabel_segmentation_with_track_id(subgraph, segmentation) # Nodes 1 and 2 form one tracklet → both get label 1 diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 063fba80..4bc7dbc8 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -43,7 +43,7 @@ def test_from_tracks_cls(graph_2d_with_segmentation): scale=(2, 2, 2), ) solution_tracks = SolutionTracks.from_tracks(tracks) - assert solution_tracks.graph == tracks.graph + assert solution_tracks.graph_solution == tracks.graph_solution assert solution_tracks.segmentation == tracks.segmentation assert solution_tracks.features.time_key == tracks.features.time_key assert solution_tracks.features.position_key == tracks.features.position_key @@ -63,7 +63,7 @@ def test_from_tracks_cls_recompute(graph_2d_with_segmentation): ) # delete track id (default value -1) on one node triggers reassignment of # track_ids even when recompute is False. - tracks.graph.nodes[1][tracks.features.tracklet_key] = -1 + tracks.graph_solution.nodes[1][tracks.features.tracklet_key] = -1 solution_tracks = SolutionTracks.from_tracks(tracks) # should have reassigned new track_id to node 6 assert solution_tracks.get_node_attr(6, solution_tracks.features.tracklet_key) == 4 @@ -125,7 +125,7 @@ def test_export_to_csv_with_display_names( with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.num_nodes() + 1 # add header + assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header # With display names: ID, Parent ID, Time, y, x, Tracklet ID, # Lineage ID, Area @@ -149,7 +149,7 @@ def test_export_to_csv_with_display_names( with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.num_nodes() + 1 # add header + assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header # With display names: ID, Parent ID, Time, z, y, x, # Tracklet ID, Lineage ID, Volume diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 74dd717b..bd0eaa65 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -96,7 +96,7 @@ def test_nodes_edges(graph_2d_with_segmentation): assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6} assert len(tracks.edges()) == 4 # rx graph starts from 0, sql from 1, # so direct comparison of edges depends on backend - assert set(map(tuple, tracks.graph.edge_list())) == { + assert set(map(tuple, tracks.graph_solution.edge_list())) == { (1, 2), (1, 3), (3, 4), @@ -225,7 +225,7 @@ def test_get_new_node_ids(graph_2d_with_position): assert 1 not in ids # existing nodes skipped assert 2 not in ids for node_id in ids: - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) # second call must not overlap with first ids2 = tracks._get_new_node_ids(2) @@ -241,7 +241,9 @@ def test_undo_redo(graph_2d_with_segmentation): assert tracks.redo() is False # Perform an action - add a custom attribute - tracks.graph.add_node_attr_key("custom_label", default_value=None, dtype=pl.Object) + tracks.graph_solution.add_node_attr_key( + "custom_label", default_value=None, dtype=pl.Object + ) action1 = UpdateNodeAttrs(tracks, node=1, attrs={"custom_label": "test_value"}) tracks.action_history.add_new_action(action1) @@ -262,7 +264,9 @@ def test_undo_redo(graph_2d_with_segmentation): assert tracks.redo() is False # Perform another action - tracks.graph.add_node_attr_key("another_label", default_value=None, dtype=pl.Object) + tracks.graph_solution.add_node_attr_key( + "another_label", default_value=None, dtype=pl.Object + ) action2 = UpdateNodeAttrs(tracks, node=2, attrs={"another_label": "second_value"}) tracks.action_history.add_new_action(action2) assert tracks.get_node_attr(2, "another_label") == "second_value" @@ -313,14 +317,14 @@ def test_add_feature_mask_creates_both_columns(): graph = create_empty_graphview_graph(ndim=3) tracks = Tracks(graph, ndim=3, **track_attrs) - assert "nuc_mask" not in tracks.graph.node_attr_keys() - assert "nuc_bbox" not in tracks.graph.node_attr_keys() + assert "nuc_mask" not in tracks.graph_solution.node_attr_keys() + assert "nuc_bbox" not in tracks.graph_solution.node_attr_keys() tracks.add_feature("nuc_mask", SegMask(ndim=3, bbox_key="nuc_bbox")) tracks.add_feature("nuc_bbox", SegBbox(ndim=3)) - assert "nuc_mask" in tracks.graph.node_attr_keys() - assert "nuc_bbox" in tracks.graph.node_attr_keys() + assert "nuc_mask" in tracks.graph_solution.node_attr_keys() + assert "nuc_bbox" in tracks.graph_solution.node_attr_keys() assert "nuc_mask" in tracks.features assert "nuc_bbox" in tracks.features @@ -333,15 +337,15 @@ def test_delete_feature_mask_removes_both_columns( assert "mask" in tracks.features assert "bbox" in tracks.features - assert "mask" in tracks.graph.node_attr_keys() - assert "bbox" in tracks.graph.node_attr_keys() + assert "mask" in tracks.graph_solution.node_attr_keys() + assert "bbox" in tracks.graph_solution.node_attr_keys() tracks.delete_feature("mask") assert "mask" not in tracks.features assert "bbox" not in tracks.features - assert "mask" not in tracks.graph.node_attr_keys() - assert "bbox" not in tracks.graph.node_attr_keys() + assert "mask" not in tracks.graph_solution.node_attr_keys() + assert "bbox" not in tracks.graph_solution.node_attr_keys() def test_update_mask_syncs_bbox(graph_2d_with_segmentation): @@ -353,8 +357,8 @@ def test_update_mask_syncs_bbox(graph_2d_with_segmentation): new_mask = make_2d_disk_mask(center=(30, 30), radius=10) tracks.update_mask(1, new_mask) - stored_mask = tracks.graph.nodes[1][td.DEFAULT_ATTR_KEYS.MASK] - stored_bbox = tracks.graph.nodes[1][td.DEFAULT_ATTR_KEYS.BBOX] + stored_mask = tracks.graph_solution.nodes[1][td.DEFAULT_ATTR_KEYS.MASK] + stored_bbox = tracks.graph_solution.nodes[1][td.DEFAULT_ATTR_KEYS.BBOX] assert stored_mask is new_mask assert np.array_equal(stored_bbox, new_mask.bbox) diff --git a/tests/import_export/test_csv_export.py b/tests/import_export/test_csv_export.py index e344d50d..26bfb938 100644 --- a/tests/import_export/test_csv_export.py +++ b/tests/import_export/test_csv_export.py @@ -22,7 +22,7 @@ def test_export_solution_to_csv(get_tracks, tmp_path, ndim, expected_header): with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.num_nodes() + 1 # add header + assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header assert lines[0].strip().split(",") == expected_header # Check first data line (node 1: t=0, pos=[50, 50] or [50, 50, 50], track_id=1) @@ -53,7 +53,7 @@ def test_export_solution_to_csv_with_seg_zarr( with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.num_nodes() + 1 # add header + assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header assert lines[0].strip().split(",") == expected_header # check the segmentation zarr @@ -65,7 +65,9 @@ def test_export_solution_to_csv_with_seg_zarr( seg_arr = seg_zarr[:] unique_vals = set(seg_arr.flatten()) - {0} label_key = tracks.features.tracklet_key - track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list()) + track_ids = set( + tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list() + ) assert unique_vals == track_ids @@ -90,7 +92,9 @@ def test_export_solution_to_csv_with_seg_tiff(get_tracks, tmp_path, ndim): # values should be tracklet_ids (not node_ids) — default seg_relabel="tracklet" unique_vals = set(seg_arr.flatten()) - {0} label_key = tracks.features.tracklet_key - track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list()) + track_ids = set( + tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list() + ) assert unique_vals == track_ids @@ -115,7 +119,7 @@ def test_export_solution_to_csv_with_seg_original_labels(get_tracks, tmp_path, n # values should be node_ids (original labels), not track_ids seg_arr = seg_zarr[:] unique_vals = set(seg_arr.flatten()) - {0} - node_ids = set(tracks.graph.node_ids()) + node_ids = set(tracks.graph_solution.node_ids()) assert unique_vals == node_ids @@ -127,7 +131,7 @@ def test_export_with_color_dict(get_tracks, tmp_path): temp_file = tmp_path / "test_export_colors.csv" # Build a color dict: node_id → [R, G, B] floats in [0, 1] - node_ids = list(tracks.graph.node_ids()) + node_ids = list(tracks.graph_solution.node_ids()) color_dict = { node_id: np.array([0.1 * (i % 10), 0.5, 0.9]) for i, node_id in enumerate(node_ids) diff --git a/tests/import_export/test_csv_import.py b/tests/import_export/test_csv_import.py index 613dc6b5..3273a9a8 100644 --- a/tests/import_export/test_csv_import.py +++ b/tests/import_export/test_csv_import.py @@ -43,8 +43,8 @@ def test_import_2d(self, simple_df_2d): tracks = tracks_from_df(simple_df_2d) assert isinstance(tracks, SolutionTracks) - assert tracks.graph.num_nodes() == 4 - assert tracks.graph.num_edges() == 3 + assert tracks.graph_solution.num_nodes() == 4 + assert tracks.graph_solution.num_edges() == 3 assert tracks.ndim == 3 def test_import_3d(self, df_3d): @@ -52,7 +52,7 @@ def test_import_3d(self, df_3d): tracks = tracks_from_df(df_3d) assert tracks.ndim == 4 - assert tracks.graph.num_nodes() == 3 + assert tracks.graph_solution.num_nodes() == 3 # Check z coordinate pos = tracks.get_position(1) assert len(pos) == 3 # z, y, x @@ -80,12 +80,12 @@ def test_edges_created(self, simple_df_2d): tracks = tracks_from_df(simple_df_2d) # Check specific edges exist - assert tracks.graph.has_edge(1, 2) - assert tracks.graph.has_edge(1, 3) - assert tracks.graph.has_edge(2, 4) + assert tracks.graph_solution.has_edge(1, 2) + assert tracks.graph_solution.has_edge(1, 3) + assert tracks.graph_solution.has_edge(2, 4) # Check node 1 has two children (division) - assert len(list(tracks.graph.successors(1))) == 2 + assert len(list(tracks.graph_solution.successors(1))) == 2 class TestSegmentationHandling: @@ -162,8 +162,8 @@ def test_single_node(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 1 - assert tracks.graph.num_edges() == 0 + assert tracks.graph_solution.num_nodes() == 1 + assert tracks.graph_solution.num_edges() == 0 def test_multiple_roots(self): """Test multiple independent lineages.""" @@ -179,11 +179,15 @@ def test_multiple_roots(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 4 - assert tracks.graph.num_edges() == 2 + assert tracks.graph_solution.num_nodes() == 4 + assert tracks.graph_solution.num_edges() == 2 # Should have two root nodes - roots = [n for n in tracks.graph.node_ids() if tracks.graph.in_degree(n) == 0] + roots = [ + n + for n in tracks.graph_solution.node_ids() + if tracks.graph_solution.in_degree(n) == 0 + ] assert len(roots) == 2 def test_division_nan_parent(self): @@ -207,10 +211,10 @@ def test_division_nan_parent(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 3 - assert tracks.graph.num_edges() == 2 + assert tracks.graph_solution.num_nodes() == 3 + assert tracks.graph_solution.num_edges() == 2 - children = list(tracks.graph.successors(1)) + children = list(tracks.graph_solution.successors(1)) assert set(children) == {2, 3} def test_division_event(self): @@ -227,11 +231,11 @@ def test_division_event(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 3 - assert tracks.graph.num_edges() == 2 + assert tracks.graph_solution.num_nodes() == 3 + assert tracks.graph_solution.num_edges() == 2 # Node 1 should have two children - children = list(tracks.graph.successors(1)) + children = list(tracks.graph_solution.successors(1)) assert len(children) == 2 assert set(children) == {2, 3} @@ -249,19 +253,25 @@ def test_long_track(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 10 - assert tracks.graph.num_edges() == 9 + assert tracks.graph_solution.num_nodes() == 10 + assert tracks.graph_solution.num_edges() == 9 # Should form a single linear chain - roots = [n for n in tracks.graph.node_ids() if tracks.graph.in_degree(n) == 0] + roots = [ + n + for n in tracks.graph_solution.node_ids() + if tracks.graph_solution.in_degree(n) == 0 + ] assert len(roots) == 1 # Each non-leaf node should have exactly one child non_leaves = [ - n for n in tracks.graph.node_ids() if tracks.graph.out_degree(n) > 0 + n + for n in tracks.graph_solution.node_ids() + if tracks.graph_solution.out_degree(n) > 0 ] for node in non_leaves: - assert tracks.graph.out_degree(node) == 1 + assert tracks.graph_solution.out_degree(node) == 1 def test_orphaned_node_raises_error(self): """Test that node with invalid parent_id raises error.""" @@ -361,8 +371,8 @@ def test_seg_id_same_as_id(self, simple_df_2d): tracks = tracks_from_df(simple_df_2d, node_name_map=name_map) # Both id and seg_id should be present with same values - assert tracks.graph.num_nodes() == 4 - for node_id in tracks.graph.node_ids(): + assert tracks.graph_solution.num_nodes() == 4 + for node_id in tracks.graph_solution.node_ids(): assert tracks.get_node_attr(node_id, "seg_id") == node_id def test_duplicate_mapping_with_segmentation(self, simple_df_2d): @@ -654,7 +664,7 @@ def test_empty_list_in_name_map_removed(self): tracks = tracks_from_df(df, node_name_map=name_map) assert tracks is not None # The empty mapping should not result in a feature being added - assert "ellipse_axis_radii" not in tracks.graph.node_attr_keys() + assert "ellipse_axis_radii" not in tracks.graph_solution.node_attr_keys() def test_import_without_position_with_segmentation(self): """Test that position can be omitted when segmentation is provided. @@ -687,8 +697,8 @@ def test_import_without_position_with_segmentation(self): assert tracks is not None # Position should be computed from segmentation centroids - assert "pos" in tracks.graph.node_attr_keys() - pos_1 = tracks.graph.nodes[1]["pos"] + assert "pos" in tracks.graph_solution.node_attr_keys() + pos_1 = tracks.graph_solution.nodes[1]["pos"] # Centroid of 3x3 region at [2:5, 2:5] is approximately [3, 3] np.testing.assert_array_almost_equal(pos_1, [3.0, 3.0], decimal=0) diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index 9429939b..1360d6cb 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -108,12 +108,14 @@ def test_export_to_geff( else: label_key = tracks.features.tracklet_key label_vals = set( - tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list() + tracks.graph_solution.node_attrs(attr_keys=[label_key])[ + label_key + ].to_list() ) assert unique_vals == label_vals else: # values should be original node_ids - node_ids_set = set(tracks.graph.node_ids()) + node_ids_set = set(tracks.graph_solution.node_ids()) assert unique_vals == node_ids_set else: assert not seg_path.exists() @@ -198,7 +200,7 @@ def test_export_to_geff( else: label_key = tracks.features.tracklet_key kept_vals = set( - tracks.graph.filter(node_ids=[1, 3, 4, 6]) + tracks.graph_solution.filter(node_ids=[1, 3, 4, 6]) .node_attrs(attr_keys=[label_key])[label_key] .to_list() ) @@ -228,7 +230,9 @@ def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path): # values should be tracklet_ids (default seg_relabel="tracklet") unique_vals = set(seg_arr.flatten()) - {0} label_key = tracks.features.tracklet_key - track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list()) + track_ids = set( + tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list() + ) assert unique_vals == track_ids # Check metadata references the tiff path with ../../ prefix (sibling of geff dir) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index fbc268b5..593106eb 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -262,7 +262,7 @@ def test_duplicate_values_in_name_map(valid_geff): tracks = import_from_geff(store, node_name_map) # Both time and seg_id should be present with same values - for node_id in tracks.graph.node_ids(): + for node_id in tracks.graph_solution.node_ids(): assert tracks.get_node_attr(node_id, "seg_id") == tracks.get_node_attr( node_id, "t" ) @@ -316,11 +316,11 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation, assert hasattr(tracks, "segmentation") assert tracks.segmentation.shape == valid_segmentation.shape # Get last node by ID (don't rely on iteration order) - last_node = max(tracks.graph.node_ids()) + last_node = max(tracks.graph_solution.node_ids()) # With composite pos, position is stored as an array - pos = tracks.graph.nodes[last_node]["pos"] + pos = tracks.graph_solution.nodes[last_node]["pos"] coords = [ - tracks.graph.nodes[last_node]["t"], + tracks.graph_solution.nodes[last_node]["t"], pos[0], # y pos[1], # x ] @@ -333,10 +333,10 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation, ) # test that the seg id has been relabeled # Check that only requested features are present and area is loaded from geff - data = tracks.graph.nodes[last_node] - assert "random_feature" in tracks.graph.node_attr_keys() - assert "random_feature2" not in tracks.graph.node_attr_keys() - assert "area" in tracks.graph.node_attr_keys() + data = tracks.graph_solution.nodes[last_node] + assert "random_feature" in tracks.graph_solution.node_attr_keys() + assert "random_feature2" not in tracks.graph_solution.node_attr_keys() + assert "area" in tracks.graph_solution.node_attr_keys() assert data["area"] == 21 # loaded directly from geff, not recomputed # Test that import fails with ValueError when invalid seg_ids are provided. @@ -418,8 +418,8 @@ def test_features_loaded_from_name_map(valid_geff, valid_segmentation, tmp_path) assert key in tracks.features # Get last node by ID (don't rely on iteration order) - max_node_id = max(tracks.graph.node_ids()) - data = tracks.graph.nodes[max_node_id] + max_node_id = max(tracks.graph_solution.node_ids()) + data = tracks.graph_solution.nodes[max_node_id] # All requested features should be present and loaded from geff for key in feature_keys: @@ -541,13 +541,13 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path): # import_from_geff must read segmentation_shape back from zarr attrs and # reconstruct a segmentation (GraphArrayView) — not return segmentation=None. tracks = import_from_geff(tracks_path) - assert tracks.graph.num_nodes() == 1 + assert tracks.graph_solution.num_nodes() == 1 assert tracks.segmentation is not None, ( "segmentation should be reconstructed from masks after round-trip" ) assert tracks.segmentation.shape == (5, 100, 100) - node1 = tracks.graph.nodes[1] + node1 = tracks.graph_solution.nodes[1] assert node1["pos"] is not None np.testing.assert_array_almost_equal(node1["pos"], [50.0, 50.0]) @@ -664,13 +664,13 @@ def test_get_time_works_after_import(valid_geff): name_map = {"time": "t", "pos": ["y", "x"]} tracks = import_from_geff(store, name_map) - for node_id in tracks.graph.node_ids(): + for node_id in tracks.graph_solution.node_ids(): # This must not raise KeyError: 'time' t = tracks.get_time(node_id) assert isinstance(t, int), f"get_time() should return int, got {type(t)}" # get_times() on all nodes must also work - all_node_ids = list(tracks.graph.node_ids()) + all_node_ids = list(tracks.graph_solution.node_ids()) times = tracks.get_times(all_node_ids) assert len(times) == len(all_node_ids) @@ -710,7 +710,7 @@ def test_bool_node_property_schema(geff_with_bool_prop): name_map = {"time": "t", "pos": ["y", "x"], "is_dividing": "is_dividing"} tracks = import_from_geff(geff_with_bool_prop, name_map) - df = tracks.graph.node_attrs(attr_keys=["is_dividing"]) + df = tracks.graph_solution.node_attrs(attr_keys=["is_dividing"]) assert df["is_dividing"].dtype == pl.Boolean, ( f"Expected pl.Boolean schema for 'is_dividing', got {df['is_dividing'].dtype}. " "Likely cause: np.bool_ default_value fell through to int in construct_graph()." @@ -729,10 +729,10 @@ def test_bool_node_property_values(geff_with_bool_prop): name_map = {"time": "t", "pos": ["y", "x"], "is_dividing": "is_dividing"} tracks = import_from_geff(geff_with_bool_prop, name_map) - node_ids = sorted(tracks.graph.node_ids()) + node_ids = sorted(tracks.graph_solution.node_ids()) expected = [True, False, True, False, True] for node_id, exp in zip(node_ids, expected, strict=True): - val = tracks.graph.nodes[node_id]["is_dividing"] + val = tracks.graph_solution.nodes[node_id]["is_dividing"] assert type(val) is bool, ( f"Expected Python bool for 'is_dividing', got {type(val)} for node {node_id}" "Likely cause: np.bool_ value not cast to bool in construct_graph()." @@ -768,14 +768,16 @@ def test_3d_pos_survives_sql_roundtrip(tmp_path): tracks = import_from_geff(store, name_map) # Verify the RX graph has correct Array dtype for 3D pos - df_rx = tracks.graph.node_attrs(attr_keys=["pos"]) + df_rx = tracks.graph_solution.node_attrs(attr_keys=["pos"]) assert df_rx["pos"].dtype == pl.Array(pl.Float64, 3), ( f"RX graph pos should be Array(Float64, 3), got {df_rx['pos'].dtype}" ) # Convert to SQL and reload — this is where the schema mismatch used to surface db_path = str(tmp_path / "test.db") - td.graph.SQLGraph.from_other(tracks.graph, drivername="sqlite", database=db_path) + td.graph.SQLGraph.from_other( + tracks.graph_solution, drivername="sqlite", database=db_path + ) sql_graph2 = td.graph.SQLGraph("sqlite", db_path) df_sql = sql_graph2.node_attrs(attr_keys=["pos"]) @@ -826,7 +828,8 @@ def test_geff_roundtrip_preserves_tracklet_ids(get_tracks, tmp_path): """End-to-end round-trip: export then import should preserve tracklet IDs.""" tracks_in = get_tracks(ndim=3, with_seg=False, is_solution=True) expected = { - int(nid): tracks_in.get_track_id(int(nid)) for nid in tracks_in.graph.node_ids() + int(nid): tracks_in.get_track_id(int(nid)) + for nid in tracks_in.graph_solution.node_ids() } export_dir = tmp_path / "export" @@ -1075,7 +1078,9 @@ def test_invalid_featuredict_in_geff_falls_back_to_autodetect(get_tracks, tmp_pa assert imported.features.time_key is not None assert imported.features.position_key is not None assert imported.features.tracklet_key is not None - assert set(tracks.graph.node_ids()) == set(imported.graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set( + imported.graph_solution.node_ids() + ) def test_import_from_geff_respects_external_solution_column(tmp_path): @@ -1112,7 +1117,7 @@ def test_import_from_geff_respects_external_solution_column(tmp_path): tracks = import_from_geff(tracks_path) - node_ids = set(tracks.graph.node_ids()) + node_ids = set(tracks.graph_solution.node_ids()) assert node_ids == {1, 3}, ( "Node 2 has solution=False in the geff file and should be filtered out " f"by import_from_geff. Got node_ids={node_ids}." diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index 8112cc60..c170021c 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -73,13 +73,19 @@ def test_save_load( assert loaded.segmentation is None # graphs_equal doesn't exist for TracksData, so we check properties - assert set(loaded.graph.node_attr_keys()) == set(tracks.graph.node_attr_keys()) - assert set(loaded.graph.edge_attr_keys()) == set(tracks.graph.edge_attr_keys()) - assert loaded.graph.num_nodes() == tracks.graph.num_nodes() - assert loaded.graph.num_edges() == tracks.graph.num_edges() - assert set(loaded.graph.node_ids()) == set(tracks.graph.node_ids()) + assert set(loaded.graph_solution.node_attr_keys()) == set( + tracks.graph_solution.node_attr_keys() + ) + assert set(loaded.graph_solution.edge_attr_keys()) == set( + tracks.graph_solution.edge_attr_keys() + ) + assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes() + assert loaded.graph_solution.num_edges() == tracks.graph_solution.num_edges() + assert set(loaded.graph_solution.node_ids()) == set(tracks.graph_solution.node_ids()) # edge_ids dont matter, only the actual edges: - assert sorted(loaded.graph.edge_list()) == sorted(tracks.graph.edge_list()) + assert sorted(loaded.graph_solution.edge_list()) == sorted( + tracks.graph_solution.edge_list() + ) @pytest.mark.parametrize("with_seg", [True, False]) diff --git a/tests/user_actions/test_user_actions_force.py b/tests/user_actions/test_user_actions_force.py index 1096b958..538ed5de 100644 --- a/tests/user_actions/test_user_actions_force.py +++ b/tests/user_actions/test_user_actions_force.py @@ -13,9 +13,9 @@ def test_user_force_add_downstream(get_tracks): attrs = {"t": 2, "track_id": 1, "pos": [3, 4]} UserAddNode(tracks, node=7, attributes=attrs, force=True) assert tracks.get_track_id(7) == 1 - assert [1, 2] not in tracks.graph.edge_list() - assert [1, 3] not in tracks.graph.edge_list() - assert [1, 7] in tracks.graph.edge_list() + assert [1, 2] not in tracks.graph_solution.edge_list() + assert [1, 3] not in tracks.graph_solution.edge_list() + assert [1, 7] in tracks.graph_solution.edge_list() def test_user_force_add_upstream(get_tracks): @@ -28,9 +28,9 @@ def test_user_force_add_upstream(get_tracks): attrs = {"t": 0, "track_id": 3, "pos": [3, 4]} UserAddNode(tracks, node=7, attributes=attrs, force=True) assert tracks.get_track_id(7) == 3 - assert [1, 2] in tracks.graph.edge_list() # still there - assert [1, 3] not in tracks.graph.edge_list() # should be removed - assert [7, 3] in tracks.graph.edge_list() # new forced edge + assert [1, 2] in tracks.graph_solution.edge_list() # still there + assert [1, 3] not in tracks.graph_solution.edge_list() # should be removed + assert [7, 3] in tracks.graph_solution.edge_list() # new forced edge def test_auto_assign_new_track_id(get_tracks): @@ -44,5 +44,5 @@ def test_auto_assign_new_track_id(get_tracks): attrs = {"t": 1, "track_id": 2, "pos": [3, 4]} # combination exists already UserAddNode(tracks, node=7, attributes=attrs) - assert tracks.graph.has_node(7) + assert tracks.graph_solution.has_node(7) assert tracks.get_track_id(7) == 6 # new assigned track id diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py index 921f3e1d..76738ce8 100644 --- a/tests/user_actions/test_user_add_delete_edge.py +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -14,17 +14,17 @@ def test_user_add_edge(self, get_tracks, ndim, with_seg): edge = (4, 6) old_child = 5 old_track_id = tracks.get_track_id(old_child) - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) action = UserAddEdge(tracks, edge) - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id inverse = action.inverse() - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == old_track_id inverse.inverse() - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id def test_user_add_merge_edge(self, get_tracks, ndim, with_seg): @@ -32,8 +32,8 @@ def test_user_add_merge_edge(self, get_tracks, ndim, with_seg): # add an edge from 2 to 4 (there is already an edge from 3 to 4) edge = (2, 4) old_edge = (3, 4) - assert not tracks.graph.has_edge(*edge) - assert tracks.graph.has_edge(*old_edge) + assert not tracks.graph_solution.has_edge(*edge) + assert tracks.graph_solution.has_edge(*old_edge) with pytest.raises( InvalidActionError, match="Cannot make a merge edge in a tracking solution" ): @@ -43,16 +43,16 @@ def test_user_add_merge_edge(self, get_tracks, ndim, with_seg): match="Removing edge .* to add new edge without merging.", ): action = UserAddEdge(tracks, edge, force=True) - assert tracks.graph.has_edge(*edge) - assert not tracks.graph.has_edge(*old_edge) + assert tracks.graph_solution.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*old_edge) inverse = action.inverse() - assert not tracks.graph.has_edge(*edge) - assert tracks.graph.has_edge(*old_edge) + assert not tracks.graph_solution.has_edge(*edge) + assert tracks.graph_solution.has_edge(*old_edge) inverse.inverse() - assert tracks.graph.has_edge(*edge) - assert not tracks.graph.has_edge(*old_edge) + assert tracks.graph_solution.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*old_edge) def test_user_delete_edge(self, get_tracks, ndim, with_seg): tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) @@ -62,18 +62,18 @@ def test_user_delete_edge(self, get_tracks, ndim, with_seg): old_track_id = tracks.get_track_id(old_child) new_track_id = tracks.get_track_id(1) - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) action = UserDeleteEdge(tracks, edge) - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == new_track_id inverse = action.inverse() - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == old_track_id double_inv = inverse.inverse() - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == new_track_id # TODO: error if edge doesn't exist? @@ -84,18 +84,18 @@ def test_user_delete_edge(self, get_tracks, ndim, with_seg): old_child = 5 old_track_id = tracks.get_track_id(old_child) - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) action = UserDeleteEdge(tracks, edge) - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id inverse = action.inverse() - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == old_track_id inverse.inverse() - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id @@ -127,7 +127,7 @@ def test_delete_edge_triple_div(get_tracks): attrs["solution"] = True attrs["iou"] = 0.9 - tracks.graph.add_edge( + tracks.graph_solution.add_edge( source_id=1, target_id=6, attrs=attrs, diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py index 5524f021..c1e5b01b 100644 --- a/tests/user_actions/test_user_add_delete_node.py +++ b/tests/user_actions/test_user_add_delete_node.py @@ -55,7 +55,7 @@ def test_user_add_node(self, get_tracks, ndim, with_seg): del attributes["pos"] else: pixels = None - graph = tracks.graph + graph = tracks.graph_solution assert not graph.has_node(node_id) assert graph.has_edge(4, 5) action = UserAddNode(tracks, node_id, attributes, pixels=pixels) @@ -88,7 +88,7 @@ def test_user_delete_node(self, get_tracks, ndim, with_seg): # delete node in middle of track. Should skip-connect 3 and 5 with span 3 node_id = 4 - graph = tracks.graph + graph = tracks.graph_solution assert graph.has_node(node_id) assert graph.has_edge(3, node_id) assert graph.has_edge(node_id, 5) @@ -128,7 +128,7 @@ def test_user_delete_node_after_division(self, get_tracks, ndim, with_seg): node_id = 2 sib = 3 - graph = tracks.graph + graph = tracks.graph_solution assert graph.has_node(node_id) assert graph.has_edge(parent_node, node_id) parent_track_id = tracks.get_track_id(parent_node) @@ -159,7 +159,7 @@ def test_user_delete_nodes(self, get_tracks, ndim, with_seg): """Test bulk deletion of multiple nodes in a single action.""" # Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate) tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - graph = tracks.graph + graph = tracks.graph_solution # Save original state original_nodes = set(graph.node_ids()) diff --git a/tests/user_actions/test_user_swap_predecessors.py b/tests/user_actions/test_user_swap_predecessors.py index d33faeb3..320bbfd7 100644 --- a/tests/user_actions/test_user_swap_predecessors.py +++ b/tests/user_actions/test_user_swap_predecessors.py @@ -13,21 +13,21 @@ def test_one_predecessor(self, get_tracks, ndim, with_seg, order): tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) # Node 5 (t=4) has pred 4, node 6 (t=4) has no pred - assert tracks.graph.has_edge(4, 5) - assert list(tracks.graph.predecessors(6)) == [] + assert tracks.graph_solution.has_edge(4, 5) + assert list(tracks.graph_solution.predecessors(6)) == [] old_track_id_5 = tracks.get_track_id(5) old_track_id_6 = tracks.get_track_id(6) action = UserSwapPredecessors(tracks, order) - assert tracks.graph.has_edge(4, 6) - assert not tracks.graph.has_edge(4, 5) + assert tracks.graph_solution.has_edge(4, 6) + assert not tracks.graph_solution.has_edge(4, 5) assert tracks.get_track_id(6) == old_track_id_5 assert tracks.get_track_id(5) != old_track_id_5 action.inverse() - assert tracks.graph.has_edge(4, 5) - assert not tracks.graph.has_edge(4, 6) + assert tracks.graph_solution.has_edge(4, 5) + assert not tracks.graph_solution.has_edge(4, 6) assert tracks.get_track_id(5) == old_track_id_5 assert tracks.get_track_id(6) == old_track_id_6 @@ -51,14 +51,14 @@ def test_different_predecessors(self, get_tracks, ndim, with_seg): action = UserSwapPredecessors(tracks, (5, 6)) - assert tracks.graph.has_edge(4, 6) - assert tracks.graph.has_edge(2, 5) - assert not tracks.graph.has_edge(4, 5) - assert not tracks.graph.has_edge(2, 6) + assert tracks.graph_solution.has_edge(4, 6) + assert tracks.graph_solution.has_edge(2, 5) + assert not tracks.graph_solution.has_edge(4, 5) + assert not tracks.graph_solution.has_edge(2, 6) action.inverse() - assert tracks.graph.has_edge(4, 5) - assert tracks.graph.has_edge(2, 6) + assert tracks.graph_solution.has_edge(4, 5) + assert tracks.graph_solution.has_edge(2, 6) assert tracks.get_track_id(5) == old_track_id_5 assert tracks.get_track_id(6) == old_track_id_6 @@ -73,14 +73,14 @@ def test_different_times_valid(self, get_tracks, ndim, with_seg): action = UserSwapPredecessors(tracks, (4, 6)) - assert tracks.graph.has_edge(3, 6) - assert tracks.graph.has_edge(2, 4) - assert not tracks.graph.has_edge(3, 4) - assert not tracks.graph.has_edge(2, 6) + assert tracks.graph_solution.has_edge(3, 6) + assert tracks.graph_solution.has_edge(2, 4) + assert not tracks.graph_solution.has_edge(3, 4) + assert not tracks.graph_solution.has_edge(2, 6) action.inverse() - assert tracks.graph.has_edge(3, 4) - assert tracks.graph.has_edge(2, 6) + assert tracks.graph_solution.has_edge(3, 4) + assert tracks.graph_solution.has_edge(2, 6) def test_different_times_invalid_raises(self, get_tracks, ndim, with_seg): """Test error when predecessor would not be before swapped node.""" diff --git a/tests/user_actions/test_user_update_node_attrs.py b/tests/user_actions/test_user_update_node_attrs.py index 5e85a03b..3dc33f11 100644 --- a/tests/user_actions/test_user_update_node_attrs.py +++ b/tests/user_actions/test_user_update_node_attrs.py @@ -11,9 +11,15 @@ def test_user_update_node_attrs(self, get_tracks, ndim, with_seg): """Test basic node attribute update functionality.""" tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) - tracks.graph.add_node_attr_key("confidence", default_value=0, dtype=pl.Float64) - tracks.graph.add_node_attr_key("validated", default_value=False, dtype=pl.Boolean) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) + tracks.graph_solution.add_node_attr_key( + "confidence", default_value=0, dtype=pl.Float64 + ) + tracks.graph_solution.add_node_attr_key( + "validated", default_value=False, dtype=pl.Boolean + ) # Add custom attributes to update custom_attrs = {"label": "my_label", "confidence": 0.95, "validated": True} @@ -45,8 +51,12 @@ def test_user_update_existing_attrs(self, get_tracks, ndim, with_seg): """Test updating attributes that already exist.""" tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) - tracks.graph.add_node_attr_key("score", default_value=None, dtype=pl.Float64) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) + tracks.graph_solution.add_node_attr_key( + "score", default_value=None, dtype=pl.Float64 + ) # Set initial custom attributes tracks._set_node_attr(1, "label", "old_label") @@ -99,7 +109,9 @@ def test_protected_pos_attr(self, get_tracks, ndim, with_seg): def test_action_history_integration(self, get_tracks, ndim, with_seg): """Test that action integrates properly with action history.""" tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) # Initially empty assert len(tracks.action_history.undo_stack) == 0 diff --git a/tests/user_actions/test_user_update_nodes_attrs.py b/tests/user_actions/test_user_update_nodes_attrs.py index bb853887..00890c23 100644 --- a/tests/user_actions/test_user_update_nodes_attrs.py +++ b/tests/user_actions/test_user_update_nodes_attrs.py @@ -12,8 +12,12 @@ def test_user_update_nodes_attrs(self, get_tracks, ndim, with_seg): """Test basic bulk node attribute update functionality.""" tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) - tracks.graph.add_node_attr_key("confidence", default_value=0, dtype=pl.Float64) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) + tracks.graph_solution.add_node_attr_key( + "confidence", default_value=0, dtype=pl.Float64 + ) attrs = {"label": ["my_label", "my_label"], "confidence": [0.95, 0.95]} UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs=attrs) @@ -25,7 +29,9 @@ def test_user_update_nodes_attrs(self, get_tracks, ndim, with_seg): def test_single_history_entry(self, get_tracks, ndim, with_seg): """Updating multiple nodes creates only one history entry.""" tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) action = UserUpdateNodesAttrs( tracks, nodes=[1, 2, 3], attrs={"label": ["x", "x", "x"]} @@ -37,7 +43,9 @@ def test_single_history_entry(self, get_tracks, ndim, with_seg): def test_undo_redo(self, get_tracks, ndim, with_seg): """Undo restores all nodes' attrs to defaults; redo re-applies them.""" tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("score", default_value=0, dtype=pl.Float64) + tracks.graph_solution.add_node_attr_key( + "score", default_value=0, dtype=pl.Float64 + ) action = UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.9, 0.9]}) @@ -57,7 +65,9 @@ def test_undo_redo(self, get_tracks, ndim, with_seg): def test_per_node_attrs(self, get_tracks, ndim, with_seg): """Test bulk update with different values per node.""" tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("score", default_value=0, dtype=pl.Float64) + tracks.graph_solution.add_node_attr_key( + "score", default_value=0, dtype=pl.Float64 + ) UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.1, 0.9]}) @@ -69,7 +79,7 @@ def test_array_attr(self, get_tracks, ndim, with_seg): tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) spatial_dims = ndim - 1 - tracks.graph.add_node_attr_key( + tracks.graph_solution.add_node_attr_key( "custom_pos", default_value=None, dtype=pl.Array(pl.Float64, spatial_dims) ) diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py index d2d67a12..c502a921 100644 --- a/tests/user_actions/test_user_update_segmentation.py +++ b/tests/user_actions/test_user_update_segmentation.py @@ -51,14 +51,14 @@ def test_user_update_seg_smaller(self, get_tracks, ndim): updated_pixels=[(pixels_to_remove, node_id)], current_track_id=1, ) - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(remaining_pixels, tracks, node_id) assert tracks.get_position(node_id) == new_position assert tracks.get_node_attr(node_id, "area") == 1 assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(0.0, abs=0.01) inverse = action.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(orig_pixels, tracks, node_id) assert tracks.get_position(node_id) == orig_position assert tracks.get_node_attr(node_id, "area") == orig_area @@ -95,20 +95,20 @@ def test_user_update_seg_bigger(self, get_tracks, ndim): action = UserUpdateSegmentation( tracks, new_value=3, updated_pixels=[(pixels_to_add, 0)], current_track_id=1 ) - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(all_pixels, tracks, node_id) assert tracks.get_node_attr(node_id, "area") == orig_area + 1 assert tracks.get_edge_attr(edge, iou_key) != orig_iou inverse = action.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(orig_pixels, tracks, node_id) assert tracks.get_position(node_id) == orig_position assert tracks.get_node_attr(node_id, "area") == orig_area assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(orig_iou, abs=0.01) inverse.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(all_pixels, tracks, node_id) assert tracks.get_node_attr(node_id, "area") == orig_area + 1 assert tracks.get_edge_attr(edge, iou_key) != orig_iou @@ -157,7 +157,7 @@ def test_invalid_action_with_segmentation(self, get_tracks, ndim): # assert that the segmentation now has the new value assert np.asarray(tracks.segmentation[t, y, x]) == new_value - assert tracks.graph.has_node(new_value) + assert tracks.graph_solution.has_node(new_value) assert len(update_seg_action.actions) == 2 # one for adding a node, # and one for updating existing node 1 @@ -182,17 +182,17 @@ def test_user_erase_seg(self, get_tracks, ndim): updated_pixels=[(pixels_to_remove, node_id)], current_track_id=1, ) - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) inverse = action.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) self.pixels_equal_mask(orig_pixels, tracks, node_id) assert tracks.get_position(node_id) == orig_position assert tracks.get_node_attr(node_id, "area") == orig_area assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(orig_iou, abs=0.01) inverse.inverse() - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) def test_user_erase_seg_history_size(self, get_tracks, ndim): """An erase via UserUpdateSegmentation must add exactly one history @@ -228,20 +228,20 @@ def test_user_two_erases_then_two_undos(self, get_tracks, ndim): UserUpdateSegmentation( tracks, new_value=0, updated_pixels=[(pixels_5, 5)], current_track_id=1 ) - assert not tracks.graph.has_node(5) + assert not tracks.graph_solution.has_node(5) UserUpdateSegmentation( tracks, new_value=0, updated_pixels=[(pixels_6, 6)], current_track_id=1 ) - assert not tracks.graph.has_node(6) + assert not tracks.graph_solution.has_node(6) assert tracks.action_history.undo() is True - assert tracks.graph.has_node(6) - assert not tracks.graph.has_node(5) + assert tracks.graph_solution.has_node(6) + assert not tracks.graph_solution.has_node(5) assert tracks.action_history.undo() is True - assert tracks.graph.has_node(5) - assert tracks.graph.has_node(6) + assert tracks.graph_solution.has_node(5) + assert tracks.graph_solution.has_node(6) def test_user_add_seg(self, get_tracks, ndim): tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) @@ -260,7 +260,7 @@ def test_user_add_seg(self, get_tracks, ndim): position = tracks.get_position(old_node_id) area = tracks.get_node_attr(old_node_id, "area") - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) assert np.sum(tracks.segmentation == node_id) == 0 action = UserUpdateSegmentation( @@ -270,16 +270,16 @@ def test_user_add_seg(self, get_tracks, ndim): current_track_id=10, ) assert np.sum(np.asarray(tracks.segmentation) == node_id) == len(pixels_to_add[0]) - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert tracks.get_position(node_id) == position assert tracks.get_node_attr(node_id, "area") == area assert tracks.get_track_id(node_id) == 10 inverse = action.inverse() - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) inverse.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert tracks.get_position(node_id) == position assert tracks.get_node_attr(node_id, "area") == area assert tracks.get_track_id(node_id) == 10 From d006c2e00532dbbb51428af5a3e7ddf139068174 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 9 Jun 2026 08:47:37 -0400 Subject: [PATCH 02/39] SolutionTracks is gone :o --- src/funtracks/actions/add_delete_node.py | 11 +- src/funtracks/actions/update_track_id.py | 6 +- src/funtracks/annotators/_graph_annotator.py | 2 +- src/funtracks/annotators/_track_annotator.py | 22 +- src/funtracks/data_model/__init__.py | 1 - src/funtracks/data_model/solution_tracks.py | 246 ------------------ src/funtracks/data_model/tracks.py | 197 +++++++++++++- .../import_export/_tracks_builder.py | 34 ++- src/funtracks/import_export/_v1_format.py | 21 +- src/funtracks/import_export/csv/_export.py | 6 +- src/funtracks/import_export/csv/_import.py | 8 +- src/funtracks/import_export/geff/_import.py | 8 +- .../user_actions/_user_swap_predecessors.py | 8 +- src/funtracks/user_actions/user_add_edge.py | 8 +- src/funtracks/user_actions/user_add_node.py | 8 +- .../user_actions/user_delete_edge.py | 8 +- .../user_actions/user_delete_node.py | 8 +- .../user_actions/user_delete_nodes.py | 6 +- .../user_actions/user_update_node_attrs.py | 8 +- .../user_actions/user_update_nodes_attrs.py | 6 +- .../user_actions/user_update_segmentation.py | 8 +- tests/actions/test_action_history.py | 4 +- tests/actions/test_add_delete_edge.py | 6 +- tests/annotators/test_annotator_registry.py | 16 +- tests/annotators/test_edge_annotator.py | 4 +- .../annotators/test_regionprops_annotator.py | 4 +- tests/annotators/test_track_annotator.py | 10 +- tests/conftest.py | 18 +- tests/data_model/test_solution_tracks.py | 26 +- tests/import_export/test_csv_import.py | 4 +- tests/import_export/test_export_to_geff.py | 5 +- tests/import_export/test_import_from_geff.py | 26 +- 32 files changed, 353 insertions(+), 400 deletions(-) delete mode 100644 src/funtracks/data_model/solution_tracks.py diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 5744bed5..7ddc388c 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -9,8 +9,7 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model.solution_tracks import SolutionTracks - from funtracks.data_model.tracks import Node + from funtracks.data_model.tracks import Node, Tracks class AddNode(BasicAction): @@ -22,7 +21,7 @@ class AddNode(BasicAction): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: Node, attributes: dict[str, Any], ): @@ -40,7 +39,7 @@ def __init__( ValueError: If neither position nor a mask feature is in attributes. """ super().__init__(tracks) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class self.node = int(node) # Get keys from tracks features @@ -93,11 +92,11 @@ class DeleteNode(BasicAction): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: Node, ): super().__init__(tracks) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class self.node = int(node) # Save all node feature values from the features dict diff --git a/src/funtracks/actions/update_track_id.py b/src/funtracks/actions/update_track_id.py index a09db079..edaabedd 100644 --- a/src/funtracks/actions/update_track_id.py +++ b/src/funtracks/actions/update_track_id.py @@ -5,7 +5,7 @@ from ._base import BasicAction if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks from funtracks.data_model.tracks import Node @@ -31,13 +31,13 @@ class UpdateTrackIDs(BasicAction): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, start_node: Node, tracklet_id: int | None = None, lineage_id: int | None = None, ): super().__init__(tracks) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class self.start_node = start_node # Capture old tracklet ID diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py index 43dc879d..0b1ba083 100644 --- a/src/funtracks/annotators/_graph_annotator.py +++ b/src/funtracks/annotators/_graph_annotator.py @@ -35,7 +35,7 @@ def can_annotate(cls, tracks: Tracks) -> bool: """Check if this annotator can annotate the given tracks. Subclasses should override this method to specify their requirements - (e.g., segmentation, SolutionTracks, etc.). + (e.g., segmentation, Tracks, etc.). Args: tracks: The tracks to check compatibility with diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index caf3d748..e728d4a6 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -8,7 +8,6 @@ import tracksdata as td from funtracks.actions import AddNode, DeleteNode, UpdateTrackIDs -from funtracks.data_model import SolutionTracks from funtracks.features import LineageID, TrackletID from ._graph_annotator import GraphAnnotator @@ -17,6 +16,7 @@ from collections.abc import Iterable from funtracks.actions import BasicAction + from funtracks.data_model import Tracks from funtracks.features import Feature @@ -25,7 +25,7 @@ class TrackAnnotator(GraphAnnotator): - """A graph annotator to compute tracklet and lineage IDs for SolutionTracks only. + """A graph annotator to compute tracklet and lineage IDs for Tracks only. Currently, updating the tracklet and lineage IDs is left to Actions. @@ -38,7 +38,7 @@ class TrackAnnotator(GraphAnnotator): max_lineage_id (int): the maximum lineage id used in the tracks Args: - tracks (SolutionTracks): The tracks to be annotated. Must be a solution. + tracks (Tracks): The tracks to be annotated. Must be a solution. tracklet_key (str | None, optional): A key that already holds the tracklet ids on the graph. If provided, must be there for every node and already hold valid tracklet ids. Defaults to None. @@ -48,7 +48,7 @@ class TrackAnnotator(GraphAnnotator): Raises: - ValueError: if the provided Tracks are not SolutionTracks (not a binary lineage + ValueError: if the provided Tracks are not Tracks (not a binary lineage tree) """ @@ -56,15 +56,16 @@ class TrackAnnotator(GraphAnnotator): def can_annotate(cls, tracks) -> bool: """Check if this annotator can annotate the given tracks. - Requires tracks to be a SolutionTracks instance. + Track ids are only meaningful when the tracks declares a tracklet_key (i.e. + it represents a solution). A None tracklet_key means a plain candidate graph. Args: tracks: The tracks to check compatibility with Returns: - True if tracks is a SolutionTracks instance, False otherwise + True if tracks.features.tracklet_key is set, False otherwise """ - return isinstance(tracks, SolutionTracks) + return tracks.features.tracklet_key is not None @classmethod def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]: @@ -87,14 +88,11 @@ def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]: def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, tracklet_key: str | None = DEFAULT_TRACKLET_KEY, lineage_key: str | None = DEFAULT_LINEAGE_KEY, ): - if not isinstance(tracks, SolutionTracks): - raise ValueError("Currently the TrackAnnotator only works on SolutionTracks") - - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class self.tracklet_key = ( tracklet_key if tracklet_key is not None else DEFAULT_TRACKLET_KEY ) diff --git a/src/funtracks/data_model/__init__.py b/src/funtracks/data_model/__init__.py index ef0f4187..8ff09431 100644 --- a/src/funtracks/data_model/__init__.py +++ b/src/funtracks/data_model/__init__.py @@ -1,2 +1 @@ from .tracks import Tracks # noqa -from .solution_tracks import SolutionTracks # noqa diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py deleted file mode 100644 index bb060425..00000000 --- a/src/funtracks/data_model/solution_tracks.py +++ /dev/null @@ -1,246 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import tracksdata as td - -from funtracks.features import FeatureDict - -from .tracks import Tracks - -if TYPE_CHECKING: - from funtracks.annotators import TrackAnnotator - - from .tracks import Node - - -class SolutionTracks(Tracks): - """Difference from Tracks: every node must have a tracklet id""" - - def __init__( - self, - graph: td.graph.GraphView, - time_attr: str | None = None, - pos_attr: str | tuple[str] | list[str] | None = None, - tracklet_attr: str | None = None, - lineage_attr: str | None = None, - scale: list[float] | None = None, - ndim: int | None = None, - features: FeatureDict | None = None, - _segmentation: td.array.GraphArrayView | None = None, - ): - """Initialize a SolutionTracks object. - - SolutionTracks extends Tracks to ensure every node has a tracklet id. A - TrackAnnotator is automatically added to manage track IDs. - - Args: - graph (td.graph.GraphView): Tracksdata graph with nodes as detections - and edges as links. - time_attr (str | None): Graph attribute name for time. Defaults to "time" - if None. - pos_attr (str | tuple[str, ...] | list[str] | None): Graph attribute - name(s) for position. Can be: - - Single string for one attribute containing position array - - List/tuple of strings for multi-axis (one attribute per axis) - Defaults to "pos" if None. - tracklet_attr (str | None): Graph attribute name for tracklet/track IDs. - Defaults to "tracklet_id" if None. - lineage_attr (str | None): Graph attribute name for lineage IDs. - Defaults to "lineage_id" if None. - scale (list[float] | None): Scaling factors for each dimension (including - time). If None, all dimensions scaled by 1.0. - ndim (int | None): Number of dimensions (3 for 2D+time, 4 for 3D+time). - If None, inferred from segmentation or scale. - features (FeatureDict | None): Pre-built FeatureDict with feature - definitions. If provided, time_attr/pos_attr/tracklet_attr are ignored. - Assumes that all features in the dict already exist on the graph (will - be activated but not recomputed). If None, core computed features (pos, - area, tracklet_id) are auto-detected by checking if they exist on the - graph. - _segmentation (GraphArrayView | None): Internal parameter for reusing an - existing GraphArrayView instance. Not intended for public use. - """ - super().__init__( - graph, - time_attr=time_attr, - pos_attr=pos_attr, - tracklet_attr=tracklet_attr, - lineage_attr=lineage_attr, - scale=scale, - ndim=ndim, - features=features, - _segmentation=_segmentation, - ) - - self.track_annotator = self._get_track_annotator() - - def _get_track_annotator(self) -> TrackAnnotator: - """Get the TrackAnnotator instance from the annotator registry. - - Returns: - TrackAnnotator: The track annotator instance - - Raises: - RuntimeError: If no TrackAnnotator is registered - """ - from funtracks.annotators import TrackAnnotator - - for annotator in self.annotators: - if isinstance(annotator, TrackAnnotator): - return annotator - raise RuntimeError( - "No TrackAnnotator registered for this SolutionTracks instance" - ) - - @classmethod - def from_tracks(cls, tracks: Tracks): - force_recompute = False - # Check if all nodes have a value at features.tracklet_key before trusting - # existing track IDs - if ( - tracks.features.tracklet_key is not None - and ( - tracks.graph_solution.node_attrs(attr_keys=tracks.features.tracklet_key)[ - tracks.features.tracklet_key - ] - == -1 - ).any() - # Attributes are no longer None, so 0 now means non-computed - ): - force_recompute = True - - soln_tracks = cls( - tracks.graph_solution, - scale=tracks.scale, - ndim=tracks.ndim, - features=tracks.features, - _segmentation=tracks.segmentation, - ) - if force_recompute: - soln_tracks.enable_features( - [ - soln_tracks.features.tracklet_key, # type: ignore[list-item] - soln_tracks.features.lineage_key, # type: ignore[list-item] - ] - ) - return soln_tracks - - @property - def max_track_id(self) -> int: - return self.track_annotator.max_tracklet_id - - @property - def track_id_to_node(self) -> dict[int, list[int]]: - return self.track_annotator.tracklet_id_to_nodes - - def get_next_track_id(self) -> int: - """Return the next available track_id. - - The max_tracklet_id in TrackAnnotator is updated automatically when - a node is added or track IDs are updated via UpdateTrackIDs. - """ - return self.track_annotator.max_tracklet_id + 1 - - def get_next_lineage_id(self) -> int: - """Return the next available lineage_id. - - The max_lineage_id in TrackAnnotator is updated automatically when - a node is added or lineage IDs are updated via UpdateTrackIDs. - """ - return self.track_annotator.max_lineage_id + 1 - - def get_track_id(self, node) -> int: - if self.features.tracklet_key is None: - raise ValueError("Tracklet key not initialized in features") - track_id = self.get_node_attr(node, self.features.tracklet_key) - return track_id - - def get_track_ids(self, nodes) -> list[int]: - """Batch version of get_track_id — one SQL query fetching all nodes in the graph. - NOTE: always fetches the entire graph internally. Optimised for bulk (all-node) - calls. For small subsets or single nodes use get_track_id() instead.""" - - if self.features.tracklet_key is None: - raise ValueError("Tracklet key not initialized in features") - tracklet_key = self.features.tracklet_key - df = self.graph_solution.node_attrs( - attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key] - ) - id_to_val = dict( - zip( - df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list(), - df[tracklet_key].to_list(), - strict=True, - ) - ) - return [id_to_val[node] for node in nodes] - - def get_lineage_id(self, node) -> int | None: - """Get the lineage ID for a node. - - Args: - node: The node to get lineage ID for - - Returns: - The lineage ID, or None if lineage feature is not enabled - """ - if self.features.lineage_key is None: - return None - return self.get_node_attr(node, self.features.lineage_key) - - def get_track_neighbors( - self, track_id: int, time: int - ) -> tuple[Node | None, Node | None]: - """Get the last node with the given track id before time, and the first node - with the track id after time, if any. Does not assume that a node with - the given track_id and time is already in tracks, but it can be. - - Args: - track_id (int): The track id to search for - time (int): The time point to find the immediate predecessor and successor - for - - Returns: - tuple[Node | None, Node | None]: The last node before time with the given - track id, and the first node after time with the given track id, - or Nones if there are no such nodes. - """ - annotator = self.track_annotator - if ( - track_id not in annotator.tracklet_id_to_nodes - or len(annotator.tracklet_id_to_nodes[track_id]) == 0 - ): - return None, None - candidates = annotator.tracklet_id_to_nodes[track_id] - candidates.sort(key=lambda n: self.get_time(n)) - - pred = None - succ = None - for cand in candidates: - if self.get_time(cand) < time: - pred = cand - elif self.get_time(cand) > time: - succ = cand - break - return ( - int(pred) if pred is not None else None, - int(succ) if succ is not None else None, - ) - - def has_track_id_at_time(self, track_id: int, time: int) -> bool: - """Function to check if a node with given track id exists at given time point. - - Args: - track_id (int): The track id to search for. - time (int): The time point to check. - - Returns: - True if a node with given track id exists at given time point. - """ - - nodes = self.track_id_to_node.get(track_id) - if not nodes: - return False - - return time in [self.get_time(node) for node in nodes] diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 6669c605..027c902a 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -90,9 +90,11 @@ def __init__( - List/tuple of strings for multi-axis (one attribute per axis) Defaults to "pos" if None. tracklet_attr (str | None): Graph attribute name for tracklet/track IDs. - Defaults to "tracklet_id" if None. - lineage_attr (str | None): Graph attribute name for lineage IDs. - Defaults to "lineage_id" if None. + If set (non-None), a TrackAnnotator is registered and track ids are + computed/maintained. If None, this is a plain (candidate) Tracks with + no track ids. + lineage_attr (str | None): Graph attribute name for lineage IDs. Only used + when tracklet_attr is set. scale (list[float] | None): Scaling factors for each dimension (including time). If None, all dimensions scaled by 1.0. ndim (int | None): Number of dimensions (3 for 2D+time, 4 for 3D+time). @@ -218,10 +220,9 @@ def _get_feature_set( time_key = time_attr if time_attr is not None else "time" if pos_attr is None: pos_attr = "pos" - if tracklet_key is None: - tracklet_key = "tracklet_id" - if lineage_key is None: - lineage_key = "lineage_id" + # tracklet_key / lineage_key are left None unless explicitly provided: a + # non-None tracklet_key is the signal that this Tracks wants track ids, which + # is what triggers TrackAnnotator registration (see TrackAnnotator.can_annotate). # Build static features dict - always include time features: dict[str, Feature] = {time_key: Time()} @@ -277,7 +278,7 @@ def _get_annotators(self) -> AnnotatorRegistry: Creates annotators conditionally: - RegionpropsAnnotator: Only if segmentation is provided - EdgeAnnotator: Only if segmentation is provided - - TrackAnnotator: Only if this is a SolutionTracks instance + - TrackAnnotator: Only if this is a Tracks instance Each annotator is configured with appropriate keys from self.features. @@ -308,11 +309,11 @@ def _get_annotators(self) -> AnnotatorRegistry: if EdgeAnnotator.can_annotate(self): annotator_list.append(EdgeAnnotator(self)) - # TrackAnnotator: requires SolutionTracks (checked in can_annotate) + # TrackAnnotator: registered when a tracklet_key is set (checked in can_annotate) if TrackAnnotator.can_annotate(self): annotator_list.append( TrackAnnotator( - self, # type: ignore[arg-type] + self, tracklet_key=self.features.tracklet_key, lineage_key=self.features.lineage_key, ) @@ -825,3 +826,179 @@ def delete_feature(self, key: str) -> None: self.graph_solution.remove_node_attr_key(key) if "edge" in feature_type and key in self.graph_solution.edge_attr_keys(): self.graph_solution.remove_edge_attr_key(key) + + # ========== Track ID management (solution view) ========== + # These operate on the solution view via the TrackAnnotator. They are only + # meaningful when this Tracks has track ids (i.e. a tracklet_key is set and a + # TrackAnnotator is registered). + + @property + def track_annotator(self): + """The registered TrackAnnotator, or None if this Tracks has no track ids.""" + from funtracks.annotators import TrackAnnotator + + for annotator in self.annotators: + if isinstance(annotator, TrackAnnotator): + return annotator + return None + + def _require_track_annotator(self): + annotator = self.track_annotator + if annotator is None: + raise ValueError( + "This Tracks has no TrackAnnotator (no tracklet_key set); track id " + "operations are unavailable." + ) + return annotator + + @classmethod + def from_tracks(cls, tracks: Tracks) -> Tracks: + """Return a Tracks with track ids, recomputing them if any are missing.""" + force_recompute = False + # Check if all nodes have a value at features.tracklet_key before trusting + # existing track IDs + if ( + tracks.features.tracklet_key is not None + and ( + tracks.graph_solution.node_attrs(attr_keys=tracks.features.tracklet_key)[ + tracks.features.tracklet_key + ] + == -1 + ).any() + ): + # Attributes are no longer None, so -1 now means non-computed + force_recompute = True + + soln_tracks = cls( + tracks.graph_solution, + scale=tracks.scale, + ndim=tracks.ndim, + features=tracks.features, + _segmentation=tracks.segmentation, + ) + if force_recompute: + soln_tracks.enable_features( + [ + soln_tracks.features.tracklet_key, # type: ignore[list-item] + soln_tracks.features.lineage_key, # type: ignore[list-item] + ] + ) + return soln_tracks + + @property + def max_track_id(self) -> int: + return self._require_track_annotator().max_tracklet_id + + @property + def track_id_to_node(self) -> dict[int, list[int]]: + return self._require_track_annotator().tracklet_id_to_nodes + + def get_next_track_id(self) -> int: + """Return the next available track_id. + + The max_tracklet_id in TrackAnnotator is updated automatically when + a node is added or track IDs are updated via UpdateTrackIDs. + """ + return self._require_track_annotator().max_tracklet_id + 1 + + def get_next_lineage_id(self) -> int: + """Return the next available lineage_id. + + The max_lineage_id in TrackAnnotator is updated automatically when + a node is added or lineage IDs are updated via UpdateTrackIDs. + """ + return self._require_track_annotator().max_lineage_id + 1 + + def get_track_id(self, node) -> int: + if self.features.tracklet_key is None: + raise ValueError("Tracklet key not initialized in features") + track_id = self.get_node_attr(node, self.features.tracklet_key) + return track_id + + def get_track_ids(self, nodes) -> list[int]: + """Batch version of get_track_id — one query fetching all nodes in the graph. + NOTE: always fetches the entire graph internally. Optimised for bulk (all-node) + calls. For small subsets or single nodes use get_track_id() instead.""" + + if self.features.tracklet_key is None: + raise ValueError("Tracklet key not initialized in features") + tracklet_key = self.features.tracklet_key + df = self.graph_solution.node_attrs( + attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key] + ) + id_to_val = dict( + zip( + df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list(), + df[tracklet_key].to_list(), + strict=True, + ) + ) + return [id_to_val[node] for node in nodes] + + def get_lineage_id(self, node) -> int | None: + """Get the lineage ID for a node. + + Args: + node: The node to get lineage ID for + + Returns: + The lineage ID, or None if lineage feature is not enabled + """ + if self.features.lineage_key is None: + return None + return self.get_node_attr(node, self.features.lineage_key) + + def get_track_neighbors( + self, track_id: int, time: int + ) -> tuple[Node | None, Node | None]: + """Get the last node with the given track id before time, and the first node + with the track id after time, if any. Does not assume that a node with + the given track_id and time is already in tracks, but it can be. + + Args: + track_id (int): The track id to search for + time (int): The time point to find the immediate predecessor and successor + for + + Returns: + tuple[Node | None, Node | None]: The last node before time with the given + track id, and the first node after time with the given track id, + or Nones if there are no such nodes. + """ + annotator = self._require_track_annotator() + if ( + track_id not in annotator.tracklet_id_to_nodes + or len(annotator.tracklet_id_to_nodes[track_id]) == 0 + ): + return None, None + candidates = annotator.tracklet_id_to_nodes[track_id] + candidates.sort(key=lambda n: self.get_time(n)) + + pred = None + succ = None + for cand in candidates: + if self.get_time(cand) < time: + pred = cand + elif self.get_time(cand) > time: + succ = cand + break + return ( + int(pred) if pred is not None else None, + int(succ) if succ is not None else None, + ) + + def has_track_id_at_time(self, track_id: int, time: int) -> bool: + """Function to check if a node with given track id exists at given time point. + + Args: + track_id (int): The track id to search for. + time (int): The time point to check. + + Returns: + True if a node with given track id exists at given time point. + """ + nodes = self.track_id_to_node.get(track_id) + if not nodes: + return False + + return time in [self.get_time(node) for node in nodes] diff --git a/src/funtracks/import_export/_tracks_builder.py b/src/funtracks/import_export/_tracks_builder.py index 8f0240a3..77b67563 100644 --- a/src/funtracks/import_export/_tracks_builder.py +++ b/src/funtracks/import_export/_tracks_builder.py @@ -1,6 +1,6 @@ """Builder pattern for importing tracks from various formats. -This module provides a unified interface for constructing SolutionTracks objects +This module provides a unified interface for constructing Tracks objects from different data sources (GEFF, CSV, etc.) while sharing common validation and construction logic. """ @@ -15,7 +15,7 @@ import tracksdata as td from geff._typing import InMemoryGeff -from funtracks.data_model.solution_tracks import SolutionTracks +from funtracks.data_model.tracks import Tracks from funtracks.features import Feature from funtracks.import_export._import_segmentation import ( load_segmentation, @@ -610,13 +610,13 @@ def handle_segmentation( return new_segmentation, scale, graph - # Structural keys that are handled by graph construction / SolutionTracks.__init__ + # Structural keys that are handled by graph construction / Tracks.__init__ # and should not be registered as features by enable_features(). STRUCTURAL_KEYS = frozenset({"time", "id", "parent_id", "seg_id"}) def enable_features( self, - tracks: SolutionTracks, + tracks: Tracks, name_map: dict[str, str | list[str]], feature_type: Literal["node", "edge"] = "node", ) -> None: @@ -629,7 +629,7 @@ def enable_features( - Otherwise, if data was loaded for it, register it as a static feature. Args: - tracks: SolutionTracks object to add features to + tracks: Tracks object to add features to name_map: Mapping from standard funtracks keys to source property names (same format as node_name_map / edge_name_map). feature_type: Type of features ("node" or "edge") @@ -676,7 +676,7 @@ def build( scale: list[float] | None = None, node_name_map: dict[str, str | list[str]] | None = None, database: str | None = None, - ) -> SolutionTracks: + ) -> Tracks: """Orchestrate the full construction process. Args: @@ -688,7 +688,7 @@ def build( If None (default), an in-memory/temp graph is used. Returns: - Fully constructed SolutionTracks object + Fully constructed Tracks object Raises: ValueError: If self.node_name_map is not set or validation fails @@ -760,22 +760,36 @@ def build( if segmentation_array is not None: graph = add_masks_and_bboxes_to_graph(graph, segmentation_array) - # 7. Create SolutionTracks + # 7. Create Tracks # construct_graph() always stores time as "t" (tracksdata convention), # regardless of TIME_ATTR, so we pass "t" here explicitly. # If a FeatureDict was loaded (e.g., from GEFF metadata), use it directly if hasattr(self, "features") and self.features is not None: - tracks = SolutionTracks( + tracks = Tracks( graph=graph, ndim=self.ndim, scale=scale, features=self.features, ) else: - tracks = SolutionTracks( + # The builder always produces a solution, so declare tracklet/lineage + # intent to register a TrackAnnotator. Use the attr name present on the + # constructed graph; fall back to the default keys (computed from scratch) + # when the source carried no track ids. + node_keys = graph.node_attr_keys() + tracklet_attr = next( + (k for k in ("tracklet_id", "track_id") if k in node_keys), + "tracklet_id", + ) + lineage_attr = next( + (k for k in ("lineage_id",) if k in node_keys), "lineage_id" + ) + tracks = Tracks( graph=graph, pos_attr="pos", time_attr="t", + tracklet_attr=tracklet_attr, + lineage_attr=lineage_attr, ndim=self.ndim, scale=scale, ) diff --git a/src/funtracks/import_export/_v1_format.py b/src/funtracks/import_export/_v1_format.py index 34856ff1..d2ca64e8 100644 --- a/src/funtracks/import_export/_v1_format.py +++ b/src/funtracks/import_export/_v1_format.py @@ -16,7 +16,7 @@ ) if TYPE_CHECKING: - from ..data_model import SolutionTracks, Tracks + from ..data_model import Tracks GRAPH_FILE = "graph.json" SEG_FILE = "seg.npy" @@ -25,7 +25,7 @@ def load_v1_tracks( directory: Path, seg_required: bool = False, solution: bool = False -) -> Tracks | SolutionTracks: +) -> Tracks: """Load a Tracks object from the given directory. Looks for files in the format generated by Tracks.save. @@ -35,7 +35,7 @@ def load_v1_tracks( directory (Path): The directory containing tracks to load seg_required (bool, optional): If true, raises a FileNotFoundError if the segmentation file is not present in the directory. Defaults to False. - solution (bool, optional): If true, returns a SolutionTracks object, otherwise + solution (bool, optional): If true, returns a Tracks object, otherwise returns a normal Tracks object. Defaults to False. Returns: @@ -114,17 +114,20 @@ def load_v1_tracks( # is breaking, and manually setting the attrs to None if features is present will # break if the attrs are changed/removed in the future. Can remove in v2.0. # Import at runtime to avoid circular dependency - from ..data_model import SolutionTracks, Tracks + from ..data_model import Tracks + + # A solution save must come back with track ids. When the stored attrs don't carry + # a FeatureDict (older saves), declare tracklet/lineage intent explicitly so a + # TrackAnnotator is registered. + if solution and "features" not in attrs: + attrs.setdefault("tracklet_attr", "tracklet_id") + attrs.setdefault("lineage_attr", "lineage_id") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="Provided both FeatureDict and pos_attr or time_attr" ) - tracks: Tracks - if solution: - tracks = SolutionTracks(graph_td, **attrs) - else: - tracks = Tracks(graph_td, **attrs) + tracks = Tracks(graph_td, **attrs) return tracks diff --git a/src/funtracks/import_export/csv/_export.py b/src/funtracks/import_export/csv/_export.py index e155845b..9490903d 100644 --- a/src/funtracks/import_export/csv/_export.py +++ b/src/funtracks/import_export/csv/_export.py @@ -12,11 +12,11 @@ from .._utils import filter_graph_with_ancestors if TYPE_CHECKING: - from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Tracks def export_to_csv( - tracks: SolutionTracks, + tracks: Tracks, outfile: Path | str, color_dict: dict[int, np.ndarray] | None = None, node_ids: set[int] | None = None, @@ -36,7 +36,7 @@ def export_to_csv( tiff. If a color dictionary is provided, it will also export the tracklet colors. Args: - tracks: SolutionTracks object containing the tracking data to export + tracks: Tracks object containing the tracking data to export outfile: Path to output CSV file color_dict: dict[int, np.ndarray], optional. If provided, will be used to save the hex colors. diff --git a/src/funtracks/import_export/csv/_import.py b/src/funtracks/import_export/csv/_import.py index 983a4efd..7aeae87c 100644 --- a/src/funtracks/import_export/csv/_import.py +++ b/src/funtracks/import_export/csv/_import.py @@ -17,7 +17,7 @@ from .._tracks_builder import TracksBuilder, flatten_name_map if TYPE_CHECKING: - from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Tracks def _ensure_integer_ids(df: pd.DataFrame) -> pd.DataFrame: @@ -169,12 +169,12 @@ def tracks_from_df( segmentation: np.ndarray | None = None, scale: list[float] | None = None, node_name_map: dict[str, str | list[str]] | None = None, -) -> SolutionTracks: +) -> Tracks: """Import tracks from pandas DataFrame. Turns a pandas DataFrame with columns: time, [z], y, x, id, parent_id, [seg_id], [optional custom attr 1], ... - into a SolutionTracks object. + into a Tracks object. Cells without a parent_id will have an empty string or a -1 for the parent_id. @@ -195,7 +195,7 @@ def tracks_from_df( If None, column names are auto-inferred using fuzzy matching. Returns: - SolutionTracks: a solution tracks object + Tracks: a solution tracks object Raises: ValueError: if the segmentation IDs in the dataframe do not match the provided diff --git a/src/funtracks/import_export/geff/_import.py b/src/funtracks/import_export/geff/_import.py index 060a481e..4a7a4da1 100644 --- a/src/funtracks/import_export/geff/_import.py +++ b/src/funtracks/import_export/geff/_import.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from pathlib import Path - from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Tracks # defining constants here because they are only used in the context of import @@ -170,7 +170,7 @@ def read_header(self, source_path: Path) -> None: self._geff_axes = metadata.axes or [] # Read funtracks FeatureDict from GEFF extra metadata if present - # This will be passed to SolutionTracks via the base build() method + # This will be passed to Tracks via the base build() method if metadata.extra and "funtracks" in metadata.extra: funtracks_extra = metadata.extra["funtracks"] if "features" in funtracks_extra: @@ -321,7 +321,7 @@ def import_from_geff( scale: list[float] | None = None, edge_name_map: dict[str, str | list[str]] | None = None, database: str | None = None, -) -> SolutionTracks: +) -> Tracks: """Import tracks from GEFF format. Args: @@ -341,7 +341,7 @@ def import_from_geff( If None (default), an in-memory/temp graph is used. Returns: - SolutionTracks object + Tracks object """ # Filter out None values and "None" strings from node_name_map # (e.g., {"lineage_id": None} or {"lineage_id": "None"}) diff --git a/src/funtracks/user_actions/_user_swap_predecessors.py b/src/funtracks/user_actions/_user_swap_predecessors.py index da88bfec..baefaaa7 100644 --- a/src/funtracks/user_actions/_user_swap_predecessors.py +++ b/src/funtracks/user_actions/_user_swap_predecessors.py @@ -9,7 +9,7 @@ from .user_delete_edge import UserDeleteEdge if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserSwapPredecessors(ActionGroup): @@ -19,7 +19,7 @@ class UserSwapPredecessors(ActionGroup): must be earlier in time than both nodes for the swap to be valid. Args: - tracks (SolutionTracks): The tracks to perform the swap on. + tracks (Tracks): The tracks to perform the swap on. nodes (tuple[Node, Node]): A tuple with two nodes. Raises: @@ -30,11 +30,11 @@ class UserSwapPredecessors(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, nodes: tuple[int, int], ): super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # narrow type + self.tracks: Tracks # narrow type if len(nodes) != 2: raise InvalidActionError("You can only swap a pair of two nodes.") diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py index d281c505..5075f1c3 100644 --- a/src/funtracks/user_actions/user_add_edge.py +++ b/src/funtracks/user_actions/user_add_edge.py @@ -11,14 +11,14 @@ from .user_delete_edge import UserDeleteEdge if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserAddEdge(ActionGroup): """Assumes that the endpoints already exist and have track ids. Args: - tracks (SolutionTracks): the tracks to add the edge to + tracks (Tracks): the tracks to add the edge to edge (tuple[int, int]): The edge to add force (bool, optional): Whether to force the action by removing any conflicting edges. Defaults to False. @@ -29,13 +29,13 @@ class UserAddEdge(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, edge: tuple[int, int], force: bool = False, _top_level: bool = True, ): super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class source, target = edge if not tracks.graph_solution.has_node(source): raise InvalidActionError( diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index 308eb15c..902ff4b8 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -15,7 +15,7 @@ from .user_delete_edge import UserDeleteEdge if TYPE_CHECKING: - from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Tracks class UserAddNode(ActionGroup): @@ -30,7 +30,7 @@ class UserAddNode(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: int, attributes: dict[str, Any], pixels: tuple[np.ndarray, ...] | None = None, @@ -39,7 +39,7 @@ def __init__( ): """ Args: - tracks (SolutionTracks): the tracks to add the node to + tracks (Tracks): the tracks to add the node to node (int): The node id of the new node to add attributes (dict[str, Any]): A dictionary from attribute strings to values. Must contain "time" and tracks.features.tracklet_key. @@ -63,7 +63,7 @@ def __init__( time point (forceable). """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class # Get keys from tracks features time_key = tracks.features.time_key diff --git a/src/funtracks/user_actions/user_delete_edge.py b/src/funtracks/user_actions/user_delete_edge.py index 84bc0a01..1a11fed6 100644 --- a/src/funtracks/user_actions/user_delete_edge.py +++ b/src/funtracks/user_actions/user_delete_edge.py @@ -9,19 +9,19 @@ from ..actions.update_track_id import UpdateTrackIDs if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserDeleteEdge(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, edge: tuple[int, int], _top_level: bool = True, ): """ Args: - tracks (SolutionTracks): The tracks to delete the edge from. + tracks (Tracks): The tracks to delete the edge from. edge (tuple[int, int]): The edge to delete. _top_level (bool): If True, add this action to the history and emit refresh. Set to False when used as a sub-action inside a compound @@ -31,7 +31,7 @@ def __init__( InvalidActionError: If the edge does not exist in the graph. """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class if not self.tracks.graph_solution.has_edge(*edge): raise InvalidActionError(f"Edge {edge} not in solution, can't remove") diff --git a/src/funtracks/user_actions/user_delete_node.py b/src/funtracks/user_actions/user_delete_node.py index ef588560..86416587 100644 --- a/src/funtracks/user_actions/user_delete_node.py +++ b/src/funtracks/user_actions/user_delete_node.py @@ -10,20 +10,20 @@ from ..actions.update_track_id import UpdateTrackIDs if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserDeleteNode(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: int, pixels: None | tuple[np.ndarray, ...] = None, _top_level: bool = True, ): """ Args: - tracks (SolutionTracks): The tracks to delete the node from. + tracks (Tracks): The tracks to delete the node from. node (int): The node id to delete. pixels (tuple[np.ndarray, ...] | None): The pixels of the node in the segmentation, if known. Will be computed if not provided. @@ -33,7 +33,7 @@ def __init__( action. Defaults to True. """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class # delete adjacent edges for pred in self.tracks.predecessors(node): siblings = self.tracks.successors(pred) diff --git a/src/funtracks/user_actions/user_delete_nodes.py b/src/funtracks/user_actions/user_delete_nodes.py index 05f30bb2..279351b1 100644 --- a/src/funtracks/user_actions/user_delete_nodes.py +++ b/src/funtracks/user_actions/user_delete_nodes.py @@ -8,7 +8,7 @@ from .user_delete_node import UserDeleteNode if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserDeleteNodes(ActionGroup): @@ -26,12 +26,12 @@ class UserDeleteNodes(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, nodes: list[int], pixels: None | list[tuple[np.ndarray, ...]] = None, ): super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class for i, node in enumerate(nodes): self.actions.append( UserDeleteNode( diff --git a/src/funtracks/user_actions/user_update_node_attrs.py b/src/funtracks/user_actions/user_update_node_attrs.py index 7cb8cfe2..27616200 100644 --- a/src/funtracks/user_actions/user_update_node_attrs.py +++ b/src/funtracks/user_actions/user_update_node_attrs.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserUpdateNodeAttrs(ActionGroup): @@ -21,14 +21,14 @@ class UserUpdateNodeAttrs(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: int, attrs: dict[str, Any], _top_level: bool = True, ): """ Args: - tracks (SolutionTracks): The tracks to update the node attributes for + tracks (Tracks): The tracks to update the node attributes for node (int): The node to update the attributes for attrs (dict[str, Any]): A mapping from attribute name to new attribute values for the given node. @@ -40,7 +40,7 @@ def __init__( ValueError: If a protected attribute is in the given attribute mapping. """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class # Call the basic UpdateNodeAttrs action self.actions.append(UpdateNodeAttrs(tracks, node, attrs)) diff --git a/src/funtracks/user_actions/user_update_nodes_attrs.py b/src/funtracks/user_actions/user_update_nodes_attrs.py index e5cae4fd..40e088ec 100644 --- a/src/funtracks/user_actions/user_update_nodes_attrs.py +++ b/src/funtracks/user_actions/user_update_nodes_attrs.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserUpdateNodesAttrs(ActionGroup): @@ -28,12 +28,12 @@ class UserUpdateNodesAttrs(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, nodes: list[int], attrs: dict[str, list[Any]], ): super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class for key, values in attrs.items(): if not isinstance(values, list): raise ValueError( diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index 74b7e8ac..3c355b3d 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -12,13 +12,13 @@ from .user_delete_node import UserDeleteNode if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserUpdateSegmentation(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, new_value: int, updated_pixels: list[tuple[tuple[np.ndarray, ...], int]], current_track_id: int, @@ -30,7 +30,7 @@ def __init__( add_node action doesn't have anything with pixels. Args: - tracks (SolutionTracks): The solution tracks that the user is updating. + tracks (Tracks): The solution tracks that the user is updating. new_value (int): The new value that the user painted with updated_pixels (list[tuple[tuple[np.ndarray, ...], int]]): A list of node update actions, consisting of a numpy multi-index, pointing to the array @@ -42,7 +42,7 @@ def __init__( Defaults to False. """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class node_to_select = None if self.tracks.segmentation is None: raise ValueError("Cannot update non-existing segmentation.") diff --git a/tests/actions/test_action_history.py b/tests/actions/test_action_history.py index 5afed0d7..7502a095 100644 --- a/tests/actions/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -1,6 +1,6 @@ from funtracks.actions import AddNode from funtracks.actions.action_history import ActionHistory -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks from funtracks.utils.tracksdata_utils import create_empty_graphview_graph # https://github.com/zaboople/klonk/blob/master/TheGURQ.md @@ -12,7 +12,7 @@ def test_action_history(): node_attributes=["track_id", "pos"], edge_attributes=[], ) - tracks = SolutionTracks(empty_graph, ndim=3, tracklet_attr="track_id", time_attr="t") + tracks = Tracks(empty_graph, ndim=3, tracklet_attr="track_id", time_attr="t") pos = [0, 1] action1 = AddNode(tracks, node=0, attributes={"t": 0, "pos": pos, "track_id": 1}) diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index 14af4d3f..3ed0a01b 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -8,7 +8,7 @@ AddEdge, DeleteEdge, ) -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks from funtracks.features import FeatureDict, LineageID, Position, Time, TrackletID from funtracks.utils.tracksdata_utils import create_empty_graphview_graph @@ -205,7 +205,7 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path): indices=[1, 2], ) - # Wrap in SolutionTracks without registering "custom_score" in features — + # Wrap in Tracks without registering "custom_score" in features — # this is the scenario that triggers the bug. features = FeatureDict( features={ @@ -219,7 +219,7 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path): tracklet_key="track_id", lineage_key="lineage_id", ) - tracks = SolutionTracks(graph, ndim=3, features=features) + tracks = Tracks(graph, ndim=3, features=features) # Sanity: "custom_score" is in the graph schema but NOT in tracks.features. assert "custom_score" in tracks.graph_solution.edge_attr_keys() diff --git a/tests/annotators/test_annotator_registry.py b/tests/annotators/test_annotator_registry.py index f8d541c7..5efb88b4 100644 --- a/tests/annotators/test_annotator_registry.py +++ b/tests/annotators/test_annotator_registry.py @@ -1,9 +1,11 @@ import pytest from funtracks.annotators import EdgeAnnotator, RegionpropsAnnotator, TrackAnnotator -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} +# A plain (non-solution) Tracks declares no tracklet_attr, so no TrackAnnotator. +plain_attrs = {"time_attr": "t"} def test_annotator_registry_init_with_segmentation( @@ -14,18 +16,18 @@ def test_annotator_registry_init_with_segmentation( tracks = Tracks( graph_2d_with_segmentation, ndim=3, - **track_attrs, + **plain_attrs, ) annotator_types = [type(ann) for ann in tracks.annotators] assert RegionpropsAnnotator in annotator_types assert EdgeAnnotator in annotator_types - assert TrackAnnotator not in annotator_types # Not a SolutionTracks + assert TrackAnnotator not in annotator_types # No tracklet_attr -> no track ids def test_annotator_registry_init_without_segmentation(graph_2d_with_position): """Test AnnotatorRegistry doesn't create annotators without segmentation.""" - tracks = Tracks(graph_2d_with_position, ndim=3, **track_attrs) + tracks = Tracks(graph_2d_with_position, ndim=3, **plain_attrs) annotator_types = [type(ann) for ann in tracks.annotators] assert RegionpropsAnnotator not in annotator_types @@ -36,9 +38,9 @@ def test_annotator_registry_init_without_segmentation(graph_2d_with_position): def test_annotator_registry_init_solution_tracks( graph_2d_with_segmentation, ): - """Test AnnotatorRegistry creates all annotators for SolutionTracks with + """Test AnnotatorRegistry creates all annotators for Tracks with segmentation.""" - tracks = SolutionTracks( + tracks = Tracks( graph_2d_with_segmentation, ndim=3, **track_attrs, @@ -122,7 +124,7 @@ def test_area_on_graph_not_auto_activated(graph_2d_with_segmentation): def test_get_available_features(graph_2d_with_segmentation): """Test get_available_features returns all features from all annotators.""" - tracks = SolutionTracks( + tracks = Tracks( graph_2d_with_segmentation, ndim=3, **track_attrs, diff --git a/tests/annotators/test_edge_annotator.py b/tests/annotators/test_edge_annotator.py index b4e98acb..4e912d50 100644 --- a/tests/annotators/test_edge_annotator.py +++ b/tests/annotators/test_edge_annotator.py @@ -4,7 +4,7 @@ from funtracks.actions import UpdateNodeSeg, UpdateTrackIDs from funtracks.annotators import EdgeAnnotator -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} @@ -134,7 +134,7 @@ def test_missing_seg(self, get_graph, ndim) -> None: def test_ignores_irrelevant_actions(self, get_graph, ndim): """Test that EdgeAnnotator ignores actions that don't affect edges.""" graph = get_graph(ndim, is_solution=True, with_seg=True) - tracks = SolutionTracks( + tracks = Tracks( graph, ndim=ndim, **track_attrs, diff --git a/tests/annotators/test_regionprops_annotator.py b/tests/annotators/test_regionprops_annotator.py index 4cb34f53..db692a02 100644 --- a/tests/annotators/test_regionprops_annotator.py +++ b/tests/annotators/test_regionprops_annotator.py @@ -4,7 +4,7 @@ from funtracks.actions import UpdateNodeSeg, UpdateTrackIDs from funtracks.annotators import RegionpropsAnnotator -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} @@ -177,7 +177,7 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): segmentation. """ graph = get_graph(ndim, is_solution=True, with_seg=True) - tracks = SolutionTracks( + tracks = Tracks( graph, ndim=ndim, **track_attrs, diff --git a/tests/annotators/test_track_annotator.py b/tests/annotators/test_track_annotator.py index 803ac885..f55afb83 100644 --- a/tests/annotators/test_track_annotator.py +++ b/tests/annotators/test_track_annotator.py @@ -101,12 +101,12 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg): assert tracks.get_node_attr(node_id, ann.tracklet_key) != orig_tra def test_invalid(self, get_tracks, ndim, with_seg) -> None: - # Create regular Tracks (not SolutionTracks) to test error handling + # A plain Tracks (no tracklet_key) is not a track-id candidate, so no + # TrackAnnotator is registered. tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=False) - with pytest.raises( - ValueError, match="Currently the TrackAnnotator only works on SolutionTracks" - ): - TrackAnnotator(tracks) # type: ignore + assert tracks.features.tracklet_key is None + assert not TrackAnnotator.can_annotate(tracks) + assert tracks.track_annotator is None def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg): """Test that TrackAnnotator ignores actions that don't affect track IDs.""" diff --git a/tests/conftest.py b/tests/conftest.py index 29477e6e..8f345634 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model import SolutionTracks, Tracks + from funtracks.data_model import Tracks def make_2d_disk_mask(center=(50, 50), radius=20) -> Mask: @@ -303,7 +303,7 @@ def graph_2d_with_position(tmp_path) -> td.graph.GraphView: @pytest.fixture def graph_2d_with_track_id(tmp_path) -> td.graph.GraphView: - """Graph with 2D positions and track_id - for SolutionTracks without segmentation.""" + """Graph with 2D positions and track_id - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_2d_track_id.db") return _make_graph(ndim=3, with_pos=True, with_track_id=True, database=db_path) @@ -332,7 +332,7 @@ def graph_3d_with_position(tmp_path) -> td.graph.GraphView: @pytest.fixture def graph_3d_with_track_id(tmp_path) -> td.graph.GraphView: - """Graph with 3D positions and track_id - for SolutionTracks without segmentation.""" + """Graph with 3D positions and track_id - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_3d_track_id.db") return _make_graph(ndim=4, with_pos=True, with_track_id=True, database=db_path) @@ -353,13 +353,13 @@ def graph_3d_with_segmentation(tmp_path) -> td.graph.GraphView: @pytest.fixture -def get_tracks(get_graph) -> Callable[..., "Tracks | SolutionTracks"]: - """Factory fixture to create Tracks or SolutionTracks instances. +def get_tracks(get_graph) -> Callable[..., "Tracks"]: + """Factory fixture to create Tracks or Tracks instances. Returns a factory function that can be called with: ndim: 3 for 2D spatial + time, 4 for 3D spatial + time with_seg: Whether to include segmentation (mask/bbox as node attributes) - is_solution: Whether to return SolutionTracks instead of Tracks + is_solution: Whether to return Tracks instead of Tracks Example: tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) @@ -368,7 +368,7 @@ def get_tracks(get_graph) -> Callable[..., "Tracks | SolutionTracks"]: Uses a pre-built FeatureDict to avoid recomputing features that already exist in the test graph fixtures. """ - from funtracks.data_model import SolutionTracks, Tracks + from funtracks.data_model import Tracks from funtracks.features import ( Area, FeatureDict, @@ -386,7 +386,7 @@ def _make_tracks( ndim: int, with_seg: bool = True, is_solution: bool = False, - ) -> Tracks | SolutionTracks: + ) -> Tracks: # Determine axis names based on ndim axis_names = ["z", "y", "x"] if ndim == 4 else ["y", "x"] @@ -418,7 +418,7 @@ def _make_tracks( # Create the appropriate Tracks type with pre-built FeatureDict if is_solution: - return SolutionTracks( + return Tracks( graph, ndim=ndim, features=feature_dict, diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 4bc7dbc8..029b465f 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -2,7 +2,7 @@ import polars as pl from funtracks.actions import AddNode -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks from funtracks.import_export import export_to_csv from funtracks.user_actions import UserUpdateSegmentation from funtracks.utils.tracksdata_utils import ( @@ -14,7 +14,7 @@ def test_recompute_track_ids(graph_2d_with_track_id): - tracks = SolutionTracks( + tracks = Tracks( graph_2d_with_track_id, ndim=3, **track_attrs, @@ -23,7 +23,7 @@ def test_recompute_track_ids(graph_2d_with_track_id): def test_next_track_id(graph_2d_with_track_id): - tracks = SolutionTracks(graph_2d_with_track_id, ndim=3, **track_attrs) + tracks = Tracks(graph_2d_with_track_id, ndim=3, **track_attrs) assert tracks.get_next_track_id() == 6 AddNode( tracks, @@ -42,7 +42,7 @@ def test_from_tracks_cls(graph_2d_with_segmentation): tracklet_attr=track_attrs["tracklet_attr"], scale=(2, 2, 2), ) - solution_tracks = SolutionTracks.from_tracks(tracks) + solution_tracks = Tracks.from_tracks(tracks) assert solution_tracks.graph_solution == tracks.graph_solution assert solution_tracks.segmentation == tracks.segmentation assert solution_tracks.features.time_key == tracks.features.time_key @@ -64,7 +64,7 @@ def test_from_tracks_cls_recompute(graph_2d_with_segmentation): # delete track id (default value -1) on one node triggers reassignment of # track_ids even when recompute is False. tracks.graph_solution.nodes[1][tracks.features.tracklet_key] = -1 - solution_tracks = SolutionTracks.from_tracks(tracks) + solution_tracks = Tracks.from_tracks(tracks) # should have reassigned new track_id to node 6 assert solution_tracks.get_node_attr(6, solution_tracks.features.tracklet_key) == 4 assert ( @@ -73,7 +73,7 @@ def test_from_tracks_cls_recompute(graph_2d_with_segmentation): def test_update_segmentation(graph_2d_with_segmentation): - tracks = SolutionTracks( + tracks = Tracks( graph_2d_with_segmentation, ndim=3, **track_attrs, @@ -94,7 +94,7 @@ def test_next_track_id_empty(): node_attributes=["pos", "track_id"], edge_attributes=[], ) - tracks = SolutionTracks(graph, ndim=4, **track_attrs) + tracks = Tracks(graph, ndim=4, **track_attrs) assert tracks.get_next_track_id() == 1 @@ -104,7 +104,7 @@ def test_get_lineage_id_without_lineage_key(graph_2d_with_track_id): graph.add_node( attrs={"t": 1, "pos": [0, 0], "track_id": 1}, index=7, validate_keys=False ) - tracks = SolutionTracks(graph, ndim=3, **track_attrs) + tracks = Tracks(graph, ndim=3, **track_attrs) # Unset lineage_key to test the None path tracks.features.lineage_key = None @@ -118,7 +118,7 @@ def test_export_to_csv_with_display_names( ): """Test CSV export with use_display_names=True option.""" # Test 2D with display names - tracks = SolutionTracks(graph_2d_with_segmentation, **track_attrs, ndim=3) + tracks = Tracks(graph_2d_with_segmentation, **track_attrs, ndim=3) tracks.enable_features(["area"]) temp_file = tmp_path / "test_export_2d_display.csv" export_to_csv(tracks, temp_file, use_display_names=True) @@ -142,7 +142,7 @@ def test_export_to_csv_with_display_names( assert lines[0].strip().split(",") == header # Test 3D with display names (area display name is "Volume" in 3D) - tracks = SolutionTracks(graph_3d_with_segmentation, **track_attrs, ndim=4) + tracks = Tracks(graph_3d_with_segmentation, **track_attrs, ndim=4) tracks.enable_features(["area"]) temp_file = tmp_path / "test_export_3d_display.csv" export_to_csv(tracks, temp_file, use_display_names=True) @@ -171,7 +171,7 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation): """pos_attr as list should be respected even when segmentation is present. Scenario: graph has both a "pos" column AND individual z/y/x columns with - distinct values. SolutionTracks(pos_attr=['z','y','x']) should use z/y/x + distinct values. Tracks(pos_attr=['z','y','x']) should use z/y/x as the position_key, not fall back to "pos". """ graph = graph_3d_with_segmentation @@ -187,7 +187,7 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation): graph.nodes[node]["y"] = float(pos[1]) + offset graph.nodes[node]["x"] = float(pos[2]) + offset - tracks = SolutionTracks( + tracks = Tracks( graph=graph, pos_attr=["z", "y", "x"], ndim=4, @@ -204,5 +204,5 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation): expected = [float(original_pos[i]) + offset for i in range(3)] assert list(pos_from_tracks) == expected, ( f"Expected positions from z/y/x ({expected}), " - f"got {list(pos_from_tracks)} — SolutionTracks used 'pos' instead" + f"got {list(pos_from_tracks)} — Tracks used 'pos' instead" ) diff --git a/tests/import_export/test_csv_import.py b/tests/import_export/test_csv_import.py index 3273a9a8..3add3c99 100644 --- a/tests/import_export/test_csv_import.py +++ b/tests/import_export/test_csv_import.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks from funtracks.import_export import tracks_from_df @@ -42,7 +42,7 @@ def test_import_2d(self, simple_df_2d): """Test importing 2D DataFrame.""" tracks = tracks_from_df(simple_df_2d) - assert isinstance(tracks, SolutionTracks) + assert isinstance(tracks, Tracks) assert tracks.graph_solution.num_nodes() == 4 assert tracks.graph_solution.num_edges() == 3 assert tracks.ndim == 3 diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index 1360d6cb..876d566b 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -4,7 +4,7 @@ import tifffile import zarr -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks from funtracks.import_export import export_to_geff @@ -51,8 +51,7 @@ def test_export_to_geff( graph.nodes[node][key] = pos[i] graph.remove_node_attr_key("pos") # Create Tracks with split position attributes - tracks_cls = SolutionTracks if is_solution else Tracks - tracks = tracks_cls( + tracks = Tracks( graph, time_attr="t", pos_attr=pos_keys, diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index 593106eb..f441c1cd 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -5,7 +5,7 @@ import zarr from geff.testing.data import create_mock_geff -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks from funtracks.import_export import export_to_geff, import_from_geff from funtracks.import_export.geff._import import GeffTracksBuilder, import_graph_from_geff from funtracks.utils.tracksdata_utils import create_empty_graphview_graph @@ -501,13 +501,15 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path): indices=[1], ) # The graph carries segmentation_shape in its metadata (set by motile-tracker), - # but no dense segmentation array is attached to the SolutionTracks object. + # but no dense segmentation array is attached to the Tracks object. graph._update_metadata(segmentation_shape=(5, 100, 100)) run_dir = tmp_path / "run" run_dir.mkdir() - st = SolutionTracks(graph, ndim=3, time_attr="t") + st = Tracks( + graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id" + ) export_to_geff(st, run_dir) tracks_path = run_dir / "tracks.geff" @@ -625,7 +627,9 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path): run_dir = tmp_path / "run" run_dir.mkdir() - st = SolutionTracks(graph, ndim=3, time_attr="t") + st = Tracks( + graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id" + ) export_to_geff(st, run_dir) tracks_path = run_dir / "tracks.geff" @@ -652,11 +656,11 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path): def test_get_time_works_after_import(valid_geff): - """Regression test: tracks.get_time() must work on a SolutionTracks returned by + """Regression test: tracks.get_time() must work on a Tracks returned by import_from_geff(). Previously, TracksBuilder.build() stored time as "t" in the graph (tracksdata - convention) but created SolutionTracks(time_attr=TIME_ATTR) where TIME_ATTR="time". + convention) but created Tracks(time_attr=TIME_ATTR) where TIME_ATTR="time". This caused features.time_key="time" while the graph only had attribute "t", making get_time() raise KeyError: 'time'. """ @@ -914,7 +918,9 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path): run_dir = tmp_path / "run" run_dir.mkdir() - st = SolutionTracks(graph, ndim=3, time_attr="t") + st = Tracks( + graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id" + ) export_to_geff(st, run_dir) # Remove FeatureDict from GEFF metadata to simulate old/external GEFF @@ -1027,7 +1033,9 @@ def test_featuredict_survives_geff_roundtrip(tmp_path): run_dir = tmp_path / "run" run_dir.mkdir() - st = SolutionTracks(graph, ndim=3, time_attr="t") + st = Tracks( + graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id" + ) # Customize metadata that auto-detection cannot reproduce. pos_key = st.features.position_key assert isinstance(pos_key, str) @@ -1074,7 +1082,7 @@ def test_invalid_featuredict_in_geff_falls_back_to_autodetect(get_tracks, tmp_pa # Should not raise — falls back to auto-detection imported = import_from_geff(geff_path) - # Verify the import produced a working SolutionTracks with auto-detected features + # Verify the import produced a working Tracks with auto-detected features assert imported.features.time_key is not None assert imported.features.position_key is not None assert imported.features.tracklet_key is not None From db7c3838f86d02f4623b83d0049f736898c5d50a Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 11 Jun 2026 11:08:25 -0400 Subject: [PATCH 03/39] soft delete|add actions on view/_root + td changes --- pyproject.toml | 4 ++ src/funtracks/actions/add_delete_edge.py | 56 +++++++++++++++--------- src/funtracks/actions/add_delete_node.py | 33 +++++++++++--- src/funtracks/data_model/tracks.py | 13 ++++++ tests/actions/test_add_delete_edge.py | 6 ++- 5 files changed, 84 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 68378565..322765bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,3 +117,7 @@ explicit_package_bases = true exclude_also = [ "if TYPE_CHECKING:", ] + +#remove this after tracksdata PR is merged: +[tool.uv.sources] +tracksdata = { git = "https://github.com/TeunHuijben/tracksdata.git", rev = "040ed9f03c0f64c587355e9fb6c7f9f08d09855c" } diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index e4fc8a18..ed36cd52 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -45,31 +45,40 @@ def _apply(self) -> None: Raises: ValueError if an endpoint of the edge does not exist """ - # Check that both endpoints exist before computing edge attributes + # Check that both endpoints exist in the solution before adding the edge for node in self.edge: if not self.tracks.graph_solution.has_node(node): raise ValueError( - f"Cannot add edge {self.edge}: endpoint {node} not in graph yet" + f"Cannot add edge {self.edge}: endpoint {node} not in solution yet" ) if self.tracks.graph_solution.has_edge(*self.edge): - raise ValueError(f"Edge {self.edge} already exists in the graph") - - attrs = dict(self.attributes) - - # Fill in missing edge attributes with schema defaults (includes - # solution and any other registered edge attrs). - schemas = self.tracks.graph_solution._edge_attr_schemas() - for attr in self.tracks.graph_solution.edge_attr_keys(): - if attr not in attrs: - attrs[attr] = schemas[attr].default_value - - # Create edge attributes for this specific edge - self.tracks.graph_solution.add_edge( - source_id=self.edge[0], - target_id=self.edge[1], - attrs=attrs, - ) + raise ValueError(f"Edge {self.edge} already exists in the solution") + + if self.tracks.graph_full.has_edge(*self.edge): + # Revive a soft-deleted edge (already present in the full graph as a + # candidate): flip solution=True and re-surface it in the solution view. + edge_id = self.tracks.graph_full.edge_id(self.edge[0], self.edge[1]) + self.tracks.graph_full.update_edge_attrs( + attrs={"solution": True}, edge_ids=[edge_id] + ) + self.tracks.graph_solution.add_edge_to_view(self.edge[0], self.edge[1]) + else: + attrs = dict(self.attributes) + + # Fill in missing edge attributes with schema defaults (includes + # solution and any other registered edge attrs). + schemas = self.tracks.graph_solution._edge_attr_schemas() + for attr in self.tracks.graph_solution.edge_attr_keys(): + if attr not in attrs: + attrs[attr] = schemas[attr].default_value + + # Create edge attributes for this specific edge + self.tracks.graph_solution.add_edge( + source_id=self.edge[0], + target_id=self.edge[1], + attrs=attrs, + ) # Notify annotators to recompute features (will overwrite computed ones) self.tracks.notify_annotators(self) @@ -106,5 +115,12 @@ def inverse(self) -> BasicAction: return AddEdge(self.tracks, self.edge, attributes=self.attributes) def _apply(self) -> None: - self.tracks.graph_solution.remove_edge(*self.edge) + """Soft-delete the edge: flag solution=False in the full graph and remove it + from the solution view only. The edge is preserved in graph_full (as a + candidate) so the delete is reversible.""" + edge_id = self.tracks.graph_full.edge_id(self.edge[0], self.edge[1]) + self.tracks.graph_full.update_edge_attrs( + attrs={"solution": False}, edge_ids=[edge_id] + ) + self.tracks.graph_solution.remove_edge_from_view(self.edge[0], self.edge[1]) self.tracks.notify_annotators(self) diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 7ddc388c..ab7efa8f 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -75,10 +75,25 @@ def inverse(self) -> BasicAction: return DeleteNode(self.tracks, self.node) def _apply(self) -> None: - """Add the node with all attributes from self.attributes.""" - self.tracks.graph_solution.add_node( - attrs=dict(self.attributes), index=self.node, validate_keys=False - ) + """Add the node, or revive a soft-deleted one. + + If the node still exists in the full graph (it was soft-deleted, so its + topology was preserved), revive it: flip solution=True and re-surface it in the + solution view. Otherwise add a genuinely new node. + """ + if self.tracks.graph_full.has_node(self.node): + # Revive: same node id, topology preserved in graph_full. Flip it back into + # the solution and re-surface it in the view in place (incident edges are + # revived separately by AddEdge). + self.tracks.graph_full.update_node_attrs( + attrs={"solution": True}, node_ids=[self.node] + ) + self.tracks.graph_solution.add_node_to_view(self.node) + else: + # Genuinely new node (solution defaults to True via the schema). + self.tracks.graph_solution.add_node( + attrs=dict(self.attributes), index=self.node, validate_keys=False + ) # Always notify annotators - they will check their own preconditions self.tracks.notify_annotators(self) @@ -115,6 +130,12 @@ def inverse(self) -> BasicAction: return AddNode(self.tracks, self.node, self.attributes) def _apply(self) -> None: - """Remove the node from the graph.""" - self.tracks.graph_solution.remove_node(self.node) + """Soft-delete the node: flag solution=False in the full graph and remove it + from the solution view only. The node (and its topology) is preserved in + graph_full so the delete is reversible and the node remains a candidate. + """ + self.tracks.graph_full.update_node_attrs( + attrs={"solution": False}, node_ids=[self.node] + ) + self.tracks.graph_solution.remove_node_from_view(self.node) self.tracks.notify_annotators(self) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 027c902a..1430b1e1 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -108,6 +108,19 @@ def __init__( existing GraphArrayView instance. Not intended for public use. """ self.graph_solution = graph + # Depth-1 invariant: graph_full is defined as graph_solution._root (one hop), + # so the root must be a base graph, NOT itself a GraphView. A nested view + # (base -> crop -> solution) would silently make graph_full mean "the crop" + # instead of "every node ever known", breaking the AddNode revive-vs-new check + # (graph_full.has_node) and any annotator registered on graph_full. Fail loudly + # rather than corrupt data if cropping is ever introduced upstream. + if isinstance(graph._root, td.graph.GraphView): + raise ValueError( + "Tracks requires graph_solution to be a direct view of a base graph " + "(graph_solution._root must not itself be a GraphView). A nested view " + "chain (e.g. a crop of a crop) violates the depth-1 assumption that " + "graph_full = graph_solution._root is the full graph." + ) if _segmentation is not None: # Reuse provided segmentation instance (internal use only) self.segmentation = _segmentation diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index 3ed0a01b..2f2f73b9 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -30,7 +30,7 @@ def test_add_delete_edges(get_tracks, ndim, with_seg): action = ActionGroup(tracks=tracks, actions=[AddEdge(tracks, edge) for edge in edges]) - with pytest.raises(ValueError, match="Edge .* already exists in the graph"): + with pytest.raises(ValueError, match="Edge .* already exists in the solution"): AddEdge(tracks, (1, 2)) # TODO: What if adding an edge that already exists? @@ -81,7 +81,9 @@ def test_add_delete_edges(get_tracks, ndim, with_seg): def test_add_edge_missing_endpoint(get_tracks): tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) - with pytest.raises(ValueError, match="Cannot add edge .*: endpoint .* not in graph"): + with pytest.raises( + ValueError, match="Cannot add edge .*: endpoint .* not in solution" + ): AddEdge(tracks, (10, 11)) From c9b8c894bd4ff61303e7c9d7163eee9bfd8dd784 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 11 Jun 2026 11:53:39 -0400 Subject: [PATCH 04/39] regionprops and edgeannotator act on graph_full, trackannotator on graph_solution --- src/funtracks/annotators/_edge_annotator.py | 22 +++--- src/funtracks/annotators/_graph_annotator.py | 16 +++++ .../annotators/_regionprops_annotator.py | 17 +++-- src/funtracks/data_model/tracks.py | 28 +++++--- .../annotators/test_features_on_full_graph.py | 69 +++++++++++++++++++ 5 files changed, 127 insertions(+), 25 deletions(-) create mode 100644 tests/annotators/test_features_on_full_graph.py diff --git a/src/funtracks/annotators/_edge_annotator.py b/src/funtracks/annotators/_edge_annotator.py index df9202fc..4c35811d 100644 --- a/src/funtracks/annotators/_edge_annotator.py +++ b/src/funtracks/annotators/_edge_annotator.py @@ -41,6 +41,12 @@ def can_annotate(cls, tracks) -> bool: """ return tracks.segmentation is not None + @property + def graph(self): + """IoU is an intrinsic link feature → computed on the full graph (all edges, + including soft-deleted/candidate ones, so they stay ready for re-solving).""" + return self.tracks.graph_full + @classmethod def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]: """Get all features that can be computed by this annotator. @@ -82,14 +88,14 @@ def compute(self, feature_keys: list[str] | None = None) -> None: # TODO: add skip edges if self.iou_key in keys_to_compute: nodes_by_frame = defaultdict(list) - for n in self.tracks.graph_solution.node_ids(): + for n in self.graph.node_ids(): nodes_by_frame[self.tracks.get_time(n)].append(n) for t in range(self.tracks.segmentation.shape[0] - 1): nodes_in_t = nodes_by_frame[t] edges = [] for node in nodes_in_t: - for succ in self.tracks.graph_solution.successors(node): + for succ in self.graph.successors(node): edges.append((node, succ)) self._iou_update(edges) @@ -105,8 +111,8 @@ def _iou_update( """ for edge in edges: source, target = edge - mask1 = self.tracks.graph_solution.nodes[source]["mask"] - mask2 = self.tracks.graph_solution.nodes[target]["mask"] + mask1 = self.graph.nodes[source]["mask"] + mask2 = self.graph.nodes[target]["mask"] iou = mask1.iou(mask2) self.tracks._set_edge_attr(edge, self.iou_key, iou) @@ -136,16 +142,16 @@ def update(self, action: BasicAction): # Get all incident edges to the modified node modified_node = action.node edges_to_update = [] - for pred in self.tracks.graph_solution.predecessors(modified_node): + for pred in self.graph.predecessors(modified_node): edges_to_update.append((pred, modified_node)) - for succ in self.tracks.graph_solution.successors(modified_node): + for succ in self.graph.successors(modified_node): edges_to_update.append((modified_node, succ)) # Update IoU for each edge for edge in edges_to_update: source, target = edge - mask1 = self.tracks.graph_solution.nodes[source]["mask"] - mask2 = self.tracks.graph_solution.nodes[target]["mask"] + mask1 = self.graph.nodes[source]["mask"] + mask2 = self.graph.nodes[target]["mask"] if mask1.mask.sum() == 0 or mask2.mask.sum() == 0: empty_node = source if mask1.mask.sum() == 0 else target frame = self.tracks.get_time(empty_node) diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py index 0b1ba083..f7fca8b6 100644 --- a/src/funtracks/annotators/_graph_annotator.py +++ b/src/funtracks/annotators/_graph_annotator.py @@ -52,6 +52,22 @@ def __init__(self, tracks: Tracks, features: dict[str, Feature]): key: (feat, False) for key, feat in features.items() } + @property + def graph(self): + """The graph this annotator iterates over and reads topology/masks from. + + Defaults to the solution view. Detection-feature annotators + (`RegionpropsAnnotator`, `EdgeAnnotator`) override this to return + `graph_full`, so intrinsic features (`pos`, `area`, `iou`, ...) are computed + for *every* node/edge — including soft-deleted (`solution=False`) candidates — + keeping the full and solution graphs in sync and ready for re-solving. Track-id + features (`tracklet_id`, `lineage_id`) are solution-only, so `TrackAnnotator` + keeps the default. Note that attribute *writes* go through the `tracks` helpers, + which target `graph_full`; because attr dicts are shared by reference, in-solution + nodes see those writes through the view automatically. + """ + return self.tracks.graph_solution + def activate_features(self, keys: list[str]) -> None: """Activate computation of the given features in the annotation process. diff --git a/src/funtracks/annotators/_regionprops_annotator.py b/src/funtracks/annotators/_regionprops_annotator.py index 61953fcc..9fd58d83 100644 --- a/src/funtracks/annotators/_regionprops_annotator.py +++ b/src/funtracks/annotators/_regionprops_annotator.py @@ -73,6 +73,11 @@ def can_annotate(cls, tracks) -> bool: """ return tracks.segmentation is not None + @property + def graph(self): + """Regionprops features are intrinsic detections → computed on the full graph.""" + return self.tracks.graph_full + def __init__( self, tracks: Tracks, @@ -170,10 +175,10 @@ def compute(self, feature_keys: list[str] | None = None) -> None: all_node_ids = [] all_values: dict[str, list] = {key: [] for key in keys_to_compute} - for node_id in self.tracks.graph_solution.node_ids(): - if not self.tracks.graph_solution.has_node(node_id): + for node_id in self.graph.node_ids(): + if not self.graph.has_node(node_id): continue - mask = self.tracks.graph_solution.nodes[node_id]["mask"] + mask = self.graph.nodes[node_id]["mask"] for region in regionprops_extended(mask, spacing=spacing): all_node_ids.append(node_id) for key in keys_to_compute: @@ -203,7 +208,7 @@ def _regionprops_update( spacing = None if self.tracks.scale is None else tuple(self.tracks.scale[1:]) for region in regionprops_extended(mask, spacing=spacing): # Skip labels that aren't nodes in the graph (e.g., unselected detections) - if not self.tracks.graph_solution.has_node(node_id): + if not self.graph.has_node(node_id): continue for key in feature_keys: value = getattr(region, self.regionprops_names[key]) @@ -240,7 +245,7 @@ def update(self, action: BasicAction): time = self.tracks.get_time(node) - if self.tracks.graph_solution.nodes[node]["mask"].mask.sum() == 0: + if self.graph.nodes[node]["mask"].mask.sum() == 0: warnings.warn( f"Cannot find label {node} in frame {time}: " "updating regionprops values to None", @@ -250,7 +255,7 @@ def update(self, action: BasicAction): value = None self.tracks._set_node_attr(node, key, value) else: - mask = self.tracks.graph_solution.nodes[node]["mask"] + mask = self.graph.nodes[node]["mask"] self._regionprops_update(node, mask, keys_to_compute) def change_key(self, old_key: str, new_key: str) -> None: diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 1430b1e1..f7b1eb4c 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -535,7 +535,7 @@ def get_times(self, nodes: Iterable[Node]) -> Sequence[int]: For a single node use get_time() instead. """ nodes = list(nodes) - df = self.graph_solution.node_attrs( + df = self.graph_full.node_attrs( attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, self.features.time_key] ) id_to_val = dict( @@ -668,39 +668,45 @@ def _compute_ndim( ) return ndim + # NOTE: the low-level attribute get/set helpers below target `graph_full`, not the + # solution view. Attribute *values* are intrinsic to a node/edge and live on the full + # graph; the solution view is a membership filter over it. Because graph_full ⊇ + # graph_solution and their attr dicts are shared by reference, writing via graph_full + # is identical to writing via the view for any in-solution node (the view sees it + # automatically) and additionally works for soft-deleted (solution=False) candidates. def _set_node_attr(self, node: Node, attr: str, value: Any): if isinstance(value, np.ndarray): value = list(value) - self.graph_solution.nodes[node][attr] = value + self.graph_full.nodes[node][attr] = value def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): nodes_list = list(nodes) values_list = list(values) if nodes_list: - self.graph_solution.update_node_attrs( + self.graph_full.update_node_attrs( attrs={attr: values_list}, node_ids=nodes_list ) def get_node_attr(self, node: Node, attr: str): - return self.graph_solution.nodes[int(node)][attr] + return self.graph_full.nodes[int(node)][attr] def get_nodes_attr(self, nodes: Iterable[Node], attr: str): return [self.get_node_attr(node, attr) for node in nodes] def _set_edge_attr(self, edge: Edge, attr: str, value: Any): - edge_id = self.graph_solution.edge_id(edge[0], edge[1]) - self.graph_solution.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) + edge_id = self.graph_full.edge_id(edge[0], edge[1]) + self.graph_full.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]): for edge, value in zip(edges, values, strict=False): - edge_id = self.graph_solution.edge_id(edge[0], edge[1]) - self.graph_solution.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) + edge_id = self.graph_full.edge_id(edge[0], edge[1]) + self.graph_full.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def get_edge_attr(self, edge: Edge, attr: str): - if attr not in self.graph_solution.edge_attr_keys(): + if attr not in self.graph_full.edge_attr_keys(): return None - edge_id = self.graph_solution.edge_id(edge[0], edge[1]) - return self.graph_solution.edges[edge_id][attr] + edge_id = self.graph_full.edge_id(edge[0], edge[1]) + return self.graph_full.edges[edge_id][attr] def get_edges_attr(self, edges: Iterable[Edge], attr: str): return [self.get_edge_attr(edge, attr) for edge in edges] diff --git a/tests/annotators/test_features_on_full_graph.py b/tests/annotators/test_features_on_full_graph.py new file mode 100644 index 00000000..85cbcb5f --- /dev/null +++ b/tests/annotators/test_features_on_full_graph.py @@ -0,0 +1,69 @@ +"""Step 5 regression: detection features (regionprops, iou) live on graph_full. + +These features are intrinsic to a detection/link and must stay computed for *all* +nodes/edges — including soft-deleted (solution=False) candidates — so the full and +solution graphs never drift and candidates are ready for re-solving. Track-id features +remain solution-only and are covered elsewhere. +""" + +import pytest + +from funtracks.actions import DeleteEdge, DeleteNode +from funtracks.annotators import EdgeAnnotator, RegionpropsAnnotator + + +def _annotator(tracks, cls): + return next(ann for ann in tracks.annotators if isinstance(ann, cls)) + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_regionprops_persist_and_recompute_on_soft_deleted_node(get_tracks, ndim): + tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + rp_ann = _annotator(tracks, RegionpropsAnnotator) + tracks.enable_features(["pos", "area"]) + rp_ann.compute(["area"]) + + node = 5 # leaf node of the fixture graph + area_before = tracks.get_node_attr(node, "area") + assert area_before is not None + + # Soft-delete: leaves the solution view, stays in graph_full as solution=False. + DeleteNode(tracks, node) + assert node not in tracks.graph_solution.node_ids() + assert node in tracks.graph_full.node_ids() + assert tracks.graph_full.nodes[node]["solution"] is False + + # NEW: the detection feature is still readable (helpers read graph_full, so this no + # longer KeyErrors on an out-of-solution node) and its value is preserved. + assert tracks.get_node_attr(node, "area") == area_before + + # NEW: a bulk recompute covers the soft-deleted candidate (iterates graph_full). + tracks._set_node_attr(node, "area", None) # wipe to prove recompute reaches it + rp_ann.compute(["area"]) + assert tracks.get_node_attr(node, "area") == area_before + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_iou_computed_on_soft_deleted_candidate_edge(get_tracks, ndim): + tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + edge_ann = _annotator(tracks, EdgeAnnotator) + tracks.enable_features(["iou"]) + edge_ann.compute(["iou"]) + + edge = (4, 5) + iou_before = tracks.get_edge_attr(edge, "iou") + assert iou_before is not None + + # Soft-delete the edge: gone from the solution, kept in graph_full as a candidate. + DeleteEdge(tracks, edge) + assert not tracks.graph_solution.has_edge(*edge) + assert tracks.graph_full.has_edge(*edge) + + # NEW: iou is still readable on the candidate edge (get_edge_attr reads graph_full). + assert tracks.get_edge_attr(edge, "iou") == iou_before + + # NEW: a bulk recompute reaches the solution=False edge (compute iterates graph_full + # successors, which include candidate edges). + tracks._set_edge_attr(edge, "iou", 0.0) # wipe to prove recompute reaches it + edge_ann.compute(["iou"]) + assert tracks.get_edge_attr(edge, "iou") == iou_before From a2c81223438eda62986a576615b21026a1a625b7 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 11 Jun 2026 12:11:05 -0400 Subject: [PATCH 05/39] properly test soft delete roundtrip --- .../data_model/test_soft_delete_roundtrip.py | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 tests/data_model/test_soft_delete_roundtrip.py diff --git a/tests/data_model/test_soft_delete_roundtrip.py b/tests/data_model/test_soft_delete_roundtrip.py new file mode 100644 index 00000000..b9570563 --- /dev/null +++ b/tests/data_model/test_soft_delete_roundtrip.py @@ -0,0 +1,136 @@ +"""Step 6: soft-delete round-trip + repeated delete<->undo stability. + +Soft-delete keeps the node/edge in ``graph_full`` (flag ``solution=False``) and only drops +it from ``graph_solution``. These tests pin the core Phase-1 guarantees: +- the full graph's topology is preserved across a leaf-node soft-delete (only flags flip); +- delete -> undo restores the solution view exactly (ids + attributes); +- repeated undo/redo through the real ``ActionHistory`` never drifts the view (invariant + #4 in the persistence plan) — the R2 in-place revive must be a true inverse of remove. + +A leaf node (5, edge 4->5) is used so the delete introduces no reconnection skip-edge; the +skip-edge case is exercised in ``test_mid_track_delete_*`` to document that graph_full +accumulates the skip edge as a candidate. +""" + +import pytest + +from funtracks.user_actions import UserDeleteNode + + +def _solution_state(tracks): + """A hashable snapshot of the solution view's topology.""" + return ( + tuple(sorted(tracks.graph_solution.node_ids())), + tuple(sorted(tracks.graph_solution.edge_list())), + ) + + +def _full_state(tracks): + return ( + tuple(sorted(tracks.graph_full.node_ids())), + tuple(sorted(tracks.graph_full.edge_list())), + ) + + +def _positions(tracks): + """pos per solution node as plain float tuples (array-safe for == comparison).""" + return { + n: tuple(float(x) for x in tracks.get_node_attr(n, "pos")) + for n in tracks.graph_solution.node_ids() + } + + +@pytest.mark.parametrize("ndim", [3, 4]) +@pytest.mark.parametrize("with_seg", [True, False]) +def test_soft_delete_keeps_leaf_node_in_full_graph(get_tracks, ndim, with_seg): + tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + full_before = _full_state(tracks) + sol_nodes_before = set(tracks.graph_solution.node_ids()) + + UserDeleteNode(tracks, 5) + + # Gone from the solution view ... + assert 5 not in tracks.graph_solution.node_ids() + assert set(tracks.graph_solution.node_ids()) == sol_nodes_before - {5} + assert not tracks.graph_solution.has_edge(4, 5) + + # ... but preserved in the full graph as a soft-deleted candidate. Topology of the + # full graph is unchanged: only the solution flag flipped. + assert 5 in tracks.graph_full.node_ids() + assert tracks.graph_full.has_edge(4, 5) + assert tracks.graph_full.nodes[5]["solution"] is False + assert _full_state(tracks) == full_before + + +@pytest.mark.parametrize("ndim", [3, 4]) +@pytest.mark.parametrize("with_seg", [True, False]) +def test_delete_undo_redo_roundtrip_identity(get_tracks, ndim, with_seg): + tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + sol_ref = _solution_state(tracks) + full_ref = _full_state(tracks) + pos_ref = _positions(tracks) + + UserDeleteNode(tracks, 5) + assert _solution_state(tracks) != sol_ref # actually changed + + # Undo restores the solution view exactly, and the full graph is untouched. + assert tracks.undo() + assert _solution_state(tracks) == sol_ref + assert _full_state(tracks) == full_ref + assert _positions(tracks) == pos_ref + + # Redo re-deletes; undo restores again — same states. + assert tracks.redo() + assert 5 not in tracks.graph_solution.node_ids() + assert tracks.undo() + assert _solution_state(tracks) == sol_ref + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_repeated_delete_undo_is_stable(get_tracks, ndim): + """Invariant #4: N undo/redo cycles must not drift the solution view or full graph.""" + tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + sol_ref = _solution_state(tracks) + iou_ref = tracks.get_edge_attr((4, 5), "iou") + + UserDeleteNode(tracks, 5) + deleted_sol = _solution_state(tracks) + full_after_delete = _full_state(tracks) + + for _ in range(5): + assert tracks.undo() + assert _solution_state(tracks) == sol_ref + # Revived edge keeps its computed attribute. + assert tracks.get_edge_attr((4, 5), "iou") == iou_ref + + assert tracks.redo() + assert _solution_state(tracks) == deleted_sol + # Full graph topology is identical every redo — no candidate accumulation. + assert _full_state(tracks) == full_after_delete + + # Leave it restored and confirm a clean final identity. + assert tracks.undo() + assert _solution_state(tracks) == sol_ref + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_mid_track_delete_leaves_skip_edge_candidate_in_full(get_tracks, ndim): + """Deleting a mid-track node adds a reconnection skip-edge (3->5). On undo it is + soft-deleted, so it persists in graph_full as a solution=False candidate while the + solution view returns to its original topology.""" + tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + sol_ref = _solution_state(tracks) + assert not tracks.graph_full.has_edge(3, 5) + + UserDeleteNode(tracks, 4) + assert tracks.graph_solution.has_edge(3, 5) + assert not tracks.graph_solution.has_edge(3, 4) + + assert tracks.undo() + # Solution view is back to the original topology (skip edge removed from the view) ... + assert _solution_state(tracks) == sol_ref + assert not tracks.graph_solution.has_edge(3, 5) + # ... but the skip edge now lives in graph_full as a candidate; node 4 is retained. + assert tracks.graph_full.has_edge(3, 5) + assert tracks.graph_full.edge_id(3, 5) is not None + assert 4 in tracks.graph_full.node_ids() From 48b537a8343a6016a90d4ba3b4c2be03b3ba4bfb Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 11 Jun 2026 12:47:26 -0400 Subject: [PATCH 06/39] stale warning --- src/funtracks/import_export/_v1_format.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/funtracks/import_export/_v1_format.py b/src/funtracks/import_export/_v1_format.py index d2ca64e8..430643d5 100644 --- a/src/funtracks/import_export/_v1_format.py +++ b/src/funtracks/import_export/_v1_format.py @@ -124,9 +124,7 @@ def load_v1_tracks( attrs.setdefault("lineage_attr", "lineage_id") with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="Provided both FeatureDict and pos_attr or time_attr" - ) + warnings.filterwarnings("ignore", message="Provided both FeatureDict and pos") tracks = Tracks(graph_td, **attrs) return tracks From 5ed7dad8eb0c9d135c7d172a401b774eba1098cc Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 11 Jun 2026 13:01:07 -0400 Subject: [PATCH 07/39] get_positions read wrong graph --- src/funtracks/data_model/tracks.py | 12 ++++++--- .../data_model/test_soft_delete_roundtrip.py | 27 +++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index f7b1eb4c..d777be47 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -221,10 +221,11 @@ def _get_feature_set( - List/tuple: multiple attributes, one per axis (e.g., ["y", "x"]) - None: defaults to "pos" tracklet_key: Graph attribute name for tracklet/track IDs - (e.g., "tracklet_id"). - If None, defaults to "tracklet_id" + (e.g., "tracklet_id"). Left None unless explicitly provided — a non-None + tracklet_key is the signal that this Tracks wants track ids (registers a + TrackAnnotator). No default is applied. lineage_key: Graph attribute name for lineage IDs (e.g., "lineage_id"). - if None, defaults to "lineage_id" + Left None unless explicitly provided (see tracklet_key). No default. Returns: FeatureDict initialized with time feature and position if no segmentation @@ -460,7 +461,10 @@ def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.nd + ([self.features.time_key] if incl_time else []) ) - df = self.graph_solution.node_attrs(attr_keys=attr_keys) + # Read from graph_full (consistent with get_position / the attr-helper policy): + # positions are intrinsic node attrs, so this also resolves soft-deleted + # (solution=False) nodes instead of KeyError-ing like a graph_solution query. + df = self.graph_full.node_attrs(attr_keys=attr_keys) id_to_row = { nid: i for i, nid in enumerate(df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list()) } diff --git a/tests/data_model/test_soft_delete_roundtrip.py b/tests/data_model/test_soft_delete_roundtrip.py index b9570563..691d16d2 100644 --- a/tests/data_model/test_soft_delete_roundtrip.py +++ b/tests/data_model/test_soft_delete_roundtrip.py @@ -134,3 +134,30 @@ def test_mid_track_delete_leaves_skip_edge_candidate_in_full(get_tracks, ndim): assert tracks.graph_full.has_edge(3, 5) assert tracks.graph_full.edge_id(3, 5) is not None assert 4 in tracks.graph_full.node_ids() + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_attr_reads_resolve_for_soft_deleted_node(get_tracks, ndim): + """Regression: attribute reads must resolve for soft-deleted (solution=False) nodes, + and the bulk `get_positions` must agree with the single-node `get_position`. The bulk + path previously queried `graph_solution` and KeyError'd on a soft-deleted node while + `get_position` (graph_full) succeeded — a latent inconsistency invisible to tests that + only ever query in-solution nodes. + """ + tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + pos_single_before = tracks.get_position(5) + pos_bulk_before = tracks.get_positions([5])[0].tolist() + assert pos_bulk_before == pytest.approx(pos_single_before) + + UserDeleteNode(tracks, 5) + assert 5 not in tracks.graph_solution.node_ids() # soft-deleted + + # Both single and bulk position reads still resolve (graph_full) and agree. + pos_single_after = tracks.get_position(5) + pos_bulk_after = tracks.get_positions([5])[0].tolist() + assert pos_single_after == pytest.approx(pos_single_before) + assert pos_bulk_after == pytest.approx(pos_single_before) + + # Other intrinsic attrs resolve too. + assert tracks.get_node_attr(5, "area") is not None + assert tracks.get_time(5) is not None From 6319dac460115735688218128cee89962cba7391 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:12:04 +0000 Subject: [PATCH 08/39] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.15.17 → v0.15.18](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.17...v0.15.18) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7759b1e7..dca2cdd1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: check-yaml # checks for correct yaml syntax for github actions ex. args: [--unsafe] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.17 + rev: v0.15.18 hooks: - id: ruff args: [--fix] From cccac040f1a294f8e8c81ab26c677a3931f1f2ab Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 18 Jun 2026 11:39:43 -0400 Subject: [PATCH 09/39] Provide explicit node_name_map in test for subgroup export/import --- tests/import_export/test_import_from_geff.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index 3ec7bcbb..47a5b010 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -1124,8 +1124,20 @@ def test_subgroup_export_omits_featuredict_and_recomputes_on_import(get_tracks, "Subgroup export should not include a FeatureDict in GEFF metadata" ) - # Import should succeed — takes the auto-detect path and recomputes IDs - imported = import_from_geff(geff_path) + # Import should succeed and recompute IDs for the new subgraph topology. + # The source graph used "track_id" (not "tracklet_id") as the column name; + # the axes-based auto-inference identity-maps it, so we provide the correct + # mapping explicitly. + imported = import_from_geff( + geff_path, + node_name_map={ + "time": "t", + "pos": ["y", "x"], + "tracklet_id": "track_id", + "lineage_id": "lineage_id", + "solution": "solution", + }, + ) assert isinstance(imported, Tracks) assert imported.features.tracklet_key is not None From 8c5ebe8d4da7b4e2cdde05641ff0587a526a6bb4 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 24 Jun 2026 15:02:06 -0400 Subject: [PATCH 10/39] Refactor geff export to work for internal save as well --- docs/features.md | 2 +- src/funtracks/import_export/__init__.py | 3 +- src/funtracks/import_export/geff/_export.py | 128 ++++++++++++++------ 3 files changed, 93 insertions(+), 40 deletions(-) diff --git a/docs/features.md b/docs/features.md index 8022a26e..0faae01d 100644 --- a/docs/features.md +++ b/docs/features.md @@ -353,7 +353,7 @@ tracks.disable_features(["area"]) def compute(self, feature_keys=None): # Compute feature values in bulk if "custom" in self.features: - for node in self.tracks.graph.node_ids(): + for node in self.tracks.graph_solution.node_ids(): value = self._compute_custom(node) self.tracks[node]["custom"] = value diff --git a/src/funtracks/import_export/__init__.py b/src/funtracks/import_export/__init__.py index 28313e6a..587a2a73 100644 --- a/src/funtracks/import_export/__init__.py +++ b/src/funtracks/import_export/__init__.py @@ -2,7 +2,7 @@ from ._v1_format import load_v1_tracks from .csv._export import export_to_csv from .csv._import import CSVTracksBuilder, tracks_from_df -from .geff._export import export_to_geff +from .geff._export import export_to_geff, write_to_geff from .geff._import import GeffTracksBuilder, import_from_geff from .magic_imread import magic_imread @@ -14,6 +14,7 @@ "tracks_from_df", "export_to_csv", "export_to_geff", + "write_to_geff", "load_v1_tracks", "magic_imread", ] diff --git a/src/funtracks/import_export/geff/_export.py b/src/funtracks/import_export/geff/_export.py index 5ff7d6d6..6bd49161 100644 --- a/src/funtracks/import_export/geff/_export.py +++ b/src/funtracks/import_export/geff/_export.py @@ -21,6 +21,38 @@ from funtracks.data_model.tracks import Tracks +def write_to_geff( + tracks: Tracks, + path: Path, + overwrite: bool = False, + zarr_format: Literal[2, 3] = 2, +): + """Write tracks directly to a geff store at the given path. + + Unlike :func:`export_to_geff` (which creates a parent zarr container with + a ``tracks.geff`` subfolder and optional segmentation), this writes the + geff store directly to *path*. Intended for internal save/load workflows + where the user picks the ``.geff`` path. + + Args: + tracks: Tracks object containing a graph to save. + path: Destination path for the geff store. + overwrite: If True, overwrites an existing store at *path*. + zarr_format: Zarr format version to use. Defaults to 2. + """ + path = remove_tilde(path) + path = path.resolve(strict=False) + + graph, metadata = _build_geff_metadata(tracks, include_features=True) + graph.to_geff( + geff_store=path, + geff_metadata=metadata, + zarr_format=zarr_format, + overwrite=overwrite, + ) + _write_segmentation_shape(path, tracks) + + def export_to_geff( tracks: Tracks, directory: Path, @@ -64,6 +96,55 @@ def export_to_geff( mode: Literal["w", "w-"] = "w" if overwrite else "w-" setup_zarr_group(directory, zarr_format=zarr_format, mode=mode) + # Include the FeatureDict in metadata only for full exports. + # Subgroup exports do not necessarily have valid tracklet/lineage IDs + # and thus are not valid SolutionTracks + graph, metadata = _build_geff_metadata(tracks, include_features=(node_ids is None)) + + # Save segmentation if present and requested + if save_segmentation and tracks.segmentation is not None: + seg_name = "segmentation.tif" if seg_file_format == "tiff" else "segmentation" + # Tiff is saved next to (sibling of) the geff directory to avoid napari + # misidentifying it as zarr when the geff directory has a .zarr extension. + seg_parent = directory.parent if seg_file_format == "tiff" else directory + rel_prefix = "../.." if seg_file_format == "tiff" else ".." + export_segmentation( + tracks, + seg_parent / seg_name, + file_format=seg_file_format, + relabel=seg_relabel, + zarr_format=zarr_format, + node_ids=set(nodes_to_keep) if node_ids is not None else None, + ) + label_prop = resolve_relabel_attr(tracks, seg_relabel) or "node_id" + metadata.related_objects = [ + { + "path": f"{rel_prefix}/{seg_name}", + "type": "labels", + "label_prop": label_prop, + } + ] + + # Filter the graph if node_ids is provided + if node_ids is not None: + graph = graph.filter(node_ids=nodes_to_keep).subgraph() + + # Save the graph in a 'tracks.geff' folder + tracks_path = directory / "tracks.geff" + graph.to_geff(geff_store=tracks_path, geff_metadata=metadata, zarr_format=zarr_format) + + _write_segmentation_shape(tracks_path, tracks) + + +def _build_geff_metadata( + tracks: Tracks, + include_features: bool = True, +) -> tuple[td.graph.GraphView, GeffMetadata]: + """Build the geff metadata and prepare the graph for writing. + + Returns the (possibly modified) graph with split position attributes + and the GeffMetadata object. + """ # update the graph to split the position into separate attrs, if they are currently # together in a list graph, axis_names = split_position_attr(tracks) @@ -90,11 +171,8 @@ def export_to_geff( } ) - # Include the FeatureDict in metadata only for full exports. - # Subgroup exports do not necessarily have valid tracklet/lineage IDs are no - # and thus are not valid SolutionTracks extra: dict = {} - if node_ids is None: + if include_features: extra["funtracks"] = {"features": tracks.features.dump_json()} metadata = GeffMetadata( @@ -106,47 +184,21 @@ def export_to_geff( extra=extra, ) - # Save segmentation if present and requested - if save_segmentation and tracks.segmentation is not None: - seg_name = "segmentation.tif" if seg_file_format == "tiff" else "segmentation" - # Tiff is saved next to (sibling of) the geff directory to avoid napari - # misidentifying it as zarr when the geff directory has a .zarr extension. - seg_parent = directory.parent if seg_file_format == "tiff" else directory - rel_prefix = "../.." if seg_file_format == "tiff" else ".." - export_segmentation( - tracks, - seg_parent / seg_name, - file_format=seg_file_format, - relabel=seg_relabel, - zarr_format=zarr_format, - node_ids=set(nodes_to_keep) if node_ids is not None else None, - ) - label_prop = resolve_relabel_attr(tracks, seg_relabel) or "node_id" - metadata.related_objects = [ - { - "path": f"{rel_prefix}/{seg_name}", - "type": "labels", - "label_prop": label_prop, - } - ] + return graph, metadata - # Filter the graph if node_ids is provided - if node_ids is not None: - graph = graph.filter(node_ids=nodes_to_keep).subgraph() - # Save the graph in a 'tracks.geff' folder - tracks_path = directory / "tracks.geff" - graph.to_geff(geff_store=tracks_path, geff_metadata=metadata, zarr_format=zarr_format) +def _write_segmentation_shape(geff_path: Path, tracks: Tracks) -> None: + """Write segmentation_shape as an extra zarr attribute when masks are present. - # Write segmentation_shape as an extra zarr attribute when masks are present. - # GeffMetadata has no segmentation_shape field, so it must be stored separately. - # This allows import_from_geff to reconstruct the segmentation (GraphArrayView) - # without requiring an external segmentation file. + GeffMetadata has no segmentation_shape field, so it must be stored separately. + This allows import_from_geff to reconstruct the segmentation (GraphArrayView) + without requiring an external segmentation file. + """ seg_shape = tracks.graph_solution.metadata.get("segmentation_shape") if seg_shape is not None: import zarr as _zarr - z = _zarr.open(str(tracks_path), mode="a") + z = _zarr.open(str(geff_path), mode="a") attrs = dict(z.attrs) attrs["segmentation_shape"] = list(seg_shape) z.attrs.update(attrs) From 58cfb334716e558177d31d96277e336cf3d7a6c2 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 24 Jun 2026 15:22:44 -0400 Subject: [PATCH 11/39] Add write_to_geff tests --- tests/import_export/test_export_to_geff.py | 93 +++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index 0b0307ec..543e18a8 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -5,7 +5,7 @@ import zarr from funtracks.data_model import Tracks -from funtracks.import_export import export_to_geff +from funtracks.import_export import export_to_geff, import_from_geff, write_to_geff def _assert_valid_geff_export(export_dir, expected_num_nodes=None): @@ -356,3 +356,94 @@ def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path): z = zarr.open((export_dir / "tracks.geff").as_posix(), mode="r") related = dict(z.attrs)["geff"].get("related_objects", []) assert any(obj["path"] == "../../segmentation.tif" for obj in related) + + +# --- write_to_geff --- + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_write_to_geff_roundtrip(get_tracks, ndim, tmp_path): + """write_to_geff then import_from_geff recovers the tracks.""" + tracks = get_tracks(ndim=ndim, with_seg=False, is_solution=True) + + geff_path = tmp_path / "my_tracks.geff" + write_to_geff(tracks, geff_path) + + loaded = import_from_geff(geff_path) + + assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes() + assert loaded.graph_solution.num_edges() == tracks.graph_solution.num_edges() + assert set(loaded.graph_solution.node_ids()) == set(tracks.graph_solution.node_ids()) + assert loaded.features.dump_json() == tracks.features.dump_json() + + +def test_write_to_geff_no_parent_container(get_tracks, tmp_path): + """write_to_geff writes directly to the path — no parent .zgroup or + tracks.geff subfolder.""" + tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + + geff_path = tmp_path / "my_tracks.geff" + write_to_geff(tracks, geff_path) + + # The store is at the given path, not nested inside it + z = zarr.open(str(geff_path), mode="r") + assert "geff" in dict(z.attrs) + assert "nodes" in z + + # No .zgroup/.zattrs at the parent level (tmp_path is not a zarr group) + assert not (tmp_path / ".zgroup").exists() + assert not (tmp_path / ".zattrs").exists() + + +def test_write_to_geff_overwrite(get_tracks, tmp_path): + """write_to_geff with overwrite=True replaces existing store.""" + tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + + geff_path = tmp_path / "my_tracks.geff" + write_to_geff(tracks, geff_path) + write_to_geff(tracks, geff_path, overwrite=True) + + loaded = import_from_geff(geff_path) + assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes() + + +def test_write_to_geff_no_overwrite_raises(get_tracks, tmp_path): + """write_to_geff with overwrite=False raises on existing store.""" + tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + + geff_path = tmp_path / "my_tracks.geff" + write_to_geff(tracks, geff_path) + + with pytest.raises(FileExistsError): + write_to_geff(tracks, geff_path, overwrite=False) + + +def test_write_to_geff_metadata(get_tracks, tmp_path): + """write_to_geff stores axes and FeatureDict metadata.""" + tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + + geff_path = tmp_path / "my_tracks.geff" + write_to_geff(tracks, geff_path) + + z = zarr.open(str(geff_path), mode="r") + attrs = dict(z.attrs) + + axes = attrs["geff"]["axes"] + assert len(axes) == 3 + assert [ax["type"] for ax in axes] == ["time", "space", "space"] + + assert "funtracks" in attrs["geff"].get("extra", {}) + assert "features" in attrs["geff"]["extra"]["funtracks"] + + +def test_write_to_geff_segmentation_shape(get_tracks, tmp_path): + """write_to_geff writes segmentation_shape when masks are present.""" + tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + + geff_path = tmp_path / "my_tracks.geff" + write_to_geff(tracks, geff_path) + + z = zarr.open(str(geff_path), mode="r") + attrs = dict(z.attrs) + assert "segmentation_shape" in attrs + assert tuple(attrs["segmentation_shape"]) == tracks.segmentation.shape From 99c9ddd60a20eee67c23d22b4d674cdca98ead7b Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 25 Jun 2026 10:30:02 -0700 Subject: [PATCH 12/39] add write_to_geff test with segmentation + removed stale if-statement --- src/funtracks/import_export/geff/_export.py | 11 +++++------ tests/import_export/test_export_to_geff.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/funtracks/import_export/geff/_export.py b/src/funtracks/import_export/geff/_export.py index 6bd49161..f1c33ae2 100644 --- a/src/funtracks/import_export/geff/_export.py +++ b/src/funtracks/import_export/geff/_export.py @@ -151,12 +151,11 @@ def _build_geff_metadata( if axis_names is None: axis_names = [] axis_names.insert(0, tracks.features.time_key) - if axis_names is not None: - axis_types = ( - ["time", "space", "space"] - if tracks.ndim == 3 - else ["time", "space", "space", "space"] - ) + axis_types = ( + ["time", "space", "space"] + if tracks.ndim == 3 + else ["time", "space", "space", "space"] + ) if tracks.scale is None: tracks.scale = (1.0,) * tracks.ndim diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index 543e18a8..61f3aa19 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -362,9 +362,10 @@ def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path): @pytest.mark.parametrize("ndim", [3, 4]) -def test_write_to_geff_roundtrip(get_tracks, ndim, tmp_path): +@pytest.mark.parametrize("with_seg", [False, True]) +def test_write_to_geff_roundtrip(get_tracks, ndim, with_seg, tmp_path): """write_to_geff then import_from_geff recovers the tracks.""" - tracks = get_tracks(ndim=ndim, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -376,6 +377,13 @@ def test_write_to_geff_roundtrip(get_tracks, ndim, tmp_path): assert set(loaded.graph_solution.node_ids()) == set(tracks.graph_solution.node_ids()) assert loaded.features.dump_json() == tracks.features.dump_json() + if with_seg: + assert loaded.segmentation is not None + assert loaded.segmentation.shape == tracks.segmentation.shape + np.testing.assert_array_equal( + np.asarray(loaded.segmentation[:]), np.asarray(tracks.segmentation[:]) + ) + def test_write_to_geff_no_parent_container(get_tracks, tmp_path): """write_to_geff writes directly to the path — no parent .zgroup or From c490b2b9582d2a87839c62d4da3df00257d52b66 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Fri, 26 Jun 2026 10:09:49 -0700 Subject: [PATCH 13/39] add solutiontracks benchmark --- tests/benchmarks/bench_candidate_graph.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/benchmarks/bench_candidate_graph.py b/tests/benchmarks/bench_candidate_graph.py index 0f7ee293..7d9d63db 100644 --- a/tests/benchmarks/bench_candidate_graph.py +++ b/tests/benchmarks/bench_candidate_graph.py @@ -9,6 +9,7 @@ from skimage.draw import disk from funtracks.candidate_graph.compute_graph import compute_graph_from_seg +from funtracks.data_model import SolutionTracks NUM_FRAMES = 50 FRAME_SHAPE = (700, 1100) @@ -50,3 +51,20 @@ def test_compute_graph_from_seg(benchmark, seg_data): rounds=1, iterations=1, ) + + +def test_graph_to_solution(benchmark, seg_data): + """Benchmark candidate graph -> SolutionTracks (tracklet/lineage assignment). + + Candidate-graph construction is benchmarked separately above and is built here in + (untimed) setup, so only the SolutionTracks construction -- dominated by + TrackAnnotator._assign_tracklet_ids and _assign_lineage_ids -- is measured. + """ + + def setup(): + # Fresh graph per round: SolutionTracks construction mutates the graph + # (adds tracklet/lineage IDs), so each measured call must start unannotated. + graph = compute_graph_from_seg(seg_data, MAX_EDGE_DISTANCE, iou=True) + return (graph,), {"time_attr": "t", "pos_attr": "pos"} + + benchmark.pedantic(SolutionTracks, setup=setup, rounds=1, iterations=1) From 7f3d82e0579a6bb81eb34bf116669e6b3783d647 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Fri, 26 Jun 2026 10:42:31 -0700 Subject: [PATCH 14/39] speed up assign_lineage_ids and assign_tracklet_ids --- src/funtracks/annotators/_track_annotator.py | 49 ++++++++------------ 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index e728d4a6..5d0d4ce9 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -202,13 +202,13 @@ def _assign_lineage_ids(self) -> None: lineages_internal = rx.weakly_connected_components( self.tracks.graph_solution.rx_graph ) - lineages_external = [] - for lin in lineages_internal: - node_ids_internal = list(lin) - node_ids_external = [ - self.tracks.graph_solution.node_ids()[nid] for nid in node_ids_internal - ] - lineages_external.append(node_ids_external) + # Map each component's internal node indices to external ids in one batched + # call. node_ids() rebuilds the full external-id list on every call, so calling + # it per node (as before) was O(N^2). + lineages_external = [ + self.tracks.graph_solution._map_to_external(list(lin)) + for lin in lineages_internal + ] max_id, ids_to_nodes = self._assign_ids(lineages_external, self.lineage_key) self.max_lineage_id = max_id @@ -220,32 +220,23 @@ def _assign_tracklet_ids(self) -> None: After removing division edges, each connected component will get a unique ID, and the relevant class attributes will be updated. """ - graph_copy = self.tracks.graph_solution.detach().filter().subgraph() - - parents = [ - node - for node, degree in zip( - self.tracks.graph_solution.node_ids(), - self.tracks.graph_solution.out_degree(), - strict=True, - ) - if degree >= 2 - ] - - # Remove all intertrack edges from a copy of the original graph - for parent in parents: - all_edges = self.tracks.graph_solution.edge_list() - daughters = [edge[1] for edge in all_edges if edge[0] == parent] - - for daughter in daughters: - graph_copy.remove_edge(parent, daughter) + # Work on a plain copy of the underlying rustworkx graph and strip the + # intertrack (division) edges there. Removing edges through the GraphView is + # slow because each remove_edge syncs the view's bidirectional edge maps, + # whereas rustworkx edge removal is in-memory. copy() preserves node indices, + # so components map back through the original graph's id mapping. + rx_copy = self.tracks.graph_solution.rx_graph.copy() + for node in rx_copy.node_indices(): + if rx_copy.out_degree(node) >= 2: + for _, daughter, _ in list(rx_copy.out_edges(node)): + rx_copy.remove_edge(node, daughter) track_id = 1 all_node_ids = [] all_track_ids = [] - for tracklet in rx.weakly_connected_components(graph_copy.rx_graph): - node_ids_internal = list(tracklet) - node_ids_external = [graph_copy.node_ids()[nid] for nid in node_ids_internal] + for tracklet in rx.weakly_connected_components(rx_copy): + # Batched internal -> external mapping (see _assign_lineage_ids). + node_ids_external = self.tracks.graph_solution._map_to_external(list(tracklet)) all_node_ids.extend(node_ids_external) all_track_ids.extend([track_id] * len(node_ids_external)) self.tracklet_id_to_nodes[track_id] = node_ids_external From 576a739e8ca189bb0e5e2c6298909f2d02263316 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Mon, 29 Jun 2026 17:10:43 -0700 Subject: [PATCH 15/39] replace solutiontracks in benchmark --- tests/benchmarks/bench_candidate_graph.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/benchmarks/bench_candidate_graph.py b/tests/benchmarks/bench_candidate_graph.py index 7d9d63db..f74f65a2 100644 --- a/tests/benchmarks/bench_candidate_graph.py +++ b/tests/benchmarks/bench_candidate_graph.py @@ -9,7 +9,7 @@ from skimage.draw import disk from funtracks.candidate_graph.compute_graph import compute_graph_from_seg -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks NUM_FRAMES = 50 FRAME_SHAPE = (700, 1100) @@ -54,17 +54,24 @@ def test_compute_graph_from_seg(benchmark, seg_data): def test_graph_to_solution(benchmark, seg_data): - """Benchmark candidate graph -> SolutionTracks (tracklet/lineage assignment). + """Benchmark candidate graph -> Tracks with track ids (tracklet/lineage assignment). Candidate-graph construction is benchmarked separately above and is built here in - (untimed) setup, so only the SolutionTracks construction -- dominated by + (untimed) setup, so only the Tracks construction -- dominated by TrackAnnotator._assign_tracklet_ids and _assign_lineage_ids -- is measured. + Setting tracklet_attr/lineage_attr registers a TrackAnnotator and triggers the + track-id computation on construction. """ def setup(): - # Fresh graph per round: SolutionTracks construction mutates the graph + # Fresh graph per round: Tracks construction mutates the graph # (adds tracklet/lineage IDs), so each measured call must start unannotated. graph = compute_graph_from_seg(seg_data, MAX_EDGE_DISTANCE, iou=True) - return (graph,), {"time_attr": "t", "pos_attr": "pos"} + return (graph,), { + "time_attr": "t", + "pos_attr": "pos", + "tracklet_attr": "tracklet_id", + "lineage_attr": "lineage_id", + } - benchmark.pedantic(SolutionTracks, setup=setup, rounds=1, iterations=1) + benchmark.pedantic(Tracks, setup=setup, rounds=1, iterations=1) From 15425a2589068187b8c2872d864a81c66de925c3 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 11:11:50 -0700 Subject: [PATCH 16/39] replace graph_solution with graph_full + removed in_degree and out_degree + graph_annotator default is graph_full --- src/funtracks/annotators/_edge_annotator.py | 6 ----- src/funtracks/annotators/_graph_annotator.py | 19 +++++++------- .../annotators/_regionprops_annotator.py | 5 ---- src/funtracks/annotators/_track_annotator.py | 7 ++++++ src/funtracks/data_model/tracks.py | 25 ------------------- src/funtracks/user_actions/user_add_edge.py | 4 +-- src/funtracks/user_actions/user_add_node.py | 6 ++--- .../user_actions/user_delete_edge.py | 2 +- .../user_actions/user_update_segmentation.py | 16 +++++++++++- src/funtracks/utils/_segmentation_utils.py | 2 +- tests/data_model/test_tracks.py | 12 --------- tests/import_export/test_csv_import.py | 10 +++----- 12 files changed, 42 insertions(+), 72 deletions(-) diff --git a/src/funtracks/annotators/_edge_annotator.py b/src/funtracks/annotators/_edge_annotator.py index 4c35811d..ee71b95d 100644 --- a/src/funtracks/annotators/_edge_annotator.py +++ b/src/funtracks/annotators/_edge_annotator.py @@ -41,12 +41,6 @@ def can_annotate(cls, tracks) -> bool: """ return tracks.segmentation is not None - @property - def graph(self): - """IoU is an intrinsic link feature → computed on the full graph (all edges, - including soft-deleted/candidate ones, so they stay ready for re-solving).""" - return self.tracks.graph_full - @classmethod def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]: """Get all features that can be computed by this annotator. diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py index f7fca8b6..10b3e99c 100644 --- a/src/funtracks/annotators/_graph_annotator.py +++ b/src/funtracks/annotators/_graph_annotator.py @@ -56,17 +56,16 @@ def __init__(self, tracks: Tracks, features: dict[str, Feature]): def graph(self): """The graph this annotator iterates over and reads topology/masks from. - Defaults to the solution view. Detection-feature annotators - (`RegionpropsAnnotator`, `EdgeAnnotator`) override this to return - `graph_full`, so intrinsic features (`pos`, `area`, `iou`, ...) are computed - for *every* node/edge — including soft-deleted (`solution=False`) candidates — - keeping the full and solution graphs in sync and ready for re-solving. Track-id - features (`tracklet_id`, `lineage_id`) are solution-only, so `TrackAnnotator` - keeps the default. Note that attribute *writes* go through the `tracks` helpers, - which target `graph_full`; because attr dicts are shared by reference, in-solution - nodes see those writes through the view automatically. + Defaults to the full graph. Detection features (`pos`, `area`, `iou`, ...) are + intrinsic to a node/edge — independent of solution membership — so they are + computed for *every* node/edge, including soft-deleted (`solution=False`) + candidates, keeping them valid for revive and ready for re-solving. + `TrackAnnotator` overrides this to `graph_solution`, since track ids + (`tracklet_id`, `lineage_id`) are derived from the solution topology. Root and + view share attribute storage, so writes made here are seen through the solution + view automatically. """ - return self.tracks.graph_solution + return self.tracks.graph_full def activate_features(self, keys: list[str]) -> None: """Activate computation of the given features in the annotation process. diff --git a/src/funtracks/annotators/_regionprops_annotator.py b/src/funtracks/annotators/_regionprops_annotator.py index 9fd58d83..e35295a9 100644 --- a/src/funtracks/annotators/_regionprops_annotator.py +++ b/src/funtracks/annotators/_regionprops_annotator.py @@ -73,11 +73,6 @@ def can_annotate(cls, tracks) -> bool: """ return tracks.segmentation is not None - @property - def graph(self): - """Regionprops features are intrinsic detections → computed on the full graph.""" - return self.tracks.graph_full - def __init__( self, tracks: Tracks, diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index cfc598c4..d0805b7a 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -67,6 +67,13 @@ def can_annotate(cls, tracks) -> bool: """ return tracks.features.tracklet_key is not None + @property + def graph(self): + """Track ids (`tracklet_id`, `lineage_id`) are derived from the solution + topology, so this annotator reads/iterates the solution view, not the full + graph (overriding the base default of `graph_full`).""" + return self.tracks.graph_solution + @classmethod def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]: """Get all features that can be computed by this annotator. diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index d777be47..ff0eec6e 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -400,31 +400,6 @@ def nodes(self): def edges(self): return np.array(self.graph_solution.edge_ids()) - def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: - """Get the in-degree edge_ids of the nodes in the graph.""" - if nodes is not None: - # make sure nodes is a numpy array - if not isinstance(nodes, np.ndarray): - nodes = np.array(nodes) - - return np.array( - [self.graph_solution.in_degree(node.item()) for node in nodes] - ) - else: - return np.array(self.graph_solution.in_degree()) - - def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: - if nodes is not None: - # make sure nodes is a numpy array - if not isinstance(nodes, np.ndarray): - nodes = np.array(nodes) - - return np.array( - [self.graph_solution.out_degree(node.item()) for node in nodes] - ) - else: - return np.array(self.graph_solution.out_degree()) - def predecessors(self, node: int) -> list[int]: return list(self.graph_solution.predecessors(node)) diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py index 5075f1c3..caf8f3c8 100644 --- a/src/funtracks/user_actions/user_add_edge.py +++ b/src/funtracks/user_actions/user_add_edge.py @@ -48,7 +48,7 @@ def __init__( # Check if making a merge. If yes and force, remove the other edge and update # track ids. - in_degree_target = self.tracks.graph_solution.in_degree(target) + in_degree_target = len(self.tracks.predecessors(target)) # type: ignore if in_degree_target > 0: if not force: raise InvalidActionError( @@ -68,7 +68,7 @@ def __init__( ) # update track ids if needed - out_degree_source = self.tracks.graph_solution.out_degree(source) + out_degree_source = len(self.tracks.successors(source)) if out_degree_source == 0: # joining two segments # assign the track id and lineage id of the source node to the target # and all downstream nodes diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index 902ff4b8..dbb2940c 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -77,7 +77,7 @@ def __init__( raise InvalidActionError( f"Cannot add node without track id. Please add {track_id_key} attribute" ) - if self.tracks.graph_solution.has_node(node): + if self.tracks.graph_full.has_node(node): raise InvalidActionError( f"Node {node} already exists in the tracks, cannot add." ) @@ -98,7 +98,7 @@ def __init__( pred, succ = self.tracks.get_track_neighbors(track_id, time) # check if you are adding a node to a track that divided previously - if pred is not None and self.tracks.graph_solution.out_degree(int(pred)) == 2: + if pred is not None and len(self.tracks.successors(pred)) == 2: if not force: raise InvalidActionError( "Cannot add node here - upstream division event detected.", @@ -122,7 +122,7 @@ def __init__( pred_of_succ = preds[0] if preds else None if ( pred_of_succ is not None - and self.tracks.graph_solution.out_degree(pred_of_succ) == 2 + and len(self.tracks.successors(pred_of_succ)) == 2 ): if not force: raise InvalidActionError( diff --git a/src/funtracks/user_actions/user_delete_edge.py b/src/funtracks/user_actions/user_delete_edge.py index 1a11fed6..098a73ac 100644 --- a/src/funtracks/user_actions/user_delete_edge.py +++ b/src/funtracks/user_actions/user_delete_edge.py @@ -36,7 +36,7 @@ def __init__( raise InvalidActionError(f"Edge {edge} not in solution, can't remove") self.actions.append(DeleteEdge(tracks, edge)) - out_degree = self.tracks.graph_solution.out_degree(edge[0]) + out_degree = len(self.tracks.successors(edge[0])) if out_degree == 0: # removed a normal (non division) edge # orphaned segment gets new track id and new lineage id new_track_id = self.tracks.get_next_track_id() diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index 3c355b3d..c750ecbb 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -62,7 +62,21 @@ def __init__( "Can only update one time point at a time" ) time = int(all_pixels[0][0]) - if self.tracks.graph_solution.has_node(new_value): + if self.tracks.graph_full.has_node(new_value): + # An id that already names a node must take the update path: you cannot + # create a *new* node on a taken id. Ids are globally unique across + # graph_full (full ⊇ solution, never reused), so graph_full is the + # correct "does this id name a node?" check; using graph_solution here + # would misroute a soft-deleted id to the add path and silently revive + # the old node with stale attributes. + # NOTE: reviving a soft-deleted (solution=False) node by painting it + # back is not supported yet — UpdateNodeSeg reads the solution view and + # would fail. Guard explicitly until revive-by-paint is implemented. + if not self.tracks.graph_solution.has_node(new_value): + raise NotImplementedError( + f"Cannot paint onto node {new_value}: it is soft-deleted (not " + "in the solution). Revive-by-paint is not supported yet." + ) mask_pixels = pixels_to_td_mask(all_pixels, self.tracks.ndim) self.actions.append( UpdateNodeSeg(tracks, new_value, mask_pixels, added=True) diff --git a/src/funtracks/utils/_segmentation_utils.py b/src/funtracks/utils/_segmentation_utils.py index 73edf470..4c13e443 100644 --- a/src/funtracks/utils/_segmentation_utils.py +++ b/src/funtracks/utils/_segmentation_utils.py @@ -25,7 +25,7 @@ def relabel_segmentation_with_track_id( # Division nodes have out_degree > 1; their outgoing edges are cut so that # each daughter cell starts a new tracklet division_nodes = { - n for n in solution_graph.node_ids() if solution_graph.out_degree(n) > 1 + n for n in solution_graph.node_ids() if len(solution_graph.successors(n)) > 1 } visited: set = set() diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index bd0eaa65..18828272 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -104,18 +104,6 @@ def test_nodes_edges(graph_2d_with_segmentation): } -def test_degrees(graph_2d_with_segmentation): - tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) - assert tracks.in_degree(np.array([1])) == 0 - assert tracks.in_degree(np.array([4])) == 1 - assert np.array_equal(tracks.in_degree(None), np.array([0, 1, 1, 1, 1, 0])) - assert np.array_equal(tracks.out_degree(np.array([1, 4])), np.array([2, 1])) - assert np.array_equal( - tracks.out_degree(None), - np.array([2, 0, 1, 1, 0, 0]), - ) - - def test_predecessors_successors(graph_2d_with_segmentation): tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) assert tracks.predecessors(2) == [1] diff --git a/tests/import_export/test_csv_import.py b/tests/import_export/test_csv_import.py index 3add3c99..18fcd66f 100644 --- a/tests/import_export/test_csv_import.py +++ b/tests/import_export/test_csv_import.py @@ -186,7 +186,7 @@ def test_multiple_roots(self): roots = [ n for n in tracks.graph_solution.node_ids() - if tracks.graph_solution.in_degree(n) == 0 + if len(tracks.predecessors(n)) == 0 ] assert len(roots) == 2 @@ -260,18 +260,16 @@ def test_long_track(self): roots = [ n for n in tracks.graph_solution.node_ids() - if tracks.graph_solution.in_degree(n) == 0 + if len(tracks.predecessors(n)) == 0 ] assert len(roots) == 1 # Each non-leaf node should have exactly one child non_leaves = [ - n - for n in tracks.graph_solution.node_ids() - if tracks.graph_solution.out_degree(n) > 0 + n for n in tracks.graph_solution.node_ids() if len(tracks.successors(n)) > 0 ] for node in non_leaves: - assert tracks.graph_solution.out_degree(node) == 1 + assert len(tracks.successors(node)) == 1 def test_orphaned_node_raises_error(self): """Test that node with invalid parent_id raises error.""" From 9942055b4d8d9cd96c366a1dac94aa3170ad4cda Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 12:57:49 -0700 Subject: [PATCH 17/39] update tracksdata pin --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0f0dadaa..392125b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies =[ "dask>=2025.5.0", "pandas>=2.3.3", "zarr>=2.18,<4", - "tracksdata[spatial]==0.1.0rc4", + "tracksdata[spatial]>=0.1.0rc5", "tqdm>=4.66.1", # zarr 2.x's util.py imports cbuffer_sizes/cbuffer_metainfo from # numcodecs.blosc, which numcodecs >= 0.16 removed. Pin numcodecs per @@ -126,4 +126,4 @@ exclude_also = [ #remove this after tracksdata PR is merged: [tool.uv.sources] -tracksdata = { git = "https://github.com/TeunHuijben/tracksdata.git", rev = "040ed9f03c0f64c587355e9fb6c7f9f08d09855c" } +tracksdata = { git = "https://github.com/TeunHuijben/tracksdata.git", rev = "06465dcabb772622aed2ae66417c0b857b4ed35c" } From c4bf231dbcd3a2f104418163085fd305080de145 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 13:35:45 -0700 Subject: [PATCH 18/39] make sure that tracks always has track_annotator and tracklet_key --- src/funtracks/annotators/_track_annotator.py | 40 ++---- src/funtracks/data_model/tracks.py | 119 +++++++++--------- src/funtracks/import_export/csv/_export.py | 2 +- .../user_actions/user_update_segmentation.py | 2 - tests/annotators/test_annotator_registry.py | 10 +- tests/annotators/test_track_annotator.py | 12 +- tests/import_export/test_export_to_geff.py | 11 -- 7 files changed, 83 insertions(+), 113 deletions(-) diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index d0805b7a..5ab86296 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -25,9 +25,12 @@ class TrackAnnotator(GraphAnnotator): - """A graph annotator to compute tracklet and lineage IDs for Tracks only. + """A graph annotator that computes tracklet and lineage IDs on the solution view. - Currently, updating the tracklet and lineage IDs is left to Actions. + Registered on every Tracks — track ids are a core feature, not a separate "type" + of tracks. It reads/iterates `graph_solution`; on an empty solution view it simply + computes nothing until nodes are added. Updating the ids after construction is left + to Actions. Attributes: tracklet_id_to_nodes (dict[int, list[int]]): A mapping from tracklet ids to @@ -38,35 +41,14 @@ class TrackAnnotator(GraphAnnotator): max_lineage_id (int): the maximum lineage id used in the tracks Args: - tracks (Tracks): The tracks to be annotated. Must be a solution. - tracklet_key (str | None, optional): A key that already holds the tracklet ids - on the graph. If provided, must be there for every node and already hold - valid tracklet ids. Defaults to None. - lineage_key (str | None, optional): A key that already holds the lineage ids - on the graph. If provided, must be there for every node and already hold - valid lineage ids. Defaults to None. - - - Raises: - ValueError: if the provided Tracks are not Tracks (not a binary lineage - tree) + tracks (Tracks): The tracks to annotate. + tracklet_key (str | None, optional): The node attribute holding tracklet ids. + If the graph already holds valid ids under this key they are reused; + otherwise they are computed. Defaults to DEFAULT_TRACKLET_KEY. + lineage_key (str | None, optional): The node attribute holding lineage ids. + Same semantics as tracklet_key. Defaults to DEFAULT_LINEAGE_KEY. """ - @classmethod - def can_annotate(cls, tracks) -> bool: - """Check if this annotator can annotate the given tracks. - - Track ids are only meaningful when the tracks declares a tracklet_key (i.e. - it represents a solution). A None tracklet_key means a plain candidate graph. - - Args: - tracks: The tracks to check compatibility with - - Returns: - True if tracks.features.tracklet_key is set, False otherwise - """ - return tracks.features.tracklet_key is not None - @property def graph(self): """Track ids (`tracklet_id`, `lineage_id`) are derived from the solution diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index ff0eec6e..e88fe857 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -17,6 +17,7 @@ from tracksdata.nodes import Mask from funtracks.actions.action_history import ActionHistory +from funtracks.annotators import TrackAnnotator from funtracks.features import ( Feature, FeatureDict, @@ -90,11 +91,11 @@ def __init__( - List/tuple of strings for multi-axis (one attribute per axis) Defaults to "pos" if None. tracklet_attr (str | None): Graph attribute name for tracklet/track IDs. - If set (non-None), a TrackAnnotator is registered and track ids are - computed/maintained. If None, this is a plain (candidate) Tracks with - no track ids. - lineage_attr (str | None): Graph attribute name for lineage IDs. Only used - when tracklet_attr is set. + Defaults to "tracklet_id" if None. Every Tracks gets a TrackAnnotator + and track ids (no "plain" vs "solution" distinction); existing ids on + the graph are reused, otherwise they are computed. + lineage_attr (str | None): Graph attribute name for lineage IDs. + Defaults to "lineage_id" if None. scale (list[float] | None): Scaling factors for each dimension (including time). If None, all dimensions scaled by 1.0. ndim (int | None): Number of dimensions (3 for 2D+time, 4 for 3D+time). @@ -190,6 +191,11 @@ def __init__( else: self._setup_core_computed_features() + # 4. Enforce the track-id invariant on BOTH paths: every Tracks has a tracklet + # key and a registered TrackAnnotator, with tracklet_id/lineage_id registered + # and computed. A provided FeatureDict that omitted them is completed here. + self._ensure_track_features() + @property def graph_full(self) -> td.graph.GraphView: """The full graph: every node/edge ever known, including soft-deleted @@ -220,12 +226,10 @@ def _get_feature_set( - Single string: one attribute containing position array (e.g., "pos") - List/tuple: multiple attributes, one per axis (e.g., ["y", "x"]) - None: defaults to "pos" - tracklet_key: Graph attribute name for tracklet/track IDs - (e.g., "tracklet_id"). Left None unless explicitly provided — a non-None - tracklet_key is the signal that this Tracks wants track ids (registers a - TrackAnnotator). No default is applied. - lineage_key: Graph attribute name for lineage IDs (e.g., "lineage_id"). - Left None unless explicitly provided (see tracklet_key). No default. + tracklet_key: Graph attribute name for tracklet/track IDs. + Defaults to "tracklet_id" if None (every Tracks gets track ids). + lineage_key: Graph attribute name for lineage IDs. + Defaults to "lineage_id" if None. Returns: FeatureDict initialized with time feature and position if no segmentation @@ -234,9 +238,11 @@ def _get_feature_set( time_key = time_attr if time_attr is not None else "time" if pos_attr is None: pos_attr = "pos" - # tracklet_key / lineage_key are left None unless explicitly provided: a - # non-None tracklet_key is the signal that this Tracks wants track ids, which - # is what triggers TrackAnnotator registration (see TrackAnnotator.can_annotate). + # Every Tracks has a tracklet/lineage key (no "plain" vs "solution" split): + # default them like time/pos. _ensure_track_features() then registers and + # computes the ids; on an empty solution view that is a no-op. + tracklet_key = tracklet_key if tracklet_key is not None else "tracklet_id" + lineage_key = lineage_key if lineage_key is not None else "lineage_id" # Build static features dict - always include time features: dict[str, Feature] = {time_key: Time()} @@ -304,7 +310,6 @@ def _get_annotators(self) -> AnnotatorRegistry: AnnotatorRegistry, EdgeAnnotator, RegionpropsAnnotator, - TrackAnnotator, ) annotator_list: list[GraphAnnotator] = [] @@ -361,15 +366,14 @@ def _check_existing_feature(self, key: str) -> bool: return key in node_attrs def _setup_core_computed_features(self) -> None: - """Sets up core computed features (position, tracklet, lineage). + """Sets up core computed position features. - Registers position/tracklet/lineage features from annotators into - FeatureDict. For each core feature: - - Activates any features that are detected to already exist on the graph - - Enables (computes) any features that don't exist yet + Registers the position feature from the RegionpropsAnnotator into the + FeatureDict, activating it if it already exists on the graph or computing it + otherwise. Track-id features are handled separately by _ensure_track_features. """ # Import here to avoid circular dependency - from funtracks.annotators import RegionpropsAnnotator, TrackAnnotator + from funtracks.annotators import RegionpropsAnnotator core_features: list[str] = [] for annotator in self.annotators: @@ -378,14 +382,12 @@ def _setup_core_computed_features(self) -> None: if self.features.position_key is None: self.features.position_key = pos_key core_features.append(pos_key) - elif isinstance(annotator, TrackAnnotator): - tracklet_key = annotator.tracklet_key - self.features.tracklet_key = tracklet_key - core_features.append(tracklet_key) - lineage_key = annotator.lineage_key - self.features.lineage_key = lineage_key - core_features.append(lineage_key) - for key in core_features: + self._register_core_features(core_features) + + def _register_core_features(self, keys: list[str]) -> None: + """Register each key as a feature: activate it if it already exists on the + graph, otherwise enable (compute) it.""" + for key in keys: if self._check_existing_feature(key): if key not in self.features: feature, _ = self.annotators.all_features[key] @@ -394,6 +396,22 @@ def _setup_core_computed_features(self) -> None: else: self.enable_features([key]) + def _ensure_track_features(self) -> None: + """Ensure the track-id core features exist on this Tracks. + + Every Tracks has a registered TrackAnnotator and a tracklet key (no "plain" + vs "solution" split). This syncs features.tracklet_key/lineage_key from the + annotator and registers + computes (or activates, if already present) + tracklet_id/lineage_id. Runs on both the provided-FeatureDict and the + auto-detect construction paths; a no-op on an empty solution view. + """ + annotator = self.track_annotator + if annotator is None: + return + self.features.tracklet_key = annotator.tracklet_key + self.features.lineage_key = annotator.lineage_key + self._register_core_features([annotator.tracklet_key, annotator.lineage_key]) + def nodes(self): return np.array(self.graph_solution.node_ids()) @@ -826,29 +844,19 @@ def delete_feature(self, key: str) -> None: self.graph_solution.remove_edge_attr_key(key) # ========== Track ID management (solution view) ========== - # These operate on the solution view via the TrackAnnotator. They are only - # meaningful when this Tracks has track ids (i.e. a tracklet_key is set and a - # TrackAnnotator is registered). + # These operate on the solution view via the TrackAnnotator, which every Tracks + # has (track ids are a core feature). On an empty solution view they are no-ops. @property def track_annotator(self): - """The registered TrackAnnotator, or None if this Tracks has no track ids.""" - from funtracks.annotators import TrackAnnotator - + """The registered TrackAnnotator. Always present (every Tracks has track ids); + returns None only in the degenerate case where no TrackAnnotator was + registered.""" for annotator in self.annotators: if isinstance(annotator, TrackAnnotator): return annotator return None - def _require_track_annotator(self): - annotator = self.track_annotator - if annotator is None: - raise ValueError( - "This Tracks has no TrackAnnotator (no tracklet_key set); track id " - "operations are unavailable." - ) - return annotator - @classmethod def from_tracks(cls, tracks: Tracks) -> Tracks: """Return a Tracks with track ids, recomputing them if any are missing.""" @@ -885,11 +893,7 @@ def from_tracks(cls, tracks: Tracks) -> Tracks: @property def max_track_id(self) -> int: - return self._require_track_annotator().max_tracklet_id - - @property - def track_id_to_node(self) -> dict[int, list[int]]: - return self._require_track_annotator().tracklet_id_to_nodes + return self.track_annotator.max_tracklet_id def get_next_track_id(self) -> int: """Return the next available track_id. @@ -897,7 +901,7 @@ def get_next_track_id(self) -> int: The max_tracklet_id in TrackAnnotator is updated automatically when a node is added or track IDs are updated via UpdateTrackIDs. """ - return self._require_track_annotator().max_tracklet_id + 1 + return self.track_annotator.max_tracklet_id + 1 def get_next_lineage_id(self) -> int: """Return the next available lineage_id. @@ -905,11 +909,9 @@ def get_next_lineage_id(self) -> int: The max_lineage_id in TrackAnnotator is updated automatically when a node is added or lineage IDs are updated via UpdateTrackIDs. """ - return self._require_track_annotator().max_lineage_id + 1 + return self.track_annotator.max_lineage_id + 1 def get_track_id(self, node) -> int: - if self.features.tracklet_key is None: - raise ValueError("Tracklet key not initialized in features") track_id = self.get_node_attr(node, self.features.tracklet_key) return track_id @@ -918,8 +920,6 @@ def get_track_ids(self, nodes) -> list[int]: NOTE: always fetches the entire graph internally. Optimised for bulk (all-node) calls. For small subsets or single nodes use get_track_id() instead.""" - if self.features.tracklet_key is None: - raise ValueError("Tracklet key not initialized in features") tracklet_key = self.features.tracklet_key df = self.graph_solution.node_attrs( attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key] @@ -963,13 +963,12 @@ def get_track_neighbors( track id, and the first node after time with the given track id, or Nones if there are no such nodes. """ - annotator = self._require_track_annotator() if ( - track_id not in annotator.tracklet_id_to_nodes - or len(annotator.tracklet_id_to_nodes[track_id]) == 0 + track_id not in self.track_annotator.tracklet_id_to_nodes + or len(self.track_annotator.tracklet_id_to_nodes[track_id]) == 0 ): return None, None - candidates = annotator.tracklet_id_to_nodes[track_id] + candidates = self.track_annotator.tracklet_id_to_nodes[track_id] candidates.sort(key=lambda n: self.get_time(n)) pred = None @@ -995,7 +994,7 @@ def has_track_id_at_time(self, track_id: int, time: int) -> bool: Returns: True if a node with given track id exists at given time point. """ - nodes = self.track_id_to_node.get(track_id) + nodes = self.track_annotator.tracklet_id_to_nodes.get(track_id) if not nodes: return False diff --git a/src/funtracks/import_export/csv/_export.py b/src/funtracks/import_export/csv/_export.py index 58286ce7..f76bea25 100644 --- a/src/funtracks/import_export/csv/_export.py +++ b/src/funtracks/import_export/csv/_export.py @@ -210,7 +210,7 @@ def rgb_to_hex(rgb): track_id_to_hex = {} - for track_id, nodes in tracks.track_id_to_node.items(): + for track_id, nodes in tracks.track_annotator.tracklet_id_to_nodes.items(): if not nodes: continue first_node = nodes[0] diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index c750ecbb..7576a744 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -84,8 +84,6 @@ def __init__( else: time_key = tracks.features.time_key tracklet_key = tracks.features.tracklet_key - if tracklet_key is None: - raise ValueError("Track ID key is not set in tracks features") attrs: dict[str, int] = { time_key: time, tracklet_key: current_track_id, diff --git a/tests/annotators/test_annotator_registry.py b/tests/annotators/test_annotator_registry.py index 5efb88b4..30996ff3 100644 --- a/tests/annotators/test_annotator_registry.py +++ b/tests/annotators/test_annotator_registry.py @@ -4,7 +4,8 @@ from funtracks.data_model import Tracks track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} -# A plain (non-solution) Tracks declares no tracklet_attr, so no TrackAnnotator. +# Even without an explicit tracklet_attr, every Tracks gets a TrackAnnotator and a +# default tracklet key (no "plain" vs "solution" split). plain_attrs = {"time_attr": "t"} @@ -22,17 +23,18 @@ def test_annotator_registry_init_with_segmentation( annotator_types = [type(ann) for ann in tracks.annotators] assert RegionpropsAnnotator in annotator_types assert EdgeAnnotator in annotator_types - assert TrackAnnotator not in annotator_types # No tracklet_attr -> no track ids + assert TrackAnnotator in annotator_types # every Tracks has track ids def test_annotator_registry_init_without_segmentation(graph_2d_with_position): - """Test AnnotatorRegistry doesn't create annotators without segmentation.""" + """Without segmentation: no regionprops/edge annotators, but a TrackAnnotator is + still registered (track ids are a core feature of every Tracks).""" tracks = Tracks(graph_2d_with_position, ndim=3, **plain_attrs) annotator_types = [type(ann) for ann in tracks.annotators] assert RegionpropsAnnotator not in annotator_types assert EdgeAnnotator not in annotator_types - assert TrackAnnotator not in annotator_types + assert TrackAnnotator in annotator_types def test_annotator_registry_init_solution_tracks( diff --git a/tests/annotators/test_track_annotator.py b/tests/annotators/test_track_annotator.py index f55afb83..05666cb3 100644 --- a/tests/annotators/test_track_annotator.py +++ b/tests/annotators/test_track_annotator.py @@ -100,13 +100,13 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg): assert tracks.get_node_attr(node_id, ann.lineage_key) != orig_lin assert tracks.get_node_attr(node_id, ann.tracklet_key) != orig_tra - def test_invalid(self, get_tracks, ndim, with_seg) -> None: - # A plain Tracks (no tracklet_key) is not a track-id candidate, so no - # TrackAnnotator is registered. + def test_always_has_track_annotator(self, get_tracks, ndim, with_seg) -> None: + # Every Tracks has a tracklet key and a registered TrackAnnotator, even when + # built without explicit track attributes (no "plain" vs "solution" split). tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=False) - assert tracks.features.tracklet_key is None - assert not TrackAnnotator.can_annotate(tracks) - assert tracks.track_annotator is None + assert tracks.features.tracklet_key is not None + assert tracks.features.lineage_key is not None + assert isinstance(tracks.track_annotator, TrackAnnotator) def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg): """Test that TrackAnnotator ignores actions that don't affect track IDs.""" diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index 61f3aa19..157fd142 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -101,17 +101,6 @@ def test_export_without_seg_on_tracks(get_tracks, tmp_path): assert "segmentation_shape" not in attrs -@pytest.mark.parametrize("seg_relabel", ["tracklet", "lineage"]) -def test_export_seg_relabel_non_solution_raises(get_tracks, seg_relabel, tmp_path): - """Relabeling by tracklet/lineage on non-solution tracks raises ValueError.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=False) - - export_dir = tmp_path / "export" - export_dir.mkdir() - with pytest.raises(ValueError): - export_to_geff(tracks, export_dir, seg_relabel=seg_relabel) - - def test_export_segmentation_non_solution(get_tracks, tmp_path): """Non-solution tracks export segmentation fine with seg_relabel=None.""" tracks = get_tracks(ndim=3, with_seg=True, is_solution=False) From 84eac14fd1636213ed5821e7f08ea616c0389403 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 13:37:31 -0700 Subject: [PATCH 19/39] remove from_tracks --- src/funtracks/data_model/tracks.py | 34 --------------------- tests/data_model/test_solution_tracks.py | 39 ------------------------ 2 files changed, 73 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index e88fe857..ce95844a 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -857,40 +857,6 @@ def track_annotator(self): return annotator return None - @classmethod - def from_tracks(cls, tracks: Tracks) -> Tracks: - """Return a Tracks with track ids, recomputing them if any are missing.""" - force_recompute = False - # Check if all nodes have a value at features.tracklet_key before trusting - # existing track IDs - if ( - tracks.features.tracklet_key is not None - and ( - tracks.graph_solution.node_attrs(attr_keys=tracks.features.tracklet_key)[ - tracks.features.tracklet_key - ] - == -1 - ).any() - ): - # Attributes are no longer None, so -1 now means non-computed - force_recompute = True - - soln_tracks = cls( - tracks.graph_solution, - scale=tracks.scale, - ndim=tracks.ndim, - features=tracks.features, - _segmentation=tracks.segmentation, - ) - if force_recompute: - soln_tracks.enable_features( - [ - soln_tracks.features.tracklet_key, # type: ignore[list-item] - soln_tracks.features.lineage_key, # type: ignore[list-item] - ] - ) - return soln_tracks - @property def max_track_id(self) -> int: return self.track_annotator.max_tracklet_id diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 029b465f..a90691ca 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -33,45 +33,6 @@ def test_next_track_id(graph_2d_with_track_id): assert tracks.get_next_track_id() == 11 -def test_from_tracks_cls(graph_2d_with_segmentation): - tracks = Tracks( - graph_2d_with_segmentation, - ndim=3, - pos_attr="POSITION", - time_attr="TIME", - tracklet_attr=track_attrs["tracklet_attr"], - scale=(2, 2, 2), - ) - solution_tracks = Tracks.from_tracks(tracks) - assert solution_tracks.graph_solution == tracks.graph_solution - assert solution_tracks.segmentation == tracks.segmentation - assert solution_tracks.features.time_key == tracks.features.time_key - assert solution_tracks.features.position_key == tracks.features.position_key - assert solution_tracks.scale == tracks.scale - assert solution_tracks.ndim == tracks.ndim - assert solution_tracks.get_node_attr(6, tracks.features.tracklet_key) == 5 - - -def test_from_tracks_cls_recompute(graph_2d_with_segmentation): - tracks = Tracks( - graph_2d_with_segmentation, - ndim=3, - pos_attr="POSITION", - time_attr="TIME", - tracklet_attr=track_attrs["tracklet_attr"], - scale=(2, 2, 2), - ) - # delete track id (default value -1) on one node triggers reassignment of - # track_ids even when recompute is False. - tracks.graph_solution.nodes[1][tracks.features.tracklet_key] = -1 - solution_tracks = Tracks.from_tracks(tracks) - # should have reassigned new track_id to node 6 - assert solution_tracks.get_node_attr(6, solution_tracks.features.tracklet_key) == 4 - assert ( - solution_tracks.get_node_attr(1, solution_tracks.features.tracklet_key) == 1 - ) # still 1 - - def test_update_segmentation(graph_2d_with_segmentation): tracks = Tracks( graph_2d_with_segmentation, From 249df295d54e2426d8b6ccd50688f31d3620525d Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 14:12:54 -0700 Subject: [PATCH 20/39] rename is_solution in conftest to prefill_track_ids, as every Tracks now automatically has tracklet_ids --- src/funtracks/data_model/tracks.py | 30 ++++++------ src/funtracks/import_export/geff/_export.py | 2 +- tests/actions/test_add_delete_edge.py | 8 ++-- tests/actions/test_add_delete_nodes.py | 10 ++-- tests/actions/test_base_action.py | 2 +- tests/actions/test_update_node_attrs.py | 4 +- tests/actions/test_update_node_segs.py | 2 +- tests/annotators/test_edge_annotator.py | 2 +- .../annotators/test_features_on_full_graph.py | 4 +- .../annotators/test_regionprops_annotator.py | 2 +- tests/annotators/test_track_annotator.py | 18 ++++---- tests/conftest.py | 46 +++++++++---------- .../data_model/test_soft_delete_roundtrip.py | 10 ++-- tests/data_model/test_solution_tracks.py | 15 ------ tests/import_export/test_csv_export.py | 18 ++++---- tests/import_export/test_export_to_geff.py | 42 ++++++++--------- tests/import_export/test_import_from_geff.py | 6 +-- tests/import_export/test_internal_format.py | 25 +++++----- .../import_export/test_solution_roundtrip.py | 6 +-- tests/import_export/test_utils.py | 6 +-- tests/user_actions/test_user_actions_force.py | 6 +-- .../user_actions/test_user_add_delete_edge.py | 14 +++--- .../user_actions/test_user_add_delete_node.py | 10 ++-- .../test_user_swap_predecessors.py | 14 +++--- .../test_user_update_node_attrs.py | 14 +++--- .../test_user_update_nodes_attrs.py | 16 +++---- .../test_user_update_segmentation.py | 16 +++---- 27 files changed, 162 insertions(+), 186 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index ce95844a..7a94e599 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -328,15 +328,16 @@ def _get_annotators(self) -> AnnotatorRegistry: if EdgeAnnotator.can_annotate(self): annotator_list.append(EdgeAnnotator(self)) - # TrackAnnotator: registered when a tracklet_key is set (checked in can_annotate) - if TrackAnnotator.can_annotate(self): - annotator_list.append( - TrackAnnotator( - self, - tracklet_key=self.features.tracklet_key, - lineage_key=self.features.lineage_key, - ) + # TrackAnnotator is registered on every Tracks — track ids are a core feature, + # not a separate "type" of tracks. On an empty solution view it simply computes + # nothing until nodes are added. + annotator_list.append( + TrackAnnotator( + self, + tracklet_key=self.features.tracklet_key, + lineage_key=self.features.lineage_key, ) + ) return AnnotatorRegistry(annotator_list) def _activate_features_from_dict(self) -> None: @@ -406,8 +407,6 @@ def _ensure_track_features(self) -> None: auto-detect construction paths; a no-op on an empty solution view. """ annotator = self.track_annotator - if annotator is None: - return self.features.tracklet_key = annotator.tracklet_key self.features.lineage_key = annotator.lineage_key self._register_core_features([annotator.tracklet_key, annotator.lineage_key]) @@ -849,9 +848,8 @@ def delete_feature(self, key: str) -> None: @property def track_annotator(self): - """The registered TrackAnnotator. Always present (every Tracks has track ids); - returns None only in the degenerate case where no TrackAnnotator was - registered.""" + """The registered TrackAnnotator — always present, since track ids are a core + feature of every Tracks.""" for annotator in self.annotators: if isinstance(annotator, TrackAnnotator): return annotator @@ -899,17 +897,15 @@ def get_track_ids(self, nodes) -> list[int]: ) return [id_to_val[node] for node in nodes] - def get_lineage_id(self, node) -> int | None: + def get_lineage_id(self, node) -> int: """Get the lineage ID for a node. Args: node: The node to get lineage ID for Returns: - The lineage ID, or None if lineage feature is not enabled + The lineage ID. """ - if self.features.lineage_key is None: - return None return self.get_node_attr(node, self.features.lineage_key) def get_track_neighbors( diff --git a/src/funtracks/import_export/geff/_export.py b/src/funtracks/import_export/geff/_export.py index f1c33ae2..8c1ad17a 100644 --- a/src/funtracks/import_export/geff/_export.py +++ b/src/funtracks/import_export/geff/_export.py @@ -98,7 +98,7 @@ def export_to_geff( # Include the FeatureDict in metadata only for full exports. # Subgroup exports do not necessarily have valid tracklet/lineage IDs - # and thus are not valid SolutionTracks + # and thus are not valid Tracks graph, metadata = _build_geff_metadata(tracks, include_features=(node_ids is None)) # Save segmentation if present and requested diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index 2f2f73b9..a552dc7e 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize("ndim", [3, 4]) @pytest.mark.parametrize("with_seg", [True, False]) def test_add_delete_edges(get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) reference_graph = tracks.graph_solution reference_seg = np.asarray(tracks.segmentation).copy() @@ -80,7 +80,7 @@ def test_add_delete_edges(get_tracks, ndim, with_seg): def test_add_edge_missing_endpoint(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises( ValueError, match="Cannot add edge .*: endpoint .* not in solution" ): @@ -88,7 +88,7 @@ def test_add_edge_missing_endpoint(get_tracks): def test_delete_missing_edge(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises( ValueError, match="Edge .* not in the graph, and cannot be removed" ): @@ -101,7 +101,7 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg): """Test custom edge attributes preserved through add/delete/re-add cycles.""" from funtracks.features import Feature - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Register custom edge features so they get saved by DeleteEdge custom_features = { diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py index f6983e8c..9d375696 100644 --- a/tests/actions/test_add_delete_nodes.py +++ b/tests/actions/test_add_delete_nodes.py @@ -20,7 +20,7 @@ @pytest.mark.parametrize("with_seg", [True, False]) def test_add_delete_nodes(get_tracks, ndim, with_seg): # Get a tracks instance - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) reference_graph = tracks.graph_solution reference_seg = np.asarray(tracks.segmentation).copy() if with_seg else None @@ -121,19 +121,19 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): def test_add_node_missing_time(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(ValueError, match="Must provide a time attribute for node"): AddNode(tracks, 8, {}) def test_add_node_missing_pos(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) # First test: missing track_id raises an error with pytest.raises(ValueError, match="Must provide a track_id attribute for node"): AddNode(tracks, 8, {"t": 2}) # Second test: with track_id but without segmentation, missing pos raises an error - tracks_no_seg = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks_no_seg = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) with pytest.raises( ValueError, match="Must provide position or segmentation for node" ): @@ -146,7 +146,7 @@ def test_custom_attributes_preserved(get_tracks, ndim, with_seg): """Test custom node attributes preserved through add/delete/re-add cycles.""" from funtracks.features import Feature - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Register custom features so they get saved by DeleteNode custom_features = { diff --git a/tests/actions/test_base_action.py b/tests/actions/test_base_action.py index 9e4c0047..2b9e3f56 100644 --- a/tests/actions/test_base_action.py +++ b/tests/actions/test_base_action.py @@ -6,7 +6,7 @@ def test_initialize_base_class(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) action = Action(tracks) with pytest.raises(NotImplementedError): action.inverse() diff --git a/tests/actions/test_update_node_attrs.py b/tests/actions/test_update_node_attrs.py index f027ead1..5171e9b4 100644 --- a/tests/actions/test_update_node_attrs.py +++ b/tests/actions/test_update_node_attrs.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize("ndim", [3, 4]) def test_update_node_attrs(get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node = 1 new_feature = Feature( @@ -32,6 +32,6 @@ def test_update_node_attrs(get_tracks, ndim): @pytest.mark.parametrize("attr", ["t", "area", "track_id"]) def test_update_protected_attr(get_tracks, attr): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(ValueError, match="Cannot update attribute .* manually"): UpdateNodeAttrs(tracks, 1, {attr: 2}) diff --git a/tests/actions/test_update_node_segs.py b/tests/actions/test_update_node_segs.py index 8c7a66d8..7115d019 100644 --- a/tests/actions/test_update_node_segs.py +++ b/tests/actions/test_update_node_segs.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize("ndim", [3, 4]) def test_update_node_segs(get_tracks, ndim): # Get tracks with segmentation - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) reference_graph = tracks.graph_solution.detach().filter().subgraph() node = 1 diff --git a/tests/annotators/test_edge_annotator.py b/tests/annotators/test_edge_annotator.py index 4e912d50..49fc0d8a 100644 --- a/tests/annotators/test_edge_annotator.py +++ b/tests/annotators/test_edge_annotator.py @@ -133,7 +133,7 @@ def test_missing_seg(self, get_graph, ndim) -> None: def test_ignores_irrelevant_actions(self, get_graph, ndim): """Test that EdgeAnnotator ignores actions that don't affect edges.""" - graph = get_graph(ndim, is_solution=True, with_seg=True) + graph = get_graph(ndim, prefill_track_ids=True, with_seg=True) tracks = Tracks( graph, ndim=ndim, diff --git a/tests/annotators/test_features_on_full_graph.py b/tests/annotators/test_features_on_full_graph.py index 85cbcb5f..49321289 100644 --- a/tests/annotators/test_features_on_full_graph.py +++ b/tests/annotators/test_features_on_full_graph.py @@ -18,7 +18,7 @@ def _annotator(tracks, cls): @pytest.mark.parametrize("ndim", [3, 4]) def test_regionprops_persist_and_recompute_on_soft_deleted_node(get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) rp_ann = _annotator(tracks, RegionpropsAnnotator) tracks.enable_features(["pos", "area"]) rp_ann.compute(["area"]) @@ -45,7 +45,7 @@ def test_regionprops_persist_and_recompute_on_soft_deleted_node(get_tracks, ndim @pytest.mark.parametrize("ndim", [3, 4]) def test_iou_computed_on_soft_deleted_candidate_edge(get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) edge_ann = _annotator(tracks, EdgeAnnotator) tracks.enable_features(["iou"]) edge_ann.compute(["iou"]) diff --git a/tests/annotators/test_regionprops_annotator.py b/tests/annotators/test_regionprops_annotator.py index db692a02..7ee98ee3 100644 --- a/tests/annotators/test_regionprops_annotator.py +++ b/tests/annotators/test_regionprops_annotator.py @@ -176,7 +176,7 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): """Test that RegionpropsAnnotator ignores actions that don't affect segmentation. """ - graph = get_graph(ndim, is_solution=True, with_seg=True) + graph = get_graph(ndim, prefill_track_ids=True, with_seg=True) tracks = Tracks( graph, ndim=ndim, diff --git a/tests/annotators/test_track_annotator.py b/tests/annotators/test_track_annotator.py index 05666cb3..8987428f 100644 --- a/tests/annotators/test_track_annotator.py +++ b/tests/annotators/test_track_annotator.py @@ -16,7 +16,7 @@ @pytest.mark.parametrize("with_seg", [True, False]) class TestTrackAnnotator: def test_init(self, get_tracks, ndim, with_seg) -> None: - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) ann = TrackAnnotator(tracks) # Features start disabled by default assert len(ann.all_features) == 2 @@ -32,7 +32,7 @@ def test_init(self, get_tracks, ndim, with_seg) -> None: assert ann.max_tracklet_id == 5 def test_compute_all(self, get_tracks, ndim, with_seg) -> None: - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) ann = TrackAnnotator(tracks, tracklet_key=tracks.features.tracklet_key) # Enable features @@ -69,7 +69,7 @@ def test_compute_all(self, get_tracks, ndim, with_seg) -> None: assert len({id_set[0] for id_set in id_sets}) == len(id_sets) def test_add_remove_feature(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) ann = TrackAnnotator(tracks, tracklet_key=tracks.features.tracklet_key) # Enable features ann.activate_features(list(ann.all_features.keys())) @@ -103,7 +103,7 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg): def test_always_has_track_annotator(self, get_tracks, ndim, with_seg) -> None: # Every Tracks has a tracklet key and a registered TrackAnnotator, even when # built without explicit track attributes (no "plain" vs "solution" split). - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=False) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=False) assert tracks.features.tracklet_key is not None assert tracks.features.lineage_key is not None assert isinstance(tracks.track_annotator, TrackAnnotator) @@ -113,7 +113,7 @@ def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg): if not with_seg: pytest.skip("Test requires segmentation") - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.enable_features(["area", tracks.features.tracklet_key]) node_id = 3 @@ -135,7 +135,7 @@ def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg): def test_lineage_id_updated_on_add_and_delete_edge( self, get_tracks, ndim, with_seg ) -> None: - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) tracks.enable_features(["lineage_id"]) # get the existing TrackAnnotator @@ -208,7 +208,7 @@ def test_lineage_id_updated_on_division(self, get_tracks, ndim, with_seg) -> Non - New child (6): keeps same track_id, gets source's lineage_id """ # Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate) - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) tracks.enable_features(["lineage_id"]) ann = next(a for a in tracks.annotators if isinstance(a, TrackAnnotator)) @@ -257,7 +257,7 @@ def test_lineage_id_updated_on_delete_division_edge( - Orphaned child (2): keeps same track_id, gets new lineage_id """ # Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate) - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) tracks.enable_features(["lineage_id"]) ann = next(a for a in tracks.annotators if isinstance(a, TrackAnnotator)) @@ -289,7 +289,7 @@ def test_lineage_id_updated_on_delete_division_edge( def test_disabled_tracklet_key_does_nothing(self, get_tracks, ndim, with_seg) -> None: """Test that TrackAnnotator does nothing when tracklet_key is disabled.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) ann = TrackAnnotator(tracks) # Don't activate any features - they should all be disabled diff --git a/tests/conftest.py b/tests/conftest.py index 8f345634..5e3a1cfc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -359,10 +359,12 @@ def get_tracks(get_graph) -> Callable[..., "Tracks"]: Returns a factory function that can be called with: ndim: 3 for 2D spatial + time, 4 for 3D spatial + time with_seg: Whether to include segmentation (mask/bbox as node attributes) - is_solution: Whether to return Tracks instead of Tracks + prefill_track_ids: If True, the fixture graph ships with track_id/lineage_id + values already set (activated as-is); if False, they are computed from the + graph topology. Either way the resulting Tracks has track ids. Example: - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) Note: Uses a pre-built FeatureDict to avoid recomputing features that already @@ -385,12 +387,14 @@ def get_tracks(get_graph) -> Callable[..., "Tracks"]: def _make_tracks( ndim: int, with_seg: bool = True, - is_solution: bool = False, + prefill_track_ids: bool = False, ) -> Tracks: # Determine axis names based on ndim axis_names = ["z", "y", "x"] if ndim == 4 else ["y", "x"] - graph = get_graph(ndim=ndim, is_solution=is_solution, with_seg=with_seg) + graph = get_graph( + ndim=ndim, prefill_track_ids=prefill_track_ids, with_seg=with_seg + ) # Build FeatureDict based on what exists in the graph features_dict: dict[str, Any] = { @@ -404,7 +408,7 @@ def _make_tracks( features_dict["bbox"] = SegBbox(ndim) features_dict["area"] = Area(ndim=ndim) features_dict["iou"] = IoU() - if is_solution: + if prefill_track_ids: features_dict["track_id"] = TrackletID() features_dict["lineage_id"] = LineageID() @@ -412,23 +416,16 @@ def _make_tracks( features=features_dict, time_key="t", position_key="pos", - tracklet_key="track_id" if is_solution else None, - lineage_key="lineage_id" if is_solution else None, + tracklet_key="track_id" if prefill_track_ids else None, + lineage_key="lineage_id" if prefill_track_ids else None, ) - # Create the appropriate Tracks type with pre-built FeatureDict - if is_solution: - return Tracks( - graph, - ndim=ndim, - features=feature_dict, - ) - else: - return Tracks( - graph, - ndim=ndim, - features=feature_dict, - ) + # Create the Tracks with the pre-built FeatureDict. + return Tracks( + graph, + ndim=ndim, + features=feature_dict, + ) return _make_tracks @@ -481,7 +478,8 @@ def get_graph(tmp_path) -> Callable[..., td.graph.GraphView]: Args: ndim: 3 for 2D spatial + time, 4 for 3D spatial + time with_pos: Include position attribute (default True) - is_solution: Include track_id and lineage_id (default False) + prefill_track_ids: Include track_id and lineage_id columns on the graph + (default False) with_seg: Include mask, bbox, area, and iou (default False) Returns: @@ -489,14 +487,14 @@ def get_graph(tmp_path) -> Callable[..., td.graph.GraphView]: Example: graph = get_graph(ndim=3, with_seg=True) - graph = get_graph(ndim=4, is_solution=True, with_seg=True) + graph = get_graph(ndim=4, prefill_track_ids=True, with_seg=True) """ counter = [0] def _get_graph( ndim: int = 3, with_pos: bool = True, - is_solution: bool = False, + prefill_track_ids: bool = False, with_seg: bool = False, ) -> td.graph.GraphView: counter[0] += 1 @@ -504,7 +502,7 @@ def _get_graph( return _make_graph( ndim=ndim, with_pos=with_pos, - with_track_id=is_solution, + with_track_id=prefill_track_ids, with_area=with_seg, with_iou=with_seg, with_masks=with_seg, diff --git a/tests/data_model/test_soft_delete_roundtrip.py b/tests/data_model/test_soft_delete_roundtrip.py index 691d16d2..852130d9 100644 --- a/tests/data_model/test_soft_delete_roundtrip.py +++ b/tests/data_model/test_soft_delete_roundtrip.py @@ -43,7 +43,7 @@ def _positions(tracks): @pytest.mark.parametrize("ndim", [3, 4]) @pytest.mark.parametrize("with_seg", [True, False]) def test_soft_delete_keeps_leaf_node_in_full_graph(get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) full_before = _full_state(tracks) sol_nodes_before = set(tracks.graph_solution.node_ids()) @@ -65,7 +65,7 @@ def test_soft_delete_keeps_leaf_node_in_full_graph(get_tracks, ndim, with_seg): @pytest.mark.parametrize("ndim", [3, 4]) @pytest.mark.parametrize("with_seg", [True, False]) def test_delete_undo_redo_roundtrip_identity(get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) sol_ref = _solution_state(tracks) full_ref = _full_state(tracks) pos_ref = _positions(tracks) @@ -89,7 +89,7 @@ def test_delete_undo_redo_roundtrip_identity(get_tracks, ndim, with_seg): @pytest.mark.parametrize("ndim", [3, 4]) def test_repeated_delete_undo_is_stable(get_tracks, ndim): """Invariant #4: N undo/redo cycles must not drift the solution view or full graph.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) sol_ref = _solution_state(tracks) iou_ref = tracks.get_edge_attr((4, 5), "iou") @@ -118,7 +118,7 @@ def test_mid_track_delete_leaves_skip_edge_candidate_in_full(get_tracks, ndim): """Deleting a mid-track node adds a reconnection skip-edge (3->5). On undo it is soft-deleted, so it persists in graph_full as a solution=False candidate while the solution view returns to its original topology.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) sol_ref = _solution_state(tracks) assert not tracks.graph_full.has_edge(3, 5) @@ -144,7 +144,7 @@ def test_attr_reads_resolve_for_soft_deleted_node(get_tracks, ndim): `get_position` (graph_full) succeeded — a latent inconsistency invisible to tests that only ever query in-solution nodes. """ - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) pos_single_before = tracks.get_position(5) pos_bulk_before = tracks.get_positions([5])[0].tolist() assert pos_bulk_before == pytest.approx(pos_single_before) diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index a90691ca..b7c94a1c 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -59,21 +59,6 @@ def test_next_track_id_empty(): assert tracks.get_next_track_id() == 1 -def test_get_lineage_id_without_lineage_key(graph_2d_with_track_id): - """Test that get_lineage_id returns None when lineage_key is not set.""" - graph = graph_2d_with_track_id - graph.add_node( - attrs={"t": 1, "pos": [0, 0], "track_id": 1}, index=7, validate_keys=False - ) - tracks = Tracks(graph, ndim=3, **track_attrs) - - # Unset lineage_key to test the None path - tracks.features.lineage_key = None - - # get_lineage_id should return None when lineage_key is not set - assert tracks.get_lineage_id(1) is None - - def test_export_to_csv_with_display_names( graph_2d_with_segmentation, graph_3d_with_segmentation, tmp_path ): diff --git a/tests/import_export/test_csv_export.py b/tests/import_export/test_csv_export.py index 75c35080..9bfa6cc4 100644 --- a/tests/import_export/test_csv_export.py +++ b/tests/import_export/test_csv_export.py @@ -16,7 +16,7 @@ ) def test_export_solution_to_csv(get_tracks, tmp_path, ndim, expected_header): """Test exporting tracks to CSV.""" - tracks = get_tracks(ndim=ndim, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=False, prefill_track_ids=True) temp_file = tmp_path / "test_export.csv" export_to_csv(tracks, temp_file) @@ -46,7 +46,7 @@ def test_export_solution_to_csv_with_seg_zarr( get_tracks, tmp_path, ndim, expected_header ): """Test exporting tracks to CSV + segmentation painted by track_id as zarr.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) temp_file = tmp_path / "test_export.csv" seg_dir = tmp_path / "test_export_seg" export_to_csv(tracks, temp_file, export_seg=True, seg_path=seg_dir) @@ -75,7 +75,7 @@ def test_export_solution_to_csv_with_seg_zarr( @pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"]) def test_export_solution_to_csv_with_seg_tiff(get_tracks, tmp_path, ndim): """Test exporting tracks to CSV + segmentation as tiff painted by tracklet ID.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) temp_file = tmp_path / "test_export.csv" seg_file = tmp_path / "test_export_seg.tif" export_to_csv( @@ -102,7 +102,7 @@ def test_export_solution_to_csv_with_seg_tiff(get_tracks, tmp_path, ndim): @pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"]) def test_export_solution_to_csv_with_seg_original_labels(get_tracks, tmp_path, ndim): """Test exporting tracks to CSV + segmentation with original (node_id) labels.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) temp_file = tmp_path / "test_export.csv" seg_dir = tmp_path / "test_export_seg" export_to_csv( @@ -128,7 +128,7 @@ def test_export_with_color_dict(get_tracks, tmp_path): """Test exporting with a color_dict adds a Tracklet ID Color column.""" import numpy as np - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) temp_file = tmp_path / "test_export_colors.csv" # Build a color dict: node_id → [R, G, B] floats in [0, 1] @@ -150,7 +150,7 @@ def test_export_with_color_dict(get_tracks, tmp_path): def test_export_with_display_names(get_tracks, tmp_path): """Test exporting with display names.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) temp_file = tmp_path / "test_export_display.csv" export_to_csv(tracks, temp_file, use_display_names=True) @@ -165,7 +165,7 @@ def test_export_with_display_names(get_tracks, tmp_path): def test_export_filtered_nodes(get_tracks, tmp_path): """Test exporting only specific nodes.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) temp_file = tmp_path / "test_export_filtered.csv" # Export only nodes 1 and 2 (and their ancestors) @@ -181,7 +181,7 @@ def test_export_filtered_nodes(get_tracks, tmp_path): def test_ignore_edge_features_at_export(get_tracks, tmp_path): """Test that edge features are ignored when exporting to csv""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) temp_file = tmp_path / "test_export_node_features_only.csv" # enable node and edge features @@ -213,7 +213,7 @@ def test_export_solution_to_csv_with_seg_and_node_subset( and segmentation must only include the resulting graph nodes, and nothing else. """ - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) csv_file = tmp_path / "export.csv" seg_path = tmp_path / "seg" diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index 157fd142..05c8814d 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -32,7 +32,7 @@ def _assert_valid_geff_export(export_dir, expected_num_nodes=None): @pytest.mark.parametrize("seg_relabel", ["tracklet", "lineage", None]) def test_export_segmentation_relabel(get_tracks, ndim, seg_relabel, tmp_path): """Test segmentation export with each relabel strategy.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -70,7 +70,7 @@ def test_export_segmentation_relabel(get_tracks, ndim, seg_relabel, tmp_path): def test_export_no_segmentation_saved(get_tracks, tmp_path): """Test that save_segmentation=False suppresses segmentation file.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -87,7 +87,7 @@ def test_export_no_segmentation_saved(get_tracks, tmp_path): def test_export_without_seg_on_tracks(get_tracks, tmp_path): """Test export when tracks have no segmentation at all.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -103,7 +103,7 @@ def test_export_without_seg_on_tracks(get_tracks, tmp_path): def test_export_segmentation_non_solution(get_tracks, tmp_path): """Non-solution tracks export segmentation fine with seg_relabel=None.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=False) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=False) export_dir = tmp_path / "export" export_dir.mkdir() @@ -126,10 +126,10 @@ def test_export_segmentation_non_solution(get_tracks, tmp_path): @pytest.mark.parametrize("ndim", [3, 4]) -@pytest.mark.parametrize("is_solution", [True, False]) -def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path): +@pytest.mark.parametrize("prefill_track_ids", [True, False]) +def test_export_split_position_attrs(get_graph, ndim, prefill_track_ids, tmp_path): """Test export with split (list) position attributes.""" - graph = get_graph(ndim, is_solution=is_solution, with_seg=False) + graph = get_graph(ndim, prefill_track_ids=prefill_track_ids, with_seg=False) pos_keys = ["y", "x"] if ndim == 3 else ["z", "y", "x"] for key in pos_keys: @@ -144,7 +144,7 @@ def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path): graph, time_attr="t", pos_attr=pos_keys, - tracklet_attr="track_id" if is_solution else None, + tracklet_attr="track_id" if prefill_track_ids else None, ndim=ndim, ) @@ -167,7 +167,7 @@ def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path): @pytest.mark.parametrize("ndim", [3, 4]) def test_export_node_subset(get_tracks, ndim, tmp_path): """Test exporting a subset of nodes includes ancestors.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -198,7 +198,7 @@ def test_export_node_subset(get_tracks, ndim, tmp_path): def test_export_node_subset_seg_relabel(get_tracks, tmp_path): """Test subset export with relabeled segmentation.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -232,7 +232,7 @@ def test_export_node_subset_seg_relabel(get_tracks, tmp_path): def test_export_node_subset_without_seg(get_tracks, tmp_path): """Test subset export when tracks have no segmentation.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -250,7 +250,7 @@ def test_export_node_subset_without_seg(get_tracks, tmp_path): def test_export_overwrite(get_tracks, tmp_path): """Test export with overwrite=True into non-empty directory.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -268,7 +268,7 @@ def test_export_overwrite(get_tracks, tmp_path): def test_export_non_directory_raises(get_tracks, tmp_path): """Test that exporting to a file path (not a directory) raises an error.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) file_path = tmp_path / "not_a_dir" file_path.write_text("test") @@ -284,7 +284,7 @@ def test_export_non_directory_raises(get_tracks, tmp_path): @pytest.mark.parametrize("with_seg", [True, False]) def test_export_metadata(get_tracks, ndim, with_seg, tmp_path): """Test axes structure, segmentation_shape, and FeatureDict in metadata.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -319,7 +319,7 @@ def test_export_metadata(get_tracks, ndim, with_seg, tmp_path): @pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"]) def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path): """Test that segmentation can be exported as tiff alongside the geff graph.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -354,7 +354,7 @@ def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path): @pytest.mark.parametrize("with_seg", [False, True]) def test_write_to_geff_roundtrip(get_tracks, ndim, with_seg, tmp_path): """write_to_geff then import_from_geff recovers the tracks.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -377,7 +377,7 @@ def test_write_to_geff_roundtrip(get_tracks, ndim, with_seg, tmp_path): def test_write_to_geff_no_parent_container(get_tracks, tmp_path): """write_to_geff writes directly to the path — no parent .zgroup or tracks.geff subfolder.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -394,7 +394,7 @@ def test_write_to_geff_no_parent_container(get_tracks, tmp_path): def test_write_to_geff_overwrite(get_tracks, tmp_path): """write_to_geff with overwrite=True replaces existing store.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -406,7 +406,7 @@ def test_write_to_geff_overwrite(get_tracks, tmp_path): def test_write_to_geff_no_overwrite_raises(get_tracks, tmp_path): """write_to_geff with overwrite=False raises on existing store.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -417,7 +417,7 @@ def test_write_to_geff_no_overwrite_raises(get_tracks, tmp_path): def test_write_to_geff_metadata(get_tracks, tmp_path): """write_to_geff stores axes and FeatureDict metadata.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -435,7 +435,7 @@ def test_write_to_geff_metadata(get_tracks, tmp_path): def test_write_to_geff_segmentation_shape(get_tracks, tmp_path): """write_to_geff writes segmentation_shape when masks are present.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index 47a5b010..7826a848 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -830,7 +830,7 @@ def test_geff_legacy_track_id_preserves_tracklet_ids(): def test_geff_roundtrip_preserves_tracklet_ids(get_tracks, tmp_path): """End-to-end round-trip: export then import should preserve tracklet IDs.""" - tracks_in = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks_in = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) expected = { int(nid): tracks_in.get_track_id(int(nid)) for nid in tracks_in.graph_solution.node_ids() @@ -1061,7 +1061,7 @@ def test_invalid_featuredict_in_geff_falls_back_to_autodetect(get_tracks, tmp_pa import_from_geff should silently fall back to auto-detection instead of raising an exception. """ - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) run_dir = tmp_path / "run" run_dir.mkdir() export_to_geff(tracks, run_dir, save_segmentation=False) @@ -1101,7 +1101,7 @@ def test_subgroup_export_omits_featuredict_and_recomputes_on_import(get_tracks, On reimport, validation would strip the now-invalid tracklet_id but the FeatureDict still referenced it, causing KeyError: 'tracklet_id'. """ - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) # Export only a subset: nodes 1, 3, 4, 5 (one branch of the division). # filter_graph_with_ancestors will include node 1 as ancestor of 3. diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index c170021c..bbba2246 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -14,20 +14,20 @@ @pytest.mark.parametrize("with_seg", [True, False]) @pytest.mark.parametrize("ndim", [3, 4]) -@pytest.mark.parametrize("is_solution", [True, False]) +@pytest.mark.parametrize("prefill_track_ids", [True, False]) def test_save_load( get_tracks, with_seg, ndim, - is_solution, + prefill_track_ids, ): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=is_solution) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=prefill_track_ids) data_path = Path( - f"tests/data/format_v1/test_save_load_{is_solution}_{ndim}_{with_seg}_0" + f"tests/data/format_v1/test_save_load_{prefill_track_ids}_{ndim}_{with_seg}_0" ) - loaded = load_v1_tracks(data_path, solution=is_solution) + loaded = load_v1_tracks(data_path, solution=prefill_track_ids) assert loaded.ndim == ndim # Check feature keys and important properties match (allow tuple vs list diff) assert loaded.features.time_key == tracks.features.time_key @@ -60,12 +60,9 @@ def test_save_load( assert loaded.scale == tracks.scale assert loaded.ndim == tracks.ndim - if is_solution: - loaded_annotator = loaded.track_annotator - tracks_annotator = tracks.track_annotator - assert ( - loaded_annotator.tracklet_id_to_nodes == tracks_annotator.tracklet_id_to_nodes - ) + loaded_annotator = loaded.track_annotator + tracks_annotator = tracks.track_annotator + assert loaded_annotator.tracklet_id_to_nodes == tracks_annotator.tracklet_id_to_nodes if with_seg: assert_array_almost_equal(loaded.segmentation, tracks.segmentation) @@ -90,16 +87,16 @@ def test_save_load( @pytest.mark.parametrize("with_seg", [True, False]) @pytest.mark.parametrize("ndim", [3, 4]) -@pytest.mark.parametrize("is_solution", [True, False]) +@pytest.mark.parametrize("prefill_track_ids", [True, False]) def test_delete( get_tracks, with_seg, ndim, - is_solution, + prefill_track_ids, tmp_path, ): reference_path = Path( - f"tests/data/format_v1/test_save_load_{is_solution}_{ndim}_{with_seg}_0" + f"tests/data/format_v1/test_save_load_{prefill_track_ids}_{ndim}_{with_seg}_0" ) # Copy reference data to temporary location diff --git a/tests/import_export/test_solution_roundtrip.py b/tests/import_export/test_solution_roundtrip.py index 60056b6c..a716fa1c 100644 --- a/tests/import_export/test_solution_roundtrip.py +++ b/tests/import_export/test_solution_roundtrip.py @@ -20,7 +20,7 @@ def _roundtrip(tracks, tmp_path, name="rt.geff"): def test_geff_roundtrip_preserves_solution_schema(get_tracks, tmp_path): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) loaded = _roundtrip(tracks, tmp_path) edge_schema = loaded.graph_solution._edge_attr_schemas()["solution"] @@ -33,7 +33,7 @@ def test_geff_roundtrip_preserves_solution_schema(get_tracks, tmp_path): def test_add_edge_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) loaded = _roundtrip(tracks, tmp_path) g = loaded.graph_solution @@ -61,7 +61,7 @@ def test_add_edge_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path): def test_add_node_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path): - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) loaded = _roundtrip(tracks, tmp_path) new_id = max(loaded.graph_solution.node_ids()) + 1 diff --git a/tests/import_export/test_utils.py b/tests/import_export/test_utils.py index cbfa1bcc..b9e75400 100644 --- a/tests/import_export/test_utils.py +++ b/tests/import_export/test_utils.py @@ -3,7 +3,7 @@ def test_rename_feature_basic(get_tracks): """Test that rename_feature renames a feature in annotators and features dict.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=False) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=False) # Rename area feature to custom name rename_feature(tracks, "area", "my_area") @@ -19,7 +19,7 @@ def test_rename_feature_basic(get_tracks): def test_rename_feature_updates_position_key(get_tracks): """Test that renaming position feature updates position_key in FeatureDict.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) original_pos_key = tracks.features.position_key new_key = "custom_position" @@ -32,7 +32,7 @@ def test_rename_feature_updates_position_key(get_tracks): def test_rename_feature_updates_tracklet_key(get_tracks): """Test that renaming tracklet feature updates tracklet_key in FeatureDict.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) original_track_key = tracks.features.tracklet_key new_key = "custom_track" diff --git a/tests/user_actions/test_user_actions_force.py b/tests/user_actions/test_user_actions_force.py index 538ed5de..045842f1 100644 --- a/tests/user_actions/test_user_actions_force.py +++ b/tests/user_actions/test_user_actions_force.py @@ -7,7 +7,7 @@ def test_user_force_add_downstream(get_tracks): """Test force adding a node of which the track id has an upstream division event. Should break the edges of the division event to allow this new edge.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) # upstream division, with force attrs = {"t": 2, "track_id": 1, "pos": [3, 4]} @@ -22,7 +22,7 @@ def test_user_force_add_upstream(get_tracks): """Test force adding a node upstream, of which the track id co-exists with the parent track id. Should break the edge with the parent track to allow this new edge.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) # downstream parent division, with force attrs = {"t": 0, "track_id": 3, "pos": [3, 4]} @@ -37,7 +37,7 @@ def test_auto_assign_new_track_id(get_tracks): """Test that adding a node with a track id that already exists at the current time point raises a warning and auto-assigns a new track id instead.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) # existing track id at current time --> allowed, with warning with pytest.warns(UserWarning, match="Starting a new track, because track id"): diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py index 76738ce8..9364b78f 100644 --- a/tests/user_actions/test_user_add_delete_edge.py +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize("with_seg", [True, False]) class TestUserAddDeleteEdge: def test_user_add_edge(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # add an edge from 4 to 6 (will make 4 a division and 5 will need to relabel # track id) edge = (4, 6) @@ -28,7 +28,7 @@ def test_user_add_edge(self, get_tracks, ndim, with_seg): assert tracks.get_track_id(old_child) != old_track_id def test_user_add_merge_edge(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # add an edge from 2 to 4 (there is already an edge from 3 to 4) edge = (2, 4) old_edge = (3, 4) @@ -55,7 +55,7 @@ def test_user_add_merge_edge(self, get_tracks, ndim, with_seg): assert not tracks.graph_solution.has_edge(*old_edge) def test_user_delete_edge(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # delete edge (1, 3). (1,2) is now not a division anymore edge = (1, 3) old_child = 2 @@ -100,7 +100,7 @@ def test_user_delete_edge(self, get_tracks, ndim, with_seg): def test_add_edge_missing_node(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(InvalidActionError, match="Source node .* not in solution yet"): UserAddEdge(tracks, (10, 11)) with pytest.raises(InvalidActionError, match="Target node .* not in solution yet"): @@ -108,7 +108,7 @@ def test_add_edge_missing_node(get_tracks): def test_add_edge_triple_div(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises( InvalidActionError, match="Expected degree of 0 or 1 before adding edge" ): @@ -116,13 +116,13 @@ def test_add_edge_triple_div(get_tracks): def test_delete_missing_edge(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(InvalidActionError, match="Edge .* not in solution"): UserDeleteEdge(tracks, (10, 11)) def test_delete_edge_triple_div(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) attrs = {} attrs["solution"] = True attrs["iou"] = 0.9 diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py index c1e5b01b..8b21d4b1 100644 --- a/tests/user_actions/test_user_add_delete_node.py +++ b/tests/user_actions/test_user_add_delete_node.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("with_seg", [True, False]) class TestUserAddDeleteNode: def test_user_add_invalid_node(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # duplicate node with pytest.raises(InvalidActionError, match="Node .* already exists"): attrs = {"t": 5, "track_id": 1} @@ -34,7 +34,7 @@ def test_user_add_invalid_node(self, get_tracks, ndim, with_seg): UserAddNode(tracks, node=7, attributes=attrs) def test_user_add_node(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # add a node to replace a skip edge between node 4 in time 2 and node 5 in time 4 node_id = 7 track_id = 3 @@ -84,7 +84,7 @@ def test_user_add_node(self, get_tracks, ndim, with_seg): # TODO: error if node already exists? def test_user_delete_node(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # delete node in middle of track. Should skip-connect 3 and 5 with span 3 node_id = 4 @@ -121,7 +121,7 @@ def test_user_delete_node(self, get_tracks, ndim, with_seg): # TODO: error if node doesn't exist? def test_user_delete_node_after_division(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # delete first node after division. Should relabel the other child # to be the same track as parent parent_node = 1 @@ -158,7 +158,7 @@ def test_user_delete_node_after_division(self, get_tracks, ndim, with_seg): def test_user_delete_nodes(self, get_tracks, ndim, with_seg): """Test bulk deletion of multiple nodes in a single action.""" # Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate) - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) graph = tracks.graph_solution # Save original state diff --git a/tests/user_actions/test_user_swap_predecessors.py b/tests/user_actions/test_user_swap_predecessors.py index 320bbfd7..82a60c88 100644 --- a/tests/user_actions/test_user_swap_predecessors.py +++ b/tests/user_actions/test_user_swap_predecessors.py @@ -10,7 +10,7 @@ class TestUserSwapPredecessors: @pytest.mark.parametrize("order", [(5, 6), (6, 5)]) def test_one_predecessor(self, get_tracks, ndim, with_seg, order): """Test swapping when one node has a predecessor and one doesn't.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Node 5 (t=4) has pred 4, node 6 (t=4) has no pred assert tracks.graph_solution.has_edge(4, 5) @@ -33,7 +33,7 @@ def test_one_predecessor(self, get_tracks, ndim, with_seg, order): def test_same_predecessor_raises(self, get_tracks, ndim, with_seg): """Test error when both nodes have the same predecessor.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Nodes 2 and 3 both have predecessor 1 with pytest.raises(InvalidActionError, match="same predecessor"): @@ -41,7 +41,7 @@ def test_same_predecessor_raises(self, get_tracks, ndim, with_seg): def test_different_predecessors(self, get_tracks, ndim, with_seg): """Test swapping when both nodes have different predecessors.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) UserAddEdge(tracks, (2, 6)) @@ -64,7 +64,7 @@ def test_different_predecessors(self, get_tracks, ndim, with_seg): def test_different_times_valid(self, get_tracks, ndim, with_seg): """Test swapping nodes at different times when predecessors are valid.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Add edge 2->6 so node 6 (t=4) has pred 2 (t=1) # Node 4 (t=2) has pred 3 (t=1) @@ -84,7 +84,7 @@ def test_different_times_valid(self, get_tracks, ndim, with_seg): def test_different_times_invalid_raises(self, get_tracks, ndim, with_seg): """Test error when predecessor would not be before swapped node.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Node 3 (t=1) has pred 1 (t=0), node 4 (t=2) has pred 3 (t=1) # pred of 4 (t=1) is not before node 3 (t=1) @@ -93,7 +93,7 @@ def test_different_times_invalid_raises(self, get_tracks, ndim, with_seg): def test_wrong_count_raises(self, get_tracks, ndim, with_seg): """Test error when not exactly two nodes provided.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) with pytest.raises( InvalidActionError, match="You can only swap a pair of two nodes" @@ -107,7 +107,7 @@ def test_wrong_count_raises(self, get_tracks, ndim, with_seg): def test_no_predecessors_raises(self, get_tracks, ndim, with_seg): """Test error when neither node has a predecessor.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Delete edge so node 5 has no predecessor like node 6 UserDeleteEdge(tracks, (4, 5)) diff --git a/tests/user_actions/test_user_update_node_attrs.py b/tests/user_actions/test_user_update_node_attrs.py index 3dc33f11..3b3c98b4 100644 --- a/tests/user_actions/test_user_update_node_attrs.py +++ b/tests/user_actions/test_user_update_node_attrs.py @@ -9,7 +9,7 @@ class TestUserUpdateNodeAttrs: def test_user_update_node_attrs(self, get_tracks, ndim, with_seg): """Test basic node attribute update functionality.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.graph_solution.add_node_attr_key( "label", default_value=None, dtype=pl.Object @@ -49,7 +49,7 @@ def test_user_update_node_attrs(self, get_tracks, ndim, with_seg): def test_user_update_existing_attrs(self, get_tracks, ndim, with_seg): """Test updating attributes that already exist.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.graph_solution.add_node_attr_key( "label", default_value=None, dtype=pl.Object @@ -77,7 +77,7 @@ def test_user_update_existing_attrs(self, get_tracks, ndim, with_seg): def test_protected_time_attr(self, get_tracks, ndim, with_seg): """Test that time attribute cannot be updated.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) time_key = tracks.features.time_key with pytest.raises(ValueError, match="Cannot update attribute"): @@ -85,14 +85,14 @@ def test_protected_time_attr(self, get_tracks, ndim, with_seg): def test_protected_track_id_attr(self, get_tracks, ndim, with_seg): """Test that track_id attribute cannot be updated.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) with pytest.raises(ValueError, match="Cannot update attribute"): UserUpdateNodeAttrs(tracks, node=1, attrs={"track_id": 999}) def test_protected_area_attr(self, get_tracks, ndim, with_seg): """Test that area attribute (managed by annotator) cannot be updated.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) if with_seg: # area only exists when segmentation is present with pytest.raises(ValueError, match="Cannot update attribute"): @@ -100,7 +100,7 @@ def test_protected_area_attr(self, get_tracks, ndim, with_seg): def test_protected_pos_attr(self, get_tracks, ndim, with_seg): """Test that position attribute (managed by annotator) cannot be updated.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) if with_seg: # pos is managed by RegionpropsAnnotator when seg exists with pytest.raises(ValueError, match="Cannot update attribute"): @@ -108,7 +108,7 @@ def test_protected_pos_attr(self, get_tracks, ndim, with_seg): def test_action_history_integration(self, get_tracks, ndim, with_seg): """Test that action integrates properly with action history.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.graph_solution.add_node_attr_key( "label", default_value=None, dtype=pl.Object ) diff --git a/tests/user_actions/test_user_update_nodes_attrs.py b/tests/user_actions/test_user_update_nodes_attrs.py index 00890c23..b327b5a0 100644 --- a/tests/user_actions/test_user_update_nodes_attrs.py +++ b/tests/user_actions/test_user_update_nodes_attrs.py @@ -10,7 +10,7 @@ class TestUserUpdateNodesAttrs: def test_user_update_nodes_attrs(self, get_tracks, ndim, with_seg): """Test basic bulk node attribute update functionality.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.graph_solution.add_node_attr_key( "label", default_value=None, dtype=pl.Object @@ -28,7 +28,7 @@ def test_user_update_nodes_attrs(self, get_tracks, ndim, with_seg): def test_single_history_entry(self, get_tracks, ndim, with_seg): """Updating multiple nodes creates only one history entry.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.graph_solution.add_node_attr_key( "label", default_value=None, dtype=pl.Object ) @@ -42,7 +42,7 @@ def test_single_history_entry(self, get_tracks, ndim, with_seg): def test_undo_redo(self, get_tracks, ndim, with_seg): """Undo restores all nodes' attrs to defaults; redo re-applies them.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.graph_solution.add_node_attr_key( "score", default_value=0, dtype=pl.Float64 ) @@ -64,7 +64,7 @@ def test_undo_redo(self, get_tracks, ndim, with_seg): def test_per_node_attrs(self, get_tracks, ndim, with_seg): """Test bulk update with different values per node.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.graph_solution.add_node_attr_key( "score", default_value=0, dtype=pl.Float64 ) @@ -76,7 +76,7 @@ def test_per_node_attrs(self, get_tracks, ndim, with_seg): def test_array_attr(self, get_tracks, ndim, with_seg): """Test bulk update with array-valued attributes.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) spatial_dims = ndim - 1 tracks.graph_solution.add_node_attr_key( @@ -91,21 +91,21 @@ def test_array_attr(self, get_tracks, ndim, with_seg): def test_values_not_list_raises(self, get_tracks, ndim, with_seg): """Non-list values raise ValueError.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) with pytest.raises(ValueError, match="must be a list"): UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": 0.9}) def test_values_length_mismatch_raises(self, get_tracks, ndim, with_seg): """List length not matching nodes length raises ValueError.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) with pytest.raises(ValueError, match="length"): UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.1]}) def test_protected_attr_raises(self, get_tracks, ndim, with_seg): """Passing a protected attribute raises ValueError.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) time_key = tracks.features.time_key with pytest.raises(ValueError, match="Cannot update attribute"): diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py index c502a921..1a9a2f3b 100644 --- a/tests/user_actions/test_user_update_segmentation.py +++ b/tests/user_actions/test_user_update_segmentation.py @@ -26,7 +26,7 @@ def pixels_equal_mask(self, pixels, tracks, node_id): ) def test_user_update_seg_smaller(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 3 edge = (1, 3) @@ -71,7 +71,7 @@ def test_user_update_seg_smaller(self, get_tracks, ndim): assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(0.0, abs=0.01) def test_user_update_seg_bigger(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 3 edge = (1, 3) @@ -114,7 +114,7 @@ def test_user_update_seg_bigger(self, get_tracks, ndim): assert tracks.get_edge_attr(edge, iou_key) != orig_iou def test_invalid_action_with_segmentation(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 1 # Paint on top of node 1 with track id 3: because of the downstream division, this @@ -162,7 +162,7 @@ def test_invalid_action_with_segmentation(self, get_tracks, ndim): # and one for updating existing node 1 def test_user_erase_seg(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 3 edge = (1, 3) @@ -199,7 +199,7 @@ def test_user_erase_seg_history_size(self, get_tracks, ndim): entry. Regression test for a bug where the nested UserDeleteNode also registered itself, leaving two entries per fill and corrupting undo behavior.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 6 pixels = td_mask_to_pixels( tracks.get_mask(node_id), tracks.get_time(node_id), ndim=tracks.ndim @@ -217,7 +217,7 @@ def test_user_two_erases_then_two_undos(self, get_tracks, ndim): tracks.action_history.undo(). Reproduces bug_paint_undo: the second undo crashed because the buggy history had a duplicate UserDeleteNode entry that tried to re-add an already-restored node.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) pixels_5 = td_mask_to_pixels( tracks.get_mask(5), tracks.get_time(5), ndim=tracks.ndim ) @@ -244,7 +244,7 @@ def test_user_two_erases_then_two_undos(self, get_tracks, ndim): assert tracks.graph_solution.has_node(6) def test_user_add_seg(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) # draw a new node just like node 6 but in time 3 (instead of 4) old_node_id = 6 node_id = 7 @@ -286,6 +286,6 @@ def test_user_add_seg(self, get_tracks, ndim): def test_missing_seg(get_tracks): - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) with pytest.raises(ValueError, match="Cannot update non-existing segmentation"): UserUpdateSegmentation(tracks, 0, [], 1) From baf0125026b4ae1830da262654690ceba96f71b0 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 14:52:50 -0700 Subject: [PATCH 21/39] give export_to_csv to option to export either full or solution graph --- src/funtracks/import_export/csv/_export.py | 21 +++++++++--- tests/import_export/test_csv_export.py | 38 ++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/src/funtracks/import_export/csv/_export.py b/src/funtracks/import_export/csv/_export.py index f76bea25..f4d42c35 100644 --- a/src/funtracks/import_export/csv/_export.py +++ b/src/funtracks/import_export/csv/_export.py @@ -20,6 +20,7 @@ def export_to_csv( outfile: Path | str, color_dict: dict[int, np.ndarray] | None = None, node_ids: set[int] | None = None, + export_full: bool = False, use_display_names: bool = False, export_seg: bool = False, seg_path: Path | str | None = None, @@ -42,6 +43,10 @@ def export_to_csv( hex colors. node_ids: Optional set of node IDs to include. If provided, only these nodes and their ancestors will be included in the output. + export_full: If True, export the full graph (every node/edge, including + soft-deleted/candidate ones with solution=False); the "solution" column is + then included so the two can be distinguished. If False (default), export + only the solution view. use_display_names: If True, use feature display names as column headers. If False (default), use raw feature keys for backward compatibility. export_seg: Whether to export the segmentation alongside the CSV. @@ -70,6 +75,10 @@ def export_to_csv( """ tracklet_key = tracks.features.tracklet_key + # Which graph to export: the full graph (incl. solution=False candidates) or just + # the solution view. Topology reads (node ids, ancestors, predecessors) go through + # this; attribute reads use the tracks helpers, which already target graph_full. + graph = tracks.graph_full if export_full else tracks.graph_solution def convert_numpy_to_python(value): """Convert numpy types to native Python types.""" @@ -122,8 +131,10 @@ def convert_numpy_to_python(value): # Skip derived features (e.g. bbox managed by mask) if feature_name in derived_keys: continue - # Skip solution — graph is already filtered to solution=True - if feature_name == "solution": + # Skip solution unless exporting the full graph: in the solution-only + # export it is always True (uninformative); in a full export it + # distinguishes solution nodes from soft-deleted/candidate ones. + if feature_name == "solution" and not export_full: continue feature_names.append(feature_name) num_values = feature_dict.get("num_values", 1) @@ -155,15 +166,15 @@ def convert_numpy_to_python(value): # Determine which nodes to export if node_ids is None: - nodes_to_keep = tracks.graph_solution.node_ids() + nodes_to_keep = graph.node_ids() else: - nodes_to_keep = filter_graph_with_ancestors(tracks.graph_solution, node_ids) + nodes_to_keep = filter_graph_with_ancestors(graph, node_ids) # Write CSV file rows: list[dict[str, Any]] = [] for node_id in nodes_to_keep: - parents = list(tracks.graph_solution.predecessors(node_id)) + parents = list(graph.predecessors(node_id)) parent_id = "" if len(parents) == 0 else parents[0] row: dict[str, Any] diff --git a/tests/import_export/test_csv_export.py b/tests/import_export/test_csv_export.py index 9bfa6cc4..ae13a721 100644 --- a/tests/import_export/test_csv_export.py +++ b/tests/import_export/test_csv_export.py @@ -284,3 +284,41 @@ def test_export_solution_to_csv_with_seg_and_node_subset( expected_vals = {node_to_label[n] for n in graph_nodes} assert unique_vals == expected_vals + + +def test_export_full_vs_solution(get_tracks, tmp_path): + """export_full=True includes soft-deleted (solution=False) nodes and surfaces the + 'Solution' column; the default exports only the solution view.""" + import pandas as pd + + from funtracks.user_actions import UserDeleteNode + + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + # Soft-delete a leaf node: dropped from the solution view, kept in graph_full. + UserDeleteNode(tracks, 5) + sol_n = tracks.graph_solution.num_nodes() + full_n = tracks.graph_full.num_nodes() + assert full_n == sol_n + 1 + + # Default: solution view only — node 5 absent. + sol_file = tmp_path / "sol.csv" + export_to_csv(tracks, sol_file) + sol = pd.read_csv(sol_file) + assert len(sol) == sol_n + assert 5 not in set(sol["id"]) + + # export_full=True: full graph — node 5 present. + full_file = tmp_path / "full.csv" + export_to_csv(tracks, full_file, export_full=True) + full = pd.read_csv(full_file) + assert len(full) == full_n + assert 5 in set(full["id"]) + + # With display names, the full export carries a 'Solution' column distinguishing + # the soft-deleted node (False) from the rest (True). + disp_file = tmp_path / "full_disp.csv" + export_to_csv(tracks, disp_file, export_full=True, use_display_names=True) + disp = pd.read_csv(disp_file) + assert "Solution" in disp.columns + assert not bool(disp.loc[disp["ID"] == 5, "Solution"].iloc[0]) + assert bool(disp.loc[disp["ID"] == 1, "Solution"].iloc[0]) From 3279124767ca232b8e716925b84a36a14794a329 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 14:59:51 -0700 Subject: [PATCH 22/39] make TracksBuilder.build easier --- src/funtracks/import_export/_tracks_builder.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/funtracks/import_export/_tracks_builder.py b/src/funtracks/import_export/_tracks_builder.py index 5566bc58..8c22276a 100644 --- a/src/funtracks/import_export/_tracks_builder.py +++ b/src/funtracks/import_export/_tracks_builder.py @@ -782,23 +782,14 @@ def build( ) else: # The builder always produces a solution, so declare tracklet/lineage - # intent to register a TrackAnnotator. Use the attr name present on the - # constructed graph; fall back to the default keys (computed from scratch) - # when the source carried no track ids. - node_keys = graph.node_attr_keys() - tracklet_attr = next( - (k for k in ("tracklet_id", "track_id") if k in node_keys), - "tracklet_id", - ) - lineage_attr = next( - (k for k in ("lineage_id",) if k in node_keys), "lineage_id" - ) + # intent to register a TrackAnnotator. Tracks.__init__ auto-detects whether + # these attrs already exist on the graph (activate) or need computing. tracks = Tracks( graph=graph, pos_attr="pos", time_attr="t", - tracklet_attr=tracklet_attr, - lineage_attr=lineage_attr, + tracklet_attr="tracklet_id", + lineage_attr="lineage_id", ndim=self.ndim, scale=scale, ) From f9fa4a6e6b9965d8209384d3d2534c3a9432d428 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 15:06:36 -0700 Subject: [PATCH 23/39] load_v1_tracks no longer makes distinction between Tracks and SolutionTracks --- src/funtracks/import_export/_v1_format.py | 15 ++++++--------- tests/import_export/test_internal_format.py | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/funtracks/import_export/_v1_format.py b/src/funtracks/import_export/_v1_format.py index 15af0a8f..1bff2c12 100644 --- a/src/funtracks/import_export/_v1_format.py +++ b/src/funtracks/import_export/_v1_format.py @@ -23,9 +23,7 @@ ATTRS_FILE = "attrs.json" -def load_v1_tracks( - directory: Path, seg_required: bool = False, solution: bool = False -) -> Tracks: +def load_v1_tracks(directory: Path, seg_required: bool = False) -> Tracks: """Load a Tracks object from the given directory. Looks for files in the format generated by Tracks.save. @@ -35,8 +33,6 @@ def load_v1_tracks( directory (Path): The directory containing tracks to load seg_required (bool, optional): If true, raises a FileNotFoundError if the segmentation file is not present in the directory. Defaults to False. - solution (bool, optional): If true, returns a Tracks object, otherwise - returns a normal Tracks object. Defaults to False. Returns: Tracks: A tracks object loaded from the given directory @@ -115,10 +111,11 @@ def load_v1_tracks( # in v2.0. Import at runtime to avoid circular dependency. from ..data_model import Tracks - # A solution save must come back with track ids. When the stored attrs don't carry - # a FeatureDict (older saves), declare tracklet/lineage intent explicitly so a - # TrackAnnotator is registered. - if solution and "features" not in attrs: + # v1 saves are always solutions, so they must come back with track ids. When the + # stored attrs don't carry a FeatureDict (older saves), declare tracklet/lineage + # intent explicitly so a TrackAnnotator is registered. Newer saves carry a + # FeatureDict that already encodes the tracklet/lineage keys. + if "features" not in attrs: attrs.setdefault("tracklet_attr", "tracklet_id") attrs.setdefault("lineage_attr", "lineage_id") diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index bbba2246..636de386 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -27,7 +27,7 @@ def test_save_load( f"tests/data/format_v1/test_save_load_{prefill_track_ids}_{ndim}_{with_seg}_0" ) - loaded = load_v1_tracks(data_path, solution=prefill_track_ids) + loaded = load_v1_tracks(data_path) assert loaded.ndim == ndim # Check feature keys and important properties match (allow tuple vs list diff) assert loaded.features.time_key == tracks.features.time_key @@ -118,7 +118,7 @@ def test_load_without_features(tmp_path, graph_2d_with_segmentation): shutil.copytree(reference_path, tracks_path) # Load the original data first to verify it loads correctly - load_v1_tracks(tracks_path, solution=True) + load_v1_tracks(tracks_path) # Modify the copy to test backward compatibility attrs_path = tracks_path / "attrs.json" From e70d9462ab421d782c6cb94f78472b9fdfaa6e96 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 16:45:16 -0700 Subject: [PATCH 24/39] make graph_full the MAIN graph and graph_solution merely a view of it, full invert --- docs/features.md | 17 +++--- docs/import-flow.md | 2 +- .../candidate_graph/compute_graph.py | 8 +-- src/funtracks/candidate_graph/iou.py | 4 +- src/funtracks/candidate_graph/utils.py | 26 ++++----- src/funtracks/data_model/tracks.py | 57 +++++++++---------- .../import_export/_import_segmentation.py | 6 +- .../import_export/_tracks_builder.py | 24 +++----- src/funtracks/import_export/_utils.py | 2 +- src/funtracks/import_export/_validation.py | 2 +- src/funtracks/import_export/geff/_import.py | 2 +- src/funtracks/utils/__init__.py | 4 +- src/funtracks/utils/tracksdata_utils.py | 49 +++++++--------- tests/actions/test_action_history.py | 4 +- tests/actions/test_add_delete_edge.py | 4 +- tests/actions/test_add_delete_nodes.py | 18 ++++-- tests/conftest.py | 28 ++++----- tests/data_model/test_solution_tracks.py | 4 +- tests/data_model/test_tracks.py | 10 ++-- tests/import_export/test_import_from_geff.py | 23 ++++---- .../import_export/test_import_segmentation.py | 10 ++-- tests/utils/test_tracksdata_utils.py | 17 +++--- 22 files changed, 155 insertions(+), 166 deletions(-) diff --git a/docs/features.md b/docs/features.md index 0faae01d..ddfeb217 100644 --- a/docs/features.md +++ b/docs/features.md @@ -54,7 +54,7 @@ An abstract base class for components that compute and update features on a grap |-----------|---------|--------------|-------------------|---------------| | **RegionpropsAnnotator** | Extracts node features from segmentation using scikit-image's `regionprops` | `segmentation` must not be `None` | `pos`, `area`, `ellipse_axis_radii`, `circularity`, `perimeter` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.RegionpropsAnnotator) | | **EdgeAnnotator** | Computes edge features based on segmentation overlap between consecutive time frames | `segmentation` must not be `None` | `iou` (Intersection over Union) | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.EdgeAnnotator) | -| **TrackAnnotator** | Computes tracklet and lineage IDs for SolutionTracks | Must be used with `SolutionTracks` (binary tree structure) | `tracklet_id`, `lineage_id` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.TrackAnnotator) | +| **TrackAnnotator** | Computes tracklet and lineage IDs | Requires `tracks.features.tracklet_key` to be set (i.e. a tracking solution with a binary-tree structure) | `tracklet_id`, `lineage_id` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.TrackAnnotator) | ### 5. AnnotatorRegistry @@ -149,7 +149,8 @@ classDiagram } class Tracks { - +graph: td.graph.GraphView + +graph_full: td.graph.BaseGraph + +graph_solution: td.graph.GraphView +segmentation: ndarray|None +features: FeatureDict +annotators: AnnotatorRegistry @@ -199,10 +200,10 @@ provided: ```python # Uses default: tracklet_key="tracklet_id" -tracks = SolutionTracks(graph=graph) +tracks = Tracks(graph=graph, tracklet_attr="tracklet_id") # Uses custom attribute name -tracks = SolutionTracks(graph=graph, tracklet_attr="my_track_col") +tracks = Tracks(graph=graph, tracklet_attr="my_track_col") ``` Custom attribute names are also supported through `FeatureDict`: @@ -213,7 +214,7 @@ fd = FeatureDict( tracklet_key="my_track_col", lineage_key="my_lineage", ) -tracks = SolutionTracks(graph=graph, features=fd) +tracks = Tracks(graph=graph, features=fd) ``` When a `FeatureDict` is provided, the `tracklet_attr`/`pos_attr`/`time_attr` arguments @@ -229,7 +230,7 @@ TrackAnnotator(tracks, tracklet_key="my_track", lineage_key="my_lineage") RegionpropsAnnotator(tracks, pos_key="coordinates") ``` -When constructing `Tracks` or `SolutionTracks` directly, you have full control over +When constructing `Tracks` directly, you have full control over which attribute names are used. **Through the import path** (`tracks_from_df`, `import_from_geff`), computed features @@ -286,11 +287,11 @@ tracks = tracks_from_df(df, segmentation=seg) **Scenario 2: Creating tracks from raw segmentation** ```python -from funtracks.utils import create_empty_graphview_graph +from funtracks.utils import create_empty_graph from funtracks.data_model import Tracks # Create empty graph and add nodes -graph = create_empty_graphview_graph() +graph = create_empty_graph() graph.add_node(index=1, attrs={"t": 0}) tracks = Tracks(graph, segmentation=seg) # Auto-detection: pos, area don't exist → compute them from segmentation diff --git a/docs/import-flow.md b/docs/import-flow.md index 1cec7319..5710248b 100644 --- a/docs/import-flow.md +++ b/docs/import-flow.md @@ -25,7 +25,7 @@ graph LR Validate["validate
check graph structure"] ConstructGraph["construct_graph
build tracksdata graph"] HandleSeg["handle_segmentation
load & relabel if needed"] - CreateTracks[Create SolutionTracks] + CreateTracks[Create Tracks] EnableFeatures["enable_features
register & compute"] ValidateMap --> LoadSource diff --git a/src/funtracks/candidate_graph/compute_graph.py b/src/funtracks/candidate_graph/compute_graph.py index a0efa378..e77ee461 100644 --- a/src/funtracks/candidate_graph/compute_graph.py +++ b/src/funtracks/candidate_graph/compute_graph.py @@ -15,7 +15,7 @@ def compute_graph_from_seg( iou: bool = False, scale: list[float] | None = None, t_start: int = 0, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """Construct a candidate graph from a segmentation array. Nodes are placed at the centroid of each segmentation and edges are added for all nodes in adjacent frames within max_edge_distance. @@ -37,7 +37,7 @@ def compute_graph_from_seg( time values. Defaults to 0. Returns: - td.graph.GraphView: A candidate graph that can be passed to the motile solver + td.graph.BaseGraph: A candidate graph that can be passed to the motile solver """ # add nodes (including mask and bbox in the same bulk_add_nodes call) cand_graph, node_frame_dict = nodes_from_segmentation( @@ -73,7 +73,7 @@ def compute_graph_from_points_list( points_list: np.ndarray, max_edge_distance: float, scale: list[float] | None = None, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """Construct a candidate graph from a points list. Args: @@ -88,7 +88,7 @@ def compute_graph_from_points_list( isotropic. Returns: - td.graph.GraphView: A candidate graph that can be passed to the motile solver. + td.graph.BaseGraph: A candidate graph that can be passed to the motile solver. """ # add nodes cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale) diff --git a/src/funtracks/candidate_graph/iou.py b/src/funtracks/candidate_graph/iou.py index 0ca2d315..97dd6e71 100644 --- a/src/funtracks/candidate_graph/iou.py +++ b/src/funtracks/candidate_graph/iou.py @@ -80,7 +80,7 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]: def add_iou( - cand_graph: td.graph.GraphView, + cand_graph: td.graph.BaseGraph, segmentation: np.ndarray, node_frame_dict: dict[int, list[int]] | None = None, multiseg=False, @@ -92,7 +92,7 @@ def add_iou( add IOU to an existing graph after the fact. Args: - cand_graph (td.graph.GraphView): Candidate graph with nodes and edges already + cand_graph (td.graph.BaseGraph): Candidate graph with nodes and edges already populated. segmentation (np.ndarray): segmentation that was used to create cand_graph. Has shape ([h], t, [z], y, x), where h is the number of hypotheses if diff --git a/src/funtracks/candidate_graph/utils.py b/src/funtracks/candidate_graph/utils.py index 68948515..f1fee38c 100644 --- a/src/funtracks/candidate_graph/utils.py +++ b/src/funtracks/candidate_graph/utils.py @@ -9,7 +9,7 @@ from tqdm import tqdm from tracksdata.nodes import Mask -from ..utils.tracksdata_utils import create_empty_graphview_graph +from ..utils.tracksdata_utils import create_empty_graph logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ def nodes_from_segmentation( scale: list[float] | None = None, mask: bool = True, t_start: int = 0, -) -> tuple[td.graph.GraphView, dict[int, list[Any]]]: +) -> tuple[td.graph.BaseGraph, dict[int, list[Any]]]: """Extract candidate nodes from a segmentation. Returns a tracksdata graph with only nodes, and also a dictionary from frames to node_ids for efficient edge adding. @@ -50,7 +50,7 @@ def nodes_from_segmentation( time values. Defaults to 0. Returns: - tuple[td.graph.GraphView, dict[int, list[Any]]]: A candidate graph with only + tuple[td.graph.BaseGraph, dict[int, list[Any]]]: A candidate graph with only nodes, and a mapping from time frames to node ids. """ logger.debug("Extracting nodes from segmentation") @@ -66,7 +66,7 @@ def nodes_from_segmentation( ) node_attributes = ["pos", "area", "mask", "bbox"] if mask else ["pos", "area"] - cand_graph = create_empty_graphview_graph( + cand_graph = create_empty_graph( node_attributes=node_attributes, position_attrs=["pos"], ndim=segmentation.ndim, @@ -115,7 +115,7 @@ def nodes_from_segmentation( def nodes_from_points_list( points_list: np.ndarray, scale: list[float] | None = None, -) -> tuple[td.graph.GraphView, dict[int, list[Any]]]: +) -> tuple[td.graph.BaseGraph, dict[int, list[Any]]]: """Extract candidate nodes from a list of points. Uses the index of the point in the list as its unique id. Returns a tracksdata graph with only nodes, and also a dictionary from frames to @@ -130,7 +130,7 @@ def nodes_from_points_list( implies the data is isotropic. Returns: - tuple[td.graph.GraphView, dict[int, list[Any]]]: A candidate graph with only + tuple[td.graph.BaseGraph, dict[int, list[Any]]]: A candidate graph with only nodes, and a mapping from time frames to node ids. """ logger.info("Extracting nodes from points list") @@ -143,7 +143,7 @@ def nodes_from_points_list( ) points_list = points_list * np.array(scale) - cand_graph = create_empty_graphview_graph( + cand_graph = create_empty_graph( node_attributes=["pos"], position_attrs=["pos"], ndim=ndim, @@ -168,11 +168,11 @@ def nodes_from_points_list( return cand_graph, node_frame_dict -def _compute_node_frame_dict(cand_graph: td.graph.GraphView) -> dict[int, list[Any]]: +def _compute_node_frame_dict(cand_graph: td.graph.BaseGraph) -> dict[int, list[Any]]: """Compute dictionary from time frames to node ids for candidate graph. Args: - cand_graph (td.graph.GraphView): A tracksdata graph + cand_graph (td.graph.BaseGraph): A tracksdata graph Returns: dict[int, list[Any]]: A mapping from time frames to lists of node ids. @@ -187,12 +187,12 @@ def _compute_node_frame_dict(cand_graph: td.graph.GraphView) -> dict[int, list[A return node_frame_dict -def create_kdtree(cand_graph: td.graph.GraphView, node_ids: list[Any]) -> KDTree: +def create_kdtree(cand_graph: td.graph.BaseGraph, node_ids: list[Any]) -> KDTree: """Create a kdtree with the given nodes from the candidate graph. Will fail if provided node ids are not in the candidate graph. Args: - cand_graph (td.graph.GraphView): A candidate graph + cand_graph (td.graph.BaseGraph): A candidate graph node_ids (list[Any]): The nodes within the candidate graph to include in the KDTree. Useful for limiting to one time frame. Must be a list (not a generic iterable) to preserve order for @@ -212,7 +212,7 @@ def create_kdtree(cand_graph: td.graph.GraphView, node_ids: list[Any]) -> KDTree def add_cand_edges( - cand_graph: td.graph.GraphView, + cand_graph: td.graph.BaseGraph, max_edge_distance: float, node_frame_dict: None | dict[int, list[Any]] = None, iou_dict: dict[int, dict[int, float]] | None = None, @@ -221,7 +221,7 @@ def add_cand_edges( frames that are closer than max_edge_distance. Args: - cand_graph (td.graph.GraphView): Candidate graph with only nodes populated. + cand_graph (td.graph.BaseGraph): Candidate graph with only nodes populated. Will be modified in-place to add edges. max_edge_distance (float): Maximum distance that objects can travel between frames. All nodes within this distance in adjacent frames will by connected diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 7a94e599..ad4f6cc8 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -53,10 +53,11 @@ class Tracks: position attribute. Edges in the graph represent links across time. Attributes: - graph_solution (td.graph.GraphView): A solution=True view of the full graph, - with nodes representing detections and edges representing links across time. - graph_full (td.graph.GraphView): The full graph: every node/edge ever known, - including soft-deleted (solution=False) ones. Backs graph_solution. + graph_full (td.graph.BaseGraph): The full graph (first-class): every node/edge + ever known, including soft-deleted (solution=False) candidates. Nodes + represent detections, edges represent links across time. + graph_solution (td.graph.GraphView): A solution==True view derived from + graph_full; the user-visible tracking solution. features (FeatureDict): Dictionary of features tracked on graph nodes/edges. annotators (AnnotatorRegistry): List of annotators that compute features. scale (list[float] | None): How much to scale each dimension by, including time. @@ -68,7 +69,7 @@ class Tracks: def __init__( self, - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, time_attr: str | None = None, pos_attr: str | tuple[str, ...] | list[str] | None = None, tracklet_attr: str | None = None, @@ -81,8 +82,9 @@ def __init__( """Initialize a Tracks object. Args: - graph (td.graph.GraphView): tracksdata directed graph with nodes as detections - and edges as links. + graph (td.graph.BaseGraph): the full base graph (graph_full) with nodes as + detections and edges as links. Must be a base graph, not a view; + graph_solution is built internally as its solution==True view. time_attr (str | None): Graph attribute name for time. Defaults to "time" if None. pos_attr (str | tuple[str, ...] | list[str] | None): Graph attribute @@ -108,20 +110,23 @@ def __init__( _segmentation (GraphArrayView | None): Internal parameter for reusing an existing GraphArrayView instance. Not intended for public use. """ - self.graph_solution = graph - # Depth-1 invariant: graph_full is defined as graph_solution._root (one hop), - # so the root must be a base graph, NOT itself a GraphView. A nested view - # (base -> crop -> solution) would silently make graph_full mean "the crop" - # instead of "every node ever known", breaking the AddNode revive-vs-new check - # (graph_full.has_node) and any annotator registered on graph_full. Fail loudly - # rather than corrupt data if cropping is ever introduced upstream. - if isinstance(graph._root, td.graph.GraphView): + # graph_full is the first-class object: the base graph holding every node/edge + # ever known, including soft-deleted (solution=False) candidates. graph_solution + # is derived from it as a solution==True view. + if isinstance(graph, td.graph.GraphView): raise ValueError( - "Tracks requires graph_solution to be a direct view of a base graph " - "(graph_solution._root must not itself be a GraphView). A nested view " - "chain (e.g. a crop of a crop) violates the depth-1 assumption that " - "graph_full = graph_solution._root is the full graph." + "Tracks requires the full base graph (graph_full), not a GraphView. " + "graph_solution is built internally as a solution==True view of it." ) + if "solution" not in graph.node_attr_keys(): + graph.add_node_attr_key("solution", default_value=True, dtype=pl.Boolean) + if "solution" not in graph.edge_attr_keys(): + graph.add_edge_attr_key("solution", default_value=True, dtype=pl.Boolean) + self.graph_full = graph + self.graph_solution = graph.filter( + td.NodeAttr("solution") == True, # noqa: E712 + td.EdgeAttr("solution") == True, # noqa: E712 + ).subgraph() if _segmentation is not None: # Reuse provided segmentation instance (internal use only) self.segmentation = _segmentation @@ -133,8 +138,10 @@ def __init__( seg_shape = graph.metadata.get("segmentation_shape") if seg_shape is not None: try: + # Render the segmentation from the solution view so soft-deleted + # nodes drop out of the array, mirroring the user-visible graph. array_view = GraphArrayView( - graph=graph, + graph=self.graph_solution, shape=seg_shape, attr_key="node_id", offset=0, @@ -196,16 +203,6 @@ def __init__( # and computed. A provided FeatureDict that omitted them is completed here. self._ensure_track_features() - @property - def graph_full(self) -> td.graph.GraphView: - """The full graph: every node/edge ever known, including soft-deleted - (solution=False) ones. `graph_solution` is a solution=True view of this. - - Backed by the solution view's root, so the two can never drift: rebuilding - graph_solution via graph_full.filter(...).subgraph() keeps the same root. - """ - return self.graph_solution._root - def _get_feature_set( self, time_attr: str | None, diff --git a/src/funtracks/import_export/_import_segmentation.py b/src/funtracks/import_export/_import_segmentation.py index e1545d94..db04c8a3 100644 --- a/src/funtracks/import_export/_import_segmentation.py +++ b/src/funtracks/import_export/_import_segmentation.py @@ -46,11 +46,11 @@ def load_segmentation(segmentation: Path | np.ndarray | da.Array) -> da.Array: def relabel_segmentation( seg_array: da.Array | np.ndarray, - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, node_ids: ArrayLike, seg_ids: ArrayLike, time_values: ArrayLike, -) -> tuple[np.ndarray, td.graph.GraphView]: +) -> tuple[np.ndarray, td.graph.BaseGraph]: """Relabel segmentation from seg_id to node_id. Handles the case where node_id 0 exists by offsetting all node IDs by 1, @@ -58,7 +58,7 @@ def relabel_segmentation( Args: seg_array: Segmentation array (dask or numpy) - graph: tracksdata GraphView (will be relabeled if node_id 0 exists) + graph: tracksdata base graph (will be relabeled if node_id 0 exists) node_ids: Array of node IDs seg_ids: Array of segmentation IDs corresponding to each node time_values: Array of time values for each node diff --git a/src/funtracks/import_export/_tracks_builder.py b/src/funtracks/import_export/_tracks_builder.py index 8c22276a..2bfdd0f8 100644 --- a/src/funtracks/import_export/_tracks_builder.py +++ b/src/funtracks/import_export/_tracks_builder.py @@ -39,7 +39,7 @@ ) from funtracks.utils.tracksdata_utils import ( add_masks_and_bboxes_to_graph, - create_empty_graphview_graph, + create_empty_graph, ) if TYPE_CHECKING: @@ -415,7 +415,7 @@ def construct_graph( self, node_name_map: dict[str, str | list[str]] | None = None, database: str | None = None, - ) -> td.graph.GraphView: + ) -> td.graph.BaseGraph: """Construct Tracksdata graph from validated InMemoryGeff data. Common logic shared across all formats. @@ -427,7 +427,7 @@ def construct_graph( If None (default), an in-memory/temp graph is used. Returns: - Tracksdata GraphView with standard keys + Tracksdata base graph with standard keys Raises: ValueError: If data not loaded or validated @@ -481,7 +481,7 @@ def construct_graph( default_value = 0 node_default_values.append(default_value) - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=list(self.in_memory_geff["node_props"].keys()), edge_attributes=list(self.in_memory_geff["edge_props"].keys()), node_default_values=node_default_values, @@ -536,23 +536,17 @@ def construct_graph( if self.TIME_ATTR != "t": graph.remove_node_attr_key(self.TIME_ATTR) - # create_empty_graphview_graph returns a filtered view, but that view is - # a snapshot at filter time; nodes/edges added afterwards (potentially with - # solution=False) bypass the filter. Re-filter the populated root so - # solution=False rows are actually excluded. - graph = graph._root.filter( - td.NodeAttr("solution") == True, # noqa: E712 - td.EdgeAttr("solution") == True, # noqa: E712 - ).subgraph() - + # Return the full base graph. Tracks builds the solution==True view internally, + # so solution=False candidates added during population stay in graph_full and + # are excluded from graph_solution by Tracks, not here. return graph def handle_segmentation( self, - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, segmentation: Path | np.ndarray | None, scale: list[float] | None, - ) -> tuple[np.ndarray | None, list[float] | None, td.graph.GraphView]: + ) -> tuple[np.ndarray | None, list[float] | None, td.graph.BaseGraph]: """Load, validate, and optionally relabel segmentation. Common logic shared across all formats. diff --git a/src/funtracks/import_export/_utils.py b/src/funtracks/import_export/_utils.py index 334cf42d..2f8d2feb 100644 --- a/src/funtracks/import_export/_utils.py +++ b/src/funtracks/import_export/_utils.py @@ -71,7 +71,7 @@ def infer_dtype_from_array(arr: ArrayLike) -> ValueType: def filter_graph_with_ancestors( - graph: td.graph.GraphView, nodes_to_keep: set[int] + graph: td.graph.BaseGraph, nodes_to_keep: set[int] ) -> list[int]: """Filter a graph to keep only the nodes in `nodes_to_keep` and their ancestors. diff --git a/src/funtracks/import_export/_validation.py b/src/funtracks/import_export/_validation.py index 415cf90c..27cc028c 100644 --- a/src/funtracks/import_export/_validation.py +++ b/src/funtracks/import_export/_validation.py @@ -22,7 +22,7 @@ def validate_graph_seg_match( - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, segmentation: da.Array, scale: list[float], position_attr: list[str], diff --git a/src/funtracks/import_export/geff/_import.py b/src/funtracks/import_export/geff/_import.py index 4a7a4da1..8cd6ce9d 100644 --- a/src/funtracks/import_export/geff/_import.py +++ b/src/funtracks/import_export/geff/_import.py @@ -243,7 +243,7 @@ def construct_graph( self, node_name_map: dict[str, str | list[str]] | None = None, database: str | None = None, - ) -> td.graph.GraphView: + ) -> td.graph.BaseGraph: """Construct graph and prepare embedded segmentation data. The GEFF format serialises mask data as plain numeric arrays (zarr diff --git a/src/funtracks/utils/__init__.py b/src/funtracks/utils/__init__.py index c0ff3cd2..8c7867f4 100644 --- a/src/funtracks/utils/__init__.py +++ b/src/funtracks/utils/__init__.py @@ -10,10 +10,10 @@ setup_zarr_array, setup_zarr_group, ) -from .tracksdata_utils import create_empty_graphview_graph +from .tracksdata_utils import create_empty_graph __all__ = [ - "create_empty_graphview_graph", + "create_empty_graph", "detect_zarr_spec_version", "get_store_path", "is_zarr_v3", diff --git a/src/funtracks/utils/tracksdata_utils.py b/src/funtracks/utils/tracksdata_utils.py index 9291ec55..3b1f3887 100644 --- a/src/funtracks/utils/tracksdata_utils.py +++ b/src/funtracks/utils/tracksdata_utils.py @@ -67,7 +67,7 @@ def to_polars_dtype(dtype_or_value: str | Any) -> pl.DataType: raise ValueError(f"Unsupported type: {type(dtype_or_value)}") -def create_empty_graphview_graph( +def create_empty_graph( node_attributes: list[str] | None = None, edge_attributes: list[str] | None = None, node_default_values: list[Any] | None = None, @@ -75,9 +75,9 @@ def create_empty_graphview_graph( database: str | None = None, position_attrs: list[str] | None = None, ndim: int = 3, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """ - Create an empty tracksdata GraphView with standard node and edge attributes. + Create an empty tracksdata base graph with standard node and edge attributes. Parameters ---------- node_attributes : list[str] | None @@ -102,8 +102,9 @@ def create_empty_graphview_graph( Returns ------- - td.graph.GraphView - An empty tracksdata GraphView with standard node and edge attributes. + td.graph.BaseGraph + An empty tracksdata base graph with standard node and edge attributes + (including a `solution` flag). Tracks builds the solution==True view from it. """ if position_attrs is None: position_attrs = ["pos"] @@ -187,12 +188,8 @@ def create_empty_graphview_graph( if "solution" not in graph_td.edge_attr_keys(): graph_td.add_edge_attr_key("solution", default_value=True, dtype=pl.Boolean) - graph_td_sub = graph_td.filter( - td.NodeAttr("solution") == True, # noqa: E712 - td.EdgeAttr("solution") == True, # noqa: E712 - ).subgraph() - - return graph_td_sub + # Return the full base graph; Tracks builds the solution==True view internally. + return graph_td def assert_node_attrs_equal_with_masks( @@ -202,8 +199,8 @@ def assert_node_attrs_equal_with_masks( Fully compare the content of two graphs (node attributes and Masks) """ - if isinstance(object1, td.graph.GraphView) and ( - isinstance(object2, td.graph.GraphView) + if isinstance(object1, td.graph.BaseGraph) and ( + isinstance(object2, td.graph.BaseGraph) ): node_attrs1 = object1.node_attrs() node_attrs2 = object2.node_attrs() @@ -385,21 +382,21 @@ def segmentation_to_masks( def add_masks_and_bboxes_to_graph( - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, segmentation: np.ndarray, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """Add mask and bbox attributes to graph nodes from segmentation. Parameters ---------- - graph : td.graph.GraphView + graph : td.graph.BaseGraph Graph to add attributes to segmentation : np.ndarray Segmentation array of shape (T, Z, Y, X) or (T, Y, X) Returns ------- - td.graph.GraphView + td.graph.BaseGraph Graph with 'mask' and 'bbox' attributes added to nodes """ @@ -480,14 +477,11 @@ def td_relabel_nodes(graph, mapping: dict[int, int]) -> td.graph.IndexedRXGraph: } new_graph.add_edge(source_id, target_id, attrs) - new_graph_sub = new_graph.filter( - td.NodeAttr("solution") == True, # noqa: E712 - td.EdgeAttr("solution") == True, # noqa: E712 - ).subgraph() - return new_graph_sub + # Return the full base graph; Tracks builds the solution==True view internally. + return new_graph -def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.GraphView: +def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.BaseGraph: """Convert a NetworkX DiGraph to a tracksdata graph. Args: @@ -584,10 +578,5 @@ def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.GraphView: attrs_copy["solution"] = True graph_td.add_edge(source_id, target_id, attrs_copy) - # Create subgraph (GraphView) with only solution nodes and edges - graph_td_sub = graph_td.filter( - td.NodeAttr("solution") == True, # noqa: E712 - td.EdgeAttr("solution") == True, # noqa: E712 - ).subgraph() - - return graph_td_sub + # Return the full base graph; Tracks builds the solution==True view internally. + return graph_td diff --git a/tests/actions/test_action_history.py b/tests/actions/test_action_history.py index 7502a095..96fd5218 100644 --- a/tests/actions/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -1,14 +1,14 @@ from funtracks.actions import AddNode from funtracks.actions.action_history import ActionHistory from funtracks.data_model import Tracks -from funtracks.utils.tracksdata_utils import create_empty_graphview_graph +from funtracks.utils.tracksdata_utils import create_empty_graph # https://github.com/zaboople/klonk/blob/master/TheGURQ.md def test_action_history(): history = ActionHistory() - empty_graph = create_empty_graphview_graph( + empty_graph = create_empty_graph( node_attributes=["track_id", "pos"], edge_attributes=[], ) diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index a552dc7e..212790ae 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -10,7 +10,7 @@ ) from funtracks.data_model import Tracks from funtracks.features import FeatureDict, LineageID, Position, Time, TrackletID -from funtracks.utils.tracksdata_utils import create_empty_graphview_graph +from funtracks.utils.tracksdata_utils import create_empty_graph iou_key = "iou" @@ -179,7 +179,7 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path): # Build a graph with "custom_score" on every edge. # This mirrors what the motile solver does: it writes edge attributes directly # to the graph without going through tracks.add_feature(). - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["pos", "track_id", "lineage_id"], edge_attributes=["custom_score"], database=db_path, diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py index 9d375696..64990ebe 100644 --- a/tests/actions/test_add_delete_nodes.py +++ b/tests/actions/test_add_delete_nodes.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import tracksdata as td from numpy.testing import assert_array_almost_equal, assert_array_equal from polars.testing import assert_frame_equal from tracksdata.array import GraphArrayView @@ -10,7 +11,7 @@ ) from funtracks.utils.tracksdata_utils import ( assert_node_attrs_equal_with_masks, - create_empty_graphview_graph, + create_empty_graph, ) from ..conftest import make_2d_disk_mask, make_3d_sphere_mask @@ -32,13 +33,19 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): tracks.features.position_key, ] edge_attributes = ["iou"] if with_seg else [] - empty_graph = create_empty_graphview_graph( + empty_graph = create_empty_graph( node_attributes=node_attributes + (["area", "bbox", "mask"] if with_seg else []), edge_attributes=edge_attributes, ndim=ndim, ) empty_seg = np.zeros_like(tracks.segmentation) if with_seg else None - tracks.graph_solution = empty_graph + # Reset the tracks onto the empty base graph, mirroring Tracks.__init__: graph_full + # is the base graph, graph_solution its solution==True view. + tracks.graph_full = empty_graph + tracks.graph_solution = empty_graph.filter( + td.NodeAttr("solution") == True, # noqa: E712 + td.EdgeAttr("solution") == True, # noqa: E712 + ).subgraph() segmentation_shape = (5, 100, 100) if ndim == 3 else (5, 100, 100, 100) tracks.segmentation = ( GraphArrayView( @@ -96,9 +103,10 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): check_dtypes=False, ) - # Invert the action to delete all the nodes + # Invert the action to delete all the nodes. They are soft-deleted, so they remain + # in graph_full (empty_graph) with solution=False but drop out of the solution view. del_nodes = action.inverse() - assert set(tracks.graph_solution.node_ids()) == set(empty_graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set() if with_seg: assert_array_almost_equal(tracks.segmentation, empty_seg) diff --git a/tests/conftest.py b/tests/conftest.py index 5e3a1cfc..1224486a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from tracksdata.nodes import Mask from funtracks.utils.tracksdata_utils import ( - create_empty_graphview_graph, + create_empty_graph, ) if TYPE_CHECKING: @@ -134,7 +134,7 @@ def _make_graph( with_iou: bool = False, with_masks: bool = False, database: str | None = None, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """Generate a test graph with configurable features. Args: @@ -172,7 +172,7 @@ def _make_graph( node_attributes.append(td.DEFAULT_ATTR_KEYS.BBOX) node_default_values.append(0.0) - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=node_attributes, node_default_values=node_default_values, edge_attributes=edge_attributes, @@ -288,28 +288,28 @@ def _make_graph( @pytest.fixture -def graph_clean(tmp_path) -> td.graph.GraphView: +def graph_clean(tmp_path) -> td.graph.BaseGraph: """Base graph with only time - no positions or computed features.""" db_path = str(tmp_path / "graph_clean.db") return _make_graph(ndim=3, database=db_path) @pytest.fixture -def graph_2d_with_position(tmp_path) -> td.graph.GraphView: +def graph_2d_with_position(tmp_path) -> td.graph.BaseGraph: """Graph with 2D positions - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_2d_position.db") return _make_graph(ndim=3, with_pos=True, database=db_path) @pytest.fixture -def graph_2d_with_track_id(tmp_path) -> td.graph.GraphView: +def graph_2d_with_track_id(tmp_path) -> td.graph.BaseGraph: """Graph with 2D positions and track_id - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_2d_track_id.db") return _make_graph(ndim=3, with_pos=True, with_track_id=True, database=db_path) @pytest.fixture -def graph_2d_with_segmentation(tmp_path) -> td.graph.GraphView: +def graph_2d_with_segmentation(tmp_path) -> td.graph.BaseGraph: """Graph with segmentation (masks/bboxes) and all computed features.""" db_path = str(tmp_path / "graph_2d_segmentation.db") return _make_graph( @@ -324,21 +324,21 @@ def graph_2d_with_segmentation(tmp_path) -> td.graph.GraphView: @pytest.fixture -def graph_3d_with_position(tmp_path) -> td.graph.GraphView: +def graph_3d_with_position(tmp_path) -> td.graph.BaseGraph: """Graph with 3D positions - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_3d_position.db") return _make_graph(ndim=4, with_pos=True, database=db_path) @pytest.fixture -def graph_3d_with_track_id(tmp_path) -> td.graph.GraphView: +def graph_3d_with_track_id(tmp_path) -> td.graph.BaseGraph: """Graph with 3D positions and track_id - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_3d_track_id.db") return _make_graph(ndim=4, with_pos=True, with_track_id=True, database=db_path) @pytest.fixture -def graph_3d_with_segmentation(tmp_path) -> td.graph.GraphView: +def graph_3d_with_segmentation(tmp_path) -> td.graph.BaseGraph: """Graph with segmentation (masks/bboxes) and all computed features.""" db_path = str(tmp_path / "graph_3d_segmentation.db") return _make_graph( @@ -431,9 +431,9 @@ def _make_tracks( @pytest.fixture -def graph_2d_list(tmp_path) -> td.graph.GraphView: +def graph_2d_list(tmp_path) -> td.graph.BaseGraph: db_path = str(tmp_path / "graph_2d_list.db") - graph = create_empty_graphview_graph(database=db_path) + graph = create_empty_graph(database=db_path) nodes = [ { @@ -472,7 +472,7 @@ def sphere(center, radius, shape): @pytest.fixture -def get_graph(tmp_path) -> Callable[..., td.graph.GraphView]: +def get_graph(tmp_path) -> Callable[..., td.graph.BaseGraph]: """Factory fixture to create a graph with configurable features. Args: @@ -496,7 +496,7 @@ def _get_graph( with_pos: bool = True, prefill_track_ids: bool = False, with_seg: bool = False, - ) -> td.graph.GraphView: + ) -> td.graph.BaseGraph: counter[0] += 1 db_path = str(tmp_path / f"graph_{counter[0]}.db") return _make_graph( diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index b7c94a1c..10cdb0e0 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -6,7 +6,7 @@ from funtracks.import_export import export_to_csv from funtracks.user_actions import UserUpdateSegmentation from funtracks.utils.tracksdata_utils import ( - create_empty_graphview_graph, + create_empty_graph, td_mask_to_pixels, ) @@ -51,7 +51,7 @@ def test_update_segmentation(graph_2d_with_segmentation): def test_next_track_id_empty(): - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["pos", "track_id"], edge_attributes=[], ) diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 18828272..06f35b81 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -8,16 +8,16 @@ from funtracks.features import SegBbox, SegMask from funtracks.user_actions import UserUpdateSegmentation from funtracks.utils.tracksdata_utils import ( - create_empty_graphview_graph, + create_empty_graph, to_polars_dtype, ) track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} -def test_create_tracks(graph_3d_with_segmentation: td.graph.GraphView): +def test_create_tracks(graph_3d_with_segmentation: td.graph.BaseGraph): # create empty tracks - empty_graph = create_empty_graphview_graph() + empty_graph = create_empty_graph() tracks = Tracks(graph=empty_graph, ndim=3, **track_attrs) # type: ignore[arg-type] assert tracks.features.position_key == "pos" assert isinstance(tracks.features["pos"], dict) @@ -190,7 +190,7 @@ def test_set_pixels_no_segmentation(graph_2d_with_track_id): def test_compute_ndim_errors(): - g = create_empty_graphview_graph() + g = create_empty_graph() g.add_node_attr_key("pos", default_value=[0, 0], dtype=pl.List(pl.Int64)) g.add_node(index=1, attrs={"t": 0, "pos": [0, 0, 0], "solution": True}) @@ -302,7 +302,7 @@ def test_to_polars_dtype_mask(): def test_add_feature_mask_creates_both_columns(): """add_feature with mask and bbox Features creates both columns.""" - graph = create_empty_graphview_graph(ndim=3) + graph = create_empty_graph(ndim=3) tracks = Tracks(graph, ndim=3, **track_attrs) assert "nuc_mask" not in tracks.graph_solution.node_attr_keys() diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index 7826a848..75cc3984 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -8,7 +8,7 @@ from funtracks.data_model import Tracks from funtracks.import_export import export_to_geff, import_from_geff from funtracks.import_export.geff._import import GeffTracksBuilder, import_graph_from_geff -from funtracks.utils.tracksdata_utils import create_empty_graphview_graph +from funtracks.utils.tracksdata_utils import create_empty_graph @pytest.fixture @@ -472,7 +472,7 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path): import tracksdata as td from tracksdata.nodes import Mask - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=[ "pos", "area", @@ -602,7 +602,7 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path): import tracksdata as td import zarr as _zarr - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=[ "pos", td.DEFAULT_ATTR_KEYS.MASK, @@ -748,7 +748,7 @@ def test_3d_pos_survives_sql_roundtrip(tmp_path): """Regression test: 3D pos (z, y, x) must keep Array dtype through SQL roundtrip. The construct_graph() method must pass the correct ndim to - create_empty_graphview_graph() so the pos schema is Array(Float64, 3) not + create_empty_graph() so the pos schema is Array(Float64, 3) not Array(Float64, 2). A schema mismatch causes SQLGraph.from_other() to downgrade the column to List(Float64), which breaks downstream callers that rely on to_numpy() returning a 2D float64 array. @@ -878,7 +878,7 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path): td.DEFAULT_ATTR_KEYS.BBOX, ] # node_default_values must align with node_attributes by index. - # pos/mask/bbox are handled by special cases in create_empty_graphview_graph + # pos/mask/bbox are handled by special cases in create_empty_graph # and their slot values here are never accessed; only area, ellipse_axis_radii, # track_id, and lineage_id go through the general loop. node_default_values = [ @@ -890,7 +890,7 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path): None, # mask — special-cased, slot unused None, # bbox — special-cased, slot unused ] - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=node_attributes, node_default_values=node_default_values, edge_attributes=[], @@ -1005,7 +1005,7 @@ def test_featuredict_survives_geff_roundtrip(tmp_path): None, # mask — special-cased, slot unused None, # bbox — special-cased, slot unused ] - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=node_attributes, node_default_values=node_default_values, edge_attributes=[], @@ -1170,7 +1170,7 @@ def test_import_from_geff_respects_external_solution_column(tmp_path): # Separate y/x columns (not a 2-D 'pos' array) so tracksdata's to_geff # registers them as space-typed axes — required to exercise the # axes-branch of GeffTracksBuilder.infer_node_name_map. - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["y", "x"], position_attrs=["y", "x"], ndim=3 ) graph.bulk_add_nodes( @@ -1182,11 +1182,10 @@ def test_import_from_geff_respects_external_solution_column(tmp_path): indices=[1, 2, 3], ) - # Export the root graph, not the filtered solution-only view, so the - # solution=False row survives into the geff (mimicking a solver-produced - # geff with rejected nodes). + # Export the full base graph so the solution=False row survives into the geff + # (mimicking a solver-produced geff with rejected nodes). tracks_path = tmp_path / "tracks.geff" - graph._root.to_geff(geff_store=tracks_path, zarr_format=2) + graph.to_geff(geff_store=tracks_path, zarr_format=2) tracks = import_from_geff(tracks_path) diff --git a/tests/import_export/test_import_segmentation.py b/tests/import_export/test_import_segmentation.py index fcbe8b5f..088de86c 100644 --- a/tests/import_export/test_import_segmentation.py +++ b/tests/import_export/test_import_segmentation.py @@ -7,7 +7,7 @@ load_segmentation, relabel_segmentation, ) -from funtracks.utils.tracksdata_utils import create_empty_graphview_graph +from funtracks.utils.tracksdata_utils import create_empty_graph class TestLoadSegmentation: @@ -46,7 +46,7 @@ def test_basic_relabeling(self): seg[1, 2, 2] = 20 # seg_id 20 at t=1 # Create graph with node_ids 1, 2 - graph = create_empty_graphview_graph() + graph = create_empty_graph() graph.add_node(index=1, attrs={"t": 0, "solution": True}) graph.add_node(index=2, attrs={"t": 1, "solution": True}) @@ -70,7 +70,7 @@ def test_relabeling_with_node_id_zero(self): seg[1, 2, 2] = 20 # seg_id 20 at t=1 # Create graph with node_ids 0, 1 (includes 0!) - graph = create_empty_graphview_graph() + graph = create_empty_graph() graph.add_node(index=0, attrs={"t": 0, "solution": True}) graph.add_node(index=1, attrs={"t": 1, "solution": True}) @@ -97,7 +97,7 @@ def test_no_relabeling_needed_same_ids(self): seg[0, 1, 1] = 1 seg[1, 2, 2] = 2 - graph = create_empty_graphview_graph() + graph = create_empty_graph() graph.add_node(index=1, attrs={"t": 0, "solution": True}) graph.add_node(index=2, attrs={"t": 1, "solution": True}) @@ -119,7 +119,7 @@ def test_multiple_nodes_same_timepoint(self): seg[0, 2, 2] = 20 seg[0, 3, 3] = 30 - graph = create_empty_graphview_graph() + graph = create_empty_graph() graph.add_node(index=1, attrs={"t": 0, "solution": True}) graph.add_node(index=2, attrs={"t": 0, "solution": True}) graph.add_node(index=3, attrs={"t": 0, "solution": True}) diff --git a/tests/utils/test_tracksdata_utils.py b/tests/utils/test_tracksdata_utils.py index 32231579..e4eb0793 100644 --- a/tests/utils/test_tracksdata_utils.py +++ b/tests/utils/test_tracksdata_utils.py @@ -6,7 +6,7 @@ import pytest from funtracks.utils.tracksdata_utils import ( - create_empty_graphview_graph, + create_empty_graph, pixels_to_td_mask, td_mask_to_pixels, ) @@ -142,7 +142,7 @@ def test_pixels_coordinate_offset(ndim): def test_memory_graph_survives_thread_boundary(): - """A GraphView created in a worker thread must remain accessible from the main thread. + """A base graph created in a worker thread must stay accessible from the main thread. Regression test: nodes_from_segmentation previously used database=':memory:', which caused 'no such table: Metadata' when the graph crossed a thread boundary @@ -152,7 +152,7 @@ def test_memory_graph_survives_thread_boundary(): result = {} def worker(): - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["pos"], ndim=3, ) @@ -165,22 +165,23 @@ def worker(): graph = result["graph"] - # This calls graph.metadata internally via BaseGraph.from_other(). + # create_empty_graph now returns the base graph; build a view to exercise + # detach(), which calls graph.metadata internally via BaseGraph.from_other(). # With :memory: + default connection pool it opens a new empty DB → crash. - detached = graph.detach() + detached = graph.filter().subgraph().detach() assert detached.num_nodes() == 1 -def test_create_empty_graphview_graph_with_solution_attr(): +def test_create_empty_graph_with_solution_attr(): """Test that passing solution as a node/edge attribute does not raise. - Regression test: create_empty_graphview_graph unconditionally added the + Regression test: create_empty_graph unconditionally added the solution attribute at the end, even when it was already added via the node_attributes / edge_attributes loop, causing a ValueError. """ # Should not raise ValueError even though solution is listed explicitly - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["solution"], edge_attributes=["solution"], node_default_values=[True], From 0d368da4916e2f1dc0b8d57b158714a75b06620c Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 17:46:40 -0700 Subject: [PATCH 25/39] all attribute I/O on graph_full, topology/track_id stuff on graph_solution --- src/funtracks/actions/update_segmentation.py | 6 +++--- src/funtracks/annotators/_graph_annotator.py | 6 +++--- src/funtracks/data_model/tracks.py | 8 ++++---- src/funtracks/user_actions/user_update_segmentation.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py index 7c2cdc58..ffcc6b8d 100644 --- a/src/funtracks/actions/update_segmentation.py +++ b/src/funtracks/actions/update_segmentation.py @@ -61,13 +61,13 @@ def _apply(self) -> None: if value == 0: # val=0 means deleting (part of) the mask - mask_old = self.tracks.graph_solution.nodes[self.node][self.mask_key] + mask_old = self.tracks.graph_full.nodes[self.node][self.mask_key] mask_subtracted = mask_old.__isub__(mask_new) self.tracks.update_mask(self.node, mask_subtracted, mask_key=self.mask_key) - elif self.tracks.graph_solution.has_node(value): + elif self.tracks.graph_full.has_node(value): # if node already exists: - mask_old = self.tracks.graph_solution.nodes[value][self.mask_key] + mask_old = self.tracks.graph_full.nodes[value][self.mask_key] mask_combined = mask_old.__or__(mask_new) self.tracks.update_mask(value, mask_combined, mask_key=self.mask_key) diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py index 10b3e99c..dd193fc5 100644 --- a/src/funtracks/annotators/_graph_annotator.py +++ b/src/funtracks/annotators/_graph_annotator.py @@ -121,7 +121,7 @@ def _filter_feature_keys(self, feature_keys: list[str] | None) -> list[str]: def compute(self, feature_keys: list[str] | None = None) -> None: """Compute a set of features and add them to the tracks. - This involves both updating the node/edge attributes on the tracks.graph_solution + This involves both updating the node/edge attributes on `self.graph` and adding the features to the FeatureDict, if necessary. This is distinct from `update` to allow more efficient bulk computation of features. @@ -138,8 +138,8 @@ def compute(self, feature_keys: list[str] | None = None) -> None: def update(self, action: BasicAction) -> None: """Update a set of features based on the given action. - This involves both updating the node or edge attributes on the - tracks.graph_solution and adding the features to the FeatureDict, if + This involves both updating the node or edge attributes on `self.graph` + and adding the features to the FeatureDict, if necessary. This is distinct from `compute` to allow more efficient computation of features for single elements. diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index ad4f6cc8..a19b4b14 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -569,7 +569,7 @@ def get_mask( if self.segmentation is None: return None - mask = self.graph_solution.nodes[node][mask_key] + mask = self.graph_full.nodes[node][mask_key] return mask def update_mask( @@ -586,13 +586,13 @@ def update_mask( mask_key: The feature key for the mask column. Defaults to the standard mask key. """ - self.graph_solution.nodes[node][mask_key] = mask + self.graph_full.nodes[node][mask_key] = mask mask_feature = self.features.get(mask_key) if mask_feature is not None: # NOTE: all derived features of a mask are currently assumed to be # its bounding box. Revisit if non-bbox derived features are added. for derived_key in mask_feature.get("derived_features", []): - self.graph_solution.nodes[node][derived_key] = mask.bbox + self.graph_full.nodes[node][derived_key] = mask.bbox def undo(self) -> bool: """Undo the last performed action from the action history. @@ -882,7 +882,7 @@ def get_track_ids(self, nodes) -> list[int]: calls. For small subsets or single nodes use get_track_id() instead.""" tracklet_key = self.features.tracklet_key - df = self.graph_solution.node_attrs( + df = self.graph_full.node_attrs( attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key] ) id_to_val = dict( diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index 7576a744..04941e52 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -108,7 +108,7 @@ def __init__( time = pixels[0][0] # check if all pixels of old_value are removed mask_pixels = pixels_to_td_mask(pixels, self.tracks.ndim) - mask_old_value = self.tracks.graph_solution.nodes[old_value]["mask"] + mask_old_value = self.tracks.graph_full.nodes[old_value]["mask"] # If pixels fully overlaps with old_value mask, delete node if mask_pixels.intersection(mask_old_value) == mask_old_value.mask.sum(): self.actions.append( From 7c90f5b4c0edb83a6724fe2c665b3e99ac4ad474 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 10:47:21 -0700 Subject: [PATCH 26/39] since tracklet_id column now always exist, check if it contains any -1, if so, recompute --- src/funtracks/data_model/tracks.py | 20 +++++++++++++++++++- tests/data_model/test_solution_tracks.py | 12 ++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index a19b4b14..0bd38e1a 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -406,7 +406,25 @@ def _ensure_track_features(self) -> None: annotator = self.track_annotator self.features.tracklet_key = annotator.tracklet_key self.features.lineage_key = annotator.lineage_key - self._register_core_features([annotator.tracklet_key, annotator.lineage_key]) + # A tracklet column can exist yet still hold the -1 sentinel ("not computed", + # the column default). Trusting it would activate stale ids and seed a phantom + # tracklet -1 in the annotator bookkeeping, so force a recompute from topology. + if self._has_uncomputed_track_ids(annotator.tracklet_key): + self.enable_features([annotator.tracklet_key, annotator.lineage_key]) + else: + self._register_core_features([annotator.tracklet_key, annotator.lineage_key]) + + def _has_uncomputed_track_ids(self, tracklet_key: str) -> bool: + """True if the tracklet column exists but any node still holds the -1 sentinel. + + A missing column returns False: _register_core_features computes it from scratch. + """ + if self.graph_solution.num_nodes() == 0: + return False + if tracklet_key not in self.graph_solution.node_attr_keys(): + return False + values = self.graph_solution.node_attrs(attr_keys=[tracklet_key])[tracklet_key] + return bool((values == -1).any()) def nodes(self): return np.array(self.graph_solution.node_ids()) diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 10cdb0e0..4f1afa81 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -22,6 +22,18 @@ def test_recompute_track_ids(graph_2d_with_track_id): assert tracks.get_next_track_id() == 6 +def test_recompute_track_ids_when_sentinel_present(graph_2d_with_track_id): + """A tracklet column that exists but still holds the -1 sentinel means "not + computed": Tracks must recompute track ids from topology, not trust the sentinels.""" + graph = graph_2d_with_track_id + node_ids = list(graph.node_ids()) + graph.update_node_attrs(attrs={"track_id": [-1] * len(node_ids)}, node_ids=node_ids) + + tracks = Tracks(graph, ndim=3, **track_attrs) + + assert -1 not in tracks.get_track_ids(tracks.nodes()) + + def test_next_track_id(graph_2d_with_track_id): tracks = Tracks(graph_2d_with_track_id, ndim=3, **track_attrs) assert tracks.get_next_track_id() == 6 From 412dd60dfe8a97ae681ae97eee6b1be0887d05db Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 10:59:31 -0700 Subject: [PATCH 27/39] fix: revive an edd didn't use provided attrs --- src/funtracks/actions/add_delete_edge.py | 7 ++++-- tests/actions/test_add_delete_edge.py | 32 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index d47f0648..3bb0b64d 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -57,10 +57,13 @@ def _apply(self) -> None: if self.tracks.graph_full.has_edge(*self.edge): # Revive a soft-deleted edge (already present in the full graph as a - # candidate): flip solution=True and re-surface it in the solution view. + # candidate): flip solution=True, apply any caller-provided attributes (so + # revive matches the add-new branch), and re-surface it in the solution view. edge_id = self.tracks.graph_full.edge_id(self.edge[0], self.edge[1]) + revive_attrs = {k: v for k, v in self.attributes.items() if k != "solution"} + revive_attrs["solution"] = True self.tracks.graph_full.update_edge_attrs( - attrs={"solution": True}, edge_ids=[edge_id] + attrs=revive_attrs, edge_ids=[edge_id] ) self.tracks.graph_solution.add_edge_to_view(self.edge[0], self.edge[1]) else: diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index 212790ae..13dc1884 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -165,6 +165,38 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg): ) +def test_add_edge_revive_applies_new_attributes(get_tracks): + """Re-adding a soft-deleted edge with NEW attributes must apply them, not silently + keep the values preserved from before the delete. + + Existing tests only re-add via inverse(), which reuses the SAME attrs that + soft-delete preserved, so the revive-vs-add-new asymmetry is invisible to them. + """ + from funtracks.features import Feature + + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + tracks.add_feature( + "weight", + Feature( + feature_type="edge", + value_type="float", + num_values=1, + display_name="Weight", + default_value=None, + ), + ) + + edge = (1, 5) + AddEdge(tracks, edge, attributes={"weight": 1.5}) + DeleteEdge(tracks, edge) # soft-delete: weight=1.5 preserved in graph_full + + # Fresh add of the (now soft-deleted) edge with a DIFFERENT weight. + AddEdge(tracks, edge, attributes={"weight": 9.9}) + + edge_id = tracks.graph_solution.edge_id(*edge) + assert tracks.graph_solution.edges[edge_id]["weight"] == 9.9 + + def test_add_edge_with_unregistered_edge_attr(tmp_path): """AddEdge must not crash when the graph has edge attrs absent from tracks.features. From 0a8287b36bb5aa7ac3c3b45bf91b636de005557b Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 11:07:41 -0700 Subject: [PATCH 28/39] failsave for not having a TrackAnnotator --- src/funtracks/data_model/tracks.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 0bd38e1a..5826641a 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -862,13 +862,16 @@ def delete_feature(self, key: str) -> None: # has (track ids are a core feature). On an empty solution view they are no-ops. @property - def track_annotator(self): - """The registered TrackAnnotator — always present, since track ids are a core - feature of every Tracks.""" + def track_annotator(self) -> TrackAnnotator: + """The registered TrackAnnotator. Always present, since track ids are a core + feature of every Tracks (_get_annotators registers one unconditionally).""" for annotator in self.annotators: if isinstance(annotator, TrackAnnotator): return annotator - return None + raise RuntimeError( + "No TrackAnnotator registered on this Tracks — this should be unreachable " + "(_get_annotators always registers one)." + ) @property def max_track_id(self) -> int: From 86588bd8029ba9d8c43cc05b9def529eb095f494 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 11:10:07 -0700 Subject: [PATCH 29/39] tracks.get_track_neighbors should not permanently sort the tracklet_ids --- src/funtracks/data_model/tracks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 5826641a..bcb5fd3c 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -948,8 +948,9 @@ def get_track_neighbors( or len(self.track_annotator.tracklet_id_to_nodes[track_id]) == 0 ): return None, None - candidates = self.track_annotator.tracklet_id_to_nodes[track_id] - candidates.sort(key=lambda n: self.get_time(n)) + candidates = sorted( + self.track_annotator.tracklet_id_to_nodes[track_id], key=self.get_time + ) pred = None succ = None From 5e4ad28afb5166268d36589657f672a391bff227 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 11:19:12 -0700 Subject: [PATCH 30/39] note about tracksdata --- src/funtracks/data_model/tracks.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index bcb5fd3c..7febe5bd 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -802,7 +802,15 @@ def add_feature(self, key: str, feature: Feature) -> None: # Add to the features dictionary self.features[key] = feature - # Perform custom graph operations when a feature is added + # Perform custom graph operations when a feature is added. + # + # Schema (attr-key) registration is done on graph_solution (the view), NOT on + # graph_full, even though annotators write the VALUES to graph_full. This relies + # on a tracksdata invariant: adding an attr key to a view propagates up to its + # root, so the column ends up on both. The reverse does NOT hold today — adding + # a key directly to the root is not propagated down into an existing view — so + # registering on graph_full would leave graph_solution without the column. + # If tracksdata ever makes view attr-key additions local, revisit this. ft = feature["feature_type"] if "node" in ft and key not in self.graph_solution.node_attr_keys(): # "mask" value_type maps to pl.Object via to_polars_dtype @@ -851,7 +859,9 @@ def delete_feature(self, key: str) -> None: else: return - # Perform custom graph operations when a feature is deleted + # Perform custom graph operations when a feature is deleted. Schema ops go + # through graph_solution (the view) and propagate to the root — same tracksdata + # invariant as add_feature (see the note there). if "node" in feature_type and key in self.graph_solution.node_attr_keys(): self.graph_solution.remove_node_attr_key(key) if "edge" in feature_type and key in self.graph_solution.edge_attr_keys(): From 0703f41fbd7d0d254a716abc56ee72b0c18285b1 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 11:20:33 -0700 Subject: [PATCH 31/39] TrackAnnotator didn't use the defined self.graph --- src/funtracks/annotators/_track_annotator.py | 27 ++++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index 5ab86296..bad62fb2 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -120,11 +120,9 @@ def _get_max_id_and_map(self, key: str) -> tuple[int, dict[int, list[int]]]: tuple[int, dict[int, list[int]]]: The maximum id value, and a mapping from ids to a list of nodes with that id. """ - if key not in self.tracks.graph_solution.node_attr_keys(): + if key not in self.graph.node_attr_keys(): return 0, {} - df = self.tracks.graph_solution.node_attrs( - attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, key] - ) + df = self.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, key]) id_to_nodes = defaultdict(list) for node, _id in zip(df[td.DEFAULT_ATTR_KEYS.NODE_ID], df[key], strict=True): if _id is None: @@ -188,15 +186,12 @@ def _assign_lineage_ids(self) -> None: attributes will be updated. """ - lineages_internal = rx.weakly_connected_components( - self.tracks.graph_solution.rx_graph - ) + lineages_internal = rx.weakly_connected_components(self.graph.rx_graph) # Map each component's internal node indices to external ids in one batched # call. node_ids() rebuilds the full external-id list on every call, so calling # it per node (as before) was O(N^2). lineages_external = [ - self.tracks.graph_solution._map_to_external(list(lin)) - for lin in lineages_internal + self.graph._map_to_external(list(lin)) for lin in lineages_internal ] max_id, ids_to_nodes = self._assign_ids(lineages_external, self.lineage_key) @@ -214,7 +209,7 @@ def _assign_tracklet_ids(self) -> None: # slow because each remove_edge syncs the view's bidirectional edge maps, # whereas rustworkx edge removal is in-memory. copy() preserves node indices, # so components map back through the original graph's id mapping. - rx_copy = self.tracks.graph_solution.rx_graph.copy() + rx_copy = self.graph.rx_graph.copy() for node in rx_copy.node_indices(): if rx_copy.out_degree(node) >= 2: for _, daughter, _ in list(rx_copy.out_edges(node)): @@ -225,15 +220,13 @@ def _assign_tracklet_ids(self) -> None: all_track_ids = [] for tracklet in rx.weakly_connected_components(rx_copy): # Batched internal -> external mapping (see _assign_lineage_ids). - node_ids_external = self.tracks.graph_solution._map_to_external( - list(tracklet) - ) + node_ids_external = self.graph._map_to_external(list(tracklet)) all_node_ids.extend(node_ids_external) all_track_ids.extend([track_id] * len(node_ids_external)) self.tracklet_id_to_nodes[track_id] = node_ids_external track_id += 1 if all_node_ids: - self.tracks.graph_solution.update_node_attrs( + self.graph.update_node_attrs( attrs={self.tracks.features.tracklet_key: all_track_ids}, node_ids=all_node_ids, ) @@ -303,18 +296,18 @@ def _handle_update_track_ids(self, action: UpdateTrackIDs) -> None: still_in_tracklet = False # Continue to all successors - next_nodes.extend(self.tracks.graph_solution.successors(node)) + next_nodes.extend(self.graph.successors(node)) curr_nodes = next_nodes # Bulk-write all collected node attribute changes in one call each if update_tracklet and tracklet_nodes: - self.tracks.graph_solution.update_node_attrs( + self.graph.update_node_attrs( attrs={self.tracklet_key: [new_tracklet_id] * len(tracklet_nodes)}, node_ids=tracklet_nodes, ) if update_lineage and lineage_nodes: - self.tracks.graph_solution.update_node_attrs( + self.graph.update_node_attrs( attrs={self.lineage_key: [new_lineage_id] * len(lineage_nodes)}, node_ids=lineage_nodes, ) From b0f4429139d26c8db48bc23abe29537fcfb2a1fc Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 11:32:48 -0700 Subject: [PATCH 32/39] warn that graph annotators compute on every possible candidate edge (can be expensive) --- src/funtracks/annotators/_graph_annotator.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py index dd193fc5..ba681508 100644 --- a/src/funtracks/annotators/_graph_annotator.py +++ b/src/funtracks/annotators/_graph_annotator.py @@ -64,6 +64,14 @@ def graph(self): (`tracklet_id`, `lineage_id`) are derived from the solution topology. Root and view share attribute storage, so writes made here are seen through the solution view automatically. + + KNOWN COST: computing over `graph_full` means `compute()`/`update()` run + regionprops / mask.iou for every candidate (solution=False) node/edge too. On + candidate graphs with many unselected detections (often 10-100x the solution + size) this is O(candidates) work that is mostly never read, since revive is rare. + This is a deliberate eagerness-for-readiness trade; if it becomes a bottleneck, + the lever is lazy compute-on-revive (default to graph_solution and compute a + candidate's intrinsic features only when it enters the solution). """ return self.tracks.graph_full From d9c9735c6d60115b072d3c0b82bcca6a9a091358 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 14:13:49 -0700 Subject: [PATCH 33/39] update tracksdata to v0.1.0rc6 --- pyproject.toml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 392125b6..bc7e72c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies =[ "dask>=2025.5.0", "pandas>=2.3.3", "zarr>=2.18,<4", - "tracksdata[spatial]>=0.1.0rc5", + "tracksdata[spatial]>=0.1.0rc6", "tqdm>=4.66.1", # zarr 2.x's util.py imports cbuffer_sizes/cbuffer_metainfo from # numcodecs.blosc, which numcodecs >= 0.16 removed. Pin numcodecs per @@ -123,7 +123,3 @@ explicit_package_bases = true exclude_also = [ "if TYPE_CHECKING:", ] - -#remove this after tracksdata PR is merged: -[tool.uv.sources] -tracksdata = { git = "https://github.com/TeunHuijben/tracksdata.git", rev = "06465dcabb772622aed2ae66417c0b857b4ed35c" } From 588e57a3e3cafc716026d5705ab01e3d2bcce5a6 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 14:24:36 -0700 Subject: [PATCH 34/39] revive node (AddNode) didnt assign provided attrs to revived node, simply took the old ones --- src/funtracks/actions/add_delete_node.py | 9 ++++++--- tests/actions/test_add_delete_nodes.py | 25 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 978c29ed..1d1b7f1c 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -83,10 +83,13 @@ def _apply(self) -> None: """ if self.tracks.graph_full.has_node(self.node): # Revive: same node id, topology preserved in graph_full. Flip it back into - # the solution and re-surface it in the view in place (incident edges are - # revived separately by AddEdge). + # the solution, apply any caller-provided attributes (so revive matches the + # add-new branch), and re-surface it in the view in place (incident edges + # are revived separately by AddEdge). + revive_attrs = {k: v for k, v in self.attributes.items() if k != "solution"} + revive_attrs["solution"] = True self.tracks.graph_full.update_node_attrs( - attrs={"solution": True}, node_ids=[self.node] + attrs=revive_attrs, node_ids=[self.node] ) self.tracks.graph_solution.add_node_to_view(self.node) else: diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py index 64990ebe..b716355c 100644 --- a/tests/actions/test_add_delete_nodes.py +++ b/tests/actions/test_add_delete_nodes.py @@ -128,6 +128,31 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): ) +def test_add_node_revive_applies_new_attributes(get_tracks): + """Re-adding a soft-deleted node with NEW attributes must apply them, not silently + keep the values preserved from before the delete. + + Existing tests only re-add via inverse(), which reuses the SAME attrs that + soft-delete preserved, so the revive-vs-add-new asymmetry is invisible to them. + Mirrors test_add_edge_revive_applies_new_attributes. + """ + from funtracks.actions import DeleteNode + + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + + node_id = 100 + AddNode(tracks, node_id, {"t": 2, "track_id": 10, "pos": [50.0, 50.0]}) + DeleteNode(tracks, node_id) # soft-delete: attrs preserved in graph_full + + # Fresh add of the (now soft-deleted) node with DIFFERENT attributes. + AddNode(tracks, node_id, {"t": 2, "track_id": 11, "pos": [60.0, 60.0]}) + + assert tracks.graph_solution.nodes[node_id]["track_id"] == 11 + assert_array_almost_equal( + tracks.graph_solution.nodes[node_id]["pos"], np.array([60.0, 60.0]) + ) + + def test_add_node_missing_time(get_tracks): tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(ValueError, match="Must provide a time attribute for node"): From efb1034d1a2cc5b6b3d797c3f5dbcad2406b0ae6 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 14:36:42 -0700 Subject: [PATCH 35/39] fix add node/edge revive regarding attrs --- src/funtracks/actions/add_delete_edge.py | 6 +++-- src/funtracks/actions/add_delete_node.py | 6 +++-- src/funtracks/data_model/tracks.py | 10 ++++++-- tests/actions/test_add_delete_edge.py | 32 ++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index 3bb0b64d..3a9fb28a 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -60,8 +60,10 @@ def _apply(self) -> None: # candidate): flip solution=True, apply any caller-provided attributes (so # revive matches the add-new branch), and re-surface it in the solution view. edge_id = self.tracks.graph_full.edge_id(self.edge[0], self.edge[1]) - revive_attrs = {k: v for k, v in self.attributes.items() if k != "solution"} - revive_attrs["solution"] = True + # Values are wrapped in single-element lists because update_edge_attrs + # reads a bare list value (e.g. a vector feature) as one-value-per-edge. + revive_attrs = {k: [v] for k, v in self.attributes.items() if k != "solution"} + revive_attrs["solution"] = [True] self.tracks.graph_full.update_edge_attrs( attrs=revive_attrs, edge_ids=[edge_id] ) diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 1d1b7f1c..ed3df040 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -86,8 +86,10 @@ def _apply(self) -> None: # the solution, apply any caller-provided attributes (so revive matches the # add-new branch), and re-surface it in the view in place (incident edges # are revived separately by AddEdge). - revive_attrs = {k: v for k, v in self.attributes.items() if k != "solution"} - revive_attrs["solution"] = True + # Values are wrapped in single-element lists because update_node_attrs + # reads a bare list value (pos, bbox, mask) as one-value-per-node. + revive_attrs = {k: [v] for k, v in self.attributes.items() if k != "solution"} + revive_attrs["solution"] = [True] self.tracks.graph_full.update_node_attrs( attrs=revive_attrs, node_ids=[self.node] ) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 7febe5bd..0caf0cb3 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -706,7 +706,9 @@ def get_nodes_attr(self, nodes: Iterable[Node], attr: str): def _set_edge_attr(self, edge: Edge, attr: str, value: Any): edge_id = self.graph_full.edge_id(edge[0], edge[1]) - self.graph_full.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) + # Wrap in a single-element list: update_edge_attrs reads a bare list value + # (e.g. a vector feature) as one-value-per-edge. + self.graph_full.update_edge_attrs(attrs={attr: [value]}, edge_ids=[edge_id]) def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]): for edge, value in zip(edges, values, strict=False): @@ -824,10 +826,14 @@ def add_feature(self, key: str, feature: Feature) -> None: dtype=dtype, ) if "edge" in ft and key not in self.graph_solution.edge_attr_keys(): + dtype = to_polars_dtype(feature["value_type"]) + num_values = feature.get("num_values") + if num_values is not None and num_values > 1: + dtype = pl.Array(dtype, num_values) self.graph_solution.add_edge_attr_key( key, default_value=feature["default_value"], - dtype=to_polars_dtype(feature["value_type"]), + dtype=dtype, ) def delete_feature(self, key: str) -> None: diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index 13dc1884..13fb1213 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -197,6 +197,38 @@ def test_add_edge_revive_applies_new_attributes(get_tracks): assert tracks.graph_solution.edges[edge_id]["weight"] == 9.9 +def test_add_edge_revive_applies_new_vector_attributes(get_tracks): + """Reviving an edge with a non-scalar (vector) attribute must apply it correctly. + + update_edge_attrs reads a bare list value as one-value-per-edge, so passing a + vector attr unwrapped for a single edge raises a size mismatch (or, for a + length-1 list, silently unwraps it to its element). + """ + from funtracks.features import Feature + + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + tracks.add_feature( + "flow", + Feature( + feature_type="edge", + value_type="float", + num_values=2, + display_name=["Flow Y", "Flow X"], + default_value=None, + ), + ) + + edge = (1, 5) + AddEdge(tracks, edge, attributes={"flow": [1.0, 2.0]}) + DeleteEdge(tracks, edge) # soft-delete: flow preserved in graph_full + + # Fresh add of the (now soft-deleted) edge with a DIFFERENT flow vector. + AddEdge(tracks, edge, attributes={"flow": [3.0, 4.0]}) + + edge_id = tracks.graph_solution.edge_id(*edge) + assert list(tracks.graph_solution.edges[edge_id]["flow"]) == [3.0, 4.0] + + def test_add_edge_with_unregistered_edge_attr(tmp_path): """AddEdge must not crash when the graph has edge attrs absent from tracks.features. From fca98a0ada023165174cfd2eb548a58eb74b507d Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 14:47:34 -0700 Subject: [PATCH 36/39] tracklet_id_to_nodes inconsistent after recompute --- src/funtracks/annotators/_track_annotator.py | 7 +++++- tests/annotators/test_track_annotator.py | 24 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index bad62fb2..38c81c7a 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -218,18 +218,23 @@ def _assign_tracklet_ids(self) -> None: track_id = 1 all_node_ids = [] all_track_ids = [] + id_to_nodes = {} for tracklet in rx.weakly_connected_components(rx_copy): # Batched internal -> external mapping (see _assign_lineage_ids). node_ids_external = self.graph._map_to_external(list(tracklet)) all_node_ids.extend(node_ids_external) all_track_ids.extend([track_id] * len(node_ids_external)) - self.tracklet_id_to_nodes[track_id] = node_ids_external + id_to_nodes[track_id] = node_ids_external track_id += 1 if all_node_ids: self.graph.update_node_attrs( attrs={self.tracks.features.tracklet_key: all_track_ids}, node_ids=all_node_ids, ) + # Replace the bookkeeping wholesale (like _assign_lineage_ids) so stale + # entries — e.g. a phantom tracklet -1 seeded at init from a graph whose + # tracklet column still held the -1 sentinel — don't survive a recompute. + self.tracklet_id_to_nodes = id_to_nodes self.max_tracklet_id = track_id - 1 def update(self, action: BasicAction) -> None: diff --git a/tests/annotators/test_track_annotator.py b/tests/annotators/test_track_annotator.py index 8987428f..a2106e19 100644 --- a/tests/annotators/test_track_annotator.py +++ b/tests/annotators/test_track_annotator.py @@ -309,3 +309,27 @@ def test_disabled_tracklet_key_does_nothing(self, get_tracks, ndim, with_seg) -> assert ann.lineage_id_to_nodes == original_lineage_map assert ann.max_tracklet_id == original_max_tracklet assert ann.max_lineage_id == original_max_lineage + + +def test_recompute_replaces_stale_bookkeeping(get_tracks): + """A tracklet column still holding the -1 sentinel seeds a phantom tracklet -1 in + the bookkeeping at annotator init (_get_max_id_and_map only skips None). A full + compute() must replace tracklet_id_to_nodes wholesale (like _assign_lineage_ids + does), not merge into it — otherwise the phantom entry survives with its nodes + duplicated under their real tracklet ids. + """ + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + node_ids = list(tracks.graph_full.node_ids()) + tracks.graph_full.update_node_attrs( + attrs={"track_id": [-1] * len(node_ids)}, node_ids=node_ids + ) + + ann = TrackAnnotator(tracks, tracklet_key="track_id") + ann.activate_features(["track_id", "lineage_id"]) + ann.compute() + + assert -1 not in ann.tracklet_id_to_nodes + # Every solution node appears exactly once across the bookkeeping. + all_nodes = sorted(n for ns in ann.tracklet_id_to_nodes.values() for n in ns) + assert all_nodes == sorted(tracks.graph_solution.node_ids()) + assert ann.max_tracklet_id == max(ann.tracklet_id_to_nodes.keys()) From d5717628503e12fb9e9b9c5031835c8ab6b1fd37 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 14:55:16 -0700 Subject: [PATCH 37/39] warm up windows cache for fair benchmark --- tests/benchmarks/bench_candidate_graph.py | 27 +++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/benchmarks/bench_candidate_graph.py b/tests/benchmarks/bench_candidate_graph.py index f74f65a2..33078e87 100644 --- a/tests/benchmarks/bench_candidate_graph.py +++ b/tests/benchmarks/bench_candidate_graph.py @@ -43,6 +43,28 @@ def seg_data(): return _generate_segmentation() +@pytest.fixture(scope="module") +def _warm_spatial_compile(): + """Warm the spatial_graph rtree JIT compile before the timed benchmark. + + Building a Tracks with segmentation creates a GraphArrayView, which builds a + spatial_graph rtree whose Cython module is JIT-compiled by witty on first use and + cached process-wide. That compile is a one-time cost (seconds-to-tens-of-seconds on + a cold Windows runner) that would otherwise land inside the timed region of + test_graph_to_solution. Trigger it here (untimed) on a tiny graph with the same + dimensionality so the benchmark measures tracking work, not compilation. + """ + tiny = _generate_segmentation(num_frames=2, frame_shape=(64, 64), cells_per_frame=2) + graph = compute_graph_from_seg(tiny, MAX_EDGE_DISTANCE, iou=True) + Tracks( + graph, + time_attr="t", + pos_attr="pos", + tracklet_attr="tracklet_id", + lineage_attr="lineage_id", + ) + + def test_compute_graph_from_seg(benchmark, seg_data): benchmark.pedantic( compute_graph_from_seg, @@ -53,14 +75,15 @@ def test_compute_graph_from_seg(benchmark, seg_data): ) -def test_graph_to_solution(benchmark, seg_data): +def test_graph_to_solution(benchmark, seg_data, _warm_spatial_compile): """Benchmark candidate graph -> Tracks with track ids (tracklet/lineage assignment). Candidate-graph construction is benchmarked separately above and is built here in (untimed) setup, so only the Tracks construction -- dominated by TrackAnnotator._assign_tracklet_ids and _assign_lineage_ids -- is measured. Setting tracklet_attr/lineage_attr registers a TrackAnnotator and triggers the - track-id computation on construction. + track-id computation on construction. The _warm_spatial_compile fixture ensures the + one-time spatial_graph rtree JIT compile is not measured here. """ def setup(): From 902a0ee8f6b80600bbd40f9acf27bde5a758fbbb Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 14:56:03 -0700 Subject: [PATCH 38/39] ruff fixes --- pyproject.toml | 2 -- src/funtracks/data_model/tracks.py | 16 +++++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bc7e72c3..581810d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,8 +107,6 @@ unfixable = [ ] [tool.ruff.lint.per-file-ignores] "tests/*" = ["D"] # no docstrings in tests -"src/funtracks/data_model/tracks.py" = ["D"] # Remove this when refactoring tracks -"src/funtracks/data_model/solution_tracks.py" = ["D"] # Remove this when refactoring tracks "__init__.py" = ["F401"] # unused imports allowed in __init__.py # https://docs.astral.sh/ruff/formatter/ diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 0caf0cb3..0d1c1096 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -427,15 +427,19 @@ def _has_uncomputed_track_ids(self, tracklet_key: str) -> bool: return bool((values == -1).any()) def nodes(self): + """Return the node ids of the solution graph as a numpy array.""" return np.array(self.graph_solution.node_ids()) def edges(self): + """Return the edge ids of the solution graph as a numpy array.""" return np.array(self.graph_solution.edge_ids()) def predecessors(self, node: int) -> list[int]: + """Return the predecessors of a node in the solution graph.""" return list(self.graph_solution.predecessors(node)) def successors(self, node: int) -> list[int]: + """Return the successors of a node in the solution graph.""" return list(self.graph_solution.successors(node)) def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.ndarray: @@ -447,7 +451,7 @@ def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.nd For a single node use get_position() instead. Args: - node (Iterable[Node]): The node ids in the graph to get the positions of + nodes (Iterable[Node]): The node ids in the graph to get the positions of incl_time (bool, optional): If true, include the time as the first element of each position array. Defaults to False. @@ -538,6 +542,7 @@ def set_positions( self._set_nodes_attr(nodes, self.features.position_key, positions) def set_position(self, node: Node, position: list | np.ndarray): + """Set the position of a single node.""" self.set_positions([node], np.expand_dims(np.array(position), axis=0)) def get_times(self, nodes: Iterable[Node]) -> Sequence[int]: @@ -699,9 +704,11 @@ def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any ) def get_node_attr(self, node: Node, attr: str): + """Get an attribute value for a single node (resolved on graph_full).""" return self.graph_full.nodes[int(node)][attr] def get_nodes_attr(self, nodes: Iterable[Node], attr: str): + """Get an attribute value for each of the given nodes.""" return [self.get_node_attr(node, attr) for node in nodes] def _set_edge_attr(self, edge: Edge, attr: str, value: Any): @@ -716,12 +723,17 @@ def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any self.graph_full.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def get_edge_attr(self, edge: Edge, attr: str): + """Get an attribute value for a single edge (resolved on graph_full). + + Returns None if the attribute is not registered on the graph. + """ if attr not in self.graph_full.edge_attr_keys(): return None edge_id = self.graph_full.edge_id(edge[0], edge[1]) return self.graph_full.edges[edge_id][attr] def get_edges_attr(self, edges: Iterable[Edge], attr: str): + """Get an attribute value for each of the given edges.""" return [self.get_edge_attr(edge, attr) for edge in edges] # ========== Feature Management ========== @@ -891,6 +903,7 @@ def track_annotator(self) -> TrackAnnotator: @property def max_track_id(self) -> int: + """The maximum tracklet id currently in use.""" return self.track_annotator.max_tracklet_id def get_next_track_id(self) -> int: @@ -910,6 +923,7 @@ def get_next_lineage_id(self) -> int: return self.track_annotator.max_lineage_id + 1 def get_track_id(self, node) -> int: + """Get the tracklet id of a single node.""" track_id = self.get_node_attr(node, self.features.tracklet_key) return track_id From 66aa3521eb851c3f2f812af4fecdadb8f966eaa9 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 1 Jul 2026 15:09:30 -0700 Subject: [PATCH 39/39] docstring updates --- src/funtracks/data_model/tracks.py | 2 +- src/funtracks/import_export/geff/_export.py | 10 +++++++++- src/funtracks/user_actions/user_add_node.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 0d1c1096..6131dcb6 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -295,7 +295,7 @@ def _get_annotators(self) -> AnnotatorRegistry: Creates annotators conditionally: - RegionpropsAnnotator: Only if segmentation is provided - EdgeAnnotator: Only if segmentation is provided - - TrackAnnotator: Only if this is a Tracks instance + - TrackAnnotator: Always (every Tracks has track ids) Each annotator is configured with appropriate keys from self.features. diff --git a/src/funtracks/import_export/geff/_export.py b/src/funtracks/import_export/geff/_export.py index 8c1ad17a..7ef00d1d 100644 --- a/src/funtracks/import_export/geff/_export.py +++ b/src/funtracks/import_export/geff/_export.py @@ -34,6 +34,11 @@ def write_to_geff( geff store directly to *path*. Intended for internal save/load workflows where the user picks the ``.geff`` path. + Note: only ``graph_solution`` is written. Soft-deleted (``solution=False``) + candidates are dropped, and reimport marks everything ``solution=True`` — so + undo-ability of past deletes does not survive a save/load round-trip (Phase-1 + design; candidate persistence is deferred to the candidate/solver phase). + Args: tracks: Tracks object containing a graph to save. path: Destination path for the geff store. @@ -65,6 +70,9 @@ def export_to_geff( ): """Export the Tracks graph to geff. + Only the solution graph is exported; soft-deleted (``solution=False``) + candidates are not included. + Args: tracks (Tracks): Tracks object containing a graph to save. directory (Path): Destination directory for saving the Zarr. @@ -193,7 +201,7 @@ def _write_segmentation_shape(geff_path: Path, tracks: Tracks) -> None: This allows import_from_geff to reconstruct the segmentation (GraphArrayView) without requiring an external segmentation file. """ - seg_shape = tracks.graph_solution.metadata.get("segmentation_shape") + seg_shape = tracks.graph_full.metadata.get("segmentation_shape") if seg_shape is not None: import zarr as _zarr diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index dbb2940c..af4ad0f9 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -118,7 +118,7 @@ def __init__( # downstream elif succ is not None: # check pred of succ - preds = self.tracks.graph_solution.predecessors(succ) + preds = self.tracks.predecessors(succ) pred_of_succ = preds[0] if preds else None if ( pred_of_succ is not None