Skip to content

Commit c67abe0

Browse files
committed
Generic Tracker API with command line arguments.
Tracker now takes command line arguments as config. Aim stack is the default tracker and code contains example to measure additional metrics seamlessly into aimstack like 'model_load_time'
1 parent fc07060 commit c67abe0

6 files changed

Lines changed: 112 additions & 32 deletions

File tree

tuning/aim_loader.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

tuning/config/tracker_configs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from dataclasses import dataclass
2+
3+
@dataclass
4+
class AimConfig:
5+
# 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
6+
# When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo.
7+
# Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'.
8+
repo: str = None
9+
remote_server_ip: str = None
10+
remote_server_port: int = None
11+
# Name of the experiment
12+
experiment: str = None
13+
# Location of where run_hash is exported
14+
run_hash_export_location: str = None

tuning/sft_trainer.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import time
23
from typing import Optional, Union
34

45
import datasets
@@ -10,11 +11,12 @@
1011
from transformers.utils import logging
1112
from transformers import TrainerCallback
1213
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
13-
from tuning.aim_loader import get_aimstack_callback
14-
from tuning.config import configs, peft_config
14+
from tuning.config import configs, peft_config, tracker_configs
1515
from tuning.data import tokenizer_data_utils
1616
from tuning.utils.config_utils import get_hf_peft_config
1717
from tuning.utils.data_type_utils import get_torch_dtype
18+
from tuning.tracker.tracker import Tracker
19+
from tuning.tracker.aimstack_tracker import AimStackTracker
1820

1921
class PeftSavingCallback(TrainerCallback):
2022
def on_save(self, args, state, control, **kwargs):
@@ -24,13 +26,13 @@ def on_save(self, args, state, control, **kwargs):
2426
if "pytorch_model.bin" in os.listdir(checkpoint_path):
2527
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
2628

27-
28-
2929
def train(
3030
model_args: configs.ModelArguments,
3131
data_args: configs.DataArguments,
3232
train_args: configs.TrainingArguments,
3333
peft_config: Optional[Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]] = None,
34+
tracker_name: Optional[str] = None,
35+
tracker_config: Optional[Union[tracker_configs.AimConfig]] = None
3436
):
3537
"""Call the SFTTrainer
3638
@@ -44,7 +46,6 @@ def train(
4446
The peft configuration to pass to trainer
4547
"""
4648
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1
47-
4849
logger = logging.get_logger("sft_trainer")
4950

5051
# Validate parameters
@@ -58,14 +59,29 @@ def train(
5859
train_args.fsdp = ""
5960
train_args.fsdp_config = {'xla':False}
6061

62+
# Initialize the tracker early so we can calculate custom metrics like model_load_time.
63+
64+
if tracker_name == 'aim':
65+
if tracker_config is not None:
66+
tracker = AimStackTracker(tracker_config)
67+
else:
68+
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
69+
else:
70+
logger.info('No tracker set so just set a dummy API which does nothing')
71+
tracker = Tracker()
72+
6173
task_type = "CAUSAL_LM"
74+
75+
model_load_time = time.time()
6276
model = AutoModelForCausalLM.from_pretrained(
6377
model_args.model_name_or_path,
6478
cache_dir=train_args.cache_dir,
6579
torch_dtype=get_torch_dtype(model_args.torch_dtype),
6680
use_flash_attention_2=model_args.use_flash_attn,
6781
)
68-
82+
model_load_time = time.time() - model_load_time
83+
tracker.track(metric=model_load_time, name='model_load_time')
84+
6985
peft_config = get_hf_peft_config(task_type, peft_config)
7086

7187
model.gradient_checkpointing_enable()
@@ -130,8 +146,12 @@ def train(
130146
formatted_dataset = json_dataset['train'].map(lambda example : {f"{data_args.dataset_text_field}" : example[f"{data_args.dataset_text_field}"] + tokenizer.eos_token})
131147
logger.info(f"Dataset length is {len(formatted_dataset)}")
132148

133-
aim_callback = get_aimstack_callback()
134-
callbacks=[aim_callback,PeftSavingCallback()]
149+
# club and register callbacks
150+
callbacks = [PeftSavingCallback()]
151+
152+
tracker_callback = tracker.get_hf_callback()
153+
if tracker_callback is not None:
154+
callbacks.append(tracker_callback)
135155

136156
if train_args.packing:
137157
logger.info("Packing is set to True")
@@ -173,16 +193,30 @@ def main(**kwargs):
173193
configs.DataArguments,
174194
configs.TrainingArguments,
175195
peft_config.LoraConfig,
176-
peft_config.PromptTuningConfig))
196+
peft_config.PromptTuningConfig,
197+
tracker_configs.AimConfig))
177198
parser.add_argument('--peft_method', type=str.lower, choices=['pt', 'lora', None, 'none'], default="pt")
178-
model_args, data_args, training_args, lora_config, prompt_tuning_config, peft_method, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
179-
if peft_method.peft_method =="lora":
199+
parser.add_argument('--tracker', type=str.lower, choices=['aim', None, 'none'], default="aim")
200+
(model_args, data_args, training_args,
201+
lora_config, prompt_tuning_config, aim_config,
202+
additional, _) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
203+
204+
peft_method = additional.peft_method
205+
tracker_name = additional.tracker
206+
207+
if peft_method =="lora":
180208
tune_config=lora_config
181-
elif peft_method.peft_method =="pt":
209+
elif peft_method =="pt":
182210
tune_config=prompt_tuning_config
183211
else:
184212
tune_config=None
185-
train(model_args, data_args, training_args, tune_config)
213+
214+
if tracker_name == "aim":
215+
tracker_config=aim_config
216+
else:
217+
tracker_config=None
218+
219+
train(model_args, data_args, training_args, tune_config, tracker_name, tracker_config)
186220

187221
if __name__ == "__main__":
188222
fire.Fire(main)

tuning/tracker/__init__.py

Whitespace-only changes.

tuning/tracker/aimstack_tracker.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Standard
2+
import os
3+
4+
from tuning.tracker.tracker import Tracker
5+
6+
# Third Party
7+
from aim.hugging_face import AimCallback
8+
9+
class AimStackTracker(Tracker):
10+
11+
def __init__(self, tracker_config):
12+
super().__init__(tracker_config)
13+
c = self.config
14+
if (c.remote_server_ip is not None and
15+
c.remote_server_port is not None):
16+
aim_callback = AimCallback(repo="aim://" + c.remote_server_ip+":"+ c.remote_server_port+ "/",
17+
experiment=c.experiment)
18+
if c.repo:
19+
aim_callback = AimCallback(repo=c.repo, experiment=c.experiment)
20+
else:
21+
aim_callback = AimCallback(experiment=c.experiment)
22+
23+
run = aim_callback.experiment # Initialize Aim run
24+
run_hash = run.hash # Extract the hash
25+
26+
# store the run hash
27+
if c.run_hash_export_location:
28+
with open(c.run_hash_export_location, 'w') as f:
29+
f.write(str(run_hash)+'\n')
30+
31+
# Save Internal State
32+
self.hf_callback = aim_callback
33+
self.run = run
34+
35+
def get_hf_callback(self):
36+
return self.hf_callback
37+
38+
def track(self, metric, name, stage='additional_metrics'):
39+
context={'subset' : stage}
40+
self.run.track(metric, name=name, context=context)

tuning/tracker/tracker.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Generic Tracker API
2+
3+
class Tracker:
4+
def __init__(self, tracker_config) -> None:
5+
self.config = tracker_config
6+
7+
def get_hf_callback():
8+
return None
9+
10+
def track(self, metric, name, stage):
11+
pass

0 commit comments

Comments
 (0)