From e042c1fd5d665b7de7f10f579c919297ae22f814 Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Thu, 26 Mar 2026 21:44:49 +0000 Subject: [PATCH] Add open-r1/OpenR1-Math-220k dataset and nvidia/OpenMathReasoning to RL and fix reward function Co-authored-by: A9isha --- .github/workflows/build_and_test_maxtext.yml | 2 + src/maxtext/configs/post_train/rl.yml | 28 +- src/maxtext/configs/types.py | 9 + .../trainers/post_train/rl/evaluate_rl.py | 169 +++--- .../post_train/rl/math_verify_pool.py | 260 ++++++++++ .../trainers/post_train/rl/train_rl.py | 287 +++++------ .../trainers/post_train/rl/utils_rl.py | 484 +++++++++++------- tests/post_training/unit/evaluate_rl_test.py | 212 ++++++++ .../unit/math_verify_pool_test.py | 183 +++++++ tests/post_training/unit/rl_utils_test.py | 122 ++--- tests/post_training/unit/train_rl_test.py | 129 ++++- 11 files changed, 1378 insertions(+), 507 deletions(-) create mode 100644 src/maxtext/trainers/post_train/rl/math_verify_pool.py create mode 100644 tests/post_training/unit/evaluate_rl_test.py create mode 100644 tests/post_training/unit/math_verify_pool_test.py diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index d54f98fbec..bf36401e50 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -199,6 +199,7 @@ jobs: base_image: maxtext-unit-test-tpu:py312 cloud_runner: linux-x86-ct6e-180-4tpu pytest_marker: 'not cpu_only and not gpu_only and not integration_test and not post_training' + pytest_addopts: '--ignore=tests/post_training' xla_python_client_mem_fraction: 0.75 tf_force_gpu_allow_growth: false container_resource_option: "--privileged" @@ -217,6 +218,7 @@ jobs: base_image: maxtext-unit-test-tpu:py312 cloud_runner: linux-x86-ct6e-180-4tpu pytest_marker: 'not cpu_only and not gpu_only and integration_test and not post_training' + pytest_addopts: '--ignore=tests/post_training' xla_python_client_mem_fraction: 0.75 tf_force_gpu_allow_growth: false container_resource_option: "--privileged" diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index f5ce74ced0..4f29702fc5 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -148,6 +148,7 @@ generation_configs: num_eval_passes: 1 # Number of generation passes during evaluation eval_corr_lst: False # If True, only include correct responses in the list during evaluation eval_make_lst: False # If True, return a list of (question, answer, responses) during evaluation +eval_mode: "pass" # Evaluation mode ("pass" for pass@K, "maj" for majority voting maj@K, "pass_at_1" for pass@1 estimation) # ====== Inference ====== # === Generation during GRPO training === @@ -190,6 +191,12 @@ reward_ratio_guess_to_answer_low: 0.0 penalty_incorrect_format: 0.0 penalty_incorrect_answer: 0.0 +# ====== Configuration for math_verify Pool ====== +# Global timeout (seconds) for math_verify calls across all examples in a batch +math_verify_timeout: 120 +# Max worker processes for the math_verify pool. null ⇒ min(batch_size, cpu_count()) +math_verify_num_procs: null + # ====== Special tokens/templates for GSM8K reasoning ====== reasoning_start_token: '' reasoning_end_token: '' @@ -198,10 +205,23 @@ solution_end_token: '' chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json' skip_jax_distributed_system: True -# # TODO(@mazumdera): fix this -# Dataset Configuration -dataset_name: 'gsm8k' # huggingface:open-r1/DAPO-Math-17k-Processed -eval_dataset_name: 'gsm8k' # huggingface:BytedTsinghua-SIA/AIME-2024 +# ====== Dataset Configuration ====== +# Supported values for dataset_name: +# ['openai/gsm8k', 'nvidia/OpenMathInstruct-2', 'nvidia/OpenMathReasoning', 'open-r1/OpenR1-Math-220k', 'bethgelab/CuratedThoughts'] +# +# Scenarios: +# 1. dataset_name='openai/gsm8k' and eval_dataset_name='openai/gsm8k' +# Loads the train and test splits of GSM8K directly. +# +# 2. Datasets other than 'gsm8k' with same eval: dataset_name=eval_dataset_name= +# The dataset has no separate test split, so the training data is +# automatically split into train and test sets internally. +# +# 3. Train and evaluation on different datasets: dataset_name=, eval_dataset_name= +# Loads separate dataset for training and evaluation (e.g., train on OpenMathInstruct-2, eval on GSM8K). +dataset_name: 'openai/gsm8k' +eval_dataset_name: 'openai/gsm8k' train_split: 'train' eval_split: 'test' +hf_name: 'main' # subset of Hugging Face dataset tokenizer_type: 'huggingface' diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index af809e2b42..43338ed507 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1830,6 +1830,10 @@ class RLEvaluation(BaseModel): False, description="If True, return a list of (question, answer, responses) during evaluation.", ) + eval_mode: Literal["pass", "maj", "pass_at_1"] = Field( + "pass", + description="Evaluation mode to use ('pass' for pass@K, 'maj' for maj@K, 'pass_at_1' for pass@1 estimation).", + ) class Reward(BaseModel): @@ -1847,6 +1851,11 @@ class Reward(BaseModel): ) penalty_incorrect_format: float = Field(-0.5, description="Penalty for an incorrect format.") penalty_incorrect_answer: float = Field(-1.0, description="Penalty for an incorrect answer.") + math_verify_timeout: int = Field(300, description="Global timeout (seconds) for math_verify calls per batch.") + math_verify_num_procs: int | None = Field( + None, + description=("Max worker processes for the math_verify pool. None ⇒ " "min(batch_size, cpu_count())."), + ) class SpecialTokens(BaseModel): diff --git a/src/maxtext/trainers/post_train/rl/evaluate_rl.py b/src/maxtext/trainers/post_train/rl/evaluate_rl.py index 29494b9a19..df537fc825 100644 --- a/src/maxtext/trainers/post_train/rl/evaluate_rl.py +++ b/src/maxtext/trainers/post_train/rl/evaluate_rl.py @@ -16,7 +16,11 @@ """ RL Evaluation Module. """ -from math_verify import parse +import collections +import json +import re +from typing import Any + from tqdm.auto import tqdm from tunix.rl.rollout.base_rollout import RolloutConfig @@ -86,85 +90,97 @@ def generate_responses( return multiple_call_responses -def score_responses(tmvp_config, question, responses, answer): - """ - Score a set of responses for a single question. +def _score_single( + extracted_response: str, + raw_response: str, + answers: list[str], + tmvp_config: Any, + match_format: re.Pattern[str], +) -> tuple[bool, bool, bool]: + """Score one (extracted answer, raw response) pair. Returns (is_correct, is_partially_correct, has_correct_format).""" + has_correct_format = match_format.search(raw_response) is not None + try: + is_correct, is_partially_correct = utils_rl.check_correctness(extracted_response, answers, tmvp_config) + if tmvp_config.debug.rl: + max_logging.log(f"Result has_correct_format: {has_correct_format}") + max_logging.log(f"Result is_correct: {is_correct}") + max_logging.log(f"Result is_partially_correct: {is_partially_correct}") + except Exception as e: # pylint: disable=broad-exception-caught + is_correct, is_partially_correct = False, False + if tmvp_config.debug.rl: + max_logging.log(f"Evaluation Exception: {e} — SKIPPED") + return is_correct, is_partially_correct, has_correct_format + + +def score_responses(tmvp_config, question, responses, answers): + """Score a set of responses for a single question. Args: tmvp_config: Configuration object question: The evaluation question responses: List of generated responses for this question - answer: The correct answer + answers: List of correct answers Returns: Tuple of (is_correct, is_partially_correct, has_correct_format) """ - match_format = utils_rl.get_match_format_regex(tmvp_config) - answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config) - if tmvp_config.debug.rl: max_logging.log("========================================") max_logging.log(f"Evaluation Question: {question}") - max_logging.log(f"Evaluation Answer: {answer}") + max_logging.log(f"Evaluation Answer: {answers}") max_logging.log(f"Evaluation Responses: {responses}") max_logging.log("========================================") - is_correct = False - is_partially_correct = False - has_correct_format = False - - for response in responses: - # Extract answer: prefer the full format match; fall back to the last - # ... tag if full format match is not found, so result - # scoring is decoupled from format. - full_match = match_format.search(response) - if full_match is not None: - extracted_response = full_match.group(1) - else: - # Find the *last* occurrence of the answer tag (most likely the final answer). - fallback_matches = answer_fallback.findall(response) - extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000" + eval_mode = getattr(tmvp_config, "eval_mode", "pass") + match_format = utils_rl.get_match_format_regex(tmvp_config) + extracted_responses = [utils_rl.extract_answer(r, tmvp_config) for r in responses] + + if not extracted_responses: + return False, False, False + + if eval_mode == "maj": + # extract the single-most frequent response + counter = collections.Counter(extracted_responses) + majority = counter.most_common(1)[0][0] if tmvp_config.debug.rl: - used = "full format" if full_match is not None else "answer-tag fallback" - max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}") - - # Check exact correctness - try: - # Fix LaTeX escaping issues for both ground truth and extracted answer - norm_answer = utils_rl.fix_latex_escaping(answer) - norm_extracted = utils_rl.fix_latex_escaping(extracted_response) - # Normalize Normalize for certain datasets and parse - if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name: - norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip() - norm_answer = utils_rl.normalize_final_answer(answer).strip() - is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1 - if tmvp_config.debug.rl: - # is_correct is a tuple, if first value is 1.0 means it's a match; - # 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}'])) - max_logging.log(f"Result is_correct: {is_correct}") - - val_extracted = parse(utils_rl.boxed(norm_extracted)) - val_answer = parse(utils_rl.boxed(norm_answer)) - - # Check partial correctness if values can be extracted (within 10%) - if val_extracted and val_answer: - ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON) - is_partially_correct = 0.9 <= ratio <= 1.1 - - except Exception as e: - if tmvp_config.debug.rl: - max_logging.log(f"Evaluation Exception: {e}") - max_logging.log("SKIPPED") - - # Check format correctness (requires the full ...... structure) - if full_match is not None: - has_correct_format = True - - # Early exit if all criteria are met - if is_correct and is_partially_correct and has_correct_format: - break + max_logging.log(f"Majority Response: {majority} (Count: {counter[majority]})") - return is_correct, is_partially_correct, has_correct_format + # Check the format for the majority response + has_correct_format = any( + match_format.search(responses[idx]) is not None + for idx, response in enumerate(extracted_responses) + if response == majority + ) + is_correct, is_partially_correct, _ = _score_single(majority, responses[0], answers, tmvp_config, match_format) + return is_correct, is_partially_correct, has_correct_format + + if eval_mode == "pass": + result = False, False, False + for extracted, response in zip(extracted_responses, responses): + result = _score_single(extracted, response, answers, tmvp_config, match_format) + # Early exit if all criteria are met + if all(result): + return result + return result + + if eval_mode == "pass_at_1": + # Estimate pass@1: fraction of N samples that are correct per problem. + # Returns floats in [0, 1] instead of booleans. + scores = [ + _score_single(extracted_response, response, answers, tmvp_config, match_format) + for extracted_response, response in zip(extracted_responses, responses) + ] + n_samples = len(scores) + frac_correct = sum(s[0] for s in scores) / n_samples + frac_partial = sum(s[1] for s in scores) / n_samples + frac_format = sum(s[2] for s in scores) / n_samples + if tmvp_config.debug.rl: + max_logging.log(f"{frac_correct*n_samples:.0f}/{n_samples} correct") + max_logging.log(f"{frac_partial*n_samples:.0f}/{n_samples} partial") + max_logging.log(f"{frac_format*n_samples:.0f}/{n_samples} format") + return frac_correct, frac_partial, frac_format + + raise ValueError(f"Unknown eval_mode: {eval_mode!r}") def evaluate( @@ -210,28 +226,29 @@ def evaluate( # Score each question-answer pair for question, responses, answer in zip(questions, multiple_call_responses, answers): + # decode the json-encoded list of acceptable answers + answer = list(dict.fromkeys(json.loads(answer))) is_correct, is_partially_correct, has_correct_format = score_responses( tmvp_config=tmvp_config, question=question, responses=responses, - answer=answer, + answers=answer, ) - # Update counters - if is_correct: - corr += 1 - if corr_lst and make_lst: + # Update counters. For "pass" and "maj" modes, scores are booleans + # (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1] + # representing the fraction of samples correct. Using += works for both: + # bool is a subtype of int in Python, so True += is the same as += 1. + corr += is_correct + partially_corr += is_partially_correct + corr_format += has_correct_format + + if make_lst: + if corr_lst and is_correct: response_lst.append((question, answer, responses)) - else: - if not corr_lst and make_lst: + elif not corr_lst and not is_correct: response_lst.append((question, answer, responses)) - if is_partially_correct: - partially_corr += 1 - - if has_correct_format: - corr_format += 1 - total += 1 # Print progress every 10 items diff --git a/src/maxtext/trainers/post_train/rl/math_verify_pool.py b/src/maxtext/trainers/post_train/rl/math_verify_pool.py new file mode 100644 index 0000000000..2eeadf1747 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/math_verify_pool.py @@ -0,0 +1,260 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Process-isolated, timeout-bounded math_verify grader for RL reward computation. + +This module provides a persistent multiprocessing pool that grades model-generated +math answers against ground-truth solutions using the math_verify library. It is +used during GRPO/GSPO training to compute rewards for each example in a batch. + +- Workers are spawned as separate CPU-only processes so that JAX/XLA inside + the trainer never competes with grader workers for accelerator resources. +- math_verify and other heavy dependencies are only imported once per + worker at startup via `silent_worker_init`, avoiding multi-second cold-start + latency on every grading call. +- The pool is module-level and persistent across training steps. Workers are + recycled after `_MAX_TASKS_PER_CHILD` tasks to bound sympy's internal cache + growth in long-running training jobs. +- Grading is subject to a global wall-clock timeout across all examples in a + batch (configured via `math_verify_timeout` in rl.yml). Items that do not + complete within the deadline are dropped; their scores remain at the + pre-call default. If any items time out, the pool is torn down and recreated + on the next call to recover from stuck threads. +- Shutdown escalates from SIGTERM to SIGKILL to handle workers blocked inside + native C extensions that ignore Python signals. +""" + +import atexit +import itertools +import multiprocessing +import os +import threading +import time +from typing import Any, Callable, Optional + +# Module-level persistent pool state. +_POOL = None +_POOL_NUM_PROCS = None +_DEFAULT_MAX_PROCS = 8 +# Recycle a worker after this many tasks. Bounds sympy's internal cache +# growth in long-lived workers without paying the cold-start cost too often. +_MAX_TASKS_PER_CHILD = 100 +# SIGTERM grace period before escalating to SIGKILL on stuck workers. +_TERMINATE_GRACE_SECONDS = 2.0 + + +def silent_worker_init() -> None: + """Pool initializer: hide accelerators and pre-import math_verify. + + Runs once per worker immediately after spawn. The env vars must be set + before math_verify load. Then we eagerly import math_verify so the first real + grading call doesn't pay multi-second cold-start latency inside the + per-item timeout. + """ + os.environ["JAX_PLATFORMS"] = "cpu" + os.environ["TPU_VISIBLE_DEVICES"] = "" + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + # Quiet TF / TPU log noise in workers. Override unconditionally — the + # parent trainer process often sets these to 0 for its own debugging, + # and `setdefault` would inherit that loud value into every spawned + # grader worker. We want the workers silent regardless. + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + os.environ["TPU_MIN_LOG_LEVEL"] = "3" + os.environ["TPU_STDERR_LOG_LEVEL"] = "3" + os.environ["GRPC_VERBOSITY"] = "ERROR" + try: + # Eagerly import the heavy grader stack so all subsequent + # `verify_math_worker` calls in this worker are fast. + import math_verify # pylint: disable=import-outside-toplevel,unused-import + from math_verify import parse, verify # pylint: disable=import-outside-toplevel,unused-import + from math_verify.parser import ( # pylint: disable=import-outside-toplevel,unused-import + ExprExtractionConfig, + LatexExtractionConfig, + ) + from sympy.parsing import sympy_parser # pylint: disable=import-outside-toplevel,unused-import + from sympy import Basic, MatrixBase # pylint: disable=import-outside-toplevel,unused-import + except Exception: # pylint: disable=broad-exception-caught + # If the import fails, individual jobs will fail too and return 0.0; + # don't crash the worker at startup. + pass + + +def are_equal_under_sympy(gold: Any, prediction: Any) -> bool: + """Returns True if gold and prediction are symbolically equal using SymPy. + + Parses both values as symbolic expressions with implicit multiplication + support (e.g. '2x' == '2*x'). Returns False if parsing fails or the + expressions differ. + """ + from sympy.parsing import sympy_parser # pylint: disable=import-outside-toplevel + + try: + gold_expr = sympy_parser.parse_expr(str(gold), evaluate=False) + pred_expr = sympy_parser.parse_expr(str(prediction), evaluate=False) + if gold_expr == pred_expr: + return True + except Exception: # pylint: disable=broad-exception-caught + pass + return False + + +def verify_math_worker(golds: list[str], predictions: list[str]) -> float: + """Worker-side math_verify grader.""" + try: # pylint: disable=too-many-nested-blocks + from math_verify import parse, verify # pylint: disable=import-outside-toplevel + from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig # pylint: disable=import-outside-toplevel + from sympy import Basic, MatrixBase # pylint: disable=import-outside-toplevel + + gold_targets = (ExprExtractionConfig(), LatexExtractionConfig()) + pred_targets = (ExprExtractionConfig(), LatexExtractionConfig()) + + extracted_predictions = list( + itertools.chain.from_iterable(parse(pred, pred_targets, parsing_timeout=None) for pred in predictions) + ) + extracted_golds = list( + itertools.chain.from_iterable(parse(gold, gold_targets, parsing_timeout=None) for gold in golds) + ) + if not extracted_predictions or not extracted_golds: + return 0.0 + + for gold in extracted_golds: + for pred in extracted_predictions: + if isinstance(gold, (Basic, MatrixBase)) and isinstance(pred, (Basic, MatrixBase)): + if are_equal_under_sympy(gold, pred): + return 1.0 + + if "**" not in str(gold) and "**" not in str(pred): + if verify(gold, pred, timeout_seconds=None): + return 1.0 + else: + if verify(gold, pred, timeout_seconds=None): + return 1.0 + return 0.0 + except Exception: # pylint: disable=broad-exception-caught + return 0.0 + + +def _shutdown_pool() -> None: + """Tear down the persistent pool, if any. Safe to call repeatedly. + + SIGTERM first with a grace period, then SIGKILL any survivors. SIGKILL is + handled by the kernel, so workers stuck in sympy/FLINT C extensions (which + ignore Python signals) are still reaped. Without this, `Pool.terminate()` + internally calls `p.join()` on each worker and blocks forever on the stuck + ones. + """ + global _POOL, _POOL_NUM_PROCS + pool = _POOL + _POOL = None + _POOL_NUM_PROCS = None + if pool is None: + return + try: + workers = list(getattr(pool, "_pool", [])) + for w in workers: + if w.is_alive(): + w.terminate() + deadline = time.monotonic() + _TERMINATE_GRACE_SECONDS + for w in workers: + remaining = max(0.0, deadline - time.monotonic()) + w.join(timeout=remaining) + for w in workers: + if w.is_alive(): + w.kill() + w.join(timeout=1.0) + # Workers SIGKILLed mid-write leave the outqueue lock orphaned, so + # pool.terminate() / pool.join() block forever on the internal + # _result_handler / _task_handler threads. Run terminate in a daemon + # thread with a bounded wait: pool._state flips to TERMINATE so the + # worker-handler stops spawning replacements, and we return even if the + # handler threads never unblock. Those threads leak, but they are daemon + # and cheap; a stuck trainer is not. + t = threading.Thread(target=pool.terminate, daemon=True) + t.start() + t.join(timeout=_TERMINATE_GRACE_SECONDS) + except Exception: # pylint: disable=broad-exception-caught + pass + + +def _get_pool(num_procs: int) -> multiprocessing.pool.Pool: + """Return the persistent pool, creating or resizing it as needed.""" + global _POOL, _POOL_NUM_PROCS + if _POOL is None or _POOL_NUM_PROCS != num_procs: + _shutdown_pool() + ctx = multiprocessing.get_context("spawn") + _POOL = ctx.Pool( + processes=num_procs, + initializer=silent_worker_init, + maxtasksperchild=_MAX_TASKS_PER_CHILD, + ) + _POOL_NUM_PROCS = num_procs + return _POOL + + +# ensures global worker pool is cleanly shut down when program finishes execution +atexit.register(_shutdown_pool) + + +def math_verify_pool( + trainer_config: Any, + items: list[tuple[int, list[str], list[str]]], + scores: list[float], + timeout: float = 300, + num_procs: Optional[int] = None, + log_fn: Optional[Callable[[str], None]] = None, +) -> list[float]: + """Grade a batch of (idx, golds, predictions) items in spawned CPU workers. + + Uses a persistent module-level pool. The first call pays the spawn + + math_verify-import cost; subsequent calls reuse warm workers and grade. + """ + if not items: + return scores + + cpu_count = multiprocessing.cpu_count() + if num_procs is None: + num_procs = min(_DEFAULT_MAX_PROCS, len(items), cpu_count) + else: + num_procs = max(1, min(num_procs, len(items), cpu_count)) + + cnt = 0 + pool = _get_pool(num_procs) + active_jobs = [(idx, pool.apply_async(verify_math_worker, (golds, predictions))) for (idx, golds, predictions) in items] + start_time = time.time() + while active_jobs and (time.time() - start_time < timeout): + # Iterate backwards to safely remove items from the list without skipping elements + for i in range(len(active_jobs) - 1, -1, -1): + idx, job = active_jobs[i] + if job.ready(): + try: + # .get(0) returns immediately since ready() was true + score = job.get(0) + if score > 0.0: + scores[idx] = max(scores[idx], trainer_config.reward_exact_answer) + cnt += 1 + except Exception as e: # pylint: disable=broad-exception-caught + if log_fn: + log_fn(f"math_verify_pool failed ({e}) for idx: {idx}") + active_jobs.pop(i) + + # Small sleep to prevent high CPU usage during the loop + time.sleep(0.1) + + if log_fn: + log_fn(f"math_verify_pool: Processed {cnt}/{len(items)} items ({len(active_jobs)} timed out).") + if len(active_jobs) > 0: + _shutdown_pool() + return scores diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 2ac465d7cf..6b71768f91 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -45,15 +45,15 @@ from __future__ import annotations from functools import wraps -from typing import Sequence +from typing import Any, Optional, Sequence +import datasets import grain import jax import json import logging import os import pathwaysutils -import tensorflow_datasets as tfds from absl import app from absl import logging as absl_logging @@ -67,8 +67,7 @@ from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner from tunix.sft import metrics_logger, profiler -# for vLLM we can skip JAX precompilation with this flag, it makes startup faster -os.environ["SKIP_JAX_PRECOMPILE"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "0" from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR @@ -79,52 +78,24 @@ def get_dataset( - model_tokenizer, tmvp_config, data_dir, split="train", data_files=None, dataset_name=None + tmvp_config: Any, + split: str = "train", + data_files: Optional[str] = None, + dataset_name: Optional[str] = None, ) -> grain.MapDataset: """Download data""" - if not os.path.exists(data_dir): - os.makedirs(data_dir) - - if dataset_name is None: - raise ValueError("dataset_name must be provided") - - if dataset_name.startswith("huggingface:"): - import datasets # pylint: disable=import-outside-toplevel - - if data_files is None: - hf_dataset_name = dataset_name.replace("huggingface:", "") - data = datasets.load_dataset(hf_dataset_name, split=split, cache_dir=data_dir) - if tmvp_config.debug.rl: - max_logging.log(f"Loaded Hugging Face dataset {hf_dataset_name} with split {split}. Size: {len(data)}") - else: # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2 - data = datasets.load_dataset( - "parquet", - data_files={tmvp_config.train_split: data_files}, - split=split, - cache_dir=data_dir, - ) - else: - builder_kwargs = {"file_format": tfds.core.FileFormat.ARRAY_RECORD} - data = tfds.data_source( - dataset_name, + if data_files is None: + data = datasets.load_dataset(dataset_name, name=tmvp_config.hf_name, split=split) + else: # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2 + data = datasets.load_dataset( + "parquet", + data_files={split: data_files}, split=split, - data_dir=data_dir, - builder_kwargs=builder_kwargs, - download=True, ) + if tmvp_config.debug.rl: + max_logging.log(f"Loaded Hugging Face dataset {dataset_name} with split {split}. Size: {len(data)}") - template_config = load_data_template_from_file(tmvp_config.chat_template_path) - if template_config is None: - raise ValueError( - f"Chat template is required for processing dataset but failed to load from {tmvp_config.chat_template_path}" - ) - - loaded_dataset = ( - grain.MapDataset.source(data) - .shuffle(seed=tmvp_config.data_shuffle_seed) - .map(lambda x: utils_rl.process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x)) - ) - return loaded_dataset + return data def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices): @@ -191,54 +162,60 @@ def get_max_train_steps(trainer_config): ) -def prepare_datasets(trainer_config, model_tokenizer): - """Setup and return train and test datasets.""" - home = os.path.expanduser("~") + "/" - train_data_dir = f"{home}/data/train" - test_data_dir = f"{home}/data/test" - if not os.path.exists(train_data_dir): - os.makedirs(train_data_dir) - if not os.path.exists(test_data_dir): - os.makedirs(test_data_dir) - - # Load datasets - if trainer_config.dataset_name == "huggingface:nvidia/OpenMathInstruct-2": - import datasets # pylint: disable=import-outside-toplevel - - def prepare_openinstructmath2_dataset( - split: str = "train_1M", - seed: int = 42, - test_size: float = 0.05, - ): - """Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split.""" - max_logging.log( - "WARNING: For reproducible experiments, preprocess the dataset once and " - "define your own HfDataset subclass that directly uses the preprocessed datasets." - ) +def prepare_train_and_eval_dataset( + trainer_config: Any, + test_size: float = 0.05, +) -> dict[str, datasets.Dataset]: + """Load and split the dataset into train and validation sets using HF's train_test_split.""" + max_logging.log( + "WARNING: For reproducible experiments, preprocess the dataset once and " + "define your own HfDataset subclass that directly uses the preprocessed datasets." + ) - # Load the original dataset - original_ds = datasets.load_dataset( - "parquet", - data_files={trainer_config.train_split: trainer_config.hf_train_files}, - split=split, - cache_dir=train_data_dir, - ) + original_ds = get_dataset( + trainer_config, + split=trainer_config.train_split, + data_files=trainer_config.hf_train_files, + dataset_name=trainer_config.dataset_name, + ) - # Split into train and validation sets using HF's train_test_split - split_ds = original_ds.train_test_split(test_size=test_size, seed=seed) + if "OpenMathReasoning" in trainer_config.dataset_name: + original_ds = original_ds.filter(lambda x: x.get("problem_type") == "has_answer_extracted") - return { - "train": split_ds["train"], - "validation": split_ds["test"], - } + # Split into train and validation sets using HF's train_test_split + split_ds = original_ds.train_test_split(test_size=test_size, seed=trainer_config.data_shuffle_seed) + + return { + "train": split_ds["train"], + "validation": split_ds["test"], + } - split_name = trainer_config.train_split if trainer_config.train_split != "train" else "train_1M" - splits = prepare_openinstructmath2_dataset(split=split_name) - template_config = load_data_template_from_file(trainer_config.chat_template_path) - if template_config is None: - raise ValueError( - f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}" - ) + +def prepare_datasets( + trainer_config: Any, + model_tokenizer: AutoTokenizer, +) -> tuple[grain.IterDataset, grain.IterDataset | None]: + """Setup and return train and test datasets.""" + template_config = load_data_template_from_file(trainer_config.chat_template_path) + if template_config is None: + raise ValueError( + f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}" + ) + + # Prepare train and test data from training data for certain datasets + eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None) + test_dataset = None + if ( + trainer_config.dataset_name + in [ + "nvidia/OpenMathInstruct-2", + "nvidia/OpenMathReasoning", + "open-r1/OpenR1-Math-220k", + "bethgelab/CuratedThoughts", + ] + and eval_dataset_name == trainer_config.dataset_name + ): + splits = prepare_train_and_eval_dataset(trainer_config) train_dataset = ( grain.MapDataset.source(splits["train"]) @@ -250,8 +227,28 @@ def prepare_openinstructmath2_dataset( ) ) - test_dataset = ( - grain.MapDataset.source(splits["validation"]) + if trainer_config.num_test_batches > 0: + test_dataset = ( + grain.MapDataset.source(splits["validation"]) + .shuffle(seed=trainer_config.data_shuffle_seed) + .map( + lambda x: utils_rl.process_data( + trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x + ) + ) + ) + else: + if not eval_dataset_name: + eval_dataset_name = trainer_config.dataset_name + + train_dataset = get_dataset( + trainer_config, + split=trainer_config.train_split, + data_files=trainer_config.hf_train_files, + dataset_name=trainer_config.dataset_name, + ) + train_dataset = ( + grain.MapDataset.source(train_dataset) .shuffle(seed=trainer_config.data_shuffle_seed) .map( lambda x: utils_rl.process_data( @@ -259,28 +256,19 @@ def prepare_openinstructmath2_dataset( ) ) ) - else: - train_dataset = get_dataset( - model_tokenizer, - trainer_config, - train_data_dir, - trainer_config.train_split, - data_files=trainer_config.hf_train_files, - dataset_name=trainer_config.dataset_name, - ) - eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None) - if not eval_dataset_name: - eval_dataset_name = trainer_config.dataset_name - - test_dataset = get_dataset( - model_tokenizer, - trainer_config, - test_data_dir, - trainer_config.eval_split, - data_files=trainer_config.hf_eval_files, - dataset_name=eval_dataset_name, - ) + if trainer_config.num_test_batches > 0: + test_dataset = get_dataset( + trainer_config, + split=trainer_config.eval_split, + data_files=trainer_config.hf_eval_files, + dataset_name=eval_dataset_name, + ) + test_dataset = ( + grain.MapDataset.source(test_dataset) + .shuffle(seed=trainer_config.data_shuffle_seed) + .map(lambda x: utils_rl.process_data(eval_dataset_name, model_tokenizer, template_config, trainer_config, x)) + ) def _filter_long_prompts(x): tokens = model_tokenizer.tokenize(x["prompts"]) @@ -300,15 +288,15 @@ def _use_raw_prompt(x): dataset_size = int(trainer_config.num_batches * trainer_config.batch_size * trainer_config.train_fraction) train_dataset = train_dataset[:dataset_size] train_dataset = train_dataset.repeat(trainer_config.num_epoch) - train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size) - test_dataset = test_dataset.filter(_filter_long_prompts) - test_dataset = test_dataset[ - trainer_config.test_batch_start_index : trainer_config.num_test_batches * trainer_config.batch_size - ] + if trainer_config.num_test_batches > 0: + test_dataset = test_dataset.filter(_filter_long_prompts) + test_dataset = test_dataset[ + trainer_config.test_batch_start_index : trainer_config.num_test_batches * trainer_config.batch_size + ] + test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size) - test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size) return train_dataset, test_dataset @@ -469,8 +457,6 @@ def _reward_fn(**kwargs): reward_fns = [ # type: ignore make_reward_fn(utils_rl.match_format_exactly), make_reward_fn(utils_rl.match_format_approximately), - # TODO(atwigg): comment out to simplify reward and overlap with check_numbers - make_reward_fn(utils_rl.check_answer), make_reward_fn(utils_rl.check_numbers), ] @@ -573,14 +559,17 @@ def rl_train(argv: Sequence[str], kwargs: dict): train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer) if trainer_config.debug.rl: + max_logging.log("Train dataset samples:") for i, ele in enumerate(train_dataset): if i >= 5: break pprint(ele) - for i, ele in enumerate(test_dataset): - if i >= 5: - break - pprint(ele) + if trainer_config.num_test_batches > 0: + max_logging.log("Test dataset samples:") + for i, ele in enumerate(test_dataset): + if i >= 5: + break + pprint(ele) if trainer_config.debug.rl: max_logging.log("Reference Model initialized successfully") @@ -604,17 +593,22 @@ def rl_train(argv: Sequence[str], kwargs: dict): max_train_steps, ) - # Before we train the model, let's evaluate the model on the test set so we can - # see the improvement post training. - (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate( - trainer_config, - test_dataset, - rl_cluster=rl_cluster, - num_passes=trainer_config.num_eval_passes, - corr_lst=trainer_config.eval_corr_lst, - make_lst=trainer_config.eval_make_lst, - ) - max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + # Run evaluation before training + if trainer_config.num_test_batches > 0: + # Update vllm with model parameters from checkpoint + rl_cluster.rollout.update_params(nnx.state(actor_model)) + + (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate( + trainer_config, + test_dataset, + rl_cluster=rl_cluster, + num_passes=trainer_config.num_eval_passes, + corr_lst=trainer_config.eval_corr_lst, + make_lst=trainer_config.eval_make_lst, + ) + max_logging.warning( + f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%" + ) # Start training if trainer_config.load_checkpoint_only_once: @@ -634,16 +628,19 @@ def rl_train(argv: Sequence[str], kwargs: dict): max_logging.warning("RL Training Completed Successfully!") - # Let's evaluate our model! - (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate( - trainer_config, - test_dataset, - rl_cluster=rl_cluster, - num_passes=trainer_config.num_eval_passes, - corr_lst=trainer_config.eval_corr_lst, - make_lst=trainer_config.eval_make_lst, - ) - max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + # Run evaluation after training + if trainer_config.num_test_batches > 0: + (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate( + trainer_config, + test_dataset, + rl_cluster=rl_cluster, + num_passes=trainer_config.num_eval_passes, + corr_lst=trainer_config.eval_corr_lst, + make_lst=trainer_config.eval_make_lst, + ) + max_logging.warning( + f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%" + ) def main(argv: Sequence[str], kwargs: dict = None) -> None: diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py index 3ff752bc30..9e1f115f2a 100644 --- a/src/maxtext/trainers/post_train/rl/utils_rl.py +++ b/src/maxtext/trainers/post_train/rl/utils_rl.py @@ -14,44 +14,48 @@ # pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught """RL Utils Module.""" +import itertools +import json import re import uuid +from typing import Any, Callable, Optional from etils import epath import optax import numpy as np - -from math_verify.errors import TimeoutException -from math_verify.metric import math_metric from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig from math_verify import parse from tunix.rl.agentic.parser.chat_template_parser import parser as agentic_chat_template_parser - -# initialize math_verify_func once -math_verify_func = math_metric( - gold_extraction_target=(LatexExtractionConfig(),), - pred_extraction_target=( - ExprExtractionConfig(), - LatexExtractionConfig(), - ), -) +from maxtext.trainers.post_train.rl.math_verify_pool import math_verify_pool, verify_math_worker +from maxtext.utils import max_logging -def boxed(x): - """Wraps the input string in a LaTeX boxed command if it's not already wrapped.""" - return "\\boxed{" + x + "}" if not x.startswith("\\boxed{") else x +EPSILON = 1e-6 +FALLBACK_ANSWER = "-1000000" -EPSILON = 1e-6 # Constants for normalization SUBSTITUTIONS = [ + # Collapse double backslashes first so subsequent rules see canonical form + # (mirrors Tunix `_strip_string` line 116). + ("\\\\", "\\"), + # Tunix `_strip_string` lines 120-121: tfrac/dfrac → frac. + ("\\tfrac", "\\frac"), + ("\\dfrac", "\\frac"), ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), (r"\ ", ""), + # Tunix `_normalize` lines 281-282: set-style answers. + (" or ", ","), + (" and ", ","), + # Tunix `_normalize` lines 284-286: scale words. + ("million", "*10^6"), + ("billion", "*10^9"), + ("trillion", "*10^12"), (" ", ""), ("mbox", "text"), (",\\text{and}", ","), @@ -59,28 +63,44 @@ def boxed(x): ("\\text{m}", "\\text{}"), ] +UNITS = [ + "yard", + "foot", + "feet", + "mile", + "day", + "week", + "month", + "year", + "hour", + "minute", + "second", + "centimeter", + "meter", + "cm", + "mm", + "km", + "inch", + "degree", + "pound", + "cent", + "mph", +] + REMOVED_EXPRESSIONS = [ + "\\left", + "\\right", + "\\!", "square", "ways", "integers", "dollars", - "mph", - "inches", - "hours", - "km", "units", "\\ldots", "sue", "points", - "feet", - "minutes", "digits", - "cents", - "degrees", - "cm", "gm", - "pounds", - "meters", "meals", "edges", "students", @@ -104,9 +124,43 @@ def boxed(x): ] -# Let's define a RegEx for checking whether the format matches. -# -def get_match_format_regex(tmvp_config): +def math_verify_func( + items: list[tuple[int, list[str], list[str]]], + scores: list[float], + timeout: float = 300, + trainer_config: Optional[Any] = None, +) -> list[float]: + """Verifies a batch of math problems, handling timeouts and exceptions. + + Dispatches to a spawn-based multiprocessing pool (`math_verify_pool`) + so that hung sympy calls inside `math_verify` can be + killed via `pool.terminate()` and so the workers cannot contend for the + trainer's TPU. + """ + if not items: + return scores + + num_procs = None + if trainer_config is not None: + timeout = getattr(trainer_config, "math_verify_timeout", timeout) + num_procs = getattr(trainer_config, "math_verify_num_procs", None) + + return math_verify_pool( + trainer_config, + items, + scores, + timeout=timeout, + num_procs=num_procs, + log_fn=max_logging.log, + ) + + +def boxed(x: str) -> str: + """Wraps the input string in a LaTeX boxed command if it's not already wrapped.""" + return "\\boxed{" + x + "}" if not x.startswith("\\boxed{") else x + + +def get_match_format_regex(tmvp_config: Any) -> re.Pattern[str]: """Returns a compiled regex to extract the answer from a completion.""" match_format = re.compile( ( @@ -123,7 +177,7 @@ def get_match_format_regex(tmvp_config): return match_format -def get_answer_fallback_regex(tmvp_config): +def get_answer_fallback_regex(tmvp_config: Any) -> re.Pattern[str]: """Returns a compiled regex that finds the *last* answer tag in a completion. Used as a fallback when the full ...... @@ -136,7 +190,7 @@ def get_answer_fallback_regex(tmvp_config): ) -def match_format_exactly(prompts, completions, tmvp_config, **kargs): +def match_format_exactly(prompts: list[str], completions: list[str], tmvp_config: Any, **kargs: Any) -> list[float]: """ Give the model a reward of tmvp_config.reward_exact_format_match points if the format matches exactly. """ @@ -152,7 +206,7 @@ def match_format_exactly(prompts, completions, tmvp_config, **kargs): return scores -def match_format_approximately(prompts, completions, tmvp_config, **kargs): +def match_format_approximately(prompts: list[str], completions: list[str], tmvp_config: Any, **kargs: Any) -> list[float]: """ We also reward the model if the format of the output matches partially. """ @@ -197,14 +251,24 @@ def normalize_final_answer(final_answer: str) -> str: """ final_answer = final_answer.split("=")[-1] + # Inject implicit mixed numbers BEFORE the substitutions strip spaces + # (mirrors Tunix `_inject_implicit_mixed_number`): "7 3/4" -> "7+3/4". + final_answer = re.sub(r"([0-9]) +([0-9])", r"\1+\2", final_answer) + # Apply substitutions and removals for before, after in SUBSTITUTIONS: final_answer = final_answer.replace(before, after) + for unit in UNITS: + final_answer = re.sub(rf"{unit}(es)?(s)? *(\^[0-9]+)?", "", final_answer) for expr in REMOVED_EXPRESSIONS: final_answer = final_answer.replace(expr, "") # Extract and normalize LaTeX math - final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub( + r".*?(\d+)?\s*\$\s*(\d+)?\s*(\\frac\{.*?\}\{.*?\}|\d+/\d+)\s*\$.*", + lambda m: f"${w}{m.group(3)}$" if (w := (m.group(1) or m.group(2))) else f"${m.group(3)}$", + final_answer, + ) final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) @@ -220,6 +284,26 @@ def normalize_final_answer(final_answer: str) -> str: final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) final_answer = final_answer.replace("$", "") + # Leading-zero fixups (Tunix `_strip_string` lines 143-149). Spaces have + # already been stripped above, so we only need the start-of-string and + # post-`{` cases. + if final_answer.startswith("."): + final_answer = "0" + final_answer + final_answer = final_answer.replace("{.", "{0.") + + # Strip a single layer of outer braces (Tunix `_normalize` lines 309-310): + # "{42}" -> "42". + if len(final_answer) >= 2 and final_answer[0] == "{" and final_answer[-1] == "}": + final_answer = final_answer[1:-1] + + # Integer-float collapse (Tunix `_normalize` lines 313-314): "2.0" -> "2". + try: + f = float(final_answer) + if abs(f - round(f)) < 1e-7: + final_answer = str(int(round(f))) + except (ValueError, OverflowError): + pass + # Normalize numbers if final_answer.replace(",", "").isdigit(): final_answer = final_answer.replace(",", "") @@ -227,92 +311,13 @@ def normalize_final_answer(final_answer: str) -> str: return final_answer -def check_answer(prompts, completions, answer, tmvp_config, **kargs): - """ - Reward the model if the answer is correct. A reward is also given if the answer - does not match exactly, i.e., based on how close the answer is to the correct - value. - """ - match_format = get_match_format_regex(tmvp_config) - answer_fallback = get_answer_fallback_regex(tmvp_config) - - extracted_responses = [] - for c in completions: - full_match = match_format.search(c) - if full_match is not None: - extracted_responses.append(full_match.group(1)) - else: - fallback_matches = answer_fallback.findall(c) - extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None) - - scores = [] - for guess, true_answer in zip(extracted_responses, answer): - score = 0 - if guess is None: - scores.append(0) - continue - # Normalize for certain datasets - if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name: - guess = normalize_final_answer(guess) - true_answer = normalize_final_answer(true_answer) - # Try math_verify first for robust comparison - verified_correct = False - mv_output = None - true_answer_fixed = true_answer - guess_fixed = guess - try: - # Fix LaTeX escaping issues for both ground truth and extracted answer - true_answer_fixed = fix_latex_escaping(true_answer) - guess_fixed = fix_latex_escaping(guess) - - mv_output = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)]) - if mv_output and mv_output[0] > 0.1: - verified_correct = True - except (TimeoutException, Exception): - pass - - # Correct answer gets tmvp_config.reward_exact_answer points! - if guess == true_answer: - score += tmvp_config.reward_exact_answer - # Give credit if spaces are seen but otherwise the answers match (useful for simple datasets like gsm8k) - elif guess.strip() == true_answer.strip(): - score += tmvp_config.reward_white_space_format_match - # Answers match upon robust comparison with math_verify - elif verified_correct: - score += tmvp_config.reward_exact_answer - else: - # We also reward it if the answer is close via ratios! - # Ie if the answer is within some range, reward it! - try: - # Fix LaTeX escaping issues for both ground truth and extracted answer - true_answer_fixed = fix_latex_escaping(true_answer) - guess_fixed = fix_latex_escaping(guess) - val_true = parse(boxed(true_answer_fixed.strip())) - val_guess = parse(boxed(guess_fixed.strip())) - - ratio = (val_guess[0] + EPSILON) / (val_true[0] + EPSILON) - if ratio >= 0.9 and ratio <= 1.1: - score += tmvp_config.reward_ratio_guess_to_answer_high - elif ratio >= 0.8 and ratio <= 1.2: - score += tmvp_config.reward_ratio_guess_to_answer_low - else: - score += tmvp_config.penalty_incorrect_answer # Penalize wrong answers - except: - score += tmvp_config.penalty_incorrect_format # Penalize - scores.append(score) - return scores - - -# Sometimes, the text between `` and `` might not be one -# number; it can be a sentence. So, we extract the number and compare the answer. - - -def get_match_numbers_regex(tmvp_config): - """Returns a compiled regex to extract the answer from a completion.""" - match_numbers = re.compile(rf"{tmvp_config.solution_start_token}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL) - if tmvp_config.debug.rl: - match_numbers.findall(f"{tmvp_config.solution_start_token} 0.34 {tmvp_config.solution_end_token}") - return match_numbers +def preprocess_math_string(text: str) -> str: + """Fix common formatting issues in text.""" + # Normalize text + text = normalize_final_answer(text).strip() + # Fix LaTeX escaping issues + text = fix_latex_escaping(text) + return text def fix_latex_escaping(text: str) -> str: @@ -413,7 +418,9 @@ def fix_latex_escaping(text: str) -> str: return text -def check_numbers(prompts, completions, answer, tmvp_config, **kargs): +def check_numbers( + prompts: list[str], completions: list[str], answer: list[str], tmvp_config: Any, **kargs: Any +) -> list[float]: """ Reward the model if the answer is correct using math_verify for robust comparison. Handles both numeric values and mathematical expressions with LaTeX. @@ -421,48 +428,41 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): question = kargs["question"] # Extract full answer content from solution tags (not just first number) - match_format = get_match_format_regex(tmvp_config) - answer_fallback = get_answer_fallback_regex(tmvp_config) - - extracted_responses = [] - for c in completions: - full_match = match_format.search(c) - if full_match is not None: - extracted_responses.append(full_match.group(1)) - else: - fallback_matches = answer_fallback.findall(c) - extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None) + extracted_responses = [extract_answer(c, tmvp_config) for c in completions] + true_answers = [list(dict.fromkeys(json.loads(acceptable_answers))) for acceptable_answers in answer] - scores = [] - - for guess, true_answer in zip(extracted_responses, answer): + scores = [tmvp_config.penalty_incorrect_format] * len(completions) # Default to penalty for incorrect format + math_verify_queue = [] + for gen_idx, (guess, unique_answers) in enumerate(zip(extracted_responses, true_answers)): if guess is None: - scores.append(0) continue - # Try math_verify first for robust comparison of both numbers and expressions - try: - # Fix LaTeX escaping issues for both ground truth and extracted answer - true_answer_fixed = fix_latex_escaping(true_answer) - guess_fixed = fix_latex_escaping(guess) - - # Normalize for certain datasets - if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name: - true_answer_fixed = normalize_final_answer(true_answer_fixed) - guess_fixed = normalize_final_answer(guess_fixed) - - # Use math_verify to compare answers (handles both numeric and expression comparison) - score, _ = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)]) - # Return scaled score: 1.5 for exact/correct, 0 otherwise - scores.append(tmvp_config.reward_exact_answer if score > 0.1 else 0.0) - except (TimeoutException, Exception): - # Fallback to simple numeric comparison if math_verify fails - try: - guess_val = float(normalize_final_answer(guess).strip()) - true_val = float(normalize_final_answer(true_answer).strip()) - scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0) - except: - scores.append(0) + if guess == FALLBACK_ANSWER: + scores[gen_idx] = tmvp_config.penalty_incorrect_answer + continue + + has_exact_match = False + for true_answer in unique_answers: + # 1. Check for exact or whitespace-normalized match first for a quick reward + if guess == true_answer: + scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_exact_answer) + has_exact_match = True + elif guess.strip() == true_answer.strip(): + scores[gen_idx] = max(scores[gen_idx], tmvp_config.reward_white_space_format_match) + has_exact_match = True + + if not has_exact_match: + norm_guess = preprocess_math_string(guess) + norm_answers = [] + for true_answer in unique_answers: + norm_answer = preprocess_math_string(true_answer) + norm_answers.append(boxed(norm_answer)) + math_verify_queue.append((gen_idx, norm_answers, [boxed(norm_guess)])) + + if math_verify_queue: + # 2. Try math_verify for robust mathematical correctness checking + scores = math_verify_func(math_verify_queue, scores, trainer_config=tmvp_config) + if tmvp_config.debug.rl: debug_log_path = epath.Path(tmvp_config.base_output_directory) / tmvp_config.run_name / "debug_rl_logs" debug_log_path.mkdir(parents=True, exist_ok=True) @@ -481,6 +481,15 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): return scores +def extract_answer(response: str, tmvp_config: Any) -> str: + """Function to extract the answer from the text based on the tmvp_config format.""" + answer_fallback = get_answer_fallback_regex(tmvp_config) + # Find the *last* occurrence of the answer tag (most likely the final answer). + fallback_matches = answer_fallback.findall(response) + extracted_response = fallback_matches[-1].strip() if fallback_matches else FALLBACK_ANSWER + return extracted_response + + def extract_hash_answer(text: str) -> str | None: """Function to extract only the answer hash from the text.""" if "####" not in text: @@ -488,7 +497,40 @@ def extract_hash_answer(text: str) -> str | None: return text.split("####")[1].strip() -def get_optimizer(tmvp_config, max_train_steps): +def check_correctness(extracted_response: str, acceptable_answers: list[str], tmvp_config: Any) -> tuple[bool, bool]: + """Handles math verification and partial correctness logic.""" + norm_response = preprocess_math_string(extracted_response) + norm_answers = [] + for answer in acceptable_answers: + norm_answers.append(preprocess_math_string(answer)) + + # Check exact correctness first + score = verify_math_worker([boxed(norm_answer) for norm_answer in norm_answers], [boxed(norm_response)]) + if score > 0.0: + return True, True # Exact correctness implies partial correctness + + # Check partial correctness if values can be extracted (within 10%) + is_partially_correct = False + try: + predictions = parse(boxed(norm_response), (ExprExtractionConfig(), LatexExtractionConfig())) + golds = list( + itertools.chain.from_iterable( + parse(boxed(norm_answer), (ExprExtractionConfig(), LatexExtractionConfig())) for norm_answer in norm_answers + ) + ) + is_partially_correct = any( + 0.9 <= (float(pred) + EPSILON) / (float(gold) + EPSILON) <= 1.1 for pred in predictions for gold in golds + ) + except: + if tmvp_config.debug.rl: + max_logging.log( + f"check_correctness failed for extracted response: {extracted_response} and answers: {acceptable_answers}" + ) + + return False, is_partially_correct + + +def get_optimizer(tmvp_config: Any, max_train_steps: int) -> optax.GradientTransformation: """Function to obtain an optax optimizer, currently we use adamw.""" schedule = optax.schedules.warmup_cosine_decay_schedule( init_value=0.0, @@ -526,7 +568,9 @@ def make_optimizer(learning_rate): return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule) -def format_maxtext_messages(messages: list[str], template_config: dict, tmvp_config) -> list[dict[str, str]]: +def format_maxtext_messages( + messages: list[str], template_config: dict[str, Any], tmvp_config: Any +) -> list[dict[str, str]]: """Helper to inject MaxText's system prompt into the input user messages.""" if template_config is None: raise ValueError("template_config cannot be None for format_maxtext_messages.") @@ -546,7 +590,67 @@ def format_maxtext_messages(messages: list[str], template_config: dict, tmvp_con return formatted_messages -def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x): +def process_answer(question: str, answer: str, question_type: str) -> list[str]: + """Function to process the answer based on the question type.""" + if question_type == "MCQ": + # For MCQs, we need to process the response to get the acceptable answers + # e.g., returns "10" and "A" if answer="A" and question="What is 5+5? (A) 10, (B) 11, ..." + return process_mcq(question, answer) + + return [answer, answer] + + +def process_mcq(question: str, answer: str) -> list[str]: + """Extracts options from MCQ question and returns a list of acceptable answers based on the provided answer key.""" + pattern = r""" + (?: + \(?([A-E])\)? # Matches (A) or A + | # OR + \\(?:text|mathrm|textbf) # Matches \text, \mathrm, or \textbf + \{([A-E])\} # Matches {A} or {B} or {C} etc. + ) + [\s\.\:\}\$]*\s* # Matches trailing punctuation/whitespace + (.*?) # The actual content of the option + (?= # Lookahead for the next option or end + \s* + (?:\\quad|\\qquad|\\hspace|\n|\r) + \s* + (?:\(?|\\(?:text|mathrm|textbf)) + | $ + ) + """ + matches = re.findall(pattern, question, re.DOTALL | re.VERBOSE) + options = {} + for m in matches: + letter = m[0] or m[1] + value = m[2].strip() + clean_value = re.sub(r"\\q?quad|\\textbf{|[~$]", "", value).strip() + options[letter] = clean_value + + # List of answer formats that should be accepted + acceptable_answers = [answer] + if answer in options: + # If the answer is a Letter (e.g., "B"), add the corresponding value + acceptable_answers.append(options[answer]) + else: + # If the answer is a value, find the corresponding letter + for letter, value in options.items(): + norm_value = normalize_final_answer(value) + norm_answer = normalize_final_answer(answer) + if norm_answer == norm_value: + acceptable_answers.append(letter) + break + + return acceptable_answers + + +def process_data( + dataset_name: str, + model_tokenizer: Any, + template_config: dict[str, Any], + tmvp_config: Any, + x: dict[str, Any], +) -> dict[str, str]: """Function to process input dataset""" def _to_str(val): @@ -554,36 +658,30 @@ def _to_str(val): return val.decode("utf-8") return str(val) - # Handle DAPO dataset schema - # originally (prompt is a list, answer is in reward_model) - # https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/viewer/default/train?row=0 - # but using https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed/viewer/all/train?row=1 - # so question is prompt and answer is solution + for key in ["problem", "prompt", "question"]: + if key in x: + question = _to_str(x[key]) + break - question = x.get("question", x.get("prompt")) - answer = x.get("answer") - if answer is None and "solution" in x: - answer = x["solution"] - - # Handle OpenMathInstruct-2 - if "problem" in x: - question = x["problem"] - if "expected_answer" in x: - answer = x["expected_answer"] + for key in ["answer", "solution", "expected_answer"]: + if key in x: + answer = _to_str(x[key]) + break # Handle AIME-2024 if "extra_info" in x and isinstance(x["extra_info"], dict) and "raw_problem" in x["extra_info"]: - question = x["extra_info"]["raw_problem"] - + question = _to_str(x["extra_info"]["raw_problem"]) if "reward_model" in x and isinstance(x["reward_model"], dict) and "ground_truth" in x["reward_model"]: - answer = x["reward_model"]["ground_truth"] + answer = _to_str(x["reward_model"]["ground_truth"]) - question = _to_str(question) - answer = _to_str(answer) - - if dataset_name == "gsm8k": + if dataset_name == "openai/gsm8k": answer = extract_hash_answer(answer) + question_type = "default" + if "question_type" in x: + question_type = _to_str(x["question_type"]) + processed_answer = process_answer(question, answer, question_type) + messages = [question] formatted_messages = format_maxtext_messages(messages, template_config, tmvp_config) @@ -594,16 +692,22 @@ def _to_str(val): ) return { - # pre-formatted prompts for evaluation + # passed to model forward pass "prompts": prompts, - # raw question for AgenticGRPOLearner to bypass formatting - "question": question, # passed to reward functions - "answer": answer, + "question": question, + # list of acceptable answers passed to reward functions + "answer": json.dumps(processed_answer), # string-encode the list to prevent grain from flattening it while batching } -def get_correctness_metrics(prompts, completions, rewards, advantages, **kwargs): +def get_correctness_metrics( + prompts: Any, + completions: Any, + rewards: np.ndarray, + advantages: Any, + **kwargs: Any, +) -> dict[str, tuple[float | int, Callable[..., Any]]]: """Compute correctness statistics metrics based on rewards.""" del prompts, completions, advantages, kwargs solve_all = (rewards > 0.1).all() @@ -637,7 +741,7 @@ class MaxTextChatParser(agentic_chat_template_parser.DefaultChatTemplateParser): special tokens using the shared helper. """ - def __init__(self, model_tokenizer, template_config, tmvp_config): + def __init__(self, model_tokenizer: Any, template_config: dict[str, Any], tmvp_config: Any) -> None: super().__init__(model_tokenizer) self.template_config = template_config self.tmvp_config = tmvp_config diff --git a/tests/post_training/unit/evaluate_rl_test.py b/tests/post_training/unit/evaluate_rl_test.py new file mode 100644 index 0000000000..205e4a8797 --- /dev/null +++ b/tests/post_training/unit/evaluate_rl_test.py @@ -0,0 +1,212 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for evaluate_rl.py (CPU-only).""" + +import unittest +import pytest +from types import SimpleNamespace + +from maxtext.trainers.post_train.rl import evaluate_rl + +pytestmark = [pytest.mark.post_training] + + +def _make_config(eval_mode="pass"): + """Create a minimal config object with the parameters required by score_responses.""" + return SimpleNamespace( + reasoning_start_token="", + reasoning_end_token="", + solution_start_token="", + solution_end_token="", + reward_exact_answer=3.0, + reward_exact_format_match=2.0, + reward_partial_format_match=0.5, + reward_white_space_format_match=1.5, + reward_ratio_guess_to_answer_high=1.0, + reward_ratio_guess_to_answer_low=0.5, + penalty_incorrect_format=-0.5, + penalty_incorrect_answer=-0.5, + dataset_name="test", + debug=SimpleNamespace(rl=False), + eval_mode=eval_mode, + ) + + +class TestScoreResponses(unittest.TestCase): + """Tests for evaluate_rl.score_responses parsing and correctness logic.""" + + def setUp(self): + self.config = _make_config(eval_mode="pass") + self.maj_config = _make_config(eval_mode="maj") + + @pytest.mark.cpu_only + def test_nested_tags(self): + """Response with nested reasoning tags still extracts the correct answer.""" + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=self.config, + question="What is 11/3?", + responses=[ + "Need to use and , " + " and 11/3" + ], + answers=["11/3"], + ) + self.assertTrue(is_correct) + self.assertTrue(is_partially_correct) + self.assertTrue(has_correct_format) + + @pytest.mark.cpu_only + def test_with_extra_ending_tags(self): + """Answer with extra ending tags such as .""" + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=self.config, + question=( + "James buys a new wardrobe. He buys 10 suits and 10 dress pants. " + "He also buys 3 dress shirts per suit. The suits cost $750 each and " + "the dress pants cost 1/5 that cost. The dress shirts were $60 each. " + "How much did everything cost?" + ), + responses=[ + "This is the sum of the cost of the suits, the pants, and the " + "shirts: $7500 + $1500 + $1800 = $10800.\n\n\n" + "10800" + ], + answers=["10,800"], + ) + self.assertTrue(is_correct) + self.assertTrue(is_partially_correct) + self.assertTrue(has_correct_format) + + @pytest.mark.cpu_only + def test_with_incomplete_reasoning_tags(self): + """(1) Incomplete reasoning tags still extracts the correct answer.""" + """(2) Currency symbols works with math_verify.""" + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=self.config, + question="What is the price of the item?", + responses=["The item costs $16.$16"], + answers=["16"], + ) + self.assertTrue(is_correct) + self.assertTrue(is_partially_correct) + self.assertFalse(has_correct_format) + + @pytest.mark.cpu_only + def test_for_mcq_value(self): + """Test for MCQ, where model responds with a math value.""" + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=self.config, + question=( + r"What is the quantity of the item? " + r"(A) $2 \frac{1}{3}$ (B) $3 \frac{1}{3}$ " + r"(C) $1 \frac{2}{3}$ (D) $1 \frac{1}{3}$ (E) 2" + ), + responses=["The answer is 3\frac{1}{3}.3\frac{1}{3}"], + answers=["3\frac{1}{3}", "B"], + ) + self.assertTrue(is_correct) + self.assertTrue(is_partially_correct) + self.assertFalse(has_correct_format) + + @pytest.mark.cpu_only + def test_for_mcq_option(self): + """Test for MCQ, where model responds with an option.""" + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=self.config, + question=( + r"What is the quantity of the item? " + r"(A) $2 \frac{1}{3}$ (B) $3 \frac{1}{3}$ " + r"(C) $1 \frac{2}{3}$ (D) $1 \frac{1}{3}$ (E) 2" + ), + responses=["The answer is B.B"], + answers=["3\frac{1}{3}", "B"], + ) + self.assertTrue(is_correct) + self.assertTrue(is_partially_correct) + self.assertFalse(has_correct_format) + + @pytest.mark.cpu_only + def test_majority_eval_mode(self): + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=self.maj_config, + question="What is the quantity of the item?", + responses=[ + r"The item is 3\frac{1}{3}3\frac{1}{3}", + r"It is 3\frac{1}{3}3\frac{1}{3}", + r"The item is 3\frac{1}{3}\frac{1}{3}", + ], + answers=[r"3\frac{1}{3}"], + ) + self.assertTrue(is_correct) + self.assertTrue(is_partially_correct) + self.assertTrue(has_correct_format) + + @pytest.mark.cpu_only + def test_pass_at_1_eval_mode(self): + """pass@1 returns fraction of correct samples, not a boolean.""" + config = _make_config(eval_mode="pass_at_1") + # 3 out of 4 samples are correct → expect 0.75 + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=config, + question="What is 2+2?", + responses=[ + "2+2=44", + "2+2=44", + "2+2=44", + "2+2=55", + ], + answers=["4"], + ) + self.assertAlmostEqual(is_correct, 0.75) + self.assertAlmostEqual(is_partially_correct, 0.75) + self.assertAlmostEqual(has_correct_format, 1.0) + + @pytest.mark.cpu_only + def test_pass_at_1_all_wrong(self): + """pass@1 with all wrong samples returns 0.0.""" + config = _make_config(eval_mode="pass_at_1") + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=config, + question="What is 2+2?", + responses=[ + "2+2=55", + "2+2=66", + ], + answers=["4"], + ) + self.assertAlmostEqual(is_correct, 0.0) + self.assertAlmostEqual(is_partially_correct, 0.0) + self.assertAlmostEqual(has_correct_format, 1.0) + + @pytest.mark.cpu_only + def test_pass_at_1_all_correct(self): + """pass@1 with all correct samples returns 1.0.""" + config = _make_config(eval_mode="pass_at_1") + is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( + tmvp_config=config, + question="What is 2+2?", + responses=[ + "2+2=44", + "Simple: 44", + ], + answers=["4"], + ) + self.assertAlmostEqual(is_correct, 1.0) + self.assertAlmostEqual(is_partially_correct, 1.0) + self.assertAlmostEqual(has_correct_format, 1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/post_training/unit/math_verify_pool_test.py b/tests/post_training/unit/math_verify_pool_test.py new file mode 100644 index 0000000000..23c84431c7 --- /dev/null +++ b/tests/post_training/unit/math_verify_pool_test.py @@ -0,0 +1,183 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for math_verify_pool grading and score-assignment logic.""" +import pytest +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +import sympy + +from maxtext.trainers.post_train.rl import math_verify_pool as mvp +from maxtext.trainers.post_train.rl.math_verify_pool import ( + are_equal_under_sympy, + math_verify_pool, + verify_math_worker, +) + +pytestmark = [pytest.mark.post_training] + + +def _make_config(reward=1.0): + return SimpleNamespace(reward_exact_answer=reward) + + +class _FakeAsyncResult: + """Stand-in for multiprocessing.pool.AsyncResult. + + Runs the target synchronously at construction and caches the outcome so + `ready()`/`get()` match the real `AsyncResult` contract without spawning + a worker. Lets us drive `math_verify_pool`'s busy-poll in-process. + """ + + def __init__(self, fn, args): + try: + self._value = fn(*args) + self._exc = None + except Exception as exc: # pylint: disable=broad-except + self._value = None + self._exc = exc + + def ready(self): + return True + + def get(self, timeout=None): # pylint: disable=unused-argument + if self._exc is not None: + raise self._exc + return self._value + + +class _FakePool: + """Minimal pool stub: runs `apply_async` synchronously in-process.""" + + def apply_async(self, fn, args): + return _FakeAsyncResult(fn, args) + + +def _fake_get_pool(num_procs): # pylint: disable=unused-argument + return _FakePool() + + +class VerifyMathWorkerTest(unittest.TestCase): + """Unit tests for the in-process grader (no pool, no spawned workers).""" + + def test_exact_numeric_match(self): + score = verify_math_worker(["\\boxed{42}"], ["\\boxed{42}"]) + self.assertEqual(score, 1.0) + + def test_numeric_mismatch(self): + score = verify_math_worker(["\\boxed{42}"], ["\\boxed{99}"]) + self.assertEqual(score, 0.0) + + def test_multiple_golds_one_matches(self): + score = verify_math_worker(["\\boxed{1}", "\\boxed{42}"], ["\\boxed{42}"]) + self.assertEqual(score, 1.0) + + def test_multiple_golds_none_matches(self): + score = verify_math_worker(["\\boxed{1}", "\\boxed{2}"], ["\\boxed{99}"]) + self.assertEqual(score, 0.0) + + def test_empty_prediction_returns_zero(self): + score = verify_math_worker(["\\boxed{42}"], [""]) + self.assertEqual(score, 0.0) + + def test_fraction_equivalent_to_decimal(self): + # 1/2 and 0.5 are numerically equal — verify() should catch this even + # if are_equal_under_sympy's structural match does not. + score = verify_math_worker(["\\boxed{\\frac{1}{2}}"], ["\\boxed{0.5}"]) + self.assertEqual(score, 1.0) + + +class AreEqualUnderSympyTest(unittest.TestCase): + """Tests for the structural sympy AST equality helper. + + `are_equal_under_sympy` is invoked first inside the worker and short-circuits + `verify()` when it returns True. Its job is cheap structural equality on + unevaluated expressions. + """ + + def test_same_integer(self): + self.assertTrue(are_equal_under_sympy(sympy.Integer(42), sympy.Integer(42))) + + def test_different_integer(self): + self.assertFalse(are_equal_under_sympy(sympy.Integer(42), sympy.Integer(99))) + + def test_same_symbol(self): + x = sympy.Symbol("x") + self.assertTrue(are_equal_under_sympy(x, x)) + + def test_malformed_input_does_not_raise(self): + # Unparsable strings must not propagate an exception; they return False. + self.assertFalse(are_equal_under_sympy("$$$", "%%%")) + + +@patch.object(mvp, "_get_pool", _fake_get_pool) +class MathVerifyPoolScoreAssignmentTest(unittest.TestCase): + """Regression tests for the score-assignment bug. + + Prior version granted `reward_exact_answer` on every completed job, ignoring + the grader's score. These tests exist to keep that bug from coming back. + + `_get_pool` is patched with an in-process fake so the busy-poll drains on + the first iteration — no spawn, no 300s global_timeout. + """ + + def test_correct_answer_gets_reward(self): + items = [(0, ["\\boxed{42}"], ["\\boxed{42}"])] + scores = [0.0] + result = math_verify_pool(_make_config(1.0), items, scores) + self.assertEqual(result[0], 1.0) + + def test_wrong_answer_does_not_get_reward(self): + items = [(0, ["\\boxed{42}"], ["\\boxed{99}"])] + scores = [0.0] + result = math_verify_pool(_make_config(1.0), items, scores) + self.assertEqual(result[0], 0.0) + + def test_wrong_answer_preserves_prior_penalty(self): + # `check_numbers` seeds scores[idx] with `penalty_incorrect_format`; a + # wrong grader verdict must not overwrite that with the reward. + items = [(0, ["\\boxed{42}"], ["\\boxed{99}"])] + scores = [-0.5] + result = math_verify_pool(_make_config(1.0), items, scores) + self.assertEqual(result[0], -0.5) + + def test_mixed_batch_scores_each_item_independently(self): + items = [ + (0, ["\\boxed{1}"], ["\\boxed{1}"]), # correct + (1, ["\\boxed{2}"], ["\\boxed{99}"]), # wrong + (2, ["\\boxed{3}"], ["\\boxed{3}"]), # correct + ] + scores = [0.0, 0.0, 0.0] + result = math_verify_pool(_make_config(1.0), items, scores) + self.assertEqual(result[0], 1.0) + self.assertEqual(result[1], 0.0) + self.assertEqual(result[2], 1.0) + + def test_reward_uses_max_not_overwrite(self): + # A correct answer must not lower an already-higher pre-existing score. + items = [(0, ["\\boxed{42}"], ["\\boxed{42}"])] + scores = [0.7] + result = math_verify_pool(_make_config(0.3), items, scores) + self.assertEqual(result[0], 0.7) + + def test_empty_items_returns_scores_unchanged(self): + scores = [0.1, 0.2, 0.3] + result = math_verify_pool(_make_config(1.0), [], scores) + self.assertEqual(result, [0.1, 0.2, 0.3]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py index 1dd8a7d0be..a972fbaa22 100644 --- a/tests/post_training/unit/rl_utils_test.py +++ b/tests/post_training/unit/rl_utils_test.py @@ -18,21 +18,13 @@ import pytest from types import SimpleNamespace -pytestmark = [pytest.mark.post_training] - -evaluate_rl = pytest.importorskip( - "maxtext.trainers.post_train.rl.evaluate_rl", - reason="tunix (required by evaluate_rl) is not installed GPU", -) +from maxtext.trainers.post_train.rl import utils_rl -utils_rl = pytest.importorskip( - "maxtext.trainers.post_train.rl.utils_rl", - reason="tunix (required by utils_rl) is not installed GPU", -) +pytestmark = [pytest.mark.post_training] def _make_config(): - """Create a minimal config object with the parameters required by score_responses.""" + """Create a minimal config object.""" return SimpleNamespace( reasoning_start_token="", reasoning_end_token="", @@ -51,63 +43,17 @@ def _make_config(): ) -class TestScoreResponses(unittest.TestCase): - """Tests for evaluate_rl.score_responses parsing and correctness logic.""" - - def setUp(self): - self.config = _make_config() - - @pytest.mark.cpu_only - def test_nested_tags(self): - """Response with nested reasoning tags still extracts the correct answer.""" - is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( - tmvp_config=self.config, - question="What is 11/3?", - responses=[ - "Need to use and , " - " and 11/3" - ], - answer="11/3", - ) - self.assertTrue(is_correct) - self.assertTrue(is_partially_correct) - self.assertTrue(has_correct_format) - - @pytest.mark.cpu_only - def test_with_extra_ending_tags(self): - """Answer with extra ending tags such as .""" - is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( - tmvp_config=self.config, - question=( - "James buys a new wardrobe. He buys 10 suits and 10 dress pants. " - "He also buys 3 dress shirts per suit. The suits cost $750 each and " - "the dress pants cost 1/5 that cost. The dress shirts were $60 each. " - "How much did everything cost?" - ), - responses=[ - "This is the sum of the cost of the suits, the pants, and the " - "shirts: $7500 + $1500 + $1800 = $10800.\n\n\n" - "10800" - ], - answer="10,800", - ) - self.assertTrue(is_correct) - self.assertTrue(is_partially_correct) - self.assertTrue(has_correct_format) +class TestProcessAnswer(unittest.TestCase): + """Tests for utils_rl.process_answer.""" @pytest.mark.cpu_only - def test_with_incomplete_reasoning_tags(self): - """(1) Incomplete reasoning tags still extracts the correct answer.""" - """(2) Currency symbols works with math_verify.""" - is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses( - tmvp_config=self.config, - question="What is the price of the item?", - responses=["The item costs $16.$16"], - answer="16", + def test_for_mcq(self): + self.assertEqual(len(utils_rl.process_answer("(A) 1\n(B) 2\n(C) 3\n", "B", "MCQ")), 2) + self.assertEqual(len(utils_rl.process_answer("A. 1\nB. 2\n(C) 3\n", "B", "MCQ")), 2) + self.assertEqual( + len(utils_rl.process_answer("$\\textbf{(A)}~\\frac{1}{24}\\qquad\\textbf{(B)}~\\frac{1}{12}$", "B", "MCQ")), 2 ) - self.assertTrue(is_correct) - self.assertTrue(is_partially_correct) - self.assertFalse(has_correct_format) + self.assertEqual(len(utils_rl.process_answer("$(\\mathrm {A}) \\ 1 \\qquad (\\mathrm {B}) \\ 2$", "B", "MCQ")), 2) class TestNormalizeFinalAnswer(unittest.TestCase): @@ -138,6 +84,7 @@ def test_latex_wrappers(self): def test_dollar_math_extraction(self): # Content inside $...$ is extracted self.assertEqual(utils_rl.normalize_final_answer("The answer is $\\frac{1}{2}$"), "\\frac{1}{2}") + self.assertEqual(utils_rl.normalize_final_answer("The answer is 3 $\\frac{1}{2}$"), "3\\frac{1}{2}") @pytest.mark.cpu_only def test_shorthand_frac_and_sqrt(self): @@ -222,7 +169,7 @@ def test_extraction_succeeds_full_format(self): """Full format allows extraction.""" scores = self._check( completions=["40 + 2 = 4242"], - answer=["42"], + answer=['["42"]'], ) self.assertEqual(scores[0], self.config.reward_exact_answer) @@ -231,16 +178,16 @@ def test_extraction_fails_no_tags(self): """Plain-text completion without any tags yields score 0 (cannot extract).""" scores = self._check( completions=["The answer is 42."], - answer=["42"], + answer=['["42"]'], ) - self.assertEqual(scores[0], 0) + self.assertEqual(scores[0], self.config.penalty_incorrect_format) @pytest.mark.cpu_only def test_extraction_fails_answer_tags_only(self): """ tag alone (no block) is matched by the regex as a fallback, score 1.5.""" scores = self._check( completions=["42"], - answer=["42"], + answer=['["42"]'], ) self.assertEqual(scores[0], self.config.reward_exact_answer) @@ -249,9 +196,9 @@ def test_extraction_fails_reasoning_tags_only(self): """ block with no tag cannot be extracted, score 0.""" scores = self._check( completions=["The answer is 42."], - answer=["42"], + answer=['["42"]'], ) - self.assertEqual(scores[0], 0) + self.assertEqual(scores[0], self.config.penalty_incorrect_format) @pytest.mark.cpu_only def test_extraction_batch_mixed(self): @@ -261,10 +208,25 @@ def test_extraction_batch_mixed(self): "thinking7", # extractable "just 7", # not extractable ], - answer=["7", "7"], + answer=['["7"]', '["7"]'], + ) + self.assertEqual(scores[0], self.config.reward_exact_answer) + self.assertEqual(scores[1], self.config.penalty_incorrect_format) + + @pytest.mark.cpu_only + def test_extraction_for_mcq(self): + """Batch with two multiple-choice questions and one single-answer question.""" + scores = self._check( + completions=[ + "thinking7", + "thinkingA", + "thinkingA", + ], + answer=['["7", "B"]', '["7", "A"]', '["7", "7"]'], ) self.assertEqual(scores[0], self.config.reward_exact_answer) - self.assertEqual(scores[1], 0) + self.assertEqual(scores[1], self.config.reward_exact_answer) + self.assertEqual(scores[2], self.config.penalty_incorrect_answer) # extracted "A" does not match "7" # --------------------------------------------------------------- # Scenario 2: extraction succeeds, value matches/mismatches the answer @@ -275,7 +237,7 @@ def test_extracted_matches_integer_answer(self): """Extracted integer equal to reference answer earns 1.5.""" scores = self._check( completions=["simple100"], - answer=["100"], + answer=['["100"]'], ) self.assertEqual(scores[0], self.config.reward_exact_answer) @@ -284,16 +246,16 @@ def test_extracted_does_not_match_answer(self): """Extracted number that differs from the reference answer earns 0.0.""" scores = self._check( completions=["wrong path99"], - answer=["42"], + answer=['["42"]'], ) - self.assertEqual(scores[0], 0.0) + self.assertEqual(scores[0], self.config.penalty_incorrect_answer) @pytest.mark.cpu_only def test_extracted_matches_comma_formatted_number(self): """Comma-formatted guess (e.g. '1,000') normalizes to match integer answer '1000'.""" scores = self._check( completions=["cost calculation1,000"], - answer=["1000"], + answer=['["1000"]'], ) self.assertEqual(scores[0], self.config.reward_exact_answer) @@ -302,7 +264,7 @@ def test_extracted_matches_with_currency_prefix(self): """Leading '$' in extracted answer is normalized away before comparison.""" scores = self._check( completions=["price is $16$16"], - answer=["16"], + answer=['["16"]'], ) self.assertEqual(scores[0], self.config.reward_exact_answer) @@ -311,9 +273,9 @@ def test_extracted_non_numeric_no_match(self): """Non-numeric extraction that cannot be float-converted and does not math-verify returns 0.""" scores = self._check( completions=["thinkingblue"], - answer=["red"], + answer=['["red"]'], ) - self.assertEqual(scores[0], 0.0) + self.assertEqual(scores[0], self.config.penalty_incorrect_format) class TestExtractHashAnswer(unittest.TestCase): diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index 8f07b01433..adce3bf9e1 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -21,14 +21,9 @@ from types import SimpleNamespace import jax +from maxtext.trainers.post_train.rl import train_rl pytestmark = [pytest.mark.post_training] - -# Same as in rl_utils_test.py. -train_rl = pytest.importorskip( - "maxtext.trainers.post_train.rl.train_rl", - reason="Tunix is not installed on the GPU image", -) from maxtext.utils import model_creation_utils @@ -306,17 +301,29 @@ def tokenize_side_effect(text): mock_tokenizer.tokenize.side_effect = tokenize_side_effect # Define dataset mock data - train_data = [{"prompts": "short"}, {"prompts": "long"}, {"prompts": "short"}, {"prompts": "long"}] - test_data = [{"prompts": "short"}, {"prompts": "long"}] + train_data = [ + {"question": "short", "answer": "a1"}, + {"question": "long", "answer": "a2"}, + {"question": "short", "answer": "a3"}, + {"question": "long", "answer": "a4"}, + ] + test_data = [{"question": "short", "answer": "a5"}, {"question": "long", "answer": "a6"}] train_map_ds = grain.MapDataset.source(train_data) test_map_ds = grain.MapDataset.source(test_data) - def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files=None, dataset_name=None): + def get_dataset_side_effect(config, split, data_files=None, dataset_name=None): if split == "train": return train_map_ds else: return test_map_ds + def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config, trainer_config, x): + return { + "prompts": x["question"], + "question": x["question"], + "answer": f"[{x['answer'], x['answer']}]", + } + # Configs trainer_config = SimpleNamespace( debug=SimpleNamespace(rl=False), @@ -327,6 +334,8 @@ def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files eval_split="eval", hf_train_files=None, hf_eval_files=None, + chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_shuffle_seed=42, max_prefill_predict_length=10, batch_size=2, num_batches=2, @@ -336,11 +345,9 @@ def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files test_batch_start_index=0, ) - # Patch everything! with ( mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect), - mock.patch("maxtext.trainers.post_train.rl.train_rl.os.makedirs"), - mock.patch("maxtext.trainers.post_train.rl.train_rl.os.path.exists", return_value=True), + mock.patch("maxtext.trainers.post_train.rl.utils_rl.process_data", side_effect=get_filtered_data_side_effect), ): train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer) @@ -366,6 +373,104 @@ def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files self.assertEqual(len(test_batch["prompts"]), 1) self.assertEqual(test_batch["prompts"][0], "short") + @pytest.mark.cpu_only + @mock.patch("datasets.load_dataset") + def test_prepare_datasets_with_split(self, mock_load): + mock_ds = mock.MagicMock() + mock_split_result = { + "train": [{"question": "q1", "answer": "a1"}, {"question": "q2", "answer": "a2"}], + "test": [{"question": "q3", "answer": "a3"}], + } + mock_ds.train_test_split.return_value = mock_split_result + mock_load.return_value = mock_ds + mock_config = SimpleNamespace( + debug=SimpleNamespace(rl=False), + dataset_name="open-r1/OpenR1-Math-220k", + eval_dataset_name="open-r1/OpenR1-Math-220k", + train_split="train", + hf_train_files="hf://open-r1/OpenR1-Math-220k/data/dummy.parquet", + chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_shuffle_seed=42, + num_batches=1, + batch_size=5, + train_fraction=1.0, + num_epoch=1, + num_test_batches=1, + test_batch_start_index=0, + rl=SimpleNamespace(use_agentic_rollout=False), + reasoning_start_token="", + reasoning_end_token="", + solution_start_token="", + solution_end_token="", + max_prefill_predict_length=256, + ) + + train_ds, test_ds = train_rl.prepare_datasets( + trainer_config=mock_config, + model_tokenizer=mock.MagicMock(), + ) + + mock_load.assert_called_once_with( + "parquet", + data_files={mock_config.train_split: mock_config.hf_train_files}, + split=mock_config.train_split, + ) + mock_ds.train_test_split.assert_called_once_with(test_size=0.05, seed=mock_config.data_shuffle_seed) + train_batches, test_batches = list(train_ds), list(test_ds) + total_train_examples = sum(len(batch["question"]) for batch in train_batches) + assert total_train_examples == 2 + total_test_examples = sum(len(batch["question"]) for batch in test_batches) + assert total_test_examples == 1 + + @pytest.mark.cpu_only + @mock.patch("datasets.load_dataset") + def test_prepare_datasets_without_split(self, mock_load): + mock_ds = mock.MagicMock() + mock_load.return_value = mock_ds + mock_config = SimpleNamespace( + debug=SimpleNamespace(rl=False), + dataset_name="openai/gsm8k", + eval_dataset_name="openai/gsm8k", + train_split="train", + eval_split="test", + hf_train_files="hf://openai/gsm8k/data/dummy.parquet", + hf_eval_files="hf://openai/gsm8k/data/dummy.parquet", + chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_shuffle_seed=42, + num_batches=1, + batch_size=5, + train_fraction=1.0, + num_epoch=1, + num_test_batches=1, + test_batch_start_index=0, + rl=SimpleNamespace(use_agentic_rollout=False), + reasoning_start_token="", + reasoning_end_token="", + solution_start_token="", + solution_end_token="", + max_prefill_predict_length=256, + ) + + _, _ = train_rl.prepare_datasets( + trainer_config=mock_config, + model_tokenizer=mock.MagicMock(), + ) + + expected_calls = [ + mock.call( + "parquet", + data_files={mock_config.train_split: mock_config.hf_train_files}, + split=mock_config.train_split, + ), + mock.call( + "parquet", + data_files={mock_config.eval_split: mock_config.hf_eval_files}, + split=mock_config.eval_split, + ), + ] + mock_load.assert_has_calls(expected_calls, any_order=True) + assert mock_load.call_count == len(expected_calls) + if __name__ == "__main__": unittest.main()