diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 5f6a2645..8e172f97 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -1317,21 +1317,26 @@ def update_edge_attrs( edge_ids = self.edge_ids() size = len(edge_ids) + # broadcast scalars into a local copy so the caller's dict is not mutated + broadcast_attrs: dict[str, Any] = {} for key, value in attrs.items(): if key not in self.edge_attr_keys(): raise ValueError(f"Edge attribute key '{key}' not found in graph. Expected '{self.edge_attr_keys()}'") if np.isscalar(value): - attrs[key] = [value] * size + broadcast_attrs[key] = [value] * size - elif len(attrs[key]) != size: - raise ValueError(f"Attribute '{key}' has wrong size. Expected {size}, got {len(attrs[key])}") + elif len(value) != size: + raise ValueError(f"Attribute '{key}' has wrong size. Expected {size}, got {len(value)}") + + else: + broadcast_attrs[key] = value edge_map = self._graph.edge_index_map() for i, edge_id in enumerate(edge_ids): edge_attr = edge_map[edge_id][2] # 0=source, 1=target, 2=attributes - for key, value in attrs.items(): + for key, value in broadcast_attrs.items(): edge_attr[key] = value[i] def assign_tracklet_ids( diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 641138e5..099e1ff8 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -607,6 +607,11 @@ def test_update_edge_attrs(graph_backend: BaseGraph) -> None: with pytest.raises(ValueError): graph_backend.update_edge_attrs(edge_ids=[edge_id], attrs={"weight": [1.0, 2.0]}) + # the caller's attrs dict must not be mutated (e.g. scalar broadcast to a list) + user_attrs = {"weight": 2.0} + graph_backend.update_edge_attrs(edge_ids=[edge_id], attrs=user_attrs) + assert user_attrs == {"weight": 2.0} + def test_num_edges(graph_backend: BaseGraph) -> None: """Test counting edges."""