diff --git a/benchmarks/graph_mutations.py b/benchmarks/graph_mutations.py index 99678527..b20579cf 100644 --- a/benchmarks/graph_mutations.py +++ b/benchmarks/graph_mutations.py @@ -9,12 +9,14 @@ from itertools import pairwise +import numpy as np import polars as pl import tracksdata as td from benchmarks.common import BACKENDS, IS_CI from tracksdata.attrs import NodeAttr from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.graph.filters import BBoxSpatialFilter # Total node count. Tuned so the current (pre-fix) `remove_node` finishes # within the per-benchmark timeout — see PR discussion for sizing rationale. @@ -105,3 +107,60 @@ def time_update_node_attrs_view_with_listener(self, backend_name: str, n_nodes: def time_filter_node_ids(self, backend_name: str, n_nodes: int) -> None: self.graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) >= 1).node_ids() + + +def _build_bbox_graph(backend_name: str, n_nodes: int) -> td.graph.BaseGraph: + """Graph whose nodes carry a bbox, so a real BBoxSpatialFilter can index them.""" + graph = BACKENDS[backend_name]() + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Array(pl.Int64, 4)) + graph.add_node_attr_key("score", dtype=pl.Float64) + nodes_per_lineage = max(2, n_nodes // N_LINEAGES) + for lineage in range(N_LINEAGES): + graph.bulk_add_nodes( + [ + { + DEFAULT_ATTR_KEYS.T: t, + DEFAULT_ATTR_KEYS.BBOX: np.asarray([lineage, t, lineage + 2, t + 2]), + "score": 0.0, + } + for t in range(nodes_per_lineage) + ] + ) + return graph + + +class SpatialFilterUpdateBenchmark: + """`update_node_attrs` with a live BBoxSpatialFilter attached to `node_updated`. + + This is the `assign_tracklet_ids` regression case: a non-spatial bulk write + (e.g. ``tracklet_id``) emits ``node_updated``, and before the fix the filter + re-indexes every node in the rtree even though no bbox/frame changed. The + `_noop`-listener benchmark above only times the producer-side payload build; + this one times the consumer-side rtree work that actually regressed. + """ + + param_names = ("backend", "n_nodes") + params = (tuple(BACKENDS), NODE_SIZES) + + number = 1 + warmup_time = 0 + timeout = 300 + + def setup(self, backend_name: str, n_nodes: int) -> None: + self.graph = _build_bbox_graph(backend_name, n_nodes) + # Attach a real spatial filter so node_updated drives rtree mutations. + self.spatial_filter = BBoxSpatialFilter( + self.graph, frame_attr_key=DEFAULT_ATTR_KEYS.T, bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX + ) + self.target_ids = self.graph.node_ids() + + def time_update_non_spatial_attr_with_filter(self, backend_name: str, n_nodes: int) -> None: + # Non-spatial write: bbox/frame untouched -> should be O(1) for the filter + # after the fix, O(N) rtree churn before it. + self.graph.update_node_attrs(node_ids=self.target_ids, attrs={"score": 1.0}) + + def time_update_bbox_attr_with_filter(self, backend_name: str, n_nodes: int) -> None: + # Spatial write: bbox genuinely changes -> filter must re-index. Guards + # against the fix over-eagerly skipping legitimate updates. + new_bboxes = [np.asarray([i, i, i + 3, i + 3]) for i in range(len(self.target_ids))] + self.graph.update_node_attrs(node_ids=self.target_ids, attrs={DEFAULT_ATTR_KEYS.BBOX: new_bboxes})