Skip to content

Commit 1facf72

Browse files
committed
Add RL unit tests: normalize and match_approximately
1 parent e3dbd54 commit 1facf72

1 file changed

Lines changed: 79 additions & 0 deletions

File tree

tests/unit/rl_utils_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,85 @@ def test_with_incomplete_reasoning_tags(self):
107107
self.assertFalse(has_correct_format)
108108

109109

110+
class TestNormalizeFinalAnswer(unittest.TestCase):
111+
"""Tests for utils_rl.normalize_final_answer."""
112+
113+
@pytest.mark.cpu_only
114+
def test_normalize_final_answer(self):
115+
"""Comma-separated numbers, \\boxed{}, and leading $ are all normalized to plain integers."""
116+
self.assertEqual(utils_rl.normalize_final_answer("1,000"), "1000")
117+
self.assertEqual(utils_rl.normalize_final_answer("$1,000"), "1000")
118+
self.assertEqual(utils_rl.normalize_final_answer("\\boxed{1,000}"), "1000")
119+
120+
"""Expressions with '=' are split on '='; trailing unit words are stripped."""
121+
self.assertEqual(utils_rl.normalize_final_answer("x = 10"), "10")
122+
self.assertEqual(utils_rl.normalize_final_answer("total = 100 meters"), "100")
123+
self.assertEqual(utils_rl.normalize_final_answer("42 mph"), "42")
124+
125+
"""\\text{}, \\textbf{}, and \\overline{} wrappers are removed, leaving inner content."""
126+
self.assertEqual(utils_rl.normalize_final_answer("\\text{hello}"), "hello")
127+
self.assertEqual(utils_rl.normalize_final_answer("\\textbf{42}"), "42")
128+
self.assertEqual(utils_rl.normalize_final_answer("\\overline{AB}"), "AB")
129+
130+
"""Content inside $...$ is extracted."""
131+
self.assertEqual(utils_rl.normalize_final_answer("The answer is $\\frac{1}{2}$"), "\\frac{1}{2}")
132+
133+
"""Shorthand \\fracab and \\sqrta are expanded to their full LaTeX forms."""
134+
self.assertEqual(utils_rl.normalize_final_answer("\\fracab"), "\\frac{a}{b}")
135+
self.assertEqual(utils_rl.normalize_final_answer("\\sqrta"), "\\sqrt{a}")
136+
137+
138+
class TestMatchFormatApproximately(unittest.TestCase):
139+
"""Tests for utils_rl.match_format_approximately.
140+
141+
Each tag that appears exactly once contributes reward_partial_format_match (0.5).
142+
Each tag that is absent or appears more than once contributes penalty_incorrect_format (-0.5).
143+
With 4 tags the score ranges from -2.0 (all wrong) to 2.0 (all correct).
144+
"""
145+
146+
def setUp(self):
147+
self.config = _make_config()
148+
149+
def _score(self, completion):
150+
return utils_rl.match_format_approximately(None, [completion], self.config)
151+
152+
@pytest.mark.cpu_only
153+
def test_score_0_no_tags_present(self):
154+
"""No tags at all -> each of the 4 tags triggers penalty -> score = 4 * -0.5 = -2.0."""
155+
completion = "The answer is 42."
156+
self.assertEqual(self._score(completion)[0], -2.0)
157+
158+
@pytest.mark.cpu_only
159+
def test_score_1_duplicate_reasoning_start_tag(self):
160+
"""Duplicate <reasoning> tag -> that tag penalised; other three correct -> 3*0.5 + (-0.5) = 1.0."""
161+
completion = "<reasoning><reasoning>think</reasoning><answer>42</answer>"
162+
self.assertEqual(self._score(completion)[0], 1.0)
163+
164+
@pytest.mark.cpu_only
165+
def test_score_2_only_answer_tags_present(self):
166+
"""Only answer open/close tags present once -> 2 rewards + 2 penalties = 0.0."""
167+
completion = "<answer>42</answer>"
168+
self.assertEqual(self._score(completion)[0], 0.0)
169+
170+
@pytest.mark.cpu_only
171+
def test_score_4_all_tags_present_exactly_once(self):
172+
"""All four tags appear exactly once -> score = 4 * 0.5 = 2.0."""
173+
completion = "<reasoning>think</reasoning><answer>42</answer>"
174+
self.assertEqual(self._score(completion)[0], 2.0)
175+
176+
@pytest.mark.cpu_only
177+
def test_score_multiple_completions(self):
178+
"""Passing multiple completions returns one score per completion."""
179+
completions = [
180+
"<reasoning>think</reasoning><answer>42</answer>", # 2.0
181+
"no tags here", # -2.0
182+
]
183+
scores = utils_rl.match_format_approximately(None, completions, self.config)
184+
self.assertEqual(len(scores), 2)
185+
self.assertEqual(scores[0], 2.0)
186+
self.assertEqual(scores[1], -2.0)
187+
188+
110189
class TestExtractHashAnswer(unittest.TestCase):
111190
"""Tests for utils_rl.extract_hash_answer."""
112191

0 commit comments

Comments
 (0)