Skip to content

Commit 5a01808

Browse files
SurbhiJainUSCA9isha
andcommitted
Add open-r1/OpenR1-Math-220k dataset and nvidia/OpenMathReasoning to RL and fix reward function
Co-authored-by: A9isha <mazumdera@google.com>
1 parent 5182e3b commit 5a01808

11 files changed

Lines changed: 1194 additions & 414 deletions

File tree

.github/workflows/build_and_test_maxtext.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ jobs:
199199
base_image: maxtext-unit-test-tpu:py312
200200
cloud_runner: linux-x86-ct6e-180-4tpu
201201
pytest_marker: 'not cpu_only and not gpu_only and not integration_test and not post_training'
202+
pytest_addopts: '--ignore=tests/post_training'
202203
xla_python_client_mem_fraction: 0.75
203204
tf_force_gpu_allow_growth: false
204205
container_resource_option: "--privileged"
@@ -217,6 +218,7 @@ jobs:
217218
base_image: maxtext-unit-test-tpu:py312
218219
cloud_runner: linux-x86-ct6e-180-4tpu
219220
pytest_marker: 'not cpu_only and not gpu_only and integration_test and not post_training'
221+
pytest_addopts: '--ignore=tests/post_training'
220222
xla_python_client_mem_fraction: 0.75
221223
tf_force_gpu_allow_growth: false
222224
container_resource_option: "--privileged"

src/maxtext/configs/post_train/rl.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ generation_configs:
148148
num_eval_passes: 1 # Number of generation passes during evaluation
149149
eval_corr_lst: False # If True, only include correct responses in the list during evaluation
150150
eval_make_lst: False # If True, return a list of (question, answer, responses) during evaluation
151+
eval_mode: "pass" # Evaluation mode ("pass" for pass@K, "maj" for majority voting maj@K, "pass_at_1" for pass@1 estimation)
151152

152153
# ====== Inference ======
153154
# === Generation during GRPO training ===
@@ -190,6 +191,12 @@ reward_ratio_guess_to_answer_low: 0.0
190191
penalty_incorrect_format: 0.0
191192
penalty_incorrect_answer: 0.0
192193

194+
# ====== Configuration for math_verify Pool ======
195+
# Timeout (seconds) for math_verify
196+
math_verify_timeout: 300
197+
# Max worker processes for the math_verify pool. null ⇒ min(batch_size, cpu_count())
198+
math_verify_num_procs: null
199+
193200
# ====== Special tokens/templates for GSM8K reasoning ======
194201
reasoning_start_token: '<reasoning>'
195202
reasoning_end_token: '</reasoning>'
@@ -200,8 +207,8 @@ skip_jax_distributed_system: True
200207

201208
# # TODO(@mazumdera): fix this
202209
# Dataset Configuration
203-
dataset_name: 'gsm8k' # huggingface:open-r1/DAPO-Math-17k-Processed
204-
eval_dataset_name: 'gsm8k' # huggingface:BytedTsinghua-SIA/AIME-2024
210+
dataset_name: 'gsm8k' # open-r1/DAPO-Math-17k-Processed
211+
eval_dataset_name: 'gsm8k' # BytedTsinghua-SIA/AIME-2024
205212
train_split: 'train'
206213
eval_split: 'test'
207214
tokenizer_type: 'huggingface'

src/maxtext/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,10 @@ class RLEvaluation(BaseModel):
18171817
False,
18181818
description="If True, return a list of (question, answer, responses) during evaluation.",
18191819
)
1820+
eval_mode: Literal["pass", "maj", "pass_at_1"] = Field(
1821+
"pass",
1822+
description="Evaluation mode to use ('pass' for pass@K, 'maj' for maj@K, 'pass_at_1' for pass@1 estimation).",
1823+
)
18201824

18211825

18221826
class Reward(BaseModel):
@@ -1834,6 +1838,11 @@ class Reward(BaseModel):
18341838
)
18351839
penalty_incorrect_format: float = Field(-0.5, description="Penalty for an incorrect format.")
18361840
penalty_incorrect_answer: float = Field(-1.0, description="Penalty for an incorrect answer.")
1841+
math_verify_timeout: int = Field(300, description="Timeout (seconds) for math_verify call per batch.")
1842+
math_verify_num_procs: int | None = Field(
1843+
None,
1844+
description=("Max worker processes for the math_verify pool. None ⇒ " "min(batch_size, cpu_count())."),
1845+
)
18371846

18381847

18391848
class SpecialTokens(BaseModel):

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 83 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
"""
1717
RL Evaluation Module.
1818
"""
19-
from math_verify import parse
19+
import collections
20+
import json
21+
2022
from tqdm.auto import tqdm
2123
from tunix.rl.rollout.base_rollout import RolloutConfig
2224

@@ -86,85 +88,82 @@ def generate_responses(
8688
return multiple_call_responses
8789

8890

89-
def score_responses(tmvp_config, question, responses, answer):
90-
"""
91-
Score a set of responses for a single question.
91+
def _score_single(extracted, response, answers, tmvp_config, match_format):
92+
"""Score one (extracted answer, raw response) pair. Returns (is_correct, is_partially_correct, has_correct_format)."""
93+
has_correct_format = match_format.search(response) is not None
94+
try:
95+
is_correct, is_partially_correct = utils_rl.check_correctness(extracted, answers, tmvp_config)
96+
if tmvp_config.debug.rl:
97+
max_logging.log(f"Result has_correct_format: {has_correct_format}")
98+
max_logging.log(f"Result is_correct: {is_correct}")
99+
max_logging.log(f"Result is_partially_correct: {is_partially_correct}")
100+
except Exception as e: # pylint: disable=broad-exception-caught
101+
is_correct, is_partially_correct = False, False
102+
if tmvp_config.debug.rl:
103+
max_logging.log(f"Evaluation Exception: {e} — SKIPPED")
104+
return is_correct, is_partially_correct, has_correct_format
92105

93-
Args:
94-
tmvp_config: Configuration object
95-
question: The evaluation question
96-
responses: List of generated responses for this question
97-
answer: The correct answer
106+
107+
def score_responses(tmvp_config, question, responses, answers):
108+
"""Score a set of responses for a single question.
98109
99110
Returns:
100-
Tuple of (is_correct, is_partially_correct, has_correct_format)
111+
Tuple of (is_correct, is_partially_correct, has_correct_format).
101112
"""
102-
match_format = utils_rl.get_match_format_regex(tmvp_config)
103-
answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config)
104-
105113
if tmvp_config.debug.rl:
106114
max_logging.log("========================================")
107115
max_logging.log(f"Evaluation Question: {question}")
108-
max_logging.log(f"Evaluation Answer: {answer}")
116+
max_logging.log(f"Evaluation Answer: {answers}")
109117
max_logging.log(f"Evaluation Responses: {responses}")
110118
max_logging.log("========================================")
111119

112-
is_correct = False
113-
is_partially_correct = False
114-
has_correct_format = False
115-
116-
for response in responses:
117-
# Extract answer: prefer the full format match; fall back to the last
118-
# <answer>...</answer> tag if full format match is not found, so result
119-
# scoring is decoupled from format.
120-
full_match = match_format.search(response)
121-
if full_match is not None:
122-
extracted_response = full_match.group(1)
123-
else:
124-
# Find the *last* occurrence of the answer tag (most likely the final answer).
125-
fallback_matches = answer_fallback.findall(response)
126-
extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000"
120+
eval_mode = getattr(tmvp_config, "eval_mode", "pass")
121+
match_format = utils_rl.get_match_format_regex(tmvp_config)
122+
extracted_responses = [utils_rl.extract_answer(r, tmvp_config) for r in responses]
123+
124+
if not extracted_responses:
125+
return False, False, False
126+
127+
if eval_mode == "maj":
128+
# extract the single-most frequent response
129+
counter = collections.Counter(extracted_responses)
130+
majority = counter.most_common(1)[0][0]
127131
if tmvp_config.debug.rl:
128-
used = "full format" if full_match is not None else "answer-tag fallback"
129-
max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}")
130-
131-
# Check exact correctness
132-
try:
133-
# Fix LaTeX escaping issues for both ground truth and extracted answer
134-
norm_answer = utils_rl.fix_latex_escaping(answer)
135-
norm_extracted = utils_rl.fix_latex_escaping(extracted_response)
136-
# Normalize Normalize for certain datasets and parse
137-
if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name:
138-
norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip()
139-
norm_answer = utils_rl.normalize_final_answer(answer).strip()
140-
is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1
141-
if tmvp_config.debug.rl:
142-
# is_correct is a tuple, if first value is 1.0 means it's a match;
143-
# 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}']))
144-
max_logging.log(f"Result is_correct: {is_correct}")
145-
146-
val_extracted = parse(utils_rl.boxed(norm_extracted))
147-
val_answer = parse(utils_rl.boxed(norm_answer))
148-
149-
# Check partial correctness if values can be extracted (within 10%)
150-
if val_extracted and val_answer:
151-
ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON)
152-
is_partially_correct = 0.9 <= ratio <= 1.1
153-
154-
except Exception as e:
155-
if tmvp_config.debug.rl:
156-
max_logging.log(f"Evaluation Exception: {e}")
157-
max_logging.log("SKIPPED")
158-
159-
# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
160-
if full_match is not None:
161-
has_correct_format = True
162-
163-
# Early exit if all criteria are met
164-
if is_correct and is_partially_correct and has_correct_format:
165-
break
132+
max_logging.log(f"Majority Response: {majority} (Count: {counter[majority]})")
166133

167-
return is_correct, is_partially_correct, has_correct_format
134+
# Check the format for the majority response
135+
has_correct_format = any(
136+
match_format.search(responses[idx]) is not None
137+
for idx, response in enumerate(extracted_responses)
138+
if response == majority
139+
)
140+
is_correct, is_partially_correct, _ = _score_single(majority, responses[0], answers, tmvp_config, match_format)
141+
return is_correct, is_partially_correct, has_correct_format
142+
143+
if eval_mode == "pass":
144+
result = False, False, False
145+
for extracted, response in zip(extracted_responses, responses):
146+
result = _score_single(extracted, response, answers, tmvp_config, match_format)
147+
# Early exit if all criteria are met
148+
if all(result):
149+
return result
150+
return result
151+
152+
if eval_mode == "pass_at_1":
153+
# Estimate pass@1: fraction of N samples that are correct per problem.
154+
# Returns floats in [0, 1] instead of booleans.
155+
scores = [_score_single(extracted_response, response, answers, tmvp_config, match_format) for extracted_response, response in zip(extracted_responses, responses)]
156+
n_samples = len(scores)
157+
frac_correct = sum(s[0] for s in scores) / n_samples
158+
frac_partial = sum(s[1] for s in scores) / n_samples
159+
frac_format = sum(s[2] for s in scores) / n_samples
160+
if tmvp_config.debug.rl:
161+
max_logging.log(f"pass@1: {frac_correct*n_samples:.0f}/{n_samples} correct, {frac_partial*n_samples:.0f}/{n_samples} partial, {frac_format*n_samples:.0f}/{n_samples} format")
162+
return frac_correct, frac_partial, frac_format
163+
164+
if tmvp_config.debug.rl:
165+
max_logging.log(f"Unknown eval mode: {eval_mode}")
166+
raise ValueError(f"Unknown eval_mode: {eval_mode!r}")
168167

169168

170169
def evaluate(
@@ -210,28 +209,29 @@ def evaluate(
210209

211210
# Score each question-answer pair
212211
for question, responses, answer in zip(questions, multiple_call_responses, answers):
212+
# decode the json-encoded list of acceptable answers
213+
answer = list(dict.fromkeys(json.loads(answer)))
213214
is_correct, is_partially_correct, has_correct_format = score_responses(
214215
tmvp_config=tmvp_config,
215216
question=question,
216217
responses=responses,
217-
answer=answer,
218+
answers=answer,
218219
)
219220

220-
# Update counters
221-
if is_correct:
222-
corr += 1
223-
if corr_lst and make_lst:
221+
# Update counters. For "pass" and "maj" modes, scores are booleans
222+
# (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1]
223+
# representing the fraction of samples correct. Using += works for both:
224+
# bool is a subtype of int in Python, so True += is the same as += 1.
225+
corr += is_correct
226+
partially_corr += is_partially_correct
227+
corr_format += has_correct_format
228+
229+
if make_lst:
230+
if corr_lst and is_correct:
224231
response_lst.append((question, answer, responses))
225-
else:
226-
if not corr_lst and make_lst:
232+
elif not corr_lst and not is_correct:
227233
response_lst.append((question, answer, responses))
228234

229-
if is_partially_correct:
230-
partially_corr += 1
231-
232-
if has_correct_format:
233-
corr_format += 1
234-
235235
total += 1
236236

237237
# Print progress every 10 items
@@ -243,8 +243,8 @@ def evaluate(
243243

244244
# Prepare return values
245245
to_return = (
246-
corr,
247-
total,
246+
corr*num_passes,
247+
total*num_passes,
248248
corr / total * 100 if total > 0 else 0,
249249
partially_corr / total * 100 if total > 0 else 0,
250250
corr_format / total * 100 if total > 0 else 0,

0 commit comments

Comments
 (0)