Skip to content

Commit da0bda8

Browse files
committed
fix
1 parent 1facf72 commit da0bda8

1 file changed

Lines changed: 20 additions & 35 deletions

File tree

tests/unit/rl_utils_test.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -112,30 +112,30 @@ class TestNormalizeFinalAnswer(unittest.TestCase):
112112

113113
@pytest.mark.cpu_only
114114
def test_normalize_final_answer(self):
115-
"""Comma-separated numbers, \\boxed{}, and leading $ are all normalized to plain integers."""
115+
# Comma-separated numbers, \\boxed{}, and leading $ are all normalized to plain integers
116116
self.assertEqual(utils_rl.normalize_final_answer("1,000"), "1000")
117117
self.assertEqual(utils_rl.normalize_final_answer("$1,000"), "1000")
118118
self.assertEqual(utils_rl.normalize_final_answer("\\boxed{1,000}"), "1000")
119119

120-
"""Expressions with '=' are split on '='; trailing unit words are stripped."""
120+
# Expressions with '=' are split on '='; trailing unit words are stripped
121121
self.assertEqual(utils_rl.normalize_final_answer("x = 10"), "10")
122122
self.assertEqual(utils_rl.normalize_final_answer("total = 100 meters"), "100")
123123
self.assertEqual(utils_rl.normalize_final_answer("42 mph"), "42")
124124

125-
"""\\text{}, \\textbf{}, and \\overline{} wrappers are removed, leaving inner content."""
125+
# \\text{}, \\textbf{}, and \\overline{} wrappers are removed, leaving inner content
126126
self.assertEqual(utils_rl.normalize_final_answer("\\text{hello}"), "hello")
127127
self.assertEqual(utils_rl.normalize_final_answer("\\textbf{42}"), "42")
128128
self.assertEqual(utils_rl.normalize_final_answer("\\overline{AB}"), "AB")
129129

130-
"""Content inside $...$ is extracted."""
130+
# Content inside $...$ is extracted
131131
self.assertEqual(utils_rl.normalize_final_answer("The answer is $\\frac{1}{2}$"), "\\frac{1}{2}")
132132

133-
"""Shorthand \\fracab and \\sqrta are expanded to their full LaTeX forms."""
133+
# Shorthand \\fracab and \\sqrta are expanded to their full LaTeX forms
134134
self.assertEqual(utils_rl.normalize_final_answer("\\fracab"), "\\frac{a}{b}")
135135
self.assertEqual(utils_rl.normalize_final_answer("\\sqrta"), "\\sqrt{a}")
136136

137137

138-
class TestMatchFormatApproximately(unittest.TestCase):
138+
class TestMatchFormatApproximatelyScores(unittest.TestCase):
139139
"""Tests for utils_rl.match_format_approximately.
140140
141141
Each tag that appears exactly once contributes reward_partial_format_match (0.5).
@@ -147,40 +147,25 @@ def setUp(self):
147147
self.config = _make_config()
148148

149149
def _score(self, completion):
150-
return utils_rl.match_format_approximately(None, [completion], self.config)
150+
return utils_rl.match_format_approximately(None, completion, self.config)
151151

152152
@pytest.mark.cpu_only
153-
def test_score_0_no_tags_present(self):
154-
"""No tags at all -> each of the 4 tags triggers penalty -> score = 4 * -0.5 = -2.0."""
155-
completion = "The answer is 42."
156-
self.assertEqual(self._score(completion)[0], -2.0)
157-
158-
@pytest.mark.cpu_only
159-
def test_score_1_duplicate_reasoning_start_tag(self):
160-
"""Duplicate <reasoning> tag -> that tag penalised; other three correct -> 3*0.5 + (-0.5) = 1.0."""
161-
completion = "<reasoning><reasoning>think</reasoning><answer>42</answer>"
162-
self.assertEqual(self._score(completion)[0], 1.0)
163-
164-
@pytest.mark.cpu_only
165-
def test_score_2_only_answer_tags_present(self):
166-
"""Only answer open/close tags present once -> 2 rewards + 2 penalties = 0.0."""
167-
completion = "<answer>42</answer>"
168-
self.assertEqual(self._score(completion)[0], 0.0)
169-
170-
@pytest.mark.cpu_only
171-
def test_score_4_all_tags_present_exactly_once(self):
172-
"""All four tags appear exactly once -> score = 4 * 0.5 = 2.0."""
173-
completion = "<reasoning>think</reasoning><answer>42</answer>"
174-
self.assertEqual(self._score(completion)[0], 2.0)
175-
176-
@pytest.mark.cpu_only
177-
def test_score_multiple_completions(self):
178-
"""Passing multiple completions returns one score per completion."""
179-
completions = [
153+
def test_partial_format_scores(self):
154+
"""Scores cover the full range depending on how many tags appear exactly once."""
155+
# All four tags present exactly once -> 4 * 0.5 = 2.0
156+
self.assertEqual(self._score(["<reasoning>think</reasoning><answer>42</answer>"])[0], 2.0)
157+
# No tags at all -> 4 * -0.5 = -2.0
158+
self.assertEqual(self._score(["The answer is 42."])[0], -2.0)
159+
# Only <answer>...</answer> present -> 2 * 0.5 + 2 * -0.5 = 0.0
160+
self.assertEqual(self._score(["<answer>42</answer>"])[0], 0.0)
161+
# Duplicate <reasoning> tag -> 3 * 0.5 + 1 * -0.5 = 1.0
162+
self.assertEqual(self._score(["<reasoning><reasoning>think</reasoning><answer>42</answer>"])[0], 1.0)
163+
# Multiple completions at once -> one score per entry
164+
multi_completions = [
180165
"<reasoning>think</reasoning><answer>42</answer>", # 2.0
181166
"no tags here", # -2.0
182167
]
183-
scores = utils_rl.match_format_approximately(None, completions, self.config)
168+
scores = self._score(multi_completions)
184169
self.assertEqual(len(scores), 2)
185170
self.assertEqual(scores[0], 2.0)
186171
self.assertEqual(scores[1], -2.0)

0 commit comments

Comments
 (0)