From e9d9cab98fc761a9dd0a09383dc0cae4816be39b Mon Sep 17 00:00:00 2001 From: qq_46439621 Date: Sat, 30 May 2026 09:16:00 +0000 Subject: [PATCH] test: add UT coverage for core datasets (aime, realworldqa, math, gsm8k, gpqa, dapo_math) - test_aime2026.py: new file with 4 tests for Aime2026Dataset - test_dapo_math.py: new file with 31 tests covering boxed functions, normalization, extractors, postprocessors and evaluators - test_aime_datasets.py: +4 tests (multi-record, field mapping, JDG class) - test_gsm8k.py: +4 tests (get_action, agent score paths) - test_gpqa.py: +7 tests (evaluator edge cases, postprocess, multi-row) - test_realworldqa.py: +5 tests (colon variants, tabs, dict refs) - test_math.py: +52 tests (boxed edge cases, normalize branches, strip_string v1/v2 paths, is_equiv versions, agent evaluator paths) Total: ~107 new tests, all passing pytest. --- tests/UT/datasets/test_aime2026.py | 63 +++++++ tests/UT/datasets/test_aime_datasets.py | 42 ++++- tests/UT/datasets/test_dapo_math.py | 164 +++++++++++++++++ tests/UT/datasets/test_gpqa.py | 49 ++++++ tests/UT/datasets/test_gsm8k.py | 37 ++++ tests/UT/datasets/test_math.py | 224 ++++++++++++++++++++++++ tests/UT/datasets/test_realworldqa.py | 32 ++++ 7 files changed, 610 insertions(+), 1 deletion(-) create mode 100644 tests/UT/datasets/test_aime2026.py create mode 100644 tests/UT/datasets/test_dapo_math.py diff --git a/tests/UT/datasets/test_aime2026.py b/tests/UT/datasets/test_aime2026.py new file mode 100644 index 00000000..8fa13401 --- /dev/null +++ b/tests/UT/datasets/test_aime2026.py @@ -0,0 +1,63 @@ +import unittest +from unittest.mock import patch, mock_open + +from datasets import Dataset + +from ais_bench.benchmark.datasets.aime2026 import Aime2026Dataset + + +class TestAime2026Dataset(unittest.TestCase): + @patch("ais_bench.benchmark.datasets.aime2026.get_data_path", return_value="/fake/path.jsonl") + @patch("builtins.open") + def test_aime2026_load(self, mock_open_file, mock_get_path): + line = '{"question": "What is 1+1?", "answer": "2"}' + m = mock_open(read_data=line + "\n") + mock_open_file.return_value = m.return_value + ds = Aime2026Dataset.load("/any") + self.assertIsInstance(ds, Dataset) + self.assertEqual(len(ds), 1) + row = ds[0] + self.assertEqual(row["question"], "What is 1+1?") + self.assertEqual(row["answer"], "2") + + @patch("ais_bench.benchmark.datasets.aime2026.get_data_path", return_value="/fake/path.jsonl") + @patch("builtins.open") + def test_aime2026_load_multiple_records(self, mock_open_file, mock_get_path): + lines = '{"id": 1, "value": "a"}\n{"id": 2, "value": "b"}\n{"id": 3, "value": "c"}\n' + m = mock_open(read_data=lines) + mock_open_file.return_value = m.return_value + ds = Aime2026Dataset.load("/any") + self.assertIsInstance(ds, Dataset) + self.assertEqual(len(ds), 3) + self.assertEqual(ds[0]["id"], 1) + self.assertEqual(ds[1]["value"], "b") + self.assertEqual(ds[2]["value"], "c") + + @patch("ais_bench.benchmark.datasets.aime2026.get_data_path", return_value="/fake/path.jsonl") + @patch("builtins.open") + def test_aime2026_load_strips_whitespace(self, mock_open_file, mock_get_path): + line = ' {"key": "value"} ' + m = mock_open(read_data=line + "\n") + mock_open_file.return_value = m.return_value + ds = Aime2026Dataset.load("/any") + self.assertIsInstance(ds, Dataset) + self.assertEqual(len(ds), 1) + self.assertEqual(ds[0]["key"], "value") + + @patch("ais_bench.benchmark.datasets.aime2026.get_data_path", return_value="/fake/path.jsonl") + @patch("builtins.open") + def test_aime2026_load_fields_preserved(self, mock_open_file, mock_get_path): + line = '{"origin_prompt": "Solve this", "gold_answer": "42", "id": 101, "subject": "math"}' + m = mock_open(read_data=line + "\n") + mock_open_file.return_value = m.return_value + ds = Aime2026Dataset.load("/any") + self.assertIsInstance(ds, Dataset) + row = ds[0] + self.assertEqual(row["origin_prompt"], "Solve this") + self.assertEqual(row["gold_answer"], "42") + self.assertEqual(row["id"], 101) + self.assertEqual(row["subject"], "math") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/UT/datasets/test_aime_datasets.py b/tests/UT/datasets/test_aime_datasets.py index 24f707c0..1cc380d5 100644 --- a/tests/UT/datasets/test_aime_datasets.py +++ b/tests/UT/datasets/test_aime_datasets.py @@ -4,7 +4,7 @@ from datasets import Dataset from ais_bench.benchmark.datasets.aime2024 import Aime2024Dataset -from ais_bench.benchmark.datasets.aime2025 import Aime2025Dataset +from ais_bench.benchmark.datasets.aime2025 import Aime2025Dataset, Aime2025JDGDataset class TestAimeDatasets(unittest.TestCase): @@ -31,6 +31,46 @@ def test_aime2025_load(self, mock_open_file, mock_get_path): self.assertIsInstance(ds, Dataset) self.assertEqual(len(ds), 1) + @patch("ais_bench.benchmark.datasets.aime2024.get_data_path", return_value="/fake/path.jsonl") + @patch("builtins.open") + def test_aime2024_multiple_records(self, mock_open_file, mock_get_path): + line1 = '{"origin_prompt": "Q1?", "gold_answer": "1"}' + line2 = '{"origin_prompt": "Q2?", "gold_answer": "2"}' + m = mock_open(read_data=line1 + "\n" + line2 + "\n") + mock_open_file.return_value = m.return_value + ds = Aime2024Dataset.load("/any") + self.assertIsInstance(ds, Dataset) + self.assertEqual(len(ds), 2) + self.assertIn("question", ds[0]) + self.assertIn("question", ds[1]) + + @patch("ais_bench.benchmark.datasets.aime2024.get_data_path", return_value="/fake/path.jsonl") + @patch("builtins.open") + def test_aime2024_fields_mapping(self, mock_open_file, mock_get_path): + line = '{"origin_prompt": "What is 1+1?", "gold_answer": "2"}' + m = mock_open(read_data=line + "\n") + mock_open_file.return_value = m.return_value + ds = Aime2024Dataset.load("/any") + row = ds[0] + self.assertEqual(row["question"], "What is 1+1?") + self.assertEqual(row["answer"], "2") + + @patch("ais_bench.benchmark.datasets.aime2025.get_data_path", return_value="/fake/path.jsonl") + @patch("builtins.open") + def test_aime2025_multiple_records(self, mock_open_file, mock_get_path): + line1 = '{"field": "value1"}' + line2 = '{"field": "value2"}' + m = mock_open(read_data=line1 + "\n" + line2 + "\n") + mock_open_file.return_value = m.return_value + ds = Aime2025Dataset.load("/any") + self.assertIsInstance(ds, Dataset) + self.assertEqual(len(ds), 2) + + def test_aime2025_jdg_dataset_class(self): + instance = Aime2025JDGDataset.__new__(Aime2025JDGDataset) + result = instance._get_dataset_class() + self.assertIs(result, Aime2025Dataset) + if __name__ == "__main__": unittest.main() diff --git a/tests/UT/datasets/test_dapo_math.py b/tests/UT/datasets/test_dapo_math.py new file mode 100644 index 00000000..739771e8 --- /dev/null +++ b/tests/UT/datasets/test_dapo_math.py @@ -0,0 +1,164 @@ +import pytest + +from ais_bench.benchmark.datasets.dapo_math import ( + last_boxed_only_string, remove_boxed, normalize_final_answer, + extract_pred_by_minerva, extract_pred_by_strict_box, + dapo_math_postprocess, dapo_math_postprocess_v2, + DAPOMathEvaluator, DAPOMathEvaluatorV2, +) +from ais_bench.benchmark.utils.logging.exceptions import AISBenchDataContentError + + +class TestBoxedFunctions: + def test_last_boxed_found(self): + result = last_boxed_only_string(r"The answer is \boxed{42}.") + assert result == r"\boxed{42}" + + def test_last_boxed_not_found(self): + result = last_boxed_only_string("No boxed expression here.") + assert result is None + + def test_last_boxed_unclosed(self): + result = last_boxed_only_string(r"\boxed{42") + assert result is None + + def test_last_boxed_nested(self): + result = last_boxed_only_string(r"First \boxed{a} then \boxed{\frac{1}{2}} end.") + assert result == r"\boxed{\frac{1}{2}}" + + def test_remove_boxed_success(self): + result = remove_boxed(r"\boxed{42}") + assert result == "42" + + def test_remove_boxed_invalid_prefix(self): + with pytest.raises(AISBenchDataContentError): + remove_boxed(r"\notboxed{42}") + + def test_remove_boxed_invalid_suffix(self): + with pytest.raises(AISBenchDataContentError): + remove_boxed(r"\boxed{42") + + +class TestNormalizeFinalAnswer: + def test_basic(self): + assert normalize_final_answer("42") == "42" + + def test_with_substitutions(self): + result = normalize_final_answer("an apple") + assert result == "apple" + + def test_with_removed_expressions(self): + result = normalize_final_answer("5 degrees") + assert result == "5" + + def test_with_comma_numbers(self): + result = normalize_final_answer("1,234") + assert result == "1234" + + def test_with_textbf(self): + result = normalize_final_answer(r"\textbf{hello}") + assert result == "hello" + + def test_with_overline(self): + result = normalize_final_answer(r"\overline{AB}") + assert result == "AB" + + def test_with_boxed(self): + result = normalize_final_answer(r"\boxed{99}") + assert result == "99" + + def test_with_dollar_signs(self): + result = normalize_final_answer("$42$") + assert result == "42" + + def test_with_frac_shorthand(self): + result = normalize_final_answer(r"frac12") + assert result == r"frac{1}{2}" + + def test_with_sqrt_shorthand(self): + result = normalize_final_answer(r"sqrt3") + assert result == r"sqrt{3}" + + +class TestExtractPredByMinerva: + def test_with_valid_answer(self): + result = extract_pred_by_minerva("The answer is Answer: 42\nmore text") + assert result == "42" + + def test_with_no_match_returns_invalid(self): + result = extract_pred_by_minerva("No answer pattern here.") + assert result == "[INVALID]" + + def test_case_insensitive(self): + result = extract_pred_by_minerva("The answer is answer: 100\n") + assert result == "100" + + +class TestExtractPredByStrictBox: + def test_with_boxed_answer(self): + result = extract_pred_by_strict_box(r"some text \boxed{42} end") + assert result == "42" + + def test_without_boxed_returns_empty(self): + result = extract_pred_by_strict_box("no boxed expression here") + assert result == "" + + +class TestDapoMathPostprocessors: + def test_dapo_math_postprocess(self): + result = dapo_math_postprocess("blah Answer: 42\n") + assert result == "42" + + def test_dapo_math_postprocess_v2(self): + result = dapo_math_postprocess_v2(r"blah \boxed{42}") + assert result == "42" + + +class TestDAPOMathEvaluator: + def test_score_basic(self): + evaluator = DAPOMathEvaluator() + result = evaluator.score(["42", "100"], ["42", "200"]) + assert result["accuracy"] == 50.0 + assert len(result["details"]) == 2 + assert result["details"][0]["correct"] is True + assert result["details"][1]["correct"] is False + + def test_score_all_correct(self): + evaluator = DAPOMathEvaluator() + result = evaluator.score(["1", "2", "3"], ["1", "2", "3"]) + assert result["accuracy"] == 100.0 + assert all(d["correct"] for d in result["details"]) + + def test_score_all_wrong(self): + evaluator = DAPOMathEvaluator() + result = evaluator.score(["1", "2", "3"], ["4", "5", "6"]) + assert result["accuracy"] == 0.0 + assert not any(d["correct"] for d in result["details"]) + + def test_score_length_mismatch(self): + evaluator = DAPOMathEvaluator() + result = evaluator.score(["1", "2"], ["1"]) + assert "error" in result + assert "different length" in result["error"] + + def test_score_normalizes_references(self): + evaluator = DAPOMathEvaluator() + result = evaluator.score(["1234"], ["1,234"]) + assert result["accuracy"] == 100.0 + assert result["details"][0]["correct"] is True + + +class TestDAPOMathEvaluatorV2: + def test_score_basic(self): + evaluator = DAPOMathEvaluatorV2() + result = evaluator.score(["42", "100"], ["42", "200"]) + assert result["accuracy"] == 50.0 + assert len(result["details"]) == 2 + assert result["details"][0]["correct"] is True + assert result["details"][1]["correct"] is False + + def test_score_length_mismatch(self): + evaluator = DAPOMathEvaluatorV2() + result = evaluator.score(["1", "2"], ["1"]) + assert "error" in result + assert "different length" in result["error"] diff --git a/tests/UT/datasets/test_gpqa.py b/tests/UT/datasets/test_gpqa.py index 185fc2ad..14eaf59a 100644 --- a/tests/UT/datasets/test_gpqa.py +++ b/tests/UT/datasets/test_gpqa.py @@ -45,6 +45,55 @@ def test_evaluator_and_postprocess(self): self.assertIn("accuracy", out) self.assertEqual(GPQA_Simple_Eval_postprocess("Answer: B"), "B") + def test_evaluator_score_length_mismatch(self): + eva = GPQAEvaluator() + result = eva.score(['A', 'B'], ['A']) + self.assertIn('error', result) + + def test_evaluator_score_all_correct(self): + eva = GPQAEvaluator() + result = eva.score(['A', 'B', 'C'], ['A', 'B', 'C']) + self.assertEqual(result['accuracy'], 100.0) + self.assertEqual(len(result['details']), 3) + + def test_evaluator_score_all_wrong(self): + eva = GPQAEvaluator() + result = eva.score(['A', 'B'], ['C', 'D']) + self.assertEqual(result['accuracy'], 0.0) + + def test_evaluator_score_details_structure(self): + eva = GPQAEvaluator() + result = eva.score(['A'], ['B']) + detail = result['details'][0] + self.assertIn('pred', detail) + self.assertIn('answer', detail) + self.assertIn('correct', detail) + + def test_postprocess_no_match(self): + result = GPQA_Simple_Eval_postprocess("This has no answer pattern") + self.assertIsNone(result) + + def test_postprocess_case_insensitive(self): + result = GPQA_Simple_Eval_postprocess("answer: B") + self.assertEqual(result, 'B') + + @patch("ais_bench.benchmark.datasets.gpqa.get_data_path", return_value="/fake/path") + @patch("builtins.open") + def test_dataset_multiple_rows(self, mock_open_file, mock_get_path): + content = ( + "h0,h1,h2,h3,h4,h5,h6,Question,A,B,C,D\n" + ",,,,,,,Q1,opt1_a,opt1_b,opt1_c,opt1_d\n" + ",,,,,,,Q2,opt2_a,opt2_b,opt2_c,opt2_d\n" + ) + m = mock_open(read_data=content) + mock_open_file.return_value = m.return_value + ds = GPQADataset.load("/any", name="file.csv") + self.assertEqual(len(ds), 2) + self.assertEqual(ds[0]['answer'], 'D') + self.assertEqual(ds[0]['D'], 'opt1_a') + self.assertEqual(ds[1]['answer'], 'C') + self.assertEqual(ds[1]['C'], 'opt2_a') + if __name__ == "__main__": unittest.main() diff --git a/tests/UT/datasets/test_gsm8k.py b/tests/UT/datasets/test_gsm8k.py index 8fa9e5f3..bb760e22 100644 --- a/tests/UT/datasets/test_gsm8k.py +++ b/tests/UT/datasets/test_gsm8k.py @@ -122,6 +122,43 @@ def test_agent_evaluator_score(self): out3 = eva.score(predictions=['5'], references=[5], steps=steps3) self.assertIn('follow_acc', out3) + def test_get_action_found(self): + eva = Gsm8kAgentEvaluator.__new__(Gsm8kAgentEvaluator) + eva.action = 'PythonInterpreter' + steps = [{'type': 'PythonInterpreter', 'result': {'text': '5'}, 'errmsg': ''}] + result = eva.get_action(steps) + self.assertIsNotNone(result) + self.assertEqual(result['type'], 'PythonInterpreter') + + def test_get_action_not_found(self): + eva = Gsm8kAgentEvaluator.__new__(Gsm8kAgentEvaluator) + eva.action = 'PythonInterpreter' + steps = [{'type': 'Other', 'result': {'text': '5'}, 'errmsg': ''}] + result = eva.get_action(steps) + self.assertIsNone(result) + + def test_score_with_action_error(self): + eva = Gsm8kAgentEvaluator.__new__(Gsm8kAgentEvaluator) + eva.action = 'PythonInterpreter' + steps = [[ + {'type': 'PythonInterpreter', 'errmsg': 'error occurred', 'result': {'text': '5'}} + ]] + out = eva.score(predictions=['6'], references=[5], steps=steps) + self.assertIn('code_acc', out) + self.assertEqual(out['code_acc'], 0) + self.assertEqual(out['action_pct'], 100) + + def test_score_pred_correct_no_action(self): + eva = Gsm8kAgentEvaluator.__new__(Gsm8kAgentEvaluator) + eva.action = 'PythonInterpreter' + steps = [ + [{'type': 'Other'}], + [{'type': 'PythonInterpreter', 'errmsg': '', 'result': {'text': '7'}}] + ] + out = eva.score(predictions=['5', '6'], references=[5, 7], steps=steps) + self.assertIn('follow_acc', out) + self.assertEqual(out['follow_acc'], 50.0) + if __name__ == '__main__': unittest.main() diff --git a/tests/UT/datasets/test_math.py b/tests/UT/datasets/test_math.py index 4d599004..85d55092 100644 --- a/tests/UT/datasets/test_math.py +++ b/tests/UT/datasets/test_math.py @@ -87,6 +87,22 @@ def test_extract_boxed_answer_no_box(self): result = extract_boxed_answer(text) self.assertIsNone(result) + def test_last_boxed_only_string_multiple(self): + text = "First \\boxed{1} then \\boxed{2}" + result = last_boxed_only_string(text) + self.assertEqual(result, '\\boxed{2}') + + def test_extract_boxed_answer_fbox(self): + from ais_bench.benchmark.utils.logging.exceptions import AISBenchDataContentError + text = "The answer is \\fbox{99}" + with self.assertRaises(AISBenchDataContentError): + extract_boxed_answer(text) + + def test_extract_boxed_answer_no_double_brace(self): + text = "\\boxed{42}" + result = extract_boxed_answer(text, strip_double_curly_brace=False) + self.assertEqual(result, '42') + class TestNormalizeFinalAnswer(unittest.TestCase): """测试 normalize_final_answer""" @@ -141,6 +157,27 @@ def test_normalize_with_comma_digits(self): result = normalize_final_answer("100,000") self.assertEqual(result, '100000') + def test_normalize_with_le(self): + result = normalize_final_answer("x \\le 5") + self.assertNotIn('\\le', result) + self.assertIn('<', result) + + def test_normalize_textbf(self): + result = normalize_final_answer("\\textbf{hello}") + self.assertEqual(result, 'hello') + + def test_normalize_overline(self): + result = normalize_final_answer("\\overline{AB}") + self.assertEqual(result, 'AB') + + def test_normalize_sqrt_shorthand(self): + result = normalize_final_answer("sqrt3") + self.assertIn('sqrt{3}', result) + + def test_normalize_with_non_digit_comma(self): + result = normalize_final_answer("hello,world") + self.assertIn('hello', result) + class TestExtractAnswer(unittest.TestCase): """测试 extract_answer""" @@ -163,6 +200,16 @@ def test_extract_answer_not_found(self): result = extract_answer(text) self.assertEqual(result, '') + def test_extract_answer_with_extra_text(self): + text = "blah\nANSWER: 42\nmore stuff" + result = extract_answer(text) + self.assertEqual(result, '42') + + def test_extract_answer_multiple_patterns(self): + text = "first ANSWER: wrong\nsecond ANSWER: right" + result = extract_answer(text) + self.assertEqual(result, 'wrong') + class TestMATHDataset(unittest.TestCase): """测试 MATHDataset""" @@ -400,6 +447,128 @@ def test_strip_string_v2_with_trailing_zeros(self): result = evaluator._strip_string_v2('42.000') self.assertEqual(result, '42') + def test_fix_fracs_short_substring(self): + evaluator = MATHEvaluator() + result = evaluator._fix_fracs('\\frac1') + self.assertEqual(result, '\\frac1') + + def test_strip_string_empty(self): + evaluator = MATHEvaluator() + result = evaluator._strip_string('') + self.assertEqual(result, '') + + def test_strip_string_leading_dot(self): + evaluator = MATHEvaluator() + result = evaluator._strip_string('.5') + self.assertEqual(result, '\\frac{1}{2}') + + def test_strip_string_with_circ(self): + evaluator = MATHEvaluator() + result = evaluator._strip_string('90^{\\circ}') + self.assertNotIn('circ', result) + + def test_strip_string_with_percentage(self): + evaluator = MATHEvaluator() + result = evaluator._strip_string('50\\%') + self.assertNotIn('%', result) + + def test_strip_string_with_tfrac(self): + evaluator = MATHEvaluator() + result = evaluator._strip_string('\\tfrac{1}{2}') + self.assertNotIn('tfrac', result) + self.assertIn('frac', result) + + def test_strip_string_with_dfrac(self): + evaluator = MATHEvaluator() + result = evaluator._strip_string('\\dfrac{1}{2}') + self.assertNotIn('dfrac', result) + + def test_strip_string_with_left_right(self): + evaluator = MATHEvaluator() + result = evaluator._strip_string('\\left(42\\right)') + self.assertNotIn('\\left', result) + self.assertNotIn('\\right', result) + + def test_is_equiv_v2_same(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator.is_equiv('42', '42') + self.assertTrue(result) + + def test_is_equiv_v2_different(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator.is_equiv('42', '43') + self.assertFalse(result) + + def test_is_equiv_with_normalization(self): + evaluator = MATHEvaluator() + result = evaluator.is_equiv('100,000', '100000') + self.assertTrue(result) + + def test_score_details_structure(self): + evaluator = MATHEvaluator() + result = evaluator.score(['42'], ['42']) + detail = result['details'][0] + self.assertIn('pred', detail) + self.assertIn('answer', detail) + self.assertIn('correct', detail) + self.assertTrue(detail['correct']) + + def test_fix_a_slash_b_non_int(self): + evaluator = MATHEvaluator() + result = evaluator._fix_a_slash_b('a/b') + self.assertEqual(result, 'a/b') + + def test_fix_sqrt_no_sqrt(self): + evaluator = MATHEvaluator() + result = evaluator._fix_sqrt('42') + self.assertEqual(result, '42') + + def test_strip_string_v2_with_cdot(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator._strip_string_v2('3\\cdot4') + self.assertNotIn('\\cdot', result) + + def test_strip_string_v2_with_mbox(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator._strip_string_v2('\\mbox{text}') + self.assertNotIn('\\mbox', result) + + def test_strip_string_v2_with_inf(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator._strip_string_v2('inf') + self.assertIn('\\infty', result) + + def test_strip_string_v2_with_mathbf(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator._strip_string_v2('\\mathbf{x}') + self.assertNotIn('\\mathbf', result) + + def test_strip_string_v2_trailing_dot(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator._strip_string_v2('42.') + self.assertEqual(result, '42') + + def test_strip_string_v2_with_equals_short(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator._strip_string_v2('x=42') + self.assertEqual(result, '42') + + def test_strip_string_v2_leading_dot(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator._strip_string_v2('.5') + self.assertTrue(result.startswith('0')) + + def test_strip_string_v2_empty(self): + evaluator = MATHEvaluator(version='v2') + result = evaluator._strip_string_v2('') + self.assertEqual(result, '') + + def test_remove_right_units_multiple_splits(self): + from ais_bench.benchmark.utils.logging.exceptions import AISBenchDataContentError + evaluator = MATHEvaluator() + with self.assertRaises(AISBenchDataContentError): + evaluator._remove_right_units('a\\text{ b\\text{ c') + class TestMATHAgentEvaluator(unittest.TestCase): """测试 MATHAgentEvaluator""" @@ -461,6 +630,61 @@ def test_score_basic(self): self.assertIn('code_acc', result) self.assertIn('action_pct', result) + def test_soft_equal_key_error(self): + evaluator = MATHAgentEvaluator() + step = {} + result = evaluator.soft_equal('pred', '42', step) + self.assertFalse(result) + + def test_soft_equal_type_error(self): + evaluator = MATHAgentEvaluator() + step = {'result': None} + result = evaluator.soft_equal('pred', '42', step) + self.assertFalse(result) + + def test_score_with_action_no_error(self): + evaluator = MATHAgentEvaluator() + predictions = ['42'] + references = ['43'] + steps = [ + [{'type': 'PythonInterpreter', 'result': {'text': '42'}, 'errmsg': None}] + ] + result = evaluator.score(predictions, references, steps) + self.assertEqual(result['code_acc'], 100) + self.assertEqual(result['action_pct'], 100) + + def test_score_pred_correct_with_action(self): + evaluator = MATHAgentEvaluator() + predictions = ['42'] + references = ['42'] + steps = [ + [{'type': 'PythonInterpreter', 'result': {'text': '42'}, 'errmsg': None}] + ] + result = evaluator.score(predictions, references, steps) + self.assertEqual(result['follow_acc'], 100) + self.assertEqual(result['reasoning_acc'], 100) + + def test_score_pred_correct_no_action(self): + evaluator = MATHAgentEvaluator() + predictions = ['42', '43'] + references = ['42', '44'] + steps = [ + [], + [{'type': 'PythonInterpreter', 'result': {'text': '43'}, 'errmsg': None}] + ] + result = evaluator.score(predictions, references, steps) + self.assertEqual(result['follow_acc'], 50.0) + + def test_get_action_returns_last_match(self): + evaluator = MATHAgentEvaluator() + steps = [ + {'type': 'PythonInterpreter', 'id': 1}, + {'type': 'Other'}, + {'type': 'PythonInterpreter', 'id': 2}, + ] + result = evaluator.get_action(steps) + self.assertEqual(result['id'], 2) + if __name__ == '__main__': unittest.main() diff --git a/tests/UT/datasets/test_realworldqa.py b/tests/UT/datasets/test_realworldqa.py index bc3c6043..c1da8d41 100644 --- a/tests/UT/datasets/test_realworldqa.py +++ b/tests/UT/datasets/test_realworldqa.py @@ -146,6 +146,38 @@ def test_open_prompt_format(self): self.assertIn("ANSWER: [ANSWER]", formatted) self.assertIn("step by step", formatted) + def test_extract_answer_with_colon_no_space(self): + evaluator = RealworldQAEvaluator() + result = evaluator._extract_answer("ANSWER:answer") + self.assertEqual(result, "answer") + + def test_extract_answer_with_multiple_colons(self): + evaluator = RealworldQAEvaluator() + result = evaluator._extract_answer("ANSWER: a: b") + self.assertEqual(result, "a: b") + + def test_normalize_answer_tabs(self): + evaluator = RealworldQAEvaluator() + self.assertEqual(evaluator._normalize_answer("\ttab\t"), "tab") + + def test_score_with_all_dict_references(self): + evaluator = RealworldQAEvaluator() + result = evaluator.score( + ["ANSWER: A", "ANSWER: B"], + [{"answer": "A"}, {"answer": "B"}], + ) + self.assertEqual(result["accuracy"], 100.0) + self.assertTrue(all(d["correct"] for d in result["details"])) + + def test_score_mixed_string_and_dict(self): + evaluator = RealworldQAEvaluator() + result = evaluator.score( + ["ANSWER: A", "ANSWER: B", "ANSWER: C"], + ["A", {"answer": "B"}, "C"], + ) + self.assertEqual(result["accuracy"], 100.0) + self.assertEqual(len(result["details"]), 3) + if __name__ == "__main__": unittest.main() \ No newline at end of file