Skip to content

Commit 0f10655

Browse files
SurbhiJainUSCA9isha
andcommitted
Add open-r1/OpenR1-Math-220k dataset and nvidia/OpenMathReasoning to RL and fix reward function
Co-authored-by: A9isha <mazumdera@google.com>
1 parent 5182e3b commit 0f10655

11 files changed

Lines changed: 1220 additions & 394 deletions

File tree

.github/workflows/build_and_test_maxtext.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ jobs:
199199
base_image: maxtext-unit-test-tpu:py312
200200
cloud_runner: linux-x86-ct6e-180-4tpu
201201
pytest_marker: 'not cpu_only and not gpu_only and not integration_test and not post_training'
202+
pytest_addopts: '--ignore=tests/post_training'
202203
xla_python_client_mem_fraction: 0.75
203204
tf_force_gpu_allow_growth: false
204205
container_resource_option: "--privileged"
@@ -217,6 +218,7 @@ jobs:
217218
base_image: maxtext-unit-test-tpu:py312
218219
cloud_runner: linux-x86-ct6e-180-4tpu
219220
pytest_marker: 'not cpu_only and not gpu_only and integration_test and not post_training'
221+
pytest_addopts: '--ignore=tests/post_training'
220222
xla_python_client_mem_fraction: 0.75
221223
tf_force_gpu_allow_growth: false
222224
container_resource_option: "--privileged"

src/maxtext/configs/post_train/rl.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ generation_configs:
148148
num_eval_passes: 1 # Number of generation passes during evaluation
149149
eval_corr_lst: False # If True, only include correct responses in the list during evaluation
150150
eval_make_lst: False # If True, return a list of (question, answer, responses) during evaluation
151+
eval_mode: "pass" # Evaluation mode ("pass" for pass@K, "maj" for majority voting maj@K, "pass_at_1" for pass@1 estimation)
151152

152153
# ====== Inference ======
153154
# === Generation during GRPO training ===
@@ -190,6 +191,12 @@ reward_ratio_guess_to_answer_low: 0.0
190191
penalty_incorrect_format: 0.0
191192
penalty_incorrect_answer: 0.0
192193

194+
# ====== Configuration for math_verify Pool ======
195+
# Timeout (seconds) for math_verify
196+
math_verify_timeout: 300
197+
# Max worker processes for the math_verify pool. null ⇒ min(batch_size, cpu_count())
198+
math_verify_num_procs: null
199+
193200
# ====== Special tokens/templates for GSM8K reasoning ======
194201
reasoning_start_token: '<reasoning>'
195202
reasoning_end_token: '</reasoning>'
@@ -200,8 +207,8 @@ skip_jax_distributed_system: True
200207

201208
# # TODO(@mazumdera): fix this
202209
# Dataset Configuration
203-
dataset_name: 'gsm8k' # huggingface:open-r1/DAPO-Math-17k-Processed
204-
eval_dataset_name: 'gsm8k' # huggingface:BytedTsinghua-SIA/AIME-2024
210+
dataset_name: 'gsm8k' # open-r1/DAPO-Math-17k-Processed
211+
eval_dataset_name: 'gsm8k' # BytedTsinghua-SIA/AIME-2024
205212
train_split: 'train'
206213
eval_split: 'test'
207214
tokenizer_type: 'huggingface'

src/maxtext/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,10 @@ class RLEvaluation(BaseModel):
18171817
False,
18181818
description="If True, return a list of (question, answer, responses) during evaluation.",
18191819
)
1820+
eval_mode: Literal["pass", "maj", "pass_at_1"] = Field(
1821+
"pass",
1822+
description="Evaluation mode to use ('pass' for pass@K, 'maj' for maj@K, 'pass_at_1' for pass@1 estimation).",
1823+
)
18201824

18211825

18221826
class Reward(BaseModel):
@@ -1834,6 +1838,11 @@ class Reward(BaseModel):
18341838
)
18351839
penalty_incorrect_format: float = Field(-0.5, description="Penalty for an incorrect format.")
18361840
penalty_incorrect_answer: float = Field(-1.0, description="Penalty for an incorrect answer.")
1841+
math_verify_timeout: int = Field(300, description="Timeout (seconds) for math_verify call per batch.")
1842+
math_verify_num_procs: int | None = Field(
1843+
None,
1844+
description=("Max worker processes for the math_verify pool. None ⇒ " "min(batch_size, cpu_count())."),
1845+
)
18371846

18381847

18391848
class SpecialTokens(BaseModel):

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 109 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
"""
1717
RL Evaluation Module.
1818
"""
19-
from math_verify import parse
19+
import collections
20+
import json
21+
2022
from tqdm.auto import tqdm
2123
from tunix.rl.rollout.base_rollout import RolloutConfig
2224

@@ -86,85 +88,128 @@ def generate_responses(
8688
return multiple_call_responses
8789

8890

89-
def score_responses(tmvp_config, question, responses, answer):
91+
def score_responses(tmvp_config, question, responses, answers):
9092
"""
9193
Score a set of responses for a single question.
9294
9395
Args:
9496
tmvp_config: Configuration object
9597
question: The evaluation question
9698
responses: List of generated responses for this question
97-
answer: The correct answer
99+
answers: List of acceptable answers for this question
98100
99101
Returns:
100102
Tuple of (is_correct, is_partially_correct, has_correct_format)
101103
"""
102-
match_format = utils_rl.get_match_format_regex(tmvp_config)
103-
answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config)
104-
105104
if tmvp_config.debug.rl:
106105
max_logging.log("========================================")
107106
max_logging.log(f"Evaluation Question: {question}")
108-
max_logging.log(f"Evaluation Answer: {answer}")
107+
max_logging.log(f"Evaluation Answer: {answers}")
109108
max_logging.log(f"Evaluation Responses: {responses}")
110109
max_logging.log("========================================")
111110

112-
is_correct = False
113-
is_partially_correct = False
114-
has_correct_format = False
111+
eval_mode = getattr(tmvp_config, "eval_mode", "pass")
112+
match_format = utils_rl.get_match_format_regex(tmvp_config)
115113

114+
extracted_responses = []
116115
for response in responses:
117-
# Extract answer: prefer the full format match; fall back to the last
118-
# <answer>...</answer> tag if full format match is not found, so result
119-
# scoring is decoupled from format.
120-
full_match = match_format.search(response)
121-
if full_match is not None:
122-
extracted_response = full_match.group(1)
123-
else:
124-
# Find the *last* occurrence of the answer tag (most likely the final answer).
125-
fallback_matches = answer_fallback.findall(response)
126-
extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000"
116+
extracted_response = utils_rl.extract_answer(response, tmvp_config)
117+
extracted_responses.append(extracted_response)
118+
119+
if not extracted_responses:
120+
return False, False, False
121+
122+
if eval_mode == "maj":
123+
# extract the single-most frequent response
124+
counter = collections.Counter(extracted_responses)
125+
majority_response = counter.most_common(1)[0][0]
127126
if tmvp_config.debug.rl:
128-
used = "full format" if full_match is not None else "answer-tag fallback"
129-
max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}")
127+
max_logging.log(f"Majority Response: {majority_response} (Count: {counter[majority_response]})")
128+
129+
# Check the format for the majority_response
130+
has_correct_format = False
131+
for idx, extracted_response in enumerate(extracted_responses):
132+
if extracted_response == majority_response:
133+
if match_format.search(responses[idx]) is not None:
134+
has_correct_format = True
135+
break
130136

131-
# Check exact correctness
132137
try:
133-
# Fix LaTeX escaping issues for both ground truth and extracted answer
134-
norm_answer = utils_rl.fix_latex_escaping(answer)
135-
norm_extracted = utils_rl.fix_latex_escaping(extracted_response)
136-
# Normalize Normalize for certain datasets and parse
137-
if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name:
138-
norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip()
139-
norm_answer = utils_rl.normalize_final_answer(answer).strip()
140-
is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1
138+
is_correct, is_partially_correct = utils_rl.check_correctness(majority_response, answers, tmvp_config)
141139
if tmvp_config.debug.rl:
142-
# is_correct is a tuple, if first value is 1.0 means it's a match;
143-
# 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}']))
140+
max_logging.log(f"Result has_correct_format: {has_correct_format}")
144141
max_logging.log(f"Result is_correct: {is_correct}")
145-
146-
val_extracted = parse(utils_rl.boxed(norm_extracted))
147-
val_answer = parse(utils_rl.boxed(norm_answer))
148-
149-
# Check partial correctness if values can be extracted (within 10%)
150-
if val_extracted and val_answer:
151-
ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON)
152-
is_partially_correct = 0.9 <= ratio <= 1.1
153-
142+
max_logging.log(f"Result is_partially_correct: {is_partially_correct}")
154143
except Exception as e:
144+
is_correct, is_partially_correct = False, False
155145
if tmvp_config.debug.rl:
156-
max_logging.log(f"Evaluation Exception: {e}")
146+
max_logging.log(f"Evaluation Exception on majority answer: {e}")
157147
max_logging.log("SKIPPED")
148+
return is_correct, is_partially_correct, has_correct_format
158149

159-
# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
160-
if full_match is not None:
161-
has_correct_format = True
150+
if eval_mode == "pass":
151+
for idx, response in enumerate(responses):
152+
is_correct, is_partially_correct, has_correct_format = False, False, False
153+
if match_format.search(response) is not None:
154+
has_correct_format = True
155+
156+
# Check exact and partial correctness (within 10%)
157+
try:
158+
is_correct, is_partially_correct = utils_rl.check_correctness(extracted_responses[idx], answers, tmvp_config)
159+
if tmvp_config.debug.rl:
160+
max_logging.log(f"Result is_correct: {is_correct}")
161+
max_logging.log(f"Result is_partially_correct: {is_partially_correct}")
162+
except Exception as e:
163+
if tmvp_config.debug.rl:
164+
max_logging.log(f"Evaluation Exception: {e}")
165+
max_logging.log("SKIPPED")
166+
167+
# Early exit if all criteria are met
168+
if is_correct and is_partially_correct and has_correct_format:
169+
return is_correct, is_partially_correct, has_correct_format
170+
return is_correct, is_partially_correct, has_correct_format
171+
172+
if eval_mode == "pass_at_1":
173+
# Estimate pass@1: fraction of N samples that are correct per problem.
174+
# Returns floats in [0, 1] instead of booleans.
175+
n_samples = len(responses)
176+
n_correct = 0
177+
n_partially_correct = 0
178+
n_correct_format = 0
179+
180+
for idx, response in enumerate(responses):
181+
if match_format.search(response) is not None:
182+
n_correct_format += 1
183+
184+
try:
185+
sample_correct, sample_partial = utils_rl.check_correctness(extracted_responses[idx], answers, tmvp_config)
186+
if sample_correct:
187+
n_correct += 1
188+
if sample_partial:
189+
n_partially_correct += 1
190+
if tmvp_config.debug.rl:
191+
max_logging.log(f"Sample {idx}: correct={sample_correct}, partial={sample_partial}")
192+
except Exception as e:
193+
if tmvp_config.debug.rl:
194+
max_logging.log(f"Evaluation Exception on sample {idx}: {e}")
195+
max_logging.log("SKIPPED")
196+
197+
frac_correct = n_correct / n_samples
198+
frac_partially_correct = n_partially_correct / n_samples
199+
frac_correct_format = n_correct_format / n_samples
200+
201+
if tmvp_config.debug.rl:
202+
max_logging.log(
203+
f"pass@1: {n_correct}/{n_samples} correct, "
204+
f"{n_partially_correct}/{n_samples} partial, "
205+
f"{n_correct_format}/{n_samples} format"
206+
)
162207

163-
# Early exit if all criteria are met
164-
if is_correct and is_partially_correct and has_correct_format:
165-
break
208+
return frac_correct, frac_partially_correct, frac_correct_format
166209

167-
return is_correct, is_partially_correct, has_correct_format
210+
if tmvp_config.debug.rl:
211+
max_logging.log(f"Unknown eval mode: {eval_mode}")
212+
return False, False, False
168213

169214

170215
def evaluate(
@@ -210,28 +255,29 @@ def evaluate(
210255

211256
# Score each question-answer pair
212257
for question, responses, answer in zip(questions, multiple_call_responses, answers):
258+
# decode the json-encoded list of acceptable answers
259+
answer = list(dict.fromkeys(json.loads(answer)))
213260
is_correct, is_partially_correct, has_correct_format = score_responses(
214261
tmvp_config=tmvp_config,
215262
question=question,
216263
responses=responses,
217-
answer=answer,
264+
answers=answer,
218265
)
219266

220-
# Update counters
221-
if is_correct:
222-
corr += 1
223-
if corr_lst and make_lst:
267+
# Update counters. For "pass" and "maj" modes, scores are booleans
268+
# (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1]
269+
# representing the fraction of samples correct. Using += works for both:
270+
# bool is a subtype of int in Python, so True += is the same as += 1.
271+
corr += is_correct
272+
partially_corr += is_partially_correct
273+
corr_format += has_correct_format
274+
275+
if make_lst:
276+
if corr_lst and is_correct:
224277
response_lst.append((question, answer, responses))
225-
else:
226-
if not corr_lst and make_lst:
278+
elif not corr_lst and not is_correct:
227279
response_lst.append((question, answer, responses))
228280

229-
if is_partially_correct:
230-
partially_corr += 1
231-
232-
if has_correct_format:
233-
corr_format += 1
234-
235281
total += 1
236282

237283
# Print progress every 10 items

0 commit comments

Comments
 (0)