Skip to content

Commit 3f0ae30

Browse files
Ximingwang-09纬杭
andauthored
Optimize Evaluation Phase By adding torch.no_grad() (#237)
* eval optimize * fix --------- Co-authored-by: 纬杭 <ximing.wxm@antgroup.com>
1 parent ca2fc85 commit 3f0ae30

2 files changed

Lines changed: 22 additions & 19 deletions

File tree

scripts/train_eagle3_offline.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,14 @@ def main():
487487
eval_plosses = [[] for _ in range(eagle3_model.length)]
488488

489489
for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"):
490-
plosses, _, acces = eagle3_model(
491-
input_ids=data["input_ids"].cuda(),
492-
attention_mask=data["attention_mask"].cuda(),
493-
loss_mask=data["loss_mask"].unsqueeze(-1).cuda(),
494-
hidden_states=data["hidden_state"].cuda(),
495-
target=data["target"].cuda(),
496-
)
490+
with torch.no_grad():
491+
plosses, _, acces = eagle3_model(
492+
input_ids=data["input_ids"].cuda(),
493+
attention_mask=data["attention_mask"].cuda(),
494+
loss_mask=data["loss_mask"].unsqueeze(-1).cuda(),
495+
hidden_states=data["hidden_state"].cuda(),
496+
target=data["target"].cuda(),
497+
)
497498
acces = torch.stack(acces).cpu().tolist()
498499
eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))]
499500
eval_plosses = [

scripts/train_eagle3_online.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -589,19 +589,21 @@ def main():
589589

590590
for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"):
591591
if args.is_vlm:
592-
plosses, _, acces = eagle3_model(
593-
input_ids=data["input_ids"].cuda(),
594-
attention_mask=data["attention_mask"].cuda(),
595-
loss_mask=data["loss_mask"].cuda(),
596-
pixel_values=data["pixel_values"].cuda(),
597-
image_grid_thw=data["image_grid_thw"].cuda(),
598-
)
592+
with torch.no_grad():
593+
plosses, _, acces = eagle3_model(
594+
input_ids=data["input_ids"].cuda(),
595+
attention_mask=data["attention_mask"].cuda(),
596+
loss_mask=data["loss_mask"].cuda(),
597+
pixel_values=data["pixel_values"].cuda(),
598+
image_grid_thw=data["image_grid_thw"].cuda(),
599+
)
599600
else:
600-
plosses, _, acces = eagle3_model(
601-
input_ids=data["input_ids"].cuda(),
602-
attention_mask=data["attention_mask"].cuda(),
603-
loss_mask=data["loss_mask"].cuda(),
604-
)
601+
with torch.no_grad():
602+
plosses, _, acces = eagle3_model(
603+
input_ids=data["input_ids"].cuda(),
604+
attention_mask=data["attention_mask"].cuda(),
605+
loss_mask=data["loss_mask"].cuda(),
606+
)
605607
acces = torch.stack(acces).cpu().tolist()
606608

607609
eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))]

0 commit comments

Comments
 (0)