Skip to content

Commit 135eaed

Browse files
committed
improve val
1 parent 260e37a commit 135eaed

2 files changed

Lines changed: 11 additions & 0 deletions

File tree

ajet/backbone/trainer_verl.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,11 @@ def fit(self): # noqa: C901
734734
if is_last_step:
735735
last_val_metrics = val_metrics
736736
metrics.update(val_metrics)
737+
val_print_to_markdown_file_path = self.config.ajet.trainer_common.val_print_to_markdown_file_path
738+
if val_print_to_markdown_file_path:
739+
with open(val_print_to_markdown_file_path, mode="a+") as f:
740+
f.write(str(val_metrics))
741+
f.write('\n')
737742

738743
# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
739744
esi_close_to_expiration = should_save_ckpt_esi(
@@ -782,6 +787,11 @@ def fit(self): # noqa: C901
782787
self.train_dataloader.sampler.update(batch=batch)
783788

784789
self.verl_logger.log(data=metrics, step=self.global_steps)
790+
train_print_to_markdown_file_path = self.config.ajet.trainer_common.train_print_to_markdown_file_path
791+
if train_print_to_markdown_file_path:
792+
with open(train_print_to_markdown_file_path, mode="a+") as f:
793+
f.write(str(metrics))
794+
f.write('\n')
785795
progress_bar.update(1)
786796
self.global_steps += 1
787797

ajet/default_config/ajet_default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ ajet:
236236
val_pass_n: 4
237237
val_only: False
238238
val_print_to_markdown_file_path: null
239+
train_print_to_markdown_file_path: null
239240

240241
# save and test frequency (in step)
241242
save_freq: 20

0 commit comments

Comments
 (0)