Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,36 @@ def _edge_attrs(self) -> pl.DataFrame:
data = {k: [] for k in self._graph.edge_attr_keys()}
data[DEFAULT_ATTR_KEYS.EDGE_ID] = []

# Endpoint membership constraints, matching SQLFilter semantics:
# when nodes are selected (explicit `node_ids` or node attribute filters),
# an edge's source must be in the selected set unless `include_sources`,
# and its target must be in the selected set unless `include_targets`.
check_node_ids = None
if self._node_ids is not None and not (self._include_targets or self._include_sources):
if self._node_ids is not None or self._node_attr_comps:
check_node_ids = set(node_ids)

# An edge can be visited twice (out_edges of its source and in_edges of
# its target) when `include_sources` is set; deduplicate by the identity
# of the payload dict, which is unique per edge (unlike EDGE_ID, it is
# always present, e.g. view-local edges of a SQL-rooted GraphView).
seen_edges = set()

# TODO: at this point I think we are better creating a rx subgraph
# and using the filter method
for node_id in node_ids:
for nf in neigh_funcs:
for src, tgt, attr in nf(node_id):
if _filter_func(attr):
if check_node_ids is not None:
if src not in check_node_ids or tgt not in check_node_ids:
continue
if id(attr) in seen_edges:
continue
seen_edges.add(id(attr))

if check_node_ids is not None:
if not self._include_sources and src not in check_node_ids:
continue
if not self._include_targets and tgt not in check_node_ids:
continue

if _filter_func(attr):
sources.append(src)
targets.append(tgt)
for k in data.keys():
Expand Down
33 changes: 33 additions & 0 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,39 @@ def test_edge_attrs_include_targets(graph_backend: BaseGraph) -> None:
assert single_exclusive_edge_ids == expected_single_exclusive, msg


def test_edge_attrs_include_sources(graph_backend: BaseGraph) -> None:
"""Edges must be unique and respect endpoint membership with include_sources."""
graph_backend.add_edge_attr_key("weight", dtype=pl.Float64)

# node0 -> node1 -> node2
node0 = graph_backend.add_node({"t": 0})
node1 = graph_backend.add_node({"t": 1})
node2 = graph_backend.add_node({"t": 2})

edge0 = graph_backend.add_edge(node0, node1, attrs={"weight": 0.1})
edge1 = graph_backend.add_edge(node1, node2, attrs={"weight": 0.2})

# include_sources=True with both endpoints of edge0 selected:
# - edge0: node0 -> node1 ✓ (must appear exactly once)
# - edge1: node1 -> node2 ✗ (node2 not selected and include_targets=False)
edge_ids = graph_backend.filter(node_ids=[node0, node1], include_sources=True).edge_ids()
assert list(edge_ids) == [edge0]

# include_sources=True selecting only the target of edge0:
# - edge0: node0 -> node1 ✓ (in-edge with source outside the selection)
edge_ids = graph_backend.filter(node_ids=[node1], include_sources=True).edge_ids()
assert list(edge_ids) == [edge0]

# node attribute filters must constrain edge endpoints the same way
assert graph_backend.filter(NodeAttr("t") == 1).edge_ids() == []

edge_ids = graph_backend.filter(NodeAttr("t") == 1, include_targets=True).edge_ids()
assert list(edge_ids) == [edge1]

edge_ids = graph_backend.filter(NodeAttr("t") == 1, include_sources=True).edge_ids()
assert list(edge_ids) == [edge0]


def test_filter_node_ids_with_include_flags(graph_backend: BaseGraph) -> None:
"""Selected nodes must be kept when include flags extend the selection."""
graph_backend.add_edge_attr_key("weight", dtype=pl.Float64)
Expand Down
Loading