Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions simpler_setup/code_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,12 +933,25 @@ def _compare_with_golden(
if not torch.allclose(actual, expected, rtol=self.rtol, atol=self.atol):
# Find mismatches for better error reporting
close_mask = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol)
mismatches = (~close_mask).sum().item()
total = actual.numel()
mismatch_indices = torch.where(~close_mask.flatten())[0]
n_show = min(20, mismatch_indices.numel())
flat_actual = actual.flatten()
flat_expected = expected.flatten()

# Efficiently extract values
show_indices = mismatch_indices[:n_show]
actual_vals = flat_actual[show_indices].tolist()
expected_vals = flat_expected[show_indices].tolist()
detail_str = "\n".join(
f" [{idx}] actual={act}, expected={exp}"
for idx, act, exp in zip(show_indices.tolist(), actual_vals, expected_vals)
)
raise AssertionError(
f"Output '{name}' does not match golden.\n"
f"Mismatched elements: {mismatches}/{total}\n"
f"rtol={self.rtol}, atol={self.atol}"
f"Mismatched elements: {mismatch_indices.numel()}/{total}\n"
f"rtol={self.rtol}, atol={self.atol}\n"
f"First {n_show} mismatches:\n{detail_str}"
)

matched = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol).sum().item()
Expand Down
Loading