From dfe44edf33efd5a756d6bb39f8cf54168e90b78a Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 1 Jul 2026 13:16:23 -0700 Subject: [PATCH] fix: reset stale node solution values in NearestNeighborsSolver When re-solving a graph whose nodes already carried the solution attribute, NearestNeighborsSolver only reset the edge values and left node values from the previous solve untouched. Nodes dropped from the new solution kept solution=True, so the returned solution subgraph contained stale nodes. Mirror the edge path (and ILPSolver) by resetting node values to False when reset=True and the output key already exists. --- .../solvers/_nearest_neighbors_solver.py | 2 ++ .../_test/test_nearest_neighbors_solver.py | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/tracksdata/solvers/_nearest_neighbors_solver.py b/src/tracksdata/solvers/_nearest_neighbors_solver.py index 21915290..99240417 100644 --- a/src/tracksdata/solvers/_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_nearest_neighbors_solver.py @@ -296,6 +296,8 @@ def solve( if self.output_key not in graph.node_attr_keys(): graph.add_node_attr_key(self.output_key, pl.Boolean) + elif self.reset: + graph.update_node_attrs(attrs={self.output_key: False}) graph.update_node_attrs( node_ids=node_ids, diff --git a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py index 10f91efd..a9a33296 100644 --- a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py @@ -383,3 +383,26 @@ def test_nearest_neighbors_solver_solve_large_graph() -> None: for target_idx in [2, 3, 4, 5, 6]: target_edges = selected_edges.filter(selected_edges[DEFAULT_ATTR_KEYS.EDGE_TARGET] == nodes[target_idx]) assert len(target_edges) <= 1 # one parent constraint + + +def test_nearest_neighbors_solver_reset_node_solution() -> None: + """Re-solving with `reset=True` must clear stale node solution values.""" + graph = RustWorkXGraph() + + node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0}) + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1}) + + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, pl.Float64) + graph.add_edge(node0, node1, {DEFAULT_ATTR_KEYS.EDGE_DIST: 1.0}) + graph.add_edge(node0, node2, {DEFAULT_ATTR_KEYS.EDGE_DIST: 2.0}) + + # First solve selects both edges (and all three nodes) + NearestNeighborsSolver(max_children=2).solve(graph) + assert graph.nodes[node2][DEFAULT_ATTR_KEYS.SOLUTION] + + # Second solve only keeps the best edge; node2 must be reset to False + NearestNeighborsSolver(max_children=1).solve(graph) + assert graph.nodes[node0][DEFAULT_ATTR_KEYS.SOLUTION] + assert graph.nodes[node1][DEFAULT_ATTR_KEYS.SOLUTION] + assert not graph.nodes[node2][DEFAULT_ATTR_KEYS.SOLUTION]