Skip to content

Commit ae00589

Browse files
committed
Add open-r1/OpenR1-Math-220k dataset to RL
1 parent 17dda16 commit ae00589

3 files changed

Lines changed: 28 additions & 29 deletions

File tree

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def score_responses(tmvp_config, question, responses, answers):
100100
Returns:
101101
Tuple of (is_correct, is_partially_correct, has_correct_format)
102102
"""
103-
104-
answers = list(dict.fromkeys(answers))
105103
if tmvp_config.debug.rl:
106104
max_logging.log("========================================")
107105
max_logging.log(f"Evaluation Question: {question}")
@@ -174,17 +172,14 @@ def evaluate(
174172

175173
# Generate responses for all prompts in the batch
176174
multiple_call_responses = generate_responses(
177-
tmvp_config=tmvp_config,
178175
prompts=prompts,
179176
rl_cluster=rl_cluster,
180177
num_passes=num_passes,
181178
)
182179

183180
# Score each question-answer pair
184181
for question, responses, answer in zip(questions, multiple_call_responses, answers):
185-
answer = (
186-
json.loads(answer) if isinstance(answer, str) else answer
187-
) # decode the json-encoded list of acceptable answers
182+
answer = list(dict.fromkeys(json.loads(answer)))
188183
is_correct, is_partially_correct, has_correct_format = score_responses(
189184
tmvp_config=tmvp_config,
190185
question=question,

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def prepare_train_and_eval_dataset(
324324
)
325325

326326
original_ds = datasets.load_dataset(
327-
trainer_config.dataset_name,
327+
"parquet",
328328
data_files={trainer_config.train_split: trainer_config.hf_train_files},
329329
split=trainer_config.train_split,
330330
)

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,8 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
379379

380380
# Extract full answer content from solution tags (not just first number)
381381
extracted_responses = [extract_answer(c, tmvp_config) for c in completions]
382-
true_answers = [list(dict.fromkeys(json.loads(acceptable_answers))) if isinstance(acceptable_answers, str) else acceptable_answers for acceptable_answers in answer]
383-
382+
true_answers = [list(dict.fromkeys(acceptable_answers)) for acceptable_answers in answer]
383+
384384
if tmvp_config.debug.rl:
385385
max_logging.log("START ============================")
386386
max_logging.log(f"Question: {question[0]}")
@@ -417,24 +417,26 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
417417
if math_verify_queue:
418418
# 2. Try math_verify for robust mathematical correctness checking
419419
math_verify_results = math_verify_func(math_verify_queue)
420-
for (gen_idx, norm_answers, norm_guess), score in zip(math_verify_queue, math_verify_results):
420+
for (gen_idx, norm_answers, norm_guesses), score in zip(math_verify_queue, math_verify_results):
421421
if score > 0.1:
422422
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_exact_answer)
423-
print("-------- Found a math_verify match -----------")
424423
else:
425424
# 3. As a fallback, try numeric comparison if both can be parsed as numbers
426425
try:
427-
predictions = parse(boxed(norm_guess), PRED_EXTRACTION_TARGET, parsing_timeout=None)
428-
golds = list(itertools.chain.from_iterable(parse(boxed(norm_answer), GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers))
426+
predictions = parse(norm_guesses[0], PRED_EXTRACTION_TARGET, parsing_timeout=None)
427+
golds = list(itertools.chain.from_iterable(parse(norm_answer, GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers))
429428
for gold in golds:
430429
for pred in predictions:
431-
ratio = (float(pred) + EPSILON) / (float(gold) + EPSILON)
432-
if 0.9 <= ratio <= 1.1:
433-
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_ratio_guess_to_answer_high)
434-
elif 0.8 <= ratio <= 1.2:
435-
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_ratio_guess_to_answer_low)
436-
else:
437-
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_answer) # Penalize wrong answers
430+
try:
431+
ratio = (float(pred) + EPSILON) / (float(gold) + EPSILON)
432+
if 0.9 <= ratio <= 1.1:
433+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_ratio_guess_to_answer_high)
434+
elif 0.8 <= ratio <= 1.2:
435+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_ratio_guess_to_answer_low)
436+
else:
437+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_answer)
438+
except:
439+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_answer)
438440
except:
439441
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_format) # Penalize if we can't parse numbers at all
440442
return scores
@@ -462,20 +464,22 @@ def check_correctness(extracted_response, acceptable_answers, tmvp_config):
462464
# Check exact correctness first
463465
for answer in acceptable_answers:
464466
norm_answers.append(preprocess_math_string(tmvp_config.dataset_name, answer))
465-
is_correct = math_verify_func([(0, [boxed(norm_answer) for norm_answer in norm_answers], [boxed(norm_response)])])[0] > 0.1
467+
is_correct = verify_math([boxed(norm_answer) for norm_answer in norm_answers], [boxed(norm_response)]) > 0.1
466468
if is_correct:
467469
return True, True # Exact correctness implies partial correctness
468470

469471
# Check partial correctness if values can be extracted (within 10%)
470-
predictions = parse(boxed(norm_response), PRED_EXTRACTION_TARGET, parsing_timeout=None)
471-
golds = list(itertools.chain.from_iterable(parse(boxed(norm_answer), GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers))
472-
is_partially_correct = any(
473-
0.9 <= (float(pred) + EPSILON) / (float(gold) + EPSILON) <= 1.1 for pred in predictions for gold in golds
474-
)
475-
if is_partially_correct:
476-
return False, True # Not exactly correct, but partially correct
472+
is_partially_correct = False
473+
try:
474+
predictions = parse(boxed(norm_response), PRED_EXTRACTION_TARGET, parsing_timeout=None)
475+
golds = list(itertools.chain.from_iterable(parse(boxed(norm_answer), GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers))
476+
is_partially_correct = any(
477+
0.9 <= (float(pred) + EPSILON) / (float(gold) + EPSILON) <= 1.1 for pred in predictions for gold in golds
478+
)
479+
except:
480+
pass
477481

478-
return False, False # Not correct at all
482+
return False, is_partially_correct
479483

480484

481485
def get_optimizer(tmvp_config, max_train_steps):

0 commit comments

Comments
 (0)