Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,26 +199,34 @@ def _current_node_ids(self) -> list[int]:

@cache_method
def node_ids(self) -> list[int]:
# if there are no edge filters, we can return the current node ids
# if there are no edge filters nor include flags, we can return the current node ids
if not self._edge_attr_comps and (not self._include_targets and not self._include_sources):
return self._current_node_ids()

# find nodes that are connected to edges that pass the edge filters
edges_df = self._edge_attrs()
node_filtered = self._node_ids is not None or bool(self._node_attr_comps)

node_ids = []
edge_node_ids = (
self._edge_attrs()
.select(
DEFAULT_ATTR_KEYS.EDGE_SOURCE,
DEFAULT_ATTR_KEYS.EDGE_TARGET,
if node_filtered or not self._edge_attr_comps:
# nodes selected by `node_ids`/node filters (or all nodes when
# unfiltered) are always kept; the include flags only extend the
# selection with the respective edge endpoints, matching SQLFilter.
node_ids.append(np.asarray(self._current_node_ids(), dtype=int))
else:
# only edge filters: nodes are the endpoints of the matching edges
node_ids.append(
edges_df.select(
DEFAULT_ATTR_KEYS.EDGE_SOURCE,
DEFAULT_ATTR_KEYS.EDGE_TARGET,
)
.to_numpy()
.ravel()
)
.to_numpy()
.ravel()
)
node_ids.append(edge_node_ids)

if self._node_attr_comps:
# if there are node filters, we need to add the nodes that pass the node filters
node_ids.append(self._current_node_ids())
if self._include_sources:
node_ids.append(edges_df[DEFAULT_ATTR_KEYS.EDGE_SOURCE].to_numpy())
if self._include_targets:
node_ids.append(edges_df[DEFAULT_ATTR_KEYS.EDGE_TARGET].to_numpy())

node_ids = [v for v in node_ids if len(v) > 0]

Expand Down
23 changes: 23 additions & 0 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,29 @@ def test_edge_attrs_include_targets(graph_backend: BaseGraph) -> None:
assert single_exclusive_edge_ids == expected_single_exclusive, msg


def test_filter_node_ids_with_include_flags(graph_backend: BaseGraph) -> None:
"""Selected nodes must be kept when include flags extend the selection."""
graph_backend.add_edge_attr_key("weight", dtype=pl.Float64)

node0 = graph_backend.add_node({"t": 0})
node1 = graph_backend.add_node({"t": 1})
isolated = graph_backend.add_node({"t": 0})

graph_backend.add_edge(node0, node1, attrs={"weight": 0.5})

# explicitly selected nodes without edges must not be dropped
node_ids = graph_backend.filter(node_ids=[node0, isolated], include_targets=True).node_ids()
assert sorted(node_ids) == sorted([node0, node1, isolated])

# node attribute filters behave the same way
node_ids = graph_backend.filter(NodeAttr("t") == 0, include_targets=True).node_ids()
assert sorted(node_ids) == sorted([node0, node1, isolated])

# include_sources only extends with edge sources, not targets
node_ids = graph_backend.filter(node_ids=[node0, isolated], include_sources=True).node_ids()
assert sorted(node_ids) == sorted([node0, isolated])


def test_from_ctc(
ctc_data_dir: Path,
graph_backend: BaseGraph,
Expand Down
Loading