Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions benchmarks/graph_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

from itertools import pairwise

import numpy as np
import polars as pl

import tracksdata as td
from benchmarks.common import BACKENDS, IS_CI
from tracksdata.attrs import NodeAttr
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.graph.filters import BBoxSpatialFilter

# Total node count. Tuned so the current (pre-fix) `remove_node` finishes
# within the per-benchmark timeout — see PR discussion for sizing rationale.
Expand Down Expand Up @@ -105,3 +107,60 @@ def time_update_node_attrs_view_with_listener(self, backend_name: str, n_nodes:

def time_filter_node_ids(self, backend_name: str, n_nodes: int) -> None:
self.graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) >= 1).node_ids()


def _build_bbox_graph(backend_name: str, n_nodes: int) -> td.graph.BaseGraph:
"""Graph whose nodes carry a bbox, so a real BBoxSpatialFilter can index them."""
graph = BACKENDS[backend_name]()
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Array(pl.Int64, 4))
graph.add_node_attr_key("score", dtype=pl.Float64)
nodes_per_lineage = max(2, n_nodes // N_LINEAGES)
for lineage in range(N_LINEAGES):
graph.bulk_add_nodes(
[
{
DEFAULT_ATTR_KEYS.T: t,
DEFAULT_ATTR_KEYS.BBOX: np.asarray([lineage, t, lineage + 2, t + 2]),
"score": 0.0,
}
for t in range(nodes_per_lineage)
]
)
return graph


class SpatialFilterUpdateBenchmark:
"""`update_node_attrs` with a live BBoxSpatialFilter attached to `node_updated`.

This is the `assign_tracklet_ids` regression case: a non-spatial bulk write
(e.g. ``tracklet_id``) emits ``node_updated``, and before the fix the filter
re-indexes every node in the rtree even though no bbox/frame changed. The
`_noop`-listener benchmark above only times the producer-side payload build;
this one times the consumer-side rtree work that actually regressed.
"""

param_names = ("backend", "n_nodes")
params = (tuple(BACKENDS), NODE_SIZES)

number = 1
warmup_time = 0
timeout = 300

def setup(self, backend_name: str, n_nodes: int) -> None:
self.graph = _build_bbox_graph(backend_name, n_nodes)
# Attach a real spatial filter so node_updated drives rtree mutations.
self.spatial_filter = BBoxSpatialFilter(
self.graph, frame_attr_key=DEFAULT_ATTR_KEYS.T, bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX
)
self.target_ids = self.graph.node_ids()

def time_update_non_spatial_attr_with_filter(self, backend_name: str, n_nodes: int) -> None:
# Non-spatial write: bbox/frame untouched -> should be O(1) for the filter
# after the fix, O(N) rtree churn before it.
self.graph.update_node_attrs(node_ids=self.target_ids, attrs={"score": 1.0})

def time_update_bbox_attr_with_filter(self, backend_name: str, n_nodes: int) -> None:
# Spatial write: bbox genuinely changes -> filter must re-index. Guards
# against the fix over-eagerly skipping legitimate updates.
new_bboxes = [np.asarray([i, i, i + 3, i + 3]) for i in range(len(self.target_ids))]
self.graph.update_node_attrs(node_ids=self.target_ids, attrs={DEFAULT_ATTR_KEYS.BBOX: new_bboxes})
7 changes: 7 additions & 0 deletions src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,15 @@ def _on_node_updated(
node_ids: list[int],
old_attrs: list[dict],
new_attrs: list[dict],
changed_keys: set[str] | None = None,
) -> None:
del node_ids
# The rendered output depends only on position (t/bbox), the mask, and the
# displayed attribute. If none changed, there is nothing to invalidate.
if changed_keys is not None and changed_keys.isdisjoint(
{DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.BBOX, DEFAULT_ATTR_KEYS.MASK, self._attr_key}
):
return
time_values: list[Any] = []
bboxes: list[Any] = []
for old_attr, new_attr in zip(old_attrs, new_attrs, strict=True):
Expand Down
2 changes: 1 addition & 1 deletion src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class BaseGraph(abc.ABC):
_PRIVATE_METADATA_PREFIX = "__private_"
node_added = Signal(list, list)
node_removed = Signal(list, list)
node_updated = Signal(list, list, list)
node_updated = Signal(list, list, list, set)

def __init__(self) -> None:
self._cache = {}
Expand Down
3 changes: 3 additions & 0 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,15 +1017,18 @@ def update_node_attrs(
if k in new_attrs:
new_attrs[k] = v if np.isscalar(v) else v[i]
new_attrs_by_id[node_id] = new_attrs
changed_keys = set(attrs.keys())
if root_signal_on:
emit_node_updated_events(
self._root.node_updated,
((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in node_ids),
changed_keys,
)
if view_signal_on:
emit_node_updated_events(
self.node_updated,
((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in node_ids),
changed_keys,
)

def update_edge_attrs(
Expand Down
2 changes: 2 additions & 0 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,7 @@ def update_node_attrs(
emit_node_updated_events(
self.node_updated,
((node_id, old_attrs_by_id[node_id], dict(self._graph[node_id])) for node_id in node_ids),
set(attrs.keys()),
)

def update_edge_attrs(
Expand Down Expand Up @@ -1969,6 +1970,7 @@ def update_node_attrs(
(external_node_id, old_attrs_by_id[external_node_id], dict(self._graph[local_node_id]))
for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True)
),
set(attrs.keys()),
)

def bulk_remove_nodes(self, node_ids: Sequence[int]) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,6 +2208,7 @@ def update_node_attrs(
emit_node_updated_events(
self.node_updated,
((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in updated_node_ids),
set(attrs.keys()),
)

def update_edge_attrs(
Expand Down
16 changes: 16 additions & 0 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,22 @@ def test_update_node_attrs_emits_batched_node_updated_callback(graph_backend: Ba
assert [attrs["x"] for attrs in calls[0][2]] == [10.0, 20.0, 30.0]


def test_update_node_attrs_node_updated_carries_changed_keys(graph_backend: BaseGraph) -> None:
"""node_updated delivers the set of written keys as a 4th arg, so connectors
can skip work when none of the keys they track changed."""
graph_backend.add_node_attr_key("x", pl.Float64)
graph_backend.add_node_attr_key("score", pl.Float64)

node_ids = graph_backend.bulk_add_nodes([{"t": 0, "x": 1.0, "score": 0.0}])

calls: list[set] = []
graph_backend.node_updated.connect(lambda node_ids, old_attrs, new_attrs, changed_keys: calls.append(changed_keys))

graph_backend.update_node_attrs(node_ids=node_ids, attrs={"score": 5.0})

assert calls == [{"score"}]


def test_update_edge_attrs(graph_backend: BaseGraph) -> None:
"""Test updating edge attributes."""
node1 = graph_backend.add_node({"t": 0})
Expand Down
12 changes: 12 additions & 0 deletions src/tracksdata/graph/filters/_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,11 @@ def _update_node(
node_ids: list[int],
old_attrs: list[dict[str, Any]],
new_attrs: list[dict[str, Any]],
changed_keys: set[str] | None = None,
) -> None:
# Skip rtree churn when the update touches no spatial coordinate.
if changed_keys is not None and changed_keys.isdisjoint(self._attr_keys):
return
self._remove_node(node_ids, old_attrs)
self._add_node(node_ids, new_attrs)

Expand Down Expand Up @@ -488,7 +492,15 @@ def _update_node(
node_ids: list[int],
old_attrs: list[dict[str, Any]],
new_attrs: list[dict[str, Any]],
changed_keys: set[str] | None = None,
) -> None:
# Skip rtree churn when the update touches neither the bbox nor the frame.
if changed_keys is not None:
spatial_keys = {self._bbox_attr_key}
if self._frame_attr_key is not None:
spatial_keys.add(self._frame_attr_key)
if changed_keys.isdisjoint(spatial_keys):
return
self._remove_node(node_ids, old_attrs)
self._add_node(node_ids, new_attrs)

Expand Down
76 changes: 76 additions & 0 deletions src/tracksdata/graph/filters/_test/test_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,82 @@ def test_bbox_spatial_filter_update_non_bbox_attr_no_error(graph_backend: BaseGr
assert node_id in spatial_filter[0:0.5, 0:3, 0:3].node_ids()


def _spy_calls(obj: object, method_name: str) -> dict[str, int]:
"""Wrap a bound method on ``obj`` with a call counter, still calling through."""
calls = {"n": 0}
orig = getattr(obj, method_name)

def wrapper(*args, **kwargs):
calls["n"] += 1
return orig(*args, **kwargs)

setattr(obj, method_name, wrapper)
return calls


def test_bbox_spatial_filter_skips_reindex_on_non_spatial_update(graph_backend: BaseGraph) -> None:
"""A non-spatial write (e.g. tracklet id) must not touch the rtree.

Regression guard for the assign_tracklet_ids slowdown: previously every
node_updated triggered a delete+reinsert per node even though bbox/frame
were unchanged.
"""
graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 4))
graph_backend.add_node_attr_key("track_id", pl.Int64, -1)
node_id = graph_backend.add_node({"t": 0, "bbox": np.asarray([0, 0, 2, 2]), "track_id": -1})

for graph in [graph_backend, graph_backend.filter().subgraph()]:
spatial_filter = BBoxSpatialFilter(graph, frame_attr_key="t", bbox_attr_key="bbox")
removes = _spy_calls(spatial_filter, "_remove_node")
adds = _spy_calls(spatial_filter, "_add_node")

graph.update_node_attrs(attrs={"track_id": 7}, node_ids=[node_id])

assert removes["n"] == 0
assert adds["n"] == 0
# Node is still indexed at its original position.
assert node_id in spatial_filter[0:0.5, 0:3, 0:3].node_ids()


def test_bbox_spatial_filter_reindexes_on_bbox_update(graph_backend: BaseGraph) -> None:
"""A genuine bbox change must still re-index — guards against over-eager skipping.

Asserts the rtree is mutated (delete + reinsert); the resulting spatial-query
correctness is covered by ``test_bbox_spatial_filter_updates_node_position``.
"""
graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 4))
node_id = graph_backend.add_node({"t": 0, "bbox": np.asarray([0, 0, 2, 2])})

for graph in [graph_backend, graph_backend.filter().subgraph()]:
spatial_filter = BBoxSpatialFilter(graph, frame_attr_key="t", bbox_attr_key="bbox")
removes = _spy_calls(spatial_filter, "_remove_node")
adds = _spy_calls(spatial_filter, "_add_node")

graph.update_node_attrs(attrs={"bbox": [np.asarray([20, 20, 22, 22])]}, node_ids=[node_id])

assert removes["n"] == 1
assert adds["n"] == 1


def test_point_spatial_filter_skips_reindex_on_non_spatial_update(graph_backend: BaseGraph) -> None:
"""The point-based SpatialFilter must also skip the rtree on non-spatial writes."""
graph_backend.add_node_attr_key("y", pl.Int64)
graph_backend.add_node_attr_key("x", pl.Int64)
graph_backend.add_node_attr_key("track_id", pl.Int64, -1)
node_id = graph_backend.add_node({"t": 0, "y": 10, "x": 20, "track_id": -1})

for graph in [graph_backend, graph_backend.filter().subgraph()]:
spatial_filter = SpatialFilter(graph, attr_keys=["y", "x"])
removes = _spy_calls(spatial_filter, "_remove_node")
adds = _spy_calls(spatial_filter, "_add_node")

graph.update_node_attrs(attrs={"track_id": 7}, node_ids=[node_id])

assert removes["n"] == 0
assert adds["n"] == 0
assert node_id in spatial_filter[5:15, 15:25].node_ids()


def test_bbox_spatial_filter_handles_list_dtype(graph_backend: BaseGraph) -> None:
"""Ensure bounding boxes stored as list dtype still work with the spatial filter."""
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))
Expand Down
12 changes: 8 additions & 4 deletions src/tracksdata/utils/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,31 @@ def emit_node_added_events(
def emit_node_updated_events(
sig: Signal | SignalInstance,
event_args: Iterable[tuple[int, dict[str, Any], dict[str, Any]]],
changed_keys: set[str],
) -> None:
"""
Emit a single batched ``node_updated`` event.

Connected slots always receive one call
``(list_of_ids, list_of_old_attrs, list_of_new_attrs)``. No-op if there are
no events or the signal has no active listeners.
Connected slots receive one call
``(list_of_ids, list_of_old_attrs, list_of_new_attrs, changed_keys)``. No-op
if there are no events or the signal has no active listeners.

Parameters
----------
sig : Signal | SignalInstance
The ``node_updated`` signal to emit on.
event_args : Iterable[tuple[int, dict[str, Any], dict[str, Any]]]
The ``(node_id, old_attrs, new_attrs)`` triples to emit.
changed_keys : set[str]
The attribute keys actually written by this update (uniform across the
batch). Lets connectors skip work when none of the keys they track changed.
"""
events = list(event_args)
if len(events) == 0 or not is_signal_on(sig):
return

node_ids, old_attrs, new_attrs = zip(*events, strict=True)
sig.emit(list(node_ids), list(old_attrs), list(new_attrs))
sig.emit(list(node_ids), list(old_attrs), list(new_attrs), set(changed_keys))


def emit_node_removed_events(
Expand Down
Loading