diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 3f99a3c156..133d66c730 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -54,6 +54,9 @@ rl: grpo_epsilon: 0.2 loss_algo: 'grpo' # grpo or gspo-token + # Specifies the method for aggregating loss across the batch. + loss_agg_mode: 'sequence-mean-token-mean' # 'token-mean' | 'sequence-mean' | 'sequence-mean-token-mean' + # ====== Agentic Rollout ====== # If true, uses the async AgenticGRPOLearner, which overlaps rollout generation # with training for faster throughput via online vLLM inference. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 2c089aee3a..0964292f49 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1963,6 +1963,10 @@ class RL(BaseModel): grpo_beta: float = Field(0.08, description="Coefficient for the KL divergence penalty (β).") grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.") loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.") + loss_agg_mode: Literal["token-mean", "sequence-mean", "sequence-mean-token-mean"] = Field( + "sequence-mean-token-mean", + description="Specifies the method for aggregating loss across the batch.", + ) use_agentic_rollout: bool = Field( False, description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts.", diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 51458955c3..b1fae97634 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -571,6 +571,7 @@ def _reward_fn(**kwargs): beta=trainer_config.rl.grpo_beta, epsilon=trainer_config.rl.grpo_epsilon, loss_algo=trainer_config.rl.loss_algo, + loss_agg_mode=trainer_config.rl.loss_agg_mode, ) rl_trainer = GrpoLearner( rl_cluster=rl_cluster,