Skip to content

Commit 93e2d03

Browse files
committed
code format using black and minor nit to use public interface
Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
1 parent c93329f commit 93e2d03

7 files changed

Lines changed: 84 additions & 58 deletions

File tree

scripts/run_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
99
If these things change in the future, we should consider breaking it up.
1010
"""
11+
1112
# Standard
1213
import argparse
1314
import json

tuning/config/configs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,8 @@ class TrainingArguments(transformers.TrainingArguments):
5959
)
6060
tracker: str.lower = field(
6161
default=None,
62-
metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"}
62+
choices=["aim", None, "none"],
63+
metadata={
64+
"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"
65+
},
6366
)

tuning/config/tracker_configs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
# Standard
12
from dataclasses import dataclass
23

4+
35
@dataclass
46
class AimConfig:
57
# Name of the experiment
68
experiment: str = None
7-
# 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
8-
# When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo.
9+
# 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
10+
# When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo.
911
# Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'.
10-
aim_repo: str = None
12+
# See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking.
13+
aim_repo: str = ".aim"
1114
aim_remote_server_ip: str = None
1215
aim_remote_server_port: int = None
13-
# Location of where run_hash is exported, if unspecified this is output to
16+
# Location of where run_hash is exported, if unspecified this is output to
1417
# training_args.output_dir/.aim_run_hash if the output_dir is set else not exported.
1518
aim_run_hash_export_path: str = None

tuning/sft_trainer.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Standard
22
from datetime import datetime
3-
from typing import Optional, Union, List, Dict
3+
from typing import Dict, List, Optional, Union
44
import json
5-
import os, time
5+
import os
6+
import time
67

78
# Third Party
8-
import transformers
9+
from peft.utils.other import fsdp_auto_wrap_policy
910
from transformers import (
1011
AutoModelForCausalLM,
1112
AutoTokenizer,
@@ -16,21 +17,22 @@
1617
TrainerCallback,
1718
)
1819
from transformers.utils import logging
19-
from peft.utils.other import fsdp_auto_wrap_policy
2020
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
2121
import datasets
2222
import fire
23+
import transformers
2324

2425
# Local
2526
from tuning.config import configs, peft_config, tracker_configs
2627
from tuning.data import tokenizer_data_utils
27-
from tuning.utils.config_utils import get_hf_peft_config
28-
from tuning.utils.data_type_utils import get_torch_dtype
2928
from tuning.trackers.tracker import Tracker
3029
from tuning.trackers.tracker_factory import get_tracker
30+
from tuning.utils.config_utils import get_hf_peft_config
31+
from tuning.utils.data_type_utils import get_torch_dtype
3132

3233
logger = logging.get_logger("sft_trainer")
3334

35+
3436
class PeftSavingCallback(TrainerCallback):
3537
def on_save(self, args, state, control, **kwargs):
3638
checkpoint_path = os.path.join(
@@ -41,6 +43,7 @@ def on_save(self, args, state, control, **kwargs):
4143
if "pytorch_model.bin" in os.listdir(checkpoint_path):
4244
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
4345

46+
4447
class FileLoggingCallback(TrainerCallback):
4548
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""
4649

@@ -84,6 +87,7 @@ def _track_loss(self, loss_key, log_file, logs, state):
8487
with open(log_file, "a") as f:
8588
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
8689

90+
8791
def train(
8892
model_args: configs.ModelArguments,
8993
data_args: configs.DataArguments,
@@ -93,7 +97,7 @@ def train(
9397
] = None,
9498
callbacks: Optional[List[TrainerCallback]] = None,
9599
tracker: Optional[Tracker] = None,
96-
exp_metadata: Optional[Dict] = None
100+
exp_metadata: Optional[Dict] = None,
97101
):
98102
"""Call the SFTTrainer
99103
@@ -105,6 +109,11 @@ def train(
105109
peft_config.PromptTuningConfig for prompt tuning | \
106110
None for fine tuning
107111
The peft configuration to pass to trainer
112+
callbacks: List of callbacks to attach with SFTtrainer.
113+
tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS
114+
Initialized using tuning.trackers.tracker_factory.get_tracker
115+
Using configs in tuning.config.tracker_configs
116+
exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker.
108117
"""
109118
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1
110119

@@ -133,7 +142,7 @@ def train(
133142
torch_dtype=get_torch_dtype(model_args.torch_dtype),
134143
use_flash_attention_2=model_args.use_flash_attn,
135144
)
136-
additional_metrics['model_load_time'] = time.time() - model_load_time
145+
additional_metrics["model_load_time"] = time.time() - model_load_time
137146

138147
peft_config = get_hf_peft_config(task_type, peft_config)
139148

@@ -269,16 +278,17 @@ def train(
269278
if tracker is not None:
270279
# Currently tracked only on process zero.
271280
if trainer.is_world_process_zero():
272-
for k,v in additional_metrics.items():
273-
tracker.track(metric=v, name=k, stage='additional_metrics')
274-
tracker.set_params(params=exp_metadata, name='experiment_metadata')
281+
for k, v in additional_metrics.items():
282+
tracker.track(metric=v, name=k, stage="additional_metrics")
283+
tracker.set_params(params=exp_metadata, name="experiment_metadata")
275284

276285
if run_distributed and peft_config is not None:
277286
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
278287
model
279288
)
280289
trainer.train()
281290

291+
282292
def main(**kwargs):
283293
parser = transformers.HfArgumentParser(
284294
dataclass_types=(
@@ -300,6 +310,7 @@ def main(**kwargs):
300310
"--exp_metadata",
301311
type=str,
302312
default=None,
313+
help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'',
303314
)
304315
(
305316
model_args,
@@ -313,18 +324,18 @@ def main(**kwargs):
313324
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
314325

315326
peft_method = additional.peft_method
316-
if peft_method =="lora":
317-
tune_config=lora_config
318-
elif peft_method =="pt":
319-
tune_config=prompt_tuning_config
327+
if peft_method == "lora":
328+
tune_config = lora_config
329+
elif peft_method == "pt":
330+
tune_config = prompt_tuning_config
320331
else:
321-
tune_config=None
332+
tune_config = None
322333

323334
tracker_name = training_args.tracker
324335
if tracker_name == "aim":
325-
tracker_config=aim_config
336+
tracker_config = aim_config
326337
else:
327-
tracker_config=None
338+
tracker_config = None
328339

329340
# Initialize callbacks
330341
file_logger_callback = FileLoggingCallback(logger)
@@ -343,7 +354,9 @@ def main(**kwargs):
343354
try:
344355
metadata = json.loads(additional.exp_metadata)
345356
if metadata is None or not isinstance(metadata, Dict):
346-
logger.warning('metadata cannot be converted to simple k:v dict ignoring')
357+
logger.warning(
358+
"metadata cannot be converted to simple k:v dict ignoring"
359+
)
347360
metadata = None
348361
except:
349362
logger.error("failed while parsing extra metadata. pass a valid json")
@@ -355,8 +368,9 @@ def main(**kwargs):
355368
peft_config=tune_config,
356369
callbacks=callbacks,
357370
tracker=tracker,
358-
exp_metadata=metadata
371+
exp_metadata=metadata,
359372
)
360373

374+
361375
if __name__ == "__main__":
362376
fire.Fire(main)
Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,60 @@
11
# Standard
22
import os
33

4+
# Third Party
5+
from aim.hugging_face import AimCallback
6+
7+
# Local
48
from .tracker import Tracker
59
from tuning.config.tracker_configs import AimConfig
610

7-
# Third Party
8-
from aim.hugging_face import AimCallback
911

1012
class CustomAimCallback(AimCallback):
1113

1214
# A path to export run hash generated by Aim
1315
# This is used to link back to the expriments from outside aimstack
14-
run_hash_export_path = None
16+
hash_export_path = None
1517

1618
def on_init_end(self, args, state, control, **kwargs):
1719

1820
if state and not state.is_world_process_zero:
1921
return
2022

21-
self.setup() # initializes the run_hash
23+
self.setup() # initializes the run_hash
2224

2325
# Store the run hash
2426
# Change default run hash path to output directory
25-
if self.run_hash_export_path is None:
27+
if self.hash_export_path is None:
2628
if args and args.output_dir:
2729
# args.output_dir/.aim_run_hash
28-
self.run_hash_export_path = os.path.join(
29-
args.output_dir,
30-
'.aim_run_hash'
31-
)
30+
self.hash_export_path = os.path.join(
31+
args.output_dir, ".aim_run_hash"
32+
)
3233

33-
if self.run_hash_export_path:
34-
with open(self.run_hash_export_path, 'w') as f:
35-
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')
34+
if self.hash_export_path:
35+
with open(self.hash_export_path, "w") as f:
36+
hash = self.experiment.hash
37+
f.write('{"run_hash":"' + str(hash) + '"}\n')
3638

3739
def on_train_begin(self, args, state, control, model=None, **kwargs):
3840
# call directly to make sure hyper parameters and model info is recorded.
3941
self.setup(args=args, state=state, model=model)
4042

4143
def track_metrics(self, metric, name, context):
42-
if self._run is not None:
43-
self._run.track(metric, name=name, context=context)
44+
run = self.experiment
45+
if run is not None:
46+
run.track(metric, name=name, context=context)
47+
4448

4549
def set_params(self, params, name):
46-
if self._run is not None:
47-
for key, value in params.items():
48-
self._run.set((name, key), value, strict=False)
50+
run = self.experiment
51+
if run is not None:
52+
[run.set((name, key), value, strict=False) for key, value in params.items()]
4953

50-
class AimStackTracker(Tracker):
5154

55+
class AimStackTracker(Tracker):
5256
def __init__(self, tracker_config: AimConfig):
53-
super().__init__(name='aim', tracker_config=tracker_config)
57+
super().__init__(name="aim", tracker_config=tracker_config)
5458

5559
def get_hf_callback(self):
5660
c = self.config
@@ -60,25 +64,25 @@ def get_hf_callback(self):
6064
repo = c.aim_repo
6165
hash_export_path = c.aim_run_hash_export_path
6266

63-
if (ip is not None and port is not None):
67+
if ip is not None and port is not None:
6468
aim_callback = CustomAimCallback(
65-
repo="aim://" + ip +":"+ port + "/",
66-
experiment=exp)
69+
repo="aim://" + ip + ":" + port + "/", experiment=exp
70+
)
6771
if repo:
6872
aim_callback = CustomAimCallback(repo=repo, experiment=exp)
6973
else:
7074
aim_callback = CustomAimCallback(experiment=exp)
7175

72-
aim_callback.run_hash_export_path = hash_export_path
76+
aim_callback.hash_export_path = hash_export_path
7377
self.hf_callback = aim_callback
7478
return self.hf_callback
7579

76-
def track(self, metric, name, stage='additional_metrics'):
77-
context={'subset' : stage}
80+
def track(self, metric, name, stage="additional_metrics"):
81+
context = {"subset": stage}
7882
self.hf_callback.track_metrics(metric, name=name, context=context)
7983

80-
def set_params(self, params, name='extra_params'):
84+
def set_params(self, params, name="extra_params"):
8185
try:
8286
self.hf_callback.set_params(params, name)
8387
except:
84-
pass
88+
pass

tuning/trackers/tracker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generic Tracker API
22

3+
34
class Tracker:
45
def __init__(self, name=None, tracker_config=None) -> None:
56
if tracker_config is not None:
@@ -19,4 +20,4 @@ def track(self, metric, name, stage):
1920
# Object passed here is supposed to be a KV object
2021
# for the parameters to be associated with a run
2122
def set_params(self, params, name):
22-
pass
23+
pass

tuning/trackers/tracker_factory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from .tracker import Tracker
1+
# Local
22
from .aimstack_tracker import AimStackTracker
3+
from .tracker import Tracker
4+
5+
REGISTERED_TRACKERS = {"aim": AimStackTracker}
36

4-
REGISTERED_TRACKERS = {
5-
"aim" : AimStackTracker
6-
}
77

88
def get_tracker(tracker_name, tracker_config):
99
if tracker_name in REGISTERED_TRACKERS:
1010
T = REGISTERED_TRACKERS[tracker_name]
1111
return T(tracker_config)
1212
else:
13-
return Tracker()
13+
return Tracker()

0 commit comments

Comments
 (0)