diff --git a/docs/features.md b/docs/features.md index 8022a26e..ddfeb217 100644 --- a/docs/features.md +++ b/docs/features.md @@ -54,7 +54,7 @@ An abstract base class for components that compute and update features on a grap |-----------|---------|--------------|-------------------|---------------| | **RegionpropsAnnotator** | Extracts node features from segmentation using scikit-image's `regionprops` | `segmentation` must not be `None` | `pos`, `area`, `ellipse_axis_radii`, `circularity`, `perimeter` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.RegionpropsAnnotator) | | **EdgeAnnotator** | Computes edge features based on segmentation overlap between consecutive time frames | `segmentation` must not be `None` | `iou` (Intersection over Union) | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.EdgeAnnotator) | -| **TrackAnnotator** | Computes tracklet and lineage IDs for SolutionTracks | Must be used with `SolutionTracks` (binary tree structure) | `tracklet_id`, `lineage_id` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.TrackAnnotator) | +| **TrackAnnotator** | Computes tracklet and lineage IDs | Requires `tracks.features.tracklet_key` to be set (i.e. a tracking solution with a binary-tree structure) | `tracklet_id`, `lineage_id` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.TrackAnnotator) | ### 5. AnnotatorRegistry @@ -149,7 +149,8 @@ classDiagram } class Tracks { - +graph: td.graph.GraphView + +graph_full: td.graph.BaseGraph + +graph_solution: td.graph.GraphView +segmentation: ndarray|None +features: FeatureDict +annotators: AnnotatorRegistry @@ -199,10 +200,10 @@ provided: ```python # Uses default: tracklet_key="tracklet_id" -tracks = SolutionTracks(graph=graph) +tracks = Tracks(graph=graph, tracklet_attr="tracklet_id") # Uses custom attribute name -tracks = SolutionTracks(graph=graph, tracklet_attr="my_track_col") +tracks = Tracks(graph=graph, tracklet_attr="my_track_col") ``` Custom attribute names are also supported through `FeatureDict`: @@ -213,7 +214,7 @@ fd = FeatureDict( tracklet_key="my_track_col", lineage_key="my_lineage", ) -tracks = SolutionTracks(graph=graph, features=fd) +tracks = Tracks(graph=graph, features=fd) ``` When a `FeatureDict` is provided, the `tracklet_attr`/`pos_attr`/`time_attr` arguments @@ -229,7 +230,7 @@ TrackAnnotator(tracks, tracklet_key="my_track", lineage_key="my_lineage") RegionpropsAnnotator(tracks, pos_key="coordinates") ``` -When constructing `Tracks` or `SolutionTracks` directly, you have full control over +When constructing `Tracks` directly, you have full control over which attribute names are used. **Through the import path** (`tracks_from_df`, `import_from_geff`), computed features @@ -286,11 +287,11 @@ tracks = tracks_from_df(df, segmentation=seg) **Scenario 2: Creating tracks from raw segmentation** ```python -from funtracks.utils import create_empty_graphview_graph +from funtracks.utils import create_empty_graph from funtracks.data_model import Tracks # Create empty graph and add nodes -graph = create_empty_graphview_graph() +graph = create_empty_graph() graph.add_node(index=1, attrs={"t": 0}) tracks = Tracks(graph, segmentation=seg) # Auto-detection: pos, area don't exist → compute them from segmentation @@ -353,7 +354,7 @@ tracks.disable_features(["area"]) def compute(self, feature_keys=None): # Compute feature values in bulk if "custom" in self.features: - for node in self.tracks.graph.node_ids(): + for node in self.tracks.graph_solution.node_ids(): value = self._compute_custom(node) self.tracks[node]["custom"] = value diff --git a/docs/import-flow.md b/docs/import-flow.md index 1cec7319..5710248b 100644 --- a/docs/import-flow.md +++ b/docs/import-flow.md @@ -25,7 +25,7 @@ graph LR Validate["validate
check graph structure"] ConstructGraph["construct_graph
build tracksdata graph"] HandleSeg["handle_segmentation
load & relabel if needed"] - CreateTracks[Create SolutionTracks] + CreateTracks[Create Tracks] EnableFeatures["enable_features
register & compute"] ValidateMap --> LoadSource diff --git a/pyproject.toml b/pyproject.toml index 0c5eb7a8..581810d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies =[ "dask>=2025.5.0", "pandas>=2.3.3", "zarr>=2.18,<4", - "tracksdata[spatial]==0.1.0rc4", + "tracksdata[spatial]>=0.1.0rc6", "tqdm>=4.66.1", # zarr 2.x's util.py imports cbuffer_sizes/cbuffer_metainfo from # numcodecs.blosc, which numcodecs >= 0.16 removed. Pin numcodecs per @@ -107,8 +107,6 @@ unfixable = [ ] [tool.ruff.lint.per-file-ignores] "tests/*" = ["D"] # no docstrings in tests -"src/funtracks/data_model/tracks.py" = ["D"] # Remove this when refactoring tracks -"src/funtracks/data_model/solution_tracks.py" = ["D"] # Remove this when refactoring tracks "__init__.py" = ["F401"] # unused imports allowed in __init__.py # https://docs.astral.sh/ruff/formatter/ diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index dda79e80..3a9fb28a 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -45,35 +45,51 @@ def _apply(self) -> None: Raises: ValueError if an endpoint of the edge does not exist """ - # Check that both endpoints exist before computing edge attributes + # Check that both endpoints exist in the solution before adding the edge for node in self.edge: - if not self.tracks.graph.has_node(node): + if not self.tracks.graph_solution.has_node(node): raise ValueError( - f"Cannot add edge {self.edge}: endpoint {node} not in graph yet" + f"Cannot add edge {self.edge}: endpoint {node} not in solution yet" ) - if self.tracks.graph.has_edge(*self.edge): - raise ValueError(f"Edge {self.edge} already exists in the graph") - - attrs = dict(self.attributes) - - # Fill in missing edge attributes with schema defaults (includes - # solution and any other registered edge attrs). - schemas = self.tracks.graph._edge_attr_schemas() - for attr in self.tracks.graph.edge_attr_keys(): - if attr not in attrs: - # An edge added to a Tracks graph is by definition part of the - # solution, so default `solution` to True rather than the schema - # default, which can be wrong (e.g. Float64/0.0) on graphs loaded - # from geff. An explicit caller-provided value still wins. - attrs[attr] = True if attr == "solution" else schemas[attr].default_value - - # Create edge attributes for this specific edge - self.tracks.graph.add_edge( - source_id=self.edge[0], - target_id=self.edge[1], - attrs=attrs, - ) + if self.tracks.graph_solution.has_edge(*self.edge): + raise ValueError(f"Edge {self.edge} already exists in the solution") + + if self.tracks.graph_full.has_edge(*self.edge): + # Revive a soft-deleted edge (already present in the full graph as a + # candidate): flip solution=True, apply any caller-provided attributes (so + # revive matches the add-new branch), and re-surface it in the solution view. + edge_id = self.tracks.graph_full.edge_id(self.edge[0], self.edge[1]) + # Values are wrapped in single-element lists because update_edge_attrs + # reads a bare list value (e.g. a vector feature) as one-value-per-edge. + revive_attrs = {k: [v] for k, v in self.attributes.items() if k != "solution"} + revive_attrs["solution"] = [True] + self.tracks.graph_full.update_edge_attrs( + attrs=revive_attrs, edge_ids=[edge_id] + ) + self.tracks.graph_solution.add_edge_to_view(self.edge[0], self.edge[1]) + else: + attrs = dict(self.attributes) + + # Fill in missing edge attributes with schema defaults (includes + # solution and any other registered edge attrs). + schemas = self.tracks.graph_solution._edge_attr_schemas() + for attr in self.tracks.graph_solution.edge_attr_keys(): + if attr not in attrs: + # An edge added to a Tracks graph is by definition part of the + # solution, so default `solution` to True rather than the schema + # default, which can be wrong (e.g. Float64/0.0) on graphs loaded + # from geff. An explicit caller-provided value still wins. + attrs[attr] = ( + True if attr == "solution" else schemas[attr].default_value + ) + + # Create edge attributes for this specific edge + self.tracks.graph_solution.add_edge( + source_id=self.edge[0], + target_id=self.edge[1], + attrs=attrs, + ) # Notify annotators to recompute features (will overwrite computed ones) self.tracks.notify_annotators(self) @@ -93,7 +109,7 @@ def __init__(self, tracks: Tracks, edge: Edge): """ super().__init__(tracks) self.edge = edge - if not self.tracks.graph.has_edge(*self.edge): + if not self.tracks.graph_solution.has_edge(*self.edge): raise ValueError(f"Edge {self.edge} not in the graph, and cannot be removed") # Save all edge feature values from the features dict @@ -110,5 +126,12 @@ def inverse(self) -> BasicAction: return AddEdge(self.tracks, self.edge, attributes=self.attributes) def _apply(self) -> None: - self.tracks.graph.remove_edge(*self.edge) + """Soft-delete the edge: flag solution=False in the full graph and remove it + from the solution view only. The edge is preserved in graph_full (as a + candidate) so the delete is reversible.""" + edge_id = self.tracks.graph_full.edge_id(self.edge[0], self.edge[1]) + self.tracks.graph_full.update_edge_attrs( + attrs={"solution": False}, edge_ids=[edge_id] + ) + self.tracks.graph_solution.remove_edge_from_view(self.edge[0], self.edge[1]) self.tracks.notify_annotators(self) diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 83f87bfc..ed3df040 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -9,8 +9,7 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model.solution_tracks import SolutionTracks - from funtracks.data_model.tracks import Node + from funtracks.data_model.tracks import Node, Tracks class AddNode(BasicAction): @@ -22,7 +21,7 @@ class AddNode(BasicAction): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: Node, attributes: dict[str, Any], ): @@ -40,7 +39,7 @@ def __init__( ValueError: If neither position nor a mask feature is in attributes. """ super().__init__(tracks) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class self.node = int(node) # Get keys from tracks features @@ -76,14 +75,34 @@ def inverse(self) -> BasicAction: return DeleteNode(self.tracks, self.node) def _apply(self) -> None: - """Add the node with all attributes from self.attributes.""" - attrs = dict(self.attributes) - # A node added to a Tracks graph is by definition part of the solution, - # so default `solution` to True rather than the column schema default, - # which can be wrong on graphs loaded from geff. An explicit - # caller-provided value still wins. - attrs.setdefault("solution", True) - self.tracks.graph.add_node(attrs=attrs, index=self.node, validate_keys=False) + """Add the node, or revive a soft-deleted one. + + If the node still exists in the full graph (it was soft-deleted, so its + topology was preserved), revive it: flip solution=True and re-surface it in the + solution view. Otherwise add a genuinely new node. + """ + if self.tracks.graph_full.has_node(self.node): + # Revive: same node id, topology preserved in graph_full. Flip it back into + # the solution, apply any caller-provided attributes (so revive matches the + # add-new branch), and re-surface it in the view in place (incident edges + # are revived separately by AddEdge). + # Values are wrapped in single-element lists because update_node_attrs + # reads a bare list value (pos, bbox, mask) as one-value-per-node. + revive_attrs = {k: [v] for k, v in self.attributes.items() if k != "solution"} + revive_attrs["solution"] = [True] + self.tracks.graph_full.update_node_attrs( + attrs=revive_attrs, node_ids=[self.node] + ) + self.tracks.graph_solution.add_node_to_view(self.node) + else: + # Genuinely new node — default `solution` to True rather than the + # column schema default, which can be wrong (e.g. Float64/0.0) on + # graphs loaded from geff. An explicit caller-provided value still wins. + attrs = dict(self.attributes) + attrs.setdefault("solution", True) + self.tracks.graph_solution.add_node( + attrs=attrs, index=self.node, validate_keys=False + ) # Always notify annotators - they will check their own preconditions self.tracks.notify_annotators(self) @@ -93,15 +112,24 @@ class DeleteNode(BasicAction): """Action of deleting an existing node. Saves all node feature values so the action can be inverted. + + Low-level action — not meant to be used directly. It soft-deletes only the + node itself (incident edges are dropped from the view by + ``remove_node_from_view`` but keep ``solution=True`` in ``graph_full``). + Managing the incident edges' solution flags is the responsibility of the + enclosing user action (``UserDeleteNode``), which soft-deletes each incident + edge with its own ``DeleteEdge`` first. Applying a bare ``DeleteNode`` to a + node that still has in-solution edges therefore leaves ``graph_full``'s edge + flags inconsistent with ``graph_solution`` — always go through the user action. """ def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: Node, ): super().__init__(tracks) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class self.node = int(node) # Save all node feature values from the features dict @@ -120,6 +148,12 @@ def inverse(self) -> BasicAction: return AddNode(self.tracks, self.node, self.attributes) def _apply(self) -> None: - """Remove the node from the graph.""" - self.tracks.graph.remove_node(self.node) + """Soft-delete the node: flag solution=False in the full graph and remove it + from the solution view only. The node (and its topology) is preserved in + graph_full so the delete is reversible and the node remains a candidate. + """ + self.tracks.graph_full.update_node_attrs( + attrs={"solution": False}, node_ids=[self.node] + ) + self.tracks.graph_solution.remove_node_from_view(self.node) self.tracks.notify_annotators(self) diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py index 4d83a158..ffcc6b8d 100644 --- a/src/funtracks/actions/update_segmentation.py +++ b/src/funtracks/actions/update_segmentation.py @@ -61,13 +61,13 @@ def _apply(self) -> None: if value == 0: # val=0 means deleting (part of) the mask - mask_old = self.tracks.graph.nodes[self.node][self.mask_key] + mask_old = self.tracks.graph_full.nodes[self.node][self.mask_key] mask_subtracted = mask_old.__isub__(mask_new) self.tracks.update_mask(self.node, mask_subtracted, mask_key=self.mask_key) - elif self.tracks.graph.has_node(value): + elif self.tracks.graph_full.has_node(value): # if node already exists: - mask_old = self.tracks.graph.nodes[value][self.mask_key] + mask_old = self.tracks.graph_full.nodes[value][self.mask_key] mask_combined = mask_old.__or__(mask_new) self.tracks.update_mask(value, mask_combined, mask_key=self.mask_key) diff --git a/src/funtracks/actions/update_track_id.py b/src/funtracks/actions/update_track_id.py index a09db079..edaabedd 100644 --- a/src/funtracks/actions/update_track_id.py +++ b/src/funtracks/actions/update_track_id.py @@ -5,7 +5,7 @@ from ._base import BasicAction if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks from funtracks.data_model.tracks import Node @@ -31,13 +31,13 @@ class UpdateTrackIDs(BasicAction): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, start_node: Node, tracklet_id: int | None = None, lineage_id: int | None = None, ): super().__init__(tracks) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class self.start_node = start_node # Capture old tracklet ID diff --git a/src/funtracks/annotators/_edge_annotator.py b/src/funtracks/annotators/_edge_annotator.py index b612f431..ee71b95d 100644 --- a/src/funtracks/annotators/_edge_annotator.py +++ b/src/funtracks/annotators/_edge_annotator.py @@ -82,14 +82,14 @@ def compute(self, feature_keys: list[str] | None = None) -> None: # TODO: add skip edges if self.iou_key in keys_to_compute: nodes_by_frame = defaultdict(list) - for n in self.tracks.graph.node_ids(): + for n in self.graph.node_ids(): nodes_by_frame[self.tracks.get_time(n)].append(n) for t in range(self.tracks.segmentation.shape[0] - 1): nodes_in_t = nodes_by_frame[t] edges = [] for node in nodes_in_t: - for succ in self.tracks.graph.successors(node): + for succ in self.graph.successors(node): edges.append((node, succ)) self._iou_update(edges) @@ -105,8 +105,8 @@ def _iou_update( """ for edge in edges: source, target = edge - mask1 = self.tracks.graph.nodes[source]["mask"] - mask2 = self.tracks.graph.nodes[target]["mask"] + mask1 = self.graph.nodes[source]["mask"] + mask2 = self.graph.nodes[target]["mask"] iou = mask1.iou(mask2) self.tracks._set_edge_attr(edge, self.iou_key, iou) @@ -136,16 +136,16 @@ def update(self, action: BasicAction): # Get all incident edges to the modified node modified_node = action.node edges_to_update = [] - for pred in self.tracks.graph.predecessors(modified_node): + for pred in self.graph.predecessors(modified_node): edges_to_update.append((pred, modified_node)) - for succ in self.tracks.graph.successors(modified_node): + for succ in self.graph.successors(modified_node): edges_to_update.append((modified_node, succ)) # Update IoU for each edge for edge in edges_to_update: source, target = edge - mask1 = self.tracks.graph.nodes[source]["mask"] - mask2 = self.tracks.graph.nodes[target]["mask"] + mask1 = self.graph.nodes[source]["mask"] + mask2 = self.graph.nodes[target]["mask"] if mask1.mask.sum() == 0 or mask2.mask.sum() == 0: empty_node = source if mask1.mask.sum() == 0 else target frame = self.tracks.get_time(empty_node) diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py index 7fb8583c..ba681508 100644 --- a/src/funtracks/annotators/_graph_annotator.py +++ b/src/funtracks/annotators/_graph_annotator.py @@ -35,7 +35,7 @@ def can_annotate(cls, tracks: Tracks) -> bool: """Check if this annotator can annotate the given tracks. Subclasses should override this method to specify their requirements - (e.g., segmentation, SolutionTracks, etc.). + (e.g., segmentation, Tracks, etc.). Args: tracks: The tracks to check compatibility with @@ -52,6 +52,29 @@ def __init__(self, tracks: Tracks, features: dict[str, Feature]): key: (feat, False) for key, feat in features.items() } + @property + def graph(self): + """The graph this annotator iterates over and reads topology/masks from. + + Defaults to the full graph. Detection features (`pos`, `area`, `iou`, ...) are + intrinsic to a node/edge — independent of solution membership — so they are + computed for *every* node/edge, including soft-deleted (`solution=False`) + candidates, keeping them valid for revive and ready for re-solving. + `TrackAnnotator` overrides this to `graph_solution`, since track ids + (`tracklet_id`, `lineage_id`) are derived from the solution topology. Root and + view share attribute storage, so writes made here are seen through the solution + view automatically. + + KNOWN COST: computing over `graph_full` means `compute()`/`update()` run + regionprops / mask.iou for every candidate (solution=False) node/edge too. On + candidate graphs with many unselected detections (often 10-100x the solution + size) this is O(candidates) work that is mostly never read, since revive is rare. + This is a deliberate eagerness-for-readiness trade; if it becomes a bottleneck, + the lever is lazy compute-on-revive (default to graph_solution and compute a + candidate's intrinsic features only when it enters the solution). + """ + return self.tracks.graph_full + def activate_features(self, keys: list[str]) -> None: """Activate computation of the given features in the annotation process. @@ -106,7 +129,7 @@ def _filter_feature_keys(self, feature_keys: list[str] | None) -> list[str]: def compute(self, feature_keys: list[str] | None = None) -> None: """Compute a set of features and add them to the tracks. - This involves both updating the node/edge attributes on the tracks.graph + This involves both updating the node/edge attributes on `self.graph` and adding the features to the FeatureDict, if necessary. This is distinct from `update` to allow more efficient bulk computation of features. @@ -123,8 +146,9 @@ def compute(self, feature_keys: list[str] | None = None) -> None: def update(self, action: BasicAction) -> None: """Update a set of features based on the given action. - This involves both updating the node or edge attributes on the tracks.graph - and adding the features to the FeatureDict, if necessary. This is distinct + This involves both updating the node or edge attributes on `self.graph` + and adding the features to the FeatureDict, if + necessary. This is distinct from `compute` to allow more efficient computation of features for single elements. diff --git a/src/funtracks/annotators/_regionprops_annotator.py b/src/funtracks/annotators/_regionprops_annotator.py index 81ec42db..e35295a9 100644 --- a/src/funtracks/annotators/_regionprops_annotator.py +++ b/src/funtracks/annotators/_regionprops_annotator.py @@ -170,10 +170,10 @@ def compute(self, feature_keys: list[str] | None = None) -> None: all_node_ids = [] all_values: dict[str, list] = {key: [] for key in keys_to_compute} - for node_id in self.tracks.graph.node_ids(): - if not self.tracks.graph.has_node(node_id): + for node_id in self.graph.node_ids(): + if not self.graph.has_node(node_id): continue - mask = self.tracks.graph.nodes[node_id]["mask"] + mask = self.graph.nodes[node_id]["mask"] for region in regionprops_extended(mask, spacing=spacing): all_node_ids.append(node_id) for key in keys_to_compute: @@ -203,7 +203,7 @@ def _regionprops_update( spacing = None if self.tracks.scale is None else tuple(self.tracks.scale[1:]) for region in regionprops_extended(mask, spacing=spacing): # Skip labels that aren't nodes in the graph (e.g., unselected detections) - if not self.tracks.graph.has_node(node_id): + if not self.graph.has_node(node_id): continue for key in feature_keys: value = getattr(region, self.regionprops_names[key]) @@ -240,7 +240,7 @@ def update(self, action: BasicAction): time = self.tracks.get_time(node) - if self.tracks.graph.nodes[node]["mask"].mask.sum() == 0: + if self.graph.nodes[node]["mask"].mask.sum() == 0: warnings.warn( f"Cannot find label {node} in frame {time}: " "updating regionprops values to None", @@ -250,7 +250,7 @@ def update(self, action: BasicAction): value = None self.tracks._set_node_attr(node, key, value) else: - mask = self.tracks.graph.nodes[node]["mask"] + mask = self.graph.nodes[node]["mask"] self._regionprops_update(node, mask, keys_to_compute) def change_key(self, old_key: str, new_key: str) -> None: diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index d8c66e42..38c81c7a 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -8,7 +8,6 @@ import tracksdata as td from funtracks.actions import AddNode, DeleteNode, UpdateTrackIDs -from funtracks.data_model import SolutionTracks from funtracks.features import LineageID, TrackletID from ._graph_annotator import GraphAnnotator @@ -17,6 +16,7 @@ from collections.abc import Iterable from funtracks.actions import BasicAction + from funtracks.data_model import Tracks from funtracks.features import Feature @@ -25,9 +25,12 @@ class TrackAnnotator(GraphAnnotator): - """A graph annotator to compute tracklet and lineage IDs for SolutionTracks only. + """A graph annotator that computes tracklet and lineage IDs on the solution view. - Currently, updating the tracklet and lineage IDs is left to Actions. + Registered on every Tracks — track ids are a core feature, not a separate "type" + of tracks. It reads/iterates `graph_solution`; on an empty solution view it simply + computes nothing until nodes are added. Updating the ids after construction is left + to Actions. Attributes: tracklet_id_to_nodes (dict[int, list[int]]): A mapping from tracklet ids to @@ -38,33 +41,20 @@ class TrackAnnotator(GraphAnnotator): max_lineage_id (int): the maximum lineage id used in the tracks Args: - tracks (SolutionTracks): The tracks to be annotated. Must be a solution. - tracklet_key (str | None, optional): A key that already holds the tracklet ids - on the graph. If provided, must be there for every node and already hold - valid tracklet ids. Defaults to None. - lineage_key (str | None, optional): A key that already holds the lineage ids - on the graph. If provided, must be there for every node and already hold - valid lineage ids. Defaults to None. - - - Raises: - ValueError: if the provided Tracks are not SolutionTracks (not a binary lineage - tree) + tracks (Tracks): The tracks to annotate. + tracklet_key (str | None, optional): The node attribute holding tracklet ids. + If the graph already holds valid ids under this key they are reused; + otherwise they are computed. Defaults to DEFAULT_TRACKLET_KEY. + lineage_key (str | None, optional): The node attribute holding lineage ids. + Same semantics as tracklet_key. Defaults to DEFAULT_LINEAGE_KEY. """ - @classmethod - def can_annotate(cls, tracks) -> bool: - """Check if this annotator can annotate the given tracks. - - Requires tracks to be a SolutionTracks instance. - - Args: - tracks: The tracks to check compatibility with - - Returns: - True if tracks is a SolutionTracks instance, False otherwise - """ - return isinstance(tracks, SolutionTracks) + @property + def graph(self): + """Track ids (`tracklet_id`, `lineage_id`) are derived from the solution + topology, so this annotator reads/iterates the solution view, not the full + graph (overriding the base default of `graph_full`).""" + return self.tracks.graph_solution @classmethod def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]: @@ -87,14 +77,11 @@ def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]: def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, tracklet_key: str | None = DEFAULT_TRACKLET_KEY, lineage_key: str | None = DEFAULT_LINEAGE_KEY, ): - if not isinstance(tracks, SolutionTracks): - raise ValueError("Currently the TrackAnnotator only works on SolutionTracks") - - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class self.tracklet_key = ( tracklet_key if tracklet_key is not None else DEFAULT_TRACKLET_KEY ) @@ -112,13 +99,13 @@ def __init__( self.max_lineage_id = 0 # Initialize tracklet bookkeeping if track IDs already exist in the graph - if tracks.graph.num_nodes() > 0: + if tracks.graph_solution.num_nodes() > 0: max_id, id_to_nodes = self._get_max_id_and_map(self.tracklet_key) self.max_tracklet_id = max_id self.tracklet_id_to_nodes = id_to_nodes # Initialize lineage bookkeeping if lineage IDs already exist - if lineage_key is not None and tracks.graph.num_nodes() > 0: + if lineage_key is not None and tracks.graph_solution.num_nodes() > 0: max_id, id_to_nodes = self._get_max_id_and_map(self.lineage_key) self.max_lineage_id = max_id self.lineage_id_to_nodes = id_to_nodes @@ -133,9 +120,9 @@ def _get_max_id_and_map(self, key: str) -> tuple[int, dict[int, list[int]]]: tuple[int, dict[int, list[int]]]: The maximum id value, and a mapping from ids to a list of nodes with that id. """ - if key not in self.tracks.graph.node_attr_keys(): + if key not in self.graph.node_attr_keys(): return 0, {} - df = self.tracks.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, key]) + df = self.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, key]) id_to_nodes = defaultdict(list) for node, _id in zip(df[td.DEFAULT_ATTR_KEYS.NODE_ID], df[key], strict=True): if _id is None: @@ -199,12 +186,12 @@ def _assign_lineage_ids(self) -> None: attributes will be updated. """ - lineages_internal = rx.weakly_connected_components(self.tracks.graph.rx_graph) + lineages_internal = rx.weakly_connected_components(self.graph.rx_graph) # Map each component's internal node indices to external ids in one batched # call. node_ids() rebuilds the full external-id list on every call, so calling # it per node (as before) was O(N^2). lineages_external = [ - self.tracks.graph._map_to_external(list(lin)) for lin in lineages_internal + self.graph._map_to_external(list(lin)) for lin in lineages_internal ] max_id, ids_to_nodes = self._assign_ids(lineages_external, self.lineage_key) @@ -222,7 +209,7 @@ def _assign_tracklet_ids(self) -> None: # slow because each remove_edge syncs the view's bidirectional edge maps, # whereas rustworkx edge removal is in-memory. copy() preserves node indices, # so components map back through the original graph's id mapping. - rx_copy = self.tracks.graph.rx_graph.copy() + rx_copy = self.graph.rx_graph.copy() for node in rx_copy.node_indices(): if rx_copy.out_degree(node) >= 2: for _, daughter, _ in list(rx_copy.out_edges(node)): @@ -231,18 +218,23 @@ def _assign_tracklet_ids(self) -> None: track_id = 1 all_node_ids = [] all_track_ids = [] + id_to_nodes = {} for tracklet in rx.weakly_connected_components(rx_copy): # Batched internal -> external mapping (see _assign_lineage_ids). - node_ids_external = self.tracks.graph._map_to_external(list(tracklet)) + node_ids_external = self.graph._map_to_external(list(tracklet)) all_node_ids.extend(node_ids_external) all_track_ids.extend([track_id] * len(node_ids_external)) - self.tracklet_id_to_nodes[track_id] = node_ids_external + id_to_nodes[track_id] = node_ids_external track_id += 1 if all_node_ids: - self.tracks.graph.update_node_attrs( + self.graph.update_node_attrs( attrs={self.tracks.features.tracklet_key: all_track_ids}, node_ids=all_node_ids, ) + # Replace the bookkeeping wholesale (like _assign_lineage_ids) so stale + # entries — e.g. a phantom tracklet -1 seeded at init from a graph whose + # tracklet column still held the -1 sentinel — don't survive a recompute. + self.tracklet_id_to_nodes = id_to_nodes self.max_tracklet_id = track_id - 1 def update(self, action: BasicAction) -> None: @@ -309,18 +301,18 @@ def _handle_update_track_ids(self, action: UpdateTrackIDs) -> None: still_in_tracklet = False # Continue to all successors - next_nodes.extend(self.tracks.graph.successors(node)) + next_nodes.extend(self.graph.successors(node)) curr_nodes = next_nodes # Bulk-write all collected node attribute changes in one call each if update_tracklet and tracklet_nodes: - self.tracks.graph.update_node_attrs( + self.graph.update_node_attrs( attrs={self.tracklet_key: [new_tracklet_id] * len(tracklet_nodes)}, node_ids=tracklet_nodes, ) if update_lineage and lineage_nodes: - self.tracks.graph.update_node_attrs( + self.graph.update_node_attrs( attrs={self.lineage_key: [new_lineage_id] * len(lineage_nodes)}, node_ids=lineage_nodes, ) diff --git a/src/funtracks/candidate_graph/compute_graph.py b/src/funtracks/candidate_graph/compute_graph.py index a0efa378..e77ee461 100644 --- a/src/funtracks/candidate_graph/compute_graph.py +++ b/src/funtracks/candidate_graph/compute_graph.py @@ -15,7 +15,7 @@ def compute_graph_from_seg( iou: bool = False, scale: list[float] | None = None, t_start: int = 0, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """Construct a candidate graph from a segmentation array. Nodes are placed at the centroid of each segmentation and edges are added for all nodes in adjacent frames within max_edge_distance. @@ -37,7 +37,7 @@ def compute_graph_from_seg( time values. Defaults to 0. Returns: - td.graph.GraphView: A candidate graph that can be passed to the motile solver + td.graph.BaseGraph: A candidate graph that can be passed to the motile solver """ # add nodes (including mask and bbox in the same bulk_add_nodes call) cand_graph, node_frame_dict = nodes_from_segmentation( @@ -73,7 +73,7 @@ def compute_graph_from_points_list( points_list: np.ndarray, max_edge_distance: float, scale: list[float] | None = None, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """Construct a candidate graph from a points list. Args: @@ -88,7 +88,7 @@ def compute_graph_from_points_list( isotropic. Returns: - td.graph.GraphView: A candidate graph that can be passed to the motile solver. + td.graph.BaseGraph: A candidate graph that can be passed to the motile solver. """ # add nodes cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale) diff --git a/src/funtracks/candidate_graph/iou.py b/src/funtracks/candidate_graph/iou.py index 0ca2d315..97dd6e71 100644 --- a/src/funtracks/candidate_graph/iou.py +++ b/src/funtracks/candidate_graph/iou.py @@ -80,7 +80,7 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]: def add_iou( - cand_graph: td.graph.GraphView, + cand_graph: td.graph.BaseGraph, segmentation: np.ndarray, node_frame_dict: dict[int, list[int]] | None = None, multiseg=False, @@ -92,7 +92,7 @@ def add_iou( add IOU to an existing graph after the fact. Args: - cand_graph (td.graph.GraphView): Candidate graph with nodes and edges already + cand_graph (td.graph.BaseGraph): Candidate graph with nodes and edges already populated. segmentation (np.ndarray): segmentation that was used to create cand_graph. Has shape ([h], t, [z], y, x), where h is the number of hypotheses if diff --git a/src/funtracks/candidate_graph/utils.py b/src/funtracks/candidate_graph/utils.py index 68948515..f1fee38c 100644 --- a/src/funtracks/candidate_graph/utils.py +++ b/src/funtracks/candidate_graph/utils.py @@ -9,7 +9,7 @@ from tqdm import tqdm from tracksdata.nodes import Mask -from ..utils.tracksdata_utils import create_empty_graphview_graph +from ..utils.tracksdata_utils import create_empty_graph logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ def nodes_from_segmentation( scale: list[float] | None = None, mask: bool = True, t_start: int = 0, -) -> tuple[td.graph.GraphView, dict[int, list[Any]]]: +) -> tuple[td.graph.BaseGraph, dict[int, list[Any]]]: """Extract candidate nodes from a segmentation. Returns a tracksdata graph with only nodes, and also a dictionary from frames to node_ids for efficient edge adding. @@ -50,7 +50,7 @@ def nodes_from_segmentation( time values. Defaults to 0. Returns: - tuple[td.graph.GraphView, dict[int, list[Any]]]: A candidate graph with only + tuple[td.graph.BaseGraph, dict[int, list[Any]]]: A candidate graph with only nodes, and a mapping from time frames to node ids. """ logger.debug("Extracting nodes from segmentation") @@ -66,7 +66,7 @@ def nodes_from_segmentation( ) node_attributes = ["pos", "area", "mask", "bbox"] if mask else ["pos", "area"] - cand_graph = create_empty_graphview_graph( + cand_graph = create_empty_graph( node_attributes=node_attributes, position_attrs=["pos"], ndim=segmentation.ndim, @@ -115,7 +115,7 @@ def nodes_from_segmentation( def nodes_from_points_list( points_list: np.ndarray, scale: list[float] | None = None, -) -> tuple[td.graph.GraphView, dict[int, list[Any]]]: +) -> tuple[td.graph.BaseGraph, dict[int, list[Any]]]: """Extract candidate nodes from a list of points. Uses the index of the point in the list as its unique id. Returns a tracksdata graph with only nodes, and also a dictionary from frames to @@ -130,7 +130,7 @@ def nodes_from_points_list( implies the data is isotropic. Returns: - tuple[td.graph.GraphView, dict[int, list[Any]]]: A candidate graph with only + tuple[td.graph.BaseGraph, dict[int, list[Any]]]: A candidate graph with only nodes, and a mapping from time frames to node ids. """ logger.info("Extracting nodes from points list") @@ -143,7 +143,7 @@ def nodes_from_points_list( ) points_list = points_list * np.array(scale) - cand_graph = create_empty_graphview_graph( + cand_graph = create_empty_graph( node_attributes=["pos"], position_attrs=["pos"], ndim=ndim, @@ -168,11 +168,11 @@ def nodes_from_points_list( return cand_graph, node_frame_dict -def _compute_node_frame_dict(cand_graph: td.graph.GraphView) -> dict[int, list[Any]]: +def _compute_node_frame_dict(cand_graph: td.graph.BaseGraph) -> dict[int, list[Any]]: """Compute dictionary from time frames to node ids for candidate graph. Args: - cand_graph (td.graph.GraphView): A tracksdata graph + cand_graph (td.graph.BaseGraph): A tracksdata graph Returns: dict[int, list[Any]]: A mapping from time frames to lists of node ids. @@ -187,12 +187,12 @@ def _compute_node_frame_dict(cand_graph: td.graph.GraphView) -> dict[int, list[A return node_frame_dict -def create_kdtree(cand_graph: td.graph.GraphView, node_ids: list[Any]) -> KDTree: +def create_kdtree(cand_graph: td.graph.BaseGraph, node_ids: list[Any]) -> KDTree: """Create a kdtree with the given nodes from the candidate graph. Will fail if provided node ids are not in the candidate graph. Args: - cand_graph (td.graph.GraphView): A candidate graph + cand_graph (td.graph.BaseGraph): A candidate graph node_ids (list[Any]): The nodes within the candidate graph to include in the KDTree. Useful for limiting to one time frame. Must be a list (not a generic iterable) to preserve order for @@ -212,7 +212,7 @@ def create_kdtree(cand_graph: td.graph.GraphView, node_ids: list[Any]) -> KDTree def add_cand_edges( - cand_graph: td.graph.GraphView, + cand_graph: td.graph.BaseGraph, max_edge_distance: float, node_frame_dict: None | dict[int, list[Any]] = None, iou_dict: dict[int, dict[int, float]] | None = None, @@ -221,7 +221,7 @@ def add_cand_edges( frames that are closer than max_edge_distance. Args: - cand_graph (td.graph.GraphView): Candidate graph with only nodes populated. + cand_graph (td.graph.BaseGraph): Candidate graph with only nodes populated. Will be modified in-place to add edges. max_edge_distance (float): Maximum distance that objects can travel between frames. All nodes within this distance in adjacent frames will by connected diff --git a/src/funtracks/data_model/__init__.py b/src/funtracks/data_model/__init__.py index ef0f4187..8ff09431 100644 --- a/src/funtracks/data_model/__init__.py +++ b/src/funtracks/data_model/__init__.py @@ -1,2 +1 @@ from .tracks import Tracks # noqa -from .solution_tracks import SolutionTracks # noqa diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py deleted file mode 100644 index 9b3eaaee..00000000 --- a/src/funtracks/data_model/solution_tracks.py +++ /dev/null @@ -1,244 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import tracksdata as td - -from funtracks.features import FeatureDict - -from .tracks import Tracks - -if TYPE_CHECKING: - from funtracks.annotators import TrackAnnotator - - from .tracks import Node - - -class SolutionTracks(Tracks): - """Difference from Tracks: every node must have a tracklet id""" - - def __init__( - self, - graph: td.graph.GraphView, - time_attr: str | None = None, - pos_attr: str | tuple[str] | list[str] | None = None, - tracklet_attr: str | None = None, - lineage_attr: str | None = None, - scale: list[float] | None = None, - ndim: int | None = None, - features: FeatureDict | None = None, - _segmentation: td.array.GraphArrayView | None = None, - ): - """Initialize a SolutionTracks object. - - SolutionTracks extends Tracks to ensure every node has a tracklet id. A - TrackAnnotator is automatically added to manage track IDs. - - Args: - graph (td.graph.GraphView): Tracksdata graph with nodes as detections - and edges as links. - time_attr (str | None): Graph attribute name for time. Defaults to "time" - if None. - pos_attr (str | tuple[str, ...] | list[str] | None): Graph attribute - name(s) for position. Can be: - - Single string for one attribute containing position array - - List/tuple of strings for multi-axis (one attribute per axis) - Defaults to "pos" if None. - tracklet_attr (str | None): Graph attribute name for tracklet/track IDs. - Defaults to "tracklet_id" if None. - lineage_attr (str | None): Graph attribute name for lineage IDs. - Defaults to "lineage_id" if None. - scale (list[float] | None): Scaling factors for each dimension (including - time). If None, all dimensions scaled by 1.0. - ndim (int | None): Number of dimensions (3 for 2D+time, 4 for 3D+time). - If None, inferred from segmentation or scale. - features (FeatureDict | None): Pre-built FeatureDict with feature - definitions. If provided, time_attr/pos_attr/tracklet_attr are ignored. - Assumes that all features in the dict already exist on the graph (will - be activated but not recomputed). If None, core computed features (pos, - area, tracklet_id) are auto-detected by checking if they exist on the - graph. - _segmentation (GraphArrayView | None): Internal parameter for reusing an - existing GraphArrayView instance. Not intended for public use. - """ - super().__init__( - graph, - time_attr=time_attr, - pos_attr=pos_attr, - tracklet_attr=tracklet_attr, - lineage_attr=lineage_attr, - scale=scale, - ndim=ndim, - features=features, - _segmentation=_segmentation, - ) - - self.track_annotator = self._get_track_annotator() - - def _get_track_annotator(self) -> TrackAnnotator: - """Get the TrackAnnotator instance from the annotator registry. - - Returns: - TrackAnnotator: The track annotator instance - - Raises: - RuntimeError: If no TrackAnnotator is registered - """ - from funtracks.annotators import TrackAnnotator - - for annotator in self.annotators: - if isinstance(annotator, TrackAnnotator): - return annotator - raise RuntimeError( - "No TrackAnnotator registered for this SolutionTracks instance" - ) - - @classmethod - def from_tracks(cls, tracks: Tracks): - force_recompute = False - # Check if all nodes have a value at features.tracklet_key before trusting - # existing track IDs - if ( - tracks.features.tracklet_key is not None - and ( - tracks.graph.node_attrs(attr_keys=tracks.features.tracklet_key)[ - tracks.features.tracklet_key - ] - == -1 - ).any() - # Attributes are no longer None, so 0 now means non-computed - ): - force_recompute = True - - soln_tracks = cls( - tracks.graph, - scale=tracks.scale, - ndim=tracks.ndim, - features=tracks.features, - _segmentation=tracks.segmentation, - ) - if force_recompute: - soln_tracks.enable_features( - [ - soln_tracks.features.tracklet_key, # type: ignore[list-item] - soln_tracks.features.lineage_key, # type: ignore[list-item] - ] - ) - return soln_tracks - - @property - def max_track_id(self) -> int: - return self.track_annotator.max_tracklet_id - - @property - def track_id_to_node(self) -> dict[int, list[int]]: - return self.track_annotator.tracklet_id_to_nodes - - def get_next_track_id(self) -> int: - """Return the next available track_id. - - The max_tracklet_id in TrackAnnotator is updated automatically when - a node is added or track IDs are updated via UpdateTrackIDs. - """ - return self.track_annotator.max_tracklet_id + 1 - - def get_next_lineage_id(self) -> int: - """Return the next available lineage_id. - - The max_lineage_id in TrackAnnotator is updated automatically when - a node is added or lineage IDs are updated via UpdateTrackIDs. - """ - return self.track_annotator.max_lineage_id + 1 - - def get_track_id(self, node) -> int: - if self.features.tracklet_key is None: - raise ValueError("Tracklet key not initialized in features") - track_id = self.get_node_attr(node, self.features.tracklet_key) - return track_id - - def get_track_ids(self, nodes) -> list[int]: - """Batch version of get_track_id — one SQL query fetching all nodes in the graph. - NOTE: always fetches the entire graph internally. Optimised for bulk (all-node) - calls. For small subsets or single nodes use get_track_id() instead.""" - - if self.features.tracklet_key is None: - raise ValueError("Tracklet key not initialized in features") - tracklet_key = self.features.tracklet_key - df = self.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key]) - id_to_val = dict( - zip( - df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list(), - df[tracklet_key].to_list(), - strict=True, - ) - ) - return [id_to_val[node] for node in nodes] - - def get_lineage_id(self, node) -> int | None: - """Get the lineage ID for a node. - - Args: - node: The node to get lineage ID for - - Returns: - The lineage ID, or None if lineage feature is not enabled - """ - if self.features.lineage_key is None: - return None - return self.get_node_attr(node, self.features.lineage_key) - - def get_track_neighbors( - self, track_id: int, time: int - ) -> tuple[Node | None, Node | None]: - """Get the last node with the given track id before time, and the first node - with the track id after time, if any. Does not assume that a node with - the given track_id and time is already in tracks, but it can be. - - Args: - track_id (int): The track id to search for - time (int): The time point to find the immediate predecessor and successor - for - - Returns: - tuple[Node | None, Node | None]: The last node before time with the given - track id, and the first node after time with the given track id, - or Nones if there are no such nodes. - """ - annotator = self.track_annotator - if ( - track_id not in annotator.tracklet_id_to_nodes - or len(annotator.tracklet_id_to_nodes[track_id]) == 0 - ): - return None, None - candidates = annotator.tracklet_id_to_nodes[track_id] - candidates.sort(key=lambda n: self.get_time(n)) - - pred = None - succ = None - for cand in candidates: - if self.get_time(cand) < time: - pred = cand - elif self.get_time(cand) > time: - succ = cand - break - return ( - int(pred) if pred is not None else None, - int(succ) if succ is not None else None, - ) - - def has_track_id_at_time(self, track_id: int, time: int) -> bool: - """Function to check if a node with given track id exists at given time point. - - Args: - track_id (int): The track id to search for. - time (int): The time point to check. - - Returns: - True if a node with given track id exists at given time point. - """ - - nodes = self.track_id_to_node.get(track_id) - if not nodes: - return False - - return time in [self.get_time(node) for node in nodes] diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 0cc16dc8..6131dcb6 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -17,6 +17,7 @@ from tracksdata.nodes import Mask from funtracks.actions.action_history import ActionHistory +from funtracks.annotators import TrackAnnotator from funtracks.features import ( Feature, FeatureDict, @@ -52,8 +53,11 @@ class Tracks: position attribute. Edges in the graph represent links across time. Attributes: - graph (td.graph.GraphView): A graph with nodes representing detections and - and edges representing links across time. + graph_full (td.graph.BaseGraph): The full graph (first-class): every node/edge + ever known, including soft-deleted (solution=False) candidates. Nodes + represent detections, edges represent links across time. + graph_solution (td.graph.GraphView): A solution==True view derived from + graph_full; the user-visible tracking solution. features (FeatureDict): Dictionary of features tracked on graph nodes/edges. annotators (AnnotatorRegistry): List of annotators that compute features. scale (list[float] | None): How much to scale each dimension by, including time. @@ -65,7 +69,7 @@ class Tracks: def __init__( self, - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, time_attr: str | None = None, pos_attr: str | tuple[str, ...] | list[str] | None = None, tracklet_attr: str | None = None, @@ -78,8 +82,9 @@ def __init__( """Initialize a Tracks object. Args: - graph (td.graph.GraphView): tracksdata directed graph with nodes as detections - and edges as links. + graph (td.graph.BaseGraph): the full base graph (graph_full) with nodes as + detections and edges as links. Must be a base graph, not a view; + graph_solution is built internally as its solution==True view. time_attr (str | None): Graph attribute name for time. Defaults to "time" if None. pos_attr (str | tuple[str, ...] | list[str] | None): Graph attribute @@ -88,7 +93,9 @@ def __init__( - List/tuple of strings for multi-axis (one attribute per axis) Defaults to "pos" if None. tracklet_attr (str | None): Graph attribute name for tracklet/track IDs. - Defaults to "tracklet_id" if None. + Defaults to "tracklet_id" if None. Every Tracks gets a TrackAnnotator + and track ids (no "plain" vs "solution" distinction); existing ids on + the graph are reused, otherwise they are computed. lineage_attr (str | None): Graph attribute name for lineage IDs. Defaults to "lineage_id" if None. scale (list[float] | None): Scaling factors for each dimension (including @@ -103,7 +110,23 @@ def __init__( _segmentation (GraphArrayView | None): Internal parameter for reusing an existing GraphArrayView instance. Not intended for public use. """ - self.graph = graph + # graph_full is the first-class object: the base graph holding every node/edge + # ever known, including soft-deleted (solution=False) candidates. graph_solution + # is derived from it as a solution==True view. + if isinstance(graph, td.graph.GraphView): + raise ValueError( + "Tracks requires the full base graph (graph_full), not a GraphView. " + "graph_solution is built internally as a solution==True view of it." + ) + if "solution" not in graph.node_attr_keys(): + graph.add_node_attr_key("solution", default_value=True, dtype=pl.Boolean) + if "solution" not in graph.edge_attr_keys(): + graph.add_edge_attr_key("solution", default_value=True, dtype=pl.Boolean) + self.graph_full = graph + self.graph_solution = graph.filter( + td.NodeAttr("solution") == True, # noqa: E712 + td.EdgeAttr("solution") == True, # noqa: E712 + ).subgraph() if _segmentation is not None: # Reuse provided segmentation instance (internal use only) self.segmentation = _segmentation @@ -115,8 +138,10 @@ def __init__( seg_shape = graph.metadata.get("segmentation_shape") if seg_shape is not None: try: + # Render the segmentation from the solution view so soft-deleted + # nodes drop out of the array, mirroring the user-visible graph. array_view = GraphArrayView( - graph=graph, + graph=self.graph_solution, shape=seg_shape, attr_key="node_id", offset=0, @@ -173,6 +198,11 @@ def __init__( else: self._setup_core_computed_features() + # 4. Enforce the track-id invariant on BOTH paths: every Tracks has a tracklet + # key and a registered TrackAnnotator, with tracklet_id/lineage_id registered + # and computed. A provided FeatureDict that omitted them is completed here. + self._ensure_track_features() + def _get_feature_set( self, time_attr: str | None, @@ -193,11 +223,10 @@ def _get_feature_set( - Single string: one attribute containing position array (e.g., "pos") - List/tuple: multiple attributes, one per axis (e.g., ["y", "x"]) - None: defaults to "pos" - tracklet_key: Graph attribute name for tracklet/track IDs - (e.g., "tracklet_id"). - If None, defaults to "tracklet_id" - lineage_key: Graph attribute name for lineage IDs (e.g., "lineage_id"). - if None, defaults to "lineage_id" + tracklet_key: Graph attribute name for tracklet/track IDs. + Defaults to "tracklet_id" if None (every Tracks gets track ids). + lineage_key: Graph attribute name for lineage IDs. + Defaults to "lineage_id" if None. Returns: FeatureDict initialized with time feature and position if no segmentation @@ -206,10 +235,11 @@ def _get_feature_set( time_key = time_attr if time_attr is not None else "time" if pos_attr is None: pos_attr = "pos" - if tracklet_key is None: - tracklet_key = "tracklet_id" - if lineage_key is None: - lineage_key = "lineage_id" + # Every Tracks has a tracklet/lineage key (no "plain" vs "solution" split): + # default them like time/pos. _ensure_track_features() then registers and + # computes the ids; on an empty solution view that is a no-op. + tracklet_key = tracklet_key if tracklet_key is not None else "tracklet_id" + lineage_key = lineage_key if lineage_key is not None else "lineage_id" # Build static features dict - always include time features: dict[str, Feature] = {time_key: Time()} @@ -247,7 +277,7 @@ def _get_feature_set( # else: single pos_attr with segmentation - RegionpropsAnnotator will handle it # Register solution feature when present on the graph - if "solution" in self.graph.node_attr_keys(): + if "solution" in self.graph_solution.node_attr_keys(): feature_dict["solution"] = Solution() # Register mask and bbox features if segmentation exists @@ -265,7 +295,7 @@ def _get_annotators(self) -> AnnotatorRegistry: Creates annotators conditionally: - RegionpropsAnnotator: Only if segmentation is provided - EdgeAnnotator: Only if segmentation is provided - - TrackAnnotator: Only if this is a SolutionTracks instance + - TrackAnnotator: Always (every Tracks has track ids) Each annotator is configured with appropriate keys from self.features. @@ -277,7 +307,6 @@ def _get_annotators(self) -> AnnotatorRegistry: AnnotatorRegistry, EdgeAnnotator, RegionpropsAnnotator, - TrackAnnotator, ) annotator_list: list[GraphAnnotator] = [] @@ -296,15 +325,16 @@ def _get_annotators(self) -> AnnotatorRegistry: if EdgeAnnotator.can_annotate(self): annotator_list.append(EdgeAnnotator(self)) - # TrackAnnotator: requires SolutionTracks (checked in can_annotate) - if TrackAnnotator.can_annotate(self): - annotator_list.append( - TrackAnnotator( - self, # type: ignore[arg-type] - tracklet_key=self.features.tracklet_key, - lineage_key=self.features.lineage_key, - ) + # TrackAnnotator is registered on every Tracks — track ids are a core feature, + # not a separate "type" of tracks. On an empty solution view it simply computes + # nothing until nodes are added. + annotator_list.append( + TrackAnnotator( + self, + tracklet_key=self.features.tracklet_key, + lineage_key=self.features.lineage_key, ) + ) return AnnotatorRegistry(annotator_list) def _activate_features_from_dict(self) -> None: @@ -326,23 +356,22 @@ def _check_existing_feature(self, key: str) -> bool: bool: True if the key is on the first sampled node or there are no nodes, and False if missing from the first node. """ - if self.graph.num_nodes() == 0: + if self.graph_solution.num_nodes() == 0: return True # Check which attributes exist - node_attrs = set(self.graph.node_attr_keys()) + node_attrs = set(self.graph_solution.node_attr_keys()) return key in node_attrs def _setup_core_computed_features(self) -> None: - """Sets up core computed features (position, tracklet, lineage). + """Sets up core computed position features. - Registers position/tracklet/lineage features from annotators into - FeatureDict. For each core feature: - - Activates any features that are detected to already exist on the graph - - Enables (computes) any features that don't exist yet + Registers the position feature from the RegionpropsAnnotator into the + FeatureDict, activating it if it already exists on the graph or computing it + otherwise. Track-id features are handled separately by _ensure_track_features. """ # Import here to avoid circular dependency - from funtracks.annotators import RegionpropsAnnotator, TrackAnnotator + from funtracks.annotators import RegionpropsAnnotator core_features: list[str] = [] for annotator in self.annotators: @@ -351,14 +380,12 @@ def _setup_core_computed_features(self) -> None: if self.features.position_key is None: self.features.position_key = pos_key core_features.append(pos_key) - elif isinstance(annotator, TrackAnnotator): - tracklet_key = annotator.tracklet_key - self.features.tracklet_key = tracklet_key - core_features.append(tracklet_key) - lineage_key = annotator.lineage_key - self.features.lineage_key = lineage_key - core_features.append(lineage_key) - for key in core_features: + self._register_core_features(core_features) + + def _register_core_features(self, keys: list[str]) -> None: + """Register each key as a feature: activate it if it already exists on the + graph, otherwise enable (compute) it.""" + for key in keys: if self._check_existing_feature(key): if key not in self.features: feature, _ = self.annotators.all_features[key] @@ -367,38 +394,53 @@ def _setup_core_computed_features(self) -> None: else: self.enable_features([key]) - def nodes(self): - return np.array(self.graph.node_ids()) + def _ensure_track_features(self) -> None: + """Ensure the track-id core features exist on this Tracks. - def edges(self): - return np.array(self.graph.edge_ids()) + Every Tracks has a registered TrackAnnotator and a tracklet key (no "plain" + vs "solution" split). This syncs features.tracklet_key/lineage_key from the + annotator and registers + computes (or activates, if already present) + tracklet_id/lineage_id. Runs on both the provided-FeatureDict and the + auto-detect construction paths; a no-op on an empty solution view. + """ + annotator = self.track_annotator + self.features.tracklet_key = annotator.tracklet_key + self.features.lineage_key = annotator.lineage_key + # A tracklet column can exist yet still hold the -1 sentinel ("not computed", + # the column default). Trusting it would activate stale ids and seed a phantom + # tracklet -1 in the annotator bookkeeping, so force a recompute from topology. + if self._has_uncomputed_track_ids(annotator.tracklet_key): + self.enable_features([annotator.tracklet_key, annotator.lineage_key]) + else: + self._register_core_features([annotator.tracklet_key, annotator.lineage_key]) - def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: - """Get the in-degree edge_ids of the nodes in the graph.""" - if nodes is not None: - # make sure nodes is a numpy array - if not isinstance(nodes, np.ndarray): - nodes = np.array(nodes) + def _has_uncomputed_track_ids(self, tracklet_key: str) -> bool: + """True if the tracklet column exists but any node still holds the -1 sentinel. - return np.array([self.graph.in_degree(node.item()) for node in nodes]) - else: - return np.array(self.graph.in_degree()) + A missing column returns False: _register_core_features computes it from scratch. + """ + if self.graph_solution.num_nodes() == 0: + return False + if tracklet_key not in self.graph_solution.node_attr_keys(): + return False + values = self.graph_solution.node_attrs(attr_keys=[tracklet_key])[tracklet_key] + return bool((values == -1).any()) - def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: - if nodes is not None: - # make sure nodes is a numpy array - if not isinstance(nodes, np.ndarray): - nodes = np.array(nodes) + def nodes(self): + """Return the node ids of the solution graph as a numpy array.""" + return np.array(self.graph_solution.node_ids()) - return np.array([self.graph.out_degree(node.item()) for node in nodes]) - else: - return np.array(self.graph.out_degree()) + def edges(self): + """Return the edge ids of the solution graph as a numpy array.""" + return np.array(self.graph_solution.edge_ids()) def predecessors(self, node: int) -> list[int]: - return list(self.graph.predecessors(node)) + """Return the predecessors of a node in the solution graph.""" + return list(self.graph_solution.predecessors(node)) def successors(self, node: int) -> list[int]: - return list(self.graph.successors(node)) + """Return the successors of a node in the solution graph.""" + return list(self.graph_solution.successors(node)) def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.ndarray: """Get the positions of nodes in the graph. Optionally include the @@ -409,7 +451,7 @@ def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.nd For a single node use get_position() instead. Args: - node (Iterable[Node]): The node ids in the graph to get the positions of + nodes (Iterable[Node]): The node ids in the graph to get the positions of incl_time (bool, optional): If true, include the time as the first element of each position array. Defaults to False. @@ -430,7 +472,10 @@ def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.nd + ([self.features.time_key] if incl_time else []) ) - df = self.graph.node_attrs(attr_keys=attr_keys) + # Read from graph_full (consistent with get_position / the attr-helper policy): + # positions are intrinsic node attrs, so this also resolves soft-deleted + # (solution=False) nodes instead of KeyError-ing like a graph_solution query. + df = self.graph_full.node_attrs(attr_keys=attr_keys) id_to_row = { nid: i for i, nid in enumerate(df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list()) } @@ -497,6 +542,7 @@ def set_positions( self._set_nodes_attr(nodes, self.features.position_key, positions) def set_position(self, node: Node, position: list | np.ndarray): + """Set the position of a single node.""" self.set_positions([node], np.expand_dims(np.array(position), axis=0)) def get_times(self, nodes: Iterable[Node]) -> Sequence[int]: @@ -505,7 +551,7 @@ def get_times(self, nodes: Iterable[Node]) -> Sequence[int]: For a single node use get_time() instead. """ nodes = list(nodes) - df = self.graph.node_attrs( + df = self.graph_full.node_attrs( attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, self.features.time_key] ) id_to_val = dict( @@ -546,7 +592,7 @@ def get_mask( if self.segmentation is None: return None - mask = self.graph.nodes[node][mask_key] + mask = self.graph_full.nodes[node][mask_key] return mask def update_mask( @@ -563,13 +609,13 @@ def update_mask( mask_key: The feature key for the mask column. Defaults to the standard mask key. """ - self.graph.nodes[node][mask_key] = mask + self.graph_full.nodes[node][mask_key] = mask mask_feature = self.features.get(mask_key) if mask_feature is not None: # NOTE: all derived features of a mask are currently assumed to be # its bounding box. Revisit if non-bbox derived features are added. for derived_key in mask_feature.get("derived_features", []): - self.graph.nodes[node][derived_key] = mask.bbox + self.graph_full.nodes[node][derived_key] = mask.bbox def undo(self) -> bool: """Undo the last performed action from the action history. @@ -603,10 +649,13 @@ def _get_new_node_ids(self, n: int) -> list[Node]: Returns: list[Node]: A list of new node ids. """ + # Check against graph_full, not the solution view: a soft-deleted + # (solution=False) node still occupies its id in the full graph, so reissuing + # it to a genuinely new node would collide with the still-present root node. ids = [self.node_id_counter + i for i in range(n)] self.node_id_counter += n for idx, _id in enumerate(ids): - while self.graph.has_node(_id): + while self.graph_full.has_node(_id): _id = self.node_id_counter self.node_id_counter += 1 ids[idx] = _id @@ -635,39 +684,56 @@ def _compute_ndim( ) return ndim + # NOTE: the low-level attribute get/set helpers below target `graph_full`, not the + # solution view. Attribute *values* are intrinsic to a node/edge and live on the full + # graph; the solution view is a membership filter over it. Because graph_full ⊇ + # graph_solution and their attr dicts are shared by reference, writing via graph_full + # is identical to writing via the view for any in-solution node (the view sees it + # automatically) and additionally works for soft-deleted (solution=False) candidates. def _set_node_attr(self, node: Node, attr: str, value: Any): if isinstance(value, np.ndarray): value = list(value) - self.graph.nodes[node][attr] = value + self.graph_full.nodes[node][attr] = value def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): nodes_list = list(nodes) values_list = list(values) if nodes_list: - self.graph.update_node_attrs(attrs={attr: values_list}, node_ids=nodes_list) + self.graph_full.update_node_attrs( + attrs={attr: values_list}, node_ids=nodes_list + ) def get_node_attr(self, node: Node, attr: str): - return self.graph.nodes[int(node)][attr] + """Get an attribute value for a single node (resolved on graph_full).""" + return self.graph_full.nodes[int(node)][attr] def get_nodes_attr(self, nodes: Iterable[Node], attr: str): + """Get an attribute value for each of the given nodes.""" return [self.get_node_attr(node, attr) for node in nodes] def _set_edge_attr(self, edge: Edge, attr: str, value: Any): - edge_id = self.graph.edge_id(edge[0], edge[1]) - self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) + edge_id = self.graph_full.edge_id(edge[0], edge[1]) + # Wrap in a single-element list: update_edge_attrs reads a bare list value + # (e.g. a vector feature) as one-value-per-edge. + self.graph_full.update_edge_attrs(attrs={attr: [value]}, edge_ids=[edge_id]) def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]): for edge, value in zip(edges, values, strict=False): - edge_id = self.graph.edge_id(edge[0], edge[1]) - self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) + edge_id = self.graph_full.edge_id(edge[0], edge[1]) + self.graph_full.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def get_edge_attr(self, edge: Edge, attr: str): - if attr not in self.graph.edge_attr_keys(): + """Get an attribute value for a single edge (resolved on graph_full). + + Returns None if the attribute is not registered on the graph. + """ + if attr not in self.graph_full.edge_attr_keys(): return None - edge_id = self.graph.edge_id(edge[0], edge[1]) - return self.graph.edges[edge_id][attr] + edge_id = self.graph_full.edge_id(edge[0], edge[1]) + return self.graph_full.edges[edge_id][attr] def get_edges_attr(self, edges: Iterable[Edge], attr: str): + """Get an attribute value for each of the given edges.""" return [self.get_edge_attr(edge, attr) for edge in edges] # ========== Feature Management ========== @@ -750,24 +816,36 @@ def add_feature(self, key: str, feature: Feature) -> None: # Add to the features dictionary self.features[key] = feature - # Perform custom graph operations when a feature is added + # Perform custom graph operations when a feature is added. + # + # Schema (attr-key) registration is done on graph_solution (the view), NOT on + # graph_full, even though annotators write the VALUES to graph_full. This relies + # on a tracksdata invariant: adding an attr key to a view propagates up to its + # root, so the column ends up on both. The reverse does NOT hold today — adding + # a key directly to the root is not propagated down into an existing view — so + # registering on graph_full would leave graph_solution without the column. + # If tracksdata ever makes view attr-key additions local, revisit this. ft = feature["feature_type"] - if "node" in ft and key not in self.graph.node_attr_keys(): + if "node" in ft and key not in self.graph_solution.node_attr_keys(): # "mask" value_type maps to pl.Object via to_polars_dtype dtype = to_polars_dtype(feature["value_type"]) num_values = feature.get("num_values") if num_values is not None and num_values > 1: dtype = pl.Array(dtype, num_values) - self.graph.add_node_attr_key( + self.graph_solution.add_node_attr_key( key, default_value=feature["default_value"], dtype=dtype, ) - if "edge" in ft and key not in self.graph.edge_attr_keys(): - self.graph.add_edge_attr_key( + if "edge" in ft and key not in self.graph_solution.edge_attr_keys(): + dtype = to_polars_dtype(feature["value_type"]) + num_values = feature.get("num_values") + if num_values is not None and num_values > 1: + dtype = pl.Array(dtype, num_values) + self.graph_solution.add_edge_attr_key( key, default_value=feature["default_value"], - dtype=to_polars_dtype(feature["value_type"]), + dtype=dtype, ) def delete_feature(self, key: str) -> None: @@ -799,8 +877,136 @@ def delete_feature(self, key: str) -> None: else: return - # Perform custom graph operations when a feature is deleted - if "node" in feature_type and key in self.graph.node_attr_keys(): - self.graph.remove_node_attr_key(key) - if "edge" in feature_type and key in self.graph.edge_attr_keys(): - self.graph.remove_edge_attr_key(key) + # Perform custom graph operations when a feature is deleted. Schema ops go + # through graph_solution (the view) and propagate to the root — same tracksdata + # invariant as add_feature (see the note there). + if "node" in feature_type and key in self.graph_solution.node_attr_keys(): + self.graph_solution.remove_node_attr_key(key) + if "edge" in feature_type and key in self.graph_solution.edge_attr_keys(): + self.graph_solution.remove_edge_attr_key(key) + + # ========== Track ID management (solution view) ========== + # These operate on the solution view via the TrackAnnotator, which every Tracks + # has (track ids are a core feature). On an empty solution view they are no-ops. + + @property + def track_annotator(self) -> TrackAnnotator: + """The registered TrackAnnotator. Always present, since track ids are a core + feature of every Tracks (_get_annotators registers one unconditionally).""" + for annotator in self.annotators: + if isinstance(annotator, TrackAnnotator): + return annotator + raise RuntimeError( + "No TrackAnnotator registered on this Tracks — this should be unreachable " + "(_get_annotators always registers one)." + ) + + @property + def max_track_id(self) -> int: + """The maximum tracklet id currently in use.""" + return self.track_annotator.max_tracklet_id + + def get_next_track_id(self) -> int: + """Return the next available track_id. + + The max_tracklet_id in TrackAnnotator is updated automatically when + a node is added or track IDs are updated via UpdateTrackIDs. + """ + return self.track_annotator.max_tracklet_id + 1 + + def get_next_lineage_id(self) -> int: + """Return the next available lineage_id. + + The max_lineage_id in TrackAnnotator is updated automatically when + a node is added or lineage IDs are updated via UpdateTrackIDs. + """ + return self.track_annotator.max_lineage_id + 1 + + def get_track_id(self, node) -> int: + """Get the tracklet id of a single node.""" + track_id = self.get_node_attr(node, self.features.tracklet_key) + return track_id + + def get_track_ids(self, nodes) -> list[int]: + """Batch version of get_track_id — one query fetching all nodes in the graph. + NOTE: always fetches the entire graph internally. Optimised for bulk (all-node) + calls. For small subsets or single nodes use get_track_id() instead.""" + + tracklet_key = self.features.tracklet_key + df = self.graph_full.node_attrs( + attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key] + ) + id_to_val = dict( + zip( + df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list(), + df[tracklet_key].to_list(), + strict=True, + ) + ) + return [id_to_val[node] for node in nodes] + + def get_lineage_id(self, node) -> int: + """Get the lineage ID for a node. + + Args: + node: The node to get lineage ID for + + Returns: + The lineage ID. + """ + return self.get_node_attr(node, self.features.lineage_key) + + def get_track_neighbors( + self, track_id: int, time: int + ) -> tuple[Node | None, Node | None]: + """Get the last node with the given track id before time, and the first node + with the track id after time, if any. Does not assume that a node with + the given track_id and time is already in tracks, but it can be. + + Args: + track_id (int): The track id to search for + time (int): The time point to find the immediate predecessor and successor + for + + Returns: + tuple[Node | None, Node | None]: The last node before time with the given + track id, and the first node after time with the given track id, + or Nones if there are no such nodes. + """ + if ( + track_id not in self.track_annotator.tracklet_id_to_nodes + or len(self.track_annotator.tracklet_id_to_nodes[track_id]) == 0 + ): + return None, None + candidates = sorted( + self.track_annotator.tracklet_id_to_nodes[track_id], key=self.get_time + ) + + pred = None + succ = None + for cand in candidates: + if self.get_time(cand) < time: + pred = cand + elif self.get_time(cand) > time: + succ = cand + break + return ( + int(pred) if pred is not None else None, + int(succ) if succ is not None else None, + ) + + def has_track_id_at_time(self, track_id: int, time: int) -> bool: + """Function to check if a node with given track id exists at given time point. + + Args: + track_id (int): The track id to search for. + time (int): The time point to check. + + Returns: + True if a node with given track id exists at given time point. + """ + nodes = self.track_annotator.tracklet_id_to_nodes.get(track_id) + if not nodes: + return False + + return time in [self.get_time(node) for node in nodes] diff --git a/src/funtracks/import_export/_export_segmentation.py b/src/funtracks/import_export/_export_segmentation.py index cef2ccd8..97f9e146 100644 --- a/src/funtracks/import_export/_export_segmentation.py +++ b/src/funtracks/import_export/_export_segmentation.py @@ -45,7 +45,7 @@ def resolve_relabel_attr( else: return None - existing_attrs = tracks.graph.node_attr_keys() + existing_attrs = tracks.graph_solution.node_attr_keys() if label_attr not in existing_attrs: raise ValueError( f"relabel='{relabel}' resolved to attribute '{label_attr}', " @@ -95,13 +95,12 @@ def export_segmentation( shape = tracks.segmentation.shape - graph = ( - tracks.graph.filter(node_ids=list(node_ids)).subgraph() - if node_ids is not None - else tracks.graph - ) - if label_attr is not None: + graph = ( + tracks.graph_solution.filter(node_ids=list(node_ids)).subgraph() + if node_ids is not None + else tracks.graph_solution + ) view = GraphArrayView(graph, label_attr, shape=shape) def get_frame(t: int) -> np.ndarray: diff --git a/src/funtracks/import_export/_import_segmentation.py b/src/funtracks/import_export/_import_segmentation.py index e1545d94..db04c8a3 100644 --- a/src/funtracks/import_export/_import_segmentation.py +++ b/src/funtracks/import_export/_import_segmentation.py @@ -46,11 +46,11 @@ def load_segmentation(segmentation: Path | np.ndarray | da.Array) -> da.Array: def relabel_segmentation( seg_array: da.Array | np.ndarray, - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, node_ids: ArrayLike, seg_ids: ArrayLike, time_values: ArrayLike, -) -> tuple[np.ndarray, td.graph.GraphView]: +) -> tuple[np.ndarray, td.graph.BaseGraph]: """Relabel segmentation from seg_id to node_id. Handles the case where node_id 0 exists by offsetting all node IDs by 1, @@ -58,7 +58,7 @@ def relabel_segmentation( Args: seg_array: Segmentation array (dask or numpy) - graph: tracksdata GraphView (will be relabeled if node_id 0 exists) + graph: tracksdata base graph (will be relabeled if node_id 0 exists) node_ids: Array of node IDs seg_ids: Array of segmentation IDs corresponding to each node time_values: Array of time values for each node diff --git a/src/funtracks/import_export/_tracks_builder.py b/src/funtracks/import_export/_tracks_builder.py index 7ed3ce75..2bfdd0f8 100644 --- a/src/funtracks/import_export/_tracks_builder.py +++ b/src/funtracks/import_export/_tracks_builder.py @@ -1,6 +1,6 @@ """Builder pattern for importing tracks from various formats. -This module provides a unified interface for constructing SolutionTracks objects +This module provides a unified interface for constructing Tracks objects from different data sources (GEFF, CSV, etc.) while sharing common validation and construction logic. """ @@ -15,7 +15,7 @@ import tracksdata as td from geff._typing import InMemoryGeff -from funtracks.data_model.solution_tracks import SolutionTracks +from funtracks.data_model.tracks import Tracks from funtracks.features import Feature from funtracks.import_export._import_segmentation import ( load_segmentation, @@ -39,7 +39,7 @@ ) from funtracks.utils.tracksdata_utils import ( add_masks_and_bboxes_to_graph, - create_empty_graphview_graph, + create_empty_graph, ) if TYPE_CHECKING: @@ -415,7 +415,7 @@ def construct_graph( self, node_name_map: dict[str, str | list[str]] | None = None, database: str | None = None, - ) -> td.graph.GraphView: + ) -> td.graph.BaseGraph: """Construct Tracksdata graph from validated InMemoryGeff data. Common logic shared across all formats. @@ -427,7 +427,7 @@ def construct_graph( If None (default), an in-memory/temp graph is used. Returns: - Tracksdata GraphView with standard keys + Tracksdata base graph with standard keys Raises: ValueError: If data not loaded or validated @@ -481,7 +481,7 @@ def construct_graph( default_value = 0 node_default_values.append(default_value) - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=list(self.in_memory_geff["node_props"].keys()), edge_attributes=list(self.in_memory_geff["edge_props"].keys()), node_default_values=node_default_values, @@ -536,23 +536,17 @@ def construct_graph( if self.TIME_ATTR != "t": graph.remove_node_attr_key(self.TIME_ATTR) - # create_empty_graphview_graph returns a filtered view, but that view is - # a snapshot at filter time; nodes/edges added afterwards (potentially with - # solution=False) bypass the filter. Re-filter the populated root so - # solution=False rows are actually excluded. - graph = graph._root.filter( - td.NodeAttr("solution") == True, # noqa: E712 - td.EdgeAttr("solution") == True, # noqa: E712 - ).subgraph() - + # Return the full base graph. Tracks builds the solution==True view internally, + # so solution=False candidates added during population stay in graph_full and + # are excluded from graph_solution by Tracks, not here. return graph def handle_segmentation( self, - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, segmentation: Path | np.ndarray | None, scale: list[float] | None, - ) -> tuple[np.ndarray | None, list[float] | None, td.graph.GraphView]: + ) -> tuple[np.ndarray | None, list[float] | None, td.graph.BaseGraph]: """Load, validate, and optionally relabel segmentation. Common logic shared across all formats. @@ -619,13 +613,13 @@ def handle_segmentation( return new_segmentation, scale, graph - # Structural keys that are handled by graph construction / SolutionTracks.__init__ + # Structural keys that are handled by graph construction / Tracks.__init__ # and should not be registered as features by enable_features(). STRUCTURAL_KEYS = frozenset({"time", "id", "parent_id", "seg_id"}) def enable_features( self, - tracks: SolutionTracks, + tracks: Tracks, name_map: dict[str, str | list[str]], feature_type: Literal["node", "edge"] = "node", ) -> None: @@ -638,7 +632,7 @@ def enable_features( - Otherwise, if data was loaded for it, register it as a static feature. Args: - tracks: SolutionTracks object to add features to + tracks: Tracks object to add features to name_map: Mapping from standard funtracks keys to source property names (same format as node_name_map / edge_name_map). feature_type: Type of features ("node" or "edge") @@ -685,7 +679,7 @@ def build( scale: list[float] | None = None, node_name_map: dict[str, str | list[str]] | None = None, database: str | None = None, - ) -> SolutionTracks: + ) -> Tracks: """Orchestrate the full construction process. Args: @@ -697,7 +691,7 @@ def build( If None (default), an in-memory/temp graph is used. Returns: - Fully constructed SolutionTracks object + Fully constructed Tracks object Raises: ValueError: If self.node_name_map is not set or validation fails @@ -769,22 +763,27 @@ def build( if segmentation_array is not None: graph = add_masks_and_bboxes_to_graph(graph, segmentation_array) - # 7. Create SolutionTracks + # 7. Create Tracks # construct_graph() always stores time as "t" (tracksdata convention), # regardless of TIME_ATTR, so we pass "t" here explicitly. # If a FeatureDict was loaded (e.g., from GEFF metadata), use it directly if hasattr(self, "features") and self.features is not None: - tracks = SolutionTracks( + tracks = Tracks( graph=graph, ndim=self.ndim, scale=scale, features=self.features, ) else: - tracks = SolutionTracks( + # The builder always produces a solution, so declare tracklet/lineage + # intent to register a TrackAnnotator. Tracks.__init__ auto-detects whether + # these attrs already exist on the graph (activate) or need computing. + tracks = Tracks( graph=graph, pos_attr="pos", time_attr="t", + tracklet_attr="tracklet_id", + lineage_attr="lineage_id", ndim=self.ndim, scale=scale, ) diff --git a/src/funtracks/import_export/_utils.py b/src/funtracks/import_export/_utils.py index 334cf42d..2f8d2feb 100644 --- a/src/funtracks/import_export/_utils.py +++ b/src/funtracks/import_export/_utils.py @@ -71,7 +71,7 @@ def infer_dtype_from_array(arr: ArrayLike) -> ValueType: def filter_graph_with_ancestors( - graph: td.graph.GraphView, nodes_to_keep: set[int] + graph: td.graph.BaseGraph, nodes_to_keep: set[int] ) -> list[int]: """Filter a graph to keep only the nodes in `nodes_to_keep` and their ancestors. diff --git a/src/funtracks/import_export/_v1_format.py b/src/funtracks/import_export/_v1_format.py index 34856ff1..1bff2c12 100644 --- a/src/funtracks/import_export/_v1_format.py +++ b/src/funtracks/import_export/_v1_format.py @@ -16,16 +16,14 @@ ) if TYPE_CHECKING: - from ..data_model import SolutionTracks, Tracks + from ..data_model import Tracks GRAPH_FILE = "graph.json" SEG_FILE = "seg.npy" ATTRS_FILE = "attrs.json" -def load_v1_tracks( - directory: Path, seg_required: bool = False, solution: bool = False -) -> Tracks | SolutionTracks: +def load_v1_tracks(directory: Path, seg_required: bool = False) -> Tracks: """Load a Tracks object from the given directory. Looks for files in the format generated by Tracks.save. @@ -35,8 +33,6 @@ def load_v1_tracks( directory (Path): The directory containing tracks to load seg_required (bool, optional): If true, raises a FileNotFoundError if the segmentation file is not present in the directory. Defaults to False. - solution (bool, optional): If true, returns a SolutionTracks object, otherwise - returns a normal Tracks object. Defaults to False. Returns: Tracks: A tracks object loaded from the given directory @@ -108,23 +104,24 @@ def load_v1_tracks( ) features[td.DEFAULT_ATTR_KEYS.BBOX] = SegBbox(ndim) - # filtering the warnings because the default values of time_attr and pos_attr are - # not None. Therefore, new style Tracks attrs that have features instead of - # pos_attr and time_attr will always trigger the warning. Updating default values - # is breaking, and manually setting the attrs to None if features is present will - # break if the attrs are changed/removed in the future. Can remove in v2.0. - # Import at runtime to avoid circular dependency - from ..data_model import SolutionTracks, Tracks + # Defensive: suppress Tracks.__init__'s "provided both FeatureDict and attr" + # warning in case a save carries both a stored FeatureDict and explicit + # pos/time/tracklet attrs. With a FeatureDict present we pass no attr args, so + # this normally does not fire, but keep the guard for forward-compat. Can remove + # in v2.0. Import at runtime to avoid circular dependency. + from ..data_model import Tracks + + # v1 saves are always solutions, so they must come back with track ids. When the + # stored attrs don't carry a FeatureDict (older saves), declare tracklet/lineage + # intent explicitly so a TrackAnnotator is registered. Newer saves carry a + # FeatureDict that already encodes the tracklet/lineage keys. + if "features" not in attrs: + attrs.setdefault("tracklet_attr", "tracklet_id") + attrs.setdefault("lineage_attr", "lineage_id") with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="Provided both FeatureDict and pos_attr or time_attr" - ) - tracks: Tracks - if solution: - tracks = SolutionTracks(graph_td, **attrs) - else: - tracks = Tracks(graph_td, **attrs) + warnings.filterwarnings("ignore", message="Provided both FeatureDict and pos") + tracks = Tracks(graph_td, **attrs) return tracks diff --git a/src/funtracks/import_export/_validation.py b/src/funtracks/import_export/_validation.py index 415cf90c..27cc028c 100644 --- a/src/funtracks/import_export/_validation.py +++ b/src/funtracks/import_export/_validation.py @@ -22,7 +22,7 @@ def validate_graph_seg_match( - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, segmentation: da.Array, scale: list[float], position_attr: list[str], diff --git a/src/funtracks/import_export/csv/_export.py b/src/funtracks/import_export/csv/_export.py index 6f13c546..f4d42c35 100644 --- a/src/funtracks/import_export/csv/_export.py +++ b/src/funtracks/import_export/csv/_export.py @@ -12,14 +12,15 @@ from .._utils import filter_graph_with_ancestors if TYPE_CHECKING: - from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Tracks def export_to_csv( - tracks: SolutionTracks, + tracks: Tracks, outfile: Path | str, color_dict: dict[int, np.ndarray] | None = None, node_ids: set[int] | None = None, + export_full: bool = False, use_display_names: bool = False, export_seg: bool = False, seg_path: Path | str | None = None, @@ -36,12 +37,16 @@ def export_to_csv( tiff. If a color dictionary is provided, it will also export the tracklet colors. Args: - tracks: SolutionTracks object containing the tracking data to export + tracks: Tracks object containing the tracking data to export outfile: Path to output CSV file color_dict: dict[int, np.ndarray], optional. If provided, will be used to save the hex colors. node_ids: Optional set of node IDs to include. If provided, only these nodes and their ancestors will be included in the output. + export_full: If True, export the full graph (every node/edge, including + soft-deleted/candidate ones with solution=False); the "solution" column is + then included so the two can be distinguished. If False (default), export + only the solution view. use_display_names: If True, use feature display names as column headers. If False (default), use raw feature keys for backward compatibility. export_seg: Whether to export the segmentation alongside the CSV. @@ -70,6 +75,10 @@ def export_to_csv( """ tracklet_key = tracks.features.tracklet_key + # Which graph to export: the full graph (incl. solution=False candidates) or just + # the solution view. Topology reads (node ids, ancestors, predecessors) go through + # this; attribute reads use the tracks helpers, which already target graph_full. + graph = tracks.graph_full if export_full else tracks.graph_solution def convert_numpy_to_python(value): """Convert numpy types to native Python types.""" @@ -122,8 +131,10 @@ def convert_numpy_to_python(value): # Skip derived features (e.g. bbox managed by mask) if feature_name in derived_keys: continue - # Skip solution — graph is already filtered to solution=True - if feature_name == "solution": + # Skip solution unless exporting the full graph: in the solution-only + # export it is always True (uninformative); in a full export it + # distinguishes solution nodes from soft-deleted/candidate ones. + if feature_name == "solution" and not export_full: continue feature_names.append(feature_name) num_values = feature_dict.get("num_values", 1) @@ -155,15 +166,15 @@ def convert_numpy_to_python(value): # Determine which nodes to export if node_ids is None: - nodes_to_keep = tracks.graph.node_ids() + nodes_to_keep = graph.node_ids() else: - nodes_to_keep = filter_graph_with_ancestors(tracks.graph, node_ids) + nodes_to_keep = filter_graph_with_ancestors(graph, node_ids) # Write CSV file rows: list[dict[str, Any]] = [] for node_id in nodes_to_keep: - parents = list(tracks.graph.predecessors(node_id)) + parents = list(graph.predecessors(node_id)) parent_id = "" if len(parents) == 0 else parents[0] row: dict[str, Any] @@ -210,7 +221,7 @@ def rgb_to_hex(rgb): track_id_to_hex = {} - for track_id, nodes in tracks.track_id_to_node.items(): + for track_id, nodes in tracks.track_annotator.tracklet_id_to_nodes.items(): if not nodes: continue first_node = nodes[0] diff --git a/src/funtracks/import_export/csv/_import.py b/src/funtracks/import_export/csv/_import.py index 983a4efd..7aeae87c 100644 --- a/src/funtracks/import_export/csv/_import.py +++ b/src/funtracks/import_export/csv/_import.py @@ -17,7 +17,7 @@ from .._tracks_builder import TracksBuilder, flatten_name_map if TYPE_CHECKING: - from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Tracks def _ensure_integer_ids(df: pd.DataFrame) -> pd.DataFrame: @@ -169,12 +169,12 @@ def tracks_from_df( segmentation: np.ndarray | None = None, scale: list[float] | None = None, node_name_map: dict[str, str | list[str]] | None = None, -) -> SolutionTracks: +) -> Tracks: """Import tracks from pandas DataFrame. Turns a pandas DataFrame with columns: time, [z], y, x, id, parent_id, [seg_id], [optional custom attr 1], ... - into a SolutionTracks object. + into a Tracks object. Cells without a parent_id will have an empty string or a -1 for the parent_id. @@ -195,7 +195,7 @@ def tracks_from_df( If None, column names are auto-inferred using fuzzy matching. Returns: - SolutionTracks: a solution tracks object + Tracks: a solution tracks object Raises: ValueError: if the segmentation IDs in the dataframe do not match the provided diff --git a/src/funtracks/import_export/geff/_export.py b/src/funtracks/import_export/geff/_export.py index ad496bcd..7ef00d1d 100644 --- a/src/funtracks/import_export/geff/_export.py +++ b/src/funtracks/import_export/geff/_export.py @@ -34,6 +34,11 @@ def write_to_geff( geff store directly to *path*. Intended for internal save/load workflows where the user picks the ``.geff`` path. + Note: only ``graph_solution`` is written. Soft-deleted (``solution=False``) + candidates are dropped, and reimport marks everything ``solution=True`` — so + undo-ability of past deletes does not survive a save/load round-trip (Phase-1 + design; candidate persistence is deferred to the candidate/solver phase). + Args: tracks: Tracks object containing a graph to save. path: Destination path for the geff store. @@ -65,6 +70,9 @@ def export_to_geff( ): """Export the Tracks graph to geff. + Only the solution graph is exported; soft-deleted (``solution=False``) + candidates are not included. + Args: tracks (Tracks): Tracks object containing a graph to save. directory (Path): Destination directory for saving the Zarr. @@ -88,7 +96,7 @@ def export_to_geff( if node_ids is not None: nodes_to_keep = filter_graph_with_ancestors( - tracks.graph, node_ids + tracks.graph_solution, node_ids ) # include the ancestors to make sure the graph is valid and has no missing # parent nodes. @@ -98,7 +106,7 @@ def export_to_geff( # Include the FeatureDict in metadata only for full exports. # Subgroup exports do not necessarily have valid tracklet/lineage IDs - # and thus are not valid SolutionTracks + # and thus are not valid Tracks graph, metadata = _build_geff_metadata(tracks, include_features=(node_ids is None)) # Save segmentation if present and requested @@ -193,7 +201,7 @@ def _write_segmentation_shape(geff_path: Path, tracks: Tracks) -> None: This allows import_from_geff to reconstruct the segmentation (GraphArrayView) without requiring an external segmentation file. """ - seg_shape = tracks.graph.metadata.get("segmentation_shape") + seg_shape = tracks.graph_full.metadata.get("segmentation_shape") if seg_shape is not None: import zarr as _zarr @@ -222,7 +230,7 @@ def split_position_attr(tracks: Tracks) -> tuple[td.graph.GraphView, list[str] | if isinstance(pos_key, str): # Position is stored as a single attribute, need to split - new_graph = tracks.graph.detach() + new_graph = tracks.graph_solution.detach() new_graph = new_graph.filter().subgraph() # Register new attribute keys @@ -254,6 +262,6 @@ def split_position_attr(tracks: Tracks) -> tuple[td.graph.GraphView, list[str] | return new_graph, new_keys elif pos_key is not None: # Position is already split into separate attributes - return tracks.graph, list(pos_key) + return tracks.graph_solution, list(pos_key) else: - return tracks.graph, None + return tracks.graph_solution, None diff --git a/src/funtracks/import_export/geff/_import.py b/src/funtracks/import_export/geff/_import.py index 060a481e..8cd6ce9d 100644 --- a/src/funtracks/import_export/geff/_import.py +++ b/src/funtracks/import_export/geff/_import.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from pathlib import Path - from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Tracks # defining constants here because they are only used in the context of import @@ -170,7 +170,7 @@ def read_header(self, source_path: Path) -> None: self._geff_axes = metadata.axes or [] # Read funtracks FeatureDict from GEFF extra metadata if present - # This will be passed to SolutionTracks via the base build() method + # This will be passed to Tracks via the base build() method if metadata.extra and "funtracks" in metadata.extra: funtracks_extra = metadata.extra["funtracks"] if "features" in funtracks_extra: @@ -243,7 +243,7 @@ def construct_graph( self, node_name_map: dict[str, str | list[str]] | None = None, database: str | None = None, - ) -> td.graph.GraphView: + ) -> td.graph.BaseGraph: """Construct graph and prepare embedded segmentation data. The GEFF format serialises mask data as plain numeric arrays (zarr @@ -321,7 +321,7 @@ def import_from_geff( scale: list[float] | None = None, edge_name_map: dict[str, str | list[str]] | None = None, database: str | None = None, -) -> SolutionTracks: +) -> Tracks: """Import tracks from GEFF format. Args: @@ -341,7 +341,7 @@ def import_from_geff( If None (default), an in-memory/temp graph is used. Returns: - SolutionTracks object + Tracks object """ # Filter out None values and "None" strings from node_name_map # (e.g., {"lineage_id": None} or {"lineage_id": "None"}) diff --git a/src/funtracks/user_actions/_user_swap_predecessors.py b/src/funtracks/user_actions/_user_swap_predecessors.py index 2353a4f0..baefaaa7 100644 --- a/src/funtracks/user_actions/_user_swap_predecessors.py +++ b/src/funtracks/user_actions/_user_swap_predecessors.py @@ -9,7 +9,7 @@ from .user_delete_edge import UserDeleteEdge if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserSwapPredecessors(ActionGroup): @@ -19,7 +19,7 @@ class UserSwapPredecessors(ActionGroup): must be earlier in time than both nodes for the swap to be valid. Args: - tracks (SolutionTracks): The tracks to perform the swap on. + tracks (Tracks): The tracks to perform the swap on. nodes (tuple[Node, Node]): A tuple with two nodes. Raises: @@ -30,17 +30,17 @@ class UserSwapPredecessors(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, nodes: tuple[int, int], ): super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # narrow type + self.tracks: Tracks # narrow type if len(nodes) != 2: raise InvalidActionError("You can only swap a pair of two nodes.") node1, node2 = nodes - graph = tracks.graph + graph = tracks.graph_solution # Find predecessors pred1 = graph.predecessors(node1)[0] if graph.predecessors(node1) else None diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py index cf05f002..caf8f3c8 100644 --- a/src/funtracks/user_actions/user_add_edge.py +++ b/src/funtracks/user_actions/user_add_edge.py @@ -11,14 +11,14 @@ from .user_delete_edge import UserDeleteEdge if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserAddEdge(ActionGroup): """Assumes that the endpoints already exist and have track ids. Args: - tracks (SolutionTracks): the tracks to add the edge to + tracks (Tracks): the tracks to add the edge to edge (tuple[int, int]): The edge to add force (bool, optional): Whether to force the action by removing any conflicting edges. Defaults to False. @@ -29,26 +29,26 @@ class UserAddEdge(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, edge: tuple[int, int], force: bool = False, _top_level: bool = True, ): super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class source, target = edge - if not tracks.graph.has_node(source): + if not tracks.graph_solution.has_node(source): raise InvalidActionError( f"Source node {source} not in solution yet - must be added before edge" ) - if not tracks.graph.has_node(target): + if not tracks.graph_solution.has_node(target): raise InvalidActionError( f"Target node {target} not in solution yet - must be added before edge" ) # Check if making a merge. If yes and force, remove the other edge and update # track ids. - in_degree_target = self.tracks.graph.in_degree(target) + in_degree_target = len(self.tracks.predecessors(target)) # type: ignore if in_degree_target > 0: if not force: raise InvalidActionError( @@ -57,7 +57,7 @@ def __init__( forceable=True, ) else: - pred = next(iter(self.tracks.graph.predecessors(target))) + pred = next(iter(self.tracks.graph_solution.predecessors(target))) merge_edge = (pred, target) warnings.warn( f"Removing edge {merge_edge} to add new edge without merging.", @@ -68,7 +68,7 @@ def __init__( ) # update track ids if needed - out_degree_source = self.tracks.graph.out_degree(source) + out_degree_source = len(self.tracks.successors(source)) if out_degree_source == 0: # joining two segments # assign the track id and lineage id of the source node to the target # and all downstream nodes @@ -79,7 +79,7 @@ def __init__( ) elif out_degree_source == 1: # creating a division # assign a new track id to existing child (lineage stays the same) - successor = next(iter(self.tracks.graph.successors(source))) + successor = next(iter(self.tracks.graph_solution.successors(source))) self.actions.append( UpdateTrackIDs(self.tracks, successor, self.tracks.get_next_track_id()) ) diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index d1160e7a..af4ad0f9 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -15,7 +15,7 @@ from .user_delete_edge import UserDeleteEdge if TYPE_CHECKING: - from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Tracks class UserAddNode(ActionGroup): @@ -30,7 +30,7 @@ class UserAddNode(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: int, attributes: dict[str, Any], pixels: tuple[np.ndarray, ...] | None = None, @@ -39,7 +39,7 @@ def __init__( ): """ Args: - tracks (SolutionTracks): the tracks to add the node to + tracks (Tracks): the tracks to add the node to node (int): The node id of the new node to add attributes (dict[str, Any]): A dictionary from attribute strings to values. Must contain "time" and tracks.features.tracklet_key. @@ -63,7 +63,7 @@ def __init__( time point (forceable). """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class # Get keys from tracks features time_key = tracks.features.time_key @@ -77,7 +77,7 @@ def __init__( raise InvalidActionError( f"Cannot add node without track id. Please add {track_id_key} attribute" ) - if self.tracks.graph.has_node(node): + if self.tracks.graph_full.has_node(node): raise InvalidActionError( f"Node {node} already exists in the tracks, cannot add." ) @@ -98,7 +98,7 @@ def __init__( pred, succ = self.tracks.get_track_neighbors(track_id, time) # check if you are adding a node to a track that divided previously - if pred is not None and self.tracks.graph.out_degree(int(pred)) == 2: + if pred is not None and len(self.tracks.successors(pred)) == 2: if not force: raise InvalidActionError( "Cannot add node here - upstream division event detected.", @@ -118,11 +118,11 @@ def __init__( # downstream elif succ is not None: # check pred of succ - preds = self.tracks.graph.predecessors(succ) + preds = self.tracks.predecessors(succ) pred_of_succ = preds[0] if preds else None if ( pred_of_succ is not None - and self.tracks.graph.out_degree(pred_of_succ) == 2 + and len(self.tracks.successors(pred_of_succ)) == 2 ): if not force: raise InvalidActionError( diff --git a/src/funtracks/user_actions/user_delete_edge.py b/src/funtracks/user_actions/user_delete_edge.py index 8af17c80..098a73ac 100644 --- a/src/funtracks/user_actions/user_delete_edge.py +++ b/src/funtracks/user_actions/user_delete_edge.py @@ -9,19 +9,19 @@ from ..actions.update_track_id import UpdateTrackIDs if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserDeleteEdge(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, edge: tuple[int, int], _top_level: bool = True, ): """ Args: - tracks (SolutionTracks): The tracks to delete the edge from. + tracks (Tracks): The tracks to delete the edge from. edge (tuple[int, int]): The edge to delete. _top_level (bool): If True, add this action to the history and emit refresh. Set to False when used as a sub-action inside a compound @@ -31,12 +31,12 @@ def __init__( InvalidActionError: If the edge does not exist in the graph. """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class - if not self.tracks.graph.has_edge(*edge): + self.tracks: Tracks # Narrow type from base class + if not self.tracks.graph_solution.has_edge(*edge): raise InvalidActionError(f"Edge {edge} not in solution, can't remove") self.actions.append(DeleteEdge(tracks, edge)) - out_degree = self.tracks.graph.out_degree(edge[0]) + out_degree = len(self.tracks.successors(edge[0])) if out_degree == 0: # removed a normal (non division) edge # orphaned segment gets new track id and new lineage id new_track_id = self.tracks.get_next_track_id() @@ -46,7 +46,7 @@ def __init__( ) elif out_degree == 1: # removed a division edge # sibling gets parent's track id (lineage stays the same) - sibling = next(iter(self.tracks.graph.successors(edge[0]))) + sibling = next(iter(self.tracks.graph_solution.successors(edge[0]))) new_track_id = self.tracks.get_track_id(edge[0]) self.actions.append(UpdateTrackIDs(self.tracks, sibling, new_track_id)) # orphaned child gets a new lineage id (now a separate component) diff --git a/src/funtracks/user_actions/user_delete_node.py b/src/funtracks/user_actions/user_delete_node.py index ef588560..86416587 100644 --- a/src/funtracks/user_actions/user_delete_node.py +++ b/src/funtracks/user_actions/user_delete_node.py @@ -10,20 +10,20 @@ from ..actions.update_track_id import UpdateTrackIDs if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserDeleteNode(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: int, pixels: None | tuple[np.ndarray, ...] = None, _top_level: bool = True, ): """ Args: - tracks (SolutionTracks): The tracks to delete the node from. + tracks (Tracks): The tracks to delete the node from. node (int): The node id to delete. pixels (tuple[np.ndarray, ...] | None): The pixels of the node in the segmentation, if known. Will be computed if not provided. @@ -33,7 +33,7 @@ def __init__( action. Defaults to True. """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class # delete adjacent edges for pred in self.tracks.predecessors(node): siblings = self.tracks.successors(pred) diff --git a/src/funtracks/user_actions/user_delete_nodes.py b/src/funtracks/user_actions/user_delete_nodes.py index 05f30bb2..279351b1 100644 --- a/src/funtracks/user_actions/user_delete_nodes.py +++ b/src/funtracks/user_actions/user_delete_nodes.py @@ -8,7 +8,7 @@ from .user_delete_node import UserDeleteNode if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserDeleteNodes(ActionGroup): @@ -26,12 +26,12 @@ class UserDeleteNodes(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, nodes: list[int], pixels: None | list[tuple[np.ndarray, ...]] = None, ): super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class for i, node in enumerate(nodes): self.actions.append( UserDeleteNode( diff --git a/src/funtracks/user_actions/user_update_node_attrs.py b/src/funtracks/user_actions/user_update_node_attrs.py index 7cb8cfe2..27616200 100644 --- a/src/funtracks/user_actions/user_update_node_attrs.py +++ b/src/funtracks/user_actions/user_update_node_attrs.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserUpdateNodeAttrs(ActionGroup): @@ -21,14 +21,14 @@ class UserUpdateNodeAttrs(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, node: int, attrs: dict[str, Any], _top_level: bool = True, ): """ Args: - tracks (SolutionTracks): The tracks to update the node attributes for + tracks (Tracks): The tracks to update the node attributes for node (int): The node to update the attributes for attrs (dict[str, Any]): A mapping from attribute name to new attribute values for the given node. @@ -40,7 +40,7 @@ def __init__( ValueError: If a protected attribute is in the given attribute mapping. """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class # Call the basic UpdateNodeAttrs action self.actions.append(UpdateNodeAttrs(tracks, node, attrs)) diff --git a/src/funtracks/user_actions/user_update_nodes_attrs.py b/src/funtracks/user_actions/user_update_nodes_attrs.py index e5cae4fd..40e088ec 100644 --- a/src/funtracks/user_actions/user_update_nodes_attrs.py +++ b/src/funtracks/user_actions/user_update_nodes_attrs.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserUpdateNodesAttrs(ActionGroup): @@ -28,12 +28,12 @@ class UserUpdateNodesAttrs(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, nodes: list[int], attrs: dict[str, list[Any]], ): super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class for key, values in attrs.items(): if not isinstance(values, list): raise ValueError( diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index 6a312265..04941e52 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -12,13 +12,13 @@ from .user_delete_node import UserDeleteNode if TYPE_CHECKING: - from funtracks.data_model import SolutionTracks + from funtracks.data_model import Tracks class UserUpdateSegmentation(ActionGroup): def __init__( self, - tracks: SolutionTracks, + tracks: Tracks, new_value: int, updated_pixels: list[tuple[tuple[np.ndarray, ...], int]], current_track_id: int, @@ -30,7 +30,7 @@ def __init__( add_node action doesn't have anything with pixels. Args: - tracks (SolutionTracks): The solution tracks that the user is updating. + tracks (Tracks): The solution tracks that the user is updating. new_value (int): The new value that the user painted with updated_pixels (list[tuple[tuple[np.ndarray, ...], int]]): A list of node update actions, consisting of a numpy multi-index, pointing to the array @@ -42,7 +42,7 @@ def __init__( Defaults to False. """ super().__init__(tracks, actions=[]) - self.tracks: SolutionTracks # Narrow type from base class + self.tracks: Tracks # Narrow type from base class node_to_select = None if self.tracks.segmentation is None: raise ValueError("Cannot update non-existing segmentation.") @@ -62,7 +62,21 @@ def __init__( "Can only update one time point at a time" ) time = int(all_pixels[0][0]) - if self.tracks.graph.has_node(new_value): + if self.tracks.graph_full.has_node(new_value): + # An id that already names a node must take the update path: you cannot + # create a *new* node on a taken id. Ids are globally unique across + # graph_full (full ⊇ solution, never reused), so graph_full is the + # correct "does this id name a node?" check; using graph_solution here + # would misroute a soft-deleted id to the add path and silently revive + # the old node with stale attributes. + # NOTE: reviving a soft-deleted (solution=False) node by painting it + # back is not supported yet — UpdateNodeSeg reads the solution view and + # would fail. Guard explicitly until revive-by-paint is implemented. + if not self.tracks.graph_solution.has_node(new_value): + raise NotImplementedError( + f"Cannot paint onto node {new_value}: it is soft-deleted (not " + "in the solution). Revive-by-paint is not supported yet." + ) mask_pixels = pixels_to_td_mask(all_pixels, self.tracks.ndim) self.actions.append( UpdateNodeSeg(tracks, new_value, mask_pixels, added=True) @@ -70,8 +84,6 @@ def __init__( else: time_key = tracks.features.time_key tracklet_key = tracks.features.tracklet_key - if tracklet_key is None: - raise ValueError("Track ID key is not set in tracks features") attrs: dict[str, int] = { time_key: time, tracklet_key: current_track_id, @@ -96,7 +108,7 @@ def __init__( time = pixels[0][0] # check if all pixels of old_value are removed mask_pixels = pixels_to_td_mask(pixels, self.tracks.ndim) - mask_old_value = self.tracks.graph.nodes[old_value]["mask"] + mask_old_value = self.tracks.graph_full.nodes[old_value]["mask"] # If pixels fully overlaps with old_value mask, delete node if mask_pixels.intersection(mask_old_value) == mask_old_value.mask.sum(): self.actions.append( diff --git a/src/funtracks/utils/__init__.py b/src/funtracks/utils/__init__.py index c0ff3cd2..8c7867f4 100644 --- a/src/funtracks/utils/__init__.py +++ b/src/funtracks/utils/__init__.py @@ -10,10 +10,10 @@ setup_zarr_array, setup_zarr_group, ) -from .tracksdata_utils import create_empty_graphview_graph +from .tracksdata_utils import create_empty_graph __all__ = [ - "create_empty_graphview_graph", + "create_empty_graph", "detect_zarr_spec_version", "get_store_path", "is_zarr_v3", diff --git a/src/funtracks/utils/_segmentation_utils.py b/src/funtracks/utils/_segmentation_utils.py index 73edf470..4c13e443 100644 --- a/src/funtracks/utils/_segmentation_utils.py +++ b/src/funtracks/utils/_segmentation_utils.py @@ -25,7 +25,7 @@ def relabel_segmentation_with_track_id( # Division nodes have out_degree > 1; their outgoing edges are cut so that # each daughter cell starts a new tracklet division_nodes = { - n for n in solution_graph.node_ids() if solution_graph.out_degree(n) > 1 + n for n in solution_graph.node_ids() if len(solution_graph.successors(n)) > 1 } visited: set = set() diff --git a/src/funtracks/utils/tracksdata_utils.py b/src/funtracks/utils/tracksdata_utils.py index 9291ec55..3b1f3887 100644 --- a/src/funtracks/utils/tracksdata_utils.py +++ b/src/funtracks/utils/tracksdata_utils.py @@ -67,7 +67,7 @@ def to_polars_dtype(dtype_or_value: str | Any) -> pl.DataType: raise ValueError(f"Unsupported type: {type(dtype_or_value)}") -def create_empty_graphview_graph( +def create_empty_graph( node_attributes: list[str] | None = None, edge_attributes: list[str] | None = None, node_default_values: list[Any] | None = None, @@ -75,9 +75,9 @@ def create_empty_graphview_graph( database: str | None = None, position_attrs: list[str] | None = None, ndim: int = 3, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """ - Create an empty tracksdata GraphView with standard node and edge attributes. + Create an empty tracksdata base graph with standard node and edge attributes. Parameters ---------- node_attributes : list[str] | None @@ -102,8 +102,9 @@ def create_empty_graphview_graph( Returns ------- - td.graph.GraphView - An empty tracksdata GraphView with standard node and edge attributes. + td.graph.BaseGraph + An empty tracksdata base graph with standard node and edge attributes + (including a `solution` flag). Tracks builds the solution==True view from it. """ if position_attrs is None: position_attrs = ["pos"] @@ -187,12 +188,8 @@ def create_empty_graphview_graph( if "solution" not in graph_td.edge_attr_keys(): graph_td.add_edge_attr_key("solution", default_value=True, dtype=pl.Boolean) - graph_td_sub = graph_td.filter( - td.NodeAttr("solution") == True, # noqa: E712 - td.EdgeAttr("solution") == True, # noqa: E712 - ).subgraph() - - return graph_td_sub + # Return the full base graph; Tracks builds the solution==True view internally. + return graph_td def assert_node_attrs_equal_with_masks( @@ -202,8 +199,8 @@ def assert_node_attrs_equal_with_masks( Fully compare the content of two graphs (node attributes and Masks) """ - if isinstance(object1, td.graph.GraphView) and ( - isinstance(object2, td.graph.GraphView) + if isinstance(object1, td.graph.BaseGraph) and ( + isinstance(object2, td.graph.BaseGraph) ): node_attrs1 = object1.node_attrs() node_attrs2 = object2.node_attrs() @@ -385,21 +382,21 @@ def segmentation_to_masks( def add_masks_and_bboxes_to_graph( - graph: td.graph.GraphView, + graph: td.graph.BaseGraph, segmentation: np.ndarray, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """Add mask and bbox attributes to graph nodes from segmentation. Parameters ---------- - graph : td.graph.GraphView + graph : td.graph.BaseGraph Graph to add attributes to segmentation : np.ndarray Segmentation array of shape (T, Z, Y, X) or (T, Y, X) Returns ------- - td.graph.GraphView + td.graph.BaseGraph Graph with 'mask' and 'bbox' attributes added to nodes """ @@ -480,14 +477,11 @@ def td_relabel_nodes(graph, mapping: dict[int, int]) -> td.graph.IndexedRXGraph: } new_graph.add_edge(source_id, target_id, attrs) - new_graph_sub = new_graph.filter( - td.NodeAttr("solution") == True, # noqa: E712 - td.EdgeAttr("solution") == True, # noqa: E712 - ).subgraph() - return new_graph_sub + # Return the full base graph; Tracks builds the solution==True view internally. + return new_graph -def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.GraphView: +def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.BaseGraph: """Convert a NetworkX DiGraph to a tracksdata graph. Args: @@ -584,10 +578,5 @@ def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.GraphView: attrs_copy["solution"] = True graph_td.add_edge(source_id, target_id, attrs_copy) - # Create subgraph (GraphView) with only solution nodes and edges - graph_td_sub = graph_td.filter( - td.NodeAttr("solution") == True, # noqa: E712 - td.EdgeAttr("solution") == True, # noqa: E712 - ).subgraph() - - return graph_td_sub + # Return the full base graph; Tracks builds the solution==True view internally. + return graph_td diff --git a/tests/actions/test_action_history.py b/tests/actions/test_action_history.py index 4365f22c..96fd5218 100644 --- a/tests/actions/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -1,18 +1,18 @@ from funtracks.actions import AddNode from funtracks.actions.action_history import ActionHistory -from funtracks.data_model import SolutionTracks -from funtracks.utils.tracksdata_utils import create_empty_graphview_graph +from funtracks.data_model import Tracks +from funtracks.utils.tracksdata_utils import create_empty_graph # https://github.com/zaboople/klonk/blob/master/TheGURQ.md def test_action_history(): history = ActionHistory() - empty_graph = create_empty_graphview_graph( + empty_graph = create_empty_graph( node_attributes=["track_id", "pos"], edge_attributes=[], ) - tracks = SolutionTracks(empty_graph, ndim=3, tracklet_attr="track_id", time_attr="t") + tracks = Tracks(empty_graph, ndim=3, tracklet_attr="track_id", time_attr="t") pos = [0, 1] action1 = AddNode(tracks, node=0, attributes={"t": 0, "pos": pos, "track_id": 1}) @@ -24,7 +24,7 @@ def test_action_history(): history.add_new_action(action1) # undo the action assert history.undo() - assert tracks.graph.num_nodes() == 0 + assert tracks.graph_solution.num_nodes() == 0 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 1 assert history._undo_pointer == -1 @@ -34,7 +34,7 @@ def test_action_history(): # redo the action assert history.redo() - assert tracks.graph.num_nodes() == 1 + assert tracks.graph_solution.num_nodes() == 1 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 0 assert history._undo_pointer == 0 @@ -46,7 +46,7 @@ def test_action_history(): assert history.undo() action2 = AddNode(tracks, node=10, attributes={"t": 10, "pos": pos, "track_id": 2}) history.add_new_action(action2) - assert tracks.graph.num_nodes() == 1 + assert tracks.graph_solution.num_nodes() == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 0 @@ -55,7 +55,7 @@ def test_action_history(): # undo back to after action 1 assert history.undo() assert history.undo() - assert tracks.graph.num_nodes() == 1 + assert tracks.graph_solution.num_nodes() == 1 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 2 diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index de856974..13fb1213 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -8,9 +8,9 @@ AddEdge, DeleteEdge, ) -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks from funtracks.features import FeatureDict, LineageID, Position, Time, TrackletID -from funtracks.utils.tracksdata_utils import create_empty_graphview_graph +from funtracks.utils.tracksdata_utils import create_empty_graph iou_key = "iou" @@ -18,26 +18,26 @@ @pytest.mark.parametrize("ndim", [3, 4]) @pytest.mark.parametrize("with_seg", [True, False]) def test_add_delete_edges(get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - reference_graph = tracks.graph + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + reference_graph = tracks.graph_solution reference_seg = np.asarray(tracks.segmentation).copy() # Create an empty tracks with just nodes (no edges) - for edge in tracks.graph.edge_list(): - tracks.graph.remove_edge(*edge) + for edge in tracks.graph_solution.edge_list(): + tracks.graph_solution.remove_edge(*edge) edges = [(1, 2), (1, 3), (3, 4), (4, 5)] action = ActionGroup(tracks=tracks, actions=[AddEdge(tracks, edge) for edge in edges]) - with pytest.raises(ValueError, match="Edge .* already exists in the graph"): + with pytest.raises(ValueError, match="Edge .* already exists in the solution"): AddEdge(tracks, (1, 2)) # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) assert_frame_equal( - tracks.graph.edge_attrs(), + tracks.graph_solution.edge_attrs(), reference_graph.edge_attrs(), check_row_order=False, check_column_order=False, @@ -47,16 +47,18 @@ def test_add_delete_edges(get_tracks, ndim, with_seg): inverse = action.inverse() - assert set(tracks.graph.edge_ids()) == set() + assert set(tracks.graph_solution.edge_ids()) == set() if tracks.segmentation is not None: assert_array_almost_equal(tracks.segmentation, reference_seg) re_added = inverse.inverse() - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - assert set(tracks.graph.edge_ids()) == set(reference_graph.edge_ids()) - assert sorted(tracks.graph.edge_list()) == sorted(reference_graph.edge_list()) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + assert set(tracks.graph_solution.edge_ids()) == set(reference_graph.edge_ids()) + assert sorted(tracks.graph_solution.edge_list()) == sorted( + reference_graph.edge_list() + ) assert_frame_equal( - tracks.graph.edge_attrs(), + tracks.graph_solution.edge_attrs(), reference_graph.edge_attrs(), check_row_order=False, check_column_order=False, @@ -74,17 +76,19 @@ def test_add_delete_edges(get_tracks, ndim, with_seg): # objects again — that's where the corruption surfaces. re_added.inverse() # reset state: edges absent (fresh objects, no bug here) inverse.inverse() # same DeleteEdge objects called again — must not crash - assert set(tracks.graph.edge_ids()) == set(reference_graph.edge_ids()) + assert set(tracks.graph_solution.edge_ids()) == set(reference_graph.edge_ids()) def test_add_edge_missing_endpoint(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) - with pytest.raises(ValueError, match="Cannot add edge .*: endpoint .* not in graph"): + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) + with pytest.raises( + ValueError, match="Cannot add edge .*: endpoint .* not in solution" + ): AddEdge(tracks, (10, 11)) def test_delete_missing_edge(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises( ValueError, match="Edge .* not in the graph, and cannot be removed" ): @@ -97,7 +101,7 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg): """Test custom edge attributes preserved through add/delete/re-add cycles.""" from funtracks.features import Feature - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Register custom edge features so they get saved by DeleteEdge custom_features = { @@ -138,29 +142,93 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg): action = AddEdge(tracks, edge, attributes=custom_attrs) # Verify all attributes are present after adding - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) for key, value in custom_attrs.items(): - edge_id = tracks.graph.edge_id(*edge) - assert tracks.graph.edges[edge_id][key] == value, ( + edge_id = tracks.graph_solution.edge_id(*edge) + assert tracks.graph_solution.edges[edge_id][key] == value, ( f"Attribute {key} not set correctly on edge" ) # Delete the edge delete_action = action.inverse() - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) # Re-add the edge by inverting the delete delete_action.inverse() - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) # Verify all custom attributes are still present after re-adding for key, value in custom_attrs.items(): - edge_id = tracks.graph.edge_id(*edge) - assert tracks.graph.edges[edge_id][key] == value, ( + edge_id = tracks.graph_solution.edge_id(*edge) + assert tracks.graph_solution.edges[edge_id][key] == value, ( f"Attribute {key} not preserved after delete/re-add cycle" ) +def test_add_edge_revive_applies_new_attributes(get_tracks): + """Re-adding a soft-deleted edge with NEW attributes must apply them, not silently + keep the values preserved from before the delete. + + Existing tests only re-add via inverse(), which reuses the SAME attrs that + soft-delete preserved, so the revive-vs-add-new asymmetry is invisible to them. + """ + from funtracks.features import Feature + + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + tracks.add_feature( + "weight", + Feature( + feature_type="edge", + value_type="float", + num_values=1, + display_name="Weight", + default_value=None, + ), + ) + + edge = (1, 5) + AddEdge(tracks, edge, attributes={"weight": 1.5}) + DeleteEdge(tracks, edge) # soft-delete: weight=1.5 preserved in graph_full + + # Fresh add of the (now soft-deleted) edge with a DIFFERENT weight. + AddEdge(tracks, edge, attributes={"weight": 9.9}) + + edge_id = tracks.graph_solution.edge_id(*edge) + assert tracks.graph_solution.edges[edge_id]["weight"] == 9.9 + + +def test_add_edge_revive_applies_new_vector_attributes(get_tracks): + """Reviving an edge with a non-scalar (vector) attribute must apply it correctly. + + update_edge_attrs reads a bare list value as one-value-per-edge, so passing a + vector attr unwrapped for a single edge raises a size mismatch (or, for a + length-1 list, silently unwraps it to its element). + """ + from funtracks.features import Feature + + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + tracks.add_feature( + "flow", + Feature( + feature_type="edge", + value_type="float", + num_values=2, + display_name=["Flow Y", "Flow X"], + default_value=None, + ), + ) + + edge = (1, 5) + AddEdge(tracks, edge, attributes={"flow": [1.0, 2.0]}) + DeleteEdge(tracks, edge) # soft-delete: flow preserved in graph_full + + # Fresh add of the (now soft-deleted) edge with a DIFFERENT flow vector. + AddEdge(tracks, edge, attributes={"flow": [3.0, 4.0]}) + + edge_id = tracks.graph_solution.edge_id(*edge) + assert list(tracks.graph_solution.edges[edge_id]["flow"]) == [3.0, 4.0] + + def test_add_edge_with_unregistered_edge_attr(tmp_path): """AddEdge must not crash when the graph has edge attrs absent from tracks.features. @@ -175,7 +243,7 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path): # Build a graph with "custom_score" on every edge. # This mirrors what the motile solver does: it writes edge attributes directly # to the graph without going through tracks.add_feature(). - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["pos", "track_id", "lineage_id"], edge_attributes=["custom_score"], database=db_path, @@ -203,7 +271,7 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path): indices=[1, 2], ) - # Wrap in SolutionTracks without registering "custom_score" in features — + # Wrap in Tracks without registering "custom_score" in features — # this is the scenario that triggers the bug. features = FeatureDict( features={ @@ -217,13 +285,13 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path): tracklet_key="track_id", lineage_key="lineage_id", ) - tracks = SolutionTracks(graph, ndim=3, features=features) + tracks = Tracks(graph, ndim=3, features=features) # Sanity: "custom_score" is in the graph schema but NOT in tracks.features. - assert "custom_score" in tracks.graph.edge_attr_keys() + assert "custom_score" in tracks.graph_solution.edge_attr_keys() assert "custom_score" not in tracks.features # Before the fix this raises: KeyError: 'custom_score' AddEdge(tracks, (1, 2)) - assert tracks.graph.has_edge(1, 2) + assert tracks.graph_solution.has_edge(1, 2) diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py index 9d6b99c4..b716355c 100644 --- a/tests/actions/test_add_delete_nodes.py +++ b/tests/actions/test_add_delete_nodes.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import tracksdata as td from numpy.testing import assert_array_almost_equal, assert_array_equal from polars.testing import assert_frame_equal from tracksdata.array import GraphArrayView @@ -10,7 +11,7 @@ ) from funtracks.utils.tracksdata_utils import ( assert_node_attrs_equal_with_masks, - create_empty_graphview_graph, + create_empty_graph, ) from ..conftest import make_2d_disk_mask, make_3d_sphere_mask @@ -20,8 +21,8 @@ @pytest.mark.parametrize("with_seg", [True, False]) def test_add_delete_nodes(get_tracks, ndim, with_seg): # Get a tracks instance - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - reference_graph = tracks.graph + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + reference_graph = tracks.graph_solution reference_seg = np.asarray(tracks.segmentation).copy() if with_seg else None # Start with an empty Tracks @@ -32,17 +33,26 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): tracks.features.position_key, ] edge_attributes = ["iou"] if with_seg else [] - empty_graph = create_empty_graphview_graph( + empty_graph = create_empty_graph( node_attributes=node_attributes + (["area", "bbox", "mask"] if with_seg else []), edge_attributes=edge_attributes, ndim=ndim, ) empty_seg = np.zeros_like(tracks.segmentation) if with_seg else None - tracks.graph = empty_graph + # Reset the tracks onto the empty base graph, mirroring Tracks.__init__: graph_full + # is the base graph, graph_solution its solution==True view. + tracks.graph_full = empty_graph + tracks.graph_solution = empty_graph.filter( + td.NodeAttr("solution") == True, # noqa: E712 + td.EdgeAttr("solution") == True, # noqa: E712 + ).subgraph() segmentation_shape = (5, 100, 100) if ndim == 3 else (5, 100, 100, 100) tracks.segmentation = ( GraphArrayView( - graph=tracks.graph, shape=segmentation_shape, attr_key="node_id", offset=0 + graph=tracks.graph_solution, + shape=segmentation_shape, + attr_key="node_id", + offset=0, ) if with_seg else None @@ -78,8 +88,8 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): actions.append(AddNode(tracks, node, attributes=attrs)) action = ActionGroup(tracks=tracks, actions=actions) - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - data_tracks = tracks.graph.node_attrs() + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + data_tracks = tracks.graph_solution.node_attrs() data_reference = reference_graph.node_attrs() if with_seg: assert_array_almost_equal(tracks.segmentation, reference_seg) @@ -93,16 +103,17 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): check_dtypes=False, ) - # Invert the action to delete all the nodes + # Invert the action to delete all the nodes. They are soft-deleted, so they remain + # in graph_full (empty_graph) with solution=False but drop out of the solution view. del_nodes = action.inverse() - assert set(tracks.graph.node_ids()) == set(empty_graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set() if with_seg: assert_array_almost_equal(tracks.segmentation, empty_seg) # Re-invert the action to add back all the nodes and their attributes del_nodes.inverse() - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - data_tracks = tracks.graph.node_attrs() + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + data_tracks = tracks.graph_solution.node_attrs() data_reference = reference_graph.node_attrs() if with_seg: assert_array_almost_equal(tracks.segmentation, reference_seg) @@ -117,20 +128,45 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): ) +def test_add_node_revive_applies_new_attributes(get_tracks): + """Re-adding a soft-deleted node with NEW attributes must apply them, not silently + keep the values preserved from before the delete. + + Existing tests only re-add via inverse(), which reuses the SAME attrs that + soft-delete preserved, so the revive-vs-add-new asymmetry is invisible to them. + Mirrors test_add_edge_revive_applies_new_attributes. + """ + from funtracks.actions import DeleteNode + + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + + node_id = 100 + AddNode(tracks, node_id, {"t": 2, "track_id": 10, "pos": [50.0, 50.0]}) + DeleteNode(tracks, node_id) # soft-delete: attrs preserved in graph_full + + # Fresh add of the (now soft-deleted) node with DIFFERENT attributes. + AddNode(tracks, node_id, {"t": 2, "track_id": 11, "pos": [60.0, 60.0]}) + + assert tracks.graph_solution.nodes[node_id]["track_id"] == 11 + assert_array_almost_equal( + tracks.graph_solution.nodes[node_id]["pos"], np.array([60.0, 60.0]) + ) + + def test_add_node_missing_time(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(ValueError, match="Must provide a time attribute for node"): AddNode(tracks, 8, {}) def test_add_node_missing_pos(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) # First test: missing track_id raises an error with pytest.raises(ValueError, match="Must provide a track_id attribute for node"): AddNode(tracks, 8, {"t": 2}) # Second test: with track_id but without segmentation, missing pos raises an error - tracks_no_seg = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks_no_seg = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) with pytest.raises( ValueError, match="Must provide position or segmentation for node" ): @@ -143,7 +179,7 @@ def test_custom_attributes_preserved(get_tracks, ndim, with_seg): """Test custom node attributes preserved through add/delete/re-add cycles.""" from funtracks.features import Feature - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Register custom features so they get saved by DeleteNode custom_features = { @@ -197,36 +233,44 @@ def test_custom_attributes_preserved(get_tracks, ndim, with_seg): node_id = 100 action = AddNode(tracks, node_id, custom_attrs.copy()) # Verify all attributes are present after adding - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) for key, value in custom_attrs.items(): if key == "pos": - assert_array_almost_equal(tracks.graph.nodes[node_id][key], np.array(value)) + assert_array_almost_equal( + tracks.graph_solution.nodes[node_id][key], np.array(value) + ) elif key == "mask": continue elif key == "bbox": - assert_array_equal(np.asarray(tracks.graph.nodes[node_id][key]), value) + assert_array_equal( + np.asarray(tracks.graph_solution.nodes[node_id][key]), value + ) else: - assert tracks.graph.nodes[node_id][key] == value, ( + assert tracks.graph_solution.nodes[node_id][key] == value, ( f"Attribute {key} not preserved after add" ) # Delete the node delete_action = action.inverse() - assert node_id not in tracks.graph.node_ids() + assert node_id not in tracks.graph_solution.node_ids() # Re-add the node by inverting the delete delete_action.inverse() - assert node_id in tracks.graph.node_ids() + assert node_id in tracks.graph_solution.node_ids() # Verify all custom attributes are still present after re-adding for key, value in custom_attrs.items(): if key == "pos": - assert_array_almost_equal(tracks.graph.nodes[node_id][key], np.array(value)) + assert_array_almost_equal( + tracks.graph_solution.nodes[node_id][key], np.array(value) + ) elif key == "mask": continue elif key == "bbox": - assert_array_equal(np.asarray(tracks.graph.nodes[node_id][key]), value) + assert_array_equal( + np.asarray(tracks.graph_solution.nodes[node_id][key]), value + ) else: - assert tracks.graph.nodes[node_id][key] == value, ( + assert tracks.graph_solution.nodes[node_id][key] == value, ( f"Attribute {key} not preserved after delete/re-add cycle" ) diff --git a/tests/actions/test_base_action.py b/tests/actions/test_base_action.py index 9e4c0047..2b9e3f56 100644 --- a/tests/actions/test_base_action.py +++ b/tests/actions/test_base_action.py @@ -6,7 +6,7 @@ def test_initialize_base_class(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) action = Action(tracks) with pytest.raises(NotImplementedError): action.inverse() diff --git a/tests/actions/test_update_node_attrs.py b/tests/actions/test_update_node_attrs.py index f027ead1..5171e9b4 100644 --- a/tests/actions/test_update_node_attrs.py +++ b/tests/actions/test_update_node_attrs.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize("ndim", [3, 4]) def test_update_node_attrs(get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node = 1 new_feature = Feature( @@ -32,6 +32,6 @@ def test_update_node_attrs(get_tracks, ndim): @pytest.mark.parametrize("attr", ["t", "area", "track_id"]) def test_update_protected_attr(get_tracks, attr): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(ValueError, match="Cannot update attribute .* manually"): UpdateNodeAttrs(tracks, 1, {attr: 2}) diff --git a/tests/actions/test_update_node_segs.py b/tests/actions/test_update_node_segs.py index 73a98af4..7115d019 100644 --- a/tests/actions/test_update_node_segs.py +++ b/tests/actions/test_update_node_segs.py @@ -10,15 +10,15 @@ @pytest.mark.parametrize("ndim", [3, 4]) def test_update_node_segs(get_tracks, ndim): # Get tracks with segmentation - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) - reference_graph = tracks.graph.detach().filter().subgraph() + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) + reference_graph = tracks.graph_solution.detach().filter().subgraph() node = 1 time = tracks.get_time(node) original_seg = np.asarray(tracks.segmentation).copy() - original_area = tracks.graph.nodes[1]["area"] - original_pos = tracks.graph.nodes[1]["pos"] + original_area = tracks.graph_solution.nodes[1]["area"] + original_pos = tracks.graph_solution.nodes[1]["pos"] # Add a couple pixels to the first node new_seg = np.asarray(tracks.segmentation).copy() @@ -31,22 +31,22 @@ def test_update_node_segs(get_tracks, ndim): action = UpdateNodeSeg(tracks, node, mask=mask, added=True) - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - assert tracks.graph.nodes[1]["area"] == original_area + 1 - assert not np.allclose(tracks.graph.nodes[1]["pos"], original_pos) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + assert tracks.graph_solution.nodes[1]["area"] == original_area + 1 + assert not np.allclose(tracks.graph_solution.nodes[1]["pos"], original_pos) assert_array_almost_equal(tracks.segmentation, new_seg) inverse = action.inverse() - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) assert_series_equal( reference_graph.nodes[1]["pos"], - tracks.graph.nodes[1]["pos"], + tracks.graph_solution.nodes[1]["pos"], ) assert_array_almost_equal(tracks.segmentation, original_seg) inverse.inverse() - assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) - assert tracks.graph.nodes[1]["area"] == original_area + 1 - assert not np.allclose(tracks.graph.nodes[1]["pos"], original_pos) + assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids()) + assert tracks.graph_solution.nodes[1]["area"] == original_area + 1 + assert not np.allclose(tracks.graph_solution.nodes[1]["pos"], original_pos) assert_array_almost_equal(tracks.segmentation, new_seg) diff --git a/tests/annotators/test_annotator_registry.py b/tests/annotators/test_annotator_registry.py index 6a7e8e3d..30996ff3 100644 --- a/tests/annotators/test_annotator_registry.py +++ b/tests/annotators/test_annotator_registry.py @@ -1,9 +1,12 @@ import pytest from funtracks.annotators import EdgeAnnotator, RegionpropsAnnotator, TrackAnnotator -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} +# Even without an explicit tracklet_attr, every Tracks gets a TrackAnnotator and a +# default tracklet key (no "plain" vs "solution" split). +plain_attrs = {"time_attr": "t"} def test_annotator_registry_init_with_segmentation( @@ -14,31 +17,32 @@ def test_annotator_registry_init_with_segmentation( tracks = Tracks( graph_2d_with_segmentation, ndim=3, - **track_attrs, + **plain_attrs, ) annotator_types = [type(ann) for ann in tracks.annotators] assert RegionpropsAnnotator in annotator_types assert EdgeAnnotator in annotator_types - assert TrackAnnotator not in annotator_types # Not a SolutionTracks + assert TrackAnnotator in annotator_types # every Tracks has track ids def test_annotator_registry_init_without_segmentation(graph_2d_with_position): - """Test AnnotatorRegistry doesn't create annotators without segmentation.""" - tracks = Tracks(graph_2d_with_position, ndim=3, **track_attrs) + """Without segmentation: no regionprops/edge annotators, but a TrackAnnotator is + still registered (track ids are a core feature of every Tracks).""" + tracks = Tracks(graph_2d_with_position, ndim=3, **plain_attrs) annotator_types = [type(ann) for ann in tracks.annotators] assert RegionpropsAnnotator not in annotator_types assert EdgeAnnotator not in annotator_types - assert TrackAnnotator not in annotator_types + assert TrackAnnotator in annotator_types def test_annotator_registry_init_solution_tracks( graph_2d_with_segmentation, ): - """Test AnnotatorRegistry creates all annotators for SolutionTracks with + """Test AnnotatorRegistry creates all annotators for Tracks with segmentation.""" - tracks = SolutionTracks( + tracks = Tracks( graph_2d_with_segmentation, ndim=3, **track_attrs, @@ -57,13 +61,13 @@ def test_enable_disable_features(graph_2d_with_segmentation): **track_attrs, ) - nodes = list(tracks.graph.node_ids()) - edges = list(tracks.graph.edge_ids()) + nodes = list(tracks.graph_solution.node_ids()) + edges = list(tracks.graph_solution.edge_ids()) # Core features (time, pos) should be in tracks.features and computed assert "pos" in tracks.features assert "t" in tracks.features - assert tracks.graph.nodes[nodes[0]]["pos"] is not None + assert tracks.graph_solution.nodes[nodes[0]]["pos"] is not None # area and other features should NOT be in tracks.features initially assert "area" not in tracks.features @@ -78,9 +82,9 @@ def test_enable_disable_features(graph_2d_with_segmentation): assert "circularity" in tracks.features # Verify values are actually computed on the graph - assert tracks.graph.nodes[nodes[0]]["circularity"] is not None + assert tracks.graph_solution.nodes[nodes[0]]["circularity"] is not None if edges: - assert None not in tracks.graph.edge_attrs()["iou"].to_list() + assert None not in tracks.graph_solution.edge_attrs()["iou"].to_list() # Disable one feature tracks.disable_features(["area"]) @@ -92,7 +96,7 @@ def test_enable_disable_features(graph_2d_with_segmentation): assert "circularity" in tracks.features # Values no longer exist in the graph for tracksdata - # assert tracks.graph.nodes[1]["area"] is not None + # assert tracks.graph_solution.nodes[1]["area"] is not None # Disable the remaining enabled features tracks.disable_features(["pos", "iou", "circularity"]) @@ -122,7 +126,7 @@ def test_area_on_graph_not_auto_activated(graph_2d_with_segmentation): def test_get_available_features(graph_2d_with_segmentation): """Test get_available_features returns all features from all annotators.""" - tracks = SolutionTracks( + tracks = Tracks( graph_2d_with_segmentation, ndim=3, **track_attrs, diff --git a/tests/annotators/test_edge_annotator.py b/tests/annotators/test_edge_annotator.py index cbef21c7..49fc0d8a 100644 --- a/tests/annotators/test_edge_annotator.py +++ b/tests/annotators/test_edge_annotator.py @@ -4,7 +4,7 @@ from funtracks.actions import UpdateNodeSeg, UpdateTrackIDs from funtracks.annotators import EdgeAnnotator -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} @@ -42,7 +42,7 @@ def test_compute_all(self, get_graph, ndim): # Compute values ann.compute() for key in all_features: - assert key in tracks.graph.edge_attr_keys() + assert key in tracks.graph_solution.edge_attr_keys() def test_update_all(self, get_graph, ndim) -> None: graph = get_graph(ndim, with_seg=True) @@ -78,7 +78,10 @@ def test_update_all(self, get_graph, ndim) -> None: with pytest.warns(match="Cannot find label 1 in frame .*"): UpdateNodeSeg(tracks, node_id, mask, added=False) - assert tracks.graph.edges[tracks.graph.edge_id(*edge_id)]["iou"] == 0 + assert ( + tracks.graph_solution.edges[tracks.graph_solution.edge_id(*edge_id)]["iou"] + == 0 + ) def test_add_remove_feature(self, get_graph, ndim): graph = get_graph(ndim, with_seg=True) @@ -130,8 +133,8 @@ def test_missing_seg(self, get_graph, ndim) -> None: def test_ignores_irrelevant_actions(self, get_graph, ndim): """Test that EdgeAnnotator ignores actions that don't affect edges.""" - graph = get_graph(ndim, is_solution=True, with_seg=True) - tracks = SolutionTracks( + graph = get_graph(ndim, prefill_track_ids=True, with_seg=True) + tracks = Tracks( graph, ndim=ndim, **track_attrs, @@ -140,7 +143,7 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): node_id = 3 edge = (1, 3) - edge_id = tracks.graph.edge_id(*edge) + edge_id = tracks.graph_solution.edge_id(*edge) initial_iou = tracks.get_edge_attr(edge, "iou") # If we recomputed IoU now, it would be different @@ -155,6 +158,6 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): UpdateTrackIDs(tracks, node_id, new_track_id) # IoU should remain unchanged (no recomputation happened despite seg change) - assert tracks.graph.edges[edge_id]["iou"] == initial_iou + assert tracks.graph_solution.edges[edge_id]["iou"] == initial_iou # But track_id should be updated assert tracks.get_track_id(node_id) == new_track_id diff --git a/tests/annotators/test_features_on_full_graph.py b/tests/annotators/test_features_on_full_graph.py new file mode 100644 index 00000000..49321289 --- /dev/null +++ b/tests/annotators/test_features_on_full_graph.py @@ -0,0 +1,69 @@ +"""Step 5 regression: detection features (regionprops, iou) live on graph_full. + +These features are intrinsic to a detection/link and must stay computed for *all* +nodes/edges — including soft-deleted (solution=False) candidates — so the full and +solution graphs never drift and candidates are ready for re-solving. Track-id features +remain solution-only and are covered elsewhere. +""" + +import pytest + +from funtracks.actions import DeleteEdge, DeleteNode +from funtracks.annotators import EdgeAnnotator, RegionpropsAnnotator + + +def _annotator(tracks, cls): + return next(ann for ann in tracks.annotators if isinstance(ann, cls)) + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_regionprops_persist_and_recompute_on_soft_deleted_node(get_tracks, ndim): + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) + rp_ann = _annotator(tracks, RegionpropsAnnotator) + tracks.enable_features(["pos", "area"]) + rp_ann.compute(["area"]) + + node = 5 # leaf node of the fixture graph + area_before = tracks.get_node_attr(node, "area") + assert area_before is not None + + # Soft-delete: leaves the solution view, stays in graph_full as solution=False. + DeleteNode(tracks, node) + assert node not in tracks.graph_solution.node_ids() + assert node in tracks.graph_full.node_ids() + assert tracks.graph_full.nodes[node]["solution"] is False + + # NEW: the detection feature is still readable (helpers read graph_full, so this no + # longer KeyErrors on an out-of-solution node) and its value is preserved. + assert tracks.get_node_attr(node, "area") == area_before + + # NEW: a bulk recompute covers the soft-deleted candidate (iterates graph_full). + tracks._set_node_attr(node, "area", None) # wipe to prove recompute reaches it + rp_ann.compute(["area"]) + assert tracks.get_node_attr(node, "area") == area_before + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_iou_computed_on_soft_deleted_candidate_edge(get_tracks, ndim): + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) + edge_ann = _annotator(tracks, EdgeAnnotator) + tracks.enable_features(["iou"]) + edge_ann.compute(["iou"]) + + edge = (4, 5) + iou_before = tracks.get_edge_attr(edge, "iou") + assert iou_before is not None + + # Soft-delete the edge: gone from the solution, kept in graph_full as a candidate. + DeleteEdge(tracks, edge) + assert not tracks.graph_solution.has_edge(*edge) + assert tracks.graph_full.has_edge(*edge) + + # NEW: iou is still readable on the candidate edge (get_edge_attr reads graph_full). + assert tracks.get_edge_attr(edge, "iou") == iou_before + + # NEW: a bulk recompute reaches the solution=False edge (compute iterates graph_full + # successors, which include candidate edges). + tracks._set_edge_attr(edge, "iou", 0.0) # wipe to prove recompute reaches it + edge_ann.compute(["iou"]) + assert tracks.get_edge_attr(edge, "iou") == iou_before diff --git a/tests/annotators/test_regionprops_annotator.py b/tests/annotators/test_regionprops_annotator.py index 3193dd06..7ee98ee3 100644 --- a/tests/annotators/test_regionprops_annotator.py +++ b/tests/annotators/test_regionprops_annotator.py @@ -4,7 +4,7 @@ from funtracks.actions import UpdateNodeSeg, UpdateTrackIDs from funtracks.annotators import RegionpropsAnnotator -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} @@ -39,9 +39,9 @@ def test_compute_all(self, get_graph, ndim): tracks.enable_features(list(rp_ann.all_features.keys())) for key in rp_ann.all_features: - assert key in tracks.graph.node_attr_keys() - for node_id in tracks.graph.node_ids(): - value = tracks.graph.nodes[node_id][key] + assert key in tracks.graph_solution.node_attr_keys() + for node_id in tracks.graph_solution.node_ids(): + value = tracks.graph_solution.nodes[node_id][key] assert value is not None def test_update_all(self, get_graph, ndim): @@ -69,7 +69,7 @@ def test_update_all(self, get_graph, ndim): UpdateNodeSeg(tracks, node_id, removal, added=False) assert tracks.get_node_attr(node_id, "area") == expected_area for key in rp_ann.features: - assert key in tracks.graph.node_attr_keys() + assert key in tracks.graph_solution.node_attr_keys() # segmentation is fully erased and you try to update node_id = 1 @@ -80,8 +80,8 @@ def test_update_all(self, get_graph, ndim): UpdateNodeSeg(tracks, node_id, mask, added=False) # all regionprops features should be the defaults, because seg doesn't exist for key in rp_ann.features: - actual = tracks.graph.nodes[node_id][key] - expected = tracks.graph._node_attr_schemas()[key].default_value + actual = tracks.graph_solution.nodes[node_id][key] + expected = tracks.graph_solution._node_attr_schemas()[key].default_value # Convert to numpy arrays for comparison (handles both scalar and array types) actual_np = np.asarray(actual) expected_np = np.asarray(expected) @@ -105,7 +105,7 @@ def test_add_remove_feature(self, get_graph, ndim): tracks.disable_features([to_remove_key]) rp_ann.compute() - assert to_remove_key not in tracks.graph.node_attr_keys() + assert to_remove_key not in tracks.graph_solution.node_attr_keys() # add it back in tracks.enable_features([to_remove_key]) @@ -155,7 +155,7 @@ def test_centroid_world_coords_with_scale(self, get_graph, ndim): # Force recomputation so regionprops runs with the given scale as spacing tracks.enable_features(["pos"]) - pos = np.array(tracks.graph.nodes[6]["pos"]) + pos = np.array(tracks.graph_solution.nodes[6]["pos"]) expected = pixel_centroid * np.array(scale[1:]) bug_value = np.array([1.5] * len(pixel_centroid)) * np.array( @@ -176,8 +176,8 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): """Test that RegionpropsAnnotator ignores actions that don't affect segmentation. """ - graph = get_graph(ndim, is_solution=True, with_seg=True) - tracks = SolutionTracks( + graph = get_graph(ndim, prefill_track_ids=True, with_seg=True) + tracks = Tracks( graph, ndim=ndim, **track_attrs, @@ -192,7 +192,9 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim): # RegionpropsAnnotator, it would recompute area back to initial_area and # the assertion below would fail. fake_area = initial_area + 999 - tracks.graph.update_node_attrs(attrs={"area": [fake_area]}, node_ids=[node_id]) + tracks.graph_solution.update_node_attrs( + attrs={"area": [fake_area]}, node_ids=[node_id] + ) original_track_id = tracks.get_track_id(node_id) new_track_id = original_track_id + 100 diff --git a/tests/annotators/test_track_annotator.py b/tests/annotators/test_track_annotator.py index 2f7fa0d2..a2106e19 100644 --- a/tests/annotators/test_track_annotator.py +++ b/tests/annotators/test_track_annotator.py @@ -16,7 +16,7 @@ @pytest.mark.parametrize("with_seg", [True, False]) class TestTrackAnnotator: def test_init(self, get_tracks, ndim, with_seg) -> None: - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) ann = TrackAnnotator(tracks) # Features start disabled by default assert len(ann.all_features) == 2 @@ -32,7 +32,7 @@ def test_init(self, get_tracks, ndim, with_seg) -> None: assert ann.max_tracklet_id == 5 def test_compute_all(self, get_tracks, ndim, with_seg) -> None: - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) ann = TrackAnnotator(tracks, tracklet_key=tracks.features.tracklet_key) # Enable features @@ -41,9 +41,9 @@ def test_compute_all(self, get_tracks, ndim, with_seg) -> None: # Compute values ann.compute() - for node in tracks.graph.node_ids(): + for node in tracks.graph_solution.node_ids(): for key in all_features: - assert tracks.graph.nodes[node][key] is not None + assert tracks.graph_solution.nodes[node][key] is not None lineages = [ [1, 2, 3, 4, 5], @@ -69,7 +69,7 @@ def test_compute_all(self, get_tracks, ndim, with_seg) -> None: assert len({id_set[0] for id_set in id_sets}) == len(id_sets) def test_add_remove_feature(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) ann = TrackAnnotator(tracks, tracklet_key=tracks.features.tracklet_key) # Enable features ann.activate_features(list(ann.all_features.keys())) @@ -79,7 +79,9 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg): node_id = 6 edge_id = (4, 6) attrs = {"iou": 0, "solution": True} if with_seg else {"solution": True} - tracks.graph.add_edge(source_id=edge_id[0], target_id=edge_id[1], attrs=attrs) + tracks.graph_solution.add_edge( + source_id=edge_id[0], target_id=edge_id[1], attrs=attrs + ) to_remove_key = ann.lineage_key orig_lin = tracks.get_node_attr(node_id, ann.lineage_key) orig_tra = tracks.get_node_attr(node_id, ann.tracklet_key) @@ -98,20 +100,20 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg): assert tracks.get_node_attr(node_id, ann.lineage_key) != orig_lin assert tracks.get_node_attr(node_id, ann.tracklet_key) != orig_tra - def test_invalid(self, get_tracks, ndim, with_seg) -> None: - # Create regular Tracks (not SolutionTracks) to test error handling - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=False) - with pytest.raises( - ValueError, match="Currently the TrackAnnotator only works on SolutionTracks" - ): - TrackAnnotator(tracks) # type: ignore + def test_always_has_track_annotator(self, get_tracks, ndim, with_seg) -> None: + # Every Tracks has a tracklet key and a registered TrackAnnotator, even when + # built without explicit track attributes (no "plain" vs "solution" split). + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=False) + assert tracks.features.tracklet_key is not None + assert tracks.features.lineage_key is not None + assert isinstance(tracks.track_annotator, TrackAnnotator) def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg): """Test that TrackAnnotator ignores actions that don't affect track IDs.""" if not with_seg: pytest.skip("Test requires segmentation") - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) tracks.enable_features(["area", tracks.features.tracklet_key]) node_id = 3 @@ -133,7 +135,7 @@ def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg): def test_lineage_id_updated_on_add_and_delete_edge( self, get_tracks, ndim, with_seg ) -> None: - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) tracks.enable_features(["lineage_id"]) # get the existing TrackAnnotator @@ -155,7 +157,7 @@ def test_lineage_id_updated_on_add_and_delete_edge( source_node = 3 target_node = 4 - edge = next(e for e in tracks.graph.edge_list() if set(e) == {3, 4}) + edge = next(e for e in tracks.graph_solution.edge_list() if set(e) == {3, 4}) expected_lineage_id = ann.max_lineage_id + 1 UserDeleteEdge(tracks, edge=edge) @@ -206,7 +208,7 @@ def test_lineage_id_updated_on_division(self, get_tracks, ndim, with_seg) -> Non - New child (6): keeps same track_id, gets source's lineage_id """ # Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate) - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) tracks.enable_features(["lineage_id"]) ann = next(a for a in tracks.annotators if isinstance(a, TrackAnnotator)) @@ -255,7 +257,7 @@ def test_lineage_id_updated_on_delete_division_edge( - Orphaned child (2): keeps same track_id, gets new lineage_id """ # Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate) - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) tracks.enable_features(["lineage_id"]) ann = next(a for a in tracks.annotators if isinstance(a, TrackAnnotator)) @@ -287,7 +289,7 @@ def test_lineage_id_updated_on_delete_division_edge( def test_disabled_tracklet_key_does_nothing(self, get_tracks, ndim, with_seg) -> None: """Test that TrackAnnotator does nothing when tracklet_key is disabled.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) ann = TrackAnnotator(tracks) # Don't activate any features - they should all be disabled @@ -307,3 +309,27 @@ def test_disabled_tracklet_key_does_nothing(self, get_tracks, ndim, with_seg) -> assert ann.lineage_id_to_nodes == original_lineage_map assert ann.max_tracklet_id == original_max_tracklet assert ann.max_lineage_id == original_max_lineage + + +def test_recompute_replaces_stale_bookkeeping(get_tracks): + """A tracklet column still holding the -1 sentinel seeds a phantom tracklet -1 in + the bookkeeping at annotator init (_get_max_id_and_map only skips None). A full + compute() must replace tracklet_id_to_nodes wholesale (like _assign_lineage_ids + does), not merge into it — otherwise the phantom entry survives with its nodes + duplicated under their real tracklet ids. + """ + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + node_ids = list(tracks.graph_full.node_ids()) + tracks.graph_full.update_node_attrs( + attrs={"track_id": [-1] * len(node_ids)}, node_ids=node_ids + ) + + ann = TrackAnnotator(tracks, tracklet_key="track_id") + ann.activate_features(["track_id", "lineage_id"]) + ann.compute() + + assert -1 not in ann.tracklet_id_to_nodes + # Every solution node appears exactly once across the bookkeeping. + all_nodes = sorted(n for ns in ann.tracklet_id_to_nodes.values() for n in ns) + assert all_nodes == sorted(tracks.graph_solution.node_ids()) + assert ann.max_tracklet_id == max(ann.tracklet_id_to_nodes.keys()) diff --git a/tests/benchmarks/bench_candidate_graph.py b/tests/benchmarks/bench_candidate_graph.py index 7d9d63db..33078e87 100644 --- a/tests/benchmarks/bench_candidate_graph.py +++ b/tests/benchmarks/bench_candidate_graph.py @@ -9,7 +9,7 @@ from skimage.draw import disk from funtracks.candidate_graph.compute_graph import compute_graph_from_seg -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks NUM_FRAMES = 50 FRAME_SHAPE = (700, 1100) @@ -43,6 +43,28 @@ def seg_data(): return _generate_segmentation() +@pytest.fixture(scope="module") +def _warm_spatial_compile(): + """Warm the spatial_graph rtree JIT compile before the timed benchmark. + + Building a Tracks with segmentation creates a GraphArrayView, which builds a + spatial_graph rtree whose Cython module is JIT-compiled by witty on first use and + cached process-wide. That compile is a one-time cost (seconds-to-tens-of-seconds on + a cold Windows runner) that would otherwise land inside the timed region of + test_graph_to_solution. Trigger it here (untimed) on a tiny graph with the same + dimensionality so the benchmark measures tracking work, not compilation. + """ + tiny = _generate_segmentation(num_frames=2, frame_shape=(64, 64), cells_per_frame=2) + graph = compute_graph_from_seg(tiny, MAX_EDGE_DISTANCE, iou=True) + Tracks( + graph, + time_attr="t", + pos_attr="pos", + tracklet_attr="tracklet_id", + lineage_attr="lineage_id", + ) + + def test_compute_graph_from_seg(benchmark, seg_data): benchmark.pedantic( compute_graph_from_seg, @@ -53,18 +75,26 @@ def test_compute_graph_from_seg(benchmark, seg_data): ) -def test_graph_to_solution(benchmark, seg_data): - """Benchmark candidate graph -> SolutionTracks (tracklet/lineage assignment). +def test_graph_to_solution(benchmark, seg_data, _warm_spatial_compile): + """Benchmark candidate graph -> Tracks with track ids (tracklet/lineage assignment). Candidate-graph construction is benchmarked separately above and is built here in - (untimed) setup, so only the SolutionTracks construction -- dominated by + (untimed) setup, so only the Tracks construction -- dominated by TrackAnnotator._assign_tracklet_ids and _assign_lineage_ids -- is measured. + Setting tracklet_attr/lineage_attr registers a TrackAnnotator and triggers the + track-id computation on construction. The _warm_spatial_compile fixture ensures the + one-time spatial_graph rtree JIT compile is not measured here. """ def setup(): - # Fresh graph per round: SolutionTracks construction mutates the graph + # Fresh graph per round: Tracks construction mutates the graph # (adds tracklet/lineage IDs), so each measured call must start unannotated. graph = compute_graph_from_seg(seg_data, MAX_EDGE_DISTANCE, iou=True) - return (graph,), {"time_attr": "t", "pos_attr": "pos"} + return (graph,), { + "time_attr": "t", + "pos_attr": "pos", + "tracklet_attr": "tracklet_id", + "lineage_attr": "lineage_id", + } - benchmark.pedantic(SolutionTracks, setup=setup, rounds=1, iterations=1) + benchmark.pedantic(Tracks, setup=setup, rounds=1, iterations=1) diff --git a/tests/candidate_graph/test_compute_graph.py b/tests/candidate_graph/test_compute_graph.py index 6abfff97..18c0c996 100644 --- a/tests/candidate_graph/test_compute_graph.py +++ b/tests/candidate_graph/test_compute_graph.py @@ -18,13 +18,13 @@ def test_graph_from_segmentation_2d(get_tracks): ) # Same node IDs as the segmentation labels - assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids()) + assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids()) # t, pos, area must match the source graph for every node for node in cand_graph.node_ids(): for key in ["t", "pos", "area"]: assert np.array(cand_graph.nodes[node][key]) == pytest.approx( - np.array(tracks.graph.nodes[node][key]), abs=0.01 + np.array(tracks.graph_solution.nodes[node][key]), abs=0.01 ) # mask and bbox must be present on every node @@ -39,12 +39,14 @@ def test_graph_from_segmentation_2d(get_tracks): # because t=3 has no nodes (add_cand_edges only links frame → frame+1) assert sorted(cand_graph.edge_list()) == [[1, 2], [1, 3], [2, 4], [3, 4]] - # For edges shared with tracks.graph, iou must agree + # For edges shared with tracks.graph_solution, iou must agree cand_edges = {tuple(e) for e in cand_graph.edge_list()} - ref_edges = {tuple(e) for e in tracks.graph.edge_list()} + ref_edges = {tuple(e) for e in tracks.graph_solution.edge_list()} for src, tgt in cand_edges & ref_edges: cand_iou = cand_graph.edges[cand_graph.edge_id(src, tgt)]["iou"] - ref_iou = tracks.graph.edges[tracks.graph.edge_id(src, tgt)]["iou"] + ref_iou = tracks.graph_solution.edges[tracks.graph_solution.edge_id(src, tgt)][ + "iou" + ] assert cand_iou == pytest.approx(ref_iou, abs=0.01) # lower edge distance: only (1, 3) is within 15 pixels (~11.2), (1, 2) is ~42 away @@ -52,7 +54,7 @@ def test_graph_from_segmentation_2d(get_tracks): segmentation=segmentation_2d, max_edge_distance=15, ) - assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids()) + assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids()) assert sorted(cand_graph.edge_list()) == [[1, 3]] @@ -65,12 +67,12 @@ def test_graph_from_segmentation_3d(get_tracks): max_edge_distance=100, ) - assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids()) + assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids()) for node in cand_graph.node_ids(): for key in ["t", "pos", "area"]: assert np.array(cand_graph.nodes[node][key]) == pytest.approx( - np.array(tracks.graph.nodes[node][key]), abs=0.01 + np.array(tracks.graph_solution.nodes[node][key]), abs=0.01 ) # mask and bbox must be present on every node @@ -115,8 +117,8 @@ def test_graph_from_segmentation_t_start(get_tracks): # t values should match the original tracks graph for shared nodes for node in cand_graph.node_ids(): - if node in set(tracks.graph.node_ids()): - assert cand_graph.nodes[node]["t"] == tracks.graph.nodes[node]["t"] + if node in set(tracks.graph_solution.node_ids()): + assert cand_graph.nodes[node]["t"] == tracks.graph_solution.nodes[node]["t"] # Edges should still be formed between adjacent (shifted) frames assert cand_graph.num_edges() > 0 diff --git a/tests/candidate_graph/test_iou.py b/tests/candidate_graph/test_iou.py index 0fc3e28a..21c7a464 100644 --- a/tests/candidate_graph/test_iou.py +++ b/tests/candidate_graph/test_iou.py @@ -43,10 +43,12 @@ def test_add_iou_2d(get_tracks): add_cand_edges(cand_graph, max_edge_distance=100, node_frame_dict=node_frame_dict) add_iou(cand_graph, segmentation_2d, node_frame_dict=node_frame_dict) - # For edges shared with tracks.graph, iou must agree + # For edges shared with tracks.graph_solution, iou must agree cand_edges = {tuple(e) for e in cand_graph.edge_list()} - ref_edges = {tuple(e) for e in tracks.graph.edge_list()} + ref_edges = {tuple(e) for e in tracks.graph_solution.edge_list()} for src, tgt in cand_edges & ref_edges: cand_iou = cand_graph.edges[cand_graph.edge_id(src, tgt)]["iou"] - ref_iou = tracks.graph.edges[tracks.graph.edge_id(src, tgt)]["iou"] + ref_iou = tracks.graph_solution.edges[tracks.graph_solution.edge_id(src, tgt)][ + "iou" + ] assert cand_iou == pytest.approx(ref_iou, abs=0.01) diff --git a/tests/candidate_graph/test_relabel_segmentation.py b/tests/candidate_graph/test_relabel_segmentation.py index 9ea20e96..0e32f28d 100644 --- a/tests/candidate_graph/test_relabel_segmentation.py +++ b/tests/candidate_graph/test_relabel_segmentation.py @@ -30,7 +30,7 @@ def test_relabel_segmentation(get_tracks): segmentation = np.asarray(tracks.segmentation) # Use only nodes 1 and 2 (single tracklet: node 1 at t=0, node 2 at t=1) - subgraph = tracks.graph.filter(node_ids=[1, 2]).subgraph() + subgraph = tracks.graph_solution.filter(node_ids=[1, 2]).subgraph() relabeled = relabel_segmentation_with_track_id(subgraph, segmentation) # Nodes 1 and 2 form one tracklet → both get label 1 diff --git a/tests/conftest.py b/tests/conftest.py index 29477e6e..1224486a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,13 +9,13 @@ from tracksdata.nodes import Mask from funtracks.utils.tracksdata_utils import ( - create_empty_graphview_graph, + create_empty_graph, ) if TYPE_CHECKING: from typing import Any - from funtracks.data_model import SolutionTracks, Tracks + from funtracks.data_model import Tracks def make_2d_disk_mask(center=(50, 50), radius=20) -> Mask: @@ -134,7 +134,7 @@ def _make_graph( with_iou: bool = False, with_masks: bool = False, database: str | None = None, -) -> td.graph.GraphView: +) -> td.graph.BaseGraph: """Generate a test graph with configurable features. Args: @@ -172,7 +172,7 @@ def _make_graph( node_attributes.append(td.DEFAULT_ATTR_KEYS.BBOX) node_default_values.append(0.0) - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=node_attributes, node_default_values=node_default_values, edge_attributes=edge_attributes, @@ -288,28 +288,28 @@ def _make_graph( @pytest.fixture -def graph_clean(tmp_path) -> td.graph.GraphView: +def graph_clean(tmp_path) -> td.graph.BaseGraph: """Base graph with only time - no positions or computed features.""" db_path = str(tmp_path / "graph_clean.db") return _make_graph(ndim=3, database=db_path) @pytest.fixture -def graph_2d_with_position(tmp_path) -> td.graph.GraphView: +def graph_2d_with_position(tmp_path) -> td.graph.BaseGraph: """Graph with 2D positions - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_2d_position.db") return _make_graph(ndim=3, with_pos=True, database=db_path) @pytest.fixture -def graph_2d_with_track_id(tmp_path) -> td.graph.GraphView: - """Graph with 2D positions and track_id - for SolutionTracks without segmentation.""" +def graph_2d_with_track_id(tmp_path) -> td.graph.BaseGraph: + """Graph with 2D positions and track_id - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_2d_track_id.db") return _make_graph(ndim=3, with_pos=True, with_track_id=True, database=db_path) @pytest.fixture -def graph_2d_with_segmentation(tmp_path) -> td.graph.GraphView: +def graph_2d_with_segmentation(tmp_path) -> td.graph.BaseGraph: """Graph with segmentation (masks/bboxes) and all computed features.""" db_path = str(tmp_path / "graph_2d_segmentation.db") return _make_graph( @@ -324,21 +324,21 @@ def graph_2d_with_segmentation(tmp_path) -> td.graph.GraphView: @pytest.fixture -def graph_3d_with_position(tmp_path) -> td.graph.GraphView: +def graph_3d_with_position(tmp_path) -> td.graph.BaseGraph: """Graph with 3D positions - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_3d_position.db") return _make_graph(ndim=4, with_pos=True, database=db_path) @pytest.fixture -def graph_3d_with_track_id(tmp_path) -> td.graph.GraphView: - """Graph with 3D positions and track_id - for SolutionTracks without segmentation.""" +def graph_3d_with_track_id(tmp_path) -> td.graph.BaseGraph: + """Graph with 3D positions and track_id - for Tracks without segmentation.""" db_path = str(tmp_path / "graph_3d_track_id.db") return _make_graph(ndim=4, with_pos=True, with_track_id=True, database=db_path) @pytest.fixture -def graph_3d_with_segmentation(tmp_path) -> td.graph.GraphView: +def graph_3d_with_segmentation(tmp_path) -> td.graph.BaseGraph: """Graph with segmentation (masks/bboxes) and all computed features.""" db_path = str(tmp_path / "graph_3d_segmentation.db") return _make_graph( @@ -353,22 +353,24 @@ def graph_3d_with_segmentation(tmp_path) -> td.graph.GraphView: @pytest.fixture -def get_tracks(get_graph) -> Callable[..., "Tracks | SolutionTracks"]: - """Factory fixture to create Tracks or SolutionTracks instances. +def get_tracks(get_graph) -> Callable[..., "Tracks"]: + """Factory fixture to create Tracks or Tracks instances. Returns a factory function that can be called with: ndim: 3 for 2D spatial + time, 4 for 3D spatial + time with_seg: Whether to include segmentation (mask/bbox as node attributes) - is_solution: Whether to return SolutionTracks instead of Tracks + prefill_track_ids: If True, the fixture graph ships with track_id/lineage_id + values already set (activated as-is); if False, they are computed from the + graph topology. Either way the resulting Tracks has track ids. Example: - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) Note: Uses a pre-built FeatureDict to avoid recomputing features that already exist in the test graph fixtures. """ - from funtracks.data_model import SolutionTracks, Tracks + from funtracks.data_model import Tracks from funtracks.features import ( Area, FeatureDict, @@ -385,12 +387,14 @@ def get_tracks(get_graph) -> Callable[..., "Tracks | SolutionTracks"]: def _make_tracks( ndim: int, with_seg: bool = True, - is_solution: bool = False, - ) -> Tracks | SolutionTracks: + prefill_track_ids: bool = False, + ) -> Tracks: # Determine axis names based on ndim axis_names = ["z", "y", "x"] if ndim == 4 else ["y", "x"] - graph = get_graph(ndim=ndim, is_solution=is_solution, with_seg=with_seg) + graph = get_graph( + ndim=ndim, prefill_track_ids=prefill_track_ids, with_seg=with_seg + ) # Build FeatureDict based on what exists in the graph features_dict: dict[str, Any] = { @@ -404,7 +408,7 @@ def _make_tracks( features_dict["bbox"] = SegBbox(ndim) features_dict["area"] = Area(ndim=ndim) features_dict["iou"] = IoU() - if is_solution: + if prefill_track_ids: features_dict["track_id"] = TrackletID() features_dict["lineage_id"] = LineageID() @@ -412,31 +416,24 @@ def _make_tracks( features=features_dict, time_key="t", position_key="pos", - tracklet_key="track_id" if is_solution else None, - lineage_key="lineage_id" if is_solution else None, + tracklet_key="track_id" if prefill_track_ids else None, + lineage_key="lineage_id" if prefill_track_ids else None, ) - # Create the appropriate Tracks type with pre-built FeatureDict - if is_solution: - return SolutionTracks( - graph, - ndim=ndim, - features=feature_dict, - ) - else: - return Tracks( - graph, - ndim=ndim, - features=feature_dict, - ) + # Create the Tracks with the pre-built FeatureDict. + return Tracks( + graph, + ndim=ndim, + features=feature_dict, + ) return _make_tracks @pytest.fixture -def graph_2d_list(tmp_path) -> td.graph.GraphView: +def graph_2d_list(tmp_path) -> td.graph.BaseGraph: db_path = str(tmp_path / "graph_2d_list.db") - graph = create_empty_graphview_graph(database=db_path) + graph = create_empty_graph(database=db_path) nodes = [ { @@ -475,13 +472,14 @@ def sphere(center, radius, shape): @pytest.fixture -def get_graph(tmp_path) -> Callable[..., td.graph.GraphView]: +def get_graph(tmp_path) -> Callable[..., td.graph.BaseGraph]: """Factory fixture to create a graph with configurable features. Args: ndim: 3 for 2D spatial + time, 4 for 3D spatial + time with_pos: Include position attribute (default True) - is_solution: Include track_id and lineage_id (default False) + prefill_track_ids: Include track_id and lineage_id columns on the graph + (default False) with_seg: Include mask, bbox, area, and iou (default False) Returns: @@ -489,22 +487,22 @@ def get_graph(tmp_path) -> Callable[..., td.graph.GraphView]: Example: graph = get_graph(ndim=3, with_seg=True) - graph = get_graph(ndim=4, is_solution=True, with_seg=True) + graph = get_graph(ndim=4, prefill_track_ids=True, with_seg=True) """ counter = [0] def _get_graph( ndim: int = 3, with_pos: bool = True, - is_solution: bool = False, + prefill_track_ids: bool = False, with_seg: bool = False, - ) -> td.graph.GraphView: + ) -> td.graph.BaseGraph: counter[0] += 1 db_path = str(tmp_path / f"graph_{counter[0]}.db") return _make_graph( ndim=ndim, with_pos=with_pos, - with_track_id=is_solution, + with_track_id=prefill_track_ids, with_area=with_seg, with_iou=with_seg, with_masks=with_seg, diff --git a/tests/data_model/test_soft_delete_roundtrip.py b/tests/data_model/test_soft_delete_roundtrip.py new file mode 100644 index 00000000..852130d9 --- /dev/null +++ b/tests/data_model/test_soft_delete_roundtrip.py @@ -0,0 +1,163 @@ +"""Step 6: soft-delete round-trip + repeated delete<->undo stability. + +Soft-delete keeps the node/edge in ``graph_full`` (flag ``solution=False``) and only drops +it from ``graph_solution``. These tests pin the core Phase-1 guarantees: +- the full graph's topology is preserved across a leaf-node soft-delete (only flags flip); +- delete -> undo restores the solution view exactly (ids + attributes); +- repeated undo/redo through the real ``ActionHistory`` never drifts the view (invariant + #4 in the persistence plan) — the R2 in-place revive must be a true inverse of remove. + +A leaf node (5, edge 4->5) is used so the delete introduces no reconnection skip-edge; the +skip-edge case is exercised in ``test_mid_track_delete_*`` to document that graph_full +accumulates the skip edge as a candidate. +""" + +import pytest + +from funtracks.user_actions import UserDeleteNode + + +def _solution_state(tracks): + """A hashable snapshot of the solution view's topology.""" + return ( + tuple(sorted(tracks.graph_solution.node_ids())), + tuple(sorted(tracks.graph_solution.edge_list())), + ) + + +def _full_state(tracks): + return ( + tuple(sorted(tracks.graph_full.node_ids())), + tuple(sorted(tracks.graph_full.edge_list())), + ) + + +def _positions(tracks): + """pos per solution node as plain float tuples (array-safe for == comparison).""" + return { + n: tuple(float(x) for x in tracks.get_node_attr(n, "pos")) + for n in tracks.graph_solution.node_ids() + } + + +@pytest.mark.parametrize("ndim", [3, 4]) +@pytest.mark.parametrize("with_seg", [True, False]) +def test_soft_delete_keeps_leaf_node_in_full_graph(get_tracks, ndim, with_seg): + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + full_before = _full_state(tracks) + sol_nodes_before = set(tracks.graph_solution.node_ids()) + + UserDeleteNode(tracks, 5) + + # Gone from the solution view ... + assert 5 not in tracks.graph_solution.node_ids() + assert set(tracks.graph_solution.node_ids()) == sol_nodes_before - {5} + assert not tracks.graph_solution.has_edge(4, 5) + + # ... but preserved in the full graph as a soft-deleted candidate. Topology of the + # full graph is unchanged: only the solution flag flipped. + assert 5 in tracks.graph_full.node_ids() + assert tracks.graph_full.has_edge(4, 5) + assert tracks.graph_full.nodes[5]["solution"] is False + assert _full_state(tracks) == full_before + + +@pytest.mark.parametrize("ndim", [3, 4]) +@pytest.mark.parametrize("with_seg", [True, False]) +def test_delete_undo_redo_roundtrip_identity(get_tracks, ndim, with_seg): + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + sol_ref = _solution_state(tracks) + full_ref = _full_state(tracks) + pos_ref = _positions(tracks) + + UserDeleteNode(tracks, 5) + assert _solution_state(tracks) != sol_ref # actually changed + + # Undo restores the solution view exactly, and the full graph is untouched. + assert tracks.undo() + assert _solution_state(tracks) == sol_ref + assert _full_state(tracks) == full_ref + assert _positions(tracks) == pos_ref + + # Redo re-deletes; undo restores again — same states. + assert tracks.redo() + assert 5 not in tracks.graph_solution.node_ids() + assert tracks.undo() + assert _solution_state(tracks) == sol_ref + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_repeated_delete_undo_is_stable(get_tracks, ndim): + """Invariant #4: N undo/redo cycles must not drift the solution view or full graph.""" + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) + sol_ref = _solution_state(tracks) + iou_ref = tracks.get_edge_attr((4, 5), "iou") + + UserDeleteNode(tracks, 5) + deleted_sol = _solution_state(tracks) + full_after_delete = _full_state(tracks) + + for _ in range(5): + assert tracks.undo() + assert _solution_state(tracks) == sol_ref + # Revived edge keeps its computed attribute. + assert tracks.get_edge_attr((4, 5), "iou") == iou_ref + + assert tracks.redo() + assert _solution_state(tracks) == deleted_sol + # Full graph topology is identical every redo — no candidate accumulation. + assert _full_state(tracks) == full_after_delete + + # Leave it restored and confirm a clean final identity. + assert tracks.undo() + assert _solution_state(tracks) == sol_ref + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_mid_track_delete_leaves_skip_edge_candidate_in_full(get_tracks, ndim): + """Deleting a mid-track node adds a reconnection skip-edge (3->5). On undo it is + soft-deleted, so it persists in graph_full as a solution=False candidate while the + solution view returns to its original topology.""" + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) + sol_ref = _solution_state(tracks) + assert not tracks.graph_full.has_edge(3, 5) + + UserDeleteNode(tracks, 4) + assert tracks.graph_solution.has_edge(3, 5) + assert not tracks.graph_solution.has_edge(3, 4) + + assert tracks.undo() + # Solution view is back to the original topology (skip edge removed from the view) ... + assert _solution_state(tracks) == sol_ref + assert not tracks.graph_solution.has_edge(3, 5) + # ... but the skip edge now lives in graph_full as a candidate; node 4 is retained. + assert tracks.graph_full.has_edge(3, 5) + assert tracks.graph_full.edge_id(3, 5) is not None + assert 4 in tracks.graph_full.node_ids() + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_attr_reads_resolve_for_soft_deleted_node(get_tracks, ndim): + """Regression: attribute reads must resolve for soft-deleted (solution=False) nodes, + and the bulk `get_positions` must agree with the single-node `get_position`. The bulk + path previously queried `graph_solution` and KeyError'd on a soft-deleted node while + `get_position` (graph_full) succeeded — a latent inconsistency invisible to tests that + only ever query in-solution nodes. + """ + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) + pos_single_before = tracks.get_position(5) + pos_bulk_before = tracks.get_positions([5])[0].tolist() + assert pos_bulk_before == pytest.approx(pos_single_before) + + UserDeleteNode(tracks, 5) + assert 5 not in tracks.graph_solution.node_ids() # soft-deleted + + # Both single and bulk position reads still resolve (graph_full) and agree. + pos_single_after = tracks.get_position(5) + pos_bulk_after = tracks.get_positions([5])[0].tolist() + assert pos_single_after == pytest.approx(pos_single_before) + assert pos_bulk_after == pytest.approx(pos_single_before) + + # Other intrinsic attrs resolve too. + assert tracks.get_node_attr(5, "area") is not None + assert tracks.get_time(5) is not None diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 063fba80..4f1afa81 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -2,11 +2,11 @@ import polars as pl from funtracks.actions import AddNode -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks from funtracks.import_export import export_to_csv from funtracks.user_actions import UserUpdateSegmentation from funtracks.utils.tracksdata_utils import ( - create_empty_graphview_graph, + create_empty_graph, td_mask_to_pixels, ) @@ -14,7 +14,7 @@ def test_recompute_track_ids(graph_2d_with_track_id): - tracks = SolutionTracks( + tracks = Tracks( graph_2d_with_track_id, ndim=3, **track_attrs, @@ -22,8 +22,20 @@ def test_recompute_track_ids(graph_2d_with_track_id): assert tracks.get_next_track_id() == 6 +def test_recompute_track_ids_when_sentinel_present(graph_2d_with_track_id): + """A tracklet column that exists but still holds the -1 sentinel means "not + computed": Tracks must recompute track ids from topology, not trust the sentinels.""" + graph = graph_2d_with_track_id + node_ids = list(graph.node_ids()) + graph.update_node_attrs(attrs={"track_id": [-1] * len(node_ids)}, node_ids=node_ids) + + tracks = Tracks(graph, ndim=3, **track_attrs) + + assert -1 not in tracks.get_track_ids(tracks.nodes()) + + def test_next_track_id(graph_2d_with_track_id): - tracks = SolutionTracks(graph_2d_with_track_id, ndim=3, **track_attrs) + tracks = Tracks(graph_2d_with_track_id, ndim=3, **track_attrs) assert tracks.get_next_track_id() == 6 AddNode( tracks, @@ -33,47 +45,8 @@ def test_next_track_id(graph_2d_with_track_id): assert tracks.get_next_track_id() == 11 -def test_from_tracks_cls(graph_2d_with_segmentation): - tracks = Tracks( - graph_2d_with_segmentation, - ndim=3, - pos_attr="POSITION", - time_attr="TIME", - tracklet_attr=track_attrs["tracklet_attr"], - scale=(2, 2, 2), - ) - solution_tracks = SolutionTracks.from_tracks(tracks) - assert solution_tracks.graph == tracks.graph - assert solution_tracks.segmentation == tracks.segmentation - assert solution_tracks.features.time_key == tracks.features.time_key - assert solution_tracks.features.position_key == tracks.features.position_key - assert solution_tracks.scale == tracks.scale - assert solution_tracks.ndim == tracks.ndim - assert solution_tracks.get_node_attr(6, tracks.features.tracklet_key) == 5 - - -def test_from_tracks_cls_recompute(graph_2d_with_segmentation): - tracks = Tracks( - graph_2d_with_segmentation, - ndim=3, - pos_attr="POSITION", - time_attr="TIME", - tracklet_attr=track_attrs["tracklet_attr"], - scale=(2, 2, 2), - ) - # delete track id (default value -1) on one node triggers reassignment of - # track_ids even when recompute is False. - tracks.graph.nodes[1][tracks.features.tracklet_key] = -1 - solution_tracks = SolutionTracks.from_tracks(tracks) - # should have reassigned new track_id to node 6 - assert solution_tracks.get_node_attr(6, solution_tracks.features.tracklet_key) == 4 - assert ( - solution_tracks.get_node_attr(1, solution_tracks.features.tracklet_key) == 1 - ) # still 1 - - def test_update_segmentation(graph_2d_with_segmentation): - tracks = SolutionTracks( + tracks = Tracks( graph_2d_with_segmentation, ndim=3, **track_attrs, @@ -90,42 +63,27 @@ def test_update_segmentation(graph_2d_with_segmentation): def test_next_track_id_empty(): - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["pos", "track_id"], edge_attributes=[], ) - tracks = SolutionTracks(graph, ndim=4, **track_attrs) + tracks = Tracks(graph, ndim=4, **track_attrs) assert tracks.get_next_track_id() == 1 -def test_get_lineage_id_without_lineage_key(graph_2d_with_track_id): - """Test that get_lineage_id returns None when lineage_key is not set.""" - graph = graph_2d_with_track_id - graph.add_node( - attrs={"t": 1, "pos": [0, 0], "track_id": 1}, index=7, validate_keys=False - ) - tracks = SolutionTracks(graph, ndim=3, **track_attrs) - - # Unset lineage_key to test the None path - tracks.features.lineage_key = None - - # get_lineage_id should return None when lineage_key is not set - assert tracks.get_lineage_id(1) is None - - def test_export_to_csv_with_display_names( graph_2d_with_segmentation, graph_3d_with_segmentation, tmp_path ): """Test CSV export with use_display_names=True option.""" # Test 2D with display names - tracks = SolutionTracks(graph_2d_with_segmentation, **track_attrs, ndim=3) + tracks = Tracks(graph_2d_with_segmentation, **track_attrs, ndim=3) tracks.enable_features(["area"]) temp_file = tmp_path / "test_export_2d_display.csv" export_to_csv(tracks, temp_file, use_display_names=True) with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.num_nodes() + 1 # add header + assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header # With display names: ID, Parent ID, Time, y, x, Tracklet ID, # Lineage ID, Area @@ -142,14 +100,14 @@ def test_export_to_csv_with_display_names( assert lines[0].strip().split(",") == header # Test 3D with display names (area display name is "Volume" in 3D) - tracks = SolutionTracks(graph_3d_with_segmentation, **track_attrs, ndim=4) + tracks = Tracks(graph_3d_with_segmentation, **track_attrs, ndim=4) tracks.enable_features(["area"]) temp_file = tmp_path / "test_export_3d_display.csv" export_to_csv(tracks, temp_file, use_display_names=True) with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.num_nodes() + 1 # add header + assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header # With display names: ID, Parent ID, Time, z, y, x, # Tracklet ID, Lineage ID, Volume @@ -171,7 +129,7 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation): """pos_attr as list should be respected even when segmentation is present. Scenario: graph has both a "pos" column AND individual z/y/x columns with - distinct values. SolutionTracks(pos_attr=['z','y','x']) should use z/y/x + distinct values. Tracks(pos_attr=['z','y','x']) should use z/y/x as the position_key, not fall back to "pos". """ graph = graph_3d_with_segmentation @@ -187,7 +145,7 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation): graph.nodes[node]["y"] = float(pos[1]) + offset graph.nodes[node]["x"] = float(pos[2]) + offset - tracks = SolutionTracks( + tracks = Tracks( graph=graph, pos_attr=["z", "y", "x"], ndim=4, @@ -204,5 +162,5 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation): expected = [float(original_pos[i]) + offset for i in range(3)] assert list(pos_from_tracks) == expected, ( f"Expected positions from z/y/x ({expected}), " - f"got {list(pos_from_tracks)} — SolutionTracks used 'pos' instead" + f"got {list(pos_from_tracks)} — Tracks used 'pos' instead" ) diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 74dd717b..06f35b81 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -8,16 +8,16 @@ from funtracks.features import SegBbox, SegMask from funtracks.user_actions import UserUpdateSegmentation from funtracks.utils.tracksdata_utils import ( - create_empty_graphview_graph, + create_empty_graph, to_polars_dtype, ) track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} -def test_create_tracks(graph_3d_with_segmentation: td.graph.GraphView): +def test_create_tracks(graph_3d_with_segmentation: td.graph.BaseGraph): # create empty tracks - empty_graph = create_empty_graphview_graph() + empty_graph = create_empty_graph() tracks = Tracks(graph=empty_graph, ndim=3, **track_attrs) # type: ignore[arg-type] assert tracks.features.position_key == "pos" assert isinstance(tracks.features["pos"], dict) @@ -96,7 +96,7 @@ def test_nodes_edges(graph_2d_with_segmentation): assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6} assert len(tracks.edges()) == 4 # rx graph starts from 0, sql from 1, # so direct comparison of edges depends on backend - assert set(map(tuple, tracks.graph.edge_list())) == { + assert set(map(tuple, tracks.graph_solution.edge_list())) == { (1, 2), (1, 3), (3, 4), @@ -104,18 +104,6 @@ def test_nodes_edges(graph_2d_with_segmentation): } -def test_degrees(graph_2d_with_segmentation): - tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) - assert tracks.in_degree(np.array([1])) == 0 - assert tracks.in_degree(np.array([4])) == 1 - assert np.array_equal(tracks.in_degree(None), np.array([0, 1, 1, 1, 1, 0])) - assert np.array_equal(tracks.out_degree(np.array([1, 4])), np.array([2, 1])) - assert np.array_equal( - tracks.out_degree(None), - np.array([2, 0, 1, 1, 0, 0]), - ) - - def test_predecessors_successors(graph_2d_with_segmentation): tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) assert tracks.predecessors(2) == [1] @@ -202,7 +190,7 @@ def test_set_pixels_no_segmentation(graph_2d_with_track_id): def test_compute_ndim_errors(): - g = create_empty_graphview_graph() + g = create_empty_graph() g.add_node_attr_key("pos", default_value=[0, 0], dtype=pl.List(pl.Int64)) g.add_node(index=1, attrs={"t": 0, "pos": [0, 0, 0], "solution": True}) @@ -225,7 +213,7 @@ def test_get_new_node_ids(graph_2d_with_position): assert 1 not in ids # existing nodes skipped assert 2 not in ids for node_id in ids: - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) # second call must not overlap with first ids2 = tracks._get_new_node_ids(2) @@ -241,7 +229,9 @@ def test_undo_redo(graph_2d_with_segmentation): assert tracks.redo() is False # Perform an action - add a custom attribute - tracks.graph.add_node_attr_key("custom_label", default_value=None, dtype=pl.Object) + tracks.graph_solution.add_node_attr_key( + "custom_label", default_value=None, dtype=pl.Object + ) action1 = UpdateNodeAttrs(tracks, node=1, attrs={"custom_label": "test_value"}) tracks.action_history.add_new_action(action1) @@ -262,7 +252,9 @@ def test_undo_redo(graph_2d_with_segmentation): assert tracks.redo() is False # Perform another action - tracks.graph.add_node_attr_key("another_label", default_value=None, dtype=pl.Object) + tracks.graph_solution.add_node_attr_key( + "another_label", default_value=None, dtype=pl.Object + ) action2 = UpdateNodeAttrs(tracks, node=2, attrs={"another_label": "second_value"}) tracks.action_history.add_new_action(action2) assert tracks.get_node_attr(2, "another_label") == "second_value" @@ -310,17 +302,17 @@ def test_to_polars_dtype_mask(): def test_add_feature_mask_creates_both_columns(): """add_feature with mask and bbox Features creates both columns.""" - graph = create_empty_graphview_graph(ndim=3) + graph = create_empty_graph(ndim=3) tracks = Tracks(graph, ndim=3, **track_attrs) - assert "nuc_mask" not in tracks.graph.node_attr_keys() - assert "nuc_bbox" not in tracks.graph.node_attr_keys() + assert "nuc_mask" not in tracks.graph_solution.node_attr_keys() + assert "nuc_bbox" not in tracks.graph_solution.node_attr_keys() tracks.add_feature("nuc_mask", SegMask(ndim=3, bbox_key="nuc_bbox")) tracks.add_feature("nuc_bbox", SegBbox(ndim=3)) - assert "nuc_mask" in tracks.graph.node_attr_keys() - assert "nuc_bbox" in tracks.graph.node_attr_keys() + assert "nuc_mask" in tracks.graph_solution.node_attr_keys() + assert "nuc_bbox" in tracks.graph_solution.node_attr_keys() assert "nuc_mask" in tracks.features assert "nuc_bbox" in tracks.features @@ -333,15 +325,15 @@ def test_delete_feature_mask_removes_both_columns( assert "mask" in tracks.features assert "bbox" in tracks.features - assert "mask" in tracks.graph.node_attr_keys() - assert "bbox" in tracks.graph.node_attr_keys() + assert "mask" in tracks.graph_solution.node_attr_keys() + assert "bbox" in tracks.graph_solution.node_attr_keys() tracks.delete_feature("mask") assert "mask" not in tracks.features assert "bbox" not in tracks.features - assert "mask" not in tracks.graph.node_attr_keys() - assert "bbox" not in tracks.graph.node_attr_keys() + assert "mask" not in tracks.graph_solution.node_attr_keys() + assert "bbox" not in tracks.graph_solution.node_attr_keys() def test_update_mask_syncs_bbox(graph_2d_with_segmentation): @@ -353,8 +345,8 @@ def test_update_mask_syncs_bbox(graph_2d_with_segmentation): new_mask = make_2d_disk_mask(center=(30, 30), radius=10) tracks.update_mask(1, new_mask) - stored_mask = tracks.graph.nodes[1][td.DEFAULT_ATTR_KEYS.MASK] - stored_bbox = tracks.graph.nodes[1][td.DEFAULT_ATTR_KEYS.BBOX] + stored_mask = tracks.graph_solution.nodes[1][td.DEFAULT_ATTR_KEYS.MASK] + stored_bbox = tracks.graph_solution.nodes[1][td.DEFAULT_ATTR_KEYS.BBOX] assert stored_mask is new_mask assert np.array_equal(stored_bbox, new_mask.bbox) diff --git a/tests/import_export/test_csv_export.py b/tests/import_export/test_csv_export.py index d12beec3..ae13a721 100644 --- a/tests/import_export/test_csv_export.py +++ b/tests/import_export/test_csv_export.py @@ -16,14 +16,14 @@ ) def test_export_solution_to_csv(get_tracks, tmp_path, ndim, expected_header): """Test exporting tracks to CSV.""" - tracks = get_tracks(ndim=ndim, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=False, prefill_track_ids=True) temp_file = tmp_path / "test_export.csv" export_to_csv(tracks, temp_file) with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.num_nodes() + 1 # add header + assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header assert lines[0].strip().split(",") == expected_header # Check first data line (node 1: t=0, pos=[50, 50] or [50, 50, 50], track_id=1) @@ -46,7 +46,7 @@ def test_export_solution_to_csv_with_seg_zarr( get_tracks, tmp_path, ndim, expected_header ): """Test exporting tracks to CSV + segmentation painted by track_id as zarr.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) temp_file = tmp_path / "test_export.csv" seg_dir = tmp_path / "test_export_seg" export_to_csv(tracks, temp_file, export_seg=True, seg_path=seg_dir) @@ -54,7 +54,7 @@ def test_export_solution_to_csv_with_seg_zarr( with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.num_nodes() + 1 # add header + assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header assert lines[0].strip().split(",") == expected_header # check the segmentation zarr @@ -66,14 +66,16 @@ def test_export_solution_to_csv_with_seg_zarr( seg_arr = seg_zarr[:] unique_vals = set(seg_arr.flatten()) - {0} label_key = tracks.features.tracklet_key - track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list()) + track_ids = set( + tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list() + ) assert unique_vals == track_ids @pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"]) def test_export_solution_to_csv_with_seg_tiff(get_tracks, tmp_path, ndim): """Test exporting tracks to CSV + segmentation as tiff painted by tracklet ID.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) temp_file = tmp_path / "test_export.csv" seg_file = tmp_path / "test_export_seg.tif" export_to_csv( @@ -91,14 +93,16 @@ def test_export_solution_to_csv_with_seg_tiff(get_tracks, tmp_path, ndim): # values should be tracklet_ids (not node_ids) — default seg_relabel="tracklet" unique_vals = set(seg_arr.flatten()) - {0} label_key = tracks.features.tracklet_key - track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list()) + track_ids = set( + tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list() + ) assert unique_vals == track_ids @pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"]) def test_export_solution_to_csv_with_seg_original_labels(get_tracks, tmp_path, ndim): """Test exporting tracks to CSV + segmentation with original (node_id) labels.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) temp_file = tmp_path / "test_export.csv" seg_dir = tmp_path / "test_export_seg" export_to_csv( @@ -116,7 +120,7 @@ def test_export_solution_to_csv_with_seg_original_labels(get_tracks, tmp_path, n # values should be node_ids (original labels), not track_ids seg_arr = seg_zarr[:] unique_vals = set(seg_arr.flatten()) - {0} - node_ids = set(tracks.graph.node_ids()) + node_ids = set(tracks.graph_solution.node_ids()) assert unique_vals == node_ids @@ -124,11 +128,11 @@ def test_export_with_color_dict(get_tracks, tmp_path): """Test exporting with a color_dict adds a Tracklet ID Color column.""" import numpy as np - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) temp_file = tmp_path / "test_export_colors.csv" # Build a color dict: node_id → [R, G, B] floats in [0, 1] - node_ids = list(tracks.graph.node_ids()) + node_ids = list(tracks.graph_solution.node_ids()) color_dict = { node_id: np.array([0.1 * (i % 10), 0.5, 0.9]) for i, node_id in enumerate(node_ids) @@ -146,7 +150,7 @@ def test_export_with_color_dict(get_tracks, tmp_path): def test_export_with_display_names(get_tracks, tmp_path): """Test exporting with display names.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) temp_file = tmp_path / "test_export_display.csv" export_to_csv(tracks, temp_file, use_display_names=True) @@ -161,7 +165,7 @@ def test_export_with_display_names(get_tracks, tmp_path): def test_export_filtered_nodes(get_tracks, tmp_path): """Test exporting only specific nodes.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) temp_file = tmp_path / "test_export_filtered.csv" # Export only nodes 1 and 2 (and their ancestors) @@ -177,7 +181,7 @@ def test_export_filtered_nodes(get_tracks, tmp_path): def test_ignore_edge_features_at_export(get_tracks, tmp_path): """Test that edge features are ignored when exporting to csv""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) temp_file = tmp_path / "test_export_node_features_only.csv" # enable node and edge features @@ -209,7 +213,7 @@ def test_export_solution_to_csv_with_seg_and_node_subset( and segmentation must only include the resulting graph nodes, and nothing else. """ - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) csv_file = tmp_path / "export.csv" seg_path = tmp_path / "seg" @@ -266,8 +270,10 @@ def test_export_solution_to_csv_with_seg_and_node_subset( else: label_key = tracks.features.tracklet_key - labels = tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list() - node_to_label = dict(zip(tracks.graph.node_ids(), labels, strict=True)) + labels = tracks.graph_solution.node_attrs(attr_keys=[label_key])[ + label_key + ].to_list() + node_to_label = dict(zip(tracks.graph_solution.node_ids(), labels, strict=True)) expected = np.zeros_like(original) @@ -278,3 +284,41 @@ def test_export_solution_to_csv_with_seg_and_node_subset( expected_vals = {node_to_label[n] for n in graph_nodes} assert unique_vals == expected_vals + + +def test_export_full_vs_solution(get_tracks, tmp_path): + """export_full=True includes soft-deleted (solution=False) nodes and surfaces the + 'Solution' column; the default exports only the solution view.""" + import pandas as pd + + from funtracks.user_actions import UserDeleteNode + + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) + # Soft-delete a leaf node: dropped from the solution view, kept in graph_full. + UserDeleteNode(tracks, 5) + sol_n = tracks.graph_solution.num_nodes() + full_n = tracks.graph_full.num_nodes() + assert full_n == sol_n + 1 + + # Default: solution view only — node 5 absent. + sol_file = tmp_path / "sol.csv" + export_to_csv(tracks, sol_file) + sol = pd.read_csv(sol_file) + assert len(sol) == sol_n + assert 5 not in set(sol["id"]) + + # export_full=True: full graph — node 5 present. + full_file = tmp_path / "full.csv" + export_to_csv(tracks, full_file, export_full=True) + full = pd.read_csv(full_file) + assert len(full) == full_n + assert 5 in set(full["id"]) + + # With display names, the full export carries a 'Solution' column distinguishing + # the soft-deleted node (False) from the rest (True). + disp_file = tmp_path / "full_disp.csv" + export_to_csv(tracks, disp_file, export_full=True, use_display_names=True) + disp = pd.read_csv(disp_file) + assert "Solution" in disp.columns + assert not bool(disp.loc[disp["ID"] == 5, "Solution"].iloc[0]) + assert bool(disp.loc[disp["ID"] == 1, "Solution"].iloc[0]) diff --git a/tests/import_export/test_csv_import.py b/tests/import_export/test_csv_import.py index 613dc6b5..18fcd66f 100644 --- a/tests/import_export/test_csv_import.py +++ b/tests/import_export/test_csv_import.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks from funtracks.import_export import tracks_from_df @@ -42,9 +42,9 @@ def test_import_2d(self, simple_df_2d): """Test importing 2D DataFrame.""" tracks = tracks_from_df(simple_df_2d) - assert isinstance(tracks, SolutionTracks) - assert tracks.graph.num_nodes() == 4 - assert tracks.graph.num_edges() == 3 + assert isinstance(tracks, Tracks) + assert tracks.graph_solution.num_nodes() == 4 + assert tracks.graph_solution.num_edges() == 3 assert tracks.ndim == 3 def test_import_3d(self, df_3d): @@ -52,7 +52,7 @@ def test_import_3d(self, df_3d): tracks = tracks_from_df(df_3d) assert tracks.ndim == 4 - assert tracks.graph.num_nodes() == 3 + assert tracks.graph_solution.num_nodes() == 3 # Check z coordinate pos = tracks.get_position(1) assert len(pos) == 3 # z, y, x @@ -80,12 +80,12 @@ def test_edges_created(self, simple_df_2d): tracks = tracks_from_df(simple_df_2d) # Check specific edges exist - assert tracks.graph.has_edge(1, 2) - assert tracks.graph.has_edge(1, 3) - assert tracks.graph.has_edge(2, 4) + assert tracks.graph_solution.has_edge(1, 2) + assert tracks.graph_solution.has_edge(1, 3) + assert tracks.graph_solution.has_edge(2, 4) # Check node 1 has two children (division) - assert len(list(tracks.graph.successors(1))) == 2 + assert len(list(tracks.graph_solution.successors(1))) == 2 class TestSegmentationHandling: @@ -162,8 +162,8 @@ def test_single_node(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 1 - assert tracks.graph.num_edges() == 0 + assert tracks.graph_solution.num_nodes() == 1 + assert tracks.graph_solution.num_edges() == 0 def test_multiple_roots(self): """Test multiple independent lineages.""" @@ -179,11 +179,15 @@ def test_multiple_roots(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 4 - assert tracks.graph.num_edges() == 2 + assert tracks.graph_solution.num_nodes() == 4 + assert tracks.graph_solution.num_edges() == 2 # Should have two root nodes - roots = [n for n in tracks.graph.node_ids() if tracks.graph.in_degree(n) == 0] + roots = [ + n + for n in tracks.graph_solution.node_ids() + if len(tracks.predecessors(n)) == 0 + ] assert len(roots) == 2 def test_division_nan_parent(self): @@ -207,10 +211,10 @@ def test_division_nan_parent(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 3 - assert tracks.graph.num_edges() == 2 + assert tracks.graph_solution.num_nodes() == 3 + assert tracks.graph_solution.num_edges() == 2 - children = list(tracks.graph.successors(1)) + children = list(tracks.graph_solution.successors(1)) assert set(children) == {2, 3} def test_division_event(self): @@ -227,11 +231,11 @@ def test_division_event(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 3 - assert tracks.graph.num_edges() == 2 + assert tracks.graph_solution.num_nodes() == 3 + assert tracks.graph_solution.num_edges() == 2 # Node 1 should have two children - children = list(tracks.graph.successors(1)) + children = list(tracks.graph_solution.successors(1)) assert len(children) == 2 assert set(children) == {2, 3} @@ -249,19 +253,23 @@ def test_long_track(self): tracks = tracks_from_df(df) - assert tracks.graph.num_nodes() == 10 - assert tracks.graph.num_edges() == 9 + assert tracks.graph_solution.num_nodes() == 10 + assert tracks.graph_solution.num_edges() == 9 # Should form a single linear chain - roots = [n for n in tracks.graph.node_ids() if tracks.graph.in_degree(n) == 0] + roots = [ + n + for n in tracks.graph_solution.node_ids() + if len(tracks.predecessors(n)) == 0 + ] assert len(roots) == 1 # Each non-leaf node should have exactly one child non_leaves = [ - n for n in tracks.graph.node_ids() if tracks.graph.out_degree(n) > 0 + n for n in tracks.graph_solution.node_ids() if len(tracks.successors(n)) > 0 ] for node in non_leaves: - assert tracks.graph.out_degree(node) == 1 + assert len(tracks.successors(node)) == 1 def test_orphaned_node_raises_error(self): """Test that node with invalid parent_id raises error.""" @@ -361,8 +369,8 @@ def test_seg_id_same_as_id(self, simple_df_2d): tracks = tracks_from_df(simple_df_2d, node_name_map=name_map) # Both id and seg_id should be present with same values - assert tracks.graph.num_nodes() == 4 - for node_id in tracks.graph.node_ids(): + assert tracks.graph_solution.num_nodes() == 4 + for node_id in tracks.graph_solution.node_ids(): assert tracks.get_node_attr(node_id, "seg_id") == node_id def test_duplicate_mapping_with_segmentation(self, simple_df_2d): @@ -654,7 +662,7 @@ def test_empty_list_in_name_map_removed(self): tracks = tracks_from_df(df, node_name_map=name_map) assert tracks is not None # The empty mapping should not result in a feature being added - assert "ellipse_axis_radii" not in tracks.graph.node_attr_keys() + assert "ellipse_axis_radii" not in tracks.graph_solution.node_attr_keys() def test_import_without_position_with_segmentation(self): """Test that position can be omitted when segmentation is provided. @@ -687,8 +695,8 @@ def test_import_without_position_with_segmentation(self): assert tracks is not None # Position should be computed from segmentation centroids - assert "pos" in tracks.graph.node_attr_keys() - pos_1 = tracks.graph.nodes[1]["pos"] + assert "pos" in tracks.graph_solution.node_attr_keys() + pos_1 = tracks.graph_solution.nodes[1]["pos"] # Centroid of 3x3 region at [2:5, 2:5] is approximately [3, 3] np.testing.assert_array_almost_equal(pos_1, [3.0, 3.0], decimal=0) diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index b67d9df6..05c8814d 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -4,7 +4,7 @@ import tifffile import zarr -from funtracks.data_model import SolutionTracks, Tracks +from funtracks.data_model import Tracks from funtracks.import_export import export_to_geff, import_from_geff, write_to_geff @@ -32,13 +32,13 @@ def _assert_valid_geff_export(export_dir, expected_num_nodes=None): @pytest.mark.parametrize("seg_relabel", ["tracklet", "lineage", None]) def test_export_segmentation_relabel(get_tracks, ndim, seg_relabel, tmp_path): """Test segmentation export with each relabel strategy.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() export_to_geff(tracks, export_dir, seg_relabel=seg_relabel) - z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes()) + z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes()) # Segmentation file must exist seg_path = export_dir / "segmentation" @@ -54,11 +54,13 @@ def test_export_segmentation_relabel(get_tracks, ndim, seg_relabel, tmp_path): else: label_key = tracks.features.tracklet_key label_vals = set( - tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list() + tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list() ) assert unique_vals == label_vals else: - assert unique_vals == set(tracks.graph.node_ids()) + # values should be original node_ids + node_ids_set = set(tracks.graph_solution.node_ids()) + assert unique_vals == node_ids_set # segmentation_shape must be in metadata attrs = dict(z.attrs) @@ -68,13 +70,13 @@ def test_export_segmentation_relabel(get_tracks, ndim, seg_relabel, tmp_path): def test_export_no_segmentation_saved(get_tracks, tmp_path): """Test that save_segmentation=False suppresses segmentation file.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() export_to_geff(tracks, export_dir, save_segmentation=False) - z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes()) + z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes()) assert not (export_dir / "segmentation").exists() @@ -85,13 +87,13 @@ def test_export_no_segmentation_saved(get_tracks, tmp_path): def test_export_without_seg_on_tracks(get_tracks, tmp_path): """Test export when tracks have no segmentation at all.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() export_to_geff(tracks, export_dir) - z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes()) + z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes()) assert not (export_dir / "segmentation").exists() @@ -99,32 +101,21 @@ def test_export_without_seg_on_tracks(get_tracks, tmp_path): assert "segmentation_shape" not in attrs -@pytest.mark.parametrize("seg_relabel", ["tracklet", "lineage"]) -def test_export_seg_relabel_non_solution_raises(get_tracks, seg_relabel, tmp_path): - """Relabeling by tracklet/lineage on non-solution tracks raises ValueError.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=False) - - export_dir = tmp_path / "export" - export_dir.mkdir() - with pytest.raises(ValueError): - export_to_geff(tracks, export_dir, seg_relabel=seg_relabel) - - def test_export_segmentation_non_solution(get_tracks, tmp_path): """Non-solution tracks export segmentation fine with seg_relabel=None.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=False) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=False) export_dir = tmp_path / "export" export_dir.mkdir() export_to_geff(tracks, export_dir, seg_relabel=None) - z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes()) + z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes()) # No relabel: segmentation pixels keep original node_ids seg_zarr = zarr.open(str(export_dir / "segmentation"), mode="r") assert isinstance(seg_zarr, zarr.Array) assert seg_zarr.shape == tracks.segmentation.shape - assert set(seg_zarr[:].flatten()) - {0} == set(tracks.graph.node_ids()) + assert set(seg_zarr[:].flatten()) - {0} == set(tracks.graph_solution.node_ids()) attrs = dict(z.attrs) assert "segmentation_shape" in attrs @@ -135,10 +126,10 @@ def test_export_segmentation_non_solution(get_tracks, tmp_path): @pytest.mark.parametrize("ndim", [3, 4]) -@pytest.mark.parametrize("is_solution", [True, False]) -def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path): +@pytest.mark.parametrize("prefill_track_ids", [True, False]) +def test_export_split_position_attrs(get_graph, ndim, prefill_track_ids, tmp_path): """Test export with split (list) position attributes.""" - graph = get_graph(ndim, is_solution=is_solution, with_seg=False) + graph = get_graph(ndim, prefill_track_ids=prefill_track_ids, with_seg=False) pos_keys = ["y", "x"] if ndim == 3 else ["z", "y", "x"] for key in pos_keys: @@ -149,12 +140,11 @@ def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path): graph.nodes[node][key] = pos[i] graph.remove_node_attr_key("pos") - tracks_cls = SolutionTracks if is_solution else Tracks - tracks = tracks_cls( + tracks = Tracks( graph, time_attr="t", pos_attr=pos_keys, - tracklet_attr="track_id" if is_solution else None, + tracklet_attr="track_id" if prefill_track_ids else None, ndim=ndim, ) @@ -162,7 +152,7 @@ def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path): export_dir.mkdir() export_to_geff(tracks, export_dir, save_segmentation=False) - z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes()) + z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes()) # Verify axis names include the split position keys axes = dict(z.attrs)["geff"]["axes"] @@ -177,7 +167,7 @@ def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path): @pytest.mark.parametrize("ndim", [3, 4]) def test_export_node_subset(get_tracks, ndim, tmp_path): """Test exporting a subset of nodes includes ancestors.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -208,7 +198,7 @@ def test_export_node_subset(get_tracks, ndim, tmp_path): def test_export_node_subset_seg_relabel(get_tracks, tmp_path): """Test subset export with relabeled segmentation.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -224,8 +214,10 @@ def test_export_node_subset_seg_relabel(get_tracks, tmp_path): original = np.asarray(tracks.segmentation[:]) label_key = tracks.features.tracklet_key - labels = tracks.graph.node_attrs(attr_keys=[label_key])[label_key] - node_to_label = dict(zip(tracks.graph.node_ids(), labels.to_list(), strict=True)) + labels = tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key] + node_to_label = dict( + zip(tracks.graph_solution.node_ids(), labels.to_list(), strict=True) + ) graph_nodes = set(expected_graph_nodes) expected = np.zeros_like(original) @@ -240,7 +232,7 @@ def test_export_node_subset_seg_relabel(get_tracks, tmp_path): def test_export_node_subset_without_seg(get_tracks, tmp_path): """Test subset export when tracks have no segmentation.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -258,7 +250,7 @@ def test_export_node_subset_without_seg(get_tracks, tmp_path): def test_export_overwrite(get_tracks, tmp_path): """Test export with overwrite=True into non-empty directory.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -266,7 +258,7 @@ def test_export_overwrite(get_tracks, tmp_path): export_to_geff(tracks, export_dir, overwrite=True) - _assert_valid_geff_export(export_dir, tracks.graph.num_nodes()) + _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes()) # Segmentation is still written correctly alongside the overwritten dir seg_zarr = zarr.open(str(export_dir / "segmentation"), mode="r") @@ -276,7 +268,7 @@ def test_export_overwrite(get_tracks, tmp_path): def test_export_non_directory_raises(get_tracks, tmp_path): """Test that exporting to a file path (not a directory) raises an error.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) file_path = tmp_path / "not_a_dir" file_path.write_text("test") @@ -292,13 +284,13 @@ def test_export_non_directory_raises(get_tracks, tmp_path): @pytest.mark.parametrize("with_seg", [True, False]) def test_export_metadata(get_tracks, ndim, with_seg, tmp_path): """Test axes structure, segmentation_shape, and FeatureDict in metadata.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() export_to_geff(tracks, export_dir, save_segmentation=with_seg) - z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes()) + z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes()) attrs = dict(z.attrs) # Correct number of axes @@ -327,7 +319,7 @@ def test_export_metadata(get_tracks, ndim, with_seg, tmp_path): @pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"]) def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path): """Test that segmentation can be exported as tiff alongside the geff graph.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) export_dir = tmp_path / "export" export_dir.mkdir() @@ -344,7 +336,9 @@ def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path): # values should be tracklet_ids (default seg_relabel="tracklet") unique_vals = set(seg_arr.flatten()) - {0} label_key = tracks.features.tracklet_key - track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list()) + track_ids = set( + tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list() + ) assert unique_vals == track_ids # Check metadata references the tiff path with ../../ prefix (sibling of geff dir) @@ -360,16 +354,16 @@ def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path): @pytest.mark.parametrize("with_seg", [False, True]) def test_write_to_geff_roundtrip(get_tracks, ndim, with_seg, tmp_path): """write_to_geff then import_from_geff recovers the tracks.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) loaded = import_from_geff(geff_path) - assert loaded.graph.num_nodes() == tracks.graph.num_nodes() - assert loaded.graph.num_edges() == tracks.graph.num_edges() - assert set(loaded.graph.node_ids()) == set(tracks.graph.node_ids()) + assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes() + assert loaded.graph_solution.num_edges() == tracks.graph_solution.num_edges() + assert set(loaded.graph_solution.node_ids()) == set(tracks.graph_solution.node_ids()) assert loaded.features.dump_json() == tracks.features.dump_json() if with_seg: @@ -383,7 +377,7 @@ def test_write_to_geff_roundtrip(get_tracks, ndim, with_seg, tmp_path): def test_write_to_geff_no_parent_container(get_tracks, tmp_path): """write_to_geff writes directly to the path — no parent .zgroup or tracks.geff subfolder.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -400,19 +394,19 @@ def test_write_to_geff_no_parent_container(get_tracks, tmp_path): def test_write_to_geff_overwrite(get_tracks, tmp_path): """write_to_geff with overwrite=True replaces existing store.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) write_to_geff(tracks, geff_path, overwrite=True) loaded = import_from_geff(geff_path) - assert loaded.graph.num_nodes() == tracks.graph.num_nodes() + assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes() def test_write_to_geff_no_overwrite_raises(get_tracks, tmp_path): """write_to_geff with overwrite=False raises on existing store.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -423,7 +417,7 @@ def test_write_to_geff_no_overwrite_raises(get_tracks, tmp_path): def test_write_to_geff_metadata(get_tracks, tmp_path): """write_to_geff stores axes and FeatureDict metadata.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) @@ -441,7 +435,7 @@ def test_write_to_geff_metadata(get_tracks, tmp_path): def test_write_to_geff_segmentation_shape(get_tracks, tmp_path): """write_to_geff writes segmentation_shape when masks are present.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) geff_path = tmp_path / "my_tracks.geff" write_to_geff(tracks, geff_path) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index 5d6fdbd0..75cc3984 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -5,10 +5,10 @@ import zarr from geff.testing.data import create_mock_geff -from funtracks.data_model import SolutionTracks +from funtracks.data_model import Tracks from funtracks.import_export import export_to_geff, import_from_geff from funtracks.import_export.geff._import import GeffTracksBuilder, import_graph_from_geff -from funtracks.utils.tracksdata_utils import create_empty_graphview_graph +from funtracks.utils.tracksdata_utils import create_empty_graph @pytest.fixture @@ -262,7 +262,7 @@ def test_duplicate_values_in_name_map(valid_geff): tracks = import_from_geff(store, node_name_map) # Both time and seg_id should be present with same values - for node_id in tracks.graph.node_ids(): + for node_id in tracks.graph_solution.node_ids(): assert tracks.get_node_attr(node_id, "seg_id") == tracks.get_node_attr( node_id, "t" ) @@ -316,11 +316,11 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation, assert hasattr(tracks, "segmentation") assert tracks.segmentation.shape == valid_segmentation.shape # Get last node by ID (don't rely on iteration order) - last_node = max(tracks.graph.node_ids()) + last_node = max(tracks.graph_solution.node_ids()) # With composite pos, position is stored as an array - pos = tracks.graph.nodes[last_node]["pos"] + pos = tracks.graph_solution.nodes[last_node]["pos"] coords = [ - tracks.graph.nodes[last_node]["t"], + tracks.graph_solution.nodes[last_node]["t"], pos[0], # y pos[1], # x ] @@ -333,10 +333,10 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation, ) # test that the seg id has been relabeled # Check that only requested features are present and area is loaded from geff - data = tracks.graph.nodes[last_node] - assert "random_feature" in tracks.graph.node_attr_keys() - assert "random_feature2" not in tracks.graph.node_attr_keys() - assert "area" in tracks.graph.node_attr_keys() + data = tracks.graph_solution.nodes[last_node] + assert "random_feature" in tracks.graph_solution.node_attr_keys() + assert "random_feature2" not in tracks.graph_solution.node_attr_keys() + assert "area" in tracks.graph_solution.node_attr_keys() assert data["area"] == 21 # loaded directly from geff, not recomputed # Test that import fails with ValueError when invalid seg_ids are provided. @@ -418,8 +418,8 @@ def test_features_loaded_from_name_map(valid_geff, valid_segmentation, tmp_path) assert key in tracks.features # Get last node by ID (don't rely on iteration order) - max_node_id = max(tracks.graph.node_ids()) - data = tracks.graph.nodes[max_node_id] + max_node_id = max(tracks.graph_solution.node_ids()) + data = tracks.graph_solution.nodes[max_node_id] # All requested features should be present and loaded from geff for key in feature_keys: @@ -472,7 +472,7 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path): import tracksdata as td from tracksdata.nodes import Mask - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=[ "pos", "area", @@ -501,13 +501,15 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path): indices=[1], ) # The graph carries segmentation_shape in its metadata (set by motile-tracker), - # but no dense segmentation array is attached to the SolutionTracks object. + # but no dense segmentation array is attached to the Tracks object. graph._update_metadata(segmentation_shape=(5, 100, 100)) run_dir = tmp_path / "run" run_dir.mkdir() - st = SolutionTracks(graph, ndim=3, time_attr="t") + st = Tracks( + graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id" + ) export_to_geff(st, run_dir) tracks_path = run_dir / "tracks.geff" @@ -541,13 +543,13 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path): # import_from_geff must read segmentation_shape back from zarr attrs and # reconstruct a segmentation (GraphArrayView) — not return segmentation=None. tracks = import_from_geff(tracks_path) - assert tracks.graph.num_nodes() == 1 + assert tracks.graph_solution.num_nodes() == 1 assert tracks.segmentation is not None, ( "segmentation should be reconstructed from masks after round-trip" ) assert tracks.segmentation.shape == (5, 100, 100) - node1 = tracks.graph.nodes[1] + node1 = tracks.graph_solution.nodes[1] assert node1["pos"] is not None np.testing.assert_array_almost_equal(node1["pos"], [50.0, 50.0]) @@ -600,7 +602,7 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path): import tracksdata as td import zarr as _zarr - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=[ "pos", td.DEFAULT_ATTR_KEYS.MASK, @@ -625,7 +627,9 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path): run_dir = tmp_path / "run" run_dir.mkdir() - st = SolutionTracks(graph, ndim=3, time_attr="t") + st = Tracks( + graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id" + ) export_to_geff(st, run_dir) tracks_path = run_dir / "tracks.geff" @@ -652,11 +656,11 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path): def test_get_time_works_after_import(valid_geff): - """Regression test: tracks.get_time() must work on a SolutionTracks returned by + """Regression test: tracks.get_time() must work on a Tracks returned by import_from_geff(). Previously, TracksBuilder.build() stored time as "t" in the graph (tracksdata - convention) but created SolutionTracks(time_attr=TIME_ATTR) where TIME_ATTR="time". + convention) but created Tracks(time_attr=TIME_ATTR) where TIME_ATTR="time". This caused features.time_key="time" while the graph only had attribute "t", making get_time() raise KeyError: 'time'. """ @@ -664,13 +668,13 @@ def test_get_time_works_after_import(valid_geff): name_map = {"time": "t", "pos": ["y", "x"]} tracks = import_from_geff(store, name_map) - for node_id in tracks.graph.node_ids(): + for node_id in tracks.graph_solution.node_ids(): # This must not raise KeyError: 'time' t = tracks.get_time(node_id) assert isinstance(t, int), f"get_time() should return int, got {type(t)}" # get_times() on all nodes must also work - all_node_ids = list(tracks.graph.node_ids()) + all_node_ids = list(tracks.graph_solution.node_ids()) times = tracks.get_times(all_node_ids) assert len(times) == len(all_node_ids) @@ -710,7 +714,7 @@ def test_bool_node_property_schema(geff_with_bool_prop): name_map = {"time": "t", "pos": ["y", "x"], "is_dividing": "is_dividing"} tracks = import_from_geff(geff_with_bool_prop, name_map) - df = tracks.graph.node_attrs(attr_keys=["is_dividing"]) + df = tracks.graph_solution.node_attrs(attr_keys=["is_dividing"]) assert df["is_dividing"].dtype == pl.Boolean, ( f"Expected pl.Boolean schema for 'is_dividing', got {df['is_dividing'].dtype}. " "Likely cause: np.bool_ default_value fell through to int in construct_graph()." @@ -729,10 +733,10 @@ def test_bool_node_property_values(geff_with_bool_prop): name_map = {"time": "t", "pos": ["y", "x"], "is_dividing": "is_dividing"} tracks = import_from_geff(geff_with_bool_prop, name_map) - node_ids = sorted(tracks.graph.node_ids()) + node_ids = sorted(tracks.graph_solution.node_ids()) expected = [True, False, True, False, True] for node_id, exp in zip(node_ids, expected, strict=True): - val = tracks.graph.nodes[node_id]["is_dividing"] + val = tracks.graph_solution.nodes[node_id]["is_dividing"] assert type(val) is bool, ( f"Expected Python bool for 'is_dividing', got {type(val)} for node {node_id}" "Likely cause: np.bool_ value not cast to bool in construct_graph()." @@ -744,7 +748,7 @@ def test_3d_pos_survives_sql_roundtrip(tmp_path): """Regression test: 3D pos (z, y, x) must keep Array dtype through SQL roundtrip. The construct_graph() method must pass the correct ndim to - create_empty_graphview_graph() so the pos schema is Array(Float64, 3) not + create_empty_graph() so the pos schema is Array(Float64, 3) not Array(Float64, 2). A schema mismatch causes SQLGraph.from_other() to downgrade the column to List(Float64), which breaks downstream callers that rely on to_numpy() returning a 2D float64 array. @@ -768,14 +772,16 @@ def test_3d_pos_survives_sql_roundtrip(tmp_path): tracks = import_from_geff(store, name_map) # Verify the RX graph has correct Array dtype for 3D pos - df_rx = tracks.graph.node_attrs(attr_keys=["pos"]) + df_rx = tracks.graph_solution.node_attrs(attr_keys=["pos"]) assert df_rx["pos"].dtype == pl.Array(pl.Float64, 3), ( f"RX graph pos should be Array(Float64, 3), got {df_rx['pos'].dtype}" ) # Convert to SQL and reload — this is where the schema mismatch used to surface db_path = str(tmp_path / "test.db") - td.graph.SQLGraph.from_other(tracks.graph, drivername="sqlite", database=db_path) + td.graph.SQLGraph.from_other( + tracks.graph_solution, drivername="sqlite", database=db_path + ) sql_graph2 = td.graph.SQLGraph("sqlite", db_path) df_sql = sql_graph2.node_attrs(attr_keys=["pos"]) @@ -824,9 +830,10 @@ def test_geff_legacy_track_id_preserves_tracklet_ids(): def test_geff_roundtrip_preserves_tracklet_ids(get_tracks, tmp_path): """End-to-end round-trip: export then import should preserve tracklet IDs.""" - tracks_in = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks_in = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) expected = { - int(nid): tracks_in.get_track_id(int(nid)) for nid in tracks_in.graph.node_ids() + int(nid): tracks_in.get_track_id(int(nid)) + for nid in tracks_in.graph_solution.node_ids() } export_dir = tmp_path / "export" @@ -871,7 +878,7 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path): td.DEFAULT_ATTR_KEYS.BBOX, ] # node_default_values must align with node_attributes by index. - # pos/mask/bbox are handled by special cases in create_empty_graphview_graph + # pos/mask/bbox are handled by special cases in create_empty_graph # and their slot values here are never accessed; only area, ellipse_axis_radii, # track_id, and lineage_id go through the general loop. node_default_values = [ @@ -883,7 +890,7 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path): None, # mask — special-cased, slot unused None, # bbox — special-cased, slot unused ] - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=node_attributes, node_default_values=node_default_values, edge_attributes=[], @@ -911,7 +918,9 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path): run_dir = tmp_path / "run" run_dir.mkdir() - st = SolutionTracks(graph, ndim=3, time_attr="t") + st = Tracks( + graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id" + ) export_to_geff(st, run_dir) # Remove FeatureDict from GEFF metadata to simulate old/external GEFF @@ -996,7 +1005,7 @@ def test_featuredict_survives_geff_roundtrip(tmp_path): None, # mask — special-cased, slot unused None, # bbox — special-cased, slot unused ] - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=node_attributes, node_default_values=node_default_values, edge_attributes=[], @@ -1024,7 +1033,9 @@ def test_featuredict_survives_geff_roundtrip(tmp_path): run_dir = tmp_path / "run" run_dir.mkdir() - st = SolutionTracks(graph, ndim=3, time_attr="t") + st = Tracks( + graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id" + ) # Customize metadata that auto-detection cannot reproduce. pos_key = st.features.position_key assert isinstance(pos_key, str) @@ -1050,7 +1061,7 @@ def test_invalid_featuredict_in_geff_falls_back_to_autodetect(get_tracks, tmp_pa import_from_geff should silently fall back to auto-detection instead of raising an exception. """ - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) run_dir = tmp_path / "run" run_dir.mkdir() export_to_geff(tracks, run_dir, save_segmentation=False) @@ -1071,11 +1082,13 @@ def test_invalid_featuredict_in_geff_falls_back_to_autodetect(get_tracks, tmp_pa # Should not raise — falls back to auto-detection imported = import_from_geff(geff_path) - # Verify the import produced a working SolutionTracks with auto-detected features + # Verify the import produced a working Tracks with auto-detected features assert imported.features.time_key is not None assert imported.features.position_key is not None assert imported.features.tracklet_key is not None - assert set(tracks.graph.node_ids()) == set(imported.graph.node_ids()) + assert set(tracks.graph_solution.node_ids()) == set( + imported.graph_solution.node_ids() + ) def test_subgroup_export_omits_featuredict_and_recomputes_on_import(get_tracks, tmp_path): @@ -1088,7 +1101,7 @@ def test_subgroup_export_omits_featuredict_and_recomputes_on_import(get_tracks, On reimport, validation would strip the now-invalid tracklet_id but the FeatureDict still referenced it, causing KeyError: 'tracklet_id'. """ - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) # Export only a subset: nodes 1, 3, 4, 5 (one branch of the division). # filter_graph_with_ancestors will include node 1 as ancestor of 3. @@ -1126,16 +1139,18 @@ def test_subgroup_export_omits_featuredict_and_recomputes_on_import(get_tracks, }, ) - assert isinstance(imported, SolutionTracks) + assert isinstance(imported, Tracks) assert imported.features.tracklet_key is not None # The subgraph is a linear chain (1→3→4→5, no divisions), so all nodes # should share a single tracklet_id and a single lineage_id. - track_ids = {imported.get_track_id(nid) for nid in imported.graph.node_ids()} + track_ids = {imported.get_track_id(nid) for nid in imported.graph_solution.node_ids()} assert len(track_ids) == 1, ( f"Linear chain should have one tracklet_id, got {track_ids}" ) - lineage_ids = {imported.get_lineage_id(nid) for nid in imported.graph.node_ids()} + lineage_ids = { + imported.get_lineage_id(nid) for nid in imported.graph_solution.node_ids() + } assert len(lineage_ids) == 1, ( f"Linear chain should have one lineage_id, got {lineage_ids}" ) @@ -1155,7 +1170,7 @@ def test_import_from_geff_respects_external_solution_column(tmp_path): # Separate y/x columns (not a 2-D 'pos' array) so tracksdata's to_geff # registers them as space-typed axes — required to exercise the # axes-branch of GeffTracksBuilder.infer_node_name_map. - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["y", "x"], position_attrs=["y", "x"], ndim=3 ) graph.bulk_add_nodes( @@ -1167,15 +1182,14 @@ def test_import_from_geff_respects_external_solution_column(tmp_path): indices=[1, 2, 3], ) - # Export the root graph, not the filtered solution-only view, so the - # solution=False row survives into the geff (mimicking a solver-produced - # geff with rejected nodes). + # Export the full base graph so the solution=False row survives into the geff + # (mimicking a solver-produced geff with rejected nodes). tracks_path = tmp_path / "tracks.geff" - graph._root.to_geff(geff_store=tracks_path, zarr_format=2) + graph.to_geff(geff_store=tracks_path, zarr_format=2) tracks = import_from_geff(tracks_path) - node_ids = set(tracks.graph.node_ids()) + node_ids = set(tracks.graph_solution.node_ids()) assert node_ids == {1, 3}, ( "Node 2 has solution=False in the geff file and should be filtered out " f"by import_from_geff. Got node_ids={node_ids}." diff --git a/tests/import_export/test_import_segmentation.py b/tests/import_export/test_import_segmentation.py index fcbe8b5f..088de86c 100644 --- a/tests/import_export/test_import_segmentation.py +++ b/tests/import_export/test_import_segmentation.py @@ -7,7 +7,7 @@ load_segmentation, relabel_segmentation, ) -from funtracks.utils.tracksdata_utils import create_empty_graphview_graph +from funtracks.utils.tracksdata_utils import create_empty_graph class TestLoadSegmentation: @@ -46,7 +46,7 @@ def test_basic_relabeling(self): seg[1, 2, 2] = 20 # seg_id 20 at t=1 # Create graph with node_ids 1, 2 - graph = create_empty_graphview_graph() + graph = create_empty_graph() graph.add_node(index=1, attrs={"t": 0, "solution": True}) graph.add_node(index=2, attrs={"t": 1, "solution": True}) @@ -70,7 +70,7 @@ def test_relabeling_with_node_id_zero(self): seg[1, 2, 2] = 20 # seg_id 20 at t=1 # Create graph with node_ids 0, 1 (includes 0!) - graph = create_empty_graphview_graph() + graph = create_empty_graph() graph.add_node(index=0, attrs={"t": 0, "solution": True}) graph.add_node(index=1, attrs={"t": 1, "solution": True}) @@ -97,7 +97,7 @@ def test_no_relabeling_needed_same_ids(self): seg[0, 1, 1] = 1 seg[1, 2, 2] = 2 - graph = create_empty_graphview_graph() + graph = create_empty_graph() graph.add_node(index=1, attrs={"t": 0, "solution": True}) graph.add_node(index=2, attrs={"t": 1, "solution": True}) @@ -119,7 +119,7 @@ def test_multiple_nodes_same_timepoint(self): seg[0, 2, 2] = 20 seg[0, 3, 3] = 30 - graph = create_empty_graphview_graph() + graph = create_empty_graph() graph.add_node(index=1, attrs={"t": 0, "solution": True}) graph.add_node(index=2, attrs={"t": 0, "solution": True}) graph.add_node(index=3, attrs={"t": 0, "solution": True}) diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index 8112cc60..636de386 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -14,20 +14,20 @@ @pytest.mark.parametrize("with_seg", [True, False]) @pytest.mark.parametrize("ndim", [3, 4]) -@pytest.mark.parametrize("is_solution", [True, False]) +@pytest.mark.parametrize("prefill_track_ids", [True, False]) def test_save_load( get_tracks, with_seg, ndim, - is_solution, + prefill_track_ids, ): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=is_solution) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=prefill_track_ids) data_path = Path( - f"tests/data/format_v1/test_save_load_{is_solution}_{ndim}_{with_seg}_0" + f"tests/data/format_v1/test_save_load_{prefill_track_ids}_{ndim}_{with_seg}_0" ) - loaded = load_v1_tracks(data_path, solution=is_solution) + loaded = load_v1_tracks(data_path) assert loaded.ndim == ndim # Check feature keys and important properties match (allow tuple vs list diff) assert loaded.features.time_key == tracks.features.time_key @@ -60,12 +60,9 @@ def test_save_load( assert loaded.scale == tracks.scale assert loaded.ndim == tracks.ndim - if is_solution: - loaded_annotator = loaded.track_annotator - tracks_annotator = tracks.track_annotator - assert ( - loaded_annotator.tracklet_id_to_nodes == tracks_annotator.tracklet_id_to_nodes - ) + loaded_annotator = loaded.track_annotator + tracks_annotator = tracks.track_annotator + assert loaded_annotator.tracklet_id_to_nodes == tracks_annotator.tracklet_id_to_nodes if with_seg: assert_array_almost_equal(loaded.segmentation, tracks.segmentation) @@ -73,27 +70,33 @@ def test_save_load( assert loaded.segmentation is None # graphs_equal doesn't exist for TracksData, so we check properties - assert set(loaded.graph.node_attr_keys()) == set(tracks.graph.node_attr_keys()) - assert set(loaded.graph.edge_attr_keys()) == set(tracks.graph.edge_attr_keys()) - assert loaded.graph.num_nodes() == tracks.graph.num_nodes() - assert loaded.graph.num_edges() == tracks.graph.num_edges() - assert set(loaded.graph.node_ids()) == set(tracks.graph.node_ids()) + assert set(loaded.graph_solution.node_attr_keys()) == set( + tracks.graph_solution.node_attr_keys() + ) + assert set(loaded.graph_solution.edge_attr_keys()) == set( + tracks.graph_solution.edge_attr_keys() + ) + assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes() + assert loaded.graph_solution.num_edges() == tracks.graph_solution.num_edges() + assert set(loaded.graph_solution.node_ids()) == set(tracks.graph_solution.node_ids()) # edge_ids dont matter, only the actual edges: - assert sorted(loaded.graph.edge_list()) == sorted(tracks.graph.edge_list()) + assert sorted(loaded.graph_solution.edge_list()) == sorted( + tracks.graph_solution.edge_list() + ) @pytest.mark.parametrize("with_seg", [True, False]) @pytest.mark.parametrize("ndim", [3, 4]) -@pytest.mark.parametrize("is_solution", [True, False]) +@pytest.mark.parametrize("prefill_track_ids", [True, False]) def test_delete( get_tracks, with_seg, ndim, - is_solution, + prefill_track_ids, tmp_path, ): reference_path = Path( - f"tests/data/format_v1/test_save_load_{is_solution}_{ndim}_{with_seg}_0" + f"tests/data/format_v1/test_save_load_{prefill_track_ids}_{ndim}_{with_seg}_0" ) # Copy reference data to temporary location @@ -115,7 +118,7 @@ def test_load_without_features(tmp_path, graph_2d_with_segmentation): shutil.copytree(reference_path, tracks_path) # Load the original data first to verify it loads correctly - load_v1_tracks(tracks_path, solution=True) + load_v1_tracks(tracks_path) # Modify the copy to test backward compatibility attrs_path = tracks_path / "attrs.json" diff --git a/tests/import_export/test_solution_roundtrip.py b/tests/import_export/test_solution_roundtrip.py index babf3183..a716fa1c 100644 --- a/tests/import_export/test_solution_roundtrip.py +++ b/tests/import_export/test_solution_roundtrip.py @@ -20,11 +20,11 @@ def _roundtrip(tracks, tmp_path, name="rt.geff"): def test_geff_roundtrip_preserves_solution_schema(get_tracks, tmp_path): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) loaded = _roundtrip(tracks, tmp_path) - edge_schema = loaded.graph._edge_attr_schemas()["solution"] - node_schema = loaded.graph._node_attr_schemas()["solution"] + edge_schema = loaded.graph_solution._edge_attr_schemas()["solution"] + node_schema = loaded.graph_solution._node_attr_schemas()["solution"] assert edge_schema.dtype == pl.Boolean assert edge_schema.default_value is True @@ -33,9 +33,9 @@ def test_geff_roundtrip_preserves_solution_schema(get_tracks, tmp_path): def test_add_edge_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) loaded = _roundtrip(tracks, tmp_path) - g = loaded.graph + g = loaded.graph_solution # find any source at frame t and target at t+1 with no edge between them rows = list(g.node_attrs(attr_keys=["node_id", "t"]).sort("t").iter_rows(named=True)) @@ -61,10 +61,10 @@ def test_add_edge_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path): def test_add_node_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path): - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) loaded = _roundtrip(tracks, tmp_path) - new_id = max(loaded.graph.node_ids()) + 1 + new_id = max(loaded.graph_solution.node_ids()) + 1 AddNode( loaded, new_id, diff --git a/tests/import_export/test_utils.py b/tests/import_export/test_utils.py index cbfa1bcc..b9e75400 100644 --- a/tests/import_export/test_utils.py +++ b/tests/import_export/test_utils.py @@ -3,7 +3,7 @@ def test_rename_feature_basic(get_tracks): """Test that rename_feature renames a feature in annotators and features dict.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=False) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=False) # Rename area feature to custom name rename_feature(tracks, "area", "my_area") @@ -19,7 +19,7 @@ def test_rename_feature_basic(get_tracks): def test_rename_feature_updates_position_key(get_tracks): """Test that renaming position feature updates position_key in FeatureDict.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) original_pos_key = tracks.features.position_key new_key = "custom_position" @@ -32,7 +32,7 @@ def test_rename_feature_updates_position_key(get_tracks): def test_rename_feature_updates_tracklet_key(get_tracks): """Test that renaming tracklet feature updates tracklet_key in FeatureDict.""" - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) original_track_key = tracks.features.tracklet_key new_key = "custom_track" diff --git a/tests/user_actions/test_user_actions_force.py b/tests/user_actions/test_user_actions_force.py index 1096b958..045842f1 100644 --- a/tests/user_actions/test_user_actions_force.py +++ b/tests/user_actions/test_user_actions_force.py @@ -7,42 +7,42 @@ def test_user_force_add_downstream(get_tracks): """Test force adding a node of which the track id has an upstream division event. Should break the edges of the division event to allow this new edge.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) # upstream division, with force attrs = {"t": 2, "track_id": 1, "pos": [3, 4]} UserAddNode(tracks, node=7, attributes=attrs, force=True) assert tracks.get_track_id(7) == 1 - assert [1, 2] not in tracks.graph.edge_list() - assert [1, 3] not in tracks.graph.edge_list() - assert [1, 7] in tracks.graph.edge_list() + assert [1, 2] not in tracks.graph_solution.edge_list() + assert [1, 3] not in tracks.graph_solution.edge_list() + assert [1, 7] in tracks.graph_solution.edge_list() def test_user_force_add_upstream(get_tracks): """Test force adding a node upstream, of which the track id co-exists with the parent track id. Should break the edge with the parent track to allow this new edge.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) # downstream parent division, with force attrs = {"t": 0, "track_id": 3, "pos": [3, 4]} UserAddNode(tracks, node=7, attributes=attrs, force=True) assert tracks.get_track_id(7) == 3 - assert [1, 2] in tracks.graph.edge_list() # still there - assert [1, 3] not in tracks.graph.edge_list() # should be removed - assert [7, 3] in tracks.graph.edge_list() # new forced edge + assert [1, 2] in tracks.graph_solution.edge_list() # still there + assert [1, 3] not in tracks.graph_solution.edge_list() # should be removed + assert [7, 3] in tracks.graph_solution.edge_list() # new forced edge def test_auto_assign_new_track_id(get_tracks): """Test that adding a node with a track id that already exists at the current time point raises a warning and auto-assigns a new track id instead.""" - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) # existing track id at current time --> allowed, with warning with pytest.warns(UserWarning, match="Starting a new track, because track id"): attrs = {"t": 1, "track_id": 2, "pos": [3, 4]} # combination exists already UserAddNode(tracks, node=7, attributes=attrs) - assert tracks.graph.has_node(7) + assert tracks.graph_solution.has_node(7) assert tracks.get_track_id(7) == 6 # new assigned track id diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py index 921f3e1d..9364b78f 100644 --- a/tests/user_actions/test_user_add_delete_edge.py +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -8,32 +8,32 @@ @pytest.mark.parametrize("with_seg", [True, False]) class TestUserAddDeleteEdge: def test_user_add_edge(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # add an edge from 4 to 6 (will make 4 a division and 5 will need to relabel # track id) edge = (4, 6) old_child = 5 old_track_id = tracks.get_track_id(old_child) - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) action = UserAddEdge(tracks, edge) - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id inverse = action.inverse() - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == old_track_id inverse.inverse() - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id def test_user_add_merge_edge(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # add an edge from 2 to 4 (there is already an edge from 3 to 4) edge = (2, 4) old_edge = (3, 4) - assert not tracks.graph.has_edge(*edge) - assert tracks.graph.has_edge(*old_edge) + assert not tracks.graph_solution.has_edge(*edge) + assert tracks.graph_solution.has_edge(*old_edge) with pytest.raises( InvalidActionError, match="Cannot make a merge edge in a tracking solution" ): @@ -43,37 +43,37 @@ def test_user_add_merge_edge(self, get_tracks, ndim, with_seg): match="Removing edge .* to add new edge without merging.", ): action = UserAddEdge(tracks, edge, force=True) - assert tracks.graph.has_edge(*edge) - assert not tracks.graph.has_edge(*old_edge) + assert tracks.graph_solution.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*old_edge) inverse = action.inverse() - assert not tracks.graph.has_edge(*edge) - assert tracks.graph.has_edge(*old_edge) + assert not tracks.graph_solution.has_edge(*edge) + assert tracks.graph_solution.has_edge(*old_edge) inverse.inverse() - assert tracks.graph.has_edge(*edge) - assert not tracks.graph.has_edge(*old_edge) + assert tracks.graph_solution.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*old_edge) def test_user_delete_edge(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # delete edge (1, 3). (1,2) is now not a division anymore edge = (1, 3) old_child = 2 old_track_id = tracks.get_track_id(old_child) new_track_id = tracks.get_track_id(1) - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) action = UserDeleteEdge(tracks, edge) - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == new_track_id inverse = action.inverse() - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == old_track_id double_inv = inverse.inverse() - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == new_track_id # TODO: error if edge doesn't exist? @@ -84,23 +84,23 @@ def test_user_delete_edge(self, get_tracks, ndim, with_seg): old_child = 5 old_track_id = tracks.get_track_id(old_child) - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) action = UserDeleteEdge(tracks, edge) - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id inverse = action.inverse() - assert tracks.graph.has_edge(*edge) + assert tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) == old_track_id inverse.inverse() - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph_solution.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id def test_add_edge_missing_node(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(InvalidActionError, match="Source node .* not in solution yet"): UserAddEdge(tracks, (10, 11)) with pytest.raises(InvalidActionError, match="Target node .* not in solution yet"): @@ -108,7 +108,7 @@ def test_add_edge_missing_node(get_tracks): def test_add_edge_triple_div(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises( InvalidActionError, match="Expected degree of 0 or 1 before adding edge" ): @@ -116,18 +116,18 @@ def test_add_edge_triple_div(get_tracks): def test_delete_missing_edge(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) with pytest.raises(InvalidActionError, match="Edge .* not in solution"): UserDeleteEdge(tracks, (10, 11)) def test_delete_edge_triple_div(get_tracks): - tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True) attrs = {} attrs["solution"] = True attrs["iou"] = 0.9 - tracks.graph.add_edge( + tracks.graph_solution.add_edge( source_id=1, target_id=6, attrs=attrs, diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py index 5524f021..8b21d4b1 100644 --- a/tests/user_actions/test_user_add_delete_node.py +++ b/tests/user_actions/test_user_add_delete_node.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("with_seg", [True, False]) class TestUserAddDeleteNode: def test_user_add_invalid_node(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # duplicate node with pytest.raises(InvalidActionError, match="Node .* already exists"): attrs = {"t": 5, "track_id": 1} @@ -34,7 +34,7 @@ def test_user_add_invalid_node(self, get_tracks, ndim, with_seg): UserAddNode(tracks, node=7, attributes=attrs) def test_user_add_node(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # add a node to replace a skip edge between node 4 in time 2 and node 5 in time 4 node_id = 7 track_id = 3 @@ -55,7 +55,7 @@ def test_user_add_node(self, get_tracks, ndim, with_seg): del attributes["pos"] else: pixels = None - graph = tracks.graph + graph = tracks.graph_solution assert not graph.has_node(node_id) assert graph.has_edge(4, 5) action = UserAddNode(tracks, node_id, attributes, pixels=pixels) @@ -84,11 +84,11 @@ def test_user_add_node(self, get_tracks, ndim, with_seg): # TODO: error if node already exists? def test_user_delete_node(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # delete node in middle of track. Should skip-connect 3 and 5 with span 3 node_id = 4 - graph = tracks.graph + graph = tracks.graph_solution assert graph.has_node(node_id) assert graph.has_edge(3, node_id) assert graph.has_edge(node_id, 5) @@ -121,14 +121,14 @@ def test_user_delete_node(self, get_tracks, ndim, with_seg): # TODO: error if node doesn't exist? def test_user_delete_node_after_division(self, get_tracks, ndim, with_seg): - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # delete first node after division. Should relabel the other child # to be the same track as parent parent_node = 1 node_id = 2 sib = 3 - graph = tracks.graph + graph = tracks.graph_solution assert graph.has_node(node_id) assert graph.has_edge(parent_node, node_id) parent_track_id = tracks.get_track_id(parent_node) @@ -158,8 +158,8 @@ def test_user_delete_node_after_division(self, get_tracks, ndim, with_seg): def test_user_delete_nodes(self, get_tracks, ndim, with_seg): """Test bulk deletion of multiple nodes in a single action.""" # Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate) - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - graph = tracks.graph + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + graph = tracks.graph_solution # Save original state original_nodes = set(graph.node_ids()) diff --git a/tests/user_actions/test_user_swap_predecessors.py b/tests/user_actions/test_user_swap_predecessors.py index d33faeb3..82a60c88 100644 --- a/tests/user_actions/test_user_swap_predecessors.py +++ b/tests/user_actions/test_user_swap_predecessors.py @@ -10,30 +10,30 @@ class TestUserSwapPredecessors: @pytest.mark.parametrize("order", [(5, 6), (6, 5)]) def test_one_predecessor(self, get_tracks, ndim, with_seg, order): """Test swapping when one node has a predecessor and one doesn't.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Node 5 (t=4) has pred 4, node 6 (t=4) has no pred - assert tracks.graph.has_edge(4, 5) - assert list(tracks.graph.predecessors(6)) == [] + assert tracks.graph_solution.has_edge(4, 5) + assert list(tracks.graph_solution.predecessors(6)) == [] old_track_id_5 = tracks.get_track_id(5) old_track_id_6 = tracks.get_track_id(6) action = UserSwapPredecessors(tracks, order) - assert tracks.graph.has_edge(4, 6) - assert not tracks.graph.has_edge(4, 5) + assert tracks.graph_solution.has_edge(4, 6) + assert not tracks.graph_solution.has_edge(4, 5) assert tracks.get_track_id(6) == old_track_id_5 assert tracks.get_track_id(5) != old_track_id_5 action.inverse() - assert tracks.graph.has_edge(4, 5) - assert not tracks.graph.has_edge(4, 6) + assert tracks.graph_solution.has_edge(4, 5) + assert not tracks.graph_solution.has_edge(4, 6) assert tracks.get_track_id(5) == old_track_id_5 assert tracks.get_track_id(6) == old_track_id_6 def test_same_predecessor_raises(self, get_tracks, ndim, with_seg): """Test error when both nodes have the same predecessor.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Nodes 2 and 3 both have predecessor 1 with pytest.raises(InvalidActionError, match="same predecessor"): @@ -41,7 +41,7 @@ def test_same_predecessor_raises(self, get_tracks, ndim, with_seg): def test_different_predecessors(self, get_tracks, ndim, with_seg): """Test swapping when both nodes have different predecessors.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) UserAddEdge(tracks, (2, 6)) @@ -51,20 +51,20 @@ def test_different_predecessors(self, get_tracks, ndim, with_seg): action = UserSwapPredecessors(tracks, (5, 6)) - assert tracks.graph.has_edge(4, 6) - assert tracks.graph.has_edge(2, 5) - assert not tracks.graph.has_edge(4, 5) - assert not tracks.graph.has_edge(2, 6) + assert tracks.graph_solution.has_edge(4, 6) + assert tracks.graph_solution.has_edge(2, 5) + assert not tracks.graph_solution.has_edge(4, 5) + assert not tracks.graph_solution.has_edge(2, 6) action.inverse() - assert tracks.graph.has_edge(4, 5) - assert tracks.graph.has_edge(2, 6) + assert tracks.graph_solution.has_edge(4, 5) + assert tracks.graph_solution.has_edge(2, 6) assert tracks.get_track_id(5) == old_track_id_5 assert tracks.get_track_id(6) == old_track_id_6 def test_different_times_valid(self, get_tracks, ndim, with_seg): """Test swapping nodes at different times when predecessors are valid.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Add edge 2->6 so node 6 (t=4) has pred 2 (t=1) # Node 4 (t=2) has pred 3 (t=1) @@ -73,18 +73,18 @@ def test_different_times_valid(self, get_tracks, ndim, with_seg): action = UserSwapPredecessors(tracks, (4, 6)) - assert tracks.graph.has_edge(3, 6) - assert tracks.graph.has_edge(2, 4) - assert not tracks.graph.has_edge(3, 4) - assert not tracks.graph.has_edge(2, 6) + assert tracks.graph_solution.has_edge(3, 6) + assert tracks.graph_solution.has_edge(2, 4) + assert not tracks.graph_solution.has_edge(3, 4) + assert not tracks.graph_solution.has_edge(2, 6) action.inverse() - assert tracks.graph.has_edge(3, 4) - assert tracks.graph.has_edge(2, 6) + assert tracks.graph_solution.has_edge(3, 4) + assert tracks.graph_solution.has_edge(2, 6) def test_different_times_invalid_raises(self, get_tracks, ndim, with_seg): """Test error when predecessor would not be before swapped node.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Node 3 (t=1) has pred 1 (t=0), node 4 (t=2) has pred 3 (t=1) # pred of 4 (t=1) is not before node 3 (t=1) @@ -93,7 +93,7 @@ def test_different_times_invalid_raises(self, get_tracks, ndim, with_seg): def test_wrong_count_raises(self, get_tracks, ndim, with_seg): """Test error when not exactly two nodes provided.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) with pytest.raises( InvalidActionError, match="You can only swap a pair of two nodes" @@ -107,7 +107,7 @@ def test_wrong_count_raises(self, get_tracks, ndim, with_seg): def test_no_predecessors_raises(self, get_tracks, ndim, with_seg): """Test error when neither node has a predecessor.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) # Delete edge so node 5 has no predecessor like node 6 UserDeleteEdge(tracks, (4, 5)) diff --git a/tests/user_actions/test_user_update_node_attrs.py b/tests/user_actions/test_user_update_node_attrs.py index 5e85a03b..3b3c98b4 100644 --- a/tests/user_actions/test_user_update_node_attrs.py +++ b/tests/user_actions/test_user_update_node_attrs.py @@ -9,11 +9,17 @@ class TestUserUpdateNodeAttrs: def test_user_update_node_attrs(self, get_tracks, ndim, with_seg): """Test basic node attribute update functionality.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) - tracks.graph.add_node_attr_key("confidence", default_value=0, dtype=pl.Float64) - tracks.graph.add_node_attr_key("validated", default_value=False, dtype=pl.Boolean) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) + tracks.graph_solution.add_node_attr_key( + "confidence", default_value=0, dtype=pl.Float64 + ) + tracks.graph_solution.add_node_attr_key( + "validated", default_value=False, dtype=pl.Boolean + ) # Add custom attributes to update custom_attrs = {"label": "my_label", "confidence": 0.95, "validated": True} @@ -43,10 +49,14 @@ def test_user_update_node_attrs(self, get_tracks, ndim, with_seg): def test_user_update_existing_attrs(self, get_tracks, ndim, with_seg): """Test updating attributes that already exist.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) - tracks.graph.add_node_attr_key("score", default_value=None, dtype=pl.Float64) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) + tracks.graph_solution.add_node_attr_key( + "score", default_value=None, dtype=pl.Float64 + ) # Set initial custom attributes tracks._set_node_attr(1, "label", "old_label") @@ -67,7 +77,7 @@ def test_user_update_existing_attrs(self, get_tracks, ndim, with_seg): def test_protected_time_attr(self, get_tracks, ndim, with_seg): """Test that time attribute cannot be updated.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) time_key = tracks.features.time_key with pytest.raises(ValueError, match="Cannot update attribute"): @@ -75,14 +85,14 @@ def test_protected_time_attr(self, get_tracks, ndim, with_seg): def test_protected_track_id_attr(self, get_tracks, ndim, with_seg): """Test that track_id attribute cannot be updated.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) with pytest.raises(ValueError, match="Cannot update attribute"): UserUpdateNodeAttrs(tracks, node=1, attrs={"track_id": 999}) def test_protected_area_attr(self, get_tracks, ndim, with_seg): """Test that area attribute (managed by annotator) cannot be updated.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) if with_seg: # area only exists when segmentation is present with pytest.raises(ValueError, match="Cannot update attribute"): @@ -90,7 +100,7 @@ def test_protected_area_attr(self, get_tracks, ndim, with_seg): def test_protected_pos_attr(self, get_tracks, ndim, with_seg): """Test that position attribute (managed by annotator) cannot be updated.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) if with_seg: # pos is managed by RegionpropsAnnotator when seg exists with pytest.raises(ValueError, match="Cannot update attribute"): @@ -98,8 +108,10 @@ def test_protected_pos_attr(self, get_tracks, ndim, with_seg): def test_action_history_integration(self, get_tracks, ndim, with_seg): """Test that action integrates properly with action history.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) # Initially empty assert len(tracks.action_history.undo_stack) == 0 diff --git a/tests/user_actions/test_user_update_nodes_attrs.py b/tests/user_actions/test_user_update_nodes_attrs.py index bb853887..b327b5a0 100644 --- a/tests/user_actions/test_user_update_nodes_attrs.py +++ b/tests/user_actions/test_user_update_nodes_attrs.py @@ -10,10 +10,14 @@ class TestUserUpdateNodesAttrs: def test_user_update_nodes_attrs(self, get_tracks, ndim, with_seg): """Test basic bulk node attribute update functionality.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) - tracks.graph.add_node_attr_key("confidence", default_value=0, dtype=pl.Float64) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) + tracks.graph_solution.add_node_attr_key( + "confidence", default_value=0, dtype=pl.Float64 + ) attrs = {"label": ["my_label", "my_label"], "confidence": [0.95, 0.95]} UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs=attrs) @@ -24,8 +28,10 @@ def test_user_update_nodes_attrs(self, get_tracks, ndim, with_seg): def test_single_history_entry(self, get_tracks, ndim, with_seg): """Updating multiple nodes creates only one history entry.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + tracks.graph_solution.add_node_attr_key( + "label", default_value=None, dtype=pl.Object + ) action = UserUpdateNodesAttrs( tracks, nodes=[1, 2, 3], attrs={"label": ["x", "x", "x"]} @@ -36,8 +42,10 @@ def test_single_history_entry(self, get_tracks, ndim, with_seg): def test_undo_redo(self, get_tracks, ndim, with_seg): """Undo restores all nodes' attrs to defaults; redo re-applies them.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("score", default_value=0, dtype=pl.Float64) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + tracks.graph_solution.add_node_attr_key( + "score", default_value=0, dtype=pl.Float64 + ) action = UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.9, 0.9]}) @@ -56,8 +64,10 @@ def test_undo_redo(self, get_tracks, ndim, with_seg): def test_per_node_attrs(self, get_tracks, ndim, with_seg): """Test bulk update with different values per node.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - tracks.graph.add_node_attr_key("score", default_value=0, dtype=pl.Float64) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) + tracks.graph_solution.add_node_attr_key( + "score", default_value=0, dtype=pl.Float64 + ) UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.1, 0.9]}) @@ -66,10 +76,10 @@ def test_per_node_attrs(self, get_tracks, ndim, with_seg): def test_array_attr(self, get_tracks, ndim, with_seg): """Test bulk update with array-valued attributes.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) spatial_dims = ndim - 1 - tracks.graph.add_node_attr_key( + tracks.graph_solution.add_node_attr_key( "custom_pos", default_value=None, dtype=pl.Array(pl.Float64, spatial_dims) ) @@ -81,21 +91,21 @@ def test_array_attr(self, get_tracks, ndim, with_seg): def test_values_not_list_raises(self, get_tracks, ndim, with_seg): """Non-list values raise ValueError.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) with pytest.raises(ValueError, match="must be a list"): UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": 0.9}) def test_values_length_mismatch_raises(self, get_tracks, ndim, with_seg): """List length not matching nodes length raises ValueError.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) with pytest.raises(ValueError, match="length"): UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.1]}) def test_protected_attr_raises(self, get_tracks, ndim, with_seg): """Passing a protected attribute raises ValueError.""" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True) time_key = tracks.features.time_key with pytest.raises(ValueError, match="Cannot update attribute"): diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py index d2d67a12..1a9a2f3b 100644 --- a/tests/user_actions/test_user_update_segmentation.py +++ b/tests/user_actions/test_user_update_segmentation.py @@ -26,7 +26,7 @@ def pixels_equal_mask(self, pixels, tracks, node_id): ) def test_user_update_seg_smaller(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 3 edge = (1, 3) @@ -51,14 +51,14 @@ def test_user_update_seg_smaller(self, get_tracks, ndim): updated_pixels=[(pixels_to_remove, node_id)], current_track_id=1, ) - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(remaining_pixels, tracks, node_id) assert tracks.get_position(node_id) == new_position assert tracks.get_node_attr(node_id, "area") == 1 assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(0.0, abs=0.01) inverse = action.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(orig_pixels, tracks, node_id) assert tracks.get_position(node_id) == orig_position assert tracks.get_node_attr(node_id, "area") == orig_area @@ -71,7 +71,7 @@ def test_user_update_seg_smaller(self, get_tracks, ndim): assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(0.0, abs=0.01) def test_user_update_seg_bigger(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 3 edge = (1, 3) @@ -95,26 +95,26 @@ def test_user_update_seg_bigger(self, get_tracks, ndim): action = UserUpdateSegmentation( tracks, new_value=3, updated_pixels=[(pixels_to_add, 0)], current_track_id=1 ) - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(all_pixels, tracks, node_id) assert tracks.get_node_attr(node_id, "area") == orig_area + 1 assert tracks.get_edge_attr(edge, iou_key) != orig_iou inverse = action.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(orig_pixels, tracks, node_id) assert tracks.get_position(node_id) == orig_position assert tracks.get_node_attr(node_id, "area") == orig_area assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(orig_iou, abs=0.01) inverse.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert self.pixels_equal_mask(all_pixels, tracks, node_id) assert tracks.get_node_attr(node_id, "area") == orig_area + 1 assert tracks.get_edge_attr(edge, iou_key) != orig_iou def test_invalid_action_with_segmentation(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 1 # Paint on top of node 1 with track id 3: because of the downstream division, this @@ -157,12 +157,12 @@ def test_invalid_action_with_segmentation(self, get_tracks, ndim): # assert that the segmentation now has the new value assert np.asarray(tracks.segmentation[t, y, x]) == new_value - assert tracks.graph.has_node(new_value) + assert tracks.graph_solution.has_node(new_value) assert len(update_seg_action.actions) == 2 # one for adding a node, # and one for updating existing node 1 def test_user_erase_seg(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 3 edge = (1, 3) @@ -182,24 +182,24 @@ def test_user_erase_seg(self, get_tracks, ndim): updated_pixels=[(pixels_to_remove, node_id)], current_track_id=1, ) - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) inverse = action.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) self.pixels_equal_mask(orig_pixels, tracks, node_id) assert tracks.get_position(node_id) == orig_position assert tracks.get_node_attr(node_id, "area") == orig_area assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(orig_iou, abs=0.01) inverse.inverse() - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) def test_user_erase_seg_history_size(self, get_tracks, ndim): """An erase via UserUpdateSegmentation must add exactly one history entry. Regression test for a bug where the nested UserDeleteNode also registered itself, leaving two entries per fill and corrupting undo behavior.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) node_id = 6 pixels = td_mask_to_pixels( tracks.get_mask(node_id), tracks.get_time(node_id), ndim=tracks.ndim @@ -217,7 +217,7 @@ def test_user_two_erases_then_two_undos(self, get_tracks, ndim): tracks.action_history.undo(). Reproduces bug_paint_undo: the second undo crashed because the buggy history had a duplicate UserDeleteNode entry that tried to re-add an already-restored node.""" - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) pixels_5 = td_mask_to_pixels( tracks.get_mask(5), tracks.get_time(5), ndim=tracks.ndim ) @@ -228,23 +228,23 @@ def test_user_two_erases_then_two_undos(self, get_tracks, ndim): UserUpdateSegmentation( tracks, new_value=0, updated_pixels=[(pixels_5, 5)], current_track_id=1 ) - assert not tracks.graph.has_node(5) + assert not tracks.graph_solution.has_node(5) UserUpdateSegmentation( tracks, new_value=0, updated_pixels=[(pixels_6, 6)], current_track_id=1 ) - assert not tracks.graph.has_node(6) + assert not tracks.graph_solution.has_node(6) assert tracks.action_history.undo() is True - assert tracks.graph.has_node(6) - assert not tracks.graph.has_node(5) + assert tracks.graph_solution.has_node(6) + assert not tracks.graph_solution.has_node(5) assert tracks.action_history.undo() is True - assert tracks.graph.has_node(5) - assert tracks.graph.has_node(6) + assert tracks.graph_solution.has_node(5) + assert tracks.graph_solution.has_node(6) def test_user_add_seg(self, get_tracks, ndim): - tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) + tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True) # draw a new node just like node 6 but in time 3 (instead of 4) old_node_id = 6 node_id = 7 @@ -260,7 +260,7 @@ def test_user_add_seg(self, get_tracks, ndim): position = tracks.get_position(old_node_id) area = tracks.get_node_attr(old_node_id, "area") - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) assert np.sum(tracks.segmentation == node_id) == 0 action = UserUpdateSegmentation( @@ -270,22 +270,22 @@ def test_user_add_seg(self, get_tracks, ndim): current_track_id=10, ) assert np.sum(np.asarray(tracks.segmentation) == node_id) == len(pixels_to_add[0]) - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert tracks.get_position(node_id) == position assert tracks.get_node_attr(node_id, "area") == area assert tracks.get_track_id(node_id) == 10 inverse = action.inverse() - assert not tracks.graph.has_node(node_id) + assert not tracks.graph_solution.has_node(node_id) inverse.inverse() - assert tracks.graph.has_node(node_id) + assert tracks.graph_solution.has_node(node_id) assert tracks.get_position(node_id) == position assert tracks.get_node_attr(node_id, "area") == area assert tracks.get_track_id(node_id) == 10 def test_missing_seg(get_tracks): - tracks = get_tracks(ndim=3, with_seg=False, is_solution=True) + tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True) with pytest.raises(ValueError, match="Cannot update non-existing segmentation"): UserUpdateSegmentation(tracks, 0, [], 1) diff --git a/tests/utils/test_tracksdata_utils.py b/tests/utils/test_tracksdata_utils.py index 32231579..e4eb0793 100644 --- a/tests/utils/test_tracksdata_utils.py +++ b/tests/utils/test_tracksdata_utils.py @@ -6,7 +6,7 @@ import pytest from funtracks.utils.tracksdata_utils import ( - create_empty_graphview_graph, + create_empty_graph, pixels_to_td_mask, td_mask_to_pixels, ) @@ -142,7 +142,7 @@ def test_pixels_coordinate_offset(ndim): def test_memory_graph_survives_thread_boundary(): - """A GraphView created in a worker thread must remain accessible from the main thread. + """A base graph created in a worker thread must stay accessible from the main thread. Regression test: nodes_from_segmentation previously used database=':memory:', which caused 'no such table: Metadata' when the graph crossed a thread boundary @@ -152,7 +152,7 @@ def test_memory_graph_survives_thread_boundary(): result = {} def worker(): - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["pos"], ndim=3, ) @@ -165,22 +165,23 @@ def worker(): graph = result["graph"] - # This calls graph.metadata internally via BaseGraph.from_other(). + # create_empty_graph now returns the base graph; build a view to exercise + # detach(), which calls graph.metadata internally via BaseGraph.from_other(). # With :memory: + default connection pool it opens a new empty DB → crash. - detached = graph.detach() + detached = graph.filter().subgraph().detach() assert detached.num_nodes() == 1 -def test_create_empty_graphview_graph_with_solution_attr(): +def test_create_empty_graph_with_solution_attr(): """Test that passing solution as a node/edge attribute does not raise. - Regression test: create_empty_graphview_graph unconditionally added the + Regression test: create_empty_graph unconditionally added the solution attribute at the end, even when it was already added via the node_attributes / edge_attributes loop, causing a ValueError. """ # Should not raise ValueError even though solution is listed explicitly - graph = create_empty_graphview_graph( + graph = create_empty_graph( node_attributes=["solution"], edge_attributes=["solution"], node_default_values=[True],