diff --git a/src/tracksdata/solvers/_nearest_neighbors_solver.py b/src/tracksdata/solvers/_nearest_neighbors_solver.py index 21915290..99240417 100644 --- a/src/tracksdata/solvers/_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_nearest_neighbors_solver.py @@ -296,6 +296,8 @@ def solve( if self.output_key not in graph.node_attr_keys(): graph.add_node_attr_key(self.output_key, pl.Boolean) + elif self.reset: + graph.update_node_attrs(attrs={self.output_key: False}) graph.update_node_attrs( node_ids=node_ids, diff --git a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py index 10f91efd..a9a33296 100644 --- a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py @@ -383,3 +383,26 @@ def test_nearest_neighbors_solver_solve_large_graph() -> None: for target_idx in [2, 3, 4, 5, 6]: target_edges = selected_edges.filter(selected_edges[DEFAULT_ATTR_KEYS.EDGE_TARGET] == nodes[target_idx]) assert len(target_edges) <= 1 # one parent constraint + + +def test_nearest_neighbors_solver_reset_node_solution() -> None: + """Re-solving with `reset=True` must clear stale node solution values.""" + graph = RustWorkXGraph() + + node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0}) + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1}) + + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, pl.Float64) + graph.add_edge(node0, node1, {DEFAULT_ATTR_KEYS.EDGE_DIST: 1.0}) + graph.add_edge(node0, node2, {DEFAULT_ATTR_KEYS.EDGE_DIST: 2.0}) + + # First solve selects both edges (and all three nodes) + NearestNeighborsSolver(max_children=2).solve(graph) + assert graph.nodes[node2][DEFAULT_ATTR_KEYS.SOLUTION] + + # Second solve only keeps the best edge; node2 must be reset to False + NearestNeighborsSolver(max_children=1).solve(graph) + assert graph.nodes[node0][DEFAULT_ATTR_KEYS.SOLUTION] + assert graph.nodes[node1][DEFAULT_ATTR_KEYS.SOLUTION] + assert not graph.nodes[node2][DEFAULT_ATTR_KEYS.SOLUTION]