Skip to content

Commit b9380e4

Browse files
authored
Merge pull request #51 from foundation-model-stack/validation-loss-file
feat: track validation loss in logs file
2 parents 24e7385 + 1e817ca commit b9380e4

1 file changed

Lines changed: 26 additions & 19 deletions

File tree

tuning/sft_trainer.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,34 +47,41 @@ def __init__(self, logger):
4747
self.logger = logger
4848

4949
def on_log(self, args, state, control, logs=None, **kwargs):
50-
"""Checks if this log contains keys of interest, e.g., los, and if so, creates
50+
"""Checks if this log contains keys of interest, e.g., loss, and if so, creates
5151
train_loss.jsonl in the model output dir (if it doesn't already exist),
5252
appends the subdict of the log & dumps the file.
5353
"""
5454
# All processes get the logs from this node; only update from process 0.
5555
if not state.is_world_process_zero:
5656
return
5757

58+
# separate evaluation loss with train loss
5859
log_file_path = os.path.join(args.output_dir, "train_loss.jsonl")
60+
eval_log_file_path = os.path.join(args.output_dir, "eval_loss.jsonl")
5961
if logs is not None and "loss" in logs and "epoch" in logs:
60-
try:
61-
# Take the subdict of the last log line; if any log_keys aren't part of this log
62-
# object, asssume this line is something else, e.g., train completion, and skip.
63-
log_obj = {
64-
"name": "loss",
65-
"data": {
66-
"epoch": round(logs["epoch"], 2),
67-
"step": state.global_step,
68-
"value": logs["loss"],
69-
"timestamp": datetime.isoformat(datetime.now()),
70-
},
71-
}
72-
except KeyError:
73-
return
74-
75-
# append the current log to the jsonl file
76-
with open(log_file_path, "a") as log_file:
77-
log_file.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
62+
self._track_loss("loss", log_file_path, logs, state)
63+
elif logs is not None and "eval_loss" in logs and "epoch" in logs:
64+
self._track_loss("eval_loss", eval_log_file_path, logs, state)
65+
66+
def _track_loss(self, loss_key, log_file, logs, state):
67+
try:
68+
# Take the subdict of the last log line; if any log_keys aren't part of this log
69+
# object, assume this line is something else, e.g., train completion, and skip.
70+
log_obj = {
71+
"name": loss_key,
72+
"data": {
73+
"epoch": round(logs["epoch"], 2),
74+
"step": state.global_step,
75+
"value": logs[loss_key],
76+
"timestamp": datetime.isoformat(datetime.now()),
77+
},
78+
}
79+
except KeyError:
80+
return
81+
82+
# append the current log to the jsonl file
83+
with open(log_file, "a") as f:
84+
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
7885

7986

8087
def train(

0 commit comments

Comments
 (0)