Skip to content

Commit 4e983ea

Browse files
committed
fix: code refactor
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 1c910f3 commit 4e983ea

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

  • plugins/online-data-mixing/src/fms_acceleration_odm/odm

plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,13 @@ def update_sampling_weights(self, model, accelerator, state):
349349
if accelerator:
350350
eval_dataset_dict[self.id2cat[c]] = (
351351
accelerator.prepare(self.eval_dataset_dict_dl[self.id2cat[c]])
352-
if self.eval_dataset_dict_dl[self.id2cat[c]]
352+
if self.eval_dataset_dict_dl.get(self.id2cat[c], None)
353353
else None
354354
)
355355
else:
356-
eval_dataset_dict[self.id2cat[c]] = self.eval_dataset_dict_dl[
357-
self.id2cat[c]
358-
]
356+
eval_dataset_dict[self.id2cat[c]] = self.eval_dataset_dict_dl.get(
357+
self.id2cat[c], None
358+
)
359359
for c in tqdm(
360360
range(self.total_categories), total=self.total_categories, desc="Categories"
361361
): # for trian loss you dont need to iterate over eval dataset.

0 commit comments

Comments
 (0)