diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 5f6a2645..8fb9e9e4 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -950,7 +950,6 @@ def _filter_nodes_by_attrs( The IDs of the filtered nodes. """ rx_graph = self.rx_graph - node_map = None # entire graph attrs, time = _pop_time_eq(attrs) selected_nodes = None @@ -967,16 +966,15 @@ def _filter_nodes_by_attrs( elif node_ids is not None: selected_nodes = node_ids - if selected_nodes is not None: - # subgraph of selected nodes - rx_graph, node_map = rx_graph.subgraph_with_nodemap(selected_nodes) - _filter_func = _create_filter_func(attrs, self._node_attr_schemas()) - if node_map is None: + if selected_nodes is None: return list(rx_graph.filter_nodes(_filter_func)) - else: - return [node_map[n] for n in rx_graph.filter_nodes(_filter_func)] + + # evaluate the filter directly on the selected nodes' payloads; + # building an rx subgraph here would also copy every edge between + # the selected nodes just to discard them afterwards. + return [node_id for node_id in selected_nodes if _filter_func(rx_graph[node_id])] def node_ids(self) -> list[int]: """