Skip to content

Commit bcf2617

Browse files
author
sjduan
committed
Feat: enhance golden validation error output with mismatch details
- Add detailed mismatch information to golden validation error messages - Show first 20 mismatched elements with indices and actual/expected values - Use vectorized operations for better performance - Improve debugging experience for test failures
1 parent 554bf89 commit bcf2617

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

simpler_setup/code_runner.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -933,12 +933,25 @@ def _compare_with_golden(
933933
if not torch.allclose(actual, expected, rtol=self.rtol, atol=self.atol):
934934
# Find mismatches for better error reporting
935935
close_mask = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol)
936-
mismatches = (~close_mask).sum().item()
937936
total = actual.numel()
937+
mismatch_indices = torch.where(~close_mask.flatten())[0]
938+
n_show = min(20, mismatch_indices.numel())
939+
flat_actual = actual.flatten()
940+
flat_expected = expected.flatten()
941+
942+
# Efficiently extract values
943+
show_indices = mismatch_indices[:n_show]
944+
actual_vals = flat_actual[show_indices].tolist()
945+
expected_vals = flat_expected[show_indices].tolist()
946+
detail_str = "\n".join(
947+
f" [{idx}] actual={act}, expected={exp}"
948+
for idx, act, exp in zip(show_indices.tolist(), actual_vals, expected_vals)
949+
)
938950
raise AssertionError(
939951
f"Output '{name}' does not match golden.\n"
940-
f"Mismatched elements: {mismatches}/{total}\n"
941-
f"rtol={self.rtol}, atol={self.atol}"
952+
f"Mismatched elements: {mismatch_indices.numel()}/{total}\n"
953+
f"rtol={self.rtol}, atol={self.atol}\n"
954+
f"First {n_show} mismatches:\n{detail_str}"
942955
)
943956

944957
matched = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol).sum().item()

0 commit comments

Comments
 (0)