@@ -51,7 +51,7 @@ def _compute_kl_advantages(
5151 return _normalize_rewards (- gathered_kl )
5252
5353
54- def compute_advantages (
54+ def compute_advantages ( # noqa: PLR0913, PLR0912, PLR0915
5555 cfg : Config ,
5656 accelerator : Accelerator ,
5757 pipeline : Any , # Any pipeline with tokenizer.batch_decode method (e.g., diffusers.DiffusionPipeline)
@@ -88,10 +88,11 @@ def compute_advantages(
8888 # Mode 2: Compute advantages for each reward separately, then weight them
8989 if cfg .per_prompt_stat_tracking :
9090 if reward_stat_trackers is None :
91- raise ConfigurationError (
91+ msg = (
9292 "reward_stat_trackers must be provided when weight_advantages=True "
9393 "and per_prompt_stat_tracking=True"
9494 )
95+ raise ConfigurationError (msg )
9596 prompt_ids = accelerator .gather (samples ["prompt_ids" ]).cpu ().numpy ()
9697 prompts = pipeline .tokenizer .batch_decode (
9798 prompt_ids , skip_special_tokens = True
@@ -100,7 +101,7 @@ def compute_advantages(
100101 # Compute advantages for each raw reward separately
101102 weighted_advantages_list = []
102103
103- for reward_name in cfg .reward_fn . keys () :
104+ for reward_name in cfg .reward_fn :
104105 raw_reward_key = f"{ reward_name } _raw"
105106 # Compute advantage for this reward using its own stat_tracker
106107 reward_advantages = reward_stat_trackers [reward_name ].update (
@@ -113,10 +114,11 @@ def compute_advantages(
113114 # Handle KL as a reward: compute advantage for KL, then subtract with kl_reward weight
114115 if cfg .sample .kl_reward > 0 :
115116 if kl_stat_tracker is None :
116- raise ConfigurationError (
117+ msg = (
117118 "kl_stat_tracker must be provided when weight_advantages=True, "
118119 "per_prompt_stat_tracking=True, and kl_reward > 0"
119120 )
121+ raise ConfigurationError (msg )
120122 kl_advantages = _compute_kl_advantages (
121123 gathered_kl , kl_stat_tracker , prompts , use_per_prompt = True
122124 )
@@ -133,13 +135,13 @@ def compute_advantages(
133135 f"len(prompts) { len (prompts )} | len unique { len (set (prompts ))} "
134136 )
135137 # Use the first stat_tracker for logging
136- first_reward_name = list ( cfg .reward_fn . keys ())[ 0 ]
138+ first_reward_name = next ( iter ( cfg .reward_fn ))
137139 group_size , trained_prompt_num = reward_stat_trackers [
138140 first_reward_name
139141 ].get_stats ()
140142 # Calculate zero_std_ratio for each raw reward
141143 zero_std_ratios = {}
142- for reward_name in cfg .reward_fn . keys () :
144+ for reward_name in cfg .reward_fn :
143145 raw_reward_key = f"{ reward_name } _raw"
144146 zero_std_ratios [f"zero_std_ratio_{ reward_name } " ] = (
145147 calculate_zero_std_ratio (
@@ -160,7 +162,7 @@ def compute_advantages(
160162 else :
161163 # No per-prompt tracking: compute advantages for each raw reward, then weight
162164 weighted_advantages_list = []
163- for reward_name in cfg .reward_fn . keys () :
165+ for reward_name in cfg .reward_fn :
164166 raw_reward_key = f"{ reward_name } _raw"
165167 raw_rewards = gathered_rewards [
166168 raw_reward_key
0 commit comments