@@ -589,19 +589,21 @@ def main():
589589
590590 for data in tqdm (eval_dataloader , desc = f"Evaluating Epoch { epoch } " ):
591591 if args .is_vlm :
592- plosses , _ , acces = eagle3_model (
593- input_ids = data ["input_ids" ].cuda (),
594- attention_mask = data ["attention_mask" ].cuda (),
595- loss_mask = data ["loss_mask" ].cuda (),
596- pixel_values = data ["pixel_values" ].cuda (),
597- image_grid_thw = data ["image_grid_thw" ].cuda (),
598- )
592+ with torch .no_grad ():
593+ plosses , _ , acces = eagle3_model (
594+ input_ids = data ["input_ids" ].cuda (),
595+ attention_mask = data ["attention_mask" ].cuda (),
596+ loss_mask = data ["loss_mask" ].cuda (),
597+ pixel_values = data ["pixel_values" ].cuda (),
598+ image_grid_thw = data ["image_grid_thw" ].cuda (),
599+ )
599600 else :
600- plosses , _ , acces = eagle3_model (
601- input_ids = data ["input_ids" ].cuda (),
602- attention_mask = data ["attention_mask" ].cuda (),
603- loss_mask = data ["loss_mask" ].cuda (),
604- )
601+ with torch .no_grad ():
602+ plosses , _ , acces = eagle3_model (
603+ input_ids = data ["input_ids" ].cuda (),
604+ attention_mask = data ["attention_mask" ].cuda (),
605+ loss_mask = data ["loss_mask" ].cuda (),
606+ )
605607 acces = torch .stack (acces ).cpu ().tolist ()
606608
607609 eval_acces = [eval_acces [i ] + [acces [i ]] for i in range (len (acces ))]
0 commit comments