From 241383ecb79b1bc8ee611b414a5c7aa2889203ce Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Wed, 11 Mar 2026 18:06:43 +0000 Subject: [PATCH] Add RL unit tests: normalize and match_approximately --- tests/unit/rl_utils_test.py | 83 +++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/unit/rl_utils_test.py b/tests/unit/rl_utils_test.py index 027af794cf..452e5f474a 100644 --- a/tests/unit/rl_utils_test.py +++ b/tests/unit/rl_utils_test.py @@ -107,6 +107,89 @@ def test_with_incomplete_reasoning_tags(self): self.assertFalse(has_correct_format) +class TestNormalizeFinalAnswer(unittest.TestCase): + """Tests for utils_rl.normalize_final_answer.""" + + @pytest.mark.cpu_only + def test_comma_boxed_and_currency(self): + # Comma-separated numbers, \\boxed{}, and leading $ are all normalized to plain integers + self.assertEqual(utils_rl.normalize_final_answer("1,000"), "1000") + self.assertEqual(utils_rl.normalize_final_answer("$1,000"), "1000") + self.assertEqual(utils_rl.normalize_final_answer("\\boxed{1,000}"), "1000") + + @pytest.mark.cpu_only + def test_equation_splitting_and_unit_removal(self): + # Expressions with '=' are split on '='; trailing unit words are stripped + self.assertEqual(utils_rl.normalize_final_answer("x = 10"), "10") + self.assertEqual(utils_rl.normalize_final_answer("total = 100 meters"), "100") + self.assertEqual(utils_rl.normalize_final_answer("42 mph"), "42") + + @pytest.mark.cpu_only + def test_latex_wrappers(self): + # \\text{}, \\textbf{}, and \\overline{} wrappers are removed, leaving inner content + self.assertEqual(utils_rl.normalize_final_answer("\\text{hello}"), "hello") + self.assertEqual(utils_rl.normalize_final_answer("\\textbf{42}"), "42") + self.assertEqual(utils_rl.normalize_final_answer("\\overline{AB}"), "AB") + + @pytest.mark.cpu_only + def test_dollar_math_extraction(self): + # Content inside $...$ is extracted + self.assertEqual(utils_rl.normalize_final_answer("The answer is $\\frac{1}{2}$"), "\\frac{1}{2}") + + @pytest.mark.cpu_only + def test_shorthand_frac_and_sqrt(self): + # Shorthand \\fracab and \\sqrta are expanded to their full LaTeX forms + self.assertEqual(utils_rl.normalize_final_answer("\\fracab"), "\\frac{a}{b}") + self.assertEqual(utils_rl.normalize_final_answer("\\sqrta"), "\\sqrt{a}") + + +class TestMatchFormatApproximatelyScores(unittest.TestCase): + """Tests for utils_rl.match_format_approximately. + + Each tag that appears exactly once contributes reward_partial_format_match (0.5). + Each tag that is absent or appears more than once contributes penalty_incorrect_format (-0.5). + With 4 tags the score ranges from -2.0 (all wrong) to 2.0 (all correct). + """ + + def setUp(self): + self.config = _make_config() + + def _score(self, completion): + return utils_rl.match_format_approximately(None, completion, self.config) + + @pytest.mark.cpu_only + def test_score_all_tags_present_exactly_once(self): + # All four tags present exactly once -> 4 * 0.5 = 2.0 + self.assertEqual(self._score(["think42"])[0], 2.0) + + @pytest.mark.cpu_only + def test_score_no_tags_present(self): + # No tags at all -> 4 * -0.5 = -2.0 + self.assertEqual(self._score(["The answer is 42."])[0], -2.0) + + @pytest.mark.cpu_only + def test_score_only_answer_tags_present(self): + # Only ... present -> 2 * 0.5 + 2 * -0.5 = 0.0 + self.assertEqual(self._score(["42"])[0], 0.0) + + @pytest.mark.cpu_only + def test_score_duplicate_reasoning_start_tag(self): + # Duplicate tag -> 3 * 0.5 + 1 * -0.5 = 1.0 + self.assertEqual(self._score(["think42"])[0], 1.0) + + @pytest.mark.cpu_only + def test_score_multiple_completions(self): + # Multiple completions at once -> one score per entry + multi_completions = [ + "think42", # 2.0 + "no tags here", # -2.0 + ] + scores = self._score(multi_completions) + self.assertEqual(len(scores), 2) + self.assertEqual(scores[0], 2.0) + self.assertEqual(scores[1], -2.0) + + class TestExtractHashAnswer(unittest.TestCase): """Tests for utils_rl.extract_hash_answer."""