Skip to content

Commit 6476df0

Browse files
SurbhiJainUSCA9isha
andcommitted
Add open-r1/OpenR1-Math-220k dataset and nvidia/OpenMathReasoning to RL and fix reward function
Co-authored-by: A9isha <mazumdera@google.com>
1 parent 62674fd commit 6476df0

11 files changed

Lines changed: 1293 additions & 475 deletions

File tree

.github/workflows/build_and_test_maxtext.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ jobs:
199199
base_image: maxtext-unit-test-tpu:py312
200200
cloud_runner: linux-x86-ct6e-180-4tpu
201201
pytest_marker: 'not cpu_only and not gpu_only and not integration_test and not post_training'
202+
pytest_addopts: '--ignore=tests/post_training'
202203
xla_python_client_mem_fraction: 0.75
203204
tf_force_gpu_allow_growth: false
204205
container_resource_option: "--privileged"
@@ -217,6 +218,7 @@ jobs:
217218
base_image: maxtext-unit-test-tpu:py312
218219
cloud_runner: linux-x86-ct6e-180-4tpu
219220
pytest_marker: 'not cpu_only and not gpu_only and integration_test and not post_training'
221+
pytest_addopts: '--ignore=tests/post_training'
220222
xla_python_client_mem_fraction: 0.75
221223
tf_force_gpu_allow_growth: false
222224
container_resource_option: "--privileged"

src/maxtext/configs/post_train/rl.yml

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ generation_configs:
148148
num_eval_passes: 1 # Number of generation passes during evaluation
149149
eval_corr_lst: False # If True, only include correct responses in the list during evaluation
150150
eval_make_lst: False # If True, return a list of (question, answer, responses) during evaluation
151+
eval_mode: "pass" # Evaluation mode ("pass" for pass@K, "maj" for majority voting maj@K, "pass_at_1" for pass@1 estimation)
151152

152153
# ====== Inference ======
153154
# === Generation during GRPO training ===
@@ -190,6 +191,12 @@ reward_ratio_guess_to_answer_low: 0.0
190191
penalty_incorrect_format: 0.0
191192
penalty_incorrect_answer: 0.0
192193

194+
# ====== Configuration for math_verify Pool ======
195+
# Global timeout (seconds) for math_verify calls across all examples in a batch
196+
math_verify_timeout: 300
197+
# Max worker processes for the math_verify pool. null ⇒ min(batch_size, cpu_count())
198+
math_verify_num_procs: null
199+
193200
# ====== Special tokens/templates for GSM8K reasoning ======
194201
reasoning_start_token: '<reasoning>'
195202
reasoning_end_token: '</reasoning>'
@@ -198,10 +205,23 @@ solution_end_token: '</answer>'
198205
chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
199206
skip_jax_distributed_system: True
200207

201-
# # TODO(@mazumdera): fix this
202-
# Dataset Configuration
203-
dataset_name: 'gsm8k' # huggingface:open-r1/DAPO-Math-17k-Processed
204-
eval_dataset_name: 'gsm8k' # huggingface:BytedTsinghua-SIA/AIME-2024
208+
# ====== Dataset Configuration ======
209+
# Supported values for dataset_name:
210+
# ['openai/gsm8k', 'nvidia/OpenMathInstruct-2', 'nvidia/OpenMathReasoning', 'open-r1/OpenR1-Math-220k', 'bethgelab/CuratedThoughts']
211+
#
212+
# Scenarios:
213+
# 1. dataset_name='openai/gsm8k' and eval_dataset_name='openai/gsm8k'
214+
# Loads the train and test splits of GSM8K directly.
215+
#
216+
# 2. Datasets other than 'gsm8k' with same eval: dataset_name=eval_dataset_name=<dataset>
217+
# The dataset has no separate test split, so the training data is
218+
# automatically split into train and test sets internally.
219+
#
220+
# 3. Train and evaluation on different datasets: dataset_name=<train_dataset>, eval_dataset_name=<eval_dataset>
221+
# Loads separate dataset for training and evaluation (e.g., train on OpenMathInstruct-2, eval on GSM8K).
222+
dataset_name: 'openai/gsm8k'
223+
eval_dataset_name: 'openai/gsm8k'
205224
train_split: 'train'
206225
eval_split: 'test'
226+
hf_name: 'main' # subset of Hugging Face dataset
207227
tokenizer_type: 'huggingface'

src/maxtext/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,6 +1830,10 @@ class RLEvaluation(BaseModel):
18301830
False,
18311831
description="If True, return a list of (question, answer, responses) during evaluation.",
18321832
)
1833+
eval_mode: Literal["pass", "maj", "pass_at_1"] = Field(
1834+
"pass",
1835+
description="Evaluation mode to use ('pass' for pass@K, 'maj' for maj@K, 'pass_at_1' for pass@1 estimation).",
1836+
)
18331837

18341838

18351839
class Reward(BaseModel):
@@ -1847,6 +1851,11 @@ class Reward(BaseModel):
18471851
)
18481852
penalty_incorrect_format: float = Field(-0.5, description="Penalty for an incorrect format.")
18491853
penalty_incorrect_answer: float = Field(-1.0, description="Penalty for an incorrect answer.")
1854+
math_verify_timeout: int = Field(300, description="Global timeout (seconds) for math_verify calls per batch.")
1855+
math_verify_num_procs: int | None = Field(
1856+
None,
1857+
description=("Max worker processes for the math_verify pool. None ⇒ " "min(batch_size, cpu_count())."),
1858+
)
18501859

18511860

18521861
class SpecialTokens(BaseModel):

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 87 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
"""
1717
RL Evaluation Module.
1818
"""
19-
from math_verify import parse
19+
import collections
20+
import json
21+
2022
from tqdm.auto import tqdm
2123
from tunix.rl.rollout.base_rollout import RolloutConfig
2224

@@ -86,85 +88,91 @@ def generate_responses(
8688
return multiple_call_responses
8789

8890

89-
def score_responses(tmvp_config, question, responses, answer):
90-
"""
91-
Score a set of responses for a single question.
91+
def _score_single(extracted, response, answers, tmvp_config, match_format):
92+
"""Score one (extracted answer, raw response) pair. Returns (is_correct, is_partially_correct, has_correct_format)."""
93+
has_correct_format = match_format.search(response) is not None
94+
try:
95+
is_correct, is_partially_correct = utils_rl.check_correctness(extracted, answers, tmvp_config)
96+
if tmvp_config.debug.rl:
97+
max_logging.log(f"Result has_correct_format: {has_correct_format}")
98+
max_logging.log(f"Result is_correct: {is_correct}")
99+
max_logging.log(f"Result is_partially_correct: {is_partially_correct}")
100+
except Exception as e: # pylint: disable=broad-exception-caught
101+
is_correct, is_partially_correct = False, False
102+
if tmvp_config.debug.rl:
103+
max_logging.log(f"Evaluation Exception: {e} — SKIPPED")
104+
return is_correct, is_partially_correct, has_correct_format
105+
106+
107+
def score_responses(tmvp_config, question, responses, answers):
108+
"""Score a set of responses for a single question.
92109
93110
Args:
94111
tmvp_config: Configuration object
95112
question: The evaluation question
96113
responses: List of generated responses for this question
97-
answer: The correct answer
114+
answers: List of correct answers
98115
99116
Returns:
100117
Tuple of (is_correct, is_partially_correct, has_correct_format)
101118
"""
102-
match_format = utils_rl.get_match_format_regex(tmvp_config)
103-
answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config)
104-
105119
if tmvp_config.debug.rl:
106120
max_logging.log("========================================")
107121
max_logging.log(f"Evaluation Question: {question}")
108-
max_logging.log(f"Evaluation Answer: {answer}")
122+
max_logging.log(f"Evaluation Answer: {answers}")
109123
max_logging.log(f"Evaluation Responses: {responses}")
110124
max_logging.log("========================================")
111125

112-
is_correct = False
113-
is_partially_correct = False
114-
has_correct_format = False
115-
116-
for response in responses:
117-
# Extract answer: prefer the full format match; fall back to the last
118-
# <answer>...</answer> tag if full format match is not found, so result
119-
# scoring is decoupled from format.
120-
full_match = match_format.search(response)
121-
if full_match is not None:
122-
extracted_response = full_match.group(1)
123-
else:
124-
# Find the *last* occurrence of the answer tag (most likely the final answer).
125-
fallback_matches = answer_fallback.findall(response)
126-
extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000"
126+
eval_mode = getattr(tmvp_config, "eval_mode", "pass")
127+
match_format = utils_rl.get_match_format_regex(tmvp_config)
128+
extracted_responses = [utils_rl.extract_answer(r, tmvp_config) for r in responses]
129+
130+
if not extracted_responses:
131+
return False, False, False
132+
133+
if eval_mode == "maj":
134+
# extract the single-most frequent response
135+
counter = collections.Counter(extracted_responses)
136+
majority = counter.most_common(1)[0][0]
127137
if tmvp_config.debug.rl:
128-
used = "full format" if full_match is not None else "answer-tag fallback"
129-
max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}")
130-
131-
# Check exact correctness
132-
try:
133-
# Fix LaTeX escaping issues for both ground truth and extracted answer
134-
norm_answer = utils_rl.fix_latex_escaping(answer)
135-
norm_extracted = utils_rl.fix_latex_escaping(extracted_response)
136-
# Normalize Normalize for certain datasets and parse
137-
if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name:
138-
norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip()
139-
norm_answer = utils_rl.normalize_final_answer(answer).strip()
140-
is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1
141-
if tmvp_config.debug.rl:
142-
# is_correct is a tuple, if first value is 1.0 means it's a match;
143-
# 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}']))
144-
max_logging.log(f"Result is_correct: {is_correct}")
145-
146-
val_extracted = parse(utils_rl.boxed(norm_extracted))
147-
val_answer = parse(utils_rl.boxed(norm_answer))
148-
149-
# Check partial correctness if values can be extracted (within 10%)
150-
if val_extracted and val_answer:
151-
ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON)
152-
is_partially_correct = 0.9 <= ratio <= 1.1
153-
154-
except Exception as e:
155-
if tmvp_config.debug.rl:
156-
max_logging.log(f"Evaluation Exception: {e}")
157-
max_logging.log("SKIPPED")
158-
159-
# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
160-
if full_match is not None:
161-
has_correct_format = True
162-
163-
# Early exit if all criteria are met
164-
if is_correct and is_partially_correct and has_correct_format:
165-
break
138+
max_logging.log(f"Majority Response: {majority} (Count: {counter[majority]})")
166139

167-
return is_correct, is_partially_correct, has_correct_format
140+
# Check the format for the majority response
141+
has_correct_format = any(
142+
match_format.search(responses[idx]) is not None
143+
for idx, response in enumerate(extracted_responses)
144+
if response == majority
145+
)
146+
is_correct, is_partially_correct, _ = _score_single(majority, responses[0], answers, tmvp_config, match_format)
147+
return is_correct, is_partially_correct, has_correct_format
148+
149+
if eval_mode == "pass":
150+
result = False, False, False
151+
for extracted, response in zip(extracted_responses, responses):
152+
result = _score_single(extracted, response, answers, tmvp_config, match_format)
153+
# Early exit if all criteria are met
154+
if all(result):
155+
return result
156+
return result
157+
158+
if eval_mode == "pass_at_1":
159+
# Estimate pass@1: fraction of N samples that are correct per problem.
160+
# Returns floats in [0, 1] instead of booleans.
161+
scores = [
162+
_score_single(extracted_response, response, answers, tmvp_config, match_format)
163+
for extracted_response, response in zip(extracted_responses, responses)
164+
]
165+
n_samples = len(scores)
166+
frac_correct = sum(s[0] for s in scores) / n_samples
167+
frac_partial = sum(s[1] for s in scores) / n_samples
168+
frac_format = sum(s[2] for s in scores) / n_samples
169+
if tmvp_config.debug.rl:
170+
max_logging.log(f"{frac_correct*n_samples:.0f}/{n_samples} correct")
171+
max_logging.log(f"{frac_partial*n_samples:.0f}/{n_samples} partial")
172+
max_logging.log(f"{frac_format*n_samples:.0f}/{n_samples} format")
173+
return frac_correct, frac_partial, frac_format
174+
175+
raise ValueError(f"Unknown eval_mode: {eval_mode!r}")
168176

169177

170178
def evaluate(
@@ -210,28 +218,29 @@ def evaluate(
210218

211219
# Score each question-answer pair
212220
for question, responses, answer in zip(questions, multiple_call_responses, answers):
221+
# decode the json-encoded list of acceptable answers
222+
answer = list(dict.fromkeys(json.loads(answer)))
213223
is_correct, is_partially_correct, has_correct_format = score_responses(
214224
tmvp_config=tmvp_config,
215225
question=question,
216226
responses=responses,
217-
answer=answer,
227+
answers=answer,
218228
)
219229

220-
# Update counters
221-
if is_correct:
222-
corr += 1
223-
if corr_lst and make_lst:
230+
# Update counters. For "pass" and "maj" modes, scores are booleans
231+
# (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1]
232+
# representing the fraction of samples correct. Using += works for both:
233+
# bool is a subtype of int in Python, so True += is the same as += 1.
234+
corr += is_correct
235+
partially_corr += is_partially_correct
236+
corr_format += has_correct_format
237+
238+
if make_lst:
239+
if corr_lst and is_correct:
224240
response_lst.append((question, answer, responses))
225-
else:
226-
if not corr_lst and make_lst:
241+
elif not corr_lst and not is_correct:
227242
response_lst.append((question, answer, responses))
228243

229-
if is_partially_correct:
230-
partially_corr += 1
231-
232-
if has_correct_format:
233-
corr_format += 1
234-
235244
total += 1
236245

237246
# Print progress every 10 items
@@ -243,8 +252,8 @@ def evaluate(
243252

244253
# Prepare return values
245254
to_return = (
246-
corr,
247-
total,
255+
corr * num_passes,
256+
total * num_passes,
248257
corr / total * 100 if total > 0 else 0,
249258
partially_corr / total * 100 if total > 0 else 0,
250259
corr_format / total * 100 if total > 0 else 0,

0 commit comments

Comments
 (0)