diff --git a/tests/unit/rl_utils_test.py b/tests/unit/rl_utils_test.py
index 027af794cf..452e5f474a 100644
--- a/tests/unit/rl_utils_test.py
+++ b/tests/unit/rl_utils_test.py
@@ -107,6 +107,89 @@ def test_with_incomplete_reasoning_tags(self):
self.assertFalse(has_correct_format)
+class TestNormalizeFinalAnswer(unittest.TestCase):
+ """Tests for utils_rl.normalize_final_answer."""
+
+ @pytest.mark.cpu_only
+ def test_comma_boxed_and_currency(self):
+ # Comma-separated numbers, \\boxed{}, and leading $ are all normalized to plain integers
+ self.assertEqual(utils_rl.normalize_final_answer("1,000"), "1000")
+ self.assertEqual(utils_rl.normalize_final_answer("$1,000"), "1000")
+ self.assertEqual(utils_rl.normalize_final_answer("\\boxed{1,000}"), "1000")
+
+ @pytest.mark.cpu_only
+ def test_equation_splitting_and_unit_removal(self):
+ # Expressions with '=' are split on '='; trailing unit words are stripped
+ self.assertEqual(utils_rl.normalize_final_answer("x = 10"), "10")
+ self.assertEqual(utils_rl.normalize_final_answer("total = 100 meters"), "100")
+ self.assertEqual(utils_rl.normalize_final_answer("42 mph"), "42")
+
+ @pytest.mark.cpu_only
+ def test_latex_wrappers(self):
+ # \\text{}, \\textbf{}, and \\overline{} wrappers are removed, leaving inner content
+ self.assertEqual(utils_rl.normalize_final_answer("\\text{hello}"), "hello")
+ self.assertEqual(utils_rl.normalize_final_answer("\\textbf{42}"), "42")
+ self.assertEqual(utils_rl.normalize_final_answer("\\overline{AB}"), "AB")
+
+ @pytest.mark.cpu_only
+ def test_dollar_math_extraction(self):
+ # Content inside $...$ is extracted
+ self.assertEqual(utils_rl.normalize_final_answer("The answer is $\\frac{1}{2}$"), "\\frac{1}{2}")
+
+ @pytest.mark.cpu_only
+ def test_shorthand_frac_and_sqrt(self):
+ # Shorthand \\fracab and \\sqrta are expanded to their full LaTeX forms
+ self.assertEqual(utils_rl.normalize_final_answer("\\fracab"), "\\frac{a}{b}")
+ self.assertEqual(utils_rl.normalize_final_answer("\\sqrta"), "\\sqrt{a}")
+
+
+class TestMatchFormatApproximatelyScores(unittest.TestCase):
+ """Tests for utils_rl.match_format_approximately.
+
+ Each tag that appears exactly once contributes reward_partial_format_match (0.5).
+ Each tag that is absent or appears more than once contributes penalty_incorrect_format (-0.5).
+ With 4 tags the score ranges from -2.0 (all wrong) to 2.0 (all correct).
+ """
+
+ def setUp(self):
+ self.config = _make_config()
+
+ def _score(self, completion):
+ return utils_rl.match_format_approximately(None, completion, self.config)
+
+ @pytest.mark.cpu_only
+ def test_score_all_tags_present_exactly_once(self):
+ # All four tags present exactly once -> 4 * 0.5 = 2.0
+ self.assertEqual(self._score(["think42"])[0], 2.0)
+
+ @pytest.mark.cpu_only
+ def test_score_no_tags_present(self):
+ # No tags at all -> 4 * -0.5 = -2.0
+ self.assertEqual(self._score(["The answer is 42."])[0], -2.0)
+
+ @pytest.mark.cpu_only
+ def test_score_only_answer_tags_present(self):
+ # Only ... present -> 2 * 0.5 + 2 * -0.5 = 0.0
+ self.assertEqual(self._score(["42"])[0], 0.0)
+
+ @pytest.mark.cpu_only
+ def test_score_duplicate_reasoning_start_tag(self):
+ # Duplicate tag -> 3 * 0.5 + 1 * -0.5 = 1.0
+ self.assertEqual(self._score(["think42"])[0], 1.0)
+
+ @pytest.mark.cpu_only
+ def test_score_multiple_completions(self):
+ # Multiple completions at once -> one score per entry
+ multi_completions = [
+ "think42", # 2.0
+ "no tags here", # -2.0
+ ]
+ scores = self._score(multi_completions)
+ self.assertEqual(len(scores), 2)
+ self.assertEqual(scores[0], 2.0)
+ self.assertEqual(scores[1], -2.0)
+
+
class TestExtractHashAnswer(unittest.TestCase):
"""Tests for utils_rl.extract_hash_answer."""