From ff7eb45d54dcbb937661c0aecb870f5ee07b6d39 Mon Sep 17 00:00:00 2001 From: Pooya Moradi Date: Tue, 26 May 2026 12:04:47 +0000 Subject: [PATCH] Plumb rl.loss_agg_mode to tunix GrpoConfig tunix's GrpoConfig defaults loss_agg_mode to 'sequence-mean-token-mean', but GPU NeMo-RL stacks use 'token-mean'. With group-normalized advantages the two modes produce materially different losses, breaking GPU<->TPU recipe parity. Adds the field to the RL Pydantic schema + rl.yml default + passes it through to GrpoConfig construction so users can override via cmdline: 'rl.loss_agg_mode=token-mean'. --- src/maxtext/configs/post_train/rl.yml | 3 +++ src/maxtext/configs/types.py | 4 ++++ src/maxtext/trainers/post_train/rl/train_rl.py | 1 + 3 files changed, 8 insertions(+) diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 3f99a3c156..133d66c730 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -54,6 +54,9 @@ rl: grpo_epsilon: 0.2 loss_algo: 'grpo' # grpo or gspo-token + # Specifies the method for aggregating loss across the batch. + loss_agg_mode: 'sequence-mean-token-mean' # 'token-mean' | 'sequence-mean' | 'sequence-mean-token-mean' + # ====== Agentic Rollout ====== # If true, uses the async AgenticGRPOLearner, which overlaps rollout generation # with training for faster throughput via online vLLM inference. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 2c089aee3a..0964292f49 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1963,6 +1963,10 @@ class RL(BaseModel): grpo_beta: float = Field(0.08, description="Coefficient for the KL divergence penalty (β).") grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.") loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.") + loss_agg_mode: Literal["token-mean", "sequence-mean", "sequence-mean-token-mean"] = Field( + "sequence-mean-token-mean", + description="Specifies the method for aggregating loss across the batch.", + ) use_agentic_rollout: bool = Field( False, description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts.", diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 51458955c3..b1fae97634 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -571,6 +571,7 @@ def _reward_fn(**kwargs): beta=trainer_config.rl.grpo_beta, epsilon=trainer_config.rl.grpo_epsilon, loss_algo=trainer_config.rl.loss_algo, + loss_agg_mode=trainer_config.rl.loss_agg_mode, ) rl_trainer = GrpoLearner( rl_cluster=rl_cluster,