-
Notifications
You must be signed in to change notification settings - Fork 42
【Test-Cases】: add UT coverage for core datasets (aime, realworldqa, math, gsm8k, gpqa, dapo_math) #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
【Test-Cases】: add UT coverage for core datasets (aime, realworldqa, math, gsm8k, gpqa, dapo_math) #315
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| import unittest | ||
| from unittest.mock import patch, mock_open | ||
|
|
||
| from datasets import Dataset | ||
|
|
||
| from ais_bench.benchmark.datasets.aime2026 import Aime2026Dataset | ||
|
|
||
|
|
||
| class TestAime2026Dataset(unittest.TestCase): | ||
| @patch("ais_bench.benchmark.datasets.aime2026.get_data_path", return_value="/fake/path.jsonl") | ||
| @patch("builtins.open") | ||
| def test_aime2026_load(self, mock_open_file, mock_get_path): | ||
| line = '{"question": "What is 1+1?", "answer": "2"}' | ||
| m = mock_open(read_data=line + "\n") | ||
| mock_open_file.return_value = m.return_value | ||
| ds = Aime2026Dataset.load("/any") | ||
| self.assertIsInstance(ds, Dataset) | ||
| self.assertEqual(len(ds), 1) | ||
| row = ds[0] | ||
| self.assertEqual(row["question"], "What is 1+1?") | ||
| self.assertEqual(row["answer"], "2") | ||
|
|
||
| @patch("ais_bench.benchmark.datasets.aime2026.get_data_path", return_value="/fake/path.jsonl") | ||
| @patch("builtins.open") | ||
| def test_aime2026_load_multiple_records(self, mock_open_file, mock_get_path): | ||
| lines = '{"id": 1, "value": "a"}\n{"id": 2, "value": "b"}\n{"id": 3, "value": "c"}\n' | ||
| m = mock_open(read_data=lines) | ||
| mock_open_file.return_value = m.return_value | ||
| ds = Aime2026Dataset.load("/any") | ||
| self.assertIsInstance(ds, Dataset) | ||
| self.assertEqual(len(ds), 3) | ||
| self.assertEqual(ds[0]["id"], 1) | ||
| self.assertEqual(ds[1]["value"], "b") | ||
| self.assertEqual(ds[2]["value"], "c") | ||
|
|
||
| @patch("ais_bench.benchmark.datasets.aime2026.get_data_path", return_value="/fake/path.jsonl") | ||
| @patch("builtins.open") | ||
| def test_aime2026_load_strips_whitespace(self, mock_open_file, mock_get_path): | ||
| line = ' {"key": "value"} ' | ||
| m = mock_open(read_data=line + "\n") | ||
| mock_open_file.return_value = m.return_value | ||
| ds = Aime2026Dataset.load("/any") | ||
| self.assertIsInstance(ds, Dataset) | ||
| self.assertEqual(len(ds), 1) | ||
| self.assertEqual(ds[0]["key"], "value") | ||
|
|
||
| @patch("ais_bench.benchmark.datasets.aime2026.get_data_path", return_value="/fake/path.jsonl") | ||
| @patch("builtins.open") | ||
| def test_aime2026_load_fields_preserved(self, mock_open_file, mock_get_path): | ||
| line = '{"origin_prompt": "Solve this", "gold_answer": "42", "id": 101, "subject": "math"}' | ||
| m = mock_open(read_data=line + "\n") | ||
| mock_open_file.return_value = m.return_value | ||
| ds = Aime2026Dataset.load("/any") | ||
| self.assertIsInstance(ds, Dataset) | ||
| row = ds[0] | ||
| self.assertEqual(row["origin_prompt"], "Solve this") | ||
| self.assertEqual(row["gold_answer"], "42") | ||
| self.assertEqual(row["id"], 101) | ||
| self.assertEqual(row["subject"], "math") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |||||||||||||||||
| from datasets import Dataset | ||||||||||||||||||
|
|
||||||||||||||||||
| from ais_bench.benchmark.datasets.aime2024 import Aime2024Dataset | ||||||||||||||||||
| from ais_bench.benchmark.datasets.aime2025 import Aime2025Dataset | ||||||||||||||||||
| from ais_bench.benchmark.datasets.aime2025 import Aime2025Dataset, Aime2025JDGDataset | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class TestAimeDatasets(unittest.TestCase): | ||||||||||||||||||
|
|
@@ -31,6 +31,46 @@ def test_aime2025_load(self, mock_open_file, mock_get_path): | |||||||||||||||||
| self.assertIsInstance(ds, Dataset) | ||||||||||||||||||
| self.assertEqual(len(ds), 1) | ||||||||||||||||||
|
|
||||||||||||||||||
| @patch("ais_bench.benchmark.datasets.aime2024.get_data_path", return_value="/fake/path.jsonl") | ||||||||||||||||||
| @patch("builtins.open") | ||||||||||||||||||
| def test_aime2024_multiple_records(self, mock_open_file, mock_get_path): | ||||||||||||||||||
| line1 = '{"origin_prompt": "Q1?", "gold_answer": "1"}' | ||||||||||||||||||
| line2 = '{"origin_prompt": "Q2?", "gold_answer": "2"}' | ||||||||||||||||||
| m = mock_open(read_data=line1 + "\n" + line2 + "\n") | ||||||||||||||||||
| mock_open_file.return_value = m.return_value | ||||||||||||||||||
|
Comment on lines
+35
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||
| ds = Aime2024Dataset.load("/any") | ||||||||||||||||||
| self.assertIsInstance(ds, Dataset) | ||||||||||||||||||
| self.assertEqual(len(ds), 2) | ||||||||||||||||||
| self.assertIn("question", ds[0]) | ||||||||||||||||||
| self.assertIn("question", ds[1]) | ||||||||||||||||||
|
|
||||||||||||||||||
| @patch("ais_bench.benchmark.datasets.aime2024.get_data_path", return_value="/fake/path.jsonl") | ||||||||||||||||||
| @patch("builtins.open") | ||||||||||||||||||
| def test_aime2024_fields_mapping(self, mock_open_file, mock_get_path): | ||||||||||||||||||
| line = '{"origin_prompt": "What is 1+1?", "gold_answer": "2"}' | ||||||||||||||||||
| m = mock_open(read_data=line + "\n") | ||||||||||||||||||
| mock_open_file.return_value = m.return_value | ||||||||||||||||||
| ds = Aime2024Dataset.load("/any") | ||||||||||||||||||
| row = ds[0] | ||||||||||||||||||
| self.assertEqual(row["question"], "What is 1+1?") | ||||||||||||||||||
| self.assertEqual(row["answer"], "2") | ||||||||||||||||||
|
|
||||||||||||||||||
| @patch("ais_bench.benchmark.datasets.aime2025.get_data_path", return_value="/fake/path.jsonl") | ||||||||||||||||||
| @patch("builtins.open") | ||||||||||||||||||
| def test_aime2025_multiple_records(self, mock_open_file, mock_get_path): | ||||||||||||||||||
| line1 = '{"field": "value1"}' | ||||||||||||||||||
| line2 = '{"field": "value2"}' | ||||||||||||||||||
| m = mock_open(read_data=line1 + "\n" + line2 + "\n") | ||||||||||||||||||
| mock_open_file.return_value = m.return_value | ||||||||||||||||||
| ds = Aime2025Dataset.load("/any") | ||||||||||||||||||
| self.assertIsInstance(ds, Dataset) | ||||||||||||||||||
| self.assertEqual(len(ds), 2) | ||||||||||||||||||
|
|
||||||||||||||||||
| def test_aime2025_jdg_dataset_class(self): | ||||||||||||||||||
| instance = Aime2025JDGDataset.__new__(Aime2025JDGDataset) | ||||||||||||||||||
| result = instance._get_dataset_class() | ||||||||||||||||||
| self.assertIs(result, Aime2025Dataset) | ||||||||||||||||||
|
Comment on lines
+69
to
+72
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||
| unittest.main() | ||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| import pytest | ||
|
|
||
| from ais_bench.benchmark.datasets.dapo_math import ( | ||
| last_boxed_only_string, remove_boxed, normalize_final_answer, | ||
| extract_pred_by_minerva, extract_pred_by_strict_box, | ||
| dapo_math_postprocess, dapo_math_postprocess_v2, | ||
| DAPOMathEvaluator, DAPOMathEvaluatorV2, | ||
| ) | ||
| from ais_bench.benchmark.utils.logging.exceptions import AISBenchDataContentError | ||
|
|
||
|
|
||
| class TestBoxedFunctions: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| def test_last_boxed_found(self): | ||
| result = last_boxed_only_string(r"The answer is \boxed{42}.") | ||
| assert result == r"\boxed{42}" | ||
|
|
||
| def test_last_boxed_not_found(self): | ||
| result = last_boxed_only_string("No boxed expression here.") | ||
| assert result is None | ||
|
|
||
| def test_last_boxed_unclosed(self): | ||
| result = last_boxed_only_string(r"\boxed{42") | ||
| assert result is None | ||
|
|
||
| def test_last_boxed_nested(self): | ||
| result = last_boxed_only_string(r"First \boxed{a} then \boxed{\frac{1}{2}} end.") | ||
| assert result == r"\boxed{\frac{1}{2}}" | ||
|
|
||
| def test_remove_boxed_success(self): | ||
| result = remove_boxed(r"\boxed{42}") | ||
| assert result == "42" | ||
|
|
||
| def test_remove_boxed_invalid_prefix(self): | ||
| with pytest.raises(AISBenchDataContentError): | ||
| remove_boxed(r"\notboxed{42}") | ||
|
|
||
| def test_remove_boxed_invalid_suffix(self): | ||
| with pytest.raises(AISBenchDataContentError): | ||
| remove_boxed(r"\boxed{42") | ||
|
|
||
|
|
||
| class TestNormalizeFinalAnswer: | ||
| def test_basic(self): | ||
| assert normalize_final_answer("42") == "42" | ||
|
|
||
| def test_with_substitutions(self): | ||
| result = normalize_final_answer("an apple") | ||
| assert result == "apple" | ||
|
|
||
| def test_with_removed_expressions(self): | ||
| result = normalize_final_answer("5 degrees") | ||
| assert result == "5" | ||
|
|
||
| def test_with_comma_numbers(self): | ||
| result = normalize_final_answer("1,234") | ||
| assert result == "1234" | ||
|
|
||
| def test_with_textbf(self): | ||
| result = normalize_final_answer(r"\textbf{hello}") | ||
| assert result == "hello" | ||
|
|
||
| def test_with_overline(self): | ||
| result = normalize_final_answer(r"\overline{AB}") | ||
| assert result == "AB" | ||
|
|
||
| def test_with_boxed(self): | ||
| result = normalize_final_answer(r"\boxed{99}") | ||
| assert result == "99" | ||
|
|
||
| def test_with_dollar_signs(self): | ||
| result = normalize_final_answer("$42$") | ||
| assert result == "42" | ||
|
|
||
| def test_with_frac_shorthand(self): | ||
| result = normalize_final_answer(r"frac12") | ||
| assert result == r"frac{1}{2}" | ||
|
|
||
| def test_with_sqrt_shorthand(self): | ||
| result = normalize_final_answer(r"sqrt3") | ||
| assert result == r"sqrt{3}" | ||
|
|
||
|
|
||
| class TestExtractPredByMinerva: | ||
| def test_with_valid_answer(self): | ||
| result = extract_pred_by_minerva("The answer is Answer: 42\nmore text") | ||
| assert result == "42" | ||
|
|
||
| def test_with_no_match_returns_invalid(self): | ||
| result = extract_pred_by_minerva("No answer pattern here.") | ||
| assert result == "[INVALID]" | ||
|
|
||
| def test_case_insensitive(self): | ||
| result = extract_pred_by_minerva("The answer is answer: 100\n") | ||
| assert result == "100" | ||
|
|
||
|
|
||
| class TestExtractPredByStrictBox: | ||
| def test_with_boxed_answer(self): | ||
| result = extract_pred_by_strict_box(r"some text \boxed{42} end") | ||
| assert result == "42" | ||
|
|
||
| def test_without_boxed_returns_empty(self): | ||
| result = extract_pred_by_strict_box("no boxed expression here") | ||
| assert result == "" | ||
|
|
||
|
|
||
| class TestDapoMathPostprocessors: | ||
| def test_dapo_math_postprocess(self): | ||
| result = dapo_math_postprocess("blah Answer: 42\n") | ||
| assert result == "42" | ||
|
|
||
| def test_dapo_math_postprocess_v2(self): | ||
| result = dapo_math_postprocess_v2(r"blah \boxed{42}") | ||
| assert result == "42" | ||
|
|
||
|
|
||
| class TestDAPOMathEvaluator: | ||
| def test_score_basic(self): | ||
| evaluator = DAPOMathEvaluator() | ||
| result = evaluator.score(["42", "100"], ["42", "200"]) | ||
| assert result["accuracy"] == 50.0 | ||
| assert len(result["details"]) == 2 | ||
| assert result["details"][0]["correct"] is True | ||
| assert result["details"][1]["correct"] is False | ||
|
|
||
| def test_score_all_correct(self): | ||
| evaluator = DAPOMathEvaluator() | ||
| result = evaluator.score(["1", "2", "3"], ["1", "2", "3"]) | ||
| assert result["accuracy"] == 100.0 | ||
| assert all(d["correct"] for d in result["details"]) | ||
|
|
||
| def test_score_all_wrong(self): | ||
| evaluator = DAPOMathEvaluator() | ||
| result = evaluator.score(["1", "2", "3"], ["4", "5", "6"]) | ||
| assert result["accuracy"] == 0.0 | ||
| assert not any(d["correct"] for d in result["details"]) | ||
|
|
||
| def test_score_length_mismatch(self): | ||
| evaluator = DAPOMathEvaluator() | ||
| result = evaluator.score(["1", "2"], ["1"]) | ||
| assert "error" in result | ||
| assert "different length" in result["error"] | ||
|
|
||
| def test_score_normalizes_references(self): | ||
| evaluator = DAPOMathEvaluator() | ||
| result = evaluator.score(["1234"], ["1,234"]) | ||
| assert result["accuracy"] == 100.0 | ||
| assert result["details"][0]["correct"] is True | ||
|
|
||
|
|
||
| class TestDAPOMathEvaluatorV2: | ||
| def test_score_basic(self): | ||
| evaluator = DAPOMathEvaluatorV2() | ||
| result = evaluator.score(["42", "100"], ["42", "200"]) | ||
| assert result["accuracy"] == 50.0 | ||
| assert len(result["details"]) == 2 | ||
| assert result["details"][0]["correct"] is True | ||
| assert result["details"][1]["correct"] is False | ||
|
|
||
| def test_score_length_mismatch(self): | ||
| evaluator = DAPOMathEvaluatorV2() | ||
| result = evaluator.score(["1", "2"], ["1"]) | ||
| assert "error" in result | ||
| assert "different length" in result["error"] | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -45,6 +45,55 @@ def test_evaluator_and_postprocess(self): | |||||||||||||||||||||||
| self.assertIn("accuracy", out) | ||||||||||||||||||||||||
| self.assertEqual(GPQA_Simple_Eval_postprocess("Answer: B"), "B") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def test_evaluator_score_length_mismatch(self): | ||||||||||||||||||||||||
| eva = GPQAEvaluator() | ||||||||||||||||||||||||
| result = eva.score(['A', 'B'], ['A']) | ||||||||||||||||||||||||
| self.assertIn('error', result) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def test_evaluator_score_all_correct(self): | ||||||||||||||||||||||||
| eva = GPQAEvaluator() | ||||||||||||||||||||||||
| result = eva.score(['A', 'B', 'C'], ['A', 'B', 'C']) | ||||||||||||||||||||||||
| self.assertEqual(result['accuracy'], 100.0) | ||||||||||||||||||||||||
| self.assertEqual(len(result['details']), 3) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def test_evaluator_score_all_wrong(self): | ||||||||||||||||||||||||
| eva = GPQAEvaluator() | ||||||||||||||||||||||||
| result = eva.score(['A', 'B'], ['C', 'D']) | ||||||||||||||||||||||||
| self.assertEqual(result['accuracy'], 0.0) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def test_evaluator_score_details_structure(self): | ||||||||||||||||||||||||
| eva = GPQAEvaluator() | ||||||||||||||||||||||||
| result = eva.score(['A'], ['B']) | ||||||||||||||||||||||||
| detail = result['details'][0] | ||||||||||||||||||||||||
| self.assertIn('pred', detail) | ||||||||||||||||||||||||
| self.assertIn('answer', detail) | ||||||||||||||||||||||||
| self.assertIn('correct', detail) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def test_postprocess_no_match(self): | ||||||||||||||||||||||||
| result = GPQA_Simple_Eval_postprocess("This has no answer pattern") | ||||||||||||||||||||||||
| self.assertIsNone(result) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def test_postprocess_case_insensitive(self): | ||||||||||||||||||||||||
| result = GPQA_Simple_Eval_postprocess("answer: B") | ||||||||||||||||||||||||
| self.assertEqual(result, 'B') | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @patch("ais_bench.benchmark.datasets.gpqa.get_data_path", return_value="/fake/path") | ||||||||||||||||||||||||
| @patch("builtins.open") | ||||||||||||||||||||||||
| def test_dataset_multiple_rows(self, mock_open_file, mock_get_path): | ||||||||||||||||||||||||
| content = ( | ||||||||||||||||||||||||
| "h0,h1,h2,h3,h4,h5,h6,Question,A,B,C,D\n" | ||||||||||||||||||||||||
| ",,,,,,,Q1,opt1_a,opt1_b,opt1_c,opt1_d\n" | ||||||||||||||||||||||||
| ",,,,,,,Q2,opt2_a,opt2_b,opt2_c,opt2_d\n" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| m = mock_open(read_data=content) | ||||||||||||||||||||||||
| mock_open_file.return_value = m.return_value | ||||||||||||||||||||||||
|
Comment on lines
+81
to
+89
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||||||||
| ds = GPQADataset.load("/any", name="file.csv") | ||||||||||||||||||||||||
| self.assertEqual(len(ds), 2) | ||||||||||||||||||||||||
| self.assertEqual(ds[0]['answer'], 'D') | ||||||||||||||||||||||||
| self.assertEqual(ds[0]['D'], 'opt1_a') | ||||||||||||||||||||||||
| self.assertEqual(ds[1]['answer'], 'C') | ||||||||||||||||||||||||
| self.assertEqual(ds[1]['C'], 'opt2_a') | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||
| unittest.main() | ||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
new_callable=mock_opendirectly in the@patchdecorator is cleaner and avoids manual assignment ofmock_open_file.return_value = m.return_value.