Skip to content

Commit 0d7759b

Browse files
author
Pooya Moradi
committed
Plumb rl.loss_agg_mode to tunix GrpoConfig
tunix's GrpoConfig defaults loss_agg_mode to 'sequence-mean-token-mean', but GPU NeMo-RL stacks use 'token-mean'. With group-normalized advantages the two modes produce materially different losses, breaking GPU<->TPU recipe parity. Adds the field to the RL Pydantic schema + rl.yml default + passes it through to GrpoConfig construction so users can override via cmdline: 'rl.loss_agg_mode=token-mean'.
1 parent be4fd71 commit 0d7759b

3 files changed

Lines changed: 17 additions & 0 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ rl:
5454
grpo_epsilon: 0.2
5555
loss_algo: 'grpo' # grpo or gspo-token
5656

57+
# Loss aggregation mode passed to tunix's GrpoConfig. tunix defaults to
58+
# 'sequence-mean-token-mean'; set 'token-mean' for parity with GPU NeMo-RL
59+
# stacks that use token-mean aggregation. With group-normalized advantages
60+
# the two modes produce materially different losses.
61+
loss_agg_mode: 'sequence-mean-token-mean' # 'token-mean' | 'sequence-mean' | 'sequence-mean-token-mean'
62+
5763
# ====== Agentic Rollout ======
5864
# If true, uses the async AgenticGRPOLearner, which overlaps rollout generation
5965
# with training for faster throughput via online vLLM inference.

src/maxtext/configs/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,16 @@ class RL(BaseModel):
19631963
grpo_beta: float = Field(0.08, description="Coefficient for the KL divergence penalty (β).")
19641964
grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.")
19651965
loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.")
1966+
loss_agg_mode: Literal["token-mean", "sequence-mean", "sequence-mean-token-mean"] = Field(
1967+
"sequence-mean-token-mean",
1968+
description=(
1969+
"Loss aggregation mode passed to tunix's GrpoConfig. tunix defaults"
1970+
" to 'sequence-mean-token-mean'; set 'token-mean' for parity with"
1971+
" GPU NeMo-RL stacks that use token-mean aggregation. With"
1972+
" group-normalized advantages the two modes produce materially"
1973+
" different losses."
1974+
),
1975+
)
19661976
use_agentic_rollout: bool = Field(
19671977
False,
19681978
description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts.",

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ def _reward_fn(**kwargs):
571571
beta=trainer_config.rl.grpo_beta,
572572
epsilon=trainer_config.rl.grpo_epsilon,
573573
loss_algo=trainer_config.rl.loss_algo,
574+
loss_agg_mode=trainer_config.rl.loss_agg_mode,
574575
)
575576
rl_trainer = GrpoLearner(
576577
rl_cluster=rl_cluster,

0 commit comments

Comments
 (0)