Skip to content

Commit 307c7f3

Browse files
committed
fix: apply ruff auto-fixes and resolve remaining linting errors
- Update type annotations to use built-in types (dict, list, tuple) - Use X | None instead of Optional[X] - Fix import ordering - Remove unnecessary else after return - Convert else-if to elif - Extract exception messages to variables
1 parent d1331f7 commit 307c7f3

1 file changed

Lines changed: 36 additions & 38 deletions

File tree

genrl/advantages.py

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Advantage computation utilities for GRPO training."""
22

3-
from typing import Any, Dict, List, Optional, Tuple
3+
from typing import Any
4+
45
import numpy as np
56
from accelerate import Accelerator
67
from loguru import logger
@@ -27,8 +28,8 @@ def _normalize_rewards(rewards: np.ndarray, epsilon: float = EPSILON) -> np.ndar
2728

2829
def _compute_kl_advantages(
2930
gathered_kl: np.ndarray,
30-
kl_stat_tracker: Optional[PerPromptStatTracker],
31-
prompts: Optional[List[str]],
31+
kl_stat_tracker: PerPromptStatTracker | None,
32+
prompts: list[str] | None,
3233
use_per_prompt: bool,
3334
) -> np.ndarray:
3435
"""Compute KL advantages (negative because KL is a penalty).
@@ -45,23 +46,22 @@ def _compute_kl_advantages(
4546
if use_per_prompt and kl_stat_tracker is not None:
4647
# KL is a penalty (larger KL is worse), so use negative KL
4748
return kl_stat_tracker.update(prompts, -gathered_kl)
48-
else:
49-
# Direct normalization on full shape
50-
# Normalize negative KL to maintain consistency with per_prompt mode
51-
return _normalize_rewards(-gathered_kl)
49+
# Direct normalization on full shape
50+
# Normalize negative KL to maintain consistency with per_prompt mode
51+
return _normalize_rewards(-gathered_kl)
5252

5353

5454
def compute_advantages( # noqa: PLR0913, PLR0912, PLR0915
5555
cfg: Config,
5656
accelerator: Accelerator,
5757
pipeline: Any, # Any pipeline with tokenizer.batch_decode method (e.g., diffusers.DiffusionPipeline)
58-
samples: Dict[str, Any],
59-
gathered_rewards: Dict[str, np.ndarray],
58+
samples: dict[str, Any],
59+
gathered_rewards: dict[str, np.ndarray],
6060
gathered_kl: np.ndarray,
61-
stat_tracker: Optional[PerPromptStatTracker],
62-
reward_stat_trackers: Optional[Dict[str, PerPromptStatTracker]],
63-
kl_stat_tracker: Optional[PerPromptStatTracker],
64-
) -> Tuple[np.ndarray, Dict[str, Any]]:
61+
stat_tracker: PerPromptStatTracker | None,
62+
reward_stat_trackers: dict[str, PerPromptStatTracker] | None,
63+
kl_stat_tracker: PerPromptStatTracker | None,
64+
) -> tuple[np.ndarray, dict[str, Any]]:
6565
"""Compute advantages from gathered rewards and KL divergence.
6666
6767
Supports two modes:
@@ -185,31 +185,29 @@ def compute_advantages( # noqa: PLR0913, PLR0912, PLR0915
185185

186186
# Sum weighted advantages
187187
advantages = sum(weighted_advantages_list)
188-
else:
189-
# Mode 1 (default): Weight rewards first, then compute advantages
190-
if cfg.per_prompt_stat_tracking:
191-
if stat_tracker is None:
192-
raise ConfigurationError(
193-
"stat_tracker must be provided when per_prompt_stat_tracking=True"
194-
)
195-
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
196-
prompts = pipeline.tokenizer.batch_decode(
197-
prompt_ids, skip_special_tokens=True
188+
# Mode 1 (default): Weight rewards first, then compute advantages
189+
elif cfg.per_prompt_stat_tracking:
190+
if stat_tracker is None:
191+
msg = "stat_tracker must be provided when per_prompt_stat_tracking=True"
192+
raise ConfigurationError(msg)
193+
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
194+
prompts = pipeline.tokenizer.batch_decode(
195+
prompt_ids, skip_special_tokens=True
196+
)
197+
advantages = stat_tracker.update(prompts, gathered_rewards["avg"])
198+
if accelerator.is_local_main_process:
199+
logger.info(
200+
f"len(prompts) {len(prompts)} | len unique {len(set(prompts))}"
198201
)
199-
advantages = stat_tracker.update(prompts, gathered_rewards["avg"])
200-
if accelerator.is_local_main_process:
201-
logger.info(
202-
f"len(prompts) {len(prompts)} | len unique {len(set(prompts))}"
203-
)
204-
group_size, trained_prompt_num = stat_tracker.get_stats()
205-
zero_std_ratio = calculate_zero_std_ratio(prompts, gathered_rewards)
206-
log_dict = {
207-
"group_size": group_size,
208-
"trained_prompt_num": trained_prompt_num,
209-
"zero_std_ratio": zero_std_ratio,
210-
}
211-
stat_tracker.clear()
212-
else:
213-
advantages = _normalize_rewards(gathered_rewards["avg"])
202+
group_size, trained_prompt_num = stat_tracker.get_stats()
203+
zero_std_ratio = calculate_zero_std_ratio(prompts, gathered_rewards)
204+
log_dict = {
205+
"group_size": group_size,
206+
"trained_prompt_num": trained_prompt_num,
207+
"zero_std_ratio": zero_std_ratio,
208+
}
209+
stat_tracker.clear()
210+
else:
211+
advantages = _normalize_rewards(gathered_rewards["avg"])
214212

215213
return advantages, log_dict

0 commit comments

Comments
 (0)