Skip to content

Commit 1b41f88

Browse files
committed
fix: ignore UP038 rule and format code
- Add UP038 to ignore list in pyproject.toml (isinstance doesn't support X | Y syntax) - Run ruff format to fix formatting issues - Fix trailing whitespace and end-of-file issues
1 parent 0956cd0 commit 1b41f88

19 files changed

Lines changed: 181 additions & 524 deletions

genrl/advantages.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ def compute_advantages(
9494
)
9595
raise ConfigurationError(msg)
9696
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
97-
prompts = pipeline.tokenizer.batch_decode(
98-
prompt_ids, skip_special_tokens=True
99-
)
97+
prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
10098

10199
# Compute advantages for each raw reward separately
102100
weighted_advantages_list = []
@@ -119,9 +117,7 @@ def compute_advantages(
119117
"per_prompt_stat_tracking=True, and kl_reward > 0"
120118
)
121119
raise ConfigurationError(msg)
122-
kl_advantages = _compute_kl_advantages(
123-
gathered_kl, kl_stat_tracker, prompts, use_per_prompt=True
124-
)
120+
kl_advantages = _compute_kl_advantages(gathered_kl, kl_stat_tracker, prompts, use_per_prompt=True)
125121
# Subtract KL advantages with kl_reward as weight
126122
# kl_advantages is already negative (because KL is a penalty),
127123
# so we directly multiply by kl_reward to get a negative contribution
@@ -131,22 +127,16 @@ def compute_advantages(
131127
advantages = sum(weighted_advantages_list)
132128

133129
if accelerator.is_local_main_process:
134-
logger.info(
135-
f"len(prompts) {len(prompts)} | len unique {len(set(prompts))}"
136-
)
130+
logger.info(f"len(prompts) {len(prompts)} | len unique {len(set(prompts))}")
137131
# Use the first stat_tracker for logging
138132
first_reward_name = next(iter(cfg.reward_fn))
139-
group_size, trained_prompt_num = reward_stat_trackers[
140-
first_reward_name
141-
].get_stats()
133+
group_size, trained_prompt_num = reward_stat_trackers[first_reward_name].get_stats()
142134
# Calculate zero_std_ratio for each raw reward
143135
zero_std_ratios = {}
144136
for reward_name in cfg.reward_fn:
145137
raw_reward_key = f"{reward_name}_raw"
146-
zero_std_ratios[f"zero_std_ratio_{reward_name}"] = (
147-
calculate_zero_std_ratio(
148-
prompts, gathered_rewards, reward_key=f"ori_{raw_reward_key}"
149-
)
138+
zero_std_ratios[f"zero_std_ratio_{reward_name}"] = calculate_zero_std_ratio(
139+
prompts, gathered_rewards, reward_key=f"ori_{raw_reward_key}"
150140
)
151141
log_dict = {
152142
"group_size": group_size,
@@ -164,9 +154,7 @@ def compute_advantages(
164154
weighted_advantages_list = []
165155
for reward_name in cfg.reward_fn:
166156
raw_reward_key = f"{reward_name}_raw"
167-
raw_rewards = gathered_rewards[
168-
raw_reward_key
169-
] # Shape: (total_batch_size, num_timesteps)
157+
raw_rewards = gathered_rewards[raw_reward_key] # Shape: (total_batch_size, num_timesteps)
170158
# Direct normalization on full shape
171159
reward_advantages = _normalize_rewards(raw_rewards)
172160
# Weight the advantages
@@ -175,9 +163,7 @@ def compute_advantages(
175163

176164
# Handle KL as a reward: compute advantage for KL, then subtract with kl_reward weight
177165
if cfg.sample.kl_reward > 0:
178-
kl_advantages = _compute_kl_advantages(
179-
gathered_kl, None, None, use_per_prompt=False
180-
)
166+
kl_advantages = _compute_kl_advantages(gathered_kl, None, None, use_per_prompt=False)
181167
# Subtract KL advantages with kl_reward as weight
182168
# kl_advantages is already negative (because KL is a penalty),
183169
# so we directly multiply by kl_reward to get a negative contribution
@@ -191,14 +177,10 @@ def compute_advantages(
191177
msg = "stat_tracker must be provided when per_prompt_stat_tracking=True"
192178
raise ConfigurationError(msg)
193179
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
194-
prompts = pipeline.tokenizer.batch_decode(
195-
prompt_ids, skip_special_tokens=True
196-
)
180+
prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
197181
advantages = stat_tracker.update(prompts, gathered_rewards["avg"])
198182
if accelerator.is_local_main_process:
199-
logger.info(
200-
f"len(prompts) {len(prompts)} | len unique {len(set(prompts))}"
201-
)
183+
logger.info(f"len(prompts) {len(prompts)} | len unique {len(set(prompts))}")
202184
group_size, trained_prompt_num = stat_tracker.get_stats()
203185
zero_std_ratio = calculate_zero_std_ratio(prompts, gathered_rewards)
204186
log_dict = {

genrl/config.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ class AccelerateConfig:
3333
@dataclass
3434
class TrainConfig:
3535
batch_size: int = 8
36-
gradient_accumulation_steps: int | None = (
37-
None # if None, derive from sample settings
38-
)
36+
gradient_accumulation_steps: int | None = None # if None, derive from sample settings
3937
num_inner_epochs: int = 1
4038
timestep_fraction: float = 0.99
4139
beta: float = 0.0
@@ -156,30 +154,15 @@ def build_dataclass(cls, src: dict[str, Any]):
156154
if field_name in src:
157155
val = src[field_name]
158156
# dispatch based on nested dataclass types
159-
if (
160-
isinstance(field_def.default, FSDPConfig)
161-
or field_def.type == FSDPConfig
162-
):
157+
if isinstance(field_def.default, FSDPConfig) or field_def.type == FSDPConfig:
163158
kwargs[field_name] = build_dataclass(FSDPConfig, val)
164-
elif (
165-
isinstance(field_def.default, AccelerateConfig)
166-
or field_def.type == AccelerateConfig
167-
):
159+
elif isinstance(field_def.default, AccelerateConfig) or field_def.type == AccelerateConfig:
168160
kwargs[field_name] = build_dataclass(AccelerateConfig, val)
169-
elif (
170-
isinstance(field_def.default, TrainConfig)
171-
or field_def.type == TrainConfig
172-
):
161+
elif isinstance(field_def.default, TrainConfig) or field_def.type == TrainConfig:
173162
kwargs[field_name] = build_dataclass(TrainConfig, val)
174-
elif (
175-
isinstance(field_def.default, SampleConfig)
176-
or field_def.type == SampleConfig
177-
):
163+
elif isinstance(field_def.default, SampleConfig) or field_def.type == SampleConfig:
178164
kwargs[field_name] = build_dataclass(SampleConfig, val)
179-
elif (
180-
isinstance(field_def.default, ProjectPaths)
181-
or field_def.type == ProjectPaths
182-
):
165+
elif isinstance(field_def.default, ProjectPaths) or field_def.type == ProjectPaths:
183166
kwargs[field_name] = build_dataclass(ProjectPaths, val)
184167
else:
185168
kwargs[field_name] = val

genrl/data.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def __getitem__(self, idx: int | tuple[int, int]) -> dict:
3232
def collate_fn(examples: list[dict]) -> tuple[int | None, list[str], list[dict]]:
3333
"""Batch prompts while preserving a consistent epoch tag."""
3434
epoch_tags = [example.get("epoch") for example in examples]
35-
epoch_tag = (
36-
epoch_tags[0] if all(tag == epoch_tags[0] for tag in epoch_tags) else None
37-
)
35+
epoch_tag = epoch_tags[0] if all(tag == epoch_tags[0] for tag in epoch_tags) else None
3836
prompts = [example["prompt"] for example in examples]
3937
metadatas = [example["metadata"] for example in examples]
4038
return epoch_tag, prompts, metadatas
@@ -66,9 +64,7 @@ def __getitem__(self, idx: int | tuple[int, int]) -> dict:
6664
def collate_fn(examples: list[dict]) -> tuple[int | None, list[str], list[dict]]:
6765
"""Batch Geneval items while preserving epoch tags."""
6866
epoch_tags = [example.get("epoch") for example in examples]
69-
epoch_tag = (
70-
epoch_tags[0] if all(tag == epoch_tags[0] for tag in epoch_tags) else None
71-
)
67+
epoch_tag = epoch_tags[0] if all(tag == epoch_tags[0] for tag in epoch_tags) else None
7268
prompts = [example["prompt"] for example in examples]
7369
metadatas = [example["metadata"] for example in examples]
7470
return epoch_tag, prompts, metadatas
@@ -93,9 +89,7 @@ def __init__(self, dataset: str, split: str = "train"):
9389
self.file_path = os.path.join(dataset, f"{split}.json")
9490
self._prompts = None
9591
self._metadatas = None
96-
self._file_size = (
97-
os.path.getsize(self.file_path) if os.path.exists(self.file_path) else 0
98-
)
92+
self._file_size = os.path.getsize(self.file_path) if os.path.exists(self.file_path) else 0
9993

10094
# Optimization strategy:
10195
# - For training data, load directly to memory even if large (frequent random access needed)
@@ -185,11 +179,7 @@ def _get_item_lazy(self, idx: int) -> dict:
185179

186180
start_offset = self._line_offsets[idx]
187181
# Calculate end offset (start of next line or end of file)
188-
end_offset = (
189-
self._line_offsets[idx + 1]
190-
if idx + 1 < len(self._line_offsets)
191-
else self._file_size
192-
)
182+
end_offset = self._line_offsets[idx + 1] if idx + 1 < len(self._line_offsets) else self._file_size
193183

194184
with open(self.file_path, encoding="utf-8") as f:
195185
f.seek(start_offset)
@@ -257,9 +247,7 @@ def collate_fn(examples: list[dict]) -> tuple[int | None, list[str], list[dict]]
257247
Tuple of (epoch_tag, prompts, metadatas) where epoch_tag is None if inconsistent.
258248
"""
259249
epoch_tags = [example.get("epoch") for example in examples]
260-
epoch_tag = (
261-
epoch_tags[0] if all(tag == epoch_tags[0] for tag in epoch_tags) else None
262-
)
250+
epoch_tag = epoch_tags[0] if all(tag == epoch_tags[0] for tag in epoch_tags) else None
263251
prompts = [example["prompt"] for example in examples]
264252
metadatas = [example["metadata"] for example in examples]
265253
return epoch_tag, prompts, metadatas
@@ -292,9 +280,9 @@ def __init__(
292280
self.rank = rank
293281
self.seed = seed
294282
self.total_samples = self.num_replicas * self.batch_size
295-
assert (
296-
self.total_samples % self.k == 0
297-
), f"k can not div n*b, k{k}-num_replicas{num_replicas}-batch_size{batch_size}"
283+
assert self.total_samples % self.k == 0, (
284+
f"k can not div n*b, k{k}-num_replicas{num_replicas}-batch_size{batch_size}"
285+
)
298286
self.m = self.total_samples // self.k
299287
self.epoch = 0
300288

@@ -305,27 +293,21 @@ def __iter__(self):
305293
g.manual_seed(self.seed + self.epoch)
306294
indices = torch.randperm(len(self.dataset), generator=g)[: self.m].tolist()
307295
repeated_indices = [idx for idx in indices for _ in range(self.k)]
308-
shuffled_indices = torch.randperm(
309-
len(repeated_indices), generator=g
310-
).tolist()
296+
shuffled_indices = torch.randperm(len(repeated_indices), generator=g).tolist()
311297
shuffled_samples = [repeated_indices[i] for i in shuffled_indices]
312298
per_card_samples = []
313299
for i in range(self.num_replicas):
314300
start = i * self.batch_size
315301
end = start + self.batch_size
316-
per_card_samples.append(
317-
[(self.epoch, idx) for idx in shuffled_samples[start:end]]
318-
)
302+
per_card_samples.append([(self.epoch, idx) for idx in shuffled_samples[start:end]])
319303
yield per_card_samples[self.rank]
320304

321305
def set_epoch(self, epoch: int):
322306
"""Set epoch tag to keep RNG in sync across workers."""
323307
self.epoch = epoch
324308

325309

326-
def build_dataloaders(
327-
cfg, accelerator
328-
) -> tuple[DataLoader, DataLoader, DistributedKRepeatSampler]:
310+
def build_dataloaders(cfg, accelerator) -> tuple[DataLoader, DataLoader, DistributedKRepeatSampler]:
329311
"""Construct train/eval dataloaders and sampler with epoch tags.
330312
331313
Args:
@@ -350,9 +332,7 @@ def build_dataloaders(
350332
collate_fn = JsonPromptDataset.collate_fn
351333
else:
352334
msg = "Only general_ocr, geneval, or filtered_prompts prompt_fn supported"
353-
raise NotImplementedError(
354-
msg
355-
)
335+
raise NotImplementedError(msg)
356336

357337
train_sampler = DistributedKRepeatSampler(
358338
dataset=train_dataset,

0 commit comments

Comments
 (0)