Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions benchmarks/graph_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

from itertools import pairwise

import numpy as np
import polars as pl

import tracksdata as td
from benchmarks.common import BACKENDS, IS_CI
from tracksdata.attrs import NodeAttr
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.graph.filters import BBoxSpatialFilter

# Total node count. Tuned so the current (pre-fix) `remove_node` finishes
# within the per-benchmark timeout — see PR discussion for sizing rationale.
Expand Down Expand Up @@ -105,3 +107,60 @@ def time_update_node_attrs_view_with_listener(self, backend_name: str, n_nodes:

def time_filter_node_ids(self, backend_name: str, n_nodes: int) -> None:
self.graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) >= 1).node_ids()


def _build_bbox_graph(backend_name: str, n_nodes: int) -> td.graph.BaseGraph:
"""Graph whose nodes carry a bbox, so a real BBoxSpatialFilter can index them."""
graph = BACKENDS[backend_name]()
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Array(pl.Int64, 4))
graph.add_node_attr_key("score", dtype=pl.Float64)
nodes_per_lineage = max(2, n_nodes // N_LINEAGES)
for lineage in range(N_LINEAGES):
graph.bulk_add_nodes(
[
{
DEFAULT_ATTR_KEYS.T: t,
DEFAULT_ATTR_KEYS.BBOX: np.asarray([lineage, t, lineage + 2, t + 2]),
"score": 0.0,
}
for t in range(nodes_per_lineage)
]
)
return graph


class SpatialFilterUpdateBenchmark:
"""`update_node_attrs` with a live BBoxSpatialFilter attached to `node_updated`.

This is the `assign_tracklet_ids` regression case: a non-spatial bulk write
(e.g. ``tracklet_id``) emits ``node_updated``, and before the fix the filter
re-indexes every node in the rtree even though no bbox/frame changed. The
`_noop`-listener benchmark above only times the producer-side payload build;
this one times the consumer-side rtree work that actually regressed.
"""

param_names = ("backend", "n_nodes")
params = (tuple(BACKENDS), NODE_SIZES)

number = 1
warmup_time = 0
timeout = 300

def setup(self, backend_name: str, n_nodes: int) -> None:
self.graph = _build_bbox_graph(backend_name, n_nodes)
# Attach a real spatial filter so node_updated drives rtree mutations.
self.spatial_filter = BBoxSpatialFilter(
self.graph, frame_attr_key=DEFAULT_ATTR_KEYS.T, bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX
)
self.target_ids = self.graph.node_ids()

def time_update_non_spatial_attr_with_filter(self, backend_name: str, n_nodes: int) -> None:
# Non-spatial write: bbox/frame untouched -> should be O(1) for the filter
# after the fix, O(N) rtree churn before it.
self.graph.update_node_attrs(node_ids=self.target_ids, attrs={"score": 1.0})

def time_update_bbox_attr_with_filter(self, backend_name: str, n_nodes: int) -> None:
# Spatial write: bbox genuinely changes -> filter must re-index. Guards
# against the fix over-eagerly skipping legitimate updates.
new_bboxes = [np.asarray([i, i, i + 3, i + 3]) for i in range(len(self.target_ids))]
self.graph.update_node_attrs(node_ids=self.target_ids, attrs={DEFAULT_ATTR_KEYS.BBOX: new_bboxes})
Loading