@@ -32,9 +32,7 @@ def __getitem__(self, idx: int | tuple[int, int]) -> dict:
3232 def collate_fn (examples : list [dict ]) -> tuple [int | None , list [str ], list [dict ]]:
3333 """Batch prompts while preserving a consistent epoch tag."""
3434 epoch_tags = [example .get ("epoch" ) for example in examples ]
35- epoch_tag = (
36- epoch_tags [0 ] if all (tag == epoch_tags [0 ] for tag in epoch_tags ) else None
37- )
35+ epoch_tag = epoch_tags [0 ] if all (tag == epoch_tags [0 ] for tag in epoch_tags ) else None
3836 prompts = [example ["prompt" ] for example in examples ]
3937 metadatas = [example ["metadata" ] for example in examples ]
4038 return epoch_tag , prompts , metadatas
@@ -66,9 +64,7 @@ def __getitem__(self, idx: int | tuple[int, int]) -> dict:
6664 def collate_fn (examples : list [dict ]) -> tuple [int | None , list [str ], list [dict ]]:
6765 """Batch Geneval items while preserving epoch tags."""
6866 epoch_tags = [example .get ("epoch" ) for example in examples ]
69- epoch_tag = (
70- epoch_tags [0 ] if all (tag == epoch_tags [0 ] for tag in epoch_tags ) else None
71- )
67+ epoch_tag = epoch_tags [0 ] if all (tag == epoch_tags [0 ] for tag in epoch_tags ) else None
7268 prompts = [example ["prompt" ] for example in examples ]
7369 metadatas = [example ["metadata" ] for example in examples ]
7470 return epoch_tag , prompts , metadatas
@@ -93,9 +89,7 @@ def __init__(self, dataset: str, split: str = "train"):
9389 self .file_path = os .path .join (dataset , f"{ split } .json" )
9490 self ._prompts = None
9591 self ._metadatas = None
96- self ._file_size = (
97- os .path .getsize (self .file_path ) if os .path .exists (self .file_path ) else 0
98- )
92+ self ._file_size = os .path .getsize (self .file_path ) if os .path .exists (self .file_path ) else 0
9993
10094 # Optimization strategy:
10195 # - For training data, load directly to memory even if large (frequent random access needed)
@@ -185,11 +179,7 @@ def _get_item_lazy(self, idx: int) -> dict:
185179
186180 start_offset = self ._line_offsets [idx ]
187181 # Calculate end offset (start of next line or end of file)
188- end_offset = (
189- self ._line_offsets [idx + 1 ]
190- if idx + 1 < len (self ._line_offsets )
191- else self ._file_size
192- )
182+ end_offset = self ._line_offsets [idx + 1 ] if idx + 1 < len (self ._line_offsets ) else self ._file_size
193183
194184 with open (self .file_path , encoding = "utf-8" ) as f :
195185 f .seek (start_offset )
@@ -257,9 +247,7 @@ def collate_fn(examples: list[dict]) -> tuple[int | None, list[str], list[dict]]
257247 Tuple of (epoch_tag, prompts, metadatas) where epoch_tag is None if inconsistent.
258248 """
259249 epoch_tags = [example .get ("epoch" ) for example in examples ]
260- epoch_tag = (
261- epoch_tags [0 ] if all (tag == epoch_tags [0 ] for tag in epoch_tags ) else None
262- )
250+ epoch_tag = epoch_tags [0 ] if all (tag == epoch_tags [0 ] for tag in epoch_tags ) else None
263251 prompts = [example ["prompt" ] for example in examples ]
264252 metadatas = [example ["metadata" ] for example in examples ]
265253 return epoch_tag , prompts , metadatas
@@ -292,9 +280,9 @@ def __init__(
292280 self .rank = rank
293281 self .seed = seed
294282 self .total_samples = self .num_replicas * self .batch_size
295- assert (
296- self . total_samples % self . k == 0
297- ), f"k can not div n*b, k { k } -num_replicas { num_replicas } -batch_size { batch_size } "
283+ assert self . total_samples % self . k == 0 , (
284+ f"k can not div n*b, k { k } -num_replicas { num_replicas } -batch_size { batch_size } "
285+ )
298286 self .m = self .total_samples // self .k
299287 self .epoch = 0
300288
@@ -305,27 +293,21 @@ def __iter__(self):
305293 g .manual_seed (self .seed + self .epoch )
306294 indices = torch .randperm (len (self .dataset ), generator = g )[: self .m ].tolist ()
307295 repeated_indices = [idx for idx in indices for _ in range (self .k )]
308- shuffled_indices = torch .randperm (
309- len (repeated_indices ), generator = g
310- ).tolist ()
296+ shuffled_indices = torch .randperm (len (repeated_indices ), generator = g ).tolist ()
311297 shuffled_samples = [repeated_indices [i ] for i in shuffled_indices ]
312298 per_card_samples = []
313299 for i in range (self .num_replicas ):
314300 start = i * self .batch_size
315301 end = start + self .batch_size
316- per_card_samples .append (
317- [(self .epoch , idx ) for idx in shuffled_samples [start :end ]]
318- )
302+ per_card_samples .append ([(self .epoch , idx ) for idx in shuffled_samples [start :end ]])
319303 yield per_card_samples [self .rank ]
320304
321305 def set_epoch (self , epoch : int ):
322306 """Set epoch tag to keep RNG in sync across workers."""
323307 self .epoch = epoch
324308
325309
326- def build_dataloaders (
327- cfg , accelerator
328- ) -> tuple [DataLoader , DataLoader , DistributedKRepeatSampler ]:
310+ def build_dataloaders (cfg , accelerator ) -> tuple [DataLoader , DataLoader , DistributedKRepeatSampler ]:
329311 """Construct train/eval dataloaders and sampler with epoch tags.
330312
331313 Args:
@@ -350,9 +332,7 @@ def build_dataloaders(
350332 collate_fn = JsonPromptDataset .collate_fn
351333 else :
352334 msg = "Only general_ocr, geneval, or filtered_prompts prompt_fn supported"
353- raise NotImplementedError (
354- msg
355- )
335+ raise NotImplementedError (msg )
356336
357337 train_sampler = DistributedKRepeatSampler (
358338 dataset = train_dataset ,
0 commit comments