Skip to content

Commit ff7eb45

Browse files
author
Pooya Moradi
committed
Plumb rl.loss_agg_mode to tunix GrpoConfig
tunix's GrpoConfig defaults loss_agg_mode to 'sequence-mean-token-mean', but GPU NeMo-RL stacks use 'token-mean'. With group-normalized advantages the two modes produce materially different losses, breaking GPU<->TPU recipe parity. Adds the field to the RL Pydantic schema + rl.yml default + passes it through to GrpoConfig construction so users can override via cmdline: 'rl.loss_agg_mode=token-mean'.
1 parent be4fd71 commit ff7eb45

3 files changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ rl:
5454
grpo_epsilon: 0.2
5555
loss_algo: 'grpo' # grpo or gspo-token
5656

57+
# Specifies the method for aggregating loss across the batch.
58+
loss_agg_mode: 'sequence-mean-token-mean' # 'token-mean' | 'sequence-mean' | 'sequence-mean-token-mean'
59+
5760
# ====== Agentic Rollout ======
5861
# If true, uses the async AgenticGRPOLearner, which overlaps rollout generation
5962
# with training for faster throughput via online vLLM inference.

src/maxtext/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,10 @@ class RL(BaseModel):
19631963
grpo_beta: float = Field(0.08, description="Coefficient for the KL divergence penalty (β).")
19641964
grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.")
19651965
loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.")
1966+
loss_agg_mode: Literal["token-mean", "sequence-mean", "sequence-mean-token-mean"] = Field(
1967+
"sequence-mean-token-mean",
1968+
description="Specifies the method for aggregating loss across the batch.",
1969+
)
19661970
use_agentic_rollout: bool = Field(
19671971
False,
19681972
description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts.",

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ def _reward_fn(**kwargs):
571571
beta=trainer_config.rl.grpo_beta,
572572
epsilon=trainer_config.rl.grpo_epsilon,
573573
loss_algo=trainer_config.rl.loss_algo,
574+
loss_agg_mode=trainer_config.rl.loss_agg_mode,
574575
)
575576
rl_trainer = GrpoLearner(
576577
rl_cluster=rl_cluster,

0 commit comments

Comments
 (0)