Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/build_and_test_maxtext.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ jobs:
base_image: maxtext-unit-test-tpu:py312
cloud_runner: linux-x86-ct6e-180-4tpu
pytest_marker: 'not cpu_only and not gpu_only and not integration_test and not post_training'
pytest_addopts: '--ignore=tests/post_training'
xla_python_client_mem_fraction: 0.75
tf_force_gpu_allow_growth: false
container_resource_option: "--privileged"
Expand All @@ -217,6 +218,7 @@ jobs:
base_image: maxtext-unit-test-tpu:py312
cloud_runner: linux-x86-ct6e-180-4tpu
pytest_marker: 'not cpu_only and not gpu_only and integration_test and not post_training'
pytest_addopts: '--ignore=tests/post_training'
xla_python_client_mem_fraction: 0.75
tf_force_gpu_allow_growth: false
container_resource_option: "--privileged"
Expand Down
28 changes: 24 additions & 4 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ generation_configs:
num_eval_passes: 1 # Number of generation passes during evaluation
eval_corr_lst: False # If True, only include correct responses in the list during evaluation
eval_make_lst: False # If True, return a list of (question, answer, responses) during evaluation
eval_mode: "pass" # Evaluation mode ("pass" for pass@K, "maj" for majority voting maj@K, "pass_at_1" for pass@1 estimation)

# ====== Inference ======
# === Generation during GRPO training ===
Expand Down Expand Up @@ -190,6 +191,12 @@ reward_ratio_guess_to_answer_low: 0.0
penalty_incorrect_format: 0.0
penalty_incorrect_answer: 0.0

# ====== Configuration for math_verify Pool ======
# Global timeout (seconds) for math_verify calls across all examples in a batch
math_verify_timeout: 120
# Max worker processes for the math_verify pool. null ⇒ min(batch_size, cpu_count())
math_verify_num_procs: null

# ====== Special tokens/templates for GSM8K reasoning ======
reasoning_start_token: '<reasoning>'
reasoning_end_token: '</reasoning>'
Expand All @@ -198,10 +205,23 @@ solution_end_token: '</answer>'
chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
skip_jax_distributed_system: True

# # TODO(@mazumdera): fix this
# Dataset Configuration
dataset_name: 'gsm8k' # huggingface:open-r1/DAPO-Math-17k-Processed
eval_dataset_name: 'gsm8k' # huggingface:BytedTsinghua-SIA/AIME-2024
# ====== Dataset Configuration ======
# Supported values for dataset_name:
# ['openai/gsm8k', 'nvidia/OpenMathInstruct-2', 'nvidia/OpenMathReasoning', 'open-r1/OpenR1-Math-220k', 'bethgelab/CuratedThoughts']
#
# Scenarios:
# 1. dataset_name='openai/gsm8k' and eval_dataset_name='openai/gsm8k'
# Loads the train and test splits of GSM8K directly.
#
# 2. Datasets other than 'gsm8k' with same eval: dataset_name=eval_dataset_name=<dataset>
# The dataset has no separate test split, so the training data is
# automatically split into train and test sets internally.
#
# 3. Train and evaluation on different datasets: dataset_name=<train_dataset>, eval_dataset_name=<eval_dataset>
# Loads separate dataset for training and evaluation (e.g., train on OpenMathInstruct-2, eval on GSM8K).
dataset_name: 'openai/gsm8k'
eval_dataset_name: 'openai/gsm8k'
train_split: 'train'
eval_split: 'test'
hf_name: 'main' # subset of Hugging Face dataset
tokenizer_type: 'huggingface'
9 changes: 9 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,6 +1830,10 @@ class RLEvaluation(BaseModel):
False,
description="If True, return a list of (question, answer, responses) during evaluation.",
)
eval_mode: Literal["pass", "maj", "pass_at_1"] = Field(
"pass",
description="Evaluation mode to use ('pass' for pass@K, 'maj' for maj@K, 'pass_at_1' for pass@1 estimation).",
)


class Reward(BaseModel):
Expand All @@ -1847,6 +1851,11 @@ class Reward(BaseModel):
)
penalty_incorrect_format: float = Field(-0.5, description="Penalty for an incorrect format.")
penalty_incorrect_answer: float = Field(-1.0, description="Penalty for an incorrect answer.")
math_verify_timeout: int = Field(300, description="Global timeout (seconds) for math_verify calls per batch.")
math_verify_num_procs: int | None = Field(
None,
description=("Max worker processes for the math_verify pool. None ⇒ " "min(batch_size, cpu_count())."),
)


class SpecialTokens(BaseModel):
Expand Down
169 changes: 93 additions & 76 deletions src/maxtext/trainers/post_train/rl/evaluate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
"""
RL Evaluation Module.
"""
from math_verify import parse
import collections
import json
import re
from typing import Any

from tqdm.auto import tqdm
from tunix.rl.rollout.base_rollout import RolloutConfig

Expand Down Expand Up @@ -86,85 +90,97 @@ def generate_responses(
return multiple_call_responses


def score_responses(tmvp_config, question, responses, answer):
"""
Score a set of responses for a single question.
def _score_single(
extracted_response: str,
raw_response: str,
answers: list[str],
tmvp_config: Any,
match_format: re.Pattern[str],
) -> tuple[bool, bool, bool]:
"""Score one (extracted answer, raw response) pair. Returns (is_correct, is_partially_correct, has_correct_format)."""
has_correct_format = match_format.search(raw_response) is not None
try:
is_correct, is_partially_correct = utils_rl.check_correctness(extracted_response, answers, tmvp_config)
if tmvp_config.debug.rl:
max_logging.log(f"Result has_correct_format: {has_correct_format}")
max_logging.log(f"Result is_correct: {is_correct}")
max_logging.log(f"Result is_partially_correct: {is_partially_correct}")
except Exception as e: # pylint: disable=broad-exception-caught
is_correct, is_partially_correct = False, False
if tmvp_config.debug.rl:
max_logging.log(f"Evaluation Exception: {e} — SKIPPED")
return is_correct, is_partially_correct, has_correct_format


def score_responses(tmvp_config, question, responses, answers):
"""Score a set of responses for a single question.

Args:
tmvp_config: Configuration object
question: The evaluation question
responses: List of generated responses for this question
answer: The correct answer
answers: List of correct answers

Returns:
Tuple of (is_correct, is_partially_correct, has_correct_format)
"""
match_format = utils_rl.get_match_format_regex(tmvp_config)
answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config)

if tmvp_config.debug.rl:
max_logging.log("========================================")
max_logging.log(f"Evaluation Question: {question}")
max_logging.log(f"Evaluation Answer: {answer}")
max_logging.log(f"Evaluation Answer: {answers}")
max_logging.log(f"Evaluation Responses: {responses}")
max_logging.log("========================================")

is_correct = False
is_partially_correct = False
has_correct_format = False

for response in responses:
# Extract answer: prefer the full format match; fall back to the last
# <answer>...</answer> tag if full format match is not found, so result
# scoring is decoupled from format.
full_match = match_format.search(response)
if full_match is not None:
extracted_response = full_match.group(1)
else:
# Find the *last* occurrence of the answer tag (most likely the final answer).
fallback_matches = answer_fallback.findall(response)
extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000"
eval_mode = getattr(tmvp_config, "eval_mode", "pass")
match_format = utils_rl.get_match_format_regex(tmvp_config)
extracted_responses = [utils_rl.extract_answer(r, tmvp_config) for r in responses]

if not extracted_responses:
return False, False, False

if eval_mode == "maj":
# extract the single-most frequent response
counter = collections.Counter(extracted_responses)
majority = counter.most_common(1)[0][0]
if tmvp_config.debug.rl:
used = "full format" if full_match is not None else "answer-tag fallback"
max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}")

# Check exact correctness
try:
# Fix LaTeX escaping issues for both ground truth and extracted answer
norm_answer = utils_rl.fix_latex_escaping(answer)
norm_extracted = utils_rl.fix_latex_escaping(extracted_response)
# Normalize Normalize for certain datasets and parse
if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name:
norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip()
norm_answer = utils_rl.normalize_final_answer(answer).strip()
is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1
if tmvp_config.debug.rl:
# is_correct is a tuple, if first value is 1.0 means it's a match;
# 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}']))
max_logging.log(f"Result is_correct: {is_correct}")

val_extracted = parse(utils_rl.boxed(norm_extracted))
val_answer = parse(utils_rl.boxed(norm_answer))

# Check partial correctness if values can be extracted (within 10%)
if val_extracted and val_answer:
ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON)
is_partially_correct = 0.9 <= ratio <= 1.1

except Exception as e:
if tmvp_config.debug.rl:
max_logging.log(f"Evaluation Exception: {e}")
max_logging.log("SKIPPED")

# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
if full_match is not None:
has_correct_format = True

# Early exit if all criteria are met
if is_correct and is_partially_correct and has_correct_format:
break
max_logging.log(f"Majority Response: {majority} (Count: {counter[majority]})")

return is_correct, is_partially_correct, has_correct_format
# Check the format for the majority response
has_correct_format = any(
match_format.search(responses[idx]) is not None
for idx, response in enumerate(extracted_responses)
if response == majority
)
is_correct, is_partially_correct, _ = _score_single(majority, responses[0], answers, tmvp_config, match_format)
return is_correct, is_partially_correct, has_correct_format

if eval_mode == "pass":
result = False, False, False
for extracted, response in zip(extracted_responses, responses):
result = _score_single(extracted, response, answers, tmvp_config, match_format)
# Early exit if all criteria are met
if all(result):
return result
return result

if eval_mode == "pass_at_1":
# Estimate pass@1: fraction of N samples that are correct per problem.
# Returns floats in [0, 1] instead of booleans.
scores = [
_score_single(extracted_response, response, answers, tmvp_config, match_format)
for extracted_response, response in zip(extracted_responses, responses)
]
n_samples = len(scores)
frac_correct = sum(s[0] for s in scores) / n_samples
frac_partial = sum(s[1] for s in scores) / n_samples
frac_format = sum(s[2] for s in scores) / n_samples
if tmvp_config.debug.rl:
max_logging.log(f"{frac_correct*n_samples:.0f}/{n_samples} correct")
max_logging.log(f"{frac_partial*n_samples:.0f}/{n_samples} partial")
max_logging.log(f"{frac_format*n_samples:.0f}/{n_samples} format")
return frac_correct, frac_partial, frac_format

raise ValueError(f"Unknown eval_mode: {eval_mode!r}")


def evaluate(
Expand Down Expand Up @@ -210,28 +226,29 @@ def evaluate(

# Score each question-answer pair
for question, responses, answer in zip(questions, multiple_call_responses, answers):
# decode the json-encoded list of acceptable answers
answer = list(dict.fromkeys(json.loads(answer)))
is_correct, is_partially_correct, has_correct_format = score_responses(
tmvp_config=tmvp_config,
question=question,
responses=responses,
answer=answer,
answers=answer,
)

# Update counters
if is_correct:
corr += 1
if corr_lst and make_lst:
# Update counters. For "pass" and "maj" modes, scores are booleans
# (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1]
# representing the fraction of samples correct. Using += works for both:
# bool is a subtype of int in Python, so True += is the same as += 1.
corr += is_correct
partially_corr += is_partially_correct
corr_format += has_correct_format

if make_lst:
if corr_lst and is_correct:
response_lst.append((question, answer, responses))
else:
if not corr_lst and make_lst:
elif not corr_lst and not is_correct:
response_lst.append((question, answer, responses))

if is_partially_correct:
partially_corr += 1

if has_correct_format:
corr_format += 1

total += 1

# Print progress every 10 items
Expand Down
Loading
Loading