@@ -101,6 +101,7 @@ def __init__(
101101 clip_eps_high = grpo_config .get ("clip_eps_high" , 0.2 ),
102102 beta = grpo_config .get ("beta" , 0.01 ),
103103 loss_variation = grpo_config .get ("loss_variation" , "sample_level" ),
104+ adv = grpo_config .get ("algo" ),
104105 )
105106
106107 # Reference model is initialized from policy model.
@@ -137,6 +138,8 @@ def __init__(
137138 eta_min = 0.1 * grpo_config .get ("lr" , 1e-6 ),
138139 )
139140
141+ self .adv = grpo_config .get ("algo" )
142+
140143 def setup (self ):
141144 super ().setup ()
142145 if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
@@ -204,9 +207,23 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
204207 # [minibatch_size x num_generations]
205208 reward_mean = reward_mean .repeat_interleave (self .num_generations , dim = 0 )
206209
207- reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
208- # [minibatch_size x num_generations]
209- advantages = ((reward - reward_mean ) / (reward_std + 1e-4 )).unsqueeze (dim = - 1 )
210+ if self .adv == "GRPO" or self .adv == "DAPO" :
211+
212+ reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
213+ # [minibatch_size x num_generations]
214+ advantages = ((reward - reward_mean ) / (reward_std + 1e-4 )).unsqueeze (dim = - 1 )
215+
216+ elif self .adv == "REINFORCE_PPB" :
217+
218+ # [minibatch_size x num_generations]
219+ advantages = ((reward - reward_mean )).unsqueeze (dim = - 1 )
220+
221+ elif self .adv == "RLOO" :
222+
223+ advantages = (
224+ reward * self .num_generations / (self .num_generations - 1 )
225+ - reward_mean * self .num_generations / (self .num_generations - 1 )
226+ ).unsqueeze (dim = - 1 )
210227
211228 # [minibatch_size x num_of_generation]
212229 loss_mask = torch .ones (action_mask .size (0 ), device = action_mask .device ).bool ()
@@ -358,10 +375,34 @@ def _criterion(outputs, inputs):
358375 per_token_kl = 0.0
359376 kl .append (torch .tensor (0.0 ))
360377
378+ inputs ["advantages" ].repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 )
379+
380+ if self .adv == "REINFORCE_PPB" :
381+
382+ inputs ["advantages" ] = inputs ["advantages" ] - self .policy_loss_fn .beta * per_token_kl
383+ advantages_forward_micro_batch_mean = torch .sum (
384+ inputs ["advantages" ] * inputs ["action_mask" ]
385+ ) / (torch .sum (inputs ["action_mask" ]) + 1e-4 )
386+ advantages_forward_micro_batch_std = torch .rsqrt (
387+ torch .sum (
388+ (inputs ["advantages" ] - advantages_forward_micro_batch_mean ) ** 2
389+ * inputs ["action_mask" ]
390+ )
391+ / (torch .sum (inputs ["action_mask" ]) + 1e-4 )
392+ + 1e-8
393+ )
394+ inputs ["advantages" ] = (
395+ (inputs ["advantages" ] - advantages_forward_micro_batch_mean )
396+ * inputs ["action_mask" ]
397+ / (advantages_forward_micro_batch_std )
398+ )
399+
400+ per_token_kl = 0.0
401+
361402 loss , _ = self .policy_loss_fn (
362403 action_log_probs ,
363404 inputs ["old_action_log_probs" ],
364- inputs ["advantages" ]. repeat_interleave ( action_log_probs . size ( - 1 ), dim = - 1 ) ,
405+ inputs ["advantages" ],
365406 per_token_kl ,
366407 inputs ["action_mask" ],
367408 loss_mask = inputs ["loss_mask" ],
@@ -420,10 +461,39 @@ def _criterion(outputs, inputs):
420461 per_token_kl = 0.0
421462 kl = None
422463
464+ (
465+ advantages_forward_micro_batch .repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 )
466+ - self .policy_loss_fn .beta * per_token_kl
467+ )
468+
469+ if self .adv == "REINFORCE_PPB" :
470+
471+ advantages_forward_micro_batch = (
472+ advantages_forward_micro_batch - self .policy_loss_fn .beta * per_token_kl
473+ )
474+ advantages_forward_micro_batch_mean = torch .sum (
475+ advantages_forward_micro_batch * action_mask_forward_micro_batch
476+ ) / (torch .sum (action_mask_forward_micro_batch ) + 1e-4 )
477+ advantages_forward_micro_batch_std = torch .rsqrt (
478+ torch .sum (
479+ (advantages_forward_micro_batch - advantages_forward_micro_batch_mean ) ** 2
480+ * action_mask_forward_micro_batch
481+ )
482+ / (torch .sum (action_mask_forward_micro_batch ) + 1e-4 )
483+ + 1e-8
484+ )
485+ advantages_forward_micro_batch = (
486+ (advantages_forward_micro_batch - advantages_forward_micro_batch_mean )
487+ * action_mask_forward_micro_batch
488+ / (advantages_forward_micro_batch_std )
489+ )
490+
491+ per_token_kl = 0.0
492+
423493 loss , _ = self .policy_loss_fn (
424494 action_log_probs ,
425495 old_action_log_probs_micro_batch ,
426- advantages_forward_micro_batch . repeat_interleave ( action_log_probs . size ( - 1 ), dim = - 1 ) ,
496+ advantages_forward_micro_batch ,
427497 per_token_kl ,
428498 action_mask_forward_micro_batch ,
429499 loss_mask = loss_mask_forward_micro_batch ,
0 commit comments