11"""Advantage computation utilities for GRPO training."""
22
3- from typing import Any , Dict , List , Optional , Tuple
3+ from typing import Any
4+
45import numpy as np
56from accelerate import Accelerator
67from loguru import logger
@@ -27,8 +28,8 @@ def _normalize_rewards(rewards: np.ndarray, epsilon: float = EPSILON) -> np.ndar
2728
2829def _compute_kl_advantages (
2930 gathered_kl : np .ndarray ,
30- kl_stat_tracker : Optional [ PerPromptStatTracker ] ,
31- prompts : Optional [ List [ str ]] ,
31+ kl_stat_tracker : PerPromptStatTracker | None ,
32+ prompts : list [ str ] | None ,
3233 use_per_prompt : bool ,
3334) -> np .ndarray :
3435 """Compute KL advantages (negative because KL is a penalty).
@@ -45,23 +46,22 @@ def _compute_kl_advantages(
4546 if use_per_prompt and kl_stat_tracker is not None :
4647 # KL is a penalty (larger KL is worse), so use negative KL
4748 return kl_stat_tracker .update (prompts , - gathered_kl )
48- else :
49- # Direct normalization on full shape
50- # Normalize negative KL to maintain consistency with per_prompt mode
51- return _normalize_rewards (- gathered_kl )
49+ # Direct normalization on full shape
50+ # Normalize negative KL to maintain consistency with per_prompt mode
51+ return _normalize_rewards (- gathered_kl )
5252
5353
5454def compute_advantages ( # noqa: PLR0913, PLR0912, PLR0915
5555 cfg : Config ,
5656 accelerator : Accelerator ,
5757 pipeline : Any , # Any pipeline with tokenizer.batch_decode method (e.g., diffusers.DiffusionPipeline)
58- samples : Dict [str , Any ],
59- gathered_rewards : Dict [str , np .ndarray ],
58+ samples : dict [str , Any ],
59+ gathered_rewards : dict [str , np .ndarray ],
6060 gathered_kl : np .ndarray ,
61- stat_tracker : Optional [ PerPromptStatTracker ] ,
62- reward_stat_trackers : Optional [ Dict [ str , PerPromptStatTracker ]] ,
63- kl_stat_tracker : Optional [ PerPromptStatTracker ] ,
64- ) -> Tuple [np .ndarray , Dict [str , Any ]]:
61+ stat_tracker : PerPromptStatTracker | None ,
62+ reward_stat_trackers : dict [ str , PerPromptStatTracker ] | None ,
63+ kl_stat_tracker : PerPromptStatTracker | None ,
64+ ) -> tuple [np .ndarray , dict [str , Any ]]:
6565 """Compute advantages from gathered rewards and KL divergence.
6666
6767 Supports two modes:
@@ -185,31 +185,29 @@ def compute_advantages( # noqa: PLR0913, PLR0912, PLR0915
185185
186186 # Sum weighted advantages
187187 advantages = sum (weighted_advantages_list )
188- else :
189- # Mode 1 (default): Weight rewards first, then compute advantages
190- if cfg .per_prompt_stat_tracking :
191- if stat_tracker is None :
192- raise ConfigurationError (
193- "stat_tracker must be provided when per_prompt_stat_tracking=True"
194- )
195- prompt_ids = accelerator .gather (samples ["prompt_ids" ]).cpu ().numpy ()
196- prompts = pipeline .tokenizer .batch_decode (
197- prompt_ids , skip_special_tokens = True
188+ # Mode 1 (default): Weight rewards first, then compute advantages
189+ elif cfg .per_prompt_stat_tracking :
190+ if stat_tracker is None :
191+ msg = "stat_tracker must be provided when per_prompt_stat_tracking=True"
192+ raise ConfigurationError (msg )
193+ prompt_ids = accelerator .gather (samples ["prompt_ids" ]).cpu ().numpy ()
194+ prompts = pipeline .tokenizer .batch_decode (
195+ prompt_ids , skip_special_tokens = True
196+ )
197+ advantages = stat_tracker .update (prompts , gathered_rewards ["avg" ])
198+ if accelerator .is_local_main_process :
199+ logger .info (
200+ f"len(prompts) { len (prompts )} | len unique { len (set (prompts ))} "
198201 )
199- advantages = stat_tracker .update (prompts , gathered_rewards ["avg" ])
200- if accelerator .is_local_main_process :
201- logger .info (
202- f"len(prompts) { len (prompts )} | len unique { len (set (prompts ))} "
203- )
204- group_size , trained_prompt_num = stat_tracker .get_stats ()
205- zero_std_ratio = calculate_zero_std_ratio (prompts , gathered_rewards )
206- log_dict = {
207- "group_size" : group_size ,
208- "trained_prompt_num" : trained_prompt_num ,
209- "zero_std_ratio" : zero_std_ratio ,
210- }
211- stat_tracker .clear ()
212- else :
213- advantages = _normalize_rewards (gathered_rewards ["avg" ])
202+ group_size , trained_prompt_num = stat_tracker .get_stats ()
203+ zero_std_ratio = calculate_zero_std_ratio (prompts , gathered_rewards )
204+ log_dict = {
205+ "group_size" : group_size ,
206+ "trained_prompt_num" : trained_prompt_num ,
207+ "zero_std_ratio" : zero_std_ratio ,
208+ }
209+ stat_tracker .clear ()
210+ else :
211+ advantages = _normalize_rewards (gathered_rewards ["avg" ])
214212
215213 return advantages , log_dict
0 commit comments