-
Notifications
You must be signed in to change notification settings - Fork 1.4k
REAL Loss (Rewards as Labels) for GRPO Training #8424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae2be72
6a5ed0d
9d410b6
f3ce16e
6a6798b
1379451
3938eeb
93198e7
9318d27
57b5078
c2dc38c
4c2d1dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Rewards as Labels: Revisiting RLVR from a Classification Perspective | ||
|
|
||
| **版本依赖**:ms-swift>4.0 | ||
|
|
||
| [Rewards as Labels: Revisiting RLVR from a Classification Perspective](https://arxiv.org/abs/2602.05630) 针对GRPO提出把奖励视为标签,在group内分类而不是计算advantage,从而将策略优化问题转化为分类问题,以此解决GRPO Loss中存在的正样本**梯度错配**与负样本**梯度主导**问题。 | ||
|
|
||
| ## 背景与动机 | ||
|
|
||
| GRPO目标函数 | ||
|
|
||
| $$ | ||
| J_{\mathrm{GRPO}}(\theta)=\mathbb{E}_{q,o\sim\pi_{\mathrm{od}}(\cdot|q)}\left[\frac{1}{|o|}\sum_{t=1}^{|o|}\left(\min\left(\rho_tA_t,\mathrm{clip}(\rho_t,1-\epsilon,1+\epsilon)A_t\right)\right)\right] | ||
| $$ | ||
|
|
||
| 其中$\rho_t=\frac{\pi_\theta(o_t|q)}{\pi_{\mathrm{old}}(o_t|q)}$为相对概率,$A_{t}$为优势函数,故梯度为: | ||
|
|
||
| $$ | ||
| \nabla_{\theta} J_{\mathrm{GRPO}} = \mathbb { E } \left[ \frac { 1 } { | o | } \sum _ { t = 1 } ^ { | o | } \mathbb { I } _ { \mathrm { clip } } \cdot A _ { t } e ^ { s _ { t } } \nabla _ { \theta } \log \pi _ { \theta } \left( o _ { t } | q \right) \right] | ||
| $$ | ||
|
|
||
| 其中$s_t=\log\frac{\pi_\theta(o_t|q)}{\pi_{\mathrm{old}}(o_t|q)}$作为token的相对对数概率,$\mathbb { I } _ { \mathrm { clip } }$为指示函数 | ||
|
|
||
| 故 GRPO 对单 token 的梯度权重为: | ||
|
|
||
| $$ | ||
| |\mathcal{W}_{\mathrm{GRPO}}|=\left\{ \begin{array} {ll}\left|A\cdot e^s\right|, & \mathrm{if~}\mathbb{I}_{\mathrm{clip}}=1, \\ 0, & \text{otherwise.} \end{array}\right. | ||
| $$ | ||
|
|
||
|  | ||
|
|
||
| - 正样本的梯度错配(Gradient Misassignment):对正样本来说,随着相对概率$s$变小,梯度更新幅度反而越弱。这违背直觉,因为模型对“不太自信”的正确 token 本来就需要更大的更新幅度来强化,但更多的梯度权重却放到更“自信”的 token,没学好的 token 得不到足够的重视。 | ||
|
|
||
| - 负样本的梯度主导(Gradient Domination):对负样本来说,随着相对概率$s$变小,梯度更新幅度呈指数级增加。这意味着,只要出现几个模型“盲目自信”的错误 token,它们产生的巨大梯度就会把同组内其他负样本的信号淹没。由于缺乏上限保护,模型在处理这些错误样本时可能会产生过大的参数更新,让训练过程变得不太可控。 | ||
|
|
||
| 为解决上述问题,Real提出将奖励直接视为标签然后进行组内的样本分类训练 | ||
|
|
||
|  | ||
|
|
||
| 分类的logits分值设计: | ||
|
|
||
| $$ | ||
| \bar{s}^k=\frac{1}{|o^k|}\sum_{t=1}^{|o^k|}\left(\log\frac{\pi_\theta(o_t^k\mid q)}{\pi_{\mathrm{old}}(o_t^k\mid q)}\right) | ||
| $$ | ||
|
|
||
| - $\bar{s}^k > 0$: 表示该样本在当前策略下生成的概率比旧策略整体更高,模型倾向于**增强**该样本。 | ||
| - $\bar{s}^k < 0$: 表示该样本在当前策略下生成的概率比旧策略整体更低,模型倾向于**抑制**该样本。 | ||
|
|
||
| 损失函数设计: | ||
|
|
||
| $$ | ||
| \mathcal{L}_{REAL}=\log\left(1+\sum_{\mathcal{O}_+}e^{-\bar{s}^i/\tau}\right)+\log\left(1+\sum_{\mathcal{O}_-}e^{\bar{s}^j/\tau}\right) | ||
| $$ | ||
|
|
||
| 梯度特性: | ||
| $$ | ||
| |\mathcal{W}_{\mathrm{REAL}}|= | ||
| \begin{cases} | ||
| \frac{1}{\tau}\frac{1}{1+C_{+}e^{\bar{s}^{k}/\tau}}, & r=1 \\ | ||
| \\ | ||
| \frac{1}{\tau}\frac{1}{1+C_{-}e^{-\bar{s}^{k}/\tau}}, & r=0 & & & | ||
| \end{cases} | ||
| $$ | ||
|
|
||
| ## 参数设置 | ||
|
|
||
| | 参数 | 类型 | 默认值 | 说明 | | ||
| |-------------------|---------|-------|-------------------| | ||
| | `--loss_type` | `str` | - | 设置为 `real` | | ||
| | `--real_tau` | `float` | `0.5` | 温度参数,控制决策边界锐度 | | ||
|
|
||
| 训练脚本参考 | ||
|
|
||
| [swift](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/real.sh) | ||
|
|
||
| ## 注意事项 | ||
|
|
||
| 设置参数时,确保 per_device_train_batch_size 能够被 num_generations 整除,以此保证单个训练batch中能拿到完整的 group 进行分类。 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Rewards as Labels: Revisiting RLVR from a Classification Perspective | ||
|
|
||
| **Version Requirement**:ms-swift>4.0 | ||
|
|
||
| [Rewards as Labels: Revisiting RLVR from a Classification Perspective](https://arxiv.org/abs/2602.05630) proposes a reformulation of GRPO by treating rewards as labels and performing **in-group classification** instead of advantage estimation. This converts the policy optimization problem into a classification problem, thereby addressing two key issues in the GRPO loss: | ||
| - **Gradient Misassignment** for positive samples | ||
| - **Gradient Domination** for negative samples | ||
|
|
||
| ## Background and Motivation | ||
|
|
||
| GRPO Objective | ||
|
|
||
| $$ | ||
| J_{\mathrm{GRPO}}(\theta)=\mathbb{E}_{q,o\sim\pi_{\mathrm{od}}(\cdot|q)}\left[\frac{1}{|o|}\sum_{t=1}^{|o|}\left(\min\left(\rho_tA_t,\mathrm{clip}(\rho_t,1-\epsilon,1+\epsilon)A_t\right)\right)\right] | ||
| $$ | ||
|
|
||
| where: | ||
| - $\rho_t = \frac{\pi_\theta(o_t|q)}{\pi_{\mathrm{old}}(o_t|q)}$ is the probability ratio | ||
| - $A_t$ is the advantage function | ||
|
|
||
| The corresponding gradient is: | ||
|
|
||
| $$ | ||
| \nabla_{\theta} J_{\mathrm{GRPO}} = \mathbb { E } \left[ \frac { 1 } { | o | } \sum _ { t = 1 } ^ { | o | } \mathbb { I } _ { \mathrm { clip } } \cdot A _ { t } e ^ { s _ { t } } \nabla _ { \theta } \log \pi _ { \theta } \left( o _ { t } | q \right) \right] | ||
| $$ | ||
|
|
||
| where: | ||
| - $s_t = \log \frac{\pi_\theta(o_t|q)}{\pi_{\mathrm{old}}(o_t|q)}$ is the relative log-probability | ||
| - $\mathbb{I}_{\mathrm{clip}}$ is the clipping indicator | ||
|
|
||
| Thus, the per-token gradient weight in GRPO is: | ||
|
|
||
| $$ | ||
| |\mathcal{W}_{\mathrm{GRPO}}|=\left\{ \begin{array} {ll}\left|A\cdot e^s\right|, & \mathrm{if~}\mathbb{I}_{\mathrm{clip}}=1, \\ 0, & \text{otherwise.} \end{array}\right. | ||
| $$ | ||
|
|
||
|  | ||
|
|
||
| 1. **Gradient Misassignment (Positive Samples)**: | ||
| For positive samples, as the relative log-probability $s$ decreases, the gradient magnitude also decreases. | ||
| This is counterintuitive: tokens that the model is less confident about but correct should receive larger updates. However, GRPO assigns more weight to already confident tokens, causing under-trained tokens to receive insufficient learning signal. | ||
|
|
||
| 2. **Gradient Domination (Negative Samples)**: | ||
| For negative samples, as $s$ decreases, the gradient magnitude increases exponentially. | ||
| This leads to a situation where a few overconfident incorrect tokens dominate the gradient, overwhelming other negative signals within the same group. Due to the absence of an upper bound, this may result in unstable and excessively large parameter updates. | ||
|
|
||
| To address the above issues, REAL treats rewards directly as labels and performs **group-wise classification training**. | ||
|
|
||
|  | ||
|
|
||
| The classification logit for each sample is defined as: | ||
|
|
||
| $$ | ||
| \bar{s}^k=\frac{1}{|o^k|}\sum_{t=1}^{|o^k|}\left(\log\frac{\pi_\theta(o_t^k\mid q)}{\pi_{\mathrm{old}}(o_t^k\mid q)}\right) | ||
| $$ | ||
|
|
||
| - $\bar{s}^k > 0$: The sample is more likely under the current policy than the old policy → the model tends to **promote** this sample | ||
| - $\bar{s}^k < 0$: The sample is less likely under the current policy → the model tends to **suppress** this sample | ||
|
|
||
| Loss Function | ||
|
|
||
| $$ | ||
| \mathcal{L}_{REAL}=\log\left(1+\sum_{\mathcal{O}_+}e^{-\bar{s}^i/\tau}\right)+\log\left(1+\sum_{\mathcal{O}_-}e^{\bar{s}^j/\tau}\right) | ||
| $$ | ||
|
|
||
| Gradient Properties | ||
|
|
||
| $$ | ||
| |\mathcal{W}_{\mathrm{REAL}}|= | ||
| \begin{cases} | ||
| \frac{1}{\tau}\frac{1}{1+C_{+}e^{\bar{s}^{k}/\tau}}, & r=1 \\ | ||
| \\ | ||
| \frac{1}{\tau}\frac{1}{1+C_{-}e^{-\bar{s}^{k}/\tau}}, & r=0 & & & | ||
| \end{cases} | ||
| $$ | ||
|
|
||
| ## Parameter Settings | ||
|
|
||
| | Parameter | Type | Default | Description | | ||
| |-----------|------|---------|--------------------------------------------------------------------| | ||
| | `--loss_type` | `str` | - | Set to `real` | | ||
| | `--real_tau` | `float` | `0.5` | Temperature parameter controlling decision boundary sharpness | | ||
|
|
||
| Training Script Reference | ||
|
|
||
| [swift](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/real.sh) | ||
|
|
||
| ## Important Notes | ||
|
|
||
| When configuring training parameters, ensure that: | ||
| - `per_device_train_batch_size` is divisible by `num_generations` | ||
|
|
||
| This guarantees that each training batch contains complete groups, which is required for correct in-group classification. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| CUDA_VISIBLE_DEVICES=2 \ | ||
| swift rollout \ | ||
| --model Qwen/Qwen3-1.7B | ||
|
|
||
|
|
||
| NPROC_PER_NODE=2 \ | ||
| CUDA_VISIBLE_DEVICES=0,1 \ | ||
| swift rlhf \ | ||
| --rlhf_type grpo \ | ||
| --model Qwen/Qwen3-1.7B \ | ||
| --dataset 'AI-MO/NuminaMath-TIR#5000' \ | ||
| --enable_thinking false \ | ||
| --reward_funcs accuracy \ | ||
| --use_vllm true \ | ||
| --vllm_mode server \ | ||
| --vllm_server_host 127.0.0.1 \ | ||
| --vllm_server_port 8000 \ | ||
| --tuner_type full \ | ||
| --torch_dtype bfloat16 \ | ||
| --load_from_cache_file true \ | ||
| --max_completion_length 4096 \ | ||
| --num_train_epochs 1 \ | ||
| --per_device_train_batch_size 8 \ | ||
| --learning_rate 2e-6 \ | ||
| --gradient_accumulation_steps 1 \ | ||
| --save_total_limit 2 \ | ||
| --save_steps 500 \ | ||
| --logging_steps 1 \ | ||
| --warmup_ratio 0.05 \ | ||
| --dataloader_num_workers 4 \ | ||
| --num_generations 8 \ | ||
| --temperature 0.6 \ | ||
| --system """You are a helpful math assistant. Solve the problem step by step and put your final answer within \\boxed{}.""" \ | ||
| --log_completions true \ | ||
| --num_iterations 1 \ | ||
| --beta 0.001 \ | ||
| --loss_type real \ | ||
| --deepspeed zero2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1202,6 +1202,8 @@ def _compute_loss_and_metrics(self, model, inputs): | |
| soft_gate = torch.where(is_positive, gate_pos, gate_neg) | ||
|
|
||
| per_token_loss = -soft_gate * advantages_expanded | ||
| elif self.loss_type == 'real': | ||
| per_token_loss = torch.zeros_like(per_token_logps) | ||
| elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']: | ||
| coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) | ||
| if self.args.delta is not None: | ||
|
|
@@ -1240,6 +1242,39 @@ def _compute_loss_and_metrics(self, model, inputs): | |
| elif self.loss_type == 'dr_grpo': | ||
| batch_size = completion_mask.shape[0] | ||
| loss = (per_token_loss * completion_mask).sum() / (batch_size * self.max_completion_length) | ||
| elif self.loss_type == 'real': | ||
| global_scores = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) | ||
|
|
||
| group_scores = global_scores.view(-1, self.num_generations) | ||
| group_rewards = advantages.view(-1, self.num_generations) | ||
|
li2zhi marked this conversation as resolved.
|
||
|
|
||
| pos_mask = (group_rewards > 0) | ||
| neg_mask = (group_rewards <= 0) | ||
| valid_mask = (pos_mask.sum(dim=1) != 0) & (neg_mask.sum(dim=1) != 0) | ||
|
|
||
| if not valid_mask.any(): | ||
| loss = torch.tensor(0., device=global_scores.device) * global_scores.mean() | ||
| else: | ||
| batch_scores = group_scores[valid_mask] | ||
| batch_pos_mask = pos_mask[valid_mask] | ||
| batch_neg_mask = neg_mask[valid_mask] | ||
|
|
||
| scaled_scores = batch_scores / self.real_tau | ||
| zeros = torch.zeros(batch_scores.size(0), 1, device=batch_scores.device, dtype=batch_scores.dtype) | ||
|
|
||
| # Negative Loss: log(1 + sum(e^{S_neg})) | ||
| neg_input = scaled_scores.masked_fill(~batch_neg_mask, float('-inf')) | ||
| neg_loss = torch.logsumexp(torch.cat([neg_input, zeros], dim=1), dim=1) | ||
|
|
||
| # Positive Loss: log(1 + sum(e^{-S_pos})) | ||
| pos_input = (-scaled_scores).masked_fill(~batch_pos_mask, float('-inf')) | ||
| pos_loss = torch.logsumexp(torch.cat([pos_input, zeros], dim=1), dim=1) | ||
|
|
||
| loss = (neg_loss + pos_loss).sum() / group_rewards.size(0) | ||
|
li2zhi marked this conversation as resolved.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we account for the number of valid samples instead?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the current implementation, invalid groups are not dropped but contribute zero loss (via masking + logsumexp). Therefore, the objective can be viewed as an expectation over the full data distribution: where invalid groups naturally have: If we instead normalize by the number of valid samples, the objective becomes a conditional expectation over valid groups only, which introduces bias relative to the original sampling distribution. In addition, since the number of valid samples can vary across batches and training stages, such normalization would lead to unstable gradient scaling (effectively changing the learning rate dynamically). |
||
|
|
||
| if self.beta != 0.0: | ||
| kl_loss = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) | ||
| loss = loss + kl_loss * self.beta | ||
| elif self.loss_type in ['cispo', 'dapo']: | ||
| # CISPO and DAPO: Normalize by total completion tokens across all processes | ||
| normalizer = inputs['num_items_in_batch'] / self.accelerator.num_processes | ||
|
|
@@ -1279,7 +1314,7 @@ def masked_batch_mean(x): | |
| cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) | ||
| gathered_cispo_clip_ratio = self.accelerator.gather_for_metrics(cispo_clip_ratio) | ||
| metrics_data['clipping'] = {'cispo_clip_ratio': gathered_cispo_clip_ratio.nanmean().item()} | ||
| elif self.loss_type == 'sapo': | ||
| elif self.loss_type in ['sapo', 'real']: | ||
| pass | ||
| else: | ||
| is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) | ||
|
|
@@ -2183,6 +2218,9 @@ def _prepare_algorithm_params(self): | |
| self.tau_pos = args.tau_pos | ||
| self.tau_neg = args.tau_neg | ||
|
|
||
| # REAL, https://arxiv.org/abs/2602.05630 | ||
| self.real_tau = args.real_tau | ||
|
|
||
| # RLOO, | ||
| self.advantage_estimator = args.advantage_estimator | ||
| self.kl_in_reward = args.kl_in_reward | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we need to prompt the user that scale_rewards has been overridden (logger.info)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out