From 28fb23bd482e8dd05c13543c1b966975dd1fa10e Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 1 Jul 2026 13:22:31 -0700 Subject: [PATCH] fix: align RXFilter edge query semantics with SQLFilter The rustworkx backend's filter returned wrong edges in two ways: 1. With include_sources=True, an edge whose endpoints were both selected was visited twice (out_edges of its source and in_edges of its target) and returned duplicated. 2. Endpoint membership was only enforced when explicit node_ids were given AND no include flag was set. Edges leaving the selected set leaked into the result for attribute-filtered queries (e.g. filter(NodeAttr('t') == 1).edge_attrs() returned edges into t == 2) and whenever include_sources/include_targets was set. This also made filter(...).edge_attrs() disagree with filter(...).subgraph() and with the SQL backend. Deduplicate visited edges and enforce SQLFilter's rule: the source must be in the selected set unless include_sources, the target unless include_targets. --- src/tracksdata/graph/_rustworkx_graph.py | 26 ++++++++++++--- .../graph/_test/test_graph_backends.py | 33 +++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 5f6a2645..dab1e7b4 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -244,20 +244,36 @@ def _edge_attrs(self) -> pl.DataFrame: data = {k: [] for k in self._graph.edge_attr_keys()} data[DEFAULT_ATTR_KEYS.EDGE_ID] = [] + # Endpoint membership constraints, matching SQLFilter semantics: + # when nodes are selected (explicit `node_ids` or node attribute filters), + # an edge's source must be in the selected set unless `include_sources`, + # and its target must be in the selected set unless `include_targets`. check_node_ids = None - if self._node_ids is not None and not (self._include_targets or self._include_sources): + if self._node_ids is not None or self._node_attr_comps: check_node_ids = set(node_ids) + # An edge can be visited twice (out_edges of its source and in_edges of + # its target) when `include_sources` is set; deduplicate by the identity + # of the payload dict, which is unique per edge (unlike EDGE_ID, it is + # always present, e.g. view-local edges of a SQL-rooted GraphView). + seen_edges = set() + # TODO: at this point I think we are better creating a rx subgraph # and using the filter method for node_id in node_ids: for nf in neigh_funcs: for src, tgt, attr in nf(node_id): - if _filter_func(attr): - if check_node_ids is not None: - if src not in check_node_ids or tgt not in check_node_ids: - continue + if id(attr) in seen_edges: + continue + seen_edges.add(id(attr)) + if check_node_ids is not None: + if not self._include_sources and src not in check_node_ids: + continue + if not self._include_targets and tgt not in check_node_ids: + continue + + if _filter_func(attr): sources.append(src) targets.append(tgt) for k in data.keys(): diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 641138e5..74bba72d 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -726,6 +726,39 @@ def test_edge_attrs_include_targets(graph_backend: BaseGraph) -> None: assert single_exclusive_edge_ids == expected_single_exclusive, msg +def test_edge_attrs_include_sources(graph_backend: BaseGraph) -> None: + """Edges must be unique and respect endpoint membership with include_sources.""" + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) + + # node0 -> node1 -> node2 + node0 = graph_backend.add_node({"t": 0}) + node1 = graph_backend.add_node({"t": 1}) + node2 = graph_backend.add_node({"t": 2}) + + edge0 = graph_backend.add_edge(node0, node1, attrs={"weight": 0.1}) + edge1 = graph_backend.add_edge(node1, node2, attrs={"weight": 0.2}) + + # include_sources=True with both endpoints of edge0 selected: + # - edge0: node0 -> node1 ✓ (must appear exactly once) + # - edge1: node1 -> node2 ✗ (node2 not selected and include_targets=False) + edge_ids = graph_backend.filter(node_ids=[node0, node1], include_sources=True).edge_ids() + assert list(edge_ids) == [edge0] + + # include_sources=True selecting only the target of edge0: + # - edge0: node0 -> node1 ✓ (in-edge with source outside the selection) + edge_ids = graph_backend.filter(node_ids=[node1], include_sources=True).edge_ids() + assert list(edge_ids) == [edge0] + + # node attribute filters must constrain edge endpoints the same way + assert graph_backend.filter(NodeAttr("t") == 1).edge_ids() == [] + + edge_ids = graph_backend.filter(NodeAttr("t") == 1, include_targets=True).edge_ids() + assert list(edge_ids) == [edge1] + + edge_ids = graph_backend.filter(NodeAttr("t") == 1, include_sources=True).edge_ids() + assert list(edge_ids) == [edge0] + + def test_from_ctc( ctc_data_dir: Path, graph_backend: BaseGraph,