Skip to content

Commit 04a07ed

Browse files
committed
Add open-r1/OpenR1-Math-220k dataset to RL
1 parent ae6671b commit 04a07ed

3 files changed

Lines changed: 105 additions & 115 deletions

File tree

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,12 @@ def score_responses(tmvp_config, question, responses, answers):
114114
has_correct_format = False
115115

116116
for response in responses:
117-
# Extract answer: prefer the full format match; fall back to the last
118-
# <answer>...</answer> tag if full format match is not found, so result
119-
# scoring is decoupled from format.
120-
extracted_response, match_method = utils_rl.extract_answer(response, tmvp_config)
117+
# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
118+
match_format = utils_rl.get_match_format_regex(tmvp_config)
119+
if match_format.search(response) is not None:
120+
has_correct_format = True
121+
122+
extracted_response = utils_rl.extract_answer(response, tmvp_config)
121123

122124
# Check exact and partial correctness (within 10%)
123125
try:
@@ -130,10 +132,6 @@ def score_responses(tmvp_config, question, responses, answers):
130132
max_logging.log(f"Evaluation Exception: {e}")
131133
max_logging.log("SKIPPED")
132134

133-
# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
134-
if match_method == "full format":
135-
has_correct_format = True
136-
137135
# Early exit if all criteria are met
138136
if is_correct and is_partially_correct and has_correct_format:
139137
break

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -653,15 +653,17 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
653653

654654
# Before we train the model, let's evaluate the model on the test set so we can
655655
# see the improvement post training.
656-
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
657-
trainer_config,
658-
test_dataset,
659-
rl_cluster=rl_cluster,
660-
num_passes=trainer_config.num_eval_passes,
661-
corr_lst=trainer_config.eval_corr_lst,
662-
make_lst=trainer_config.eval_make_lst,
663-
)
664-
max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
656+
if trainer_config.num_test_batches > 0:
657+
max_logging.warning("Starting evaluation before RL training...")
658+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
659+
trainer_config,
660+
test_dataset,
661+
rl_cluster=rl_cluster,
662+
num_passes=trainer_config.num_eval_passes,
663+
corr_lst=trainer_config.eval_corr_lst,
664+
make_lst=trainer_config.eval_make_lst,
665+
)
666+
max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
665667

666668
# Start training
667669
if trainer_config.load_checkpoint_only_once:
@@ -682,15 +684,17 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
682684
max_logging.warning("RL Training Completed Successfully!")
683685

684686
# Let's evaluate our model!
685-
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
686-
trainer_config,
687-
test_dataset,
688-
rl_cluster=rl_cluster,
689-
num_passes=trainer_config.num_eval_passes,
690-
corr_lst=trainer_config.eval_corr_lst,
691-
make_lst=trainer_config.eval_make_lst,
692-
)
693-
max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
687+
if trainer_config.num_test_batches > 0:
688+
max_logging.warning("Starting evaluation after RL training...")
689+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
690+
trainer_config,
691+
test_dataset,
692+
rl_cluster=rl_cluster,
693+
num_passes=trainer_config.num_eval_passes,
694+
corr_lst=trainer_config.eval_corr_lst,
695+
make_lst=trainer_config.eval_make_lst,
696+
)
697+
max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
694698

695699

696700
def main(argv: Sequence[str]) -> None:

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 77 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught
1616
"""RL Utils Module."""
1717
import concurrent
18+
import itertools
1819
import json
1920
import re
2021
import optax
@@ -94,23 +95,28 @@
9495
]
9596

9697

97-
def math_verify_func(golds, predictions, timeout=5):
98-
"""Wrapper around math_verify's verify function to handle timeouts and exceptions gracefully."""
99-
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
100-
future = executor.submit(verify_math, golds, predictions)
101-
try:
102-
return future.result(timeout=timeout)
103-
except concurrent.futures.TimeoutError:
104-
return 0.0
105-
except Exception:
106-
return 0.0
98+
def math_verify_func(items, timeout=5):
99+
"""Verifies a batch of math problems, handling timeouts and exceptions."""
100+
with concurrent.futures.ThreadPoolExecutor() as executor:
101+
future_to_index = {executor.submit(verify_math, golds, predictions): idx for idx, (_, golds, predictions) in enumerate(items)}
102+
results = [0.0] * len(items)
103+
for future in concurrent.futures.as_completed(future_to_index):
104+
index = future_to_index[future]
105+
try:
106+
results[index] = future.result(timeout=timeout)
107+
except (concurrent.futures.TimeoutError, Exception):
108+
max_logging.log(f"math_verify_func failed for golds: {items[index][1]} and predictions: {items[index][2]}")
109+
return results
107110

108111

109112
def verify_math(golds, predictions):
110113
"""Runs mathematical expression evaluation on ground-truth and predictions."""
111114

112-
extracted_predictions = [parse(pred, PRED_EXTRACTION_TARGET, parsing_timeout=None) for pred in predictions]
113-
extracted_golds = [parse(gold, GOLD_EXTRACTION_TARGET, parsing_timeout=None) for gold in golds]
115+
extracted_predictions = list(itertools.chain.from_iterable(parse(pred, PRED_EXTRACTION_TARGET, parsing_timeout=None) for pred in predictions))
116+
extracted_golds = list(itertools.chain.from_iterable(parse(gold, GOLD_EXTRACTION_TARGET, parsing_timeout=None) for gold in golds))
117+
# If no predictions or golds were extracted, return 0.0
118+
if not extracted_predictions or not extracted_golds:
119+
return 0.0
114120

115121
return max(
116122
[
@@ -372,7 +378,8 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
372378
question = kargs["question"]
373379

374380
# Extract full answer content from solution tags (not just first number)
375-
extracted_responses = [extract_answer(c, tmvp_config)[0] for c in completions]
381+
extracted_responses = [extract_answer(c, tmvp_config) for c in completions]
382+
true_answers = [list(dict.fromkeys(json.loads(acceptable_answers))) if isinstance(acceptable_answers, str) else acceptable_answers for acceptable_answers in answer]
376383

377384
if tmvp_config.debug.rl:
378385
max_logging.log("START ============================")
@@ -382,79 +389,63 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
382389
max_logging.log(f"Extracted: {extracted_responses[0]}")
383390
max_logging.log("END ==============================")
384391

385-
scores = []
386-
for guess, acceptable_answers in zip(extracted_responses, answer):
392+
scores = [tmvp_config.penalty_incorrect_format] * len(completions) # Default to penalty for incorrect format
393+
math_verify_queue = []
394+
for gen_idx, (guess, unique_answers) in enumerate(zip(extracted_responses, true_answers)):
387395
if guess is None:
388-
scores.append(0)
396+
scores[gen_idx] = 0
389397
continue
390398

391-
norm_guess = preprocess_math_string(tmvp_config.dataset_name, guess)
392-
393-
# decode the json-encoded list of acceptable answers
394-
answers = json.loads(acceptable_answers) if isinstance(acceptable_answers, str) else acceptable_answers
395-
unique_answers = list(dict.fromkeys(answers))
396-
max_score = tmvp_config.penalty_incorrect_format
399+
has_exact_match = False
397400
for true_answer in unique_answers:
398-
norm_answer = preprocess_math_string(tmvp_config.dataset_name, true_answer)
399-
400401
# 1. Check for exact or whitespace-normalized match first for a quick reward
401402
if guess == true_answer:
402-
max_score = max(max_score, tmvp_config.reward_exact_answer)
403-
continue
403+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_exact_answer)
404+
has_exact_match = True
404405
elif guess.strip() == true_answer.strip():
405-
max_score = max(max_score, tmvp_config.reward_white_space_format_match)
406-
continue
406+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_white_space_format_match)
407+
has_exact_match = True
408+
409+
if not has_exact_match:
410+
norm_guess = preprocess_math_string(tmvp_config.dataset_name, guess)
411+
norm_answers = []
412+
for true_answer in unique_answers:
413+
norm_answer = preprocess_math_string(tmvp_config.dataset_name, true_answer)
414+
norm_answers.append(boxed(norm_answer))
415+
math_verify_queue.append((gen_idx, norm_answers, [boxed(norm_guess)]))
416+
417+
if math_verify_queue:
418+
# 2. Try math_verify for robust mathematical correctness checking
419+
math_verify_results = math_verify_func(math_verify_queue)
420+
for (gen_idx, norm_answers, norm_guess), score in zip(math_verify_queue, math_verify_results):
421+
if score > 0.1:
422+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_exact_answer)
423+
print("-------- Found a math_verify match -----------")
407424
else:
425+
# 3. As a fallback, try numeric comparison if both can be parsed as numbers
408426
try:
409-
# 2. Try math_verify for robust comparison
410-
if math_verify_func([boxed(norm_answer)], [boxed(norm_guess)]) > 0.1:
411-
max_score = max(max_score, tmvp_config.reward_exact_answer)
412-
continue
413-
except (TimeoutException, Exception):
414-
if tmvp_config.debug.rl:
415-
max_logging.log(
416-
f"math_verify_func failed for gold: {norm_answer} and prediction: "
417-
f"{norm_guess}, falling back to numeric comparison."
418-
)
419-
420-
# 3. As a fallback, try numeric comparison if both can be parsed as numbers
421-
try:
422-
predictions = parse(boxed(norm_guess))
423-
golds = parse(boxed(norm_answer))
424-
for gold in golds:
425-
for pred in predictions:
426-
ratio = (float(pred) + EPSILON) / (float(gold) + EPSILON)
427-
if 0.9 <= ratio <= 1.1:
428-
max_score = max(max_score, tmvp_config.reward_ratio_guess_to_answer_high)
429-
elif 0.8 <= ratio <= 1.2:
430-
max_score = max(max_score, tmvp_config.reward_ratio_guess_to_answer_low)
431-
else:
432-
max_score = max(max_score, tmvp_config.penalty_incorrect_answer) # Penalize wrong answers
433-
except:
434-
max_score = max(max_score, tmvp_config.penalty_incorrect_format) # Penalize if we can't parse numbers at all
435-
436-
scores.append(max_score)
427+
predictions = parse(boxed(norm_guess), PRED_EXTRACTION_TARGET, parsing_timeout=None)
428+
golds = list(itertools.chain.from_iterable(parse(boxed(norm_answer), GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers))
429+
for gold in golds:
430+
for pred in predictions:
431+
ratio = (float(pred) + EPSILON) / (float(gold) + EPSILON)
432+
if 0.9 <= ratio <= 1.1:
433+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_ratio_guess_to_answer_high)
434+
elif 0.8 <= ratio <= 1.2:
435+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_ratio_guess_to_answer_low)
436+
else:
437+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_answer) # Penalize wrong answers
438+
except:
439+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_format) # Penalize if we can't parse numbers at all
437440
return scores
438441

439442

440443
def extract_answer(response, tmvp_config) -> str | None:
441444
"""Function to extract the answer from the text based on the tmvp_config format."""
442-
match_format = get_match_format_regex(tmvp_config)
443445
answer_fallback = get_answer_fallback_regex(tmvp_config)
444-
445-
full_match = match_format.search(response)
446-
if full_match is not None:
447-
extracted_response = full_match.group(1)
448-
else:
449-
# Find the *last* occurrence of the answer tag (most likely the final answer).
450-
fallback_matches = answer_fallback.findall(response)
451-
extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000"
452-
453-
match_method = "full format" if full_match is not None else "answer-tag fallback"
454-
if tmvp_config.debug.rl:
455-
max_logging.log(f"Evaluation extracted_response ({match_method}): {extracted_response}")
456-
457-
return extracted_response, match_method
446+
fallback_matches = answer_fallback.findall(response)
447+
extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000"
448+
return extracted_response
458449

459450

460451
def extract_hash_answer(text: str) -> str | None:
@@ -466,26 +457,23 @@ def extract_hash_answer(text: str) -> str | None:
466457

467458
def check_correctness(extracted_response, acceptable_answers, tmvp_config):
468459
"""Handles math verification and partial correctness logic."""
469-
for answer in acceptable_answers:
470-
# Check exact correctness first
471-
norm_answer = preprocess_math_string(tmvp_config.dataset_name, answer)
472-
norm_extracted = preprocess_math_string(tmvp_config.dataset_name, extracted_response)
473-
is_correct = math_verify_func([boxed(norm_answer)], [boxed(norm_extracted)]) > 0.1
474-
if is_correct:
475-
return True, True # Exact correctness implies partial correctness
460+
norm_answers = []
461+
norm_response = preprocess_math_string(tmvp_config.dataset_name, extracted_response)
462+
# Check exact correctness first
463+
for answer in acceptable_answers:
464+
norm_answers.append(preprocess_math_string(tmvp_config.dataset_name, answer))
465+
is_correct = math_verify_func([(0, [boxed(norm_answer) for norm_answer in norm_answers], [boxed(norm_response)])])[0] > 0.1
466+
if is_correct:
467+
return True, True # Exact correctness implies partial correctness
476468

477469
# Check partial correctness if values can be extracted (within 10%)
478-
for answer in acceptable_answers:
479-
norm_answer = preprocess_math_string(tmvp_config.dataset_name, answer)
480-
norm_extracted = preprocess_math_string(tmvp_config.dataset_name, extracted_response)
481-
val_extracted = parse(boxed(norm_extracted))
482-
val_answer = parse(boxed(norm_answer))
483-
if val_extracted and val_answer:
484-
is_partially_correct = any(
485-
0.9 <= (float(pred) + EPSILON) / (float(gold) + EPSILON) <= 1.1 for pred in val_extracted for gold in val_answer
486-
)
487-
if is_partially_correct:
488-
return False, True # Not exactly correct, but partially correct
470+
predictions = parse(boxed(norm_response), PRED_EXTRACTION_TARGET, parsing_timeout=None)
471+
golds = list(itertools.chain.from_iterable(parse(boxed(norm_answer), GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers))
472+
is_partially_correct = any(
473+
0.9 <= (float(pred) + EPSILON) / (float(gold) + EPSILON) <= 1.1 for pred in predictions for gold in golds
474+
)
475+
if is_partially_correct:
476+
return False, True # Not exactly correct, but partially correct
489477

490478
return False, False # Not correct at all
491479

0 commit comments

Comments
 (0)