From 61052b25bde907997b8266d66a1e5b6a4d16a368 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 1 Jul 2026 13:24:58 -0700 Subject: [PATCH] fix: RXFilter.node_ids() dropped selected nodes with include flags When include_targets/include_sources was set (or edge filters were present), the rustworkx filter rebuilt the node selection purely from edge endpoints: - explicitly selected or attribute-filtered nodes without matching edges (e.g. isolated nodes) were dropped from node_ids(); - both endpoints were added regardless of which include flag was set, pulling in sources that include_targets never asked for. Keep the base selection and only extend it with the endpoints the include flags request, matching SQLFilter. Pure edge-filtered queries still return the matching edges' endpoints. --- src/tracksdata/graph/_rustworkx_graph.py | 36 +++++++++++-------- .../graph/_test/test_graph_backends.py | 23 ++++++++++++ 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 5f6a2645..bd14972c 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -199,26 +199,34 @@ def _current_node_ids(self) -> list[int]: @cache_method def node_ids(self) -> list[int]: - # if there are no edge filters, we can return the current node ids + # if there are no edge filters nor include flags, we can return the current node ids if not self._edge_attr_comps and (not self._include_targets and not self._include_sources): return self._current_node_ids() - # find nodes that are connected to edges that pass the edge filters + edges_df = self._edge_attrs() + node_filtered = self._node_ids is not None or bool(self._node_attr_comps) + node_ids = [] - edge_node_ids = ( - self._edge_attrs() - .select( - DEFAULT_ATTR_KEYS.EDGE_SOURCE, - DEFAULT_ATTR_KEYS.EDGE_TARGET, + if node_filtered or not self._edge_attr_comps: + # nodes selected by `node_ids`/node filters (or all nodes when + # unfiltered) are always kept; the include flags only extend the + # selection with the respective edge endpoints, matching SQLFilter. + node_ids.append(np.asarray(self._current_node_ids(), dtype=int)) + else: + # only edge filters: nodes are the endpoints of the matching edges + node_ids.append( + edges_df.select( + DEFAULT_ATTR_KEYS.EDGE_SOURCE, + DEFAULT_ATTR_KEYS.EDGE_TARGET, + ) + .to_numpy() + .ravel() ) - .to_numpy() - .ravel() - ) - node_ids.append(edge_node_ids) - if self._node_attr_comps: - # if there are node filters, we need to add the nodes that pass the node filters - node_ids.append(self._current_node_ids()) + if self._include_sources: + node_ids.append(edges_df[DEFAULT_ATTR_KEYS.EDGE_SOURCE].to_numpy()) + if self._include_targets: + node_ids.append(edges_df[DEFAULT_ATTR_KEYS.EDGE_TARGET].to_numpy()) node_ids = [v for v in node_ids if len(v) > 0] diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 641138e5..bf8ff3d0 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -726,6 +726,29 @@ def test_edge_attrs_include_targets(graph_backend: BaseGraph) -> None: assert single_exclusive_edge_ids == expected_single_exclusive, msg +def test_filter_node_ids_with_include_flags(graph_backend: BaseGraph) -> None: + """Selected nodes must be kept when include flags extend the selection.""" + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) + + node0 = graph_backend.add_node({"t": 0}) + node1 = graph_backend.add_node({"t": 1}) + isolated = graph_backend.add_node({"t": 0}) + + graph_backend.add_edge(node0, node1, attrs={"weight": 0.5}) + + # explicitly selected nodes without edges must not be dropped + node_ids = graph_backend.filter(node_ids=[node0, isolated], include_targets=True).node_ids() + assert sorted(node_ids) == sorted([node0, node1, isolated]) + + # node attribute filters behave the same way + node_ids = graph_backend.filter(NodeAttr("t") == 0, include_targets=True).node_ids() + assert sorted(node_ids) == sorted([node0, node1, isolated]) + + # include_sources only extends with edge sources, not targets + node_ids = graph_backend.filter(node_ids=[node0, isolated], include_sources=True).node_ids() + assert sorted(node_ids) == sorted([node0, isolated]) + + def test_from_ctc( ctc_data_dir: Path, graph_backend: BaseGraph,