@@ -47,34 +47,41 @@ def __init__(self, logger):
4747 self .logger = logger
4848
4949 def on_log (self , args , state , control , logs = None , ** kwargs ):
50- """Checks if this log contains keys of interest, e.g., los , and if so, creates
50+ """Checks if this log contains keys of interest, e.g., loss , and if so, creates
5151 train_loss.jsonl in the model output dir (if it doesn't already exist),
5252 appends the subdict of the log & dumps the file.
5353 """
5454 # All processes get the logs from this node; only update from process 0.
5555 if not state .is_world_process_zero :
5656 return
5757
58+ # separate evaluation loss with train loss
5859 log_file_path = os .path .join (args .output_dir , "train_loss.jsonl" )
60+ eval_log_file_path = os .path .join (args .output_dir , "eval_loss.jsonl" )
5961 if logs is not None and "loss" in logs and "epoch" in logs :
60- try :
61- # Take the subdict of the last log line; if any log_keys aren't part of this log
62- # object, asssume this line is something else, e.g., train completion, and skip.
63- log_obj = {
64- "name" : "loss" ,
65- "data" : {
66- "epoch" : round (logs ["epoch" ], 2 ),
67- "step" : state .global_step ,
68- "value" : logs ["loss" ],
69- "timestamp" : datetime .isoformat (datetime .now ()),
70- },
71- }
72- except KeyError :
73- return
74-
75- # append the current log to the jsonl file
76- with open (log_file_path , "a" ) as log_file :
77- log_file .write (f"{ json .dumps (log_obj , sort_keys = True )} \n " )
62+ self ._track_loss ("loss" , log_file_path , logs , state )
63+ elif logs is not None and "eval_loss" in logs and "epoch" in logs :
64+ self ._track_loss ("eval_loss" , eval_log_file_path , logs , state )
65+
66+ def _track_loss (self , loss_key , log_file , logs , state ):
67+ try :
68+ # Take the subdict of the last log line; if any log_keys aren't part of this log
69+ # object, assume this line is something else, e.g., train completion, and skip.
70+ log_obj = {
71+ "name" : loss_key ,
72+ "data" : {
73+ "epoch" : round (logs ["epoch" ], 2 ),
74+ "step" : state .global_step ,
75+ "value" : logs [loss_key ],
76+ "timestamp" : datetime .isoformat (datetime .now ()),
77+ },
78+ }
79+ except KeyError :
80+ return
81+
82+ # append the current log to the jsonl file
83+ with open (log_file , "a" ) as f :
84+ f .write (f"{ json .dumps (log_obj , sort_keys = True )} \n " )
7885
7986
8087def train (
0 commit comments