Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/unit/rl_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,89 @@ def test_with_incomplete_reasoning_tags(self):
self.assertFalse(has_correct_format)


class TestNormalizeFinalAnswer(unittest.TestCase):
"""Tests for utils_rl.normalize_final_answer."""

@pytest.mark.cpu_only
def test_comma_boxed_and_currency(self):
# Comma-separated numbers, \\boxed{}, and leading $ are all normalized to plain integers
self.assertEqual(utils_rl.normalize_final_answer("1,000"), "1000")
self.assertEqual(utils_rl.normalize_final_answer("$1,000"), "1000")
self.assertEqual(utils_rl.normalize_final_answer("\\boxed{1,000}"), "1000")

@pytest.mark.cpu_only
def test_equation_splitting_and_unit_removal(self):
# Expressions with '=' are split on '='; trailing unit words are stripped
self.assertEqual(utils_rl.normalize_final_answer("x = 10"), "10")
self.assertEqual(utils_rl.normalize_final_answer("total = 100 meters"), "100")
self.assertEqual(utils_rl.normalize_final_answer("42 mph"), "42")

@pytest.mark.cpu_only
def test_latex_wrappers(self):
# \\text{}, \\textbf{}, and \\overline{} wrappers are removed, leaving inner content
self.assertEqual(utils_rl.normalize_final_answer("\\text{hello}"), "hello")
self.assertEqual(utils_rl.normalize_final_answer("\\textbf{42}"), "42")
self.assertEqual(utils_rl.normalize_final_answer("\\overline{AB}"), "AB")

@pytest.mark.cpu_only
def test_dollar_math_extraction(self):
# Content inside $...$ is extracted
self.assertEqual(utils_rl.normalize_final_answer("The answer is $\\frac{1}{2}$"), "\\frac{1}{2}")

@pytest.mark.cpu_only
def test_shorthand_frac_and_sqrt(self):
# Shorthand \\fracab and \\sqrta are expanded to their full LaTeX forms
self.assertEqual(utils_rl.normalize_final_answer("\\fracab"), "\\frac{a}{b}")
self.assertEqual(utils_rl.normalize_final_answer("\\sqrta"), "\\sqrt{a}")
Comment thread
hengtaoguo marked this conversation as resolved.


class TestMatchFormatApproximatelyScores(unittest.TestCase):
"""Tests for utils_rl.match_format_approximately.

Each tag that appears exactly once contributes reward_partial_format_match (0.5).
Each tag that is absent or appears more than once contributes penalty_incorrect_format (-0.5).
With 4 tags the score ranges from -2.0 (all wrong) to 2.0 (all correct).
"""

def setUp(self):
self.config = _make_config()

def _score(self, completion):
return utils_rl.match_format_approximately(None, completion, self.config)

@pytest.mark.cpu_only
def test_score_all_tags_present_exactly_once(self):
# All four tags present exactly once -> 4 * 0.5 = 2.0
self.assertEqual(self._score(["<reasoning>think</reasoning><answer>42</answer>"])[0], 2.0)

@pytest.mark.cpu_only
def test_score_no_tags_present(self):
# No tags at all -> 4 * -0.5 = -2.0
self.assertEqual(self._score(["The answer is 42."])[0], -2.0)

@pytest.mark.cpu_only
def test_score_only_answer_tags_present(self):
# Only <answer>...</answer> present -> 2 * 0.5 + 2 * -0.5 = 0.0
self.assertEqual(self._score(["<answer>42</answer>"])[0], 0.0)

@pytest.mark.cpu_only
def test_score_duplicate_reasoning_start_tag(self):
# Duplicate <reasoning> tag -> 3 * 0.5 + 1 * -0.5 = 1.0
self.assertEqual(self._score(["<reasoning><reasoning>think</reasoning><answer>42</answer>"])[0], 1.0)

@pytest.mark.cpu_only
def test_score_multiple_completions(self):
# Multiple completions at once -> one score per entry
multi_completions = [
"<reasoning>think</reasoning><answer>42</answer>", # 2.0
"no tags here", # -2.0
]
scores = self._score(multi_completions)
self.assertEqual(len(scores), 2)
self.assertEqual(scores[0], 2.0)
self.assertEqual(scores[1], -2.0)


class TestExtractHashAnswer(unittest.TestCase):
"""Tests for utils_rl.extract_hash_answer."""

Expand Down
Loading