From 97c7456a78c47852b3e25dc1bd094cdd36b8bebb Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 14:40:18 -0700 Subject: [PATCH 1/2] add benchmark for node_update_attrs when BBoxSpatialFilter attached --- benchmarks/graph_mutations.py | 59 +++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/benchmarks/graph_mutations.py b/benchmarks/graph_mutations.py index 99678527..b20579cf 100644 --- a/benchmarks/graph_mutations.py +++ b/benchmarks/graph_mutations.py @@ -9,12 +9,14 @@ from itertools import pairwise +import numpy as np import polars as pl import tracksdata as td from benchmarks.common import BACKENDS, IS_CI from tracksdata.attrs import NodeAttr from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.graph.filters import BBoxSpatialFilter # Total node count. Tuned so the current (pre-fix) `remove_node` finishes # within the per-benchmark timeout — see PR discussion for sizing rationale. @@ -105,3 +107,60 @@ def time_update_node_attrs_view_with_listener(self, backend_name: str, n_nodes: def time_filter_node_ids(self, backend_name: str, n_nodes: int) -> None: self.graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) >= 1).node_ids() + + +def _build_bbox_graph(backend_name: str, n_nodes: int) -> td.graph.BaseGraph: + """Graph whose nodes carry a bbox, so a real BBoxSpatialFilter can index them.""" + graph = BACKENDS[backend_name]() + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Array(pl.Int64, 4)) + graph.add_node_attr_key("score", dtype=pl.Float64) + nodes_per_lineage = max(2, n_nodes // N_LINEAGES) + for lineage in range(N_LINEAGES): + graph.bulk_add_nodes( + [ + { + DEFAULT_ATTR_KEYS.T: t, + DEFAULT_ATTR_KEYS.BBOX: np.asarray([lineage, t, lineage + 2, t + 2]), + "score": 0.0, + } + for t in range(nodes_per_lineage) + ] + ) + return graph + + +class SpatialFilterUpdateBenchmark: + """`update_node_attrs` with a live BBoxSpatialFilter attached to `node_updated`. + + This is the `assign_tracklet_ids` regression case: a non-spatial bulk write + (e.g. ``tracklet_id``) emits ``node_updated``, and before the fix the filter + re-indexes every node in the rtree even though no bbox/frame changed. The + `_noop`-listener benchmark above only times the producer-side payload build; + this one times the consumer-side rtree work that actually regressed. + """ + + param_names = ("backend", "n_nodes") + params = (tuple(BACKENDS), NODE_SIZES) + + number = 1 + warmup_time = 0 + timeout = 300 + + def setup(self, backend_name: str, n_nodes: int) -> None: + self.graph = _build_bbox_graph(backend_name, n_nodes) + # Attach a real spatial filter so node_updated drives rtree mutations. + self.spatial_filter = BBoxSpatialFilter( + self.graph, frame_attr_key=DEFAULT_ATTR_KEYS.T, bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX + ) + self.target_ids = self.graph.node_ids() + + def time_update_non_spatial_attr_with_filter(self, backend_name: str, n_nodes: int) -> None: + # Non-spatial write: bbox/frame untouched -> should be O(1) for the filter + # after the fix, O(N) rtree churn before it. + self.graph.update_node_attrs(node_ids=self.target_ids, attrs={"score": 1.0}) + + def time_update_bbox_attr_with_filter(self, backend_name: str, n_nodes: int) -> None: + # Spatial write: bbox genuinely changes -> filter must re-index. Guards + # against the fix over-eagerly skipping legitimate updates. + new_bboxes = [np.asarray([i, i, i + 3, i + 3]) for i in range(len(self.target_ids))] + self.graph.update_node_attrs(node_ids=self.target_ids, attrs={DEFAULT_ATTR_KEYS.BBOX: new_bboxes}) From 3fbba1008bacb7be0539bbb97c5f9ca37fcaa6f2 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 30 Jun 2026 15:31:23 -0700 Subject: [PATCH 2/2] give node_update signal a list of changed_keys for listeners to make cheaper decisions on when to update themselves --- src/tracksdata/array/_graph_array.py | 7 ++ src/tracksdata/graph/_base_graph.py | 2 +- src/tracksdata/graph/_graph_view.py | 3 + src/tracksdata/graph/_rustworkx_graph.py | 2 + src/tracksdata/graph/_sql_graph.py | 1 + .../graph/_test/test_graph_backends.py | 16 ++++ .../graph/filters/_spatial_filter.py | 12 +++ .../filters/_test/test_spatial_filter.py | 76 +++++++++++++++++++ src/tracksdata/utils/_signal.py | 12 ++- 9 files changed, 126 insertions(+), 5 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index ac64983c..edee40cc 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -446,8 +446,15 @@ def _on_node_updated( node_ids: list[int], old_attrs: list[dict], new_attrs: list[dict], + changed_keys: set[str] | None = None, ) -> None: del node_ids + # The rendered output depends only on position (t/bbox), the mask, and the + # displayed attribute. If none changed, there is nothing to invalidate. + if changed_keys is not None and changed_keys.isdisjoint( + {DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.BBOX, DEFAULT_ATTR_KEYS.MASK, self._attr_key} + ): + return time_values: list[Any] = [] bboxes: list[Any] = [] for old_attr, new_attr in zip(old_attrs, new_attrs, strict=True): diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 44e2d1d0..f746673b 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -113,7 +113,7 @@ class BaseGraph(abc.ABC): _PRIVATE_METADATA_PREFIX = "__private_" node_added = Signal(list, list) node_removed = Signal(list, list) - node_updated = Signal(list, list, list) + node_updated = Signal(list, list, list, set) def __init__(self) -> None: self._cache = {} diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index fe434283..7da34d35 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -1017,15 +1017,18 @@ def update_node_attrs( if k in new_attrs: new_attrs[k] = v if np.isscalar(v) else v[i] new_attrs_by_id[node_id] = new_attrs + changed_keys = set(attrs.keys()) if root_signal_on: emit_node_updated_events( self._root.node_updated, ((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in node_ids), + changed_keys, ) if view_signal_on: emit_node_updated_events( self.node_updated, ((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in node_ids), + changed_keys, ) def update_edge_attrs( diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 254935e8..5f6a2645 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -1294,6 +1294,7 @@ def update_node_attrs( emit_node_updated_events( self.node_updated, ((node_id, old_attrs_by_id[node_id], dict(self._graph[node_id])) for node_id in node_ids), + set(attrs.keys()), ) def update_edge_attrs( @@ -1969,6 +1970,7 @@ def update_node_attrs( (external_node_id, old_attrs_by_id[external_node_id], dict(self._graph[local_node_id])) for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True) ), + set(attrs.keys()), ) def bulk_remove_nodes(self, node_ids: Sequence[int]) -> None: diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 2c703c15..b2f66b46 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -2208,6 +2208,7 @@ def update_node_attrs( emit_node_updated_events( self.node_updated, ((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in updated_node_ids), + set(attrs.keys()), ) def update_edge_attrs( diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 93731fba..641138e5 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -575,6 +575,22 @@ def test_update_node_attrs_emits_batched_node_updated_callback(graph_backend: Ba assert [attrs["x"] for attrs in calls[0][2]] == [10.0, 20.0, 30.0] +def test_update_node_attrs_node_updated_carries_changed_keys(graph_backend: BaseGraph) -> None: + """node_updated delivers the set of written keys as a 4th arg, so connectors + can skip work when none of the keys they track changed.""" + graph_backend.add_node_attr_key("x", pl.Float64) + graph_backend.add_node_attr_key("score", pl.Float64) + + node_ids = graph_backend.bulk_add_nodes([{"t": 0, "x": 1.0, "score": 0.0}]) + + calls: list[set] = [] + graph_backend.node_updated.connect(lambda node_ids, old_attrs, new_attrs, changed_keys: calls.append(changed_keys)) + + graph_backend.update_node_attrs(node_ids=node_ids, attrs={"score": 5.0}) + + assert calls == [{"score"}] + + def test_update_edge_attrs(graph_backend: BaseGraph) -> None: """Test updating edge attributes.""" node1 = graph_backend.add_node({"t": 0}) diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index a04ccf64..f3f78237 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -246,7 +246,11 @@ def _update_node( node_ids: list[int], old_attrs: list[dict[str, Any]], new_attrs: list[dict[str, Any]], + changed_keys: set[str] | None = None, ) -> None: + # Skip rtree churn when the update touches no spatial coordinate. + if changed_keys is not None and changed_keys.isdisjoint(self._attr_keys): + return self._remove_node(node_ids, old_attrs) self._add_node(node_ids, new_attrs) @@ -488,7 +492,15 @@ def _update_node( node_ids: list[int], old_attrs: list[dict[str, Any]], new_attrs: list[dict[str, Any]], + changed_keys: set[str] | None = None, ) -> None: + # Skip rtree churn when the update touches neither the bbox nor the frame. + if changed_keys is not None: + spatial_keys = {self._bbox_attr_key} + if self._frame_attr_key is not None: + spatial_keys.add(self._frame_attr_key) + if changed_keys.isdisjoint(spatial_keys): + return self._remove_node(node_ids, old_attrs) self._add_node(node_ids, new_attrs) diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index 6a6d0828..abf4e986 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -413,6 +413,82 @@ def test_bbox_spatial_filter_update_non_bbox_attr_no_error(graph_backend: BaseGr assert node_id in spatial_filter[0:0.5, 0:3, 0:3].node_ids() +def _spy_calls(obj: object, method_name: str) -> dict[str, int]: + """Wrap a bound method on ``obj`` with a call counter, still calling through.""" + calls = {"n": 0} + orig = getattr(obj, method_name) + + def wrapper(*args, **kwargs): + calls["n"] += 1 + return orig(*args, **kwargs) + + setattr(obj, method_name, wrapper) + return calls + + +def test_bbox_spatial_filter_skips_reindex_on_non_spatial_update(graph_backend: BaseGraph) -> None: + """A non-spatial write (e.g. tracklet id) must not touch the rtree. + + Regression guard for the assign_tracklet_ids slowdown: previously every + node_updated triggered a delete+reinsert per node even though bbox/frame + were unchanged. + """ + graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 4)) + graph_backend.add_node_attr_key("track_id", pl.Int64, -1) + node_id = graph_backend.add_node({"t": 0, "bbox": np.asarray([0, 0, 2, 2]), "track_id": -1}) + + for graph in [graph_backend, graph_backend.filter().subgraph()]: + spatial_filter = BBoxSpatialFilter(graph, frame_attr_key="t", bbox_attr_key="bbox") + removes = _spy_calls(spatial_filter, "_remove_node") + adds = _spy_calls(spatial_filter, "_add_node") + + graph.update_node_attrs(attrs={"track_id": 7}, node_ids=[node_id]) + + assert removes["n"] == 0 + assert adds["n"] == 0 + # Node is still indexed at its original position. + assert node_id in spatial_filter[0:0.5, 0:3, 0:3].node_ids() + + +def test_bbox_spatial_filter_reindexes_on_bbox_update(graph_backend: BaseGraph) -> None: + """A genuine bbox change must still re-index — guards against over-eager skipping. + + Asserts the rtree is mutated (delete + reinsert); the resulting spatial-query + correctness is covered by ``test_bbox_spatial_filter_updates_node_position``. + """ + graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 4)) + node_id = graph_backend.add_node({"t": 0, "bbox": np.asarray([0, 0, 2, 2])}) + + for graph in [graph_backend, graph_backend.filter().subgraph()]: + spatial_filter = BBoxSpatialFilter(graph, frame_attr_key="t", bbox_attr_key="bbox") + removes = _spy_calls(spatial_filter, "_remove_node") + adds = _spy_calls(spatial_filter, "_add_node") + + graph.update_node_attrs(attrs={"bbox": [np.asarray([20, 20, 22, 22])]}, node_ids=[node_id]) + + assert removes["n"] == 1 + assert adds["n"] == 1 + + +def test_point_spatial_filter_skips_reindex_on_non_spatial_update(graph_backend: BaseGraph) -> None: + """The point-based SpatialFilter must also skip the rtree on non-spatial writes.""" + graph_backend.add_node_attr_key("y", pl.Int64) + graph_backend.add_node_attr_key("x", pl.Int64) + graph_backend.add_node_attr_key("track_id", pl.Int64, -1) + node_id = graph_backend.add_node({"t": 0, "y": 10, "x": 20, "track_id": -1}) + + for graph in [graph_backend, graph_backend.filter().subgraph()]: + spatial_filter = SpatialFilter(graph, attr_keys=["y", "x"]) + removes = _spy_calls(spatial_filter, "_remove_node") + adds = _spy_calls(spatial_filter, "_add_node") + + graph.update_node_attrs(attrs={"track_id": 7}, node_ids=[node_id]) + + assert removes["n"] == 0 + assert adds["n"] == 0 + assert node_id in spatial_filter[5:15, 15:25].node_ids() + + def test_bbox_spatial_filter_handles_list_dtype(graph_backend: BaseGraph) -> None: """Ensure bounding boxes stored as list dtype still work with the spatial filter.""" graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) diff --git a/src/tracksdata/utils/_signal.py b/src/tracksdata/utils/_signal.py index 13ab2340..105786c3 100644 --- a/src/tracksdata/utils/_signal.py +++ b/src/tracksdata/utils/_signal.py @@ -38,13 +38,14 @@ def emit_node_added_events( def emit_node_updated_events( sig: Signal | SignalInstance, event_args: Iterable[tuple[int, dict[str, Any], dict[str, Any]]], + changed_keys: set[str], ) -> None: """ Emit a single batched ``node_updated`` event. - Connected slots always receive one call - ``(list_of_ids, list_of_old_attrs, list_of_new_attrs)``. No-op if there are - no events or the signal has no active listeners. + Connected slots receive one call + ``(list_of_ids, list_of_old_attrs, list_of_new_attrs, changed_keys)``. No-op + if there are no events or the signal has no active listeners. Parameters ---------- @@ -52,13 +53,16 @@ def emit_node_updated_events( The ``node_updated`` signal to emit on. event_args : Iterable[tuple[int, dict[str, Any], dict[str, Any]]] The ``(node_id, old_attrs, new_attrs)`` triples to emit. + changed_keys : set[str] + The attribute keys actually written by this update (uniform across the + batch). Lets connectors skip work when none of the keys they track changed. """ events = list(event_args) if len(events) == 0 or not is_signal_on(sig): return node_ids, old_attrs, new_attrs = zip(*events, strict=True) - sig.emit(list(node_ids), list(old_attrs), list(new_attrs)) + sig.emit(list(node_ids), list(old_attrs), list(new_attrs), set(changed_keys)) def emit_node_removed_events(