Skip to content

Commit 517652d

Browse files
Add file logger callback & export train loss json file (#22)
* Add file logger callback & export train loss json file Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * only update logs from process 0 Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Export logs in jsonl format Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> * Formatting Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> --------- Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent bd0c917 commit 517652d

1 file changed

Lines changed: 42 additions & 2 deletions

File tree

tuning/sft_trainer.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Standard
2+
from datetime import datetime
23
from typing import Optional, Union
4+
import json
35
import os
46

57
# Third Party
@@ -17,7 +19,6 @@
1719
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
1820
import datasets
1921
import fire
20-
import torch
2122
import transformers
2223

2324
# Local
@@ -39,6 +40,43 @@ def on_save(self, args, state, control, **kwargs):
3940
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
4041

4142

43+
class FileLoggingCallback(TrainerCallback):
44+
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""
45+
46+
def __init__(self, logger):
47+
self.logger = logger
48+
49+
def on_log(self, args, state, control, logs=None, **kwargs):
50+
"""Checks if this log contains keys of interest, e.g., los, and if so, creates
51+
train_loss.jsonl in the model output dir (if it doesn't already exist),
52+
appends the subdict of the log & dumps the file.
53+
"""
54+
# All processes get the logs from this node; only update from process 0.
55+
if not state.is_world_process_zero:
56+
return
57+
58+
log_file_path = os.path.join(args.output_dir, "train_loss.jsonl")
59+
if logs is not None and "loss" in logs and "epoch" in logs:
60+
try:
61+
# Take the subdict of the last log line; if any log_keys aren't part of this log
62+
# object, asssume this line is something else, e.g., train completion, and skip.
63+
log_obj = {
64+
"name": "loss",
65+
"data": {
66+
"epoch": round(logs["epoch"], 2),
67+
"step": state.global_step,
68+
"value": logs["loss"],
69+
"timestamp": datetime.isoformat(datetime.now()),
70+
},
71+
}
72+
except KeyError:
73+
return
74+
75+
# append the current log to the jsonl file
76+
with open(log_file_path, "a") as log_file:
77+
log_file.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
78+
79+
4280
def train(
4381
model_args: configs.ModelArguments,
4482
data_args: configs.DataArguments,
@@ -175,7 +213,9 @@ def train(
175213
logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}")
176214

177215
aim_callback = get_aimstack_callback()
178-
callbacks = [aim_callback, PeftSavingCallback()]
216+
file_logger_callback = FileLoggingCallback(logger)
217+
peft_saving_callback = PeftSavingCallback()
218+
callbacks = [aim_callback, peft_saving_callback, file_logger_callback]
179219

180220
if train_args.packing:
181221
logger.info("Packing is set to True")

0 commit comments

Comments
 (0)