diff --git a/docs/features.md b/docs/features.md
index 8022a26e..ddfeb217 100644
--- a/docs/features.md
+++ b/docs/features.md
@@ -54,7 +54,7 @@ An abstract base class for components that compute and update features on a grap
|-----------|---------|--------------|-------------------|---------------|
| **RegionpropsAnnotator** | Extracts node features from segmentation using scikit-image's `regionprops` | `segmentation` must not be `None` | `pos`, `area`, `ellipse_axis_radii`, `circularity`, `perimeter` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.RegionpropsAnnotator) |
| **EdgeAnnotator** | Computes edge features based on segmentation overlap between consecutive time frames | `segmentation` must not be `None` | `iou` (Intersection over Union) | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.EdgeAnnotator) |
-| **TrackAnnotator** | Computes tracklet and lineage IDs for SolutionTracks | Must be used with `SolutionTracks` (binary tree structure) | `tracklet_id`, `lineage_id` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.TrackAnnotator) |
+| **TrackAnnotator** | Computes tracklet and lineage IDs | Requires `tracks.features.tracklet_key` to be set (i.e. a tracking solution with a binary-tree structure) | `tracklet_id`, `lineage_id` | [📚 API](../reference/funtracks/annotators/#funtracks.annotators.TrackAnnotator) |
### 5. AnnotatorRegistry
@@ -149,7 +149,8 @@ classDiagram
}
class Tracks {
- +graph: td.graph.GraphView
+ +graph_full: td.graph.BaseGraph
+ +graph_solution: td.graph.GraphView
+segmentation: ndarray|None
+features: FeatureDict
+annotators: AnnotatorRegistry
@@ -199,10 +200,10 @@ provided:
```python
# Uses default: tracklet_key="tracklet_id"
-tracks = SolutionTracks(graph=graph)
+tracks = Tracks(graph=graph, tracklet_attr="tracklet_id")
# Uses custom attribute name
-tracks = SolutionTracks(graph=graph, tracklet_attr="my_track_col")
+tracks = Tracks(graph=graph, tracklet_attr="my_track_col")
```
Custom attribute names are also supported through `FeatureDict`:
@@ -213,7 +214,7 @@ fd = FeatureDict(
tracklet_key="my_track_col",
lineage_key="my_lineage",
)
-tracks = SolutionTracks(graph=graph, features=fd)
+tracks = Tracks(graph=graph, features=fd)
```
When a `FeatureDict` is provided, the `tracklet_attr`/`pos_attr`/`time_attr` arguments
@@ -229,7 +230,7 @@ TrackAnnotator(tracks, tracklet_key="my_track", lineage_key="my_lineage")
RegionpropsAnnotator(tracks, pos_key="coordinates")
```
-When constructing `Tracks` or `SolutionTracks` directly, you have full control over
+When constructing `Tracks` directly, you have full control over
which attribute names are used.
**Through the import path** (`tracks_from_df`, `import_from_geff`), computed features
@@ -286,11 +287,11 @@ tracks = tracks_from_df(df, segmentation=seg)
**Scenario 2: Creating tracks from raw segmentation**
```python
-from funtracks.utils import create_empty_graphview_graph
+from funtracks.utils import create_empty_graph
from funtracks.data_model import Tracks
# Create empty graph and add nodes
-graph = create_empty_graphview_graph()
+graph = create_empty_graph()
graph.add_node(index=1, attrs={"t": 0})
tracks = Tracks(graph, segmentation=seg)
# Auto-detection: pos, area don't exist → compute them from segmentation
@@ -353,7 +354,7 @@ tracks.disable_features(["area"])
def compute(self, feature_keys=None):
# Compute feature values in bulk
if "custom" in self.features:
- for node in self.tracks.graph.node_ids():
+ for node in self.tracks.graph_solution.node_ids():
value = self._compute_custom(node)
self.tracks[node]["custom"] = value
diff --git a/docs/import-flow.md b/docs/import-flow.md
index 1cec7319..5710248b 100644
--- a/docs/import-flow.md
+++ b/docs/import-flow.md
@@ -25,7 +25,7 @@ graph LR
Validate["validate
check graph structure"]
ConstructGraph["construct_graph
build tracksdata graph"]
HandleSeg["handle_segmentation
load & relabel if needed"]
- CreateTracks[Create SolutionTracks]
+ CreateTracks[Create Tracks]
EnableFeatures["enable_features
register & compute"]
ValidateMap --> LoadSource
diff --git a/pyproject.toml b/pyproject.toml
index 0c5eb7a8..581810d9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,7 +38,7 @@ dependencies =[
"dask>=2025.5.0",
"pandas>=2.3.3",
"zarr>=2.18,<4",
- "tracksdata[spatial]==0.1.0rc4",
+ "tracksdata[spatial]>=0.1.0rc6",
"tqdm>=4.66.1",
# zarr 2.x's util.py imports cbuffer_sizes/cbuffer_metainfo from
# numcodecs.blosc, which numcodecs >= 0.16 removed. Pin numcodecs per
@@ -107,8 +107,6 @@ unfixable = [
]
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D"] # no docstrings in tests
-"src/funtracks/data_model/tracks.py" = ["D"] # Remove this when refactoring tracks
-"src/funtracks/data_model/solution_tracks.py" = ["D"] # Remove this when refactoring tracks
"__init__.py" = ["F401"] # unused imports allowed in __init__.py
# https://docs.astral.sh/ruff/formatter/
diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py
index dda79e80..3a9fb28a 100644
--- a/src/funtracks/actions/add_delete_edge.py
+++ b/src/funtracks/actions/add_delete_edge.py
@@ -45,35 +45,51 @@ def _apply(self) -> None:
Raises:
ValueError if an endpoint of the edge does not exist
"""
- # Check that both endpoints exist before computing edge attributes
+ # Check that both endpoints exist in the solution before adding the edge
for node in self.edge:
- if not self.tracks.graph.has_node(node):
+ if not self.tracks.graph_solution.has_node(node):
raise ValueError(
- f"Cannot add edge {self.edge}: endpoint {node} not in graph yet"
+ f"Cannot add edge {self.edge}: endpoint {node} not in solution yet"
)
- if self.tracks.graph.has_edge(*self.edge):
- raise ValueError(f"Edge {self.edge} already exists in the graph")
-
- attrs = dict(self.attributes)
-
- # Fill in missing edge attributes with schema defaults (includes
- # solution and any other registered edge attrs).
- schemas = self.tracks.graph._edge_attr_schemas()
- for attr in self.tracks.graph.edge_attr_keys():
- if attr not in attrs:
- # An edge added to a Tracks graph is by definition part of the
- # solution, so default `solution` to True rather than the schema
- # default, which can be wrong (e.g. Float64/0.0) on graphs loaded
- # from geff. An explicit caller-provided value still wins.
- attrs[attr] = True if attr == "solution" else schemas[attr].default_value
-
- # Create edge attributes for this specific edge
- self.tracks.graph.add_edge(
- source_id=self.edge[0],
- target_id=self.edge[1],
- attrs=attrs,
- )
+ if self.tracks.graph_solution.has_edge(*self.edge):
+ raise ValueError(f"Edge {self.edge} already exists in the solution")
+
+ if self.tracks.graph_full.has_edge(*self.edge):
+ # Revive a soft-deleted edge (already present in the full graph as a
+ # candidate): flip solution=True, apply any caller-provided attributes (so
+ # revive matches the add-new branch), and re-surface it in the solution view.
+ edge_id = self.tracks.graph_full.edge_id(self.edge[0], self.edge[1])
+ # Values are wrapped in single-element lists because update_edge_attrs
+ # reads a bare list value (e.g. a vector feature) as one-value-per-edge.
+ revive_attrs = {k: [v] for k, v in self.attributes.items() if k != "solution"}
+ revive_attrs["solution"] = [True]
+ self.tracks.graph_full.update_edge_attrs(
+ attrs=revive_attrs, edge_ids=[edge_id]
+ )
+ self.tracks.graph_solution.add_edge_to_view(self.edge[0], self.edge[1])
+ else:
+ attrs = dict(self.attributes)
+
+ # Fill in missing edge attributes with schema defaults (includes
+ # solution and any other registered edge attrs).
+ schemas = self.tracks.graph_solution._edge_attr_schemas()
+ for attr in self.tracks.graph_solution.edge_attr_keys():
+ if attr not in attrs:
+ # An edge added to a Tracks graph is by definition part of the
+ # solution, so default `solution` to True rather than the schema
+ # default, which can be wrong (e.g. Float64/0.0) on graphs loaded
+ # from geff. An explicit caller-provided value still wins.
+ attrs[attr] = (
+ True if attr == "solution" else schemas[attr].default_value
+ )
+
+ # Create edge attributes for this specific edge
+ self.tracks.graph_solution.add_edge(
+ source_id=self.edge[0],
+ target_id=self.edge[1],
+ attrs=attrs,
+ )
# Notify annotators to recompute features (will overwrite computed ones)
self.tracks.notify_annotators(self)
@@ -93,7 +109,7 @@ def __init__(self, tracks: Tracks, edge: Edge):
"""
super().__init__(tracks)
self.edge = edge
- if not self.tracks.graph.has_edge(*self.edge):
+ if not self.tracks.graph_solution.has_edge(*self.edge):
raise ValueError(f"Edge {self.edge} not in the graph, and cannot be removed")
# Save all edge feature values from the features dict
@@ -110,5 +126,12 @@ def inverse(self) -> BasicAction:
return AddEdge(self.tracks, self.edge, attributes=self.attributes)
def _apply(self) -> None:
- self.tracks.graph.remove_edge(*self.edge)
+ """Soft-delete the edge: flag solution=False in the full graph and remove it
+ from the solution view only. The edge is preserved in graph_full (as a
+ candidate) so the delete is reversible."""
+ edge_id = self.tracks.graph_full.edge_id(self.edge[0], self.edge[1])
+ self.tracks.graph_full.update_edge_attrs(
+ attrs={"solution": False}, edge_ids=[edge_id]
+ )
+ self.tracks.graph_solution.remove_edge_from_view(self.edge[0], self.edge[1])
self.tracks.notify_annotators(self)
diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py
index 83f87bfc..ed3df040 100644
--- a/src/funtracks/actions/add_delete_node.py
+++ b/src/funtracks/actions/add_delete_node.py
@@ -9,8 +9,7 @@
if TYPE_CHECKING:
from typing import Any
- from funtracks.data_model.solution_tracks import SolutionTracks
- from funtracks.data_model.tracks import Node
+ from funtracks.data_model.tracks import Node, Tracks
class AddNode(BasicAction):
@@ -22,7 +21,7 @@ class AddNode(BasicAction):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
node: Node,
attributes: dict[str, Any],
):
@@ -40,7 +39,7 @@ def __init__(
ValueError: If neither position nor a mask feature is in attributes.
"""
super().__init__(tracks)
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
self.node = int(node)
# Get keys from tracks features
@@ -76,14 +75,34 @@ def inverse(self) -> BasicAction:
return DeleteNode(self.tracks, self.node)
def _apply(self) -> None:
- """Add the node with all attributes from self.attributes."""
- attrs = dict(self.attributes)
- # A node added to a Tracks graph is by definition part of the solution,
- # so default `solution` to True rather than the column schema default,
- # which can be wrong on graphs loaded from geff. An explicit
- # caller-provided value still wins.
- attrs.setdefault("solution", True)
- self.tracks.graph.add_node(attrs=attrs, index=self.node, validate_keys=False)
+ """Add the node, or revive a soft-deleted one.
+
+ If the node still exists in the full graph (it was soft-deleted, so its
+ topology was preserved), revive it: flip solution=True and re-surface it in the
+ solution view. Otherwise add a genuinely new node.
+ """
+ if self.tracks.graph_full.has_node(self.node):
+ # Revive: same node id, topology preserved in graph_full. Flip it back into
+ # the solution, apply any caller-provided attributes (so revive matches the
+ # add-new branch), and re-surface it in the view in place (incident edges
+ # are revived separately by AddEdge).
+ # Values are wrapped in single-element lists because update_node_attrs
+ # reads a bare list value (pos, bbox, mask) as one-value-per-node.
+ revive_attrs = {k: [v] for k, v in self.attributes.items() if k != "solution"}
+ revive_attrs["solution"] = [True]
+ self.tracks.graph_full.update_node_attrs(
+ attrs=revive_attrs, node_ids=[self.node]
+ )
+ self.tracks.graph_solution.add_node_to_view(self.node)
+ else:
+ # Genuinely new node — default `solution` to True rather than the
+ # column schema default, which can be wrong (e.g. Float64/0.0) on
+ # graphs loaded from geff. An explicit caller-provided value still wins.
+ attrs = dict(self.attributes)
+ attrs.setdefault("solution", True)
+ self.tracks.graph_solution.add_node(
+ attrs=attrs, index=self.node, validate_keys=False
+ )
# Always notify annotators - they will check their own preconditions
self.tracks.notify_annotators(self)
@@ -93,15 +112,24 @@ class DeleteNode(BasicAction):
"""Action of deleting an existing node.
Saves all node feature values so the action can be inverted.
+
+ Low-level action — not meant to be used directly. It soft-deletes only the
+ node itself (incident edges are dropped from the view by
+ ``remove_node_from_view`` but keep ``solution=True`` in ``graph_full``).
+ Managing the incident edges' solution flags is the responsibility of the
+ enclosing user action (``UserDeleteNode``), which soft-deletes each incident
+ edge with its own ``DeleteEdge`` first. Applying a bare ``DeleteNode`` to a
+ node that still has in-solution edges therefore leaves ``graph_full``'s edge
+ flags inconsistent with ``graph_solution`` — always go through the user action.
"""
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
node: Node,
):
super().__init__(tracks)
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
self.node = int(node)
# Save all node feature values from the features dict
@@ -120,6 +148,12 @@ def inverse(self) -> BasicAction:
return AddNode(self.tracks, self.node, self.attributes)
def _apply(self) -> None:
- """Remove the node from the graph."""
- self.tracks.graph.remove_node(self.node)
+ """Soft-delete the node: flag solution=False in the full graph and remove it
+ from the solution view only. The node (and its topology) is preserved in
+ graph_full so the delete is reversible and the node remains a candidate.
+ """
+ self.tracks.graph_full.update_node_attrs(
+ attrs={"solution": False}, node_ids=[self.node]
+ )
+ self.tracks.graph_solution.remove_node_from_view(self.node)
self.tracks.notify_annotators(self)
diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py
index 4d83a158..ffcc6b8d 100644
--- a/src/funtracks/actions/update_segmentation.py
+++ b/src/funtracks/actions/update_segmentation.py
@@ -61,13 +61,13 @@ def _apply(self) -> None:
if value == 0:
# val=0 means deleting (part of) the mask
- mask_old = self.tracks.graph.nodes[self.node][self.mask_key]
+ mask_old = self.tracks.graph_full.nodes[self.node][self.mask_key]
mask_subtracted = mask_old.__isub__(mask_new)
self.tracks.update_mask(self.node, mask_subtracted, mask_key=self.mask_key)
- elif self.tracks.graph.has_node(value):
+ elif self.tracks.graph_full.has_node(value):
# if node already exists:
- mask_old = self.tracks.graph.nodes[value][self.mask_key]
+ mask_old = self.tracks.graph_full.nodes[value][self.mask_key]
mask_combined = mask_old.__or__(mask_new)
self.tracks.update_mask(value, mask_combined, mask_key=self.mask_key)
diff --git a/src/funtracks/actions/update_track_id.py b/src/funtracks/actions/update_track_id.py
index a09db079..edaabedd 100644
--- a/src/funtracks/actions/update_track_id.py
+++ b/src/funtracks/actions/update_track_id.py
@@ -5,7 +5,7 @@
from ._base import BasicAction
if TYPE_CHECKING:
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
from funtracks.data_model.tracks import Node
@@ -31,13 +31,13 @@ class UpdateTrackIDs(BasicAction):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
start_node: Node,
tracklet_id: int | None = None,
lineage_id: int | None = None,
):
super().__init__(tracks)
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
self.start_node = start_node
# Capture old tracklet ID
diff --git a/src/funtracks/annotators/_edge_annotator.py b/src/funtracks/annotators/_edge_annotator.py
index b612f431..ee71b95d 100644
--- a/src/funtracks/annotators/_edge_annotator.py
+++ b/src/funtracks/annotators/_edge_annotator.py
@@ -82,14 +82,14 @@ def compute(self, feature_keys: list[str] | None = None) -> None:
# TODO: add skip edges
if self.iou_key in keys_to_compute:
nodes_by_frame = defaultdict(list)
- for n in self.tracks.graph.node_ids():
+ for n in self.graph.node_ids():
nodes_by_frame[self.tracks.get_time(n)].append(n)
for t in range(self.tracks.segmentation.shape[0] - 1):
nodes_in_t = nodes_by_frame[t]
edges = []
for node in nodes_in_t:
- for succ in self.tracks.graph.successors(node):
+ for succ in self.graph.successors(node):
edges.append((node, succ))
self._iou_update(edges)
@@ -105,8 +105,8 @@ def _iou_update(
"""
for edge in edges:
source, target = edge
- mask1 = self.tracks.graph.nodes[source]["mask"]
- mask2 = self.tracks.graph.nodes[target]["mask"]
+ mask1 = self.graph.nodes[source]["mask"]
+ mask2 = self.graph.nodes[target]["mask"]
iou = mask1.iou(mask2)
self.tracks._set_edge_attr(edge, self.iou_key, iou)
@@ -136,16 +136,16 @@ def update(self, action: BasicAction):
# Get all incident edges to the modified node
modified_node = action.node
edges_to_update = []
- for pred in self.tracks.graph.predecessors(modified_node):
+ for pred in self.graph.predecessors(modified_node):
edges_to_update.append((pred, modified_node))
- for succ in self.tracks.graph.successors(modified_node):
+ for succ in self.graph.successors(modified_node):
edges_to_update.append((modified_node, succ))
# Update IoU for each edge
for edge in edges_to_update:
source, target = edge
- mask1 = self.tracks.graph.nodes[source]["mask"]
- mask2 = self.tracks.graph.nodes[target]["mask"]
+ mask1 = self.graph.nodes[source]["mask"]
+ mask2 = self.graph.nodes[target]["mask"]
if mask1.mask.sum() == 0 or mask2.mask.sum() == 0:
empty_node = source if mask1.mask.sum() == 0 else target
frame = self.tracks.get_time(empty_node)
diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py
index 7fb8583c..ba681508 100644
--- a/src/funtracks/annotators/_graph_annotator.py
+++ b/src/funtracks/annotators/_graph_annotator.py
@@ -35,7 +35,7 @@ def can_annotate(cls, tracks: Tracks) -> bool:
"""Check if this annotator can annotate the given tracks.
Subclasses should override this method to specify their requirements
- (e.g., segmentation, SolutionTracks, etc.).
+ (e.g., segmentation, Tracks, etc.).
Args:
tracks: The tracks to check compatibility with
@@ -52,6 +52,29 @@ def __init__(self, tracks: Tracks, features: dict[str, Feature]):
key: (feat, False) for key, feat in features.items()
}
+ @property
+ def graph(self):
+ """The graph this annotator iterates over and reads topology/masks from.
+
+ Defaults to the full graph. Detection features (`pos`, `area`, `iou`, ...) are
+ intrinsic to a node/edge — independent of solution membership — so they are
+ computed for *every* node/edge, including soft-deleted (`solution=False`)
+ candidates, keeping them valid for revive and ready for re-solving.
+ `TrackAnnotator` overrides this to `graph_solution`, since track ids
+ (`tracklet_id`, `lineage_id`) are derived from the solution topology. Root and
+ view share attribute storage, so writes made here are seen through the solution
+ view automatically.
+
+ KNOWN COST: computing over `graph_full` means `compute()`/`update()` run
+ regionprops / mask.iou for every candidate (solution=False) node/edge too. On
+ candidate graphs with many unselected detections (often 10-100x the solution
+ size) this is O(candidates) work that is mostly never read, since revive is rare.
+ This is a deliberate eagerness-for-readiness trade; if it becomes a bottleneck,
+ the lever is lazy compute-on-revive (default to graph_solution and compute a
+ candidate's intrinsic features only when it enters the solution).
+ """
+ return self.tracks.graph_full
+
def activate_features(self, keys: list[str]) -> None:
"""Activate computation of the given features in the annotation process.
@@ -106,7 +129,7 @@ def _filter_feature_keys(self, feature_keys: list[str] | None) -> list[str]:
def compute(self, feature_keys: list[str] | None = None) -> None:
"""Compute a set of features and add them to the tracks.
- This involves both updating the node/edge attributes on the tracks.graph
+ This involves both updating the node/edge attributes on `self.graph`
and adding the features to the FeatureDict, if necessary. This is distinct
from `update` to allow more efficient bulk computation of features.
@@ -123,8 +146,9 @@ def compute(self, feature_keys: list[str] | None = None) -> None:
def update(self, action: BasicAction) -> None:
"""Update a set of features based on the given action.
- This involves both updating the node or edge attributes on the tracks.graph
- and adding the features to the FeatureDict, if necessary. This is distinct
+ This involves both updating the node or edge attributes on `self.graph`
+ and adding the features to the FeatureDict, if
+ necessary. This is distinct
from `compute` to allow more efficient computation of features for single
elements.
diff --git a/src/funtracks/annotators/_regionprops_annotator.py b/src/funtracks/annotators/_regionprops_annotator.py
index 81ec42db..e35295a9 100644
--- a/src/funtracks/annotators/_regionprops_annotator.py
+++ b/src/funtracks/annotators/_regionprops_annotator.py
@@ -170,10 +170,10 @@ def compute(self, feature_keys: list[str] | None = None) -> None:
all_node_ids = []
all_values: dict[str, list] = {key: [] for key in keys_to_compute}
- for node_id in self.tracks.graph.node_ids():
- if not self.tracks.graph.has_node(node_id):
+ for node_id in self.graph.node_ids():
+ if not self.graph.has_node(node_id):
continue
- mask = self.tracks.graph.nodes[node_id]["mask"]
+ mask = self.graph.nodes[node_id]["mask"]
for region in regionprops_extended(mask, spacing=spacing):
all_node_ids.append(node_id)
for key in keys_to_compute:
@@ -203,7 +203,7 @@ def _regionprops_update(
spacing = None if self.tracks.scale is None else tuple(self.tracks.scale[1:])
for region in regionprops_extended(mask, spacing=spacing):
# Skip labels that aren't nodes in the graph (e.g., unselected detections)
- if not self.tracks.graph.has_node(node_id):
+ if not self.graph.has_node(node_id):
continue
for key in feature_keys:
value = getattr(region, self.regionprops_names[key])
@@ -240,7 +240,7 @@ def update(self, action: BasicAction):
time = self.tracks.get_time(node)
- if self.tracks.graph.nodes[node]["mask"].mask.sum() == 0:
+ if self.graph.nodes[node]["mask"].mask.sum() == 0:
warnings.warn(
f"Cannot find label {node} in frame {time}: "
"updating regionprops values to None",
@@ -250,7 +250,7 @@ def update(self, action: BasicAction):
value = None
self.tracks._set_node_attr(node, key, value)
else:
- mask = self.tracks.graph.nodes[node]["mask"]
+ mask = self.graph.nodes[node]["mask"]
self._regionprops_update(node, mask, keys_to_compute)
def change_key(self, old_key: str, new_key: str) -> None:
diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py
index d8c66e42..38c81c7a 100644
--- a/src/funtracks/annotators/_track_annotator.py
+++ b/src/funtracks/annotators/_track_annotator.py
@@ -8,7 +8,6 @@
import tracksdata as td
from funtracks.actions import AddNode, DeleteNode, UpdateTrackIDs
-from funtracks.data_model import SolutionTracks
from funtracks.features import LineageID, TrackletID
from ._graph_annotator import GraphAnnotator
@@ -17,6 +16,7 @@
from collections.abc import Iterable
from funtracks.actions import BasicAction
+ from funtracks.data_model import Tracks
from funtracks.features import Feature
@@ -25,9 +25,12 @@
class TrackAnnotator(GraphAnnotator):
- """A graph annotator to compute tracklet and lineage IDs for SolutionTracks only.
+ """A graph annotator that computes tracklet and lineage IDs on the solution view.
- Currently, updating the tracklet and lineage IDs is left to Actions.
+ Registered on every Tracks — track ids are a core feature, not a separate "type"
+ of tracks. It reads/iterates `graph_solution`; on an empty solution view it simply
+ computes nothing until nodes are added. Updating the ids after construction is left
+ to Actions.
Attributes:
tracklet_id_to_nodes (dict[int, list[int]]): A mapping from tracklet ids to
@@ -38,33 +41,20 @@ class TrackAnnotator(GraphAnnotator):
max_lineage_id (int): the maximum lineage id used in the tracks
Args:
- tracks (SolutionTracks): The tracks to be annotated. Must be a solution.
- tracklet_key (str | None, optional): A key that already holds the tracklet ids
- on the graph. If provided, must be there for every node and already hold
- valid tracklet ids. Defaults to None.
- lineage_key (str | None, optional): A key that already holds the lineage ids
- on the graph. If provided, must be there for every node and already hold
- valid lineage ids. Defaults to None.
-
-
- Raises:
- ValueError: if the provided Tracks are not SolutionTracks (not a binary lineage
- tree)
+ tracks (Tracks): The tracks to annotate.
+ tracklet_key (str | None, optional): The node attribute holding tracklet ids.
+ If the graph already holds valid ids under this key they are reused;
+ otherwise they are computed. Defaults to DEFAULT_TRACKLET_KEY.
+ lineage_key (str | None, optional): The node attribute holding lineage ids.
+ Same semantics as tracklet_key. Defaults to DEFAULT_LINEAGE_KEY.
"""
- @classmethod
- def can_annotate(cls, tracks) -> bool:
- """Check if this annotator can annotate the given tracks.
-
- Requires tracks to be a SolutionTracks instance.
-
- Args:
- tracks: The tracks to check compatibility with
-
- Returns:
- True if tracks is a SolutionTracks instance, False otherwise
- """
- return isinstance(tracks, SolutionTracks)
+ @property
+ def graph(self):
+ """Track ids (`tracklet_id`, `lineage_id`) are derived from the solution
+ topology, so this annotator reads/iterates the solution view, not the full
+ graph (overriding the base default of `graph_full`)."""
+ return self.tracks.graph_solution
@classmethod
def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]:
@@ -87,14 +77,11 @@ def get_available_features(cls, ndim: int = 3) -> dict[str, Feature]:
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
tracklet_key: str | None = DEFAULT_TRACKLET_KEY,
lineage_key: str | None = DEFAULT_LINEAGE_KEY,
):
- if not isinstance(tracks, SolutionTracks):
- raise ValueError("Currently the TrackAnnotator only works on SolutionTracks")
-
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
self.tracklet_key = (
tracklet_key if tracklet_key is not None else DEFAULT_TRACKLET_KEY
)
@@ -112,13 +99,13 @@ def __init__(
self.max_lineage_id = 0
# Initialize tracklet bookkeeping if track IDs already exist in the graph
- if tracks.graph.num_nodes() > 0:
+ if tracks.graph_solution.num_nodes() > 0:
max_id, id_to_nodes = self._get_max_id_and_map(self.tracklet_key)
self.max_tracklet_id = max_id
self.tracklet_id_to_nodes = id_to_nodes
# Initialize lineage bookkeeping if lineage IDs already exist
- if lineage_key is not None and tracks.graph.num_nodes() > 0:
+ if lineage_key is not None and tracks.graph_solution.num_nodes() > 0:
max_id, id_to_nodes = self._get_max_id_and_map(self.lineage_key)
self.max_lineage_id = max_id
self.lineage_id_to_nodes = id_to_nodes
@@ -133,9 +120,9 @@ def _get_max_id_and_map(self, key: str) -> tuple[int, dict[int, list[int]]]:
tuple[int, dict[int, list[int]]]: The maximum id value, and a mapping from
ids to a list of nodes with that id.
"""
- if key not in self.tracks.graph.node_attr_keys():
+ if key not in self.graph.node_attr_keys():
return 0, {}
- df = self.tracks.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, key])
+ df = self.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, key])
id_to_nodes = defaultdict(list)
for node, _id in zip(df[td.DEFAULT_ATTR_KEYS.NODE_ID], df[key], strict=True):
if _id is None:
@@ -199,12 +186,12 @@ def _assign_lineage_ids(self) -> None:
attributes will be updated.
"""
- lineages_internal = rx.weakly_connected_components(self.tracks.graph.rx_graph)
+ lineages_internal = rx.weakly_connected_components(self.graph.rx_graph)
# Map each component's internal node indices to external ids in one batched
# call. node_ids() rebuilds the full external-id list on every call, so calling
# it per node (as before) was O(N^2).
lineages_external = [
- self.tracks.graph._map_to_external(list(lin)) for lin in lineages_internal
+ self.graph._map_to_external(list(lin)) for lin in lineages_internal
]
max_id, ids_to_nodes = self._assign_ids(lineages_external, self.lineage_key)
@@ -222,7 +209,7 @@ def _assign_tracklet_ids(self) -> None:
# slow because each remove_edge syncs the view's bidirectional edge maps,
# whereas rustworkx edge removal is in-memory. copy() preserves node indices,
# so components map back through the original graph's id mapping.
- rx_copy = self.tracks.graph.rx_graph.copy()
+ rx_copy = self.graph.rx_graph.copy()
for node in rx_copy.node_indices():
if rx_copy.out_degree(node) >= 2:
for _, daughter, _ in list(rx_copy.out_edges(node)):
@@ -231,18 +218,23 @@ def _assign_tracklet_ids(self) -> None:
track_id = 1
all_node_ids = []
all_track_ids = []
+ id_to_nodes = {}
for tracklet in rx.weakly_connected_components(rx_copy):
# Batched internal -> external mapping (see _assign_lineage_ids).
- node_ids_external = self.tracks.graph._map_to_external(list(tracklet))
+ node_ids_external = self.graph._map_to_external(list(tracklet))
all_node_ids.extend(node_ids_external)
all_track_ids.extend([track_id] * len(node_ids_external))
- self.tracklet_id_to_nodes[track_id] = node_ids_external
+ id_to_nodes[track_id] = node_ids_external
track_id += 1
if all_node_ids:
- self.tracks.graph.update_node_attrs(
+ self.graph.update_node_attrs(
attrs={self.tracks.features.tracklet_key: all_track_ids},
node_ids=all_node_ids,
)
+ # Replace the bookkeeping wholesale (like _assign_lineage_ids) so stale
+ # entries — e.g. a phantom tracklet -1 seeded at init from a graph whose
+ # tracklet column still held the -1 sentinel — don't survive a recompute.
+ self.tracklet_id_to_nodes = id_to_nodes
self.max_tracklet_id = track_id - 1
def update(self, action: BasicAction) -> None:
@@ -309,18 +301,18 @@ def _handle_update_track_ids(self, action: UpdateTrackIDs) -> None:
still_in_tracklet = False
# Continue to all successors
- next_nodes.extend(self.tracks.graph.successors(node))
+ next_nodes.extend(self.graph.successors(node))
curr_nodes = next_nodes
# Bulk-write all collected node attribute changes in one call each
if update_tracklet and tracklet_nodes:
- self.tracks.graph.update_node_attrs(
+ self.graph.update_node_attrs(
attrs={self.tracklet_key: [new_tracklet_id] * len(tracklet_nodes)},
node_ids=tracklet_nodes,
)
if update_lineage and lineage_nodes:
- self.tracks.graph.update_node_attrs(
+ self.graph.update_node_attrs(
attrs={self.lineage_key: [new_lineage_id] * len(lineage_nodes)},
node_ids=lineage_nodes,
)
diff --git a/src/funtracks/candidate_graph/compute_graph.py b/src/funtracks/candidate_graph/compute_graph.py
index a0efa378..e77ee461 100644
--- a/src/funtracks/candidate_graph/compute_graph.py
+++ b/src/funtracks/candidate_graph/compute_graph.py
@@ -15,7 +15,7 @@ def compute_graph_from_seg(
iou: bool = False,
scale: list[float] | None = None,
t_start: int = 0,
-) -> td.graph.GraphView:
+) -> td.graph.BaseGraph:
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
centroid of each segmentation and edges are added for all nodes in adjacent frames
within max_edge_distance.
@@ -37,7 +37,7 @@ def compute_graph_from_seg(
time values. Defaults to 0.
Returns:
- td.graph.GraphView: A candidate graph that can be passed to the motile solver
+ td.graph.BaseGraph: A candidate graph that can be passed to the motile solver
"""
# add nodes (including mask and bbox in the same bulk_add_nodes call)
cand_graph, node_frame_dict = nodes_from_segmentation(
@@ -73,7 +73,7 @@ def compute_graph_from_points_list(
points_list: np.ndarray,
max_edge_distance: float,
scale: list[float] | None = None,
-) -> td.graph.GraphView:
+) -> td.graph.BaseGraph:
"""Construct a candidate graph from a points list.
Args:
@@ -88,7 +88,7 @@ def compute_graph_from_points_list(
isotropic.
Returns:
- td.graph.GraphView: A candidate graph that can be passed to the motile solver.
+ td.graph.BaseGraph: A candidate graph that can be passed to the motile solver.
"""
# add nodes
cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale)
diff --git a/src/funtracks/candidate_graph/iou.py b/src/funtracks/candidate_graph/iou.py
index 0ca2d315..97dd6e71 100644
--- a/src/funtracks/candidate_graph/iou.py
+++ b/src/funtracks/candidate_graph/iou.py
@@ -80,7 +80,7 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]:
def add_iou(
- cand_graph: td.graph.GraphView,
+ cand_graph: td.graph.BaseGraph,
segmentation: np.ndarray,
node_frame_dict: dict[int, list[int]] | None = None,
multiseg=False,
@@ -92,7 +92,7 @@ def add_iou(
add IOU to an existing graph after the fact.
Args:
- cand_graph (td.graph.GraphView): Candidate graph with nodes and edges already
+ cand_graph (td.graph.BaseGraph): Candidate graph with nodes and edges already
populated.
segmentation (np.ndarray): segmentation that was used to create cand_graph.
Has shape ([h], t, [z], y, x), where h is the number of hypotheses if
diff --git a/src/funtracks/candidate_graph/utils.py b/src/funtracks/candidate_graph/utils.py
index 68948515..f1fee38c 100644
--- a/src/funtracks/candidate_graph/utils.py
+++ b/src/funtracks/candidate_graph/utils.py
@@ -9,7 +9,7 @@
from tqdm import tqdm
from tracksdata.nodes import Mask
-from ..utils.tracksdata_utils import create_empty_graphview_graph
+from ..utils.tracksdata_utils import create_empty_graph
logger = logging.getLogger(__name__)
@@ -19,7 +19,7 @@ def nodes_from_segmentation(
scale: list[float] | None = None,
mask: bool = True,
t_start: int = 0,
-) -> tuple[td.graph.GraphView, dict[int, list[Any]]]:
+) -> tuple[td.graph.BaseGraph, dict[int, list[Any]]]:
"""Extract candidate nodes from a segmentation. Returns a tracksdata graph
with only nodes, and also a dictionary from frames to node_ids for
efficient edge adding.
@@ -50,7 +50,7 @@ def nodes_from_segmentation(
time values. Defaults to 0.
Returns:
- tuple[td.graph.GraphView, dict[int, list[Any]]]: A candidate graph with only
+ tuple[td.graph.BaseGraph, dict[int, list[Any]]]: A candidate graph with only
nodes, and a mapping from time frames to node ids.
"""
logger.debug("Extracting nodes from segmentation")
@@ -66,7 +66,7 @@ def nodes_from_segmentation(
)
node_attributes = ["pos", "area", "mask", "bbox"] if mask else ["pos", "area"]
- cand_graph = create_empty_graphview_graph(
+ cand_graph = create_empty_graph(
node_attributes=node_attributes,
position_attrs=["pos"],
ndim=segmentation.ndim,
@@ -115,7 +115,7 @@ def nodes_from_segmentation(
def nodes_from_points_list(
points_list: np.ndarray,
scale: list[float] | None = None,
-) -> tuple[td.graph.GraphView, dict[int, list[Any]]]:
+) -> tuple[td.graph.BaseGraph, dict[int, list[Any]]]:
"""Extract candidate nodes from a list of points. Uses the index of the
point in the list as its unique id.
Returns a tracksdata graph with only nodes, and also a dictionary from frames to
@@ -130,7 +130,7 @@ def nodes_from_points_list(
implies the data is isotropic.
Returns:
- tuple[td.graph.GraphView, dict[int, list[Any]]]: A candidate graph with only
+ tuple[td.graph.BaseGraph, dict[int, list[Any]]]: A candidate graph with only
nodes, and a mapping from time frames to node ids.
"""
logger.info("Extracting nodes from points list")
@@ -143,7 +143,7 @@ def nodes_from_points_list(
)
points_list = points_list * np.array(scale)
- cand_graph = create_empty_graphview_graph(
+ cand_graph = create_empty_graph(
node_attributes=["pos"],
position_attrs=["pos"],
ndim=ndim,
@@ -168,11 +168,11 @@ def nodes_from_points_list(
return cand_graph, node_frame_dict
-def _compute_node_frame_dict(cand_graph: td.graph.GraphView) -> dict[int, list[Any]]:
+def _compute_node_frame_dict(cand_graph: td.graph.BaseGraph) -> dict[int, list[Any]]:
"""Compute dictionary from time frames to node ids for candidate graph.
Args:
- cand_graph (td.graph.GraphView): A tracksdata graph
+ cand_graph (td.graph.BaseGraph): A tracksdata graph
Returns:
dict[int, list[Any]]: A mapping from time frames to lists of node ids.
@@ -187,12 +187,12 @@ def _compute_node_frame_dict(cand_graph: td.graph.GraphView) -> dict[int, list[A
return node_frame_dict
-def create_kdtree(cand_graph: td.graph.GraphView, node_ids: list[Any]) -> KDTree:
+def create_kdtree(cand_graph: td.graph.BaseGraph, node_ids: list[Any]) -> KDTree:
"""Create a kdtree with the given nodes from the candidate graph.
Will fail if provided node ids are not in the candidate graph.
Args:
- cand_graph (td.graph.GraphView): A candidate graph
+ cand_graph (td.graph.BaseGraph): A candidate graph
node_ids (list[Any]): The nodes within the candidate graph to
include in the KDTree. Useful for limiting to one time frame.
Must be a list (not a generic iterable) to preserve order for
@@ -212,7 +212,7 @@ def create_kdtree(cand_graph: td.graph.GraphView, node_ids: list[Any]) -> KDTree
def add_cand_edges(
- cand_graph: td.graph.GraphView,
+ cand_graph: td.graph.BaseGraph,
max_edge_distance: float,
node_frame_dict: None | dict[int, list[Any]] = None,
iou_dict: dict[int, dict[int, float]] | None = None,
@@ -221,7 +221,7 @@ def add_cand_edges(
frames that are closer than max_edge_distance.
Args:
- cand_graph (td.graph.GraphView): Candidate graph with only nodes populated.
+ cand_graph (td.graph.BaseGraph): Candidate graph with only nodes populated.
Will be modified in-place to add edges.
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes within this distance in adjacent frames will by connected
diff --git a/src/funtracks/data_model/__init__.py b/src/funtracks/data_model/__init__.py
index ef0f4187..8ff09431 100644
--- a/src/funtracks/data_model/__init__.py
+++ b/src/funtracks/data_model/__init__.py
@@ -1,2 +1 @@
from .tracks import Tracks # noqa
-from .solution_tracks import SolutionTracks # noqa
diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py
deleted file mode 100644
index 9b3eaaee..00000000
--- a/src/funtracks/data_model/solution_tracks.py
+++ /dev/null
@@ -1,244 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-import tracksdata as td
-
-from funtracks.features import FeatureDict
-
-from .tracks import Tracks
-
-if TYPE_CHECKING:
- from funtracks.annotators import TrackAnnotator
-
- from .tracks import Node
-
-
-class SolutionTracks(Tracks):
- """Difference from Tracks: every node must have a tracklet id"""
-
- def __init__(
- self,
- graph: td.graph.GraphView,
- time_attr: str | None = None,
- pos_attr: str | tuple[str] | list[str] | None = None,
- tracklet_attr: str | None = None,
- lineage_attr: str | None = None,
- scale: list[float] | None = None,
- ndim: int | None = None,
- features: FeatureDict | None = None,
- _segmentation: td.array.GraphArrayView | None = None,
- ):
- """Initialize a SolutionTracks object.
-
- SolutionTracks extends Tracks to ensure every node has a tracklet id. A
- TrackAnnotator is automatically added to manage track IDs.
-
- Args:
- graph (td.graph.GraphView): Tracksdata graph with nodes as detections
- and edges as links.
- time_attr (str | None): Graph attribute name for time. Defaults to "time"
- if None.
- pos_attr (str | tuple[str, ...] | list[str] | None): Graph attribute
- name(s) for position. Can be:
- - Single string for one attribute containing position array
- - List/tuple of strings for multi-axis (one attribute per axis)
- Defaults to "pos" if None.
- tracklet_attr (str | None): Graph attribute name for tracklet/track IDs.
- Defaults to "tracklet_id" if None.
- lineage_attr (str | None): Graph attribute name for lineage IDs.
- Defaults to "lineage_id" if None.
- scale (list[float] | None): Scaling factors for each dimension (including
- time). If None, all dimensions scaled by 1.0.
- ndim (int | None): Number of dimensions (3 for 2D+time, 4 for 3D+time).
- If None, inferred from segmentation or scale.
- features (FeatureDict | None): Pre-built FeatureDict with feature
- definitions. If provided, time_attr/pos_attr/tracklet_attr are ignored.
- Assumes that all features in the dict already exist on the graph (will
- be activated but not recomputed). If None, core computed features (pos,
- area, tracklet_id) are auto-detected by checking if they exist on the
- graph.
- _segmentation (GraphArrayView | None): Internal parameter for reusing an
- existing GraphArrayView instance. Not intended for public use.
- """
- super().__init__(
- graph,
- time_attr=time_attr,
- pos_attr=pos_attr,
- tracklet_attr=tracklet_attr,
- lineage_attr=lineage_attr,
- scale=scale,
- ndim=ndim,
- features=features,
- _segmentation=_segmentation,
- )
-
- self.track_annotator = self._get_track_annotator()
-
- def _get_track_annotator(self) -> TrackAnnotator:
- """Get the TrackAnnotator instance from the annotator registry.
-
- Returns:
- TrackAnnotator: The track annotator instance
-
- Raises:
- RuntimeError: If no TrackAnnotator is registered
- """
- from funtracks.annotators import TrackAnnotator
-
- for annotator in self.annotators:
- if isinstance(annotator, TrackAnnotator):
- return annotator
- raise RuntimeError(
- "No TrackAnnotator registered for this SolutionTracks instance"
- )
-
- @classmethod
- def from_tracks(cls, tracks: Tracks):
- force_recompute = False
- # Check if all nodes have a value at features.tracklet_key before trusting
- # existing track IDs
- if (
- tracks.features.tracklet_key is not None
- and (
- tracks.graph.node_attrs(attr_keys=tracks.features.tracklet_key)[
- tracks.features.tracklet_key
- ]
- == -1
- ).any()
- # Attributes are no longer None, so 0 now means non-computed
- ):
- force_recompute = True
-
- soln_tracks = cls(
- tracks.graph,
- scale=tracks.scale,
- ndim=tracks.ndim,
- features=tracks.features,
- _segmentation=tracks.segmentation,
- )
- if force_recompute:
- soln_tracks.enable_features(
- [
- soln_tracks.features.tracklet_key, # type: ignore[list-item]
- soln_tracks.features.lineage_key, # type: ignore[list-item]
- ]
- )
- return soln_tracks
-
- @property
- def max_track_id(self) -> int:
- return self.track_annotator.max_tracklet_id
-
- @property
- def track_id_to_node(self) -> dict[int, list[int]]:
- return self.track_annotator.tracklet_id_to_nodes
-
- def get_next_track_id(self) -> int:
- """Return the next available track_id.
-
- The max_tracklet_id in TrackAnnotator is updated automatically when
- a node is added or track IDs are updated via UpdateTrackIDs.
- """
- return self.track_annotator.max_tracklet_id + 1
-
- def get_next_lineage_id(self) -> int:
- """Return the next available lineage_id.
-
- The max_lineage_id in TrackAnnotator is updated automatically when
- a node is added or lineage IDs are updated via UpdateTrackIDs.
- """
- return self.track_annotator.max_lineage_id + 1
-
- def get_track_id(self, node) -> int:
- if self.features.tracklet_key is None:
- raise ValueError("Tracklet key not initialized in features")
- track_id = self.get_node_attr(node, self.features.tracklet_key)
- return track_id
-
- def get_track_ids(self, nodes) -> list[int]:
- """Batch version of get_track_id — one SQL query fetching all nodes in the graph.
- NOTE: always fetches the entire graph internally. Optimised for bulk (all-node)
- calls. For small subsets or single nodes use get_track_id() instead."""
-
- if self.features.tracklet_key is None:
- raise ValueError("Tracklet key not initialized in features")
- tracklet_key = self.features.tracklet_key
- df = self.graph.node_attrs(attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key])
- id_to_val = dict(
- zip(
- df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list(),
- df[tracklet_key].to_list(),
- strict=True,
- )
- )
- return [id_to_val[node] for node in nodes]
-
- def get_lineage_id(self, node) -> int | None:
- """Get the lineage ID for a node.
-
- Args:
- node: The node to get lineage ID for
-
- Returns:
- The lineage ID, or None if lineage feature is not enabled
- """
- if self.features.lineage_key is None:
- return None
- return self.get_node_attr(node, self.features.lineage_key)
-
- def get_track_neighbors(
- self, track_id: int, time: int
- ) -> tuple[Node | None, Node | None]:
- """Get the last node with the given track id before time, and the first node
- with the track id after time, if any. Does not assume that a node with
- the given track_id and time is already in tracks, but it can be.
-
- Args:
- track_id (int): The track id to search for
- time (int): The time point to find the immediate predecessor and successor
- for
-
- Returns:
- tuple[Node | None, Node | None]: The last node before time with the given
- track id, and the first node after time with the given track id,
- or Nones if there are no such nodes.
- """
- annotator = self.track_annotator
- if (
- track_id not in annotator.tracklet_id_to_nodes
- or len(annotator.tracklet_id_to_nodes[track_id]) == 0
- ):
- return None, None
- candidates = annotator.tracklet_id_to_nodes[track_id]
- candidates.sort(key=lambda n: self.get_time(n))
-
- pred = None
- succ = None
- for cand in candidates:
- if self.get_time(cand) < time:
- pred = cand
- elif self.get_time(cand) > time:
- succ = cand
- break
- return (
- int(pred) if pred is not None else None,
- int(succ) if succ is not None else None,
- )
-
- def has_track_id_at_time(self, track_id: int, time: int) -> bool:
- """Function to check if a node with given track id exists at given time point.
-
- Args:
- track_id (int): The track id to search for.
- time (int): The time point to check.
-
- Returns:
- True if a node with given track id exists at given time point.
- """
-
- nodes = self.track_id_to_node.get(track_id)
- if not nodes:
- return False
-
- return time in [self.get_time(node) for node in nodes]
diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py
index 0cc16dc8..6131dcb6 100644
--- a/src/funtracks/data_model/tracks.py
+++ b/src/funtracks/data_model/tracks.py
@@ -17,6 +17,7 @@
from tracksdata.nodes import Mask
from funtracks.actions.action_history import ActionHistory
+from funtracks.annotators import TrackAnnotator
from funtracks.features import (
Feature,
FeatureDict,
@@ -52,8 +53,11 @@ class Tracks:
position attribute. Edges in the graph represent links across time.
Attributes:
- graph (td.graph.GraphView): A graph with nodes representing detections and
- and edges representing links across time.
+ graph_full (td.graph.BaseGraph): The full graph (first-class): every node/edge
+ ever known, including soft-deleted (solution=False) candidates. Nodes
+ represent detections, edges represent links across time.
+ graph_solution (td.graph.GraphView): A solution==True view derived from
+ graph_full; the user-visible tracking solution.
features (FeatureDict): Dictionary of features tracked on graph nodes/edges.
annotators (AnnotatorRegistry): List of annotators that compute features.
scale (list[float] | None): How much to scale each dimension by, including time.
@@ -65,7 +69,7 @@ class Tracks:
def __init__(
self,
- graph: td.graph.GraphView,
+ graph: td.graph.BaseGraph,
time_attr: str | None = None,
pos_attr: str | tuple[str, ...] | list[str] | None = None,
tracklet_attr: str | None = None,
@@ -78,8 +82,9 @@ def __init__(
"""Initialize a Tracks object.
Args:
- graph (td.graph.GraphView): tracksdata directed graph with nodes as detections
- and edges as links.
+ graph (td.graph.BaseGraph): the full base graph (graph_full) with nodes as
+ detections and edges as links. Must be a base graph, not a view;
+ graph_solution is built internally as its solution==True view.
time_attr (str | None): Graph attribute name for time. Defaults to "time"
if None.
pos_attr (str | tuple[str, ...] | list[str] | None): Graph attribute
@@ -88,7 +93,9 @@ def __init__(
- List/tuple of strings for multi-axis (one attribute per axis)
Defaults to "pos" if None.
tracklet_attr (str | None): Graph attribute name for tracklet/track IDs.
- Defaults to "tracklet_id" if None.
+ Defaults to "tracklet_id" if None. Every Tracks gets a TrackAnnotator
+ and track ids (no "plain" vs "solution" distinction); existing ids on
+ the graph are reused, otherwise they are computed.
lineage_attr (str | None): Graph attribute name for lineage IDs.
Defaults to "lineage_id" if None.
scale (list[float] | None): Scaling factors for each dimension (including
@@ -103,7 +110,23 @@ def __init__(
_segmentation (GraphArrayView | None): Internal parameter for reusing an
existing GraphArrayView instance. Not intended for public use.
"""
- self.graph = graph
+ # graph_full is the first-class object: the base graph holding every node/edge
+ # ever known, including soft-deleted (solution=False) candidates. graph_solution
+ # is derived from it as a solution==True view.
+ if isinstance(graph, td.graph.GraphView):
+ raise ValueError(
+ "Tracks requires the full base graph (graph_full), not a GraphView. "
+ "graph_solution is built internally as a solution==True view of it."
+ )
+ if "solution" not in graph.node_attr_keys():
+ graph.add_node_attr_key("solution", default_value=True, dtype=pl.Boolean)
+ if "solution" not in graph.edge_attr_keys():
+ graph.add_edge_attr_key("solution", default_value=True, dtype=pl.Boolean)
+ self.graph_full = graph
+ self.graph_solution = graph.filter(
+ td.NodeAttr("solution") == True, # noqa: E712
+ td.EdgeAttr("solution") == True, # noqa: E712
+ ).subgraph()
if _segmentation is not None:
# Reuse provided segmentation instance (internal use only)
self.segmentation = _segmentation
@@ -115,8 +138,10 @@ def __init__(
seg_shape = graph.metadata.get("segmentation_shape")
if seg_shape is not None:
try:
+ # Render the segmentation from the solution view so soft-deleted
+ # nodes drop out of the array, mirroring the user-visible graph.
array_view = GraphArrayView(
- graph=graph,
+ graph=self.graph_solution,
shape=seg_shape,
attr_key="node_id",
offset=0,
@@ -173,6 +198,11 @@ def __init__(
else:
self._setup_core_computed_features()
+ # 4. Enforce the track-id invariant on BOTH paths: every Tracks has a tracklet
+ # key and a registered TrackAnnotator, with tracklet_id/lineage_id registered
+ # and computed. A provided FeatureDict that omitted them is completed here.
+ self._ensure_track_features()
+
def _get_feature_set(
self,
time_attr: str | None,
@@ -193,11 +223,10 @@ def _get_feature_set(
- Single string: one attribute containing position array (e.g., "pos")
- List/tuple: multiple attributes, one per axis (e.g., ["y", "x"])
- None: defaults to "pos"
- tracklet_key: Graph attribute name for tracklet/track IDs
- (e.g., "tracklet_id").
- If None, defaults to "tracklet_id"
- lineage_key: Graph attribute name for lineage IDs (e.g., "lineage_id").
- if None, defaults to "lineage_id"
+ tracklet_key: Graph attribute name for tracklet/track IDs.
+ Defaults to "tracklet_id" if None (every Tracks gets track ids).
+ lineage_key: Graph attribute name for lineage IDs.
+ Defaults to "lineage_id" if None.
Returns:
FeatureDict initialized with time feature and position if no segmentation
@@ -206,10 +235,11 @@ def _get_feature_set(
time_key = time_attr if time_attr is not None else "time"
if pos_attr is None:
pos_attr = "pos"
- if tracklet_key is None:
- tracklet_key = "tracklet_id"
- if lineage_key is None:
- lineage_key = "lineage_id"
+ # Every Tracks has a tracklet/lineage key (no "plain" vs "solution" split):
+ # default them like time/pos. _ensure_track_features() then registers and
+ # computes the ids; on an empty solution view that is a no-op.
+ tracklet_key = tracklet_key if tracklet_key is not None else "tracklet_id"
+ lineage_key = lineage_key if lineage_key is not None else "lineage_id"
# Build static features dict - always include time
features: dict[str, Feature] = {time_key: Time()}
@@ -247,7 +277,7 @@ def _get_feature_set(
# else: single pos_attr with segmentation - RegionpropsAnnotator will handle it
# Register solution feature when present on the graph
- if "solution" in self.graph.node_attr_keys():
+ if "solution" in self.graph_solution.node_attr_keys():
feature_dict["solution"] = Solution()
# Register mask and bbox features if segmentation exists
@@ -265,7 +295,7 @@ def _get_annotators(self) -> AnnotatorRegistry:
Creates annotators conditionally:
- RegionpropsAnnotator: Only if segmentation is provided
- EdgeAnnotator: Only if segmentation is provided
- - TrackAnnotator: Only if this is a SolutionTracks instance
+ - TrackAnnotator: Always (every Tracks has track ids)
Each annotator is configured with appropriate keys from self.features.
@@ -277,7 +307,6 @@ def _get_annotators(self) -> AnnotatorRegistry:
AnnotatorRegistry,
EdgeAnnotator,
RegionpropsAnnotator,
- TrackAnnotator,
)
annotator_list: list[GraphAnnotator] = []
@@ -296,15 +325,16 @@ def _get_annotators(self) -> AnnotatorRegistry:
if EdgeAnnotator.can_annotate(self):
annotator_list.append(EdgeAnnotator(self))
- # TrackAnnotator: requires SolutionTracks (checked in can_annotate)
- if TrackAnnotator.can_annotate(self):
- annotator_list.append(
- TrackAnnotator(
- self, # type: ignore[arg-type]
- tracklet_key=self.features.tracklet_key,
- lineage_key=self.features.lineage_key,
- )
+ # TrackAnnotator is registered on every Tracks — track ids are a core feature,
+ # not a separate "type" of tracks. On an empty solution view it simply computes
+ # nothing until nodes are added.
+ annotator_list.append(
+ TrackAnnotator(
+ self,
+ tracklet_key=self.features.tracklet_key,
+ lineage_key=self.features.lineage_key,
)
+ )
return AnnotatorRegistry(annotator_list)
def _activate_features_from_dict(self) -> None:
@@ -326,23 +356,22 @@ def _check_existing_feature(self, key: str) -> bool:
bool: True if the key is on the first sampled node or there are no nodes,
and False if missing from the first node.
"""
- if self.graph.num_nodes() == 0:
+ if self.graph_solution.num_nodes() == 0:
return True
# Check which attributes exist
- node_attrs = set(self.graph.node_attr_keys())
+ node_attrs = set(self.graph_solution.node_attr_keys())
return key in node_attrs
def _setup_core_computed_features(self) -> None:
- """Sets up core computed features (position, tracklet, lineage).
+ """Sets up core computed position features.
- Registers position/tracklet/lineage features from annotators into
- FeatureDict. For each core feature:
- - Activates any features that are detected to already exist on the graph
- - Enables (computes) any features that don't exist yet
+ Registers the position feature from the RegionpropsAnnotator into the
+ FeatureDict, activating it if it already exists on the graph or computing it
+ otherwise. Track-id features are handled separately by _ensure_track_features.
"""
# Import here to avoid circular dependency
- from funtracks.annotators import RegionpropsAnnotator, TrackAnnotator
+ from funtracks.annotators import RegionpropsAnnotator
core_features: list[str] = []
for annotator in self.annotators:
@@ -351,14 +380,12 @@ def _setup_core_computed_features(self) -> None:
if self.features.position_key is None:
self.features.position_key = pos_key
core_features.append(pos_key)
- elif isinstance(annotator, TrackAnnotator):
- tracklet_key = annotator.tracklet_key
- self.features.tracklet_key = tracklet_key
- core_features.append(tracklet_key)
- lineage_key = annotator.lineage_key
- self.features.lineage_key = lineage_key
- core_features.append(lineage_key)
- for key in core_features:
+ self._register_core_features(core_features)
+
+ def _register_core_features(self, keys: list[str]) -> None:
+ """Register each key as a feature: activate it if it already exists on the
+ graph, otherwise enable (compute) it."""
+ for key in keys:
if self._check_existing_feature(key):
if key not in self.features:
feature, _ = self.annotators.all_features[key]
@@ -367,38 +394,53 @@ def _setup_core_computed_features(self) -> None:
else:
self.enable_features([key])
- def nodes(self):
- return np.array(self.graph.node_ids())
+ def _ensure_track_features(self) -> None:
+ """Ensure the track-id core features exist on this Tracks.
- def edges(self):
- return np.array(self.graph.edge_ids())
+ Every Tracks has a registered TrackAnnotator and a tracklet key (no "plain"
+ vs "solution" split). This syncs features.tracklet_key/lineage_key from the
+ annotator and registers + computes (or activates, if already present)
+ tracklet_id/lineage_id. Runs on both the provided-FeatureDict and the
+ auto-detect construction paths; a no-op on an empty solution view.
+ """
+ annotator = self.track_annotator
+ self.features.tracklet_key = annotator.tracklet_key
+ self.features.lineage_key = annotator.lineage_key
+ # A tracklet column can exist yet still hold the -1 sentinel ("not computed",
+ # the column default). Trusting it would activate stale ids and seed a phantom
+ # tracklet -1 in the annotator bookkeeping, so force a recompute from topology.
+ if self._has_uncomputed_track_ids(annotator.tracklet_key):
+ self.enable_features([annotator.tracklet_key, annotator.lineage_key])
+ else:
+ self._register_core_features([annotator.tracklet_key, annotator.lineage_key])
- def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray:
- """Get the in-degree edge_ids of the nodes in the graph."""
- if nodes is not None:
- # make sure nodes is a numpy array
- if not isinstance(nodes, np.ndarray):
- nodes = np.array(nodes)
+ def _has_uncomputed_track_ids(self, tracklet_key: str) -> bool:
+ """True if the tracklet column exists but any node still holds the -1 sentinel.
- return np.array([self.graph.in_degree(node.item()) for node in nodes])
- else:
- return np.array(self.graph.in_degree())
+ A missing column returns False: _register_core_features computes it from scratch.
+ """
+ if self.graph_solution.num_nodes() == 0:
+ return False
+ if tracklet_key not in self.graph_solution.node_attr_keys():
+ return False
+ values = self.graph_solution.node_attrs(attr_keys=[tracklet_key])[tracklet_key]
+ return bool((values == -1).any())
- def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray:
- if nodes is not None:
- # make sure nodes is a numpy array
- if not isinstance(nodes, np.ndarray):
- nodes = np.array(nodes)
+ def nodes(self):
+ """Return the node ids of the solution graph as a numpy array."""
+ return np.array(self.graph_solution.node_ids())
- return np.array([self.graph.out_degree(node.item()) for node in nodes])
- else:
- return np.array(self.graph.out_degree())
+ def edges(self):
+ """Return the edge ids of the solution graph as a numpy array."""
+ return np.array(self.graph_solution.edge_ids())
def predecessors(self, node: int) -> list[int]:
- return list(self.graph.predecessors(node))
+ """Return the predecessors of a node in the solution graph."""
+ return list(self.graph_solution.predecessors(node))
def successors(self, node: int) -> list[int]:
- return list(self.graph.successors(node))
+ """Return the successors of a node in the solution graph."""
+ return list(self.graph_solution.successors(node))
def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.ndarray:
"""Get the positions of nodes in the graph. Optionally include the
@@ -409,7 +451,7 @@ def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.nd
For a single node use get_position() instead.
Args:
- node (Iterable[Node]): The node ids in the graph to get the positions of
+ nodes (Iterable[Node]): The node ids in the graph to get the positions of
incl_time (bool, optional): If true, include the time as the
first element of each position array. Defaults to False.
@@ -430,7 +472,10 @@ def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.nd
+ ([self.features.time_key] if incl_time else [])
)
- df = self.graph.node_attrs(attr_keys=attr_keys)
+ # Read from graph_full (consistent with get_position / the attr-helper policy):
+ # positions are intrinsic node attrs, so this also resolves soft-deleted
+ # (solution=False) nodes instead of KeyError-ing like a graph_solution query.
+ df = self.graph_full.node_attrs(attr_keys=attr_keys)
id_to_row = {
nid: i for i, nid in enumerate(df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list())
}
@@ -497,6 +542,7 @@ def set_positions(
self._set_nodes_attr(nodes, self.features.position_key, positions)
def set_position(self, node: Node, position: list | np.ndarray):
+ """Set the position of a single node."""
self.set_positions([node], np.expand_dims(np.array(position), axis=0))
def get_times(self, nodes: Iterable[Node]) -> Sequence[int]:
@@ -505,7 +551,7 @@ def get_times(self, nodes: Iterable[Node]) -> Sequence[int]:
For a single node use get_time() instead.
"""
nodes = list(nodes)
- df = self.graph.node_attrs(
+ df = self.graph_full.node_attrs(
attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, self.features.time_key]
)
id_to_val = dict(
@@ -546,7 +592,7 @@ def get_mask(
if self.segmentation is None:
return None
- mask = self.graph.nodes[node][mask_key]
+ mask = self.graph_full.nodes[node][mask_key]
return mask
def update_mask(
@@ -563,13 +609,13 @@ def update_mask(
mask_key: The feature key for the mask column.
Defaults to the standard mask key.
"""
- self.graph.nodes[node][mask_key] = mask
+ self.graph_full.nodes[node][mask_key] = mask
mask_feature = self.features.get(mask_key)
if mask_feature is not None:
# NOTE: all derived features of a mask are currently assumed to be
# its bounding box. Revisit if non-bbox derived features are added.
for derived_key in mask_feature.get("derived_features", []):
- self.graph.nodes[node][derived_key] = mask.bbox
+ self.graph_full.nodes[node][derived_key] = mask.bbox
def undo(self) -> bool:
"""Undo the last performed action from the action history.
@@ -603,10 +649,13 @@ def _get_new_node_ids(self, n: int) -> list[Node]:
Returns:
list[Node]: A list of new node ids.
"""
+ # Check against graph_full, not the solution view: a soft-deleted
+ # (solution=False) node still occupies its id in the full graph, so reissuing
+ # it to a genuinely new node would collide with the still-present root node.
ids = [self.node_id_counter + i for i in range(n)]
self.node_id_counter += n
for idx, _id in enumerate(ids):
- while self.graph.has_node(_id):
+ while self.graph_full.has_node(_id):
_id = self.node_id_counter
self.node_id_counter += 1
ids[idx] = _id
@@ -635,39 +684,56 @@ def _compute_ndim(
)
return ndim
+ # NOTE: the low-level attribute get/set helpers below target `graph_full`, not the
+ # solution view. Attribute *values* are intrinsic to a node/edge and live on the full
+ # graph; the solution view is a membership filter over it. Because graph_full ⊇
+ # graph_solution and their attr dicts are shared by reference, writing via graph_full
+ # is identical to writing via the view for any in-solution node (the view sees it
+ # automatically) and additionally works for soft-deleted (solution=False) candidates.
def _set_node_attr(self, node: Node, attr: str, value: Any):
if isinstance(value, np.ndarray):
value = list(value)
- self.graph.nodes[node][attr] = value
+ self.graph_full.nodes[node][attr] = value
def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]):
nodes_list = list(nodes)
values_list = list(values)
if nodes_list:
- self.graph.update_node_attrs(attrs={attr: values_list}, node_ids=nodes_list)
+ self.graph_full.update_node_attrs(
+ attrs={attr: values_list}, node_ids=nodes_list
+ )
def get_node_attr(self, node: Node, attr: str):
- return self.graph.nodes[int(node)][attr]
+ """Get an attribute value for a single node (resolved on graph_full)."""
+ return self.graph_full.nodes[int(node)][attr]
def get_nodes_attr(self, nodes: Iterable[Node], attr: str):
+ """Get an attribute value for each of the given nodes."""
return [self.get_node_attr(node, attr) for node in nodes]
def _set_edge_attr(self, edge: Edge, attr: str, value: Any):
- edge_id = self.graph.edge_id(edge[0], edge[1])
- self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id])
+ edge_id = self.graph_full.edge_id(edge[0], edge[1])
+ # Wrap in a single-element list: update_edge_attrs reads a bare list value
+ # (e.g. a vector feature) as one-value-per-edge.
+ self.graph_full.update_edge_attrs(attrs={attr: [value]}, edge_ids=[edge_id])
def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]):
for edge, value in zip(edges, values, strict=False):
- edge_id = self.graph.edge_id(edge[0], edge[1])
- self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id])
+ edge_id = self.graph_full.edge_id(edge[0], edge[1])
+ self.graph_full.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id])
def get_edge_attr(self, edge: Edge, attr: str):
- if attr not in self.graph.edge_attr_keys():
+ """Get an attribute value for a single edge (resolved on graph_full).
+
+ Returns None if the attribute is not registered on the graph.
+ """
+ if attr not in self.graph_full.edge_attr_keys():
return None
- edge_id = self.graph.edge_id(edge[0], edge[1])
- return self.graph.edges[edge_id][attr]
+ edge_id = self.graph_full.edge_id(edge[0], edge[1])
+ return self.graph_full.edges[edge_id][attr]
def get_edges_attr(self, edges: Iterable[Edge], attr: str):
+ """Get an attribute value for each of the given edges."""
return [self.get_edge_attr(edge, attr) for edge in edges]
# ========== Feature Management ==========
@@ -750,24 +816,36 @@ def add_feature(self, key: str, feature: Feature) -> None:
# Add to the features dictionary
self.features[key] = feature
- # Perform custom graph operations when a feature is added
+ # Perform custom graph operations when a feature is added.
+ #
+ # Schema (attr-key) registration is done on graph_solution (the view), NOT on
+ # graph_full, even though annotators write the VALUES to graph_full. This relies
+ # on a tracksdata invariant: adding an attr key to a view propagates up to its
+ # root, so the column ends up on both. The reverse does NOT hold today — adding
+ # a key directly to the root is not propagated down into an existing view — so
+ # registering on graph_full would leave graph_solution without the column.
+ # If tracksdata ever makes view attr-key additions local, revisit this.
ft = feature["feature_type"]
- if "node" in ft and key not in self.graph.node_attr_keys():
+ if "node" in ft and key not in self.graph_solution.node_attr_keys():
# "mask" value_type maps to pl.Object via to_polars_dtype
dtype = to_polars_dtype(feature["value_type"])
num_values = feature.get("num_values")
if num_values is not None and num_values > 1:
dtype = pl.Array(dtype, num_values)
- self.graph.add_node_attr_key(
+ self.graph_solution.add_node_attr_key(
key,
default_value=feature["default_value"],
dtype=dtype,
)
- if "edge" in ft and key not in self.graph.edge_attr_keys():
- self.graph.add_edge_attr_key(
+ if "edge" in ft and key not in self.graph_solution.edge_attr_keys():
+ dtype = to_polars_dtype(feature["value_type"])
+ num_values = feature.get("num_values")
+ if num_values is not None and num_values > 1:
+ dtype = pl.Array(dtype, num_values)
+ self.graph_solution.add_edge_attr_key(
key,
default_value=feature["default_value"],
- dtype=to_polars_dtype(feature["value_type"]),
+ dtype=dtype,
)
def delete_feature(self, key: str) -> None:
@@ -799,8 +877,136 @@ def delete_feature(self, key: str) -> None:
else:
return
- # Perform custom graph operations when a feature is deleted
- if "node" in feature_type and key in self.graph.node_attr_keys():
- self.graph.remove_node_attr_key(key)
- if "edge" in feature_type and key in self.graph.edge_attr_keys():
- self.graph.remove_edge_attr_key(key)
+ # Perform custom graph operations when a feature is deleted. Schema ops go
+ # through graph_solution (the view) and propagate to the root — same tracksdata
+ # invariant as add_feature (see the note there).
+ if "node" in feature_type and key in self.graph_solution.node_attr_keys():
+ self.graph_solution.remove_node_attr_key(key)
+ if "edge" in feature_type and key in self.graph_solution.edge_attr_keys():
+ self.graph_solution.remove_edge_attr_key(key)
+
+ # ========== Track ID management (solution view) ==========
+ # These operate on the solution view via the TrackAnnotator, which every Tracks
+ # has (track ids are a core feature). On an empty solution view they are no-ops.
+
+ @property
+ def track_annotator(self) -> TrackAnnotator:
+ """The registered TrackAnnotator. Always present, since track ids are a core
+ feature of every Tracks (_get_annotators registers one unconditionally)."""
+ for annotator in self.annotators:
+ if isinstance(annotator, TrackAnnotator):
+ return annotator
+ raise RuntimeError(
+ "No TrackAnnotator registered on this Tracks — this should be unreachable "
+ "(_get_annotators always registers one)."
+ )
+
+ @property
+ def max_track_id(self) -> int:
+ """The maximum tracklet id currently in use."""
+ return self.track_annotator.max_tracklet_id
+
+ def get_next_track_id(self) -> int:
+ """Return the next available track_id.
+
+ The max_tracklet_id in TrackAnnotator is updated automatically when
+ a node is added or track IDs are updated via UpdateTrackIDs.
+ """
+ return self.track_annotator.max_tracklet_id + 1
+
+ def get_next_lineage_id(self) -> int:
+ """Return the next available lineage_id.
+
+ The max_lineage_id in TrackAnnotator is updated automatically when
+ a node is added or lineage IDs are updated via UpdateTrackIDs.
+ """
+ return self.track_annotator.max_lineage_id + 1
+
+ def get_track_id(self, node) -> int:
+ """Get the tracklet id of a single node."""
+ track_id = self.get_node_attr(node, self.features.tracklet_key)
+ return track_id
+
+ def get_track_ids(self, nodes) -> list[int]:
+ """Batch version of get_track_id — one query fetching all nodes in the graph.
+ NOTE: always fetches the entire graph internally. Optimised for bulk (all-node)
+ calls. For small subsets or single nodes use get_track_id() instead."""
+
+ tracklet_key = self.features.tracklet_key
+ df = self.graph_full.node_attrs(
+ attr_keys=[td.DEFAULT_ATTR_KEYS.NODE_ID, tracklet_key]
+ )
+ id_to_val = dict(
+ zip(
+ df[td.DEFAULT_ATTR_KEYS.NODE_ID].to_list(),
+ df[tracklet_key].to_list(),
+ strict=True,
+ )
+ )
+ return [id_to_val[node] for node in nodes]
+
+ def get_lineage_id(self, node) -> int:
+ """Get the lineage ID for a node.
+
+ Args:
+ node: The node to get lineage ID for
+
+ Returns:
+ The lineage ID.
+ """
+ return self.get_node_attr(node, self.features.lineage_key)
+
+ def get_track_neighbors(
+ self, track_id: int, time: int
+ ) -> tuple[Node | None, Node | None]:
+ """Get the last node with the given track id before time, and the first node
+ with the track id after time, if any. Does not assume that a node with
+ the given track_id and time is already in tracks, but it can be.
+
+ Args:
+ track_id (int): The track id to search for
+ time (int): The time point to find the immediate predecessor and successor
+ for
+
+ Returns:
+ tuple[Node | None, Node | None]: The last node before time with the given
+ track id, and the first node after time with the given track id,
+ or Nones if there are no such nodes.
+ """
+ if (
+ track_id not in self.track_annotator.tracklet_id_to_nodes
+ or len(self.track_annotator.tracklet_id_to_nodes[track_id]) == 0
+ ):
+ return None, None
+ candidates = sorted(
+ self.track_annotator.tracklet_id_to_nodes[track_id], key=self.get_time
+ )
+
+ pred = None
+ succ = None
+ for cand in candidates:
+ if self.get_time(cand) < time:
+ pred = cand
+ elif self.get_time(cand) > time:
+ succ = cand
+ break
+ return (
+ int(pred) if pred is not None else None,
+ int(succ) if succ is not None else None,
+ )
+
+ def has_track_id_at_time(self, track_id: int, time: int) -> bool:
+ """Function to check if a node with given track id exists at given time point.
+
+ Args:
+ track_id (int): The track id to search for.
+ time (int): The time point to check.
+
+ Returns:
+ True if a node with given track id exists at given time point.
+ """
+ nodes = self.track_annotator.tracklet_id_to_nodes.get(track_id)
+ if not nodes:
+ return False
+
+ return time in [self.get_time(node) for node in nodes]
diff --git a/src/funtracks/import_export/_export_segmentation.py b/src/funtracks/import_export/_export_segmentation.py
index cef2ccd8..97f9e146 100644
--- a/src/funtracks/import_export/_export_segmentation.py
+++ b/src/funtracks/import_export/_export_segmentation.py
@@ -45,7 +45,7 @@ def resolve_relabel_attr(
else:
return None
- existing_attrs = tracks.graph.node_attr_keys()
+ existing_attrs = tracks.graph_solution.node_attr_keys()
if label_attr not in existing_attrs:
raise ValueError(
f"relabel='{relabel}' resolved to attribute '{label_attr}', "
@@ -95,13 +95,12 @@ def export_segmentation(
shape = tracks.segmentation.shape
- graph = (
- tracks.graph.filter(node_ids=list(node_ids)).subgraph()
- if node_ids is not None
- else tracks.graph
- )
-
if label_attr is not None:
+ graph = (
+ tracks.graph_solution.filter(node_ids=list(node_ids)).subgraph()
+ if node_ids is not None
+ else tracks.graph_solution
+ )
view = GraphArrayView(graph, label_attr, shape=shape)
def get_frame(t: int) -> np.ndarray:
diff --git a/src/funtracks/import_export/_import_segmentation.py b/src/funtracks/import_export/_import_segmentation.py
index e1545d94..db04c8a3 100644
--- a/src/funtracks/import_export/_import_segmentation.py
+++ b/src/funtracks/import_export/_import_segmentation.py
@@ -46,11 +46,11 @@ def load_segmentation(segmentation: Path | np.ndarray | da.Array) -> da.Array:
def relabel_segmentation(
seg_array: da.Array | np.ndarray,
- graph: td.graph.GraphView,
+ graph: td.graph.BaseGraph,
node_ids: ArrayLike,
seg_ids: ArrayLike,
time_values: ArrayLike,
-) -> tuple[np.ndarray, td.graph.GraphView]:
+) -> tuple[np.ndarray, td.graph.BaseGraph]:
"""Relabel segmentation from seg_id to node_id.
Handles the case where node_id 0 exists by offsetting all node IDs by 1,
@@ -58,7 +58,7 @@ def relabel_segmentation(
Args:
seg_array: Segmentation array (dask or numpy)
- graph: tracksdata GraphView (will be relabeled if node_id 0 exists)
+ graph: tracksdata base graph (will be relabeled if node_id 0 exists)
node_ids: Array of node IDs
seg_ids: Array of segmentation IDs corresponding to each node
time_values: Array of time values for each node
diff --git a/src/funtracks/import_export/_tracks_builder.py b/src/funtracks/import_export/_tracks_builder.py
index 7ed3ce75..2bfdd0f8 100644
--- a/src/funtracks/import_export/_tracks_builder.py
+++ b/src/funtracks/import_export/_tracks_builder.py
@@ -1,6 +1,6 @@
"""Builder pattern for importing tracks from various formats.
-This module provides a unified interface for constructing SolutionTracks objects
+This module provides a unified interface for constructing Tracks objects
from different data sources (GEFF, CSV, etc.) while sharing common validation
and construction logic.
"""
@@ -15,7 +15,7 @@
import tracksdata as td
from geff._typing import InMemoryGeff
-from funtracks.data_model.solution_tracks import SolutionTracks
+from funtracks.data_model.tracks import Tracks
from funtracks.features import Feature
from funtracks.import_export._import_segmentation import (
load_segmentation,
@@ -39,7 +39,7 @@
)
from funtracks.utils.tracksdata_utils import (
add_masks_and_bboxes_to_graph,
- create_empty_graphview_graph,
+ create_empty_graph,
)
if TYPE_CHECKING:
@@ -415,7 +415,7 @@ def construct_graph(
self,
node_name_map: dict[str, str | list[str]] | None = None,
database: str | None = None,
- ) -> td.graph.GraphView:
+ ) -> td.graph.BaseGraph:
"""Construct Tracksdata graph from validated InMemoryGeff data.
Common logic shared across all formats.
@@ -427,7 +427,7 @@ def construct_graph(
If None (default), an in-memory/temp graph is used.
Returns:
- Tracksdata GraphView with standard keys
+ Tracksdata base graph with standard keys
Raises:
ValueError: If data not loaded or validated
@@ -481,7 +481,7 @@ def construct_graph(
default_value = 0
node_default_values.append(default_value)
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=list(self.in_memory_geff["node_props"].keys()),
edge_attributes=list(self.in_memory_geff["edge_props"].keys()),
node_default_values=node_default_values,
@@ -536,23 +536,17 @@ def construct_graph(
if self.TIME_ATTR != "t":
graph.remove_node_attr_key(self.TIME_ATTR)
- # create_empty_graphview_graph returns a filtered view, but that view is
- # a snapshot at filter time; nodes/edges added afterwards (potentially with
- # solution=False) bypass the filter. Re-filter the populated root so
- # solution=False rows are actually excluded.
- graph = graph._root.filter(
- td.NodeAttr("solution") == True, # noqa: E712
- td.EdgeAttr("solution") == True, # noqa: E712
- ).subgraph()
-
+ # Return the full base graph. Tracks builds the solution==True view internally,
+ # so solution=False candidates added during population stay in graph_full and
+ # are excluded from graph_solution by Tracks, not here.
return graph
def handle_segmentation(
self,
- graph: td.graph.GraphView,
+ graph: td.graph.BaseGraph,
segmentation: Path | np.ndarray | None,
scale: list[float] | None,
- ) -> tuple[np.ndarray | None, list[float] | None, td.graph.GraphView]:
+ ) -> tuple[np.ndarray | None, list[float] | None, td.graph.BaseGraph]:
"""Load, validate, and optionally relabel segmentation.
Common logic shared across all formats.
@@ -619,13 +613,13 @@ def handle_segmentation(
return new_segmentation, scale, graph
- # Structural keys that are handled by graph construction / SolutionTracks.__init__
+ # Structural keys that are handled by graph construction / Tracks.__init__
# and should not be registered as features by enable_features().
STRUCTURAL_KEYS = frozenset({"time", "id", "parent_id", "seg_id"})
def enable_features(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
name_map: dict[str, str | list[str]],
feature_type: Literal["node", "edge"] = "node",
) -> None:
@@ -638,7 +632,7 @@ def enable_features(
- Otherwise, if data was loaded for it, register it as a static feature.
Args:
- tracks: SolutionTracks object to add features to
+ tracks: Tracks object to add features to
name_map: Mapping from standard funtracks keys to source property
names (same format as node_name_map / edge_name_map).
feature_type: Type of features ("node" or "edge")
@@ -685,7 +679,7 @@ def build(
scale: list[float] | None = None,
node_name_map: dict[str, str | list[str]] | None = None,
database: str | None = None,
- ) -> SolutionTracks:
+ ) -> Tracks:
"""Orchestrate the full construction process.
Args:
@@ -697,7 +691,7 @@ def build(
If None (default), an in-memory/temp graph is used.
Returns:
- Fully constructed SolutionTracks object
+ Fully constructed Tracks object
Raises:
ValueError: If self.node_name_map is not set or validation fails
@@ -769,22 +763,27 @@ def build(
if segmentation_array is not None:
graph = add_masks_and_bboxes_to_graph(graph, segmentation_array)
- # 7. Create SolutionTracks
+ # 7. Create Tracks
# construct_graph() always stores time as "t" (tracksdata convention),
# regardless of TIME_ATTR, so we pass "t" here explicitly.
# If a FeatureDict was loaded (e.g., from GEFF metadata), use it directly
if hasattr(self, "features") and self.features is not None:
- tracks = SolutionTracks(
+ tracks = Tracks(
graph=graph,
ndim=self.ndim,
scale=scale,
features=self.features,
)
else:
- tracks = SolutionTracks(
+ # The builder always produces a solution, so declare tracklet/lineage
+ # intent to register a TrackAnnotator. Tracks.__init__ auto-detects whether
+ # these attrs already exist on the graph (activate) or need computing.
+ tracks = Tracks(
graph=graph,
pos_attr="pos",
time_attr="t",
+ tracklet_attr="tracklet_id",
+ lineage_attr="lineage_id",
ndim=self.ndim,
scale=scale,
)
diff --git a/src/funtracks/import_export/_utils.py b/src/funtracks/import_export/_utils.py
index 334cf42d..2f8d2feb 100644
--- a/src/funtracks/import_export/_utils.py
+++ b/src/funtracks/import_export/_utils.py
@@ -71,7 +71,7 @@ def infer_dtype_from_array(arr: ArrayLike) -> ValueType:
def filter_graph_with_ancestors(
- graph: td.graph.GraphView, nodes_to_keep: set[int]
+ graph: td.graph.BaseGraph, nodes_to_keep: set[int]
) -> list[int]:
"""Filter a graph to keep only the nodes in `nodes_to_keep` and their ancestors.
diff --git a/src/funtracks/import_export/_v1_format.py b/src/funtracks/import_export/_v1_format.py
index 34856ff1..1bff2c12 100644
--- a/src/funtracks/import_export/_v1_format.py
+++ b/src/funtracks/import_export/_v1_format.py
@@ -16,16 +16,14 @@
)
if TYPE_CHECKING:
- from ..data_model import SolutionTracks, Tracks
+ from ..data_model import Tracks
GRAPH_FILE = "graph.json"
SEG_FILE = "seg.npy"
ATTRS_FILE = "attrs.json"
-def load_v1_tracks(
- directory: Path, seg_required: bool = False, solution: bool = False
-) -> Tracks | SolutionTracks:
+def load_v1_tracks(directory: Path, seg_required: bool = False) -> Tracks:
"""Load a Tracks object from the given directory. Looks for files
in the format generated by Tracks.save.
@@ -35,8 +33,6 @@ def load_v1_tracks(
directory (Path): The directory containing tracks to load
seg_required (bool, optional): If true, raises a FileNotFoundError if the
segmentation file is not present in the directory. Defaults to False.
- solution (bool, optional): If true, returns a SolutionTracks object, otherwise
- returns a normal Tracks object. Defaults to False.
Returns:
Tracks: A tracks object loaded from the given directory
@@ -108,23 +104,24 @@ def load_v1_tracks(
)
features[td.DEFAULT_ATTR_KEYS.BBOX] = SegBbox(ndim)
- # filtering the warnings because the default values of time_attr and pos_attr are
- # not None. Therefore, new style Tracks attrs that have features instead of
- # pos_attr and time_attr will always trigger the warning. Updating default values
- # is breaking, and manually setting the attrs to None if features is present will
- # break if the attrs are changed/removed in the future. Can remove in v2.0.
- # Import at runtime to avoid circular dependency
- from ..data_model import SolutionTracks, Tracks
+ # Defensive: suppress Tracks.__init__'s "provided both FeatureDict and attr"
+ # warning in case a save carries both a stored FeatureDict and explicit
+ # pos/time/tracklet attrs. With a FeatureDict present we pass no attr args, so
+ # this normally does not fire, but keep the guard for forward-compat. Can remove
+ # in v2.0. Import at runtime to avoid circular dependency.
+ from ..data_model import Tracks
+
+ # v1 saves are always solutions, so they must come back with track ids. When the
+ # stored attrs don't carry a FeatureDict (older saves), declare tracklet/lineage
+ # intent explicitly so a TrackAnnotator is registered. Newer saves carry a
+ # FeatureDict that already encodes the tracklet/lineage keys.
+ if "features" not in attrs:
+ attrs.setdefault("tracklet_attr", "tracklet_id")
+ attrs.setdefault("lineage_attr", "lineage_id")
with warnings.catch_warnings():
- warnings.filterwarnings(
- "ignore", message="Provided both FeatureDict and pos_attr or time_attr"
- )
- tracks: Tracks
- if solution:
- tracks = SolutionTracks(graph_td, **attrs)
- else:
- tracks = Tracks(graph_td, **attrs)
+ warnings.filterwarnings("ignore", message="Provided both FeatureDict and pos")
+ tracks = Tracks(graph_td, **attrs)
return tracks
diff --git a/src/funtracks/import_export/_validation.py b/src/funtracks/import_export/_validation.py
index 415cf90c..27cc028c 100644
--- a/src/funtracks/import_export/_validation.py
+++ b/src/funtracks/import_export/_validation.py
@@ -22,7 +22,7 @@
def validate_graph_seg_match(
- graph: td.graph.GraphView,
+ graph: td.graph.BaseGraph,
segmentation: da.Array,
scale: list[float],
position_attr: list[str],
diff --git a/src/funtracks/import_export/csv/_export.py b/src/funtracks/import_export/csv/_export.py
index 6f13c546..f4d42c35 100644
--- a/src/funtracks/import_export/csv/_export.py
+++ b/src/funtracks/import_export/csv/_export.py
@@ -12,14 +12,15 @@
from .._utils import filter_graph_with_ancestors
if TYPE_CHECKING:
- from funtracks.data_model.solution_tracks import SolutionTracks
+ from funtracks.data_model.tracks import Tracks
def export_to_csv(
- tracks: SolutionTracks,
+ tracks: Tracks,
outfile: Path | str,
color_dict: dict[int, np.ndarray] | None = None,
node_ids: set[int] | None = None,
+ export_full: bool = False,
use_display_names: bool = False,
export_seg: bool = False,
seg_path: Path | str | None = None,
@@ -36,12 +37,16 @@ def export_to_csv(
tiff. If a color dictionary is provided, it will also export the tracklet colors.
Args:
- tracks: SolutionTracks object containing the tracking data to export
+ tracks: Tracks object containing the tracking data to export
outfile: Path to output CSV file
color_dict: dict[int, np.ndarray], optional. If provided, will be used to save the
hex colors.
node_ids: Optional set of node IDs to include. If provided, only these
nodes and their ancestors will be included in the output.
+ export_full: If True, export the full graph (every node/edge, including
+ soft-deleted/candidate ones with solution=False); the "solution" column is
+ then included so the two can be distinguished. If False (default), export
+ only the solution view.
use_display_names: If True, use feature display names as column headers.
If False (default), use raw feature keys for backward compatibility.
export_seg: Whether to export the segmentation alongside the CSV.
@@ -70,6 +75,10 @@ def export_to_csv(
"""
tracklet_key = tracks.features.tracklet_key
+ # Which graph to export: the full graph (incl. solution=False candidates) or just
+ # the solution view. Topology reads (node ids, ancestors, predecessors) go through
+ # this; attribute reads use the tracks helpers, which already target graph_full.
+ graph = tracks.graph_full if export_full else tracks.graph_solution
def convert_numpy_to_python(value):
"""Convert numpy types to native Python types."""
@@ -122,8 +131,10 @@ def convert_numpy_to_python(value):
# Skip derived features (e.g. bbox managed by mask)
if feature_name in derived_keys:
continue
- # Skip solution — graph is already filtered to solution=True
- if feature_name == "solution":
+ # Skip solution unless exporting the full graph: in the solution-only
+ # export it is always True (uninformative); in a full export it
+ # distinguishes solution nodes from soft-deleted/candidate ones.
+ if feature_name == "solution" and not export_full:
continue
feature_names.append(feature_name)
num_values = feature_dict.get("num_values", 1)
@@ -155,15 +166,15 @@ def convert_numpy_to_python(value):
# Determine which nodes to export
if node_ids is None:
- nodes_to_keep = tracks.graph.node_ids()
+ nodes_to_keep = graph.node_ids()
else:
- nodes_to_keep = filter_graph_with_ancestors(tracks.graph, node_ids)
+ nodes_to_keep = filter_graph_with_ancestors(graph, node_ids)
# Write CSV file
rows: list[dict[str, Any]] = []
for node_id in nodes_to_keep:
- parents = list(tracks.graph.predecessors(node_id))
+ parents = list(graph.predecessors(node_id))
parent_id = "" if len(parents) == 0 else parents[0]
row: dict[str, Any]
@@ -210,7 +221,7 @@ def rgb_to_hex(rgb):
track_id_to_hex = {}
- for track_id, nodes in tracks.track_id_to_node.items():
+ for track_id, nodes in tracks.track_annotator.tracklet_id_to_nodes.items():
if not nodes:
continue
first_node = nodes[0]
diff --git a/src/funtracks/import_export/csv/_import.py b/src/funtracks/import_export/csv/_import.py
index 983a4efd..7aeae87c 100644
--- a/src/funtracks/import_export/csv/_import.py
+++ b/src/funtracks/import_export/csv/_import.py
@@ -17,7 +17,7 @@
from .._tracks_builder import TracksBuilder, flatten_name_map
if TYPE_CHECKING:
- from funtracks.data_model.solution_tracks import SolutionTracks
+ from funtracks.data_model.tracks import Tracks
def _ensure_integer_ids(df: pd.DataFrame) -> pd.DataFrame:
@@ -169,12 +169,12 @@ def tracks_from_df(
segmentation: np.ndarray | None = None,
scale: list[float] | None = None,
node_name_map: dict[str, str | list[str]] | None = None,
-) -> SolutionTracks:
+) -> Tracks:
"""Import tracks from pandas DataFrame.
Turns a pandas DataFrame with columns:
time, [z], y, x, id, parent_id, [seg_id], [optional custom attr 1], ...
- into a SolutionTracks object.
+ into a Tracks object.
Cells without a parent_id will have an empty string or a -1 for the parent_id.
@@ -195,7 +195,7 @@ def tracks_from_df(
If None, column names are auto-inferred using fuzzy matching.
Returns:
- SolutionTracks: a solution tracks object
+ Tracks: a solution tracks object
Raises:
ValueError: if the segmentation IDs in the dataframe do not match the provided
diff --git a/src/funtracks/import_export/geff/_export.py b/src/funtracks/import_export/geff/_export.py
index ad496bcd..7ef00d1d 100644
--- a/src/funtracks/import_export/geff/_export.py
+++ b/src/funtracks/import_export/geff/_export.py
@@ -34,6 +34,11 @@ def write_to_geff(
geff store directly to *path*. Intended for internal save/load workflows
where the user picks the ``.geff`` path.
+ Note: only ``graph_solution`` is written. Soft-deleted (``solution=False``)
+ candidates are dropped, and reimport marks everything ``solution=True`` — so
+ undo-ability of past deletes does not survive a save/load round-trip (Phase-1
+ design; candidate persistence is deferred to the candidate/solver phase).
+
Args:
tracks: Tracks object containing a graph to save.
path: Destination path for the geff store.
@@ -65,6 +70,9 @@ def export_to_geff(
):
"""Export the Tracks graph to geff.
+ Only the solution graph is exported; soft-deleted (``solution=False``)
+ candidates are not included.
+
Args:
tracks (Tracks): Tracks object containing a graph to save.
directory (Path): Destination directory for saving the Zarr.
@@ -88,7 +96,7 @@ def export_to_geff(
if node_ids is not None:
nodes_to_keep = filter_graph_with_ancestors(
- tracks.graph, node_ids
+ tracks.graph_solution, node_ids
) # include the ancestors to make sure the graph is valid and has no missing
# parent nodes.
@@ -98,7 +106,7 @@ def export_to_geff(
# Include the FeatureDict in metadata only for full exports.
# Subgroup exports do not necessarily have valid tracklet/lineage IDs
- # and thus are not valid SolutionTracks
+ # and thus are not valid Tracks
graph, metadata = _build_geff_metadata(tracks, include_features=(node_ids is None))
# Save segmentation if present and requested
@@ -193,7 +201,7 @@ def _write_segmentation_shape(geff_path: Path, tracks: Tracks) -> None:
This allows import_from_geff to reconstruct the segmentation (GraphArrayView)
without requiring an external segmentation file.
"""
- seg_shape = tracks.graph.metadata.get("segmentation_shape")
+ seg_shape = tracks.graph_full.metadata.get("segmentation_shape")
if seg_shape is not None:
import zarr as _zarr
@@ -222,7 +230,7 @@ def split_position_attr(tracks: Tracks) -> tuple[td.graph.GraphView, list[str] |
if isinstance(pos_key, str):
# Position is stored as a single attribute, need to split
- new_graph = tracks.graph.detach()
+ new_graph = tracks.graph_solution.detach()
new_graph = new_graph.filter().subgraph()
# Register new attribute keys
@@ -254,6 +262,6 @@ def split_position_attr(tracks: Tracks) -> tuple[td.graph.GraphView, list[str] |
return new_graph, new_keys
elif pos_key is not None:
# Position is already split into separate attributes
- return tracks.graph, list(pos_key)
+ return tracks.graph_solution, list(pos_key)
else:
- return tracks.graph, None
+ return tracks.graph_solution, None
diff --git a/src/funtracks/import_export/geff/_import.py b/src/funtracks/import_export/geff/_import.py
index 060a481e..8cd6ce9d 100644
--- a/src/funtracks/import_export/geff/_import.py
+++ b/src/funtracks/import_export/geff/_import.py
@@ -13,7 +13,7 @@
if TYPE_CHECKING:
from pathlib import Path
- from funtracks.data_model.solution_tracks import SolutionTracks
+ from funtracks.data_model.tracks import Tracks
# defining constants here because they are only used in the context of import
@@ -170,7 +170,7 @@ def read_header(self, source_path: Path) -> None:
self._geff_axes = metadata.axes or []
# Read funtracks FeatureDict from GEFF extra metadata if present
- # This will be passed to SolutionTracks via the base build() method
+ # This will be passed to Tracks via the base build() method
if metadata.extra and "funtracks" in metadata.extra:
funtracks_extra = metadata.extra["funtracks"]
if "features" in funtracks_extra:
@@ -243,7 +243,7 @@ def construct_graph(
self,
node_name_map: dict[str, str | list[str]] | None = None,
database: str | None = None,
- ) -> td.graph.GraphView:
+ ) -> td.graph.BaseGraph:
"""Construct graph and prepare embedded segmentation data.
The GEFF format serialises mask data as plain numeric arrays (zarr
@@ -321,7 +321,7 @@ def import_from_geff(
scale: list[float] | None = None,
edge_name_map: dict[str, str | list[str]] | None = None,
database: str | None = None,
-) -> SolutionTracks:
+) -> Tracks:
"""Import tracks from GEFF format.
Args:
@@ -341,7 +341,7 @@ def import_from_geff(
If None (default), an in-memory/temp graph is used.
Returns:
- SolutionTracks object
+ Tracks object
"""
# Filter out None values and "None" strings from node_name_map
# (e.g., {"lineage_id": None} or {"lineage_id": "None"})
diff --git a/src/funtracks/user_actions/_user_swap_predecessors.py b/src/funtracks/user_actions/_user_swap_predecessors.py
index 2353a4f0..baefaaa7 100644
--- a/src/funtracks/user_actions/_user_swap_predecessors.py
+++ b/src/funtracks/user_actions/_user_swap_predecessors.py
@@ -9,7 +9,7 @@
from .user_delete_edge import UserDeleteEdge
if TYPE_CHECKING:
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
class UserSwapPredecessors(ActionGroup):
@@ -19,7 +19,7 @@ class UserSwapPredecessors(ActionGroup):
must be earlier in time than both nodes for the swap to be valid.
Args:
- tracks (SolutionTracks): The tracks to perform the swap on.
+ tracks (Tracks): The tracks to perform the swap on.
nodes (tuple[Node, Node]): A tuple with two nodes.
Raises:
@@ -30,17 +30,17 @@ class UserSwapPredecessors(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
nodes: tuple[int, int],
):
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # narrow type
+ self.tracks: Tracks # narrow type
if len(nodes) != 2:
raise InvalidActionError("You can only swap a pair of two nodes.")
node1, node2 = nodes
- graph = tracks.graph
+ graph = tracks.graph_solution
# Find predecessors
pred1 = graph.predecessors(node1)[0] if graph.predecessors(node1) else None
diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py
index cf05f002..caf8f3c8 100644
--- a/src/funtracks/user_actions/user_add_edge.py
+++ b/src/funtracks/user_actions/user_add_edge.py
@@ -11,14 +11,14 @@
from .user_delete_edge import UserDeleteEdge
if TYPE_CHECKING:
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
class UserAddEdge(ActionGroup):
"""Assumes that the endpoints already exist and have track ids.
Args:
- tracks (SolutionTracks): the tracks to add the edge to
+ tracks (Tracks): the tracks to add the edge to
edge (tuple[int, int]): The edge to add
force (bool, optional): Whether to force the action by removing any conflicting
edges. Defaults to False.
@@ -29,26 +29,26 @@ class UserAddEdge(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
edge: tuple[int, int],
force: bool = False,
_top_level: bool = True,
):
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
source, target = edge
- if not tracks.graph.has_node(source):
+ if not tracks.graph_solution.has_node(source):
raise InvalidActionError(
f"Source node {source} not in solution yet - must be added before edge"
)
- if not tracks.graph.has_node(target):
+ if not tracks.graph_solution.has_node(target):
raise InvalidActionError(
f"Target node {target} not in solution yet - must be added before edge"
)
# Check if making a merge. If yes and force, remove the other edge and update
# track ids.
- in_degree_target = self.tracks.graph.in_degree(target)
+ in_degree_target = len(self.tracks.predecessors(target)) # type: ignore
if in_degree_target > 0:
if not force:
raise InvalidActionError(
@@ -57,7 +57,7 @@ def __init__(
forceable=True,
)
else:
- pred = next(iter(self.tracks.graph.predecessors(target)))
+ pred = next(iter(self.tracks.graph_solution.predecessors(target)))
merge_edge = (pred, target)
warnings.warn(
f"Removing edge {merge_edge} to add new edge without merging.",
@@ -68,7 +68,7 @@ def __init__(
)
# update track ids if needed
- out_degree_source = self.tracks.graph.out_degree(source)
+ out_degree_source = len(self.tracks.successors(source))
if out_degree_source == 0: # joining two segments
# assign the track id and lineage id of the source node to the target
# and all downstream nodes
@@ -79,7 +79,7 @@ def __init__(
)
elif out_degree_source == 1: # creating a division
# assign a new track id to existing child (lineage stays the same)
- successor = next(iter(self.tracks.graph.successors(source)))
+ successor = next(iter(self.tracks.graph_solution.successors(source)))
self.actions.append(
UpdateTrackIDs(self.tracks, successor, self.tracks.get_next_track_id())
)
diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py
index d1160e7a..af4ad0f9 100644
--- a/src/funtracks/user_actions/user_add_node.py
+++ b/src/funtracks/user_actions/user_add_node.py
@@ -15,7 +15,7 @@
from .user_delete_edge import UserDeleteEdge
if TYPE_CHECKING:
- from funtracks.data_model.solution_tracks import SolutionTracks
+ from funtracks.data_model.tracks import Tracks
class UserAddNode(ActionGroup):
@@ -30,7 +30,7 @@ class UserAddNode(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
node: int,
attributes: dict[str, Any],
pixels: tuple[np.ndarray, ...] | None = None,
@@ -39,7 +39,7 @@ def __init__(
):
"""
Args:
- tracks (SolutionTracks): the tracks to add the node to
+ tracks (Tracks): the tracks to add the node to
node (int): The node id of the new node to add
attributes (dict[str, Any]): A dictionary from attribute strings to values.
Must contain "time" and tracks.features.tracklet_key.
@@ -63,7 +63,7 @@ def __init__(
time point (forceable).
"""
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
# Get keys from tracks features
time_key = tracks.features.time_key
@@ -77,7 +77,7 @@ def __init__(
raise InvalidActionError(
f"Cannot add node without track id. Please add {track_id_key} attribute"
)
- if self.tracks.graph.has_node(node):
+ if self.tracks.graph_full.has_node(node):
raise InvalidActionError(
f"Node {node} already exists in the tracks, cannot add."
)
@@ -98,7 +98,7 @@ def __init__(
pred, succ = self.tracks.get_track_neighbors(track_id, time)
# check if you are adding a node to a track that divided previously
- if pred is not None and self.tracks.graph.out_degree(int(pred)) == 2:
+ if pred is not None and len(self.tracks.successors(pred)) == 2:
if not force:
raise InvalidActionError(
"Cannot add node here - upstream division event detected.",
@@ -118,11 +118,11 @@ def __init__(
# downstream
elif succ is not None:
# check pred of succ
- preds = self.tracks.graph.predecessors(succ)
+ preds = self.tracks.predecessors(succ)
pred_of_succ = preds[0] if preds else None
if (
pred_of_succ is not None
- and self.tracks.graph.out_degree(pred_of_succ) == 2
+ and len(self.tracks.successors(pred_of_succ)) == 2
):
if not force:
raise InvalidActionError(
diff --git a/src/funtracks/user_actions/user_delete_edge.py b/src/funtracks/user_actions/user_delete_edge.py
index 8af17c80..098a73ac 100644
--- a/src/funtracks/user_actions/user_delete_edge.py
+++ b/src/funtracks/user_actions/user_delete_edge.py
@@ -9,19 +9,19 @@
from ..actions.update_track_id import UpdateTrackIDs
if TYPE_CHECKING:
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
class UserDeleteEdge(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
edge: tuple[int, int],
_top_level: bool = True,
):
"""
Args:
- tracks (SolutionTracks): The tracks to delete the edge from.
+ tracks (Tracks): The tracks to delete the edge from.
edge (tuple[int, int]): The edge to delete.
_top_level (bool): If True, add this action to the history and emit
refresh. Set to False when used as a sub-action inside a compound
@@ -31,12 +31,12 @@ def __init__(
InvalidActionError: If the edge does not exist in the graph.
"""
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # Narrow type from base class
- if not self.tracks.graph.has_edge(*edge):
+ self.tracks: Tracks # Narrow type from base class
+ if not self.tracks.graph_solution.has_edge(*edge):
raise InvalidActionError(f"Edge {edge} not in solution, can't remove")
self.actions.append(DeleteEdge(tracks, edge))
- out_degree = self.tracks.graph.out_degree(edge[0])
+ out_degree = len(self.tracks.successors(edge[0]))
if out_degree == 0: # removed a normal (non division) edge
# orphaned segment gets new track id and new lineage id
new_track_id = self.tracks.get_next_track_id()
@@ -46,7 +46,7 @@ def __init__(
)
elif out_degree == 1: # removed a division edge
# sibling gets parent's track id (lineage stays the same)
- sibling = next(iter(self.tracks.graph.successors(edge[0])))
+ sibling = next(iter(self.tracks.graph_solution.successors(edge[0])))
new_track_id = self.tracks.get_track_id(edge[0])
self.actions.append(UpdateTrackIDs(self.tracks, sibling, new_track_id))
# orphaned child gets a new lineage id (now a separate component)
diff --git a/src/funtracks/user_actions/user_delete_node.py b/src/funtracks/user_actions/user_delete_node.py
index ef588560..86416587 100644
--- a/src/funtracks/user_actions/user_delete_node.py
+++ b/src/funtracks/user_actions/user_delete_node.py
@@ -10,20 +10,20 @@
from ..actions.update_track_id import UpdateTrackIDs
if TYPE_CHECKING:
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
class UserDeleteNode(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
node: int,
pixels: None | tuple[np.ndarray, ...] = None,
_top_level: bool = True,
):
"""
Args:
- tracks (SolutionTracks): The tracks to delete the node from.
+ tracks (Tracks): The tracks to delete the node from.
node (int): The node id to delete.
pixels (tuple[np.ndarray, ...] | None): The pixels of the node in the
segmentation, if known. Will be computed if not provided.
@@ -33,7 +33,7 @@ def __init__(
action. Defaults to True.
"""
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
# delete adjacent edges
for pred in self.tracks.predecessors(node):
siblings = self.tracks.successors(pred)
diff --git a/src/funtracks/user_actions/user_delete_nodes.py b/src/funtracks/user_actions/user_delete_nodes.py
index 05f30bb2..279351b1 100644
--- a/src/funtracks/user_actions/user_delete_nodes.py
+++ b/src/funtracks/user_actions/user_delete_nodes.py
@@ -8,7 +8,7 @@
from .user_delete_node import UserDeleteNode
if TYPE_CHECKING:
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
class UserDeleteNodes(ActionGroup):
@@ -26,12 +26,12 @@ class UserDeleteNodes(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
nodes: list[int],
pixels: None | list[tuple[np.ndarray, ...]] = None,
):
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
for i, node in enumerate(nodes):
self.actions.append(
UserDeleteNode(
diff --git a/src/funtracks/user_actions/user_update_node_attrs.py b/src/funtracks/user_actions/user_update_node_attrs.py
index 7cb8cfe2..27616200 100644
--- a/src/funtracks/user_actions/user_update_node_attrs.py
+++ b/src/funtracks/user_actions/user_update_node_attrs.py
@@ -8,7 +8,7 @@
if TYPE_CHECKING:
from typing import Any
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
class UserUpdateNodeAttrs(ActionGroup):
@@ -21,14 +21,14 @@ class UserUpdateNodeAttrs(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
node: int,
attrs: dict[str, Any],
_top_level: bool = True,
):
"""
Args:
- tracks (SolutionTracks): The tracks to update the node attributes for
+ tracks (Tracks): The tracks to update the node attributes for
node (int): The node to update the attributes for
attrs (dict[str, Any]): A mapping from attribute name to new attribute
values for the given node.
@@ -40,7 +40,7 @@ def __init__(
ValueError: If a protected attribute is in the given attribute mapping.
"""
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
# Call the basic UpdateNodeAttrs action
self.actions.append(UpdateNodeAttrs(tracks, node, attrs))
diff --git a/src/funtracks/user_actions/user_update_nodes_attrs.py b/src/funtracks/user_actions/user_update_nodes_attrs.py
index e5cae4fd..40e088ec 100644
--- a/src/funtracks/user_actions/user_update_nodes_attrs.py
+++ b/src/funtracks/user_actions/user_update_nodes_attrs.py
@@ -8,7 +8,7 @@
if TYPE_CHECKING:
from typing import Any
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
class UserUpdateNodesAttrs(ActionGroup):
@@ -28,12 +28,12 @@ class UserUpdateNodesAttrs(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
nodes: list[int],
attrs: dict[str, list[Any]],
):
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
for key, values in attrs.items():
if not isinstance(values, list):
raise ValueError(
diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py
index 6a312265..04941e52 100644
--- a/src/funtracks/user_actions/user_update_segmentation.py
+++ b/src/funtracks/user_actions/user_update_segmentation.py
@@ -12,13 +12,13 @@
from .user_delete_node import UserDeleteNode
if TYPE_CHECKING:
- from funtracks.data_model import SolutionTracks
+ from funtracks.data_model import Tracks
class UserUpdateSegmentation(ActionGroup):
def __init__(
self,
- tracks: SolutionTracks,
+ tracks: Tracks,
new_value: int,
updated_pixels: list[tuple[tuple[np.ndarray, ...], int]],
current_track_id: int,
@@ -30,7 +30,7 @@ def __init__(
add_node action doesn't have anything with pixels.
Args:
- tracks (SolutionTracks): The solution tracks that the user is updating.
+ tracks (Tracks): The solution tracks that the user is updating.
new_value (int): The new value that the user painted with
updated_pixels (list[tuple[tuple[np.ndarray, ...], int]]): A list of node
update actions, consisting of a numpy multi-index, pointing to the array
@@ -42,7 +42,7 @@ def __init__(
Defaults to False.
"""
super().__init__(tracks, actions=[])
- self.tracks: SolutionTracks # Narrow type from base class
+ self.tracks: Tracks # Narrow type from base class
node_to_select = None
if self.tracks.segmentation is None:
raise ValueError("Cannot update non-existing segmentation.")
@@ -62,7 +62,21 @@ def __init__(
"Can only update one time point at a time"
)
time = int(all_pixels[0][0])
- if self.tracks.graph.has_node(new_value):
+ if self.tracks.graph_full.has_node(new_value):
+ # An id that already names a node must take the update path: you cannot
+ # create a *new* node on a taken id. Ids are globally unique across
+ # graph_full (full ⊇ solution, never reused), so graph_full is the
+ # correct "does this id name a node?" check; using graph_solution here
+ # would misroute a soft-deleted id to the add path and silently revive
+ # the old node with stale attributes.
+ # NOTE: reviving a soft-deleted (solution=False) node by painting it
+ # back is not supported yet — UpdateNodeSeg reads the solution view and
+ # would fail. Guard explicitly until revive-by-paint is implemented.
+ if not self.tracks.graph_solution.has_node(new_value):
+ raise NotImplementedError(
+ f"Cannot paint onto node {new_value}: it is soft-deleted (not "
+ "in the solution). Revive-by-paint is not supported yet."
+ )
mask_pixels = pixels_to_td_mask(all_pixels, self.tracks.ndim)
self.actions.append(
UpdateNodeSeg(tracks, new_value, mask_pixels, added=True)
@@ -70,8 +84,6 @@ def __init__(
else:
time_key = tracks.features.time_key
tracklet_key = tracks.features.tracklet_key
- if tracklet_key is None:
- raise ValueError("Track ID key is not set in tracks features")
attrs: dict[str, int] = {
time_key: time,
tracklet_key: current_track_id,
@@ -96,7 +108,7 @@ def __init__(
time = pixels[0][0]
# check if all pixels of old_value are removed
mask_pixels = pixels_to_td_mask(pixels, self.tracks.ndim)
- mask_old_value = self.tracks.graph.nodes[old_value]["mask"]
+ mask_old_value = self.tracks.graph_full.nodes[old_value]["mask"]
# If pixels fully overlaps with old_value mask, delete node
if mask_pixels.intersection(mask_old_value) == mask_old_value.mask.sum():
self.actions.append(
diff --git a/src/funtracks/utils/__init__.py b/src/funtracks/utils/__init__.py
index c0ff3cd2..8c7867f4 100644
--- a/src/funtracks/utils/__init__.py
+++ b/src/funtracks/utils/__init__.py
@@ -10,10 +10,10 @@
setup_zarr_array,
setup_zarr_group,
)
-from .tracksdata_utils import create_empty_graphview_graph
+from .tracksdata_utils import create_empty_graph
__all__ = [
- "create_empty_graphview_graph",
+ "create_empty_graph",
"detect_zarr_spec_version",
"get_store_path",
"is_zarr_v3",
diff --git a/src/funtracks/utils/_segmentation_utils.py b/src/funtracks/utils/_segmentation_utils.py
index 73edf470..4c13e443 100644
--- a/src/funtracks/utils/_segmentation_utils.py
+++ b/src/funtracks/utils/_segmentation_utils.py
@@ -25,7 +25,7 @@ def relabel_segmentation_with_track_id(
# Division nodes have out_degree > 1; their outgoing edges are cut so that
# each daughter cell starts a new tracklet
division_nodes = {
- n for n in solution_graph.node_ids() if solution_graph.out_degree(n) > 1
+ n for n in solution_graph.node_ids() if len(solution_graph.successors(n)) > 1
}
visited: set = set()
diff --git a/src/funtracks/utils/tracksdata_utils.py b/src/funtracks/utils/tracksdata_utils.py
index 9291ec55..3b1f3887 100644
--- a/src/funtracks/utils/tracksdata_utils.py
+++ b/src/funtracks/utils/tracksdata_utils.py
@@ -67,7 +67,7 @@ def to_polars_dtype(dtype_or_value: str | Any) -> pl.DataType:
raise ValueError(f"Unsupported type: {type(dtype_or_value)}")
-def create_empty_graphview_graph(
+def create_empty_graph(
node_attributes: list[str] | None = None,
edge_attributes: list[str] | None = None,
node_default_values: list[Any] | None = None,
@@ -75,9 +75,9 @@ def create_empty_graphview_graph(
database: str | None = None,
position_attrs: list[str] | None = None,
ndim: int = 3,
-) -> td.graph.GraphView:
+) -> td.graph.BaseGraph:
"""
- Create an empty tracksdata GraphView with standard node and edge attributes.
+ Create an empty tracksdata base graph with standard node and edge attributes.
Parameters
----------
node_attributes : list[str] | None
@@ -102,8 +102,9 @@ def create_empty_graphview_graph(
Returns
-------
- td.graph.GraphView
- An empty tracksdata GraphView with standard node and edge attributes.
+ td.graph.BaseGraph
+ An empty tracksdata base graph with standard node and edge attributes
+ (including a `solution` flag). Tracks builds the solution==True view from it.
"""
if position_attrs is None:
position_attrs = ["pos"]
@@ -187,12 +188,8 @@ def create_empty_graphview_graph(
if "solution" not in graph_td.edge_attr_keys():
graph_td.add_edge_attr_key("solution", default_value=True, dtype=pl.Boolean)
- graph_td_sub = graph_td.filter(
- td.NodeAttr("solution") == True, # noqa: E712
- td.EdgeAttr("solution") == True, # noqa: E712
- ).subgraph()
-
- return graph_td_sub
+ # Return the full base graph; Tracks builds the solution==True view internally.
+ return graph_td
def assert_node_attrs_equal_with_masks(
@@ -202,8 +199,8 @@ def assert_node_attrs_equal_with_masks(
Fully compare the content of two graphs (node attributes and Masks)
"""
- if isinstance(object1, td.graph.GraphView) and (
- isinstance(object2, td.graph.GraphView)
+ if isinstance(object1, td.graph.BaseGraph) and (
+ isinstance(object2, td.graph.BaseGraph)
):
node_attrs1 = object1.node_attrs()
node_attrs2 = object2.node_attrs()
@@ -385,21 +382,21 @@ def segmentation_to_masks(
def add_masks_and_bboxes_to_graph(
- graph: td.graph.GraphView,
+ graph: td.graph.BaseGraph,
segmentation: np.ndarray,
-) -> td.graph.GraphView:
+) -> td.graph.BaseGraph:
"""Add mask and bbox attributes to graph nodes from segmentation.
Parameters
----------
- graph : td.graph.GraphView
+ graph : td.graph.BaseGraph
Graph to add attributes to
segmentation : np.ndarray
Segmentation array of shape (T, Z, Y, X) or (T, Y, X)
Returns
-------
- td.graph.GraphView
+ td.graph.BaseGraph
Graph with 'mask' and 'bbox' attributes added to nodes
"""
@@ -480,14 +477,11 @@ def td_relabel_nodes(graph, mapping: dict[int, int]) -> td.graph.IndexedRXGraph:
}
new_graph.add_edge(source_id, target_id, attrs)
- new_graph_sub = new_graph.filter(
- td.NodeAttr("solution") == True, # noqa: E712
- td.EdgeAttr("solution") == True, # noqa: E712
- ).subgraph()
- return new_graph_sub
+ # Return the full base graph; Tracks builds the solution==True view internally.
+ return new_graph
-def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.GraphView:
+def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.BaseGraph:
"""Convert a NetworkX DiGraph to a tracksdata graph.
Args:
@@ -584,10 +578,5 @@ def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.GraphView:
attrs_copy["solution"] = True
graph_td.add_edge(source_id, target_id, attrs_copy)
- # Create subgraph (GraphView) with only solution nodes and edges
- graph_td_sub = graph_td.filter(
- td.NodeAttr("solution") == True, # noqa: E712
- td.EdgeAttr("solution") == True, # noqa: E712
- ).subgraph()
-
- return graph_td_sub
+ # Return the full base graph; Tracks builds the solution==True view internally.
+ return graph_td
diff --git a/tests/actions/test_action_history.py b/tests/actions/test_action_history.py
index 4365f22c..96fd5218 100644
--- a/tests/actions/test_action_history.py
+++ b/tests/actions/test_action_history.py
@@ -1,18 +1,18 @@
from funtracks.actions import AddNode
from funtracks.actions.action_history import ActionHistory
-from funtracks.data_model import SolutionTracks
-from funtracks.utils.tracksdata_utils import create_empty_graphview_graph
+from funtracks.data_model import Tracks
+from funtracks.utils.tracksdata_utils import create_empty_graph
# https://github.com/zaboople/klonk/blob/master/TheGURQ.md
def test_action_history():
history = ActionHistory()
- empty_graph = create_empty_graphview_graph(
+ empty_graph = create_empty_graph(
node_attributes=["track_id", "pos"],
edge_attributes=[],
)
- tracks = SolutionTracks(empty_graph, ndim=3, tracklet_attr="track_id", time_attr="t")
+ tracks = Tracks(empty_graph, ndim=3, tracklet_attr="track_id", time_attr="t")
pos = [0, 1]
action1 = AddNode(tracks, node=0, attributes={"t": 0, "pos": pos, "track_id": 1})
@@ -24,7 +24,7 @@ def test_action_history():
history.add_new_action(action1)
# undo the action
assert history.undo()
- assert tracks.graph.num_nodes() == 0
+ assert tracks.graph_solution.num_nodes() == 0
assert len(history.undo_stack) == 1
assert len(history.redo_stack) == 1
assert history._undo_pointer == -1
@@ -34,7 +34,7 @@ def test_action_history():
# redo the action
assert history.redo()
- assert tracks.graph.num_nodes() == 1
+ assert tracks.graph_solution.num_nodes() == 1
assert len(history.undo_stack) == 1
assert len(history.redo_stack) == 0
assert history._undo_pointer == 0
@@ -46,7 +46,7 @@ def test_action_history():
assert history.undo()
action2 = AddNode(tracks, node=10, attributes={"t": 10, "pos": pos, "track_id": 2})
history.add_new_action(action2)
- assert tracks.graph.num_nodes() == 1
+ assert tracks.graph_solution.num_nodes() == 1
# there are 3 things on the stack: action1, action1's inverse, and action 2
assert len(history.undo_stack) == 3
assert len(history.redo_stack) == 0
@@ -55,7 +55,7 @@ def test_action_history():
# undo back to after action 1
assert history.undo()
assert history.undo()
- assert tracks.graph.num_nodes() == 1
+ assert tracks.graph_solution.num_nodes() == 1
assert len(history.undo_stack) == 3
assert len(history.redo_stack) == 2
diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py
index de856974..13fb1213 100644
--- a/tests/actions/test_add_delete_edge.py
+++ b/tests/actions/test_add_delete_edge.py
@@ -8,9 +8,9 @@
AddEdge,
DeleteEdge,
)
-from funtracks.data_model import SolutionTracks
+from funtracks.data_model import Tracks
from funtracks.features import FeatureDict, LineageID, Position, Time, TrackletID
-from funtracks.utils.tracksdata_utils import create_empty_graphview_graph
+from funtracks.utils.tracksdata_utils import create_empty_graph
iou_key = "iou"
@@ -18,26 +18,26 @@
@pytest.mark.parametrize("ndim", [3, 4])
@pytest.mark.parametrize("with_seg", [True, False])
def test_add_delete_edges(get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
- reference_graph = tracks.graph
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ reference_graph = tracks.graph_solution
reference_seg = np.asarray(tracks.segmentation).copy()
# Create an empty tracks with just nodes (no edges)
- for edge in tracks.graph.edge_list():
- tracks.graph.remove_edge(*edge)
+ for edge in tracks.graph_solution.edge_list():
+ tracks.graph_solution.remove_edge(*edge)
edges = [(1, 2), (1, 3), (3, 4), (4, 5)]
action = ActionGroup(tracks=tracks, actions=[AddEdge(tracks, edge) for edge in edges])
- with pytest.raises(ValueError, match="Edge .* already exists in the graph"):
+ with pytest.raises(ValueError, match="Edge .* already exists in the solution"):
AddEdge(tracks, (1, 2))
# TODO: What if adding an edge that already exists?
# TODO: test all the edge cases, invalid operations, etc. for all actions
- assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids())
+ assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids())
assert_frame_equal(
- tracks.graph.edge_attrs(),
+ tracks.graph_solution.edge_attrs(),
reference_graph.edge_attrs(),
check_row_order=False,
check_column_order=False,
@@ -47,16 +47,18 @@ def test_add_delete_edges(get_tracks, ndim, with_seg):
inverse = action.inverse()
- assert set(tracks.graph.edge_ids()) == set()
+ assert set(tracks.graph_solution.edge_ids()) == set()
if tracks.segmentation is not None:
assert_array_almost_equal(tracks.segmentation, reference_seg)
re_added = inverse.inverse()
- assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids())
- assert set(tracks.graph.edge_ids()) == set(reference_graph.edge_ids())
- assert sorted(tracks.graph.edge_list()) == sorted(reference_graph.edge_list())
+ assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids())
+ assert set(tracks.graph_solution.edge_ids()) == set(reference_graph.edge_ids())
+ assert sorted(tracks.graph_solution.edge_list()) == sorted(
+ reference_graph.edge_list()
+ )
assert_frame_equal(
- tracks.graph.edge_attrs(),
+ tracks.graph_solution.edge_attrs(),
reference_graph.edge_attrs(),
check_row_order=False,
check_column_order=False,
@@ -74,17 +76,19 @@ def test_add_delete_edges(get_tracks, ndim, with_seg):
# objects again — that's where the corruption surfaces.
re_added.inverse() # reset state: edges absent (fresh objects, no bug here)
inverse.inverse() # same DeleteEdge objects called again — must not crash
- assert set(tracks.graph.edge_ids()) == set(reference_graph.edge_ids())
+ assert set(tracks.graph_solution.edge_ids()) == set(reference_graph.edge_ids())
def test_add_edge_missing_endpoint(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
- with pytest.raises(ValueError, match="Cannot add edge .*: endpoint .* not in graph"):
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
+ with pytest.raises(
+ ValueError, match="Cannot add edge .*: endpoint .* not in solution"
+ ):
AddEdge(tracks, (10, 11))
def test_delete_missing_edge(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
with pytest.raises(
ValueError, match="Edge .* not in the graph, and cannot be removed"
):
@@ -97,7 +101,7 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg):
"""Test custom edge attributes preserved through add/delete/re-add cycles."""
from funtracks.features import Feature
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# Register custom edge features so they get saved by DeleteEdge
custom_features = {
@@ -138,29 +142,93 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg):
action = AddEdge(tracks, edge, attributes=custom_attrs)
# Verify all attributes are present after adding
- assert tracks.graph.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*edge)
for key, value in custom_attrs.items():
- edge_id = tracks.graph.edge_id(*edge)
- assert tracks.graph.edges[edge_id][key] == value, (
+ edge_id = tracks.graph_solution.edge_id(*edge)
+ assert tracks.graph_solution.edges[edge_id][key] == value, (
f"Attribute {key} not set correctly on edge"
)
# Delete the edge
delete_action = action.inverse()
- assert not tracks.graph.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*edge)
# Re-add the edge by inverting the delete
delete_action.inverse()
- assert tracks.graph.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*edge)
# Verify all custom attributes are still present after re-adding
for key, value in custom_attrs.items():
- edge_id = tracks.graph.edge_id(*edge)
- assert tracks.graph.edges[edge_id][key] == value, (
+ edge_id = tracks.graph_solution.edge_id(*edge)
+ assert tracks.graph_solution.edges[edge_id][key] == value, (
f"Attribute {key} not preserved after delete/re-add cycle"
)
+def test_add_edge_revive_applies_new_attributes(get_tracks):
+ """Re-adding a soft-deleted edge with NEW attributes must apply them, not silently
+ keep the values preserved from before the delete.
+
+ Existing tests only re-add via inverse(), which reuses the SAME attrs that
+ soft-delete preserved, so the revive-vs-add-new asymmetry is invisible to them.
+ """
+ from funtracks.features import Feature
+
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
+ tracks.add_feature(
+ "weight",
+ Feature(
+ feature_type="edge",
+ value_type="float",
+ num_values=1,
+ display_name="Weight",
+ default_value=None,
+ ),
+ )
+
+ edge = (1, 5)
+ AddEdge(tracks, edge, attributes={"weight": 1.5})
+ DeleteEdge(tracks, edge) # soft-delete: weight=1.5 preserved in graph_full
+
+ # Fresh add of the (now soft-deleted) edge with a DIFFERENT weight.
+ AddEdge(tracks, edge, attributes={"weight": 9.9})
+
+ edge_id = tracks.graph_solution.edge_id(*edge)
+ assert tracks.graph_solution.edges[edge_id]["weight"] == 9.9
+
+
+def test_add_edge_revive_applies_new_vector_attributes(get_tracks):
+ """Reviving an edge with a non-scalar (vector) attribute must apply it correctly.
+
+ update_edge_attrs reads a bare list value as one-value-per-edge, so passing a
+ vector attr unwrapped for a single edge raises a size mismatch (or, for a
+ length-1 list, silently unwraps it to its element).
+ """
+ from funtracks.features import Feature
+
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
+ tracks.add_feature(
+ "flow",
+ Feature(
+ feature_type="edge",
+ value_type="float",
+ num_values=2,
+ display_name=["Flow Y", "Flow X"],
+ default_value=None,
+ ),
+ )
+
+ edge = (1, 5)
+ AddEdge(tracks, edge, attributes={"flow": [1.0, 2.0]})
+ DeleteEdge(tracks, edge) # soft-delete: flow preserved in graph_full
+
+ # Fresh add of the (now soft-deleted) edge with a DIFFERENT flow vector.
+ AddEdge(tracks, edge, attributes={"flow": [3.0, 4.0]})
+
+ edge_id = tracks.graph_solution.edge_id(*edge)
+ assert list(tracks.graph_solution.edges[edge_id]["flow"]) == [3.0, 4.0]
+
+
def test_add_edge_with_unregistered_edge_attr(tmp_path):
"""AddEdge must not crash when the graph has edge attrs absent from tracks.features.
@@ -175,7 +243,7 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path):
# Build a graph with "custom_score" on every edge.
# This mirrors what the motile solver does: it writes edge attributes directly
# to the graph without going through tracks.add_feature().
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=["pos", "track_id", "lineage_id"],
edge_attributes=["custom_score"],
database=db_path,
@@ -203,7 +271,7 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path):
indices=[1, 2],
)
- # Wrap in SolutionTracks without registering "custom_score" in features —
+ # Wrap in Tracks without registering "custom_score" in features —
# this is the scenario that triggers the bug.
features = FeatureDict(
features={
@@ -217,13 +285,13 @@ def test_add_edge_with_unregistered_edge_attr(tmp_path):
tracklet_key="track_id",
lineage_key="lineage_id",
)
- tracks = SolutionTracks(graph, ndim=3, features=features)
+ tracks = Tracks(graph, ndim=3, features=features)
# Sanity: "custom_score" is in the graph schema but NOT in tracks.features.
- assert "custom_score" in tracks.graph.edge_attr_keys()
+ assert "custom_score" in tracks.graph_solution.edge_attr_keys()
assert "custom_score" not in tracks.features
# Before the fix this raises: KeyError: 'custom_score'
AddEdge(tracks, (1, 2))
- assert tracks.graph.has_edge(1, 2)
+ assert tracks.graph_solution.has_edge(1, 2)
diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py
index 9d6b99c4..b716355c 100644
--- a/tests/actions/test_add_delete_nodes.py
+++ b/tests/actions/test_add_delete_nodes.py
@@ -1,5 +1,6 @@
import numpy as np
import pytest
+import tracksdata as td
from numpy.testing import assert_array_almost_equal, assert_array_equal
from polars.testing import assert_frame_equal
from tracksdata.array import GraphArrayView
@@ -10,7 +11,7 @@
)
from funtracks.utils.tracksdata_utils import (
assert_node_attrs_equal_with_masks,
- create_empty_graphview_graph,
+ create_empty_graph,
)
from ..conftest import make_2d_disk_mask, make_3d_sphere_mask
@@ -20,8 +21,8 @@
@pytest.mark.parametrize("with_seg", [True, False])
def test_add_delete_nodes(get_tracks, ndim, with_seg):
# Get a tracks instance
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
- reference_graph = tracks.graph
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ reference_graph = tracks.graph_solution
reference_seg = np.asarray(tracks.segmentation).copy() if with_seg else None
# Start with an empty Tracks
@@ -32,17 +33,26 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg):
tracks.features.position_key,
]
edge_attributes = ["iou"] if with_seg else []
- empty_graph = create_empty_graphview_graph(
+ empty_graph = create_empty_graph(
node_attributes=node_attributes + (["area", "bbox", "mask"] if with_seg else []),
edge_attributes=edge_attributes,
ndim=ndim,
)
empty_seg = np.zeros_like(tracks.segmentation) if with_seg else None
- tracks.graph = empty_graph
+ # Reset the tracks onto the empty base graph, mirroring Tracks.__init__: graph_full
+ # is the base graph, graph_solution its solution==True view.
+ tracks.graph_full = empty_graph
+ tracks.graph_solution = empty_graph.filter(
+ td.NodeAttr("solution") == True, # noqa: E712
+ td.EdgeAttr("solution") == True, # noqa: E712
+ ).subgraph()
segmentation_shape = (5, 100, 100) if ndim == 3 else (5, 100, 100, 100)
tracks.segmentation = (
GraphArrayView(
- graph=tracks.graph, shape=segmentation_shape, attr_key="node_id", offset=0
+ graph=tracks.graph_solution,
+ shape=segmentation_shape,
+ attr_key="node_id",
+ offset=0,
)
if with_seg
else None
@@ -78,8 +88,8 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg):
actions.append(AddNode(tracks, node, attributes=attrs))
action = ActionGroup(tracks=tracks, actions=actions)
- assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids())
- data_tracks = tracks.graph.node_attrs()
+ assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids())
+ data_tracks = tracks.graph_solution.node_attrs()
data_reference = reference_graph.node_attrs()
if with_seg:
assert_array_almost_equal(tracks.segmentation, reference_seg)
@@ -93,16 +103,17 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg):
check_dtypes=False,
)
- # Invert the action to delete all the nodes
+ # Invert the action to delete all the nodes. They are soft-deleted, so they remain
+ # in graph_full (empty_graph) with solution=False but drop out of the solution view.
del_nodes = action.inverse()
- assert set(tracks.graph.node_ids()) == set(empty_graph.node_ids())
+ assert set(tracks.graph_solution.node_ids()) == set()
if with_seg:
assert_array_almost_equal(tracks.segmentation, empty_seg)
# Re-invert the action to add back all the nodes and their attributes
del_nodes.inverse()
- assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids())
- data_tracks = tracks.graph.node_attrs()
+ assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids())
+ data_tracks = tracks.graph_solution.node_attrs()
data_reference = reference_graph.node_attrs()
if with_seg:
assert_array_almost_equal(tracks.segmentation, reference_seg)
@@ -117,20 +128,45 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg):
)
+def test_add_node_revive_applies_new_attributes(get_tracks):
+ """Re-adding a soft-deleted node with NEW attributes must apply them, not silently
+ keep the values preserved from before the delete.
+
+ Existing tests only re-add via inverse(), which reuses the SAME attrs that
+ soft-delete preserved, so the revive-vs-add-new asymmetry is invisible to them.
+ Mirrors test_add_edge_revive_applies_new_attributes.
+ """
+ from funtracks.actions import DeleteNode
+
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
+
+ node_id = 100
+ AddNode(tracks, node_id, {"t": 2, "track_id": 10, "pos": [50.0, 50.0]})
+ DeleteNode(tracks, node_id) # soft-delete: attrs preserved in graph_full
+
+ # Fresh add of the (now soft-deleted) node with DIFFERENT attributes.
+ AddNode(tracks, node_id, {"t": 2, "track_id": 11, "pos": [60.0, 60.0]})
+
+ assert tracks.graph_solution.nodes[node_id]["track_id"] == 11
+ assert_array_almost_equal(
+ tracks.graph_solution.nodes[node_id]["pos"], np.array([60.0, 60.0])
+ )
+
+
def test_add_node_missing_time(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
with pytest.raises(ValueError, match="Must provide a time attribute for node"):
AddNode(tracks, 8, {})
def test_add_node_missing_pos(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
# First test: missing track_id raises an error
with pytest.raises(ValueError, match="Must provide a track_id attribute for node"):
AddNode(tracks, 8, {"t": 2})
# Second test: with track_id but without segmentation, missing pos raises an error
- tracks_no_seg = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks_no_seg = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
with pytest.raises(
ValueError, match="Must provide position or segmentation for node"
):
@@ -143,7 +179,7 @@ def test_custom_attributes_preserved(get_tracks, ndim, with_seg):
"""Test custom node attributes preserved through add/delete/re-add cycles."""
from funtracks.features import Feature
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# Register custom features so they get saved by DeleteNode
custom_features = {
@@ -197,36 +233,44 @@ def test_custom_attributes_preserved(get_tracks, ndim, with_seg):
node_id = 100
action = AddNode(tracks, node_id, custom_attrs.copy())
# Verify all attributes are present after adding
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
for key, value in custom_attrs.items():
if key == "pos":
- assert_array_almost_equal(tracks.graph.nodes[node_id][key], np.array(value))
+ assert_array_almost_equal(
+ tracks.graph_solution.nodes[node_id][key], np.array(value)
+ )
elif key == "mask":
continue
elif key == "bbox":
- assert_array_equal(np.asarray(tracks.graph.nodes[node_id][key]), value)
+ assert_array_equal(
+ np.asarray(tracks.graph_solution.nodes[node_id][key]), value
+ )
else:
- assert tracks.graph.nodes[node_id][key] == value, (
+ assert tracks.graph_solution.nodes[node_id][key] == value, (
f"Attribute {key} not preserved after add"
)
# Delete the node
delete_action = action.inverse()
- assert node_id not in tracks.graph.node_ids()
+ assert node_id not in tracks.graph_solution.node_ids()
# Re-add the node by inverting the delete
delete_action.inverse()
- assert node_id in tracks.graph.node_ids()
+ assert node_id in tracks.graph_solution.node_ids()
# Verify all custom attributes are still present after re-adding
for key, value in custom_attrs.items():
if key == "pos":
- assert_array_almost_equal(tracks.graph.nodes[node_id][key], np.array(value))
+ assert_array_almost_equal(
+ tracks.graph_solution.nodes[node_id][key], np.array(value)
+ )
elif key == "mask":
continue
elif key == "bbox":
- assert_array_equal(np.asarray(tracks.graph.nodes[node_id][key]), value)
+ assert_array_equal(
+ np.asarray(tracks.graph_solution.nodes[node_id][key]), value
+ )
else:
- assert tracks.graph.nodes[node_id][key] == value, (
+ assert tracks.graph_solution.nodes[node_id][key] == value, (
f"Attribute {key} not preserved after delete/re-add cycle"
)
diff --git a/tests/actions/test_base_action.py b/tests/actions/test_base_action.py
index 9e4c0047..2b9e3f56 100644
--- a/tests/actions/test_base_action.py
+++ b/tests/actions/test_base_action.py
@@ -6,7 +6,7 @@
def test_initialize_base_class(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
action = Action(tracks)
with pytest.raises(NotImplementedError):
action.inverse()
diff --git a/tests/actions/test_update_node_attrs.py b/tests/actions/test_update_node_attrs.py
index f027ead1..5171e9b4 100644
--- a/tests/actions/test_update_node_attrs.py
+++ b/tests/actions/test_update_node_attrs.py
@@ -8,7 +8,7 @@
@pytest.mark.parametrize("ndim", [3, 4])
def test_update_node_attrs(get_tracks, ndim):
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
node = 1
new_feature = Feature(
@@ -32,6 +32,6 @@ def test_update_node_attrs(get_tracks, ndim):
@pytest.mark.parametrize("attr", ["t", "area", "track_id"])
def test_update_protected_attr(get_tracks, attr):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
with pytest.raises(ValueError, match="Cannot update attribute .* manually"):
UpdateNodeAttrs(tracks, 1, {attr: 2})
diff --git a/tests/actions/test_update_node_segs.py b/tests/actions/test_update_node_segs.py
index 73a98af4..7115d019 100644
--- a/tests/actions/test_update_node_segs.py
+++ b/tests/actions/test_update_node_segs.py
@@ -10,15 +10,15 @@
@pytest.mark.parametrize("ndim", [3, 4])
def test_update_node_segs(get_tracks, ndim):
# Get tracks with segmentation
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
- reference_graph = tracks.graph.detach().filter().subgraph()
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
+ reference_graph = tracks.graph_solution.detach().filter().subgraph()
node = 1
time = tracks.get_time(node)
original_seg = np.asarray(tracks.segmentation).copy()
- original_area = tracks.graph.nodes[1]["area"]
- original_pos = tracks.graph.nodes[1]["pos"]
+ original_area = tracks.graph_solution.nodes[1]["area"]
+ original_pos = tracks.graph_solution.nodes[1]["pos"]
# Add a couple pixels to the first node
new_seg = np.asarray(tracks.segmentation).copy()
@@ -31,22 +31,22 @@ def test_update_node_segs(get_tracks, ndim):
action = UpdateNodeSeg(tracks, node, mask=mask, added=True)
- assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids())
- assert tracks.graph.nodes[1]["area"] == original_area + 1
- assert not np.allclose(tracks.graph.nodes[1]["pos"], original_pos)
+ assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids())
+ assert tracks.graph_solution.nodes[1]["area"] == original_area + 1
+ assert not np.allclose(tracks.graph_solution.nodes[1]["pos"], original_pos)
assert_array_almost_equal(tracks.segmentation, new_seg)
inverse = action.inverse()
- assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids())
+ assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids())
assert_series_equal(
reference_graph.nodes[1]["pos"],
- tracks.graph.nodes[1]["pos"],
+ tracks.graph_solution.nodes[1]["pos"],
)
assert_array_almost_equal(tracks.segmentation, original_seg)
inverse.inverse()
- assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids())
- assert tracks.graph.nodes[1]["area"] == original_area + 1
- assert not np.allclose(tracks.graph.nodes[1]["pos"], original_pos)
+ assert set(tracks.graph_solution.node_ids()) == set(reference_graph.node_ids())
+ assert tracks.graph_solution.nodes[1]["area"] == original_area + 1
+ assert not np.allclose(tracks.graph_solution.nodes[1]["pos"], original_pos)
assert_array_almost_equal(tracks.segmentation, new_seg)
diff --git a/tests/annotators/test_annotator_registry.py b/tests/annotators/test_annotator_registry.py
index 6a7e8e3d..30996ff3 100644
--- a/tests/annotators/test_annotator_registry.py
+++ b/tests/annotators/test_annotator_registry.py
@@ -1,9 +1,12 @@
import pytest
from funtracks.annotators import EdgeAnnotator, RegionpropsAnnotator, TrackAnnotator
-from funtracks.data_model import SolutionTracks, Tracks
+from funtracks.data_model import Tracks
track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"}
+# Even without an explicit tracklet_attr, every Tracks gets a TrackAnnotator and a
+# default tracklet key (no "plain" vs "solution" split).
+plain_attrs = {"time_attr": "t"}
def test_annotator_registry_init_with_segmentation(
@@ -14,31 +17,32 @@ def test_annotator_registry_init_with_segmentation(
tracks = Tracks(
graph_2d_with_segmentation,
ndim=3,
- **track_attrs,
+ **plain_attrs,
)
annotator_types = [type(ann) for ann in tracks.annotators]
assert RegionpropsAnnotator in annotator_types
assert EdgeAnnotator in annotator_types
- assert TrackAnnotator not in annotator_types # Not a SolutionTracks
+ assert TrackAnnotator in annotator_types # every Tracks has track ids
def test_annotator_registry_init_without_segmentation(graph_2d_with_position):
- """Test AnnotatorRegistry doesn't create annotators without segmentation."""
- tracks = Tracks(graph_2d_with_position, ndim=3, **track_attrs)
+ """Without segmentation: no regionprops/edge annotators, but a TrackAnnotator is
+ still registered (track ids are a core feature of every Tracks)."""
+ tracks = Tracks(graph_2d_with_position, ndim=3, **plain_attrs)
annotator_types = [type(ann) for ann in tracks.annotators]
assert RegionpropsAnnotator not in annotator_types
assert EdgeAnnotator not in annotator_types
- assert TrackAnnotator not in annotator_types
+ assert TrackAnnotator in annotator_types
def test_annotator_registry_init_solution_tracks(
graph_2d_with_segmentation,
):
- """Test AnnotatorRegistry creates all annotators for SolutionTracks with
+ """Test AnnotatorRegistry creates all annotators for Tracks with
segmentation."""
- tracks = SolutionTracks(
+ tracks = Tracks(
graph_2d_with_segmentation,
ndim=3,
**track_attrs,
@@ -57,13 +61,13 @@ def test_enable_disable_features(graph_2d_with_segmentation):
**track_attrs,
)
- nodes = list(tracks.graph.node_ids())
- edges = list(tracks.graph.edge_ids())
+ nodes = list(tracks.graph_solution.node_ids())
+ edges = list(tracks.graph_solution.edge_ids())
# Core features (time, pos) should be in tracks.features and computed
assert "pos" in tracks.features
assert "t" in tracks.features
- assert tracks.graph.nodes[nodes[0]]["pos"] is not None
+ assert tracks.graph_solution.nodes[nodes[0]]["pos"] is not None
# area and other features should NOT be in tracks.features initially
assert "area" not in tracks.features
@@ -78,9 +82,9 @@ def test_enable_disable_features(graph_2d_with_segmentation):
assert "circularity" in tracks.features
# Verify values are actually computed on the graph
- assert tracks.graph.nodes[nodes[0]]["circularity"] is not None
+ assert tracks.graph_solution.nodes[nodes[0]]["circularity"] is not None
if edges:
- assert None not in tracks.graph.edge_attrs()["iou"].to_list()
+ assert None not in tracks.graph_solution.edge_attrs()["iou"].to_list()
# Disable one feature
tracks.disable_features(["area"])
@@ -92,7 +96,7 @@ def test_enable_disable_features(graph_2d_with_segmentation):
assert "circularity" in tracks.features
# Values no longer exist in the graph for tracksdata
- # assert tracks.graph.nodes[1]["area"] is not None
+ # assert tracks.graph_solution.nodes[1]["area"] is not None
# Disable the remaining enabled features
tracks.disable_features(["pos", "iou", "circularity"])
@@ -122,7 +126,7 @@ def test_area_on_graph_not_auto_activated(graph_2d_with_segmentation):
def test_get_available_features(graph_2d_with_segmentation):
"""Test get_available_features returns all features from all annotators."""
- tracks = SolutionTracks(
+ tracks = Tracks(
graph_2d_with_segmentation,
ndim=3,
**track_attrs,
diff --git a/tests/annotators/test_edge_annotator.py b/tests/annotators/test_edge_annotator.py
index cbef21c7..49fc0d8a 100644
--- a/tests/annotators/test_edge_annotator.py
+++ b/tests/annotators/test_edge_annotator.py
@@ -4,7 +4,7 @@
from funtracks.actions import UpdateNodeSeg, UpdateTrackIDs
from funtracks.annotators import EdgeAnnotator
-from funtracks.data_model import SolutionTracks, Tracks
+from funtracks.data_model import Tracks
track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"}
@@ -42,7 +42,7 @@ def test_compute_all(self, get_graph, ndim):
# Compute values
ann.compute()
for key in all_features:
- assert key in tracks.graph.edge_attr_keys()
+ assert key in tracks.graph_solution.edge_attr_keys()
def test_update_all(self, get_graph, ndim) -> None:
graph = get_graph(ndim, with_seg=True)
@@ -78,7 +78,10 @@ def test_update_all(self, get_graph, ndim) -> None:
with pytest.warns(match="Cannot find label 1 in frame .*"):
UpdateNodeSeg(tracks, node_id, mask, added=False)
- assert tracks.graph.edges[tracks.graph.edge_id(*edge_id)]["iou"] == 0
+ assert (
+ tracks.graph_solution.edges[tracks.graph_solution.edge_id(*edge_id)]["iou"]
+ == 0
+ )
def test_add_remove_feature(self, get_graph, ndim):
graph = get_graph(ndim, with_seg=True)
@@ -130,8 +133,8 @@ def test_missing_seg(self, get_graph, ndim) -> None:
def test_ignores_irrelevant_actions(self, get_graph, ndim):
"""Test that EdgeAnnotator ignores actions that don't affect edges."""
- graph = get_graph(ndim, is_solution=True, with_seg=True)
- tracks = SolutionTracks(
+ graph = get_graph(ndim, prefill_track_ids=True, with_seg=True)
+ tracks = Tracks(
graph,
ndim=ndim,
**track_attrs,
@@ -140,7 +143,7 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim):
node_id = 3
edge = (1, 3)
- edge_id = tracks.graph.edge_id(*edge)
+ edge_id = tracks.graph_solution.edge_id(*edge)
initial_iou = tracks.get_edge_attr(edge, "iou")
# If we recomputed IoU now, it would be different
@@ -155,6 +158,6 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim):
UpdateTrackIDs(tracks, node_id, new_track_id)
# IoU should remain unchanged (no recomputation happened despite seg change)
- assert tracks.graph.edges[edge_id]["iou"] == initial_iou
+ assert tracks.graph_solution.edges[edge_id]["iou"] == initial_iou
# But track_id should be updated
assert tracks.get_track_id(node_id) == new_track_id
diff --git a/tests/annotators/test_features_on_full_graph.py b/tests/annotators/test_features_on_full_graph.py
new file mode 100644
index 00000000..49321289
--- /dev/null
+++ b/tests/annotators/test_features_on_full_graph.py
@@ -0,0 +1,69 @@
+"""Step 5 regression: detection features (regionprops, iou) live on graph_full.
+
+These features are intrinsic to a detection/link and must stay computed for *all*
+nodes/edges — including soft-deleted (solution=False) candidates — so the full and
+solution graphs never drift and candidates are ready for re-solving. Track-id features
+remain solution-only and are covered elsewhere.
+"""
+
+import pytest
+
+from funtracks.actions import DeleteEdge, DeleteNode
+from funtracks.annotators import EdgeAnnotator, RegionpropsAnnotator
+
+
+def _annotator(tracks, cls):
+ return next(ann for ann in tracks.annotators if isinstance(ann, cls))
+
+
+@pytest.mark.parametrize("ndim", [3, 4])
+def test_regionprops_persist_and_recompute_on_soft_deleted_node(get_tracks, ndim):
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
+ rp_ann = _annotator(tracks, RegionpropsAnnotator)
+ tracks.enable_features(["pos", "area"])
+ rp_ann.compute(["area"])
+
+ node = 5 # leaf node of the fixture graph
+ area_before = tracks.get_node_attr(node, "area")
+ assert area_before is not None
+
+ # Soft-delete: leaves the solution view, stays in graph_full as solution=False.
+ DeleteNode(tracks, node)
+ assert node not in tracks.graph_solution.node_ids()
+ assert node in tracks.graph_full.node_ids()
+ assert tracks.graph_full.nodes[node]["solution"] is False
+
+ # NEW: the detection feature is still readable (helpers read graph_full, so this no
+ # longer KeyErrors on an out-of-solution node) and its value is preserved.
+ assert tracks.get_node_attr(node, "area") == area_before
+
+ # NEW: a bulk recompute covers the soft-deleted candidate (iterates graph_full).
+ tracks._set_node_attr(node, "area", None) # wipe to prove recompute reaches it
+ rp_ann.compute(["area"])
+ assert tracks.get_node_attr(node, "area") == area_before
+
+
+@pytest.mark.parametrize("ndim", [3, 4])
+def test_iou_computed_on_soft_deleted_candidate_edge(get_tracks, ndim):
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
+ edge_ann = _annotator(tracks, EdgeAnnotator)
+ tracks.enable_features(["iou"])
+ edge_ann.compute(["iou"])
+
+ edge = (4, 5)
+ iou_before = tracks.get_edge_attr(edge, "iou")
+ assert iou_before is not None
+
+ # Soft-delete the edge: gone from the solution, kept in graph_full as a candidate.
+ DeleteEdge(tracks, edge)
+ assert not tracks.graph_solution.has_edge(*edge)
+ assert tracks.graph_full.has_edge(*edge)
+
+ # NEW: iou is still readable on the candidate edge (get_edge_attr reads graph_full).
+ assert tracks.get_edge_attr(edge, "iou") == iou_before
+
+ # NEW: a bulk recompute reaches the solution=False edge (compute iterates graph_full
+ # successors, which include candidate edges).
+ tracks._set_edge_attr(edge, "iou", 0.0) # wipe to prove recompute reaches it
+ edge_ann.compute(["iou"])
+ assert tracks.get_edge_attr(edge, "iou") == iou_before
diff --git a/tests/annotators/test_regionprops_annotator.py b/tests/annotators/test_regionprops_annotator.py
index 3193dd06..7ee98ee3 100644
--- a/tests/annotators/test_regionprops_annotator.py
+++ b/tests/annotators/test_regionprops_annotator.py
@@ -4,7 +4,7 @@
from funtracks.actions import UpdateNodeSeg, UpdateTrackIDs
from funtracks.annotators import RegionpropsAnnotator
-from funtracks.data_model import SolutionTracks, Tracks
+from funtracks.data_model import Tracks
track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"}
@@ -39,9 +39,9 @@ def test_compute_all(self, get_graph, ndim):
tracks.enable_features(list(rp_ann.all_features.keys()))
for key in rp_ann.all_features:
- assert key in tracks.graph.node_attr_keys()
- for node_id in tracks.graph.node_ids():
- value = tracks.graph.nodes[node_id][key]
+ assert key in tracks.graph_solution.node_attr_keys()
+ for node_id in tracks.graph_solution.node_ids():
+ value = tracks.graph_solution.nodes[node_id][key]
assert value is not None
def test_update_all(self, get_graph, ndim):
@@ -69,7 +69,7 @@ def test_update_all(self, get_graph, ndim):
UpdateNodeSeg(tracks, node_id, removal, added=False)
assert tracks.get_node_attr(node_id, "area") == expected_area
for key in rp_ann.features:
- assert key in tracks.graph.node_attr_keys()
+ assert key in tracks.graph_solution.node_attr_keys()
# segmentation is fully erased and you try to update
node_id = 1
@@ -80,8 +80,8 @@ def test_update_all(self, get_graph, ndim):
UpdateNodeSeg(tracks, node_id, mask, added=False)
# all regionprops features should be the defaults, because seg doesn't exist
for key in rp_ann.features:
- actual = tracks.graph.nodes[node_id][key]
- expected = tracks.graph._node_attr_schemas()[key].default_value
+ actual = tracks.graph_solution.nodes[node_id][key]
+ expected = tracks.graph_solution._node_attr_schemas()[key].default_value
# Convert to numpy arrays for comparison (handles both scalar and array types)
actual_np = np.asarray(actual)
expected_np = np.asarray(expected)
@@ -105,7 +105,7 @@ def test_add_remove_feature(self, get_graph, ndim):
tracks.disable_features([to_remove_key])
rp_ann.compute()
- assert to_remove_key not in tracks.graph.node_attr_keys()
+ assert to_remove_key not in tracks.graph_solution.node_attr_keys()
# add it back in
tracks.enable_features([to_remove_key])
@@ -155,7 +155,7 @@ def test_centroid_world_coords_with_scale(self, get_graph, ndim):
# Force recomputation so regionprops runs with the given scale as spacing
tracks.enable_features(["pos"])
- pos = np.array(tracks.graph.nodes[6]["pos"])
+ pos = np.array(tracks.graph_solution.nodes[6]["pos"])
expected = pixel_centroid * np.array(scale[1:])
bug_value = np.array([1.5] * len(pixel_centroid)) * np.array(
@@ -176,8 +176,8 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim):
"""Test that RegionpropsAnnotator ignores actions that don't affect
segmentation.
"""
- graph = get_graph(ndim, is_solution=True, with_seg=True)
- tracks = SolutionTracks(
+ graph = get_graph(ndim, prefill_track_ids=True, with_seg=True)
+ tracks = Tracks(
graph,
ndim=ndim,
**track_attrs,
@@ -192,7 +192,9 @@ def test_ignores_irrelevant_actions(self, get_graph, ndim):
# RegionpropsAnnotator, it would recompute area back to initial_area and
# the assertion below would fail.
fake_area = initial_area + 999
- tracks.graph.update_node_attrs(attrs={"area": [fake_area]}, node_ids=[node_id])
+ tracks.graph_solution.update_node_attrs(
+ attrs={"area": [fake_area]}, node_ids=[node_id]
+ )
original_track_id = tracks.get_track_id(node_id)
new_track_id = original_track_id + 100
diff --git a/tests/annotators/test_track_annotator.py b/tests/annotators/test_track_annotator.py
index 2f7fa0d2..a2106e19 100644
--- a/tests/annotators/test_track_annotator.py
+++ b/tests/annotators/test_track_annotator.py
@@ -16,7 +16,7 @@
@pytest.mark.parametrize("with_seg", [True, False])
class TestTrackAnnotator:
def test_init(self, get_tracks, ndim, with_seg) -> None:
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
ann = TrackAnnotator(tracks)
# Features start disabled by default
assert len(ann.all_features) == 2
@@ -32,7 +32,7 @@ def test_init(self, get_tracks, ndim, with_seg) -> None:
assert ann.max_tracklet_id == 5
def test_compute_all(self, get_tracks, ndim, with_seg) -> None:
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
ann = TrackAnnotator(tracks, tracklet_key=tracks.features.tracklet_key)
# Enable features
@@ -41,9 +41,9 @@ def test_compute_all(self, get_tracks, ndim, with_seg) -> None:
# Compute values
ann.compute()
- for node in tracks.graph.node_ids():
+ for node in tracks.graph_solution.node_ids():
for key in all_features:
- assert tracks.graph.nodes[node][key] is not None
+ assert tracks.graph_solution.nodes[node][key] is not None
lineages = [
[1, 2, 3, 4, 5],
@@ -69,7 +69,7 @@ def test_compute_all(self, get_tracks, ndim, with_seg) -> None:
assert len({id_set[0] for id_set in id_sets}) == len(id_sets)
def test_add_remove_feature(self, get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
ann = TrackAnnotator(tracks, tracklet_key=tracks.features.tracklet_key)
# Enable features
ann.activate_features(list(ann.all_features.keys()))
@@ -79,7 +79,9 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg):
node_id = 6
edge_id = (4, 6)
attrs = {"iou": 0, "solution": True} if with_seg else {"solution": True}
- tracks.graph.add_edge(source_id=edge_id[0], target_id=edge_id[1], attrs=attrs)
+ tracks.graph_solution.add_edge(
+ source_id=edge_id[0], target_id=edge_id[1], attrs=attrs
+ )
to_remove_key = ann.lineage_key
orig_lin = tracks.get_node_attr(node_id, ann.lineage_key)
orig_tra = tracks.get_node_attr(node_id, ann.tracklet_key)
@@ -98,20 +100,20 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg):
assert tracks.get_node_attr(node_id, ann.lineage_key) != orig_lin
assert tracks.get_node_attr(node_id, ann.tracklet_key) != orig_tra
- def test_invalid(self, get_tracks, ndim, with_seg) -> None:
- # Create regular Tracks (not SolutionTracks) to test error handling
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=False)
- with pytest.raises(
- ValueError, match="Currently the TrackAnnotator only works on SolutionTracks"
- ):
- TrackAnnotator(tracks) # type: ignore
+ def test_always_has_track_annotator(self, get_tracks, ndim, with_seg) -> None:
+ # Every Tracks has a tracklet key and a registered TrackAnnotator, even when
+ # built without explicit track attributes (no "plain" vs "solution" split).
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=False)
+ assert tracks.features.tracklet_key is not None
+ assert tracks.features.lineage_key is not None
+ assert isinstance(tracks.track_annotator, TrackAnnotator)
def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg):
"""Test that TrackAnnotator ignores actions that don't affect track IDs."""
if not with_seg:
pytest.skip("Test requires segmentation")
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
tracks.enable_features(["area", tracks.features.tracklet_key])
node_id = 3
@@ -133,7 +135,7 @@ def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg):
def test_lineage_id_updated_on_add_and_delete_edge(
self, get_tracks, ndim, with_seg
) -> None:
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
tracks.enable_features(["lineage_id"])
# get the existing TrackAnnotator
@@ -155,7 +157,7 @@ def test_lineage_id_updated_on_add_and_delete_edge(
source_node = 3
target_node = 4
- edge = next(e for e in tracks.graph.edge_list() if set(e) == {3, 4})
+ edge = next(e for e in tracks.graph_solution.edge_list() if set(e) == {3, 4})
expected_lineage_id = ann.max_lineage_id + 1
UserDeleteEdge(tracks, edge=edge)
@@ -206,7 +208,7 @@ def test_lineage_id_updated_on_division(self, get_tracks, ndim, with_seg) -> Non
- New child (6): keeps same track_id, gets source's lineage_id
"""
# Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate)
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
tracks.enable_features(["lineage_id"])
ann = next(a for a in tracks.annotators if isinstance(a, TrackAnnotator))
@@ -255,7 +257,7 @@ def test_lineage_id_updated_on_delete_division_edge(
- Orphaned child (2): keeps same track_id, gets new lineage_id
"""
# Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate)
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
tracks.enable_features(["lineage_id"])
ann = next(a for a in tracks.annotators if isinstance(a, TrackAnnotator))
@@ -287,7 +289,7 @@ def test_lineage_id_updated_on_delete_division_edge(
def test_disabled_tracklet_key_does_nothing(self, get_tracks, ndim, with_seg) -> None:
"""Test that TrackAnnotator does nothing when tracklet_key is disabled."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
ann = TrackAnnotator(tracks)
# Don't activate any features - they should all be disabled
@@ -307,3 +309,27 @@ def test_disabled_tracklet_key_does_nothing(self, get_tracks, ndim, with_seg) ->
assert ann.lineage_id_to_nodes == original_lineage_map
assert ann.max_tracklet_id == original_max_tracklet
assert ann.max_lineage_id == original_max_lineage
+
+
+def test_recompute_replaces_stale_bookkeeping(get_tracks):
+ """A tracklet column still holding the -1 sentinel seeds a phantom tracklet -1 in
+ the bookkeeping at annotator init (_get_max_id_and_map only skips None). A full
+ compute() must replace tracklet_id_to_nodes wholesale (like _assign_lineage_ids
+ does), not merge into it — otherwise the phantom entry survives with its nodes
+ duplicated under their real tracklet ids.
+ """
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
+ node_ids = list(tracks.graph_full.node_ids())
+ tracks.graph_full.update_node_attrs(
+ attrs={"track_id": [-1] * len(node_ids)}, node_ids=node_ids
+ )
+
+ ann = TrackAnnotator(tracks, tracklet_key="track_id")
+ ann.activate_features(["track_id", "lineage_id"])
+ ann.compute()
+
+ assert -1 not in ann.tracklet_id_to_nodes
+ # Every solution node appears exactly once across the bookkeeping.
+ all_nodes = sorted(n for ns in ann.tracklet_id_to_nodes.values() for n in ns)
+ assert all_nodes == sorted(tracks.graph_solution.node_ids())
+ assert ann.max_tracklet_id == max(ann.tracklet_id_to_nodes.keys())
diff --git a/tests/benchmarks/bench_candidate_graph.py b/tests/benchmarks/bench_candidate_graph.py
index 7d9d63db..33078e87 100644
--- a/tests/benchmarks/bench_candidate_graph.py
+++ b/tests/benchmarks/bench_candidate_graph.py
@@ -9,7 +9,7 @@
from skimage.draw import disk
from funtracks.candidate_graph.compute_graph import compute_graph_from_seg
-from funtracks.data_model import SolutionTracks
+from funtracks.data_model import Tracks
NUM_FRAMES = 50
FRAME_SHAPE = (700, 1100)
@@ -43,6 +43,28 @@ def seg_data():
return _generate_segmentation()
+@pytest.fixture(scope="module")
+def _warm_spatial_compile():
+ """Warm the spatial_graph rtree JIT compile before the timed benchmark.
+
+ Building a Tracks with segmentation creates a GraphArrayView, which builds a
+ spatial_graph rtree whose Cython module is JIT-compiled by witty on first use and
+ cached process-wide. That compile is a one-time cost (seconds-to-tens-of-seconds on
+ a cold Windows runner) that would otherwise land inside the timed region of
+ test_graph_to_solution. Trigger it here (untimed) on a tiny graph with the same
+ dimensionality so the benchmark measures tracking work, not compilation.
+ """
+ tiny = _generate_segmentation(num_frames=2, frame_shape=(64, 64), cells_per_frame=2)
+ graph = compute_graph_from_seg(tiny, MAX_EDGE_DISTANCE, iou=True)
+ Tracks(
+ graph,
+ time_attr="t",
+ pos_attr="pos",
+ tracklet_attr="tracklet_id",
+ lineage_attr="lineage_id",
+ )
+
+
def test_compute_graph_from_seg(benchmark, seg_data):
benchmark.pedantic(
compute_graph_from_seg,
@@ -53,18 +75,26 @@ def test_compute_graph_from_seg(benchmark, seg_data):
)
-def test_graph_to_solution(benchmark, seg_data):
- """Benchmark candidate graph -> SolutionTracks (tracklet/lineage assignment).
+def test_graph_to_solution(benchmark, seg_data, _warm_spatial_compile):
+ """Benchmark candidate graph -> Tracks with track ids (tracklet/lineage assignment).
Candidate-graph construction is benchmarked separately above and is built here in
- (untimed) setup, so only the SolutionTracks construction -- dominated by
+ (untimed) setup, so only the Tracks construction -- dominated by
TrackAnnotator._assign_tracklet_ids and _assign_lineage_ids -- is measured.
+ Setting tracklet_attr/lineage_attr registers a TrackAnnotator and triggers the
+ track-id computation on construction. The _warm_spatial_compile fixture ensures the
+ one-time spatial_graph rtree JIT compile is not measured here.
"""
def setup():
- # Fresh graph per round: SolutionTracks construction mutates the graph
+ # Fresh graph per round: Tracks construction mutates the graph
# (adds tracklet/lineage IDs), so each measured call must start unannotated.
graph = compute_graph_from_seg(seg_data, MAX_EDGE_DISTANCE, iou=True)
- return (graph,), {"time_attr": "t", "pos_attr": "pos"}
+ return (graph,), {
+ "time_attr": "t",
+ "pos_attr": "pos",
+ "tracklet_attr": "tracklet_id",
+ "lineage_attr": "lineage_id",
+ }
- benchmark.pedantic(SolutionTracks, setup=setup, rounds=1, iterations=1)
+ benchmark.pedantic(Tracks, setup=setup, rounds=1, iterations=1)
diff --git a/tests/candidate_graph/test_compute_graph.py b/tests/candidate_graph/test_compute_graph.py
index 6abfff97..18c0c996 100644
--- a/tests/candidate_graph/test_compute_graph.py
+++ b/tests/candidate_graph/test_compute_graph.py
@@ -18,13 +18,13 @@ def test_graph_from_segmentation_2d(get_tracks):
)
# Same node IDs as the segmentation labels
- assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids())
+ assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids())
# t, pos, area must match the source graph for every node
for node in cand_graph.node_ids():
for key in ["t", "pos", "area"]:
assert np.array(cand_graph.nodes[node][key]) == pytest.approx(
- np.array(tracks.graph.nodes[node][key]), abs=0.01
+ np.array(tracks.graph_solution.nodes[node][key]), abs=0.01
)
# mask and bbox must be present on every node
@@ -39,12 +39,14 @@ def test_graph_from_segmentation_2d(get_tracks):
# because t=3 has no nodes (add_cand_edges only links frame → frame+1)
assert sorted(cand_graph.edge_list()) == [[1, 2], [1, 3], [2, 4], [3, 4]]
- # For edges shared with tracks.graph, iou must agree
+ # For edges shared with tracks.graph_solution, iou must agree
cand_edges = {tuple(e) for e in cand_graph.edge_list()}
- ref_edges = {tuple(e) for e in tracks.graph.edge_list()}
+ ref_edges = {tuple(e) for e in tracks.graph_solution.edge_list()}
for src, tgt in cand_edges & ref_edges:
cand_iou = cand_graph.edges[cand_graph.edge_id(src, tgt)]["iou"]
- ref_iou = tracks.graph.edges[tracks.graph.edge_id(src, tgt)]["iou"]
+ ref_iou = tracks.graph_solution.edges[tracks.graph_solution.edge_id(src, tgt)][
+ "iou"
+ ]
assert cand_iou == pytest.approx(ref_iou, abs=0.01)
# lower edge distance: only (1, 3) is within 15 pixels (~11.2), (1, 2) is ~42 away
@@ -52,7 +54,7 @@ def test_graph_from_segmentation_2d(get_tracks):
segmentation=segmentation_2d,
max_edge_distance=15,
)
- assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids())
+ assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids())
assert sorted(cand_graph.edge_list()) == [[1, 3]]
@@ -65,12 +67,12 @@ def test_graph_from_segmentation_3d(get_tracks):
max_edge_distance=100,
)
- assert set(cand_graph.node_ids()) == set(tracks.graph.node_ids())
+ assert set(cand_graph.node_ids()) == set(tracks.graph_solution.node_ids())
for node in cand_graph.node_ids():
for key in ["t", "pos", "area"]:
assert np.array(cand_graph.nodes[node][key]) == pytest.approx(
- np.array(tracks.graph.nodes[node][key]), abs=0.01
+ np.array(tracks.graph_solution.nodes[node][key]), abs=0.01
)
# mask and bbox must be present on every node
@@ -115,8 +117,8 @@ def test_graph_from_segmentation_t_start(get_tracks):
# t values should match the original tracks graph for shared nodes
for node in cand_graph.node_ids():
- if node in set(tracks.graph.node_ids()):
- assert cand_graph.nodes[node]["t"] == tracks.graph.nodes[node]["t"]
+ if node in set(tracks.graph_solution.node_ids()):
+ assert cand_graph.nodes[node]["t"] == tracks.graph_solution.nodes[node]["t"]
# Edges should still be formed between adjacent (shifted) frames
assert cand_graph.num_edges() > 0
diff --git a/tests/candidate_graph/test_iou.py b/tests/candidate_graph/test_iou.py
index 0fc3e28a..21c7a464 100644
--- a/tests/candidate_graph/test_iou.py
+++ b/tests/candidate_graph/test_iou.py
@@ -43,10 +43,12 @@ def test_add_iou_2d(get_tracks):
add_cand_edges(cand_graph, max_edge_distance=100, node_frame_dict=node_frame_dict)
add_iou(cand_graph, segmentation_2d, node_frame_dict=node_frame_dict)
- # For edges shared with tracks.graph, iou must agree
+ # For edges shared with tracks.graph_solution, iou must agree
cand_edges = {tuple(e) for e in cand_graph.edge_list()}
- ref_edges = {tuple(e) for e in tracks.graph.edge_list()}
+ ref_edges = {tuple(e) for e in tracks.graph_solution.edge_list()}
for src, tgt in cand_edges & ref_edges:
cand_iou = cand_graph.edges[cand_graph.edge_id(src, tgt)]["iou"]
- ref_iou = tracks.graph.edges[tracks.graph.edge_id(src, tgt)]["iou"]
+ ref_iou = tracks.graph_solution.edges[tracks.graph_solution.edge_id(src, tgt)][
+ "iou"
+ ]
assert cand_iou == pytest.approx(ref_iou, abs=0.01)
diff --git a/tests/candidate_graph/test_relabel_segmentation.py b/tests/candidate_graph/test_relabel_segmentation.py
index 9ea20e96..0e32f28d 100644
--- a/tests/candidate_graph/test_relabel_segmentation.py
+++ b/tests/candidate_graph/test_relabel_segmentation.py
@@ -30,7 +30,7 @@ def test_relabel_segmentation(get_tracks):
segmentation = np.asarray(tracks.segmentation)
# Use only nodes 1 and 2 (single tracklet: node 1 at t=0, node 2 at t=1)
- subgraph = tracks.graph.filter(node_ids=[1, 2]).subgraph()
+ subgraph = tracks.graph_solution.filter(node_ids=[1, 2]).subgraph()
relabeled = relabel_segmentation_with_track_id(subgraph, segmentation)
# Nodes 1 and 2 form one tracklet → both get label 1
diff --git a/tests/conftest.py b/tests/conftest.py
index 29477e6e..1224486a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,13 +9,13 @@
from tracksdata.nodes import Mask
from funtracks.utils.tracksdata_utils import (
- create_empty_graphview_graph,
+ create_empty_graph,
)
if TYPE_CHECKING:
from typing import Any
- from funtracks.data_model import SolutionTracks, Tracks
+ from funtracks.data_model import Tracks
def make_2d_disk_mask(center=(50, 50), radius=20) -> Mask:
@@ -134,7 +134,7 @@ def _make_graph(
with_iou: bool = False,
with_masks: bool = False,
database: str | None = None,
-) -> td.graph.GraphView:
+) -> td.graph.BaseGraph:
"""Generate a test graph with configurable features.
Args:
@@ -172,7 +172,7 @@ def _make_graph(
node_attributes.append(td.DEFAULT_ATTR_KEYS.BBOX)
node_default_values.append(0.0)
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=node_attributes,
node_default_values=node_default_values,
edge_attributes=edge_attributes,
@@ -288,28 +288,28 @@ def _make_graph(
@pytest.fixture
-def graph_clean(tmp_path) -> td.graph.GraphView:
+def graph_clean(tmp_path) -> td.graph.BaseGraph:
"""Base graph with only time - no positions or computed features."""
db_path = str(tmp_path / "graph_clean.db")
return _make_graph(ndim=3, database=db_path)
@pytest.fixture
-def graph_2d_with_position(tmp_path) -> td.graph.GraphView:
+def graph_2d_with_position(tmp_path) -> td.graph.BaseGraph:
"""Graph with 2D positions - for Tracks without segmentation."""
db_path = str(tmp_path / "graph_2d_position.db")
return _make_graph(ndim=3, with_pos=True, database=db_path)
@pytest.fixture
-def graph_2d_with_track_id(tmp_path) -> td.graph.GraphView:
- """Graph with 2D positions and track_id - for SolutionTracks without segmentation."""
+def graph_2d_with_track_id(tmp_path) -> td.graph.BaseGraph:
+ """Graph with 2D positions and track_id - for Tracks without segmentation."""
db_path = str(tmp_path / "graph_2d_track_id.db")
return _make_graph(ndim=3, with_pos=True, with_track_id=True, database=db_path)
@pytest.fixture
-def graph_2d_with_segmentation(tmp_path) -> td.graph.GraphView:
+def graph_2d_with_segmentation(tmp_path) -> td.graph.BaseGraph:
"""Graph with segmentation (masks/bboxes) and all computed features."""
db_path = str(tmp_path / "graph_2d_segmentation.db")
return _make_graph(
@@ -324,21 +324,21 @@ def graph_2d_with_segmentation(tmp_path) -> td.graph.GraphView:
@pytest.fixture
-def graph_3d_with_position(tmp_path) -> td.graph.GraphView:
+def graph_3d_with_position(tmp_path) -> td.graph.BaseGraph:
"""Graph with 3D positions - for Tracks without segmentation."""
db_path = str(tmp_path / "graph_3d_position.db")
return _make_graph(ndim=4, with_pos=True, database=db_path)
@pytest.fixture
-def graph_3d_with_track_id(tmp_path) -> td.graph.GraphView:
- """Graph with 3D positions and track_id - for SolutionTracks without segmentation."""
+def graph_3d_with_track_id(tmp_path) -> td.graph.BaseGraph:
+ """Graph with 3D positions and track_id - for Tracks without segmentation."""
db_path = str(tmp_path / "graph_3d_track_id.db")
return _make_graph(ndim=4, with_pos=True, with_track_id=True, database=db_path)
@pytest.fixture
-def graph_3d_with_segmentation(tmp_path) -> td.graph.GraphView:
+def graph_3d_with_segmentation(tmp_path) -> td.graph.BaseGraph:
"""Graph with segmentation (masks/bboxes) and all computed features."""
db_path = str(tmp_path / "graph_3d_segmentation.db")
return _make_graph(
@@ -353,22 +353,24 @@ def graph_3d_with_segmentation(tmp_path) -> td.graph.GraphView:
@pytest.fixture
-def get_tracks(get_graph) -> Callable[..., "Tracks | SolutionTracks"]:
- """Factory fixture to create Tracks or SolutionTracks instances.
+def get_tracks(get_graph) -> Callable[..., "Tracks"]:
+ """Factory fixture to create Tracks or Tracks instances.
Returns a factory function that can be called with:
ndim: 3 for 2D spatial + time, 4 for 3D spatial + time
with_seg: Whether to include segmentation (mask/bbox as node attributes)
- is_solution: Whether to return SolutionTracks instead of Tracks
+ prefill_track_ids: If True, the fixture graph ships with track_id/lineage_id
+ values already set (activated as-is); if False, they are computed from the
+ graph topology. Either way the resulting Tracks has track ids.
Example:
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
Note:
Uses a pre-built FeatureDict to avoid recomputing features that already
exist in the test graph fixtures.
"""
- from funtracks.data_model import SolutionTracks, Tracks
+ from funtracks.data_model import Tracks
from funtracks.features import (
Area,
FeatureDict,
@@ -385,12 +387,14 @@ def get_tracks(get_graph) -> Callable[..., "Tracks | SolutionTracks"]:
def _make_tracks(
ndim: int,
with_seg: bool = True,
- is_solution: bool = False,
- ) -> Tracks | SolutionTracks:
+ prefill_track_ids: bool = False,
+ ) -> Tracks:
# Determine axis names based on ndim
axis_names = ["z", "y", "x"] if ndim == 4 else ["y", "x"]
- graph = get_graph(ndim=ndim, is_solution=is_solution, with_seg=with_seg)
+ graph = get_graph(
+ ndim=ndim, prefill_track_ids=prefill_track_ids, with_seg=with_seg
+ )
# Build FeatureDict based on what exists in the graph
features_dict: dict[str, Any] = {
@@ -404,7 +408,7 @@ def _make_tracks(
features_dict["bbox"] = SegBbox(ndim)
features_dict["area"] = Area(ndim=ndim)
features_dict["iou"] = IoU()
- if is_solution:
+ if prefill_track_ids:
features_dict["track_id"] = TrackletID()
features_dict["lineage_id"] = LineageID()
@@ -412,31 +416,24 @@ def _make_tracks(
features=features_dict,
time_key="t",
position_key="pos",
- tracklet_key="track_id" if is_solution else None,
- lineage_key="lineage_id" if is_solution else None,
+ tracklet_key="track_id" if prefill_track_ids else None,
+ lineage_key="lineage_id" if prefill_track_ids else None,
)
- # Create the appropriate Tracks type with pre-built FeatureDict
- if is_solution:
- return SolutionTracks(
- graph,
- ndim=ndim,
- features=feature_dict,
- )
- else:
- return Tracks(
- graph,
- ndim=ndim,
- features=feature_dict,
- )
+ # Create the Tracks with the pre-built FeatureDict.
+ return Tracks(
+ graph,
+ ndim=ndim,
+ features=feature_dict,
+ )
return _make_tracks
@pytest.fixture
-def graph_2d_list(tmp_path) -> td.graph.GraphView:
+def graph_2d_list(tmp_path) -> td.graph.BaseGraph:
db_path = str(tmp_path / "graph_2d_list.db")
- graph = create_empty_graphview_graph(database=db_path)
+ graph = create_empty_graph(database=db_path)
nodes = [
{
@@ -475,13 +472,14 @@ def sphere(center, radius, shape):
@pytest.fixture
-def get_graph(tmp_path) -> Callable[..., td.graph.GraphView]:
+def get_graph(tmp_path) -> Callable[..., td.graph.BaseGraph]:
"""Factory fixture to create a graph with configurable features.
Args:
ndim: 3 for 2D spatial + time, 4 for 3D spatial + time
with_pos: Include position attribute (default True)
- is_solution: Include track_id and lineage_id (default False)
+ prefill_track_ids: Include track_id and lineage_id columns on the graph
+ (default False)
with_seg: Include mask, bbox, area, and iou (default False)
Returns:
@@ -489,22 +487,22 @@ def get_graph(tmp_path) -> Callable[..., td.graph.GraphView]:
Example:
graph = get_graph(ndim=3, with_seg=True)
- graph = get_graph(ndim=4, is_solution=True, with_seg=True)
+ graph = get_graph(ndim=4, prefill_track_ids=True, with_seg=True)
"""
counter = [0]
def _get_graph(
ndim: int = 3,
with_pos: bool = True,
- is_solution: bool = False,
+ prefill_track_ids: bool = False,
with_seg: bool = False,
- ) -> td.graph.GraphView:
+ ) -> td.graph.BaseGraph:
counter[0] += 1
db_path = str(tmp_path / f"graph_{counter[0]}.db")
return _make_graph(
ndim=ndim,
with_pos=with_pos,
- with_track_id=is_solution,
+ with_track_id=prefill_track_ids,
with_area=with_seg,
with_iou=with_seg,
with_masks=with_seg,
diff --git a/tests/data_model/test_soft_delete_roundtrip.py b/tests/data_model/test_soft_delete_roundtrip.py
new file mode 100644
index 00000000..852130d9
--- /dev/null
+++ b/tests/data_model/test_soft_delete_roundtrip.py
@@ -0,0 +1,163 @@
+"""Step 6: soft-delete round-trip + repeated delete<->undo stability.
+
+Soft-delete keeps the node/edge in ``graph_full`` (flag ``solution=False``) and only drops
+it from ``graph_solution``. These tests pin the core Phase-1 guarantees:
+- the full graph's topology is preserved across a leaf-node soft-delete (only flags flip);
+- delete -> undo restores the solution view exactly (ids + attributes);
+- repeated undo/redo through the real ``ActionHistory`` never drifts the view (invariant
+ #4 in the persistence plan) — the R2 in-place revive must be a true inverse of remove.
+
+A leaf node (5, edge 4->5) is used so the delete introduces no reconnection skip-edge; the
+skip-edge case is exercised in ``test_mid_track_delete_*`` to document that graph_full
+accumulates the skip edge as a candidate.
+"""
+
+import pytest
+
+from funtracks.user_actions import UserDeleteNode
+
+
+def _solution_state(tracks):
+ """A hashable snapshot of the solution view's topology."""
+ return (
+ tuple(sorted(tracks.graph_solution.node_ids())),
+ tuple(sorted(tracks.graph_solution.edge_list())),
+ )
+
+
+def _full_state(tracks):
+ return (
+ tuple(sorted(tracks.graph_full.node_ids())),
+ tuple(sorted(tracks.graph_full.edge_list())),
+ )
+
+
+def _positions(tracks):
+ """pos per solution node as plain float tuples (array-safe for == comparison)."""
+ return {
+ n: tuple(float(x) for x in tracks.get_node_attr(n, "pos"))
+ for n in tracks.graph_solution.node_ids()
+ }
+
+
+@pytest.mark.parametrize("ndim", [3, 4])
+@pytest.mark.parametrize("with_seg", [True, False])
+def test_soft_delete_keeps_leaf_node_in_full_graph(get_tracks, ndim, with_seg):
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ full_before = _full_state(tracks)
+ sol_nodes_before = set(tracks.graph_solution.node_ids())
+
+ UserDeleteNode(tracks, 5)
+
+ # Gone from the solution view ...
+ assert 5 not in tracks.graph_solution.node_ids()
+ assert set(tracks.graph_solution.node_ids()) == sol_nodes_before - {5}
+ assert not tracks.graph_solution.has_edge(4, 5)
+
+ # ... but preserved in the full graph as a soft-deleted candidate. Topology of the
+ # full graph is unchanged: only the solution flag flipped.
+ assert 5 in tracks.graph_full.node_ids()
+ assert tracks.graph_full.has_edge(4, 5)
+ assert tracks.graph_full.nodes[5]["solution"] is False
+ assert _full_state(tracks) == full_before
+
+
+@pytest.mark.parametrize("ndim", [3, 4])
+@pytest.mark.parametrize("with_seg", [True, False])
+def test_delete_undo_redo_roundtrip_identity(get_tracks, ndim, with_seg):
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ sol_ref = _solution_state(tracks)
+ full_ref = _full_state(tracks)
+ pos_ref = _positions(tracks)
+
+ UserDeleteNode(tracks, 5)
+ assert _solution_state(tracks) != sol_ref # actually changed
+
+ # Undo restores the solution view exactly, and the full graph is untouched.
+ assert tracks.undo()
+ assert _solution_state(tracks) == sol_ref
+ assert _full_state(tracks) == full_ref
+ assert _positions(tracks) == pos_ref
+
+ # Redo re-deletes; undo restores again — same states.
+ assert tracks.redo()
+ assert 5 not in tracks.graph_solution.node_ids()
+ assert tracks.undo()
+ assert _solution_state(tracks) == sol_ref
+
+
+@pytest.mark.parametrize("ndim", [3, 4])
+def test_repeated_delete_undo_is_stable(get_tracks, ndim):
+ """Invariant #4: N undo/redo cycles must not drift the solution view or full graph."""
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
+ sol_ref = _solution_state(tracks)
+ iou_ref = tracks.get_edge_attr((4, 5), "iou")
+
+ UserDeleteNode(tracks, 5)
+ deleted_sol = _solution_state(tracks)
+ full_after_delete = _full_state(tracks)
+
+ for _ in range(5):
+ assert tracks.undo()
+ assert _solution_state(tracks) == sol_ref
+ # Revived edge keeps its computed attribute.
+ assert tracks.get_edge_attr((4, 5), "iou") == iou_ref
+
+ assert tracks.redo()
+ assert _solution_state(tracks) == deleted_sol
+ # Full graph topology is identical every redo — no candidate accumulation.
+ assert _full_state(tracks) == full_after_delete
+
+ # Leave it restored and confirm a clean final identity.
+ assert tracks.undo()
+ assert _solution_state(tracks) == sol_ref
+
+
+@pytest.mark.parametrize("ndim", [3, 4])
+def test_mid_track_delete_leaves_skip_edge_candidate_in_full(get_tracks, ndim):
+ """Deleting a mid-track node adds a reconnection skip-edge (3->5). On undo it is
+ soft-deleted, so it persists in graph_full as a solution=False candidate while the
+ solution view returns to its original topology."""
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
+ sol_ref = _solution_state(tracks)
+ assert not tracks.graph_full.has_edge(3, 5)
+
+ UserDeleteNode(tracks, 4)
+ assert tracks.graph_solution.has_edge(3, 5)
+ assert not tracks.graph_solution.has_edge(3, 4)
+
+ assert tracks.undo()
+ # Solution view is back to the original topology (skip edge removed from the view) ...
+ assert _solution_state(tracks) == sol_ref
+ assert not tracks.graph_solution.has_edge(3, 5)
+ # ... but the skip edge now lives in graph_full as a candidate; node 4 is retained.
+ assert tracks.graph_full.has_edge(3, 5)
+ assert tracks.graph_full.edge_id(3, 5) is not None
+ assert 4 in tracks.graph_full.node_ids()
+
+
+@pytest.mark.parametrize("ndim", [3, 4])
+def test_attr_reads_resolve_for_soft_deleted_node(get_tracks, ndim):
+ """Regression: attribute reads must resolve for soft-deleted (solution=False) nodes,
+ and the bulk `get_positions` must agree with the single-node `get_position`. The bulk
+ path previously queried `graph_solution` and KeyError'd on a soft-deleted node while
+ `get_position` (graph_full) succeeded — a latent inconsistency invisible to tests that
+ only ever query in-solution nodes.
+ """
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
+ pos_single_before = tracks.get_position(5)
+ pos_bulk_before = tracks.get_positions([5])[0].tolist()
+ assert pos_bulk_before == pytest.approx(pos_single_before)
+
+ UserDeleteNode(tracks, 5)
+ assert 5 not in tracks.graph_solution.node_ids() # soft-deleted
+
+ # Both single and bulk position reads still resolve (graph_full) and agree.
+ pos_single_after = tracks.get_position(5)
+ pos_bulk_after = tracks.get_positions([5])[0].tolist()
+ assert pos_single_after == pytest.approx(pos_single_before)
+ assert pos_bulk_after == pytest.approx(pos_single_before)
+
+ # Other intrinsic attrs resolve too.
+ assert tracks.get_node_attr(5, "area") is not None
+ assert tracks.get_time(5) is not None
diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py
index 063fba80..4f1afa81 100644
--- a/tests/data_model/test_solution_tracks.py
+++ b/tests/data_model/test_solution_tracks.py
@@ -2,11 +2,11 @@
import polars as pl
from funtracks.actions import AddNode
-from funtracks.data_model import SolutionTracks, Tracks
+from funtracks.data_model import Tracks
from funtracks.import_export import export_to_csv
from funtracks.user_actions import UserUpdateSegmentation
from funtracks.utils.tracksdata_utils import (
- create_empty_graphview_graph,
+ create_empty_graph,
td_mask_to_pixels,
)
@@ -14,7 +14,7 @@
def test_recompute_track_ids(graph_2d_with_track_id):
- tracks = SolutionTracks(
+ tracks = Tracks(
graph_2d_with_track_id,
ndim=3,
**track_attrs,
@@ -22,8 +22,20 @@ def test_recompute_track_ids(graph_2d_with_track_id):
assert tracks.get_next_track_id() == 6
+def test_recompute_track_ids_when_sentinel_present(graph_2d_with_track_id):
+ """A tracklet column that exists but still holds the -1 sentinel means "not
+ computed": Tracks must recompute track ids from topology, not trust the sentinels."""
+ graph = graph_2d_with_track_id
+ node_ids = list(graph.node_ids())
+ graph.update_node_attrs(attrs={"track_id": [-1] * len(node_ids)}, node_ids=node_ids)
+
+ tracks = Tracks(graph, ndim=3, **track_attrs)
+
+ assert -1 not in tracks.get_track_ids(tracks.nodes())
+
+
def test_next_track_id(graph_2d_with_track_id):
- tracks = SolutionTracks(graph_2d_with_track_id, ndim=3, **track_attrs)
+ tracks = Tracks(graph_2d_with_track_id, ndim=3, **track_attrs)
assert tracks.get_next_track_id() == 6
AddNode(
tracks,
@@ -33,47 +45,8 @@ def test_next_track_id(graph_2d_with_track_id):
assert tracks.get_next_track_id() == 11
-def test_from_tracks_cls(graph_2d_with_segmentation):
- tracks = Tracks(
- graph_2d_with_segmentation,
- ndim=3,
- pos_attr="POSITION",
- time_attr="TIME",
- tracklet_attr=track_attrs["tracklet_attr"],
- scale=(2, 2, 2),
- )
- solution_tracks = SolutionTracks.from_tracks(tracks)
- assert solution_tracks.graph == tracks.graph
- assert solution_tracks.segmentation == tracks.segmentation
- assert solution_tracks.features.time_key == tracks.features.time_key
- assert solution_tracks.features.position_key == tracks.features.position_key
- assert solution_tracks.scale == tracks.scale
- assert solution_tracks.ndim == tracks.ndim
- assert solution_tracks.get_node_attr(6, tracks.features.tracklet_key) == 5
-
-
-def test_from_tracks_cls_recompute(graph_2d_with_segmentation):
- tracks = Tracks(
- graph_2d_with_segmentation,
- ndim=3,
- pos_attr="POSITION",
- time_attr="TIME",
- tracklet_attr=track_attrs["tracklet_attr"],
- scale=(2, 2, 2),
- )
- # delete track id (default value -1) on one node triggers reassignment of
- # track_ids even when recompute is False.
- tracks.graph.nodes[1][tracks.features.tracklet_key] = -1
- solution_tracks = SolutionTracks.from_tracks(tracks)
- # should have reassigned new track_id to node 6
- assert solution_tracks.get_node_attr(6, solution_tracks.features.tracklet_key) == 4
- assert (
- solution_tracks.get_node_attr(1, solution_tracks.features.tracklet_key) == 1
- ) # still 1
-
-
def test_update_segmentation(graph_2d_with_segmentation):
- tracks = SolutionTracks(
+ tracks = Tracks(
graph_2d_with_segmentation,
ndim=3,
**track_attrs,
@@ -90,42 +63,27 @@ def test_update_segmentation(graph_2d_with_segmentation):
def test_next_track_id_empty():
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=["pos", "track_id"],
edge_attributes=[],
)
- tracks = SolutionTracks(graph, ndim=4, **track_attrs)
+ tracks = Tracks(graph, ndim=4, **track_attrs)
assert tracks.get_next_track_id() == 1
-def test_get_lineage_id_without_lineage_key(graph_2d_with_track_id):
- """Test that get_lineage_id returns None when lineage_key is not set."""
- graph = graph_2d_with_track_id
- graph.add_node(
- attrs={"t": 1, "pos": [0, 0], "track_id": 1}, index=7, validate_keys=False
- )
- tracks = SolutionTracks(graph, ndim=3, **track_attrs)
-
- # Unset lineage_key to test the None path
- tracks.features.lineage_key = None
-
- # get_lineage_id should return None when lineage_key is not set
- assert tracks.get_lineage_id(1) is None
-
-
def test_export_to_csv_with_display_names(
graph_2d_with_segmentation, graph_3d_with_segmentation, tmp_path
):
"""Test CSV export with use_display_names=True option."""
# Test 2D with display names
- tracks = SolutionTracks(graph_2d_with_segmentation, **track_attrs, ndim=3)
+ tracks = Tracks(graph_2d_with_segmentation, **track_attrs, ndim=3)
tracks.enable_features(["area"])
temp_file = tmp_path / "test_export_2d_display.csv"
export_to_csv(tracks, temp_file, use_display_names=True)
with open(temp_file) as f:
lines = f.readlines()
- assert len(lines) == tracks.graph.num_nodes() + 1 # add header
+ assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header
# With display names: ID, Parent ID, Time, y, x, Tracklet ID,
# Lineage ID, Area
@@ -142,14 +100,14 @@ def test_export_to_csv_with_display_names(
assert lines[0].strip().split(",") == header
# Test 3D with display names (area display name is "Volume" in 3D)
- tracks = SolutionTracks(graph_3d_with_segmentation, **track_attrs, ndim=4)
+ tracks = Tracks(graph_3d_with_segmentation, **track_attrs, ndim=4)
tracks.enable_features(["area"])
temp_file = tmp_path / "test_export_3d_display.csv"
export_to_csv(tracks, temp_file, use_display_names=True)
with open(temp_file) as f:
lines = f.readlines()
- assert len(lines) == tracks.graph.num_nodes() + 1 # add header
+ assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header
# With display names: ID, Parent ID, Time, z, y, x,
# Tracklet ID, Lineage ID, Volume
@@ -171,7 +129,7 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation):
"""pos_attr as list should be respected even when segmentation is present.
Scenario: graph has both a "pos" column AND individual z/y/x columns with
- distinct values. SolutionTracks(pos_attr=['z','y','x']) should use z/y/x
+ distinct values. Tracks(pos_attr=['z','y','x']) should use z/y/x
as the position_key, not fall back to "pos".
"""
graph = graph_3d_with_segmentation
@@ -187,7 +145,7 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation):
graph.nodes[node]["y"] = float(pos[1]) + offset
graph.nodes[node]["x"] = float(pos[2]) + offset
- tracks = SolutionTracks(
+ tracks = Tracks(
graph=graph,
pos_attr=["z", "y", "x"],
ndim=4,
@@ -204,5 +162,5 @@ def test_multi_axis_pos_attr_with_segmentation(graph_3d_with_segmentation):
expected = [float(original_pos[i]) + offset for i in range(3)]
assert list(pos_from_tracks) == expected, (
f"Expected positions from z/y/x ({expected}), "
- f"got {list(pos_from_tracks)} — SolutionTracks used 'pos' instead"
+ f"got {list(pos_from_tracks)} — Tracks used 'pos' instead"
)
diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py
index 74dd717b..06f35b81 100644
--- a/tests/data_model/test_tracks.py
+++ b/tests/data_model/test_tracks.py
@@ -8,16 +8,16 @@
from funtracks.features import SegBbox, SegMask
from funtracks.user_actions import UserUpdateSegmentation
from funtracks.utils.tracksdata_utils import (
- create_empty_graphview_graph,
+ create_empty_graph,
to_polars_dtype,
)
track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"}
-def test_create_tracks(graph_3d_with_segmentation: td.graph.GraphView):
+def test_create_tracks(graph_3d_with_segmentation: td.graph.BaseGraph):
# create empty tracks
- empty_graph = create_empty_graphview_graph()
+ empty_graph = create_empty_graph()
tracks = Tracks(graph=empty_graph, ndim=3, **track_attrs) # type: ignore[arg-type]
assert tracks.features.position_key == "pos"
assert isinstance(tracks.features["pos"], dict)
@@ -96,7 +96,7 @@ def test_nodes_edges(graph_2d_with_segmentation):
assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6}
assert len(tracks.edges()) == 4 # rx graph starts from 0, sql from 1,
# so direct comparison of edges depends on backend
- assert set(map(tuple, tracks.graph.edge_list())) == {
+ assert set(map(tuple, tracks.graph_solution.edge_list())) == {
(1, 2),
(1, 3),
(3, 4),
@@ -104,18 +104,6 @@ def test_nodes_edges(graph_2d_with_segmentation):
}
-def test_degrees(graph_2d_with_segmentation):
- tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs)
- assert tracks.in_degree(np.array([1])) == 0
- assert tracks.in_degree(np.array([4])) == 1
- assert np.array_equal(tracks.in_degree(None), np.array([0, 1, 1, 1, 1, 0]))
- assert np.array_equal(tracks.out_degree(np.array([1, 4])), np.array([2, 1]))
- assert np.array_equal(
- tracks.out_degree(None),
- np.array([2, 0, 1, 1, 0, 0]),
- )
-
-
def test_predecessors_successors(graph_2d_with_segmentation):
tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs)
assert tracks.predecessors(2) == [1]
@@ -202,7 +190,7 @@ def test_set_pixels_no_segmentation(graph_2d_with_track_id):
def test_compute_ndim_errors():
- g = create_empty_graphview_graph()
+ g = create_empty_graph()
g.add_node_attr_key("pos", default_value=[0, 0], dtype=pl.List(pl.Int64))
g.add_node(index=1, attrs={"t": 0, "pos": [0, 0, 0], "solution": True})
@@ -225,7 +213,7 @@ def test_get_new_node_ids(graph_2d_with_position):
assert 1 not in ids # existing nodes skipped
assert 2 not in ids
for node_id in ids:
- assert not tracks.graph.has_node(node_id)
+ assert not tracks.graph_solution.has_node(node_id)
# second call must not overlap with first
ids2 = tracks._get_new_node_ids(2)
@@ -241,7 +229,9 @@ def test_undo_redo(graph_2d_with_segmentation):
assert tracks.redo() is False
# Perform an action - add a custom attribute
- tracks.graph.add_node_attr_key("custom_label", default_value=None, dtype=pl.Object)
+ tracks.graph_solution.add_node_attr_key(
+ "custom_label", default_value=None, dtype=pl.Object
+ )
action1 = UpdateNodeAttrs(tracks, node=1, attrs={"custom_label": "test_value"})
tracks.action_history.add_new_action(action1)
@@ -262,7 +252,9 @@ def test_undo_redo(graph_2d_with_segmentation):
assert tracks.redo() is False
# Perform another action
- tracks.graph.add_node_attr_key("another_label", default_value=None, dtype=pl.Object)
+ tracks.graph_solution.add_node_attr_key(
+ "another_label", default_value=None, dtype=pl.Object
+ )
action2 = UpdateNodeAttrs(tracks, node=2, attrs={"another_label": "second_value"})
tracks.action_history.add_new_action(action2)
assert tracks.get_node_attr(2, "another_label") == "second_value"
@@ -310,17 +302,17 @@ def test_to_polars_dtype_mask():
def test_add_feature_mask_creates_both_columns():
"""add_feature with mask and bbox Features creates both columns."""
- graph = create_empty_graphview_graph(ndim=3)
+ graph = create_empty_graph(ndim=3)
tracks = Tracks(graph, ndim=3, **track_attrs)
- assert "nuc_mask" not in tracks.graph.node_attr_keys()
- assert "nuc_bbox" not in tracks.graph.node_attr_keys()
+ assert "nuc_mask" not in tracks.graph_solution.node_attr_keys()
+ assert "nuc_bbox" not in tracks.graph_solution.node_attr_keys()
tracks.add_feature("nuc_mask", SegMask(ndim=3, bbox_key="nuc_bbox"))
tracks.add_feature("nuc_bbox", SegBbox(ndim=3))
- assert "nuc_mask" in tracks.graph.node_attr_keys()
- assert "nuc_bbox" in tracks.graph.node_attr_keys()
+ assert "nuc_mask" in tracks.graph_solution.node_attr_keys()
+ assert "nuc_bbox" in tracks.graph_solution.node_attr_keys()
assert "nuc_mask" in tracks.features
assert "nuc_bbox" in tracks.features
@@ -333,15 +325,15 @@ def test_delete_feature_mask_removes_both_columns(
assert "mask" in tracks.features
assert "bbox" in tracks.features
- assert "mask" in tracks.graph.node_attr_keys()
- assert "bbox" in tracks.graph.node_attr_keys()
+ assert "mask" in tracks.graph_solution.node_attr_keys()
+ assert "bbox" in tracks.graph_solution.node_attr_keys()
tracks.delete_feature("mask")
assert "mask" not in tracks.features
assert "bbox" not in tracks.features
- assert "mask" not in tracks.graph.node_attr_keys()
- assert "bbox" not in tracks.graph.node_attr_keys()
+ assert "mask" not in tracks.graph_solution.node_attr_keys()
+ assert "bbox" not in tracks.graph_solution.node_attr_keys()
def test_update_mask_syncs_bbox(graph_2d_with_segmentation):
@@ -353,8 +345,8 @@ def test_update_mask_syncs_bbox(graph_2d_with_segmentation):
new_mask = make_2d_disk_mask(center=(30, 30), radius=10)
tracks.update_mask(1, new_mask)
- stored_mask = tracks.graph.nodes[1][td.DEFAULT_ATTR_KEYS.MASK]
- stored_bbox = tracks.graph.nodes[1][td.DEFAULT_ATTR_KEYS.BBOX]
+ stored_mask = tracks.graph_solution.nodes[1][td.DEFAULT_ATTR_KEYS.MASK]
+ stored_bbox = tracks.graph_solution.nodes[1][td.DEFAULT_ATTR_KEYS.BBOX]
assert stored_mask is new_mask
assert np.array_equal(stored_bbox, new_mask.bbox)
diff --git a/tests/import_export/test_csv_export.py b/tests/import_export/test_csv_export.py
index d12beec3..ae13a721 100644
--- a/tests/import_export/test_csv_export.py
+++ b/tests/import_export/test_csv_export.py
@@ -16,14 +16,14 @@
)
def test_export_solution_to_csv(get_tracks, tmp_path, ndim, expected_header):
"""Test exporting tracks to CSV."""
- tracks = get_tracks(ndim=ndim, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=False, prefill_track_ids=True)
temp_file = tmp_path / "test_export.csv"
export_to_csv(tracks, temp_file)
with open(temp_file) as f:
lines = f.readlines()
- assert len(lines) == tracks.graph.num_nodes() + 1 # add header
+ assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header
assert lines[0].strip().split(",") == expected_header
# Check first data line (node 1: t=0, pos=[50, 50] or [50, 50, 50], track_id=1)
@@ -46,7 +46,7 @@ def test_export_solution_to_csv_with_seg_zarr(
get_tracks, tmp_path, ndim, expected_header
):
"""Test exporting tracks to CSV + segmentation painted by track_id as zarr."""
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
temp_file = tmp_path / "test_export.csv"
seg_dir = tmp_path / "test_export_seg"
export_to_csv(tracks, temp_file, export_seg=True, seg_path=seg_dir)
@@ -54,7 +54,7 @@ def test_export_solution_to_csv_with_seg_zarr(
with open(temp_file) as f:
lines = f.readlines()
- assert len(lines) == tracks.graph.num_nodes() + 1 # add header
+ assert len(lines) == tracks.graph_solution.num_nodes() + 1 # add header
assert lines[0].strip().split(",") == expected_header
# check the segmentation zarr
@@ -66,14 +66,16 @@ def test_export_solution_to_csv_with_seg_zarr(
seg_arr = seg_zarr[:]
unique_vals = set(seg_arr.flatten()) - {0}
label_key = tracks.features.tracklet_key
- track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list())
+ track_ids = set(
+ tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list()
+ )
assert unique_vals == track_ids
@pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"])
def test_export_solution_to_csv_with_seg_tiff(get_tracks, tmp_path, ndim):
"""Test exporting tracks to CSV + segmentation as tiff painted by tracklet ID."""
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
temp_file = tmp_path / "test_export.csv"
seg_file = tmp_path / "test_export_seg.tif"
export_to_csv(
@@ -91,14 +93,16 @@ def test_export_solution_to_csv_with_seg_tiff(get_tracks, tmp_path, ndim):
# values should be tracklet_ids (not node_ids) — default seg_relabel="tracklet"
unique_vals = set(seg_arr.flatten()) - {0}
label_key = tracks.features.tracklet_key
- track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list())
+ track_ids = set(
+ tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list()
+ )
assert unique_vals == track_ids
@pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"])
def test_export_solution_to_csv_with_seg_original_labels(get_tracks, tmp_path, ndim):
"""Test exporting tracks to CSV + segmentation with original (node_id) labels."""
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
temp_file = tmp_path / "test_export.csv"
seg_dir = tmp_path / "test_export_seg"
export_to_csv(
@@ -116,7 +120,7 @@ def test_export_solution_to_csv_with_seg_original_labels(get_tracks, tmp_path, n
# values should be node_ids (original labels), not track_ids
seg_arr = seg_zarr[:]
unique_vals = set(seg_arr.flatten()) - {0}
- node_ids = set(tracks.graph.node_ids())
+ node_ids = set(tracks.graph_solution.node_ids())
assert unique_vals == node_ids
@@ -124,11 +128,11 @@ def test_export_with_color_dict(get_tracks, tmp_path):
"""Test exporting with a color_dict adds a Tracklet ID Color column."""
import numpy as np
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
temp_file = tmp_path / "test_export_colors.csv"
# Build a color dict: node_id → [R, G, B] floats in [0, 1]
- node_ids = list(tracks.graph.node_ids())
+ node_ids = list(tracks.graph_solution.node_ids())
color_dict = {
node_id: np.array([0.1 * (i % 10), 0.5, 0.9])
for i, node_id in enumerate(node_ids)
@@ -146,7 +150,7 @@ def test_export_with_color_dict(get_tracks, tmp_path):
def test_export_with_display_names(get_tracks, tmp_path):
"""Test exporting with display names."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
temp_file = tmp_path / "test_export_display.csv"
export_to_csv(tracks, temp_file, use_display_names=True)
@@ -161,7 +165,7 @@ def test_export_with_display_names(get_tracks, tmp_path):
def test_export_filtered_nodes(get_tracks, tmp_path):
"""Test exporting only specific nodes."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
temp_file = tmp_path / "test_export_filtered.csv"
# Export only nodes 1 and 2 (and their ancestors)
@@ -177,7 +181,7 @@ def test_export_filtered_nodes(get_tracks, tmp_path):
def test_ignore_edge_features_at_export(get_tracks, tmp_path):
"""Test that edge features are ignored when exporting to csv"""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
temp_file = tmp_path / "test_export_node_features_only.csv"
# enable node and edge features
@@ -209,7 +213,7 @@ def test_export_solution_to_csv_with_seg_and_node_subset(
and segmentation must only include the resulting graph nodes, and nothing else.
"""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
csv_file = tmp_path / "export.csv"
seg_path = tmp_path / "seg"
@@ -266,8 +270,10 @@ def test_export_solution_to_csv_with_seg_and_node_subset(
else:
label_key = tracks.features.tracklet_key
- labels = tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list()
- node_to_label = dict(zip(tracks.graph.node_ids(), labels, strict=True))
+ labels = tracks.graph_solution.node_attrs(attr_keys=[label_key])[
+ label_key
+ ].to_list()
+ node_to_label = dict(zip(tracks.graph_solution.node_ids(), labels, strict=True))
expected = np.zeros_like(original)
@@ -278,3 +284,41 @@ def test_export_solution_to_csv_with_seg_and_node_subset(
expected_vals = {node_to_label[n] for n in graph_nodes}
assert unique_vals == expected_vals
+
+
+def test_export_full_vs_solution(get_tracks, tmp_path):
+ """export_full=True includes soft-deleted (solution=False) nodes and surfaces the
+ 'Solution' column; the default exports only the solution view."""
+ import pandas as pd
+
+ from funtracks.user_actions import UserDeleteNode
+
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
+ # Soft-delete a leaf node: dropped from the solution view, kept in graph_full.
+ UserDeleteNode(tracks, 5)
+ sol_n = tracks.graph_solution.num_nodes()
+ full_n = tracks.graph_full.num_nodes()
+ assert full_n == sol_n + 1
+
+ # Default: solution view only — node 5 absent.
+ sol_file = tmp_path / "sol.csv"
+ export_to_csv(tracks, sol_file)
+ sol = pd.read_csv(sol_file)
+ assert len(sol) == sol_n
+ assert 5 not in set(sol["id"])
+
+ # export_full=True: full graph — node 5 present.
+ full_file = tmp_path / "full.csv"
+ export_to_csv(tracks, full_file, export_full=True)
+ full = pd.read_csv(full_file)
+ assert len(full) == full_n
+ assert 5 in set(full["id"])
+
+ # With display names, the full export carries a 'Solution' column distinguishing
+ # the soft-deleted node (False) from the rest (True).
+ disp_file = tmp_path / "full_disp.csv"
+ export_to_csv(tracks, disp_file, export_full=True, use_display_names=True)
+ disp = pd.read_csv(disp_file)
+ assert "Solution" in disp.columns
+ assert not bool(disp.loc[disp["ID"] == 5, "Solution"].iloc[0])
+ assert bool(disp.loc[disp["ID"] == 1, "Solution"].iloc[0])
diff --git a/tests/import_export/test_csv_import.py b/tests/import_export/test_csv_import.py
index 613dc6b5..18fcd66f 100644
--- a/tests/import_export/test_csv_import.py
+++ b/tests/import_export/test_csv_import.py
@@ -2,7 +2,7 @@
import pandas as pd
import pytest
-from funtracks.data_model import SolutionTracks
+from funtracks.data_model import Tracks
from funtracks.import_export import tracks_from_df
@@ -42,9 +42,9 @@ def test_import_2d(self, simple_df_2d):
"""Test importing 2D DataFrame."""
tracks = tracks_from_df(simple_df_2d)
- assert isinstance(tracks, SolutionTracks)
- assert tracks.graph.num_nodes() == 4
- assert tracks.graph.num_edges() == 3
+ assert isinstance(tracks, Tracks)
+ assert tracks.graph_solution.num_nodes() == 4
+ assert tracks.graph_solution.num_edges() == 3
assert tracks.ndim == 3
def test_import_3d(self, df_3d):
@@ -52,7 +52,7 @@ def test_import_3d(self, df_3d):
tracks = tracks_from_df(df_3d)
assert tracks.ndim == 4
- assert tracks.graph.num_nodes() == 3
+ assert tracks.graph_solution.num_nodes() == 3
# Check z coordinate
pos = tracks.get_position(1)
assert len(pos) == 3 # z, y, x
@@ -80,12 +80,12 @@ def test_edges_created(self, simple_df_2d):
tracks = tracks_from_df(simple_df_2d)
# Check specific edges exist
- assert tracks.graph.has_edge(1, 2)
- assert tracks.graph.has_edge(1, 3)
- assert tracks.graph.has_edge(2, 4)
+ assert tracks.graph_solution.has_edge(1, 2)
+ assert tracks.graph_solution.has_edge(1, 3)
+ assert tracks.graph_solution.has_edge(2, 4)
# Check node 1 has two children (division)
- assert len(list(tracks.graph.successors(1))) == 2
+ assert len(list(tracks.graph_solution.successors(1))) == 2
class TestSegmentationHandling:
@@ -162,8 +162,8 @@ def test_single_node(self):
tracks = tracks_from_df(df)
- assert tracks.graph.num_nodes() == 1
- assert tracks.graph.num_edges() == 0
+ assert tracks.graph_solution.num_nodes() == 1
+ assert tracks.graph_solution.num_edges() == 0
def test_multiple_roots(self):
"""Test multiple independent lineages."""
@@ -179,11 +179,15 @@ def test_multiple_roots(self):
tracks = tracks_from_df(df)
- assert tracks.graph.num_nodes() == 4
- assert tracks.graph.num_edges() == 2
+ assert tracks.graph_solution.num_nodes() == 4
+ assert tracks.graph_solution.num_edges() == 2
# Should have two root nodes
- roots = [n for n in tracks.graph.node_ids() if tracks.graph.in_degree(n) == 0]
+ roots = [
+ n
+ for n in tracks.graph_solution.node_ids()
+ if len(tracks.predecessors(n)) == 0
+ ]
assert len(roots) == 2
def test_division_nan_parent(self):
@@ -207,10 +211,10 @@ def test_division_nan_parent(self):
tracks = tracks_from_df(df)
- assert tracks.graph.num_nodes() == 3
- assert tracks.graph.num_edges() == 2
+ assert tracks.graph_solution.num_nodes() == 3
+ assert tracks.graph_solution.num_edges() == 2
- children = list(tracks.graph.successors(1))
+ children = list(tracks.graph_solution.successors(1))
assert set(children) == {2, 3}
def test_division_event(self):
@@ -227,11 +231,11 @@ def test_division_event(self):
tracks = tracks_from_df(df)
- assert tracks.graph.num_nodes() == 3
- assert tracks.graph.num_edges() == 2
+ assert tracks.graph_solution.num_nodes() == 3
+ assert tracks.graph_solution.num_edges() == 2
# Node 1 should have two children
- children = list(tracks.graph.successors(1))
+ children = list(tracks.graph_solution.successors(1))
assert len(children) == 2
assert set(children) == {2, 3}
@@ -249,19 +253,23 @@ def test_long_track(self):
tracks = tracks_from_df(df)
- assert tracks.graph.num_nodes() == 10
- assert tracks.graph.num_edges() == 9
+ assert tracks.graph_solution.num_nodes() == 10
+ assert tracks.graph_solution.num_edges() == 9
# Should form a single linear chain
- roots = [n for n in tracks.graph.node_ids() if tracks.graph.in_degree(n) == 0]
+ roots = [
+ n
+ for n in tracks.graph_solution.node_ids()
+ if len(tracks.predecessors(n)) == 0
+ ]
assert len(roots) == 1
# Each non-leaf node should have exactly one child
non_leaves = [
- n for n in tracks.graph.node_ids() if tracks.graph.out_degree(n) > 0
+ n for n in tracks.graph_solution.node_ids() if len(tracks.successors(n)) > 0
]
for node in non_leaves:
- assert tracks.graph.out_degree(node) == 1
+ assert len(tracks.successors(node)) == 1
def test_orphaned_node_raises_error(self):
"""Test that node with invalid parent_id raises error."""
@@ -361,8 +369,8 @@ def test_seg_id_same_as_id(self, simple_df_2d):
tracks = tracks_from_df(simple_df_2d, node_name_map=name_map)
# Both id and seg_id should be present with same values
- assert tracks.graph.num_nodes() == 4
- for node_id in tracks.graph.node_ids():
+ assert tracks.graph_solution.num_nodes() == 4
+ for node_id in tracks.graph_solution.node_ids():
assert tracks.get_node_attr(node_id, "seg_id") == node_id
def test_duplicate_mapping_with_segmentation(self, simple_df_2d):
@@ -654,7 +662,7 @@ def test_empty_list_in_name_map_removed(self):
tracks = tracks_from_df(df, node_name_map=name_map)
assert tracks is not None
# The empty mapping should not result in a feature being added
- assert "ellipse_axis_radii" not in tracks.graph.node_attr_keys()
+ assert "ellipse_axis_radii" not in tracks.graph_solution.node_attr_keys()
def test_import_without_position_with_segmentation(self):
"""Test that position can be omitted when segmentation is provided.
@@ -687,8 +695,8 @@ def test_import_without_position_with_segmentation(self):
assert tracks is not None
# Position should be computed from segmentation centroids
- assert "pos" in tracks.graph.node_attr_keys()
- pos_1 = tracks.graph.nodes[1]["pos"]
+ assert "pos" in tracks.graph_solution.node_attr_keys()
+ pos_1 = tracks.graph_solution.nodes[1]["pos"]
# Centroid of 3x3 region at [2:5, 2:5] is approximately [3, 3]
np.testing.assert_array_almost_equal(pos_1, [3.0, 3.0], decimal=0)
diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py
index b67d9df6..05c8814d 100644
--- a/tests/import_export/test_export_to_geff.py
+++ b/tests/import_export/test_export_to_geff.py
@@ -4,7 +4,7 @@
import tifffile
import zarr
-from funtracks.data_model import SolutionTracks, Tracks
+from funtracks.data_model import Tracks
from funtracks.import_export import export_to_geff, import_from_geff, write_to_geff
@@ -32,13 +32,13 @@ def _assert_valid_geff_export(export_dir, expected_num_nodes=None):
@pytest.mark.parametrize("seg_relabel", ["tracklet", "lineage", None])
def test_export_segmentation_relabel(get_tracks, ndim, seg_relabel, tmp_path):
"""Test segmentation export with each relabel strategy."""
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
export_to_geff(tracks, export_dir, seg_relabel=seg_relabel)
- z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes())
+ z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes())
# Segmentation file must exist
seg_path = export_dir / "segmentation"
@@ -54,11 +54,13 @@ def test_export_segmentation_relabel(get_tracks, ndim, seg_relabel, tmp_path):
else:
label_key = tracks.features.tracklet_key
label_vals = set(
- tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list()
+ tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list()
)
assert unique_vals == label_vals
else:
- assert unique_vals == set(tracks.graph.node_ids())
+ # values should be original node_ids
+ node_ids_set = set(tracks.graph_solution.node_ids())
+ assert unique_vals == node_ids_set
# segmentation_shape must be in metadata
attrs = dict(z.attrs)
@@ -68,13 +70,13 @@ def test_export_segmentation_relabel(get_tracks, ndim, seg_relabel, tmp_path):
def test_export_no_segmentation_saved(get_tracks, tmp_path):
"""Test that save_segmentation=False suppresses segmentation file."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
export_to_geff(tracks, export_dir, save_segmentation=False)
- z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes())
+ z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes())
assert not (export_dir / "segmentation").exists()
@@ -85,13 +87,13 @@ def test_export_no_segmentation_saved(get_tracks, tmp_path):
def test_export_without_seg_on_tracks(get_tracks, tmp_path):
"""Test export when tracks have no segmentation at all."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
export_to_geff(tracks, export_dir)
- z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes())
+ z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes())
assert not (export_dir / "segmentation").exists()
@@ -99,32 +101,21 @@ def test_export_without_seg_on_tracks(get_tracks, tmp_path):
assert "segmentation_shape" not in attrs
-@pytest.mark.parametrize("seg_relabel", ["tracklet", "lineage"])
-def test_export_seg_relabel_non_solution_raises(get_tracks, seg_relabel, tmp_path):
- """Relabeling by tracklet/lineage on non-solution tracks raises ValueError."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=False)
-
- export_dir = tmp_path / "export"
- export_dir.mkdir()
- with pytest.raises(ValueError):
- export_to_geff(tracks, export_dir, seg_relabel=seg_relabel)
-
-
def test_export_segmentation_non_solution(get_tracks, tmp_path):
"""Non-solution tracks export segmentation fine with seg_relabel=None."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=False)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=False)
export_dir = tmp_path / "export"
export_dir.mkdir()
export_to_geff(tracks, export_dir, seg_relabel=None)
- z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes())
+ z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes())
# No relabel: segmentation pixels keep original node_ids
seg_zarr = zarr.open(str(export_dir / "segmentation"), mode="r")
assert isinstance(seg_zarr, zarr.Array)
assert seg_zarr.shape == tracks.segmentation.shape
- assert set(seg_zarr[:].flatten()) - {0} == set(tracks.graph.node_ids())
+ assert set(seg_zarr[:].flatten()) - {0} == set(tracks.graph_solution.node_ids())
attrs = dict(z.attrs)
assert "segmentation_shape" in attrs
@@ -135,10 +126,10 @@ def test_export_segmentation_non_solution(get_tracks, tmp_path):
@pytest.mark.parametrize("ndim", [3, 4])
-@pytest.mark.parametrize("is_solution", [True, False])
-def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path):
+@pytest.mark.parametrize("prefill_track_ids", [True, False])
+def test_export_split_position_attrs(get_graph, ndim, prefill_track_ids, tmp_path):
"""Test export with split (list) position attributes."""
- graph = get_graph(ndim, is_solution=is_solution, with_seg=False)
+ graph = get_graph(ndim, prefill_track_ids=prefill_track_ids, with_seg=False)
pos_keys = ["y", "x"] if ndim == 3 else ["z", "y", "x"]
for key in pos_keys:
@@ -149,12 +140,11 @@ def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path):
graph.nodes[node][key] = pos[i]
graph.remove_node_attr_key("pos")
- tracks_cls = SolutionTracks if is_solution else Tracks
- tracks = tracks_cls(
+ tracks = Tracks(
graph,
time_attr="t",
pos_attr=pos_keys,
- tracklet_attr="track_id" if is_solution else None,
+ tracklet_attr="track_id" if prefill_track_ids else None,
ndim=ndim,
)
@@ -162,7 +152,7 @@ def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path):
export_dir.mkdir()
export_to_geff(tracks, export_dir, save_segmentation=False)
- z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes())
+ z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes())
# Verify axis names include the split position keys
axes = dict(z.attrs)["geff"]["axes"]
@@ -177,7 +167,7 @@ def test_export_split_position_attrs(get_graph, ndim, is_solution, tmp_path):
@pytest.mark.parametrize("ndim", [3, 4])
def test_export_node_subset(get_tracks, ndim, tmp_path):
"""Test exporting a subset of nodes includes ancestors."""
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
@@ -208,7 +198,7 @@ def test_export_node_subset(get_tracks, ndim, tmp_path):
def test_export_node_subset_seg_relabel(get_tracks, tmp_path):
"""Test subset export with relabeled segmentation."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
@@ -224,8 +214,10 @@ def test_export_node_subset_seg_relabel(get_tracks, tmp_path):
original = np.asarray(tracks.segmentation[:])
label_key = tracks.features.tracklet_key
- labels = tracks.graph.node_attrs(attr_keys=[label_key])[label_key]
- node_to_label = dict(zip(tracks.graph.node_ids(), labels.to_list(), strict=True))
+ labels = tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key]
+ node_to_label = dict(
+ zip(tracks.graph_solution.node_ids(), labels.to_list(), strict=True)
+ )
graph_nodes = set(expected_graph_nodes)
expected = np.zeros_like(original)
@@ -240,7 +232,7 @@ def test_export_node_subset_seg_relabel(get_tracks, tmp_path):
def test_export_node_subset_without_seg(get_tracks, tmp_path):
"""Test subset export when tracks have no segmentation."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
@@ -258,7 +250,7 @@ def test_export_node_subset_without_seg(get_tracks, tmp_path):
def test_export_overwrite(get_tracks, tmp_path):
"""Test export with overwrite=True into non-empty directory."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
@@ -266,7 +258,7 @@ def test_export_overwrite(get_tracks, tmp_path):
export_to_geff(tracks, export_dir, overwrite=True)
- _assert_valid_geff_export(export_dir, tracks.graph.num_nodes())
+ _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes())
# Segmentation is still written correctly alongside the overwritten dir
seg_zarr = zarr.open(str(export_dir / "segmentation"), mode="r")
@@ -276,7 +268,7 @@ def test_export_overwrite(get_tracks, tmp_path):
def test_export_non_directory_raises(get_tracks, tmp_path):
"""Test that exporting to a file path (not a directory) raises an error."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
file_path = tmp_path / "not_a_dir"
file_path.write_text("test")
@@ -292,13 +284,13 @@ def test_export_non_directory_raises(get_tracks, tmp_path):
@pytest.mark.parametrize("with_seg", [True, False])
def test_export_metadata(get_tracks, ndim, with_seg, tmp_path):
"""Test axes structure, segmentation_shape, and FeatureDict in metadata."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
export_to_geff(tracks, export_dir, save_segmentation=with_seg)
- z = _assert_valid_geff_export(export_dir, tracks.graph.num_nodes())
+ z = _assert_valid_geff_export(export_dir, tracks.graph_solution.num_nodes())
attrs = dict(z.attrs)
# Correct number of axes
@@ -327,7 +319,7 @@ def test_export_metadata(get_tracks, ndim, with_seg, tmp_path):
@pytest.mark.parametrize("ndim", [3, 4], ids=["2d", "3d"])
def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path):
"""Test that segmentation can be exported as tiff alongside the geff graph."""
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
export_dir = tmp_path / "export"
export_dir.mkdir()
@@ -344,7 +336,9 @@ def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path):
# values should be tracklet_ids (default seg_relabel="tracklet")
unique_vals = set(seg_arr.flatten()) - {0}
label_key = tracks.features.tracklet_key
- track_ids = set(tracks.graph.node_attrs(attr_keys=[label_key])[label_key].to_list())
+ track_ids = set(
+ tracks.graph_solution.node_attrs(attr_keys=[label_key])[label_key].to_list()
+ )
assert unique_vals == track_ids
# Check metadata references the tiff path with ../../ prefix (sibling of geff dir)
@@ -360,16 +354,16 @@ def test_export_to_geff_seg_tiff(get_tracks, ndim, tmp_path):
@pytest.mark.parametrize("with_seg", [False, True])
def test_write_to_geff_roundtrip(get_tracks, ndim, with_seg, tmp_path):
"""write_to_geff then import_from_geff recovers the tracks."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
geff_path = tmp_path / "my_tracks.geff"
write_to_geff(tracks, geff_path)
loaded = import_from_geff(geff_path)
- assert loaded.graph.num_nodes() == tracks.graph.num_nodes()
- assert loaded.graph.num_edges() == tracks.graph.num_edges()
- assert set(loaded.graph.node_ids()) == set(tracks.graph.node_ids())
+ assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes()
+ assert loaded.graph_solution.num_edges() == tracks.graph_solution.num_edges()
+ assert set(loaded.graph_solution.node_ids()) == set(tracks.graph_solution.node_ids())
assert loaded.features.dump_json() == tracks.features.dump_json()
if with_seg:
@@ -383,7 +377,7 @@ def test_write_to_geff_roundtrip(get_tracks, ndim, with_seg, tmp_path):
def test_write_to_geff_no_parent_container(get_tracks, tmp_path):
"""write_to_geff writes directly to the path — no parent .zgroup or
tracks.geff subfolder."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
geff_path = tmp_path / "my_tracks.geff"
write_to_geff(tracks, geff_path)
@@ -400,19 +394,19 @@ def test_write_to_geff_no_parent_container(get_tracks, tmp_path):
def test_write_to_geff_overwrite(get_tracks, tmp_path):
"""write_to_geff with overwrite=True replaces existing store."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
geff_path = tmp_path / "my_tracks.geff"
write_to_geff(tracks, geff_path)
write_to_geff(tracks, geff_path, overwrite=True)
loaded = import_from_geff(geff_path)
- assert loaded.graph.num_nodes() == tracks.graph.num_nodes()
+ assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes()
def test_write_to_geff_no_overwrite_raises(get_tracks, tmp_path):
"""write_to_geff with overwrite=False raises on existing store."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
geff_path = tmp_path / "my_tracks.geff"
write_to_geff(tracks, geff_path)
@@ -423,7 +417,7 @@ def test_write_to_geff_no_overwrite_raises(get_tracks, tmp_path):
def test_write_to_geff_metadata(get_tracks, tmp_path):
"""write_to_geff stores axes and FeatureDict metadata."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
geff_path = tmp_path / "my_tracks.geff"
write_to_geff(tracks, geff_path)
@@ -441,7 +435,7 @@ def test_write_to_geff_metadata(get_tracks, tmp_path):
def test_write_to_geff_segmentation_shape(get_tracks, tmp_path):
"""write_to_geff writes segmentation_shape when masks are present."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
geff_path = tmp_path / "my_tracks.geff"
write_to_geff(tracks, geff_path)
diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py
index 5d6fdbd0..75cc3984 100644
--- a/tests/import_export/test_import_from_geff.py
+++ b/tests/import_export/test_import_from_geff.py
@@ -5,10 +5,10 @@
import zarr
from geff.testing.data import create_mock_geff
-from funtracks.data_model import SolutionTracks
+from funtracks.data_model import Tracks
from funtracks.import_export import export_to_geff, import_from_geff
from funtracks.import_export.geff._import import GeffTracksBuilder, import_graph_from_geff
-from funtracks.utils.tracksdata_utils import create_empty_graphview_graph
+from funtracks.utils.tracksdata_utils import create_empty_graph
@pytest.fixture
@@ -262,7 +262,7 @@ def test_duplicate_values_in_name_map(valid_geff):
tracks = import_from_geff(store, node_name_map)
# Both time and seg_id should be present with same values
- for node_id in tracks.graph.node_ids():
+ for node_id in tracks.graph_solution.node_ids():
assert tracks.get_node_attr(node_id, "seg_id") == tracks.get_node_attr(
node_id, "t"
)
@@ -316,11 +316,11 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation,
assert hasattr(tracks, "segmentation")
assert tracks.segmentation.shape == valid_segmentation.shape
# Get last node by ID (don't rely on iteration order)
- last_node = max(tracks.graph.node_ids())
+ last_node = max(tracks.graph_solution.node_ids())
# With composite pos, position is stored as an array
- pos = tracks.graph.nodes[last_node]["pos"]
+ pos = tracks.graph_solution.nodes[last_node]["pos"]
coords = [
- tracks.graph.nodes[last_node]["t"],
+ tracks.graph_solution.nodes[last_node]["t"],
pos[0], # y
pos[1], # x
]
@@ -333,10 +333,10 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation,
) # test that the seg id has been relabeled
# Check that only requested features are present and area is loaded from geff
- data = tracks.graph.nodes[last_node]
- assert "random_feature" in tracks.graph.node_attr_keys()
- assert "random_feature2" not in tracks.graph.node_attr_keys()
- assert "area" in tracks.graph.node_attr_keys()
+ data = tracks.graph_solution.nodes[last_node]
+ assert "random_feature" in tracks.graph_solution.node_attr_keys()
+ assert "random_feature2" not in tracks.graph_solution.node_attr_keys()
+ assert "area" in tracks.graph_solution.node_attr_keys()
assert data["area"] == 21 # loaded directly from geff, not recomputed
# Test that import fails with ValueError when invalid seg_ids are provided.
@@ -418,8 +418,8 @@ def test_features_loaded_from_name_map(valid_geff, valid_segmentation, tmp_path)
assert key in tracks.features
# Get last node by ID (don't rely on iteration order)
- max_node_id = max(tracks.graph.node_ids())
- data = tracks.graph.nodes[max_node_id]
+ max_node_id = max(tracks.graph_solution.node_ids())
+ data = tracks.graph_solution.nodes[max_node_id]
# All requested features should be present and loaded from geff
for key in feature_keys:
@@ -472,7 +472,7 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path):
import tracksdata as td
from tracksdata.nodes import Mask
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=[
"pos",
"area",
@@ -501,13 +501,15 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path):
indices=[1],
)
# The graph carries segmentation_shape in its metadata (set by motile-tracker),
- # but no dense segmentation array is attached to the SolutionTracks object.
+ # but no dense segmentation array is attached to the Tracks object.
graph._update_metadata(segmentation_shape=(5, 100, 100))
run_dir = tmp_path / "run"
run_dir.mkdir()
- st = SolutionTracks(graph, ndim=3, time_attr="t")
+ st = Tracks(
+ graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id"
+ )
export_to_geff(st, run_dir)
tracks_path = run_dir / "tracks.geff"
@@ -541,13 +543,13 @@ def test_import_from_geff_roundtrip_auto_axes(tmp_path):
# import_from_geff must read segmentation_shape back from zarr attrs and
# reconstruct a segmentation (GraphArrayView) — not return segmentation=None.
tracks = import_from_geff(tracks_path)
- assert tracks.graph.num_nodes() == 1
+ assert tracks.graph_solution.num_nodes() == 1
assert tracks.segmentation is not None, (
"segmentation should be reconstructed from masks after round-trip"
)
assert tracks.segmentation.shape == (5, 100, 100)
- node1 = tracks.graph.nodes[1]
+ node1 = tracks.graph_solution.nodes[1]
assert node1["pos"] is not None
np.testing.assert_array_almost_equal(node1["pos"], [50.0, 50.0])
@@ -600,7 +602,7 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path):
import tracksdata as td
import zarr as _zarr
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=[
"pos",
td.DEFAULT_ATTR_KEYS.MASK,
@@ -625,7 +627,9 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path):
run_dir = tmp_path / "run"
run_dir.mkdir()
- st = SolutionTracks(graph, ndim=3, time_attr="t")
+ st = Tracks(
+ graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id"
+ )
export_to_geff(st, run_dir)
tracks_path = run_dir / "tracks.geff"
@@ -652,11 +656,11 @@ def test_import_from_geff_warns_missing_segmentation_shape(tmp_path):
def test_get_time_works_after_import(valid_geff):
- """Regression test: tracks.get_time() must work on a SolutionTracks returned by
+ """Regression test: tracks.get_time() must work on a Tracks returned by
import_from_geff().
Previously, TracksBuilder.build() stored time as "t" in the graph (tracksdata
- convention) but created SolutionTracks(time_attr=TIME_ATTR) where TIME_ATTR="time".
+ convention) but created Tracks(time_attr=TIME_ATTR) where TIME_ATTR="time".
This caused features.time_key="time" while the graph only had attribute "t",
making get_time() raise KeyError: 'time'.
"""
@@ -664,13 +668,13 @@ def test_get_time_works_after_import(valid_geff):
name_map = {"time": "t", "pos": ["y", "x"]}
tracks = import_from_geff(store, name_map)
- for node_id in tracks.graph.node_ids():
+ for node_id in tracks.graph_solution.node_ids():
# This must not raise KeyError: 'time'
t = tracks.get_time(node_id)
assert isinstance(t, int), f"get_time() should return int, got {type(t)}"
# get_times() on all nodes must also work
- all_node_ids = list(tracks.graph.node_ids())
+ all_node_ids = list(tracks.graph_solution.node_ids())
times = tracks.get_times(all_node_ids)
assert len(times) == len(all_node_ids)
@@ -710,7 +714,7 @@ def test_bool_node_property_schema(geff_with_bool_prop):
name_map = {"time": "t", "pos": ["y", "x"], "is_dividing": "is_dividing"}
tracks = import_from_geff(geff_with_bool_prop, name_map)
- df = tracks.graph.node_attrs(attr_keys=["is_dividing"])
+ df = tracks.graph_solution.node_attrs(attr_keys=["is_dividing"])
assert df["is_dividing"].dtype == pl.Boolean, (
f"Expected pl.Boolean schema for 'is_dividing', got {df['is_dividing'].dtype}. "
"Likely cause: np.bool_ default_value fell through to int in construct_graph()."
@@ -729,10 +733,10 @@ def test_bool_node_property_values(geff_with_bool_prop):
name_map = {"time": "t", "pos": ["y", "x"], "is_dividing": "is_dividing"}
tracks = import_from_geff(geff_with_bool_prop, name_map)
- node_ids = sorted(tracks.graph.node_ids())
+ node_ids = sorted(tracks.graph_solution.node_ids())
expected = [True, False, True, False, True]
for node_id, exp in zip(node_ids, expected, strict=True):
- val = tracks.graph.nodes[node_id]["is_dividing"]
+ val = tracks.graph_solution.nodes[node_id]["is_dividing"]
assert type(val) is bool, (
f"Expected Python bool for 'is_dividing', got {type(val)} for node {node_id}"
"Likely cause: np.bool_ value not cast to bool in construct_graph()."
@@ -744,7 +748,7 @@ def test_3d_pos_survives_sql_roundtrip(tmp_path):
"""Regression test: 3D pos (z, y, x) must keep Array dtype through SQL roundtrip.
The construct_graph() method must pass the correct ndim to
- create_empty_graphview_graph() so the pos schema is Array(Float64, 3) not
+ create_empty_graph() so the pos schema is Array(Float64, 3) not
Array(Float64, 2). A schema mismatch causes SQLGraph.from_other() to
downgrade the column to List(Float64), which breaks downstream callers that
rely on to_numpy() returning a 2D float64 array.
@@ -768,14 +772,16 @@ def test_3d_pos_survives_sql_roundtrip(tmp_path):
tracks = import_from_geff(store, name_map)
# Verify the RX graph has correct Array dtype for 3D pos
- df_rx = tracks.graph.node_attrs(attr_keys=["pos"])
+ df_rx = tracks.graph_solution.node_attrs(attr_keys=["pos"])
assert df_rx["pos"].dtype == pl.Array(pl.Float64, 3), (
f"RX graph pos should be Array(Float64, 3), got {df_rx['pos'].dtype}"
)
# Convert to SQL and reload — this is where the schema mismatch used to surface
db_path = str(tmp_path / "test.db")
- td.graph.SQLGraph.from_other(tracks.graph, drivername="sqlite", database=db_path)
+ td.graph.SQLGraph.from_other(
+ tracks.graph_solution, drivername="sqlite", database=db_path
+ )
sql_graph2 = td.graph.SQLGraph("sqlite", db_path)
df_sql = sql_graph2.node_attrs(attr_keys=["pos"])
@@ -824,9 +830,10 @@ def test_geff_legacy_track_id_preserves_tracklet_ids():
def test_geff_roundtrip_preserves_tracklet_ids(get_tracks, tmp_path):
"""End-to-end round-trip: export then import should preserve tracklet IDs."""
- tracks_in = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks_in = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
expected = {
- int(nid): tracks_in.get_track_id(int(nid)) for nid in tracks_in.graph.node_ids()
+ int(nid): tracks_in.get_track_id(int(nid))
+ for nid in tracks_in.graph_solution.node_ids()
}
export_dir = tmp_path / "export"
@@ -871,7 +878,7 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path):
td.DEFAULT_ATTR_KEYS.BBOX,
]
# node_default_values must align with node_attributes by index.
- # pos/mask/bbox are handled by special cases in create_empty_graphview_graph
+ # pos/mask/bbox are handled by special cases in create_empty_graph
# and their slot values here are never accessed; only area, ellipse_axis_radii,
# track_id, and lineage_id go through the general loop.
node_default_values = [
@@ -883,7 +890,7 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path):
None, # mask — special-cased, slot unused
None, # bbox — special-cased, slot unused
]
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=node_attributes,
node_default_values=node_default_values,
edge_attributes=[],
@@ -911,7 +918,9 @@ def test_embedded_seg_ellipse_axis_radii_feature_metadata(tmp_path):
run_dir = tmp_path / "run"
run_dir.mkdir()
- st = SolutionTracks(graph, ndim=3, time_attr="t")
+ st = Tracks(
+ graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id"
+ )
export_to_geff(st, run_dir)
# Remove FeatureDict from GEFF metadata to simulate old/external GEFF
@@ -996,7 +1005,7 @@ def test_featuredict_survives_geff_roundtrip(tmp_path):
None, # mask — special-cased, slot unused
None, # bbox — special-cased, slot unused
]
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=node_attributes,
node_default_values=node_default_values,
edge_attributes=[],
@@ -1024,7 +1033,9 @@ def test_featuredict_survives_geff_roundtrip(tmp_path):
run_dir = tmp_path / "run"
run_dir.mkdir()
- st = SolutionTracks(graph, ndim=3, time_attr="t")
+ st = Tracks(
+ graph, ndim=3, time_attr="t", tracklet_attr="track_id", lineage_attr="lineage_id"
+ )
# Customize metadata that auto-detection cannot reproduce.
pos_key = st.features.position_key
assert isinstance(pos_key, str)
@@ -1050,7 +1061,7 @@ def test_invalid_featuredict_in_geff_falls_back_to_autodetect(get_tracks, tmp_pa
import_from_geff should silently fall back to auto-detection
instead of raising an exception.
"""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
run_dir = tmp_path / "run"
run_dir.mkdir()
export_to_geff(tracks, run_dir, save_segmentation=False)
@@ -1071,11 +1082,13 @@ def test_invalid_featuredict_in_geff_falls_back_to_autodetect(get_tracks, tmp_pa
# Should not raise — falls back to auto-detection
imported = import_from_geff(geff_path)
- # Verify the import produced a working SolutionTracks with auto-detected features
+ # Verify the import produced a working Tracks with auto-detected features
assert imported.features.time_key is not None
assert imported.features.position_key is not None
assert imported.features.tracklet_key is not None
- assert set(tracks.graph.node_ids()) == set(imported.graph.node_ids())
+ assert set(tracks.graph_solution.node_ids()) == set(
+ imported.graph_solution.node_ids()
+ )
def test_subgroup_export_omits_featuredict_and_recomputes_on_import(get_tracks, tmp_path):
@@ -1088,7 +1101,7 @@ def test_subgroup_export_omits_featuredict_and_recomputes_on_import(get_tracks,
On reimport, validation would strip the now-invalid tracklet_id but the
FeatureDict still referenced it, causing KeyError: 'tracklet_id'.
"""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
# Export only a subset: nodes 1, 3, 4, 5 (one branch of the division).
# filter_graph_with_ancestors will include node 1 as ancestor of 3.
@@ -1126,16 +1139,18 @@ def test_subgroup_export_omits_featuredict_and_recomputes_on_import(get_tracks,
},
)
- assert isinstance(imported, SolutionTracks)
+ assert isinstance(imported, Tracks)
assert imported.features.tracklet_key is not None
# The subgraph is a linear chain (1→3→4→5, no divisions), so all nodes
# should share a single tracklet_id and a single lineage_id.
- track_ids = {imported.get_track_id(nid) for nid in imported.graph.node_ids()}
+ track_ids = {imported.get_track_id(nid) for nid in imported.graph_solution.node_ids()}
assert len(track_ids) == 1, (
f"Linear chain should have one tracklet_id, got {track_ids}"
)
- lineage_ids = {imported.get_lineage_id(nid) for nid in imported.graph.node_ids()}
+ lineage_ids = {
+ imported.get_lineage_id(nid) for nid in imported.graph_solution.node_ids()
+ }
assert len(lineage_ids) == 1, (
f"Linear chain should have one lineage_id, got {lineage_ids}"
)
@@ -1155,7 +1170,7 @@ def test_import_from_geff_respects_external_solution_column(tmp_path):
# Separate y/x columns (not a 2-D 'pos' array) so tracksdata's to_geff
# registers them as space-typed axes — required to exercise the
# axes-branch of GeffTracksBuilder.infer_node_name_map.
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=["y", "x"], position_attrs=["y", "x"], ndim=3
)
graph.bulk_add_nodes(
@@ -1167,15 +1182,14 @@ def test_import_from_geff_respects_external_solution_column(tmp_path):
indices=[1, 2, 3],
)
- # Export the root graph, not the filtered solution-only view, so the
- # solution=False row survives into the geff (mimicking a solver-produced
- # geff with rejected nodes).
+ # Export the full base graph so the solution=False row survives into the geff
+ # (mimicking a solver-produced geff with rejected nodes).
tracks_path = tmp_path / "tracks.geff"
- graph._root.to_geff(geff_store=tracks_path, zarr_format=2)
+ graph.to_geff(geff_store=tracks_path, zarr_format=2)
tracks = import_from_geff(tracks_path)
- node_ids = set(tracks.graph.node_ids())
+ node_ids = set(tracks.graph_solution.node_ids())
assert node_ids == {1, 3}, (
"Node 2 has solution=False in the geff file and should be filtered out "
f"by import_from_geff. Got node_ids={node_ids}."
diff --git a/tests/import_export/test_import_segmentation.py b/tests/import_export/test_import_segmentation.py
index fcbe8b5f..088de86c 100644
--- a/tests/import_export/test_import_segmentation.py
+++ b/tests/import_export/test_import_segmentation.py
@@ -7,7 +7,7 @@
load_segmentation,
relabel_segmentation,
)
-from funtracks.utils.tracksdata_utils import create_empty_graphview_graph
+from funtracks.utils.tracksdata_utils import create_empty_graph
class TestLoadSegmentation:
@@ -46,7 +46,7 @@ def test_basic_relabeling(self):
seg[1, 2, 2] = 20 # seg_id 20 at t=1
# Create graph with node_ids 1, 2
- graph = create_empty_graphview_graph()
+ graph = create_empty_graph()
graph.add_node(index=1, attrs={"t": 0, "solution": True})
graph.add_node(index=2, attrs={"t": 1, "solution": True})
@@ -70,7 +70,7 @@ def test_relabeling_with_node_id_zero(self):
seg[1, 2, 2] = 20 # seg_id 20 at t=1
# Create graph with node_ids 0, 1 (includes 0!)
- graph = create_empty_graphview_graph()
+ graph = create_empty_graph()
graph.add_node(index=0, attrs={"t": 0, "solution": True})
graph.add_node(index=1, attrs={"t": 1, "solution": True})
@@ -97,7 +97,7 @@ def test_no_relabeling_needed_same_ids(self):
seg[0, 1, 1] = 1
seg[1, 2, 2] = 2
- graph = create_empty_graphview_graph()
+ graph = create_empty_graph()
graph.add_node(index=1, attrs={"t": 0, "solution": True})
graph.add_node(index=2, attrs={"t": 1, "solution": True})
@@ -119,7 +119,7 @@ def test_multiple_nodes_same_timepoint(self):
seg[0, 2, 2] = 20
seg[0, 3, 3] = 30
- graph = create_empty_graphview_graph()
+ graph = create_empty_graph()
graph.add_node(index=1, attrs={"t": 0, "solution": True})
graph.add_node(index=2, attrs={"t": 0, "solution": True})
graph.add_node(index=3, attrs={"t": 0, "solution": True})
diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py
index 8112cc60..636de386 100644
--- a/tests/import_export/test_internal_format.py
+++ b/tests/import_export/test_internal_format.py
@@ -14,20 +14,20 @@
@pytest.mark.parametrize("with_seg", [True, False])
@pytest.mark.parametrize("ndim", [3, 4])
-@pytest.mark.parametrize("is_solution", [True, False])
+@pytest.mark.parametrize("prefill_track_ids", [True, False])
def test_save_load(
get_tracks,
with_seg,
ndim,
- is_solution,
+ prefill_track_ids,
):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=is_solution)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=prefill_track_ids)
data_path = Path(
- f"tests/data/format_v1/test_save_load_{is_solution}_{ndim}_{with_seg}_0"
+ f"tests/data/format_v1/test_save_load_{prefill_track_ids}_{ndim}_{with_seg}_0"
)
- loaded = load_v1_tracks(data_path, solution=is_solution)
+ loaded = load_v1_tracks(data_path)
assert loaded.ndim == ndim
# Check feature keys and important properties match (allow tuple vs list diff)
assert loaded.features.time_key == tracks.features.time_key
@@ -60,12 +60,9 @@ def test_save_load(
assert loaded.scale == tracks.scale
assert loaded.ndim == tracks.ndim
- if is_solution:
- loaded_annotator = loaded.track_annotator
- tracks_annotator = tracks.track_annotator
- assert (
- loaded_annotator.tracklet_id_to_nodes == tracks_annotator.tracklet_id_to_nodes
- )
+ loaded_annotator = loaded.track_annotator
+ tracks_annotator = tracks.track_annotator
+ assert loaded_annotator.tracklet_id_to_nodes == tracks_annotator.tracklet_id_to_nodes
if with_seg:
assert_array_almost_equal(loaded.segmentation, tracks.segmentation)
@@ -73,27 +70,33 @@ def test_save_load(
assert loaded.segmentation is None
# graphs_equal doesn't exist for TracksData, so we check properties
- assert set(loaded.graph.node_attr_keys()) == set(tracks.graph.node_attr_keys())
- assert set(loaded.graph.edge_attr_keys()) == set(tracks.graph.edge_attr_keys())
- assert loaded.graph.num_nodes() == tracks.graph.num_nodes()
- assert loaded.graph.num_edges() == tracks.graph.num_edges()
- assert set(loaded.graph.node_ids()) == set(tracks.graph.node_ids())
+ assert set(loaded.graph_solution.node_attr_keys()) == set(
+ tracks.graph_solution.node_attr_keys()
+ )
+ assert set(loaded.graph_solution.edge_attr_keys()) == set(
+ tracks.graph_solution.edge_attr_keys()
+ )
+ assert loaded.graph_solution.num_nodes() == tracks.graph_solution.num_nodes()
+ assert loaded.graph_solution.num_edges() == tracks.graph_solution.num_edges()
+ assert set(loaded.graph_solution.node_ids()) == set(tracks.graph_solution.node_ids())
# edge_ids dont matter, only the actual edges:
- assert sorted(loaded.graph.edge_list()) == sorted(tracks.graph.edge_list())
+ assert sorted(loaded.graph_solution.edge_list()) == sorted(
+ tracks.graph_solution.edge_list()
+ )
@pytest.mark.parametrize("with_seg", [True, False])
@pytest.mark.parametrize("ndim", [3, 4])
-@pytest.mark.parametrize("is_solution", [True, False])
+@pytest.mark.parametrize("prefill_track_ids", [True, False])
def test_delete(
get_tracks,
with_seg,
ndim,
- is_solution,
+ prefill_track_ids,
tmp_path,
):
reference_path = Path(
- f"tests/data/format_v1/test_save_load_{is_solution}_{ndim}_{with_seg}_0"
+ f"tests/data/format_v1/test_save_load_{prefill_track_ids}_{ndim}_{with_seg}_0"
)
# Copy reference data to temporary location
@@ -115,7 +118,7 @@ def test_load_without_features(tmp_path, graph_2d_with_segmentation):
shutil.copytree(reference_path, tracks_path)
# Load the original data first to verify it loads correctly
- load_v1_tracks(tracks_path, solution=True)
+ load_v1_tracks(tracks_path)
# Modify the copy to test backward compatibility
attrs_path = tracks_path / "attrs.json"
diff --git a/tests/import_export/test_solution_roundtrip.py b/tests/import_export/test_solution_roundtrip.py
index babf3183..a716fa1c 100644
--- a/tests/import_export/test_solution_roundtrip.py
+++ b/tests/import_export/test_solution_roundtrip.py
@@ -20,11 +20,11 @@ def _roundtrip(tracks, tmp_path, name="rt.geff"):
def test_geff_roundtrip_preserves_solution_schema(get_tracks, tmp_path):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
loaded = _roundtrip(tracks, tmp_path)
- edge_schema = loaded.graph._edge_attr_schemas()["solution"]
- node_schema = loaded.graph._node_attr_schemas()["solution"]
+ edge_schema = loaded.graph_solution._edge_attr_schemas()["solution"]
+ node_schema = loaded.graph_solution._node_attr_schemas()["solution"]
assert edge_schema.dtype == pl.Boolean
assert edge_schema.default_value is True
@@ -33,9 +33,9 @@ def test_geff_roundtrip_preserves_solution_schema(get_tracks, tmp_path):
def test_add_edge_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
loaded = _roundtrip(tracks, tmp_path)
- g = loaded.graph
+ g = loaded.graph_solution
# find any source at frame t and target at t+1 with no edge between them
rows = list(g.node_attrs(attr_keys=["node_id", "t"]).sort("t").iter_rows(named=True))
@@ -61,10 +61,10 @@ def test_add_edge_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path):
def test_add_node_is_solution_true_after_geff_roundtrip(get_tracks, tmp_path):
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
loaded = _roundtrip(tracks, tmp_path)
- new_id = max(loaded.graph.node_ids()) + 1
+ new_id = max(loaded.graph_solution.node_ids()) + 1
AddNode(
loaded,
new_id,
diff --git a/tests/import_export/test_utils.py b/tests/import_export/test_utils.py
index cbfa1bcc..b9e75400 100644
--- a/tests/import_export/test_utils.py
+++ b/tests/import_export/test_utils.py
@@ -3,7 +3,7 @@
def test_rename_feature_basic(get_tracks):
"""Test that rename_feature renames a feature in annotators and features dict."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=False)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=False)
# Rename area feature to custom name
rename_feature(tracks, "area", "my_area")
@@ -19,7 +19,7 @@ def test_rename_feature_basic(get_tracks):
def test_rename_feature_updates_position_key(get_tracks):
"""Test that renaming position feature updates position_key in FeatureDict."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
original_pos_key = tracks.features.position_key
new_key = "custom_position"
@@ -32,7 +32,7 @@ def test_rename_feature_updates_position_key(get_tracks):
def test_rename_feature_updates_tracklet_key(get_tracks):
"""Test that renaming tracklet feature updates tracklet_key in FeatureDict."""
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
original_track_key = tracks.features.tracklet_key
new_key = "custom_track"
diff --git a/tests/user_actions/test_user_actions_force.py b/tests/user_actions/test_user_actions_force.py
index 1096b958..045842f1 100644
--- a/tests/user_actions/test_user_actions_force.py
+++ b/tests/user_actions/test_user_actions_force.py
@@ -7,42 +7,42 @@ def test_user_force_add_downstream(get_tracks):
"""Test force adding a node of which the track id has an upstream division event.
Should break the edges of the division event to allow this new edge."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
# upstream division, with force
attrs = {"t": 2, "track_id": 1, "pos": [3, 4]}
UserAddNode(tracks, node=7, attributes=attrs, force=True)
assert tracks.get_track_id(7) == 1
- assert [1, 2] not in tracks.graph.edge_list()
- assert [1, 3] not in tracks.graph.edge_list()
- assert [1, 7] in tracks.graph.edge_list()
+ assert [1, 2] not in tracks.graph_solution.edge_list()
+ assert [1, 3] not in tracks.graph_solution.edge_list()
+ assert [1, 7] in tracks.graph_solution.edge_list()
def test_user_force_add_upstream(get_tracks):
"""Test force adding a node upstream, of which the track id co-exists with the parent
track id. Should break the edge with the parent track to allow this new edge."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
# downstream parent division, with force
attrs = {"t": 0, "track_id": 3, "pos": [3, 4]}
UserAddNode(tracks, node=7, attributes=attrs, force=True)
assert tracks.get_track_id(7) == 3
- assert [1, 2] in tracks.graph.edge_list() # still there
- assert [1, 3] not in tracks.graph.edge_list() # should be removed
- assert [7, 3] in tracks.graph.edge_list() # new forced edge
+ assert [1, 2] in tracks.graph_solution.edge_list() # still there
+ assert [1, 3] not in tracks.graph_solution.edge_list() # should be removed
+ assert [7, 3] in tracks.graph_solution.edge_list() # new forced edge
def test_auto_assign_new_track_id(get_tracks):
"""Test that adding a node with a track id that already exists at the current time
point raises a warning and auto-assigns a new track id instead."""
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
# existing track id at current time --> allowed, with warning
with pytest.warns(UserWarning, match="Starting a new track, because track id"):
attrs = {"t": 1, "track_id": 2, "pos": [3, 4]} # combination exists already
UserAddNode(tracks, node=7, attributes=attrs)
- assert tracks.graph.has_node(7)
+ assert tracks.graph_solution.has_node(7)
assert tracks.get_track_id(7) == 6 # new assigned track id
diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py
index 921f3e1d..9364b78f 100644
--- a/tests/user_actions/test_user_add_delete_edge.py
+++ b/tests/user_actions/test_user_add_delete_edge.py
@@ -8,32 +8,32 @@
@pytest.mark.parametrize("with_seg", [True, False])
class TestUserAddDeleteEdge:
def test_user_add_edge(self, get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# add an edge from 4 to 6 (will make 4 a division and 5 will need to relabel
# track id)
edge = (4, 6)
old_child = 5
old_track_id = tracks.get_track_id(old_child)
- assert not tracks.graph.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*edge)
action = UserAddEdge(tracks, edge)
- assert tracks.graph.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) != old_track_id
inverse = action.inverse()
- assert not tracks.graph.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) == old_track_id
inverse.inverse()
- assert tracks.graph.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) != old_track_id
def test_user_add_merge_edge(self, get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# add an edge from 2 to 4 (there is already an edge from 3 to 4)
edge = (2, 4)
old_edge = (3, 4)
- assert not tracks.graph.has_edge(*edge)
- assert tracks.graph.has_edge(*old_edge)
+ assert not tracks.graph_solution.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*old_edge)
with pytest.raises(
InvalidActionError, match="Cannot make a merge edge in a tracking solution"
):
@@ -43,37 +43,37 @@ def test_user_add_merge_edge(self, get_tracks, ndim, with_seg):
match="Removing edge .* to add new edge without merging.",
):
action = UserAddEdge(tracks, edge, force=True)
- assert tracks.graph.has_edge(*edge)
- assert not tracks.graph.has_edge(*old_edge)
+ assert tracks.graph_solution.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*old_edge)
inverse = action.inverse()
- assert not tracks.graph.has_edge(*edge)
- assert tracks.graph.has_edge(*old_edge)
+ assert not tracks.graph_solution.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*old_edge)
inverse.inverse()
- assert tracks.graph.has_edge(*edge)
- assert not tracks.graph.has_edge(*old_edge)
+ assert tracks.graph_solution.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*old_edge)
def test_user_delete_edge(self, get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# delete edge (1, 3). (1,2) is now not a division anymore
edge = (1, 3)
old_child = 2
old_track_id = tracks.get_track_id(old_child)
new_track_id = tracks.get_track_id(1)
- assert tracks.graph.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*edge)
action = UserDeleteEdge(tracks, edge)
- assert not tracks.graph.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) == new_track_id
inverse = action.inverse()
- assert tracks.graph.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) == old_track_id
double_inv = inverse.inverse()
- assert not tracks.graph.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) == new_track_id
# TODO: error if edge doesn't exist?
@@ -84,23 +84,23 @@ def test_user_delete_edge(self, get_tracks, ndim, with_seg):
old_child = 5
old_track_id = tracks.get_track_id(old_child)
- assert tracks.graph.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*edge)
action = UserDeleteEdge(tracks, edge)
- assert not tracks.graph.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) != old_track_id
inverse = action.inverse()
- assert tracks.graph.has_edge(*edge)
+ assert tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) == old_track_id
inverse.inverse()
- assert not tracks.graph.has_edge(*edge)
+ assert not tracks.graph_solution.has_edge(*edge)
assert tracks.get_track_id(old_child) != old_track_id
def test_add_edge_missing_node(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
with pytest.raises(InvalidActionError, match="Source node .* not in solution yet"):
UserAddEdge(tracks, (10, 11))
with pytest.raises(InvalidActionError, match="Target node .* not in solution yet"):
@@ -108,7 +108,7 @@ def test_add_edge_missing_node(get_tracks):
def test_add_edge_triple_div(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
with pytest.raises(
InvalidActionError, match="Expected degree of 0 or 1 before adding edge"
):
@@ -116,18 +116,18 @@ def test_add_edge_triple_div(get_tracks):
def test_delete_missing_edge(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
with pytest.raises(InvalidActionError, match="Edge .* not in solution"):
UserDeleteEdge(tracks, (10, 11))
def test_delete_edge_triple_div(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=True, prefill_track_ids=True)
attrs = {}
attrs["solution"] = True
attrs["iou"] = 0.9
- tracks.graph.add_edge(
+ tracks.graph_solution.add_edge(
source_id=1,
target_id=6,
attrs=attrs,
diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py
index 5524f021..8b21d4b1 100644
--- a/tests/user_actions/test_user_add_delete_node.py
+++ b/tests/user_actions/test_user_add_delete_node.py
@@ -9,7 +9,7 @@
@pytest.mark.parametrize("with_seg", [True, False])
class TestUserAddDeleteNode:
def test_user_add_invalid_node(self, get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# duplicate node
with pytest.raises(InvalidActionError, match="Node .* already exists"):
attrs = {"t": 5, "track_id": 1}
@@ -34,7 +34,7 @@ def test_user_add_invalid_node(self, get_tracks, ndim, with_seg):
UserAddNode(tracks, node=7, attributes=attrs)
def test_user_add_node(self, get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# add a node to replace a skip edge between node 4 in time 2 and node 5 in time 4
node_id = 7
track_id = 3
@@ -55,7 +55,7 @@ def test_user_add_node(self, get_tracks, ndim, with_seg):
del attributes["pos"]
else:
pixels = None
- graph = tracks.graph
+ graph = tracks.graph_solution
assert not graph.has_node(node_id)
assert graph.has_edge(4, 5)
action = UserAddNode(tracks, node_id, attributes, pixels=pixels)
@@ -84,11 +84,11 @@ def test_user_add_node(self, get_tracks, ndim, with_seg):
# TODO: error if node already exists?
def test_user_delete_node(self, get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# delete node in middle of track. Should skip-connect 3 and 5 with span 3
node_id = 4
- graph = tracks.graph
+ graph = tracks.graph_solution
assert graph.has_node(node_id)
assert graph.has_edge(3, node_id)
assert graph.has_edge(node_id, 5)
@@ -121,14 +121,14 @@ def test_user_delete_node(self, get_tracks, ndim, with_seg):
# TODO: error if node doesn't exist?
def test_user_delete_node_after_division(self, get_tracks, ndim, with_seg):
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# delete first node after division. Should relabel the other child
# to be the same track as parent
parent_node = 1
node_id = 2
sib = 3
- graph = tracks.graph
+ graph = tracks.graph_solution
assert graph.has_node(node_id)
assert graph.has_edge(parent_node, node_id)
parent_track_id = tracks.get_track_id(parent_node)
@@ -158,8 +158,8 @@ def test_user_delete_node_after_division(self, get_tracks, ndim, with_seg):
def test_user_delete_nodes(self, get_tracks, ndim, with_seg):
"""Test bulk deletion of multiple nodes in a single action."""
# Graph structure: 1 → 2, 1 → 3 → 4 → 5, and 6 (separate)
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
- graph = tracks.graph
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ graph = tracks.graph_solution
# Save original state
original_nodes = set(graph.node_ids())
diff --git a/tests/user_actions/test_user_swap_predecessors.py b/tests/user_actions/test_user_swap_predecessors.py
index d33faeb3..82a60c88 100644
--- a/tests/user_actions/test_user_swap_predecessors.py
+++ b/tests/user_actions/test_user_swap_predecessors.py
@@ -10,30 +10,30 @@ class TestUserSwapPredecessors:
@pytest.mark.parametrize("order", [(5, 6), (6, 5)])
def test_one_predecessor(self, get_tracks, ndim, with_seg, order):
"""Test swapping when one node has a predecessor and one doesn't."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# Node 5 (t=4) has pred 4, node 6 (t=4) has no pred
- assert tracks.graph.has_edge(4, 5)
- assert list(tracks.graph.predecessors(6)) == []
+ assert tracks.graph_solution.has_edge(4, 5)
+ assert list(tracks.graph_solution.predecessors(6)) == []
old_track_id_5 = tracks.get_track_id(5)
old_track_id_6 = tracks.get_track_id(6)
action = UserSwapPredecessors(tracks, order)
- assert tracks.graph.has_edge(4, 6)
- assert not tracks.graph.has_edge(4, 5)
+ assert tracks.graph_solution.has_edge(4, 6)
+ assert not tracks.graph_solution.has_edge(4, 5)
assert tracks.get_track_id(6) == old_track_id_5
assert tracks.get_track_id(5) != old_track_id_5
action.inverse()
- assert tracks.graph.has_edge(4, 5)
- assert not tracks.graph.has_edge(4, 6)
+ assert tracks.graph_solution.has_edge(4, 5)
+ assert not tracks.graph_solution.has_edge(4, 6)
assert tracks.get_track_id(5) == old_track_id_5
assert tracks.get_track_id(6) == old_track_id_6
def test_same_predecessor_raises(self, get_tracks, ndim, with_seg):
"""Test error when both nodes have the same predecessor."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# Nodes 2 and 3 both have predecessor 1
with pytest.raises(InvalidActionError, match="same predecessor"):
@@ -41,7 +41,7 @@ def test_same_predecessor_raises(self, get_tracks, ndim, with_seg):
def test_different_predecessors(self, get_tracks, ndim, with_seg):
"""Test swapping when both nodes have different predecessors."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
UserAddEdge(tracks, (2, 6))
@@ -51,20 +51,20 @@ def test_different_predecessors(self, get_tracks, ndim, with_seg):
action = UserSwapPredecessors(tracks, (5, 6))
- assert tracks.graph.has_edge(4, 6)
- assert tracks.graph.has_edge(2, 5)
- assert not tracks.graph.has_edge(4, 5)
- assert not tracks.graph.has_edge(2, 6)
+ assert tracks.graph_solution.has_edge(4, 6)
+ assert tracks.graph_solution.has_edge(2, 5)
+ assert not tracks.graph_solution.has_edge(4, 5)
+ assert not tracks.graph_solution.has_edge(2, 6)
action.inverse()
- assert tracks.graph.has_edge(4, 5)
- assert tracks.graph.has_edge(2, 6)
+ assert tracks.graph_solution.has_edge(4, 5)
+ assert tracks.graph_solution.has_edge(2, 6)
assert tracks.get_track_id(5) == old_track_id_5
assert tracks.get_track_id(6) == old_track_id_6
def test_different_times_valid(self, get_tracks, ndim, with_seg):
"""Test swapping nodes at different times when predecessors are valid."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# Add edge 2->6 so node 6 (t=4) has pred 2 (t=1)
# Node 4 (t=2) has pred 3 (t=1)
@@ -73,18 +73,18 @@ def test_different_times_valid(self, get_tracks, ndim, with_seg):
action = UserSwapPredecessors(tracks, (4, 6))
- assert tracks.graph.has_edge(3, 6)
- assert tracks.graph.has_edge(2, 4)
- assert not tracks.graph.has_edge(3, 4)
- assert not tracks.graph.has_edge(2, 6)
+ assert tracks.graph_solution.has_edge(3, 6)
+ assert tracks.graph_solution.has_edge(2, 4)
+ assert not tracks.graph_solution.has_edge(3, 4)
+ assert not tracks.graph_solution.has_edge(2, 6)
action.inverse()
- assert tracks.graph.has_edge(3, 4)
- assert tracks.graph.has_edge(2, 6)
+ assert tracks.graph_solution.has_edge(3, 4)
+ assert tracks.graph_solution.has_edge(2, 6)
def test_different_times_invalid_raises(self, get_tracks, ndim, with_seg):
"""Test error when predecessor would not be before swapped node."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# Node 3 (t=1) has pred 1 (t=0), node 4 (t=2) has pred 3 (t=1)
# pred of 4 (t=1) is not before node 3 (t=1)
@@ -93,7 +93,7 @@ def test_different_times_invalid_raises(self, get_tracks, ndim, with_seg):
def test_wrong_count_raises(self, get_tracks, ndim, with_seg):
"""Test error when not exactly two nodes provided."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
with pytest.raises(
InvalidActionError, match="You can only swap a pair of two nodes"
@@ -107,7 +107,7 @@ def test_wrong_count_raises(self, get_tracks, ndim, with_seg):
def test_no_predecessors_raises(self, get_tracks, ndim, with_seg):
"""Test error when neither node has a predecessor."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
# Delete edge so node 5 has no predecessor like node 6
UserDeleteEdge(tracks, (4, 5))
diff --git a/tests/user_actions/test_user_update_node_attrs.py b/tests/user_actions/test_user_update_node_attrs.py
index 5e85a03b..3b3c98b4 100644
--- a/tests/user_actions/test_user_update_node_attrs.py
+++ b/tests/user_actions/test_user_update_node_attrs.py
@@ -9,11 +9,17 @@
class TestUserUpdateNodeAttrs:
def test_user_update_node_attrs(self, get_tracks, ndim, with_seg):
"""Test basic node attribute update functionality."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
-
- tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object)
- tracks.graph.add_node_attr_key("confidence", default_value=0, dtype=pl.Float64)
- tracks.graph.add_node_attr_key("validated", default_value=False, dtype=pl.Boolean)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+
+ tracks.graph_solution.add_node_attr_key(
+ "label", default_value=None, dtype=pl.Object
+ )
+ tracks.graph_solution.add_node_attr_key(
+ "confidence", default_value=0, dtype=pl.Float64
+ )
+ tracks.graph_solution.add_node_attr_key(
+ "validated", default_value=False, dtype=pl.Boolean
+ )
# Add custom attributes to update
custom_attrs = {"label": "my_label", "confidence": 0.95, "validated": True}
@@ -43,10 +49,14 @@ def test_user_update_node_attrs(self, get_tracks, ndim, with_seg):
def test_user_update_existing_attrs(self, get_tracks, ndim, with_seg):
"""Test updating attributes that already exist."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
- tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object)
- tracks.graph.add_node_attr_key("score", default_value=None, dtype=pl.Float64)
+ tracks.graph_solution.add_node_attr_key(
+ "label", default_value=None, dtype=pl.Object
+ )
+ tracks.graph_solution.add_node_attr_key(
+ "score", default_value=None, dtype=pl.Float64
+ )
# Set initial custom attributes
tracks._set_node_attr(1, "label", "old_label")
@@ -67,7 +77,7 @@ def test_user_update_existing_attrs(self, get_tracks, ndim, with_seg):
def test_protected_time_attr(self, get_tracks, ndim, with_seg):
"""Test that time attribute cannot be updated."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
time_key = tracks.features.time_key
with pytest.raises(ValueError, match="Cannot update attribute"):
@@ -75,14 +85,14 @@ def test_protected_time_attr(self, get_tracks, ndim, with_seg):
def test_protected_track_id_attr(self, get_tracks, ndim, with_seg):
"""Test that track_id attribute cannot be updated."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
with pytest.raises(ValueError, match="Cannot update attribute"):
UserUpdateNodeAttrs(tracks, node=1, attrs={"track_id": 999})
def test_protected_area_attr(self, get_tracks, ndim, with_seg):
"""Test that area attribute (managed by annotator) cannot be updated."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
if with_seg: # area only exists when segmentation is present
with pytest.raises(ValueError, match="Cannot update attribute"):
@@ -90,7 +100,7 @@ def test_protected_area_attr(self, get_tracks, ndim, with_seg):
def test_protected_pos_attr(self, get_tracks, ndim, with_seg):
"""Test that position attribute (managed by annotator) cannot be updated."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
if with_seg: # pos is managed by RegionpropsAnnotator when seg exists
with pytest.raises(ValueError, match="Cannot update attribute"):
@@ -98,8 +108,10 @@ def test_protected_pos_attr(self, get_tracks, ndim, with_seg):
def test_action_history_integration(self, get_tracks, ndim, with_seg):
"""Test that action integrates properly with action history."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
- tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ tracks.graph_solution.add_node_attr_key(
+ "label", default_value=None, dtype=pl.Object
+ )
# Initially empty
assert len(tracks.action_history.undo_stack) == 0
diff --git a/tests/user_actions/test_user_update_nodes_attrs.py b/tests/user_actions/test_user_update_nodes_attrs.py
index bb853887..b327b5a0 100644
--- a/tests/user_actions/test_user_update_nodes_attrs.py
+++ b/tests/user_actions/test_user_update_nodes_attrs.py
@@ -10,10 +10,14 @@
class TestUserUpdateNodesAttrs:
def test_user_update_nodes_attrs(self, get_tracks, ndim, with_seg):
"""Test basic bulk node attribute update functionality."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
- tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object)
- tracks.graph.add_node_attr_key("confidence", default_value=0, dtype=pl.Float64)
+ tracks.graph_solution.add_node_attr_key(
+ "label", default_value=None, dtype=pl.Object
+ )
+ tracks.graph_solution.add_node_attr_key(
+ "confidence", default_value=0, dtype=pl.Float64
+ )
attrs = {"label": ["my_label", "my_label"], "confidence": [0.95, 0.95]}
UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs=attrs)
@@ -24,8 +28,10 @@ def test_user_update_nodes_attrs(self, get_tracks, ndim, with_seg):
def test_single_history_entry(self, get_tracks, ndim, with_seg):
"""Updating multiple nodes creates only one history entry."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
- tracks.graph.add_node_attr_key("label", default_value=None, dtype=pl.Object)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ tracks.graph_solution.add_node_attr_key(
+ "label", default_value=None, dtype=pl.Object
+ )
action = UserUpdateNodesAttrs(
tracks, nodes=[1, 2, 3], attrs={"label": ["x", "x", "x"]}
@@ -36,8 +42,10 @@ def test_single_history_entry(self, get_tracks, ndim, with_seg):
def test_undo_redo(self, get_tracks, ndim, with_seg):
"""Undo restores all nodes' attrs to defaults; redo re-applies them."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
- tracks.graph.add_node_attr_key("score", default_value=0, dtype=pl.Float64)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ tracks.graph_solution.add_node_attr_key(
+ "score", default_value=0, dtype=pl.Float64
+ )
action = UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.9, 0.9]})
@@ -56,8 +64,10 @@ def test_undo_redo(self, get_tracks, ndim, with_seg):
def test_per_node_attrs(self, get_tracks, ndim, with_seg):
"""Test bulk update with different values per node."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
- tracks.graph.add_node_attr_key("score", default_value=0, dtype=pl.Float64)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
+ tracks.graph_solution.add_node_attr_key(
+ "score", default_value=0, dtype=pl.Float64
+ )
UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.1, 0.9]})
@@ -66,10 +76,10 @@ def test_per_node_attrs(self, get_tracks, ndim, with_seg):
def test_array_attr(self, get_tracks, ndim, with_seg):
"""Test bulk update with array-valued attributes."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
spatial_dims = ndim - 1
- tracks.graph.add_node_attr_key(
+ tracks.graph_solution.add_node_attr_key(
"custom_pos", default_value=None, dtype=pl.Array(pl.Float64, spatial_dims)
)
@@ -81,21 +91,21 @@ def test_array_attr(self, get_tracks, ndim, with_seg):
def test_values_not_list_raises(self, get_tracks, ndim, with_seg):
"""Non-list values raise ValueError."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
with pytest.raises(ValueError, match="must be a list"):
UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": 0.9})
def test_values_length_mismatch_raises(self, get_tracks, ndim, with_seg):
"""List length not matching nodes length raises ValueError."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
with pytest.raises(ValueError, match="length"):
UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs={"score": [0.1]})
def test_protected_attr_raises(self, get_tracks, ndim, with_seg):
"""Passing a protected attribute raises ValueError."""
- tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=with_seg, prefill_track_ids=True)
time_key = tracks.features.time_key
with pytest.raises(ValueError, match="Cannot update attribute"):
diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py
index d2d67a12..1a9a2f3b 100644
--- a/tests/user_actions/test_user_update_segmentation.py
+++ b/tests/user_actions/test_user_update_segmentation.py
@@ -26,7 +26,7 @@ def pixels_equal_mask(self, pixels, tracks, node_id):
)
def test_user_update_seg_smaller(self, get_tracks, ndim):
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
node_id = 3
edge = (1, 3)
@@ -51,14 +51,14 @@ def test_user_update_seg_smaller(self, get_tracks, ndim):
updated_pixels=[(pixels_to_remove, node_id)],
current_track_id=1,
)
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
assert self.pixels_equal_mask(remaining_pixels, tracks, node_id)
assert tracks.get_position(node_id) == new_position
assert tracks.get_node_attr(node_id, "area") == 1
assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(0.0, abs=0.01)
inverse = action.inverse()
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
assert self.pixels_equal_mask(orig_pixels, tracks, node_id)
assert tracks.get_position(node_id) == orig_position
assert tracks.get_node_attr(node_id, "area") == orig_area
@@ -71,7 +71,7 @@ def test_user_update_seg_smaller(self, get_tracks, ndim):
assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(0.0, abs=0.01)
def test_user_update_seg_bigger(self, get_tracks, ndim):
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
node_id = 3
edge = (1, 3)
@@ -95,26 +95,26 @@ def test_user_update_seg_bigger(self, get_tracks, ndim):
action = UserUpdateSegmentation(
tracks, new_value=3, updated_pixels=[(pixels_to_add, 0)], current_track_id=1
)
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
assert self.pixels_equal_mask(all_pixels, tracks, node_id)
assert tracks.get_node_attr(node_id, "area") == orig_area + 1
assert tracks.get_edge_attr(edge, iou_key) != orig_iou
inverse = action.inverse()
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
assert self.pixels_equal_mask(orig_pixels, tracks, node_id)
assert tracks.get_position(node_id) == orig_position
assert tracks.get_node_attr(node_id, "area") == orig_area
assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(orig_iou, abs=0.01)
inverse.inverse()
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
assert self.pixels_equal_mask(all_pixels, tracks, node_id)
assert tracks.get_node_attr(node_id, "area") == orig_area + 1
assert tracks.get_edge_attr(edge, iou_key) != orig_iou
def test_invalid_action_with_segmentation(self, get_tracks, ndim):
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
node_id = 1
# Paint on top of node 1 with track id 3: because of the downstream division, this
@@ -157,12 +157,12 @@ def test_invalid_action_with_segmentation(self, get_tracks, ndim):
# assert that the segmentation now has the new value
assert np.asarray(tracks.segmentation[t, y, x]) == new_value
- assert tracks.graph.has_node(new_value)
+ assert tracks.graph_solution.has_node(new_value)
assert len(update_seg_action.actions) == 2 # one for adding a node,
# and one for updating existing node 1
def test_user_erase_seg(self, get_tracks, ndim):
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
node_id = 3
edge = (1, 3)
@@ -182,24 +182,24 @@ def test_user_erase_seg(self, get_tracks, ndim):
updated_pixels=[(pixels_to_remove, node_id)],
current_track_id=1,
)
- assert not tracks.graph.has_node(node_id)
+ assert not tracks.graph_solution.has_node(node_id)
inverse = action.inverse()
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
self.pixels_equal_mask(orig_pixels, tracks, node_id)
assert tracks.get_position(node_id) == orig_position
assert tracks.get_node_attr(node_id, "area") == orig_area
assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(orig_iou, abs=0.01)
inverse.inverse()
- assert not tracks.graph.has_node(node_id)
+ assert not tracks.graph_solution.has_node(node_id)
def test_user_erase_seg_history_size(self, get_tracks, ndim):
"""An erase via UserUpdateSegmentation must add exactly one history
entry. Regression test for a bug where the nested UserDeleteNode
also registered itself, leaving two entries per fill and corrupting
undo behavior."""
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
node_id = 6
pixels = td_mask_to_pixels(
tracks.get_mask(node_id), tracks.get_time(node_id), ndim=tracks.ndim
@@ -217,7 +217,7 @@ def test_user_two_erases_then_two_undos(self, get_tracks, ndim):
tracks.action_history.undo(). Reproduces bug_paint_undo: the second
undo crashed because the buggy history had a duplicate UserDeleteNode
entry that tried to re-add an already-restored node."""
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
pixels_5 = td_mask_to_pixels(
tracks.get_mask(5), tracks.get_time(5), ndim=tracks.ndim
)
@@ -228,23 +228,23 @@ def test_user_two_erases_then_two_undos(self, get_tracks, ndim):
UserUpdateSegmentation(
tracks, new_value=0, updated_pixels=[(pixels_5, 5)], current_track_id=1
)
- assert not tracks.graph.has_node(5)
+ assert not tracks.graph_solution.has_node(5)
UserUpdateSegmentation(
tracks, new_value=0, updated_pixels=[(pixels_6, 6)], current_track_id=1
)
- assert not tracks.graph.has_node(6)
+ assert not tracks.graph_solution.has_node(6)
assert tracks.action_history.undo() is True
- assert tracks.graph.has_node(6)
- assert not tracks.graph.has_node(5)
+ assert tracks.graph_solution.has_node(6)
+ assert not tracks.graph_solution.has_node(5)
assert tracks.action_history.undo() is True
- assert tracks.graph.has_node(5)
- assert tracks.graph.has_node(6)
+ assert tracks.graph_solution.has_node(5)
+ assert tracks.graph_solution.has_node(6)
def test_user_add_seg(self, get_tracks, ndim):
- tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
+ tracks = get_tracks(ndim=ndim, with_seg=True, prefill_track_ids=True)
# draw a new node just like node 6 but in time 3 (instead of 4)
old_node_id = 6
node_id = 7
@@ -260,7 +260,7 @@ def test_user_add_seg(self, get_tracks, ndim):
position = tracks.get_position(old_node_id)
area = tracks.get_node_attr(old_node_id, "area")
- assert not tracks.graph.has_node(node_id)
+ assert not tracks.graph_solution.has_node(node_id)
assert np.sum(tracks.segmentation == node_id) == 0
action = UserUpdateSegmentation(
@@ -270,22 +270,22 @@ def test_user_add_seg(self, get_tracks, ndim):
current_track_id=10,
)
assert np.sum(np.asarray(tracks.segmentation) == node_id) == len(pixels_to_add[0])
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
assert tracks.get_position(node_id) == position
assert tracks.get_node_attr(node_id, "area") == area
assert tracks.get_track_id(node_id) == 10
inverse = action.inverse()
- assert not tracks.graph.has_node(node_id)
+ assert not tracks.graph_solution.has_node(node_id)
inverse.inverse()
- assert tracks.graph.has_node(node_id)
+ assert tracks.graph_solution.has_node(node_id)
assert tracks.get_position(node_id) == position
assert tracks.get_node_attr(node_id, "area") == area
assert tracks.get_track_id(node_id) == 10
def test_missing_seg(get_tracks):
- tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
+ tracks = get_tracks(ndim=3, with_seg=False, prefill_track_ids=True)
with pytest.raises(ValueError, match="Cannot update non-existing segmentation"):
UserUpdateSegmentation(tracks, 0, [], 1)
diff --git a/tests/utils/test_tracksdata_utils.py b/tests/utils/test_tracksdata_utils.py
index 32231579..e4eb0793 100644
--- a/tests/utils/test_tracksdata_utils.py
+++ b/tests/utils/test_tracksdata_utils.py
@@ -6,7 +6,7 @@
import pytest
from funtracks.utils.tracksdata_utils import (
- create_empty_graphview_graph,
+ create_empty_graph,
pixels_to_td_mask,
td_mask_to_pixels,
)
@@ -142,7 +142,7 @@ def test_pixels_coordinate_offset(ndim):
def test_memory_graph_survives_thread_boundary():
- """A GraphView created in a worker thread must remain accessible from the main thread.
+ """A base graph created in a worker thread must stay accessible from the main thread.
Regression test: nodes_from_segmentation previously used database=':memory:',
which caused 'no such table: Metadata' when the graph crossed a thread boundary
@@ -152,7 +152,7 @@ def test_memory_graph_survives_thread_boundary():
result = {}
def worker():
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=["pos"],
ndim=3,
)
@@ -165,22 +165,23 @@ def worker():
graph = result["graph"]
- # This calls graph.metadata internally via BaseGraph.from_other().
+ # create_empty_graph now returns the base graph; build a view to exercise
+ # detach(), which calls graph.metadata internally via BaseGraph.from_other().
# With :memory: + default connection pool it opens a new empty DB → crash.
- detached = graph.detach()
+ detached = graph.filter().subgraph().detach()
assert detached.num_nodes() == 1
-def test_create_empty_graphview_graph_with_solution_attr():
+def test_create_empty_graph_with_solution_attr():
"""Test that passing solution as a node/edge attribute does not raise.
- Regression test: create_empty_graphview_graph unconditionally added the
+ Regression test: create_empty_graph unconditionally added the
solution attribute at the end, even when it was already added via the
node_attributes / edge_attributes loop, causing a ValueError.
"""
# Should not raise ValueError even though solution is listed explicitly
- graph = create_empty_graphview_graph(
+ graph = create_empty_graph(
node_attributes=["solution"],
edge_attributes=["solution"],
node_default_values=[True],