Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,20 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None
with self._root.node_added.blocked():
parent_node_ids = self._root.bulk_add_nodes(nodes, indices=indices)

# Defensive: drop NODE_ID from emitted/local-stored attrs in case the root
# backend (e.g. older SQL paths) injected it.
emitted_nodes = [
{key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID}
for node_attrs in nodes
]
if self._is_root_rx_graph:
# The rx root stored these exact dict objects by reference (and does not
# inject NODE_ID), so reuse them in the local view. Root and view then share
# attribute storage, staying in sync without per-write copies — the invariant
# the read-skip in update_node_attrs/update_edge_attrs relies on. (subgraph()
# already shares this way for pre-existing nodes; this keeps add-through-view
# consistent with it.)
emitted_nodes = nodes
else:
# Defensive: non-rx roots (e.g. SQL) may inject NODE_ID; store filtered copies.
emitted_nodes = [
{key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID}
for node_attrs in nodes
]
if self.sync:
node_ids = self._bulk_add_nodes_local(emitted_nodes)
self._add_id_mappings(list(zip(node_ids, parent_node_ids, strict=True)))
Expand Down Expand Up @@ -547,7 +555,7 @@ def remove_node_from_view(self, node_id: int) -> None:
self._remove_node_local(node_id)

if view_signal_on:
self.node_removed.emit(node_id, old_attrs)
self.node_removed.emit([node_id], [old_attrs])

def _add_node_local(self, node_id: int) -> None:
"""
Expand Down Expand Up @@ -611,7 +619,7 @@ def add_node_to_view(self, node_id: int) -> None:
self._add_node_local(node_id)

if is_signal_on(self.node_added):
self.node_added.emit(node_id, self.nodes[node_id].to_dict())
self.node_added.emit([node_id], [self.nodes[node_id].to_dict()])

def add_edge(
self,
Expand Down
92 changes: 88 additions & 4 deletions src/tracksdata/graph/_test/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,10 +1334,18 @@ def test_remove_node_from_view_signals(graph_backend: BaseGraph) -> None:

view = graph_backend.filter().subgraph()

# `node_removed` is a Signal(list, list); slots iterate the batch exactly like
# the real consumers (SpatialFilter._remove_node, GraphArrayView._on_node_removed).
# A non-iterating slot would silently accept a buggy scalar emit — these don't.
view_calls: list[tuple[int, dict]] = []
root_calls: list[tuple[int, dict]] = []
view.node_removed.connect(lambda nid, attrs: view_calls.append((nid, attrs)))
graph_backend.node_removed.connect(lambda nid, attrs: root_calls.append((nid, attrs)))

def _record(sink: list, node_ids: list[int], old_attrs: list[dict]) -> None:
for nid, attrs in zip(node_ids, old_attrs, strict=True):
sink.append((nid, attrs))

view.node_removed.connect(lambda node_ids, old_attrs: _record(view_calls, node_ids, old_attrs))
graph_backend.node_removed.connect(lambda node_ids, old_attrs: _record(root_calls, node_ids, old_attrs))

view.remove_node_from_view(n1)

Expand Down Expand Up @@ -1508,10 +1516,18 @@ def test_add_node_to_view_signals(graph_backend: BaseGraph) -> None:
view = graph_backend.filter().subgraph()
view.remove_node_from_view(n1)

# `node_added` is a Signal(list, list); slots iterate the batch exactly like
# the real consumers (GraphArrayView._on_node_added). A non-iterating slot would
# silently accept a buggy scalar emit — these don't.
view_calls: list[tuple[int, dict]] = []
root_calls: list[tuple[int, dict]] = []
view.node_added.connect(lambda nid, attrs: view_calls.append((nid, attrs)))
graph_backend.node_added.connect(lambda nid, attrs: root_calls.append((nid, attrs)))

def _record(sink: list, node_ids: list[int], new_attrs: list[dict]) -> None:
for nid, attrs in zip(node_ids, new_attrs, strict=True):
sink.append((nid, attrs))

view.node_added.connect(lambda node_ids, new_attrs: _record(view_calls, node_ids, new_attrs))
graph_backend.node_added.connect(lambda node_ids, new_attrs: _record(root_calls, node_ids, new_attrs))

view.add_node_to_view(n1)

Expand Down Expand Up @@ -1543,6 +1559,74 @@ def test_add_node_to_view_validation(graph_backend: BaseGraph) -> None:
view.add_node_to_view(n0)


def test_view_only_helpers_emit_batched_signal(graph_backend: BaseGraph) -> None:
"""Regression: view-only remove/add must emit the batched ``Signal(list, list)``.

``remove_node_from_view``/``add_node_to_view`` previously emitted scalars
(``node_id``, ``attrs``) instead of single-element lists. Real consumers
(SpatialFilter, GraphArrayView) iterate the payload with ``zip(node_ids, attrs)``,
so a scalar emit raised ``EmitLoopError`` (``zip(int, dict)`` -> not iterable).
This slot reproduces that consumer pattern.
"""
graph_backend.add_node_attr_key("x", pl.Float64)
n0 = graph_backend.add_node({"t": 0, "x": 0.0})
n1 = graph_backend.add_node({"t": 1, "x": 1.0})

view = graph_backend.filter().subgraph()

removed: list[tuple[int, dict]] = []
added: list[tuple[int, dict]] = []

def on_removed(node_ids: list[int], old_attrs: list[dict]) -> None:
for nid, attrs in zip(node_ids, old_attrs, strict=True):
removed.append((nid, attrs))

def on_added(node_ids: list[int], new_attrs: list[dict]) -> None:
for nid, attrs in zip(node_ids, new_attrs, strict=True):
added.append((nid, attrs))

view.node_removed.connect(on_removed)
view.node_added.connect(on_added)

# Must not raise EmitLoopError from a consumer iterating the batch.
view.remove_node_from_view(n1)
view.add_node_to_view(n1)

assert removed == [(n1, removed[0][1])]
assert removed[0][1]["x"] == 1.0
assert added == [(n1, added[0][1])]
assert added[0][1]["x"] == 1.0
_ = n0


def test_add_nodes_via_view_shares_storage_with_root(graph_backend: BaseGraph) -> None:
"""Regression: nodes added *through a view* must stay in sync with the root.

``update_node_attrs``/``update_edge_attrs`` skip syncing the view's local store
for rx roots, on the assumption that root and view share attribute-dict storage.
The view add-path previously stored *copies*, so a root write was never reflected
in the view for nodes added through the view. This builds the funtracks pattern
(empty graph -> subgraph -> add via view) and asserts a root write is visible.
"""
graph_backend.add_node_attr_key("area", default_value=0.0, dtype=pl.Float64)

view = graph_backend.filter().subgraph()
if not view._is_root_rx_graph:
# Only rx roots share attribute storage root<->view; for copy-on-subgraph
# backends (e.g. SQL) a direct-on-root write is intentionally not propagated.
pytest.skip("shared-storage invariant only applies to rustworkx-family roots")

(node_id,) = view.bulk_add_nodes(nodes=[{"t": 0, "area": 10.0}])

# Write on the root; the view (and its readers) must observe it.
graph_backend.update_node_attrs(attrs={"area": 99.0}, node_ids=[node_id])

assert view.nodes[node_id]["area"] == 99.0
# Reading via the dataframe API must agree with the per-node accessor.
view_area = view.node_attrs().filter(pl.col(DEFAULT_ATTR_KEYS.NODE_ID) == node_id)["area"].item()
assert view_area == 99.0


def test_add_edge_to_view_basic(graph_backend: BaseGraph) -> None:
"""`add_edge_to_view` re-surfaces a root edge into the view, leaving root untouched."""
graph_backend.add_node_attr_key("x", pl.Float64)
Expand Down
Loading