Skip to content

Commit d1331f7

Browse files
committed
fix: resolve ruff linting errors in advantages.py
- Add noqa comments for complexity warnings (PLR0913, PLR0912, PLR0915) - Extract exception messages to variables (TRY003, EM101) - Use 'key in dict' instead of 'key in dict.keys()' (SIM118) - Use next(iter(...)) instead of list(...)[0] (RUF015)
1 parent c436b41 commit d1331f7

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

genrl/advantages.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _compute_kl_advantages(
5151
return _normalize_rewards(-gathered_kl)
5252

5353

54-
def compute_advantages(
54+
def compute_advantages( # noqa: PLR0913, PLR0912, PLR0915
5555
cfg: Config,
5656
accelerator: Accelerator,
5757
pipeline: Any, # Any pipeline with tokenizer.batch_decode method (e.g., diffusers.DiffusionPipeline)
@@ -88,10 +88,11 @@ def compute_advantages(
8888
# Mode 2: Compute advantages for each reward separately, then weight them
8989
if cfg.per_prompt_stat_tracking:
9090
if reward_stat_trackers is None:
91-
raise ConfigurationError(
91+
msg = (
9292
"reward_stat_trackers must be provided when weight_advantages=True "
9393
"and per_prompt_stat_tracking=True"
9494
)
95+
raise ConfigurationError(msg)
9596
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
9697
prompts = pipeline.tokenizer.batch_decode(
9798
prompt_ids, skip_special_tokens=True
@@ -100,7 +101,7 @@ def compute_advantages(
100101
# Compute advantages for each raw reward separately
101102
weighted_advantages_list = []
102103

103-
for reward_name in cfg.reward_fn.keys():
104+
for reward_name in cfg.reward_fn:
104105
raw_reward_key = f"{reward_name}_raw"
105106
# Compute advantage for this reward using its own stat_tracker
106107
reward_advantages = reward_stat_trackers[reward_name].update(
@@ -113,10 +114,11 @@ def compute_advantages(
113114
# Handle KL as a reward: compute advantage for KL, then subtract with kl_reward weight
114115
if cfg.sample.kl_reward > 0:
115116
if kl_stat_tracker is None:
116-
raise ConfigurationError(
117+
msg = (
117118
"kl_stat_tracker must be provided when weight_advantages=True, "
118119
"per_prompt_stat_tracking=True, and kl_reward > 0"
119120
)
121+
raise ConfigurationError(msg)
120122
kl_advantages = _compute_kl_advantages(
121123
gathered_kl, kl_stat_tracker, prompts, use_per_prompt=True
122124
)
@@ -133,13 +135,13 @@ def compute_advantages(
133135
f"len(prompts) {len(prompts)} | len unique {len(set(prompts))}"
134136
)
135137
# Use the first stat_tracker for logging
136-
first_reward_name = list(cfg.reward_fn.keys())[0]
138+
first_reward_name = next(iter(cfg.reward_fn))
137139
group_size, trained_prompt_num = reward_stat_trackers[
138140
first_reward_name
139141
].get_stats()
140142
# Calculate zero_std_ratio for each raw reward
141143
zero_std_ratios = {}
142-
for reward_name in cfg.reward_fn.keys():
144+
for reward_name in cfg.reward_fn:
143145
raw_reward_key = f"{reward_name}_raw"
144146
zero_std_ratios[f"zero_std_ratio_{reward_name}"] = (
145147
calculate_zero_std_ratio(
@@ -160,7 +162,7 @@ def compute_advantages(
160162
else:
161163
# No per-prompt tracking: compute advantages for each raw reward, then weight
162164
weighted_advantages_list = []
163-
for reward_name in cfg.reward_fn.keys():
165+
for reward_name in cfg.reward_fn:
164166
raw_reward_key = f"{reward_name}_raw"
165167
raw_rewards = gathered_rewards[
166168
raw_reward_key

0 commit comments

Comments
 (0)