diff --git a/simpler_setup/code_runner.py b/simpler_setup/code_runner.py index f382aa27..efdd53ff 100644 --- a/simpler_setup/code_runner.py +++ b/simpler_setup/code_runner.py @@ -933,12 +933,25 @@ def _compare_with_golden( if not torch.allclose(actual, expected, rtol=self.rtol, atol=self.atol): # Find mismatches for better error reporting close_mask = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol) - mismatches = (~close_mask).sum().item() total = actual.numel() + mismatch_indices = torch.where(~close_mask.flatten())[0] + n_show = min(20, mismatch_indices.numel()) + flat_actual = actual.flatten() + flat_expected = expected.flatten() + + # Efficiently extract values + show_indices = mismatch_indices[:n_show] + actual_vals = flat_actual[show_indices].tolist() + expected_vals = flat_expected[show_indices].tolist() + detail_str = "\n".join( + f" [{idx}] actual={act}, expected={exp}" + for idx, act, exp in zip(show_indices.tolist(), actual_vals, expected_vals) + ) raise AssertionError( f"Output '{name}' does not match golden.\n" - f"Mismatched elements: {mismatches}/{total}\n" - f"rtol={self.rtol}, atol={self.atol}" + f"Mismatched elements: {mismatch_indices.numel()}/{total}\n" + f"rtol={self.rtol}, atol={self.atol}\n" + f"First {n_show} mismatches:\n{detail_str}" ) matched = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol).sum().item()