Skip to content

Commit b0c170c

Browse files
committed
separate callbacks from train
Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
1 parent 9357243 commit b0c170c

2 files changed

Lines changed: 25 additions & 53 deletions

File tree

tuning/aim_loader.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

tuning/sft_trainer.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Standard
22
from datetime import datetime
3-
from typing import Optional, Union
3+
from typing import Optional, Union, List
44
import json
55
import os, time
66

77
# Third Party
8-
from peft.utils.other import fsdp_auto_wrap_policy
8+
import transformers
99
from transformers import (
1010
AutoModelForCausalLM,
1111
AutoTokenizer,
@@ -16,10 +16,10 @@
1616
TrainerCallback,
1717
)
1818
from transformers.utils import logging
19+
from peft.utils.other import fsdp_auto_wrap_policy
1920
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
2021
import datasets
2122
import fire
22-
import transformers
2323

2424
# Local
2525
from tuning.config import configs, peft_config, tracker_configs
@@ -29,6 +29,7 @@
2929
from tuning.tracker.tracker import Tracker
3030
from tuning.tracker.aimstack_tracker import AimStackTracker
3131

32+
logger = logging.get_logger("sft_trainer")
3233

3334
class PeftSavingCallback(TrainerCallback):
3435
def on_save(self, args, state, control, **kwargs):
@@ -83,15 +84,15 @@ def _track_loss(self, loss_key, log_file, logs, state):
8384
with open(log_file, "a") as f:
8485
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
8586

86-
8787
def train(
8888
model_args: configs.ModelArguments,
8989
data_args: configs.DataArguments,
9090
train_args: configs.TrainingArguments,
9191
peft_config: Optional[
9292
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
9393
] = None,
94-
tracker_config: Optional[Union[tracker_configs.AimConfig]] = None
94+
callbacks: Optional[List[TrainerCallback]] = None,
95+
tracker: Optional[Tracker] = None,
9596
):
9697
"""Call the SFTTrainer
9798
@@ -105,7 +106,6 @@ def train(
105106
The peft configuration to pass to trainer
106107
"""
107108
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1
108-
logger = logging.get_logger("sft_trainer")
109109

110110
# Validate parameters
111111
if (not isinstance(train_args.num_train_epochs, float)) or (
@@ -122,17 +122,6 @@ def train(
122122
train_args.fsdp = ""
123123
train_args.fsdp_config = {"xla": False}
124124

125-
# Initialize the tracker early so we can calculate custom metrics like model_load_time.
126-
tracker_name = train_args.tracker
127-
if tracker_name == 'aim':
128-
if tracker_config is not None:
129-
tracker = AimStackTracker(tracker_config)
130-
else:
131-
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
132-
else:
133-
logger.info('No tracker set so just set a dummy API which does nothing')
134-
tracker = Tracker()
135-
136125
task_type = "CAUSAL_LM"
137126

138127
model_load_time = time.time()
@@ -259,15 +248,6 @@ def train(
259248
)
260249
packing = False
261250

262-
# club and register callbacks
263-
file_logger_callback = FileLoggingCallback(logger)
264-
peft_saving_callback = PeftSavingCallback()
265-
callbacks = [peft_saving_callback, file_logger_callback]
266-
267-
tracker_callback = tracker.get_hf_callback()
268-
if tracker_callback is not None:
269-
callbacks.append(tracker_callback)
270-
271251
trainer = SFTTrainer(
272252
model=model,
273253
tokenizer=tokenizer,
@@ -288,7 +268,6 @@ def train(
288268
)
289269
trainer.train()
290270

291-
292271
def main(**kwargs):
293272
parser = transformers.HfArgumentParser(
294273
dataclass_types=(
@@ -331,8 +310,26 @@ def main(**kwargs):
331310
else:
332311
tracker_config=None
333312

334-
train(model_args, data_args, training_args, tune_config, tracker_config)
313+
# Initialize the tracker early so we can calculate custom metrics like model_load_time.
314+
tracker_name = training_args.tracker
315+
if tracker_name == 'aim':
316+
if tracker_config is not None:
317+
tracker = AimStackTracker(tracker_config)
318+
else:
319+
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
320+
else:
321+
tracker = Tracker()
322+
323+
# Initialize callbacks
324+
file_logger_callback = FileLoggingCallback(logger)
325+
peft_saving_callback = PeftSavingCallback()
326+
callbacks = [peft_saving_callback, file_logger_callback]
327+
328+
tracker_callback = tracker.get_hf_callback()
329+
if tracker_callback is not None:
330+
callbacks.append(tracker_callback)
335331

332+
train(model_args, data_args, training_args, tune_config, callbacks, tracker)
336333

337334
if __name__ == "__main__":
338335
fire.Fire(main)

0 commit comments

Comments
 (0)