Skip to content

Commit 5b90399

Browse files
committed
add math verify pool test
1 parent 274056f commit 5b90399

1 file changed

Lines changed: 200 additions & 0 deletions

File tree

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for math_verify_pool grading and score-assignment logic."""
16+
import pytest
17+
18+
pytest.importorskip("math_verify")
19+
pytest.importorskip("sympy")
20+
21+
import unittest
22+
from types import SimpleNamespace
23+
from unittest.mock import patch
24+
25+
import sympy
26+
27+
from maxtext.trainers.post_train.rl import math_verify_pool as mvp
28+
from maxtext.trainers.post_train.rl.math_verify_pool import (
29+
are_equal_under_sympy,
30+
math_verify_pool,
31+
verify_math_worker,
32+
)
33+
34+
35+
def _make_config(reward=1.0):
36+
return SimpleNamespace(reward_exact_answer=reward)
37+
38+
39+
class _FakeAsyncResult:
40+
"""Stand-in for multiprocessing.pool.AsyncResult.
41+
42+
Runs the target synchronously at construction and caches the outcome so
43+
`ready()`/`get()` match the real `AsyncResult` contract without spawning
44+
a worker. Lets us drive `math_verify_pool`'s busy-poll in-process.
45+
"""
46+
47+
def __init__(self, fn, args):
48+
try:
49+
self._value = fn(*args)
50+
self._exc = None
51+
except Exception as exc: # pylint: disable=broad-except
52+
self._value = None
53+
self._exc = exc
54+
55+
def ready(self):
56+
return True
57+
58+
def get(self, timeout=None): # pylint: disable=unused-argument
59+
if self._exc is not None:
60+
raise self._exc
61+
return self._value
62+
63+
64+
class _FakePool:
65+
"""Minimal pool stub: runs `apply_async` synchronously in-process."""
66+
67+
def apply_async(self, fn, args):
68+
return _FakeAsyncResult(fn, args)
69+
70+
71+
def _fake_get_pool(num_procs): # pylint: disable=unused-argument
72+
return _FakePool()
73+
74+
75+
class VerifyMathWorkerTest(unittest.TestCase):
76+
"""Unit tests for the in-process grader (no pool, no spawned workers)."""
77+
78+
def test_exact_numeric_match(self):
79+
idx, score = verify_math_worker(0, ["\\boxed{42}"], ["\\boxed{42}"])
80+
self.assertEqual(idx, 0)
81+
self.assertEqual(score, 1.0)
82+
83+
def test_numeric_mismatch(self):
84+
_, score = verify_math_worker(0, ["\\boxed{42}"], ["\\boxed{99}"])
85+
self.assertEqual(score, 0.0)
86+
87+
def test_multiple_golds_one_matches(self):
88+
_, score = verify_math_worker(
89+
0, ["\\boxed{1}", "\\boxed{42}"], ["\\boxed{42}"]
90+
)
91+
self.assertEqual(score, 1.0)
92+
93+
def test_multiple_golds_none_matches(self):
94+
_, score = verify_math_worker(
95+
0, ["\\boxed{1}", "\\boxed{2}"], ["\\boxed{99}"]
96+
)
97+
self.assertEqual(score, 0.0)
98+
99+
def test_idx_is_echoed_back(self):
100+
idx, _ = verify_math_worker(17, ["\\boxed{5}"], ["\\boxed{5}"])
101+
self.assertEqual(idx, 17)
102+
103+
def test_empty_prediction_returns_zero(self):
104+
_, score = verify_math_worker(0, ["\\boxed{42}"], [""])
105+
self.assertEqual(score, 0.0)
106+
107+
def test_fraction_equivalent_to_decimal(self):
108+
# 1/2 and 0.5 are numerically equal — verify() should catch this even
109+
# if are_equal_under_sympy's structural match does not.
110+
_, score = verify_math_worker(
111+
0, ["\\boxed{\\frac{1}{2}}"], ["\\boxed{0.5}"]
112+
)
113+
self.assertEqual(score, 1.0)
114+
115+
116+
class AreEqualUnderSympyTest(unittest.TestCase):
117+
"""Tests for the structural sympy AST equality helper.
118+
119+
`are_equal_under_sympy` is invoked first inside the worker and short-circuits
120+
`verify()` when it returns True. Its job is cheap structural equality on
121+
unevaluated expressions.
122+
"""
123+
124+
def test_same_integer(self):
125+
self.assertTrue(
126+
are_equal_under_sympy(sympy.Integer(42), sympy.Integer(42))
127+
)
128+
129+
def test_different_integer(self):
130+
self.assertFalse(
131+
are_equal_under_sympy(sympy.Integer(42), sympy.Integer(99))
132+
)
133+
134+
def test_same_symbol(self):
135+
x = sympy.Symbol("x")
136+
self.assertTrue(are_equal_under_sympy(x, x))
137+
138+
def test_malformed_input_does_not_raise(self):
139+
# Unparseable strings must not propagate an exception; they return False.
140+
self.assertFalse(are_equal_under_sympy("$$$", "%%%"))
141+
142+
143+
@patch.object(mvp, "_get_pool", _fake_get_pool)
144+
class MathVerifyPoolScoreAssignmentTest(unittest.TestCase):
145+
"""Regression tests for the score-assignment bug.
146+
147+
Prior version granted `reward_exact_answer` on every completed job, ignoring
148+
the grader's score. These tests exist to keep that bug from coming back.
149+
150+
`_get_pool` is patched with an in-process fake so the busy-poll drains on
151+
the first iteration — no spawn, no 300s global_timeout.
152+
"""
153+
154+
def test_correct_answer_gets_reward(self):
155+
items = [(0, ["\\boxed{42}"], ["\\boxed{42}"])]
156+
scores = [0.0]
157+
result = math_verify_pool(_make_config(1.0), items, scores)
158+
self.assertEqual(result[0], 1.0)
159+
160+
def test_wrong_answer_does_not_get_reward(self):
161+
items = [(0, ["\\boxed{42}"], ["\\boxed{99}"])]
162+
scores = [0.0]
163+
result = math_verify_pool(_make_config(1.0), items, scores)
164+
self.assertEqual(result[0], 0.0)
165+
166+
def test_wrong_answer_preserves_prior_penalty(self):
167+
# `check_numbers` seeds scores[idx] with `penalty_incorrect_format`; a
168+
# wrong grader verdict must not overwrite that with the reward.
169+
items = [(0, ["\\boxed{42}"], ["\\boxed{99}"])]
170+
scores = [-0.5]
171+
result = math_verify_pool(_make_config(1.0), items, scores)
172+
self.assertEqual(result[0], -0.5)
173+
174+
def test_mixed_batch_scores_each_item_independently(self):
175+
items = [
176+
(0, ["\\boxed{1}"], ["\\boxed{1}"]), # correct
177+
(1, ["\\boxed{2}"], ["\\boxed{99}"]), # wrong
178+
(2, ["\\boxed{3}"], ["\\boxed{3}"]), # correct
179+
]
180+
scores = [0.0, 0.0, 0.0]
181+
result = math_verify_pool(_make_config(1.0), items, scores)
182+
self.assertEqual(result[0], 1.0)
183+
self.assertEqual(result[1], 0.0)
184+
self.assertEqual(result[2], 1.0)
185+
186+
def test_reward_uses_max_not_overwrite(self):
187+
# A correct answer must not lower an already-higher pre-existing score.
188+
items = [(0, ["\\boxed{42}"], ["\\boxed{42}"])]
189+
scores = [0.7]
190+
result = math_verify_pool(_make_config(0.3), items, scores)
191+
self.assertEqual(result[0], 0.7)
192+
193+
def test_empty_items_returns_scores_unchanged(self):
194+
scores = [0.1, 0.2, 0.3]
195+
result = math_verify_pool(_make_config(1.0), [], scores)
196+
self.assertEqual(result, [0.1, 0.2, 0.3])
197+
198+
199+
if __name__ == "__main__":
200+
unittest.main()

0 commit comments

Comments
 (0)