11# Standard
2+ from datetime import datetime
23from typing import Optional , Union
4+ import json
35import os
46
57# Third Party
1719from trl import DataCollatorForCompletionOnlyLM , SFTTrainer
1820import datasets
1921import fire
20- import torch
2122import transformers
2223
2324# Local
@@ -39,6 +40,43 @@ def on_save(self, args, state, control, **kwargs):
3940 os .remove (os .path .join (checkpoint_path , "pytorch_model.bin" ))
4041
4142
43+ class FileLoggingCallback (TrainerCallback ):
44+ """Exports metrics, e.g., training loss to a file in the checkpoint directory."""
45+
46+ def __init__ (self , logger ):
47+ self .logger = logger
48+
49+ def on_log (self , args , state , control , logs = None , ** kwargs ):
50+ """Checks if this log contains keys of interest, e.g., los, and if so, creates
51+ train_loss.jsonl in the model output dir (if it doesn't already exist),
52+ appends the subdict of the log & dumps the file.
53+ """
54+ # All processes get the logs from this node; only update from process 0.
55+ if not state .is_world_process_zero :
56+ return
57+
58+ log_file_path = os .path .join (args .output_dir , "train_loss.jsonl" )
59+ if logs is not None and "loss" in logs and "epoch" in logs :
60+ try :
61+ # Take the subdict of the last log line; if any log_keys aren't part of this log
62+ # object, asssume this line is something else, e.g., train completion, and skip.
63+ log_obj = {
64+ "name" : "loss" ,
65+ "data" : {
66+ "epoch" : round (logs ["epoch" ], 2 ),
67+ "step" : state .global_step ,
68+ "value" : logs ["loss" ],
69+ "timestamp" : datetime .isoformat (datetime .now ()),
70+ },
71+ }
72+ except KeyError :
73+ return
74+
75+ # append the current log to the jsonl file
76+ with open (log_file_path , "a" ) as log_file :
77+ log_file .write (f"{ json .dumps (log_obj , sort_keys = True )} \n " )
78+
79+
4280def train (
4381 model_args : configs .ModelArguments ,
4482 data_args : configs .DataArguments ,
@@ -175,7 +213,9 @@ def train(
175213 logger .info (f"Validation dataset length is { len (formatted_validation_dataset )} " )
176214
177215 aim_callback = get_aimstack_callback ()
178- callbacks = [aim_callback , PeftSavingCallback ()]
216+ file_logger_callback = FileLoggingCallback (logger )
217+ peft_saving_callback = PeftSavingCallback ()
218+ callbacks = [aim_callback , peft_saving_callback , file_logger_callback ]
179219
180220 if train_args .packing :
181221 logger .info ("Packing is set to True" )
0 commit comments