Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,21 +1317,26 @@ def update_edge_attrs(
edge_ids = self.edge_ids()

size = len(edge_ids)
# broadcast scalars into a local copy so the caller's dict is not mutated
broadcast_attrs: dict[str, Any] = {}
for key, value in attrs.items():
if key not in self.edge_attr_keys():
raise ValueError(f"Edge attribute key '{key}' not found in graph. Expected '{self.edge_attr_keys()}'")

if np.isscalar(value):
attrs[key] = [value] * size
broadcast_attrs[key] = [value] * size

elif len(attrs[key]) != size:
raise ValueError(f"Attribute '{key}' has wrong size. Expected {size}, got {len(attrs[key])}")
elif len(value) != size:
raise ValueError(f"Attribute '{key}' has wrong size. Expected {size}, got {len(value)}")

else:
broadcast_attrs[key] = value

edge_map = self._graph.edge_index_map()

for i, edge_id in enumerate(edge_ids):
edge_attr = edge_map[edge_id][2] # 0=source, 1=target, 2=attributes
for key, value in attrs.items():
for key, value in broadcast_attrs.items():
edge_attr[key] = value[i]

def assign_tracklet_ids(
Expand Down
5 changes: 5 additions & 0 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,11 @@ def test_update_edge_attrs(graph_backend: BaseGraph) -> None:
with pytest.raises(ValueError):
graph_backend.update_edge_attrs(edge_ids=[edge_id], attrs={"weight": [1.0, 2.0]})

# the caller's attrs dict must not be mutated (e.g. scalar broadcast to a list)
user_attrs = {"weight": 2.0}
graph_backend.update_edge_attrs(edge_ids=[edge_id], attrs=user_attrs)
assert user_attrs == {"weight": 2.0}


def test_num_edges(graph_backend: BaseGraph) -> None:
"""Test counting edges."""
Expand Down
Loading