Skip to content

Commit 0e60ecd

Browse files
authored
Merge pull request #69 from tharapalanivel/trainer_image
Initial commit for trainer image
2 parents 0f09dab + 4312222 commit 0e60ecd

3 files changed

Lines changed: 286 additions & 1 deletion

File tree

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ durations/*
77
coverage*.xml
88
dist
99
htmlcov
10-
build
1110
test
1211

1312
# IDEs

build/Dockerfile

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
FROM registry.access.redhat.com/ubi9/ubi AS release
2+
3+
ARG CUDA_VERSION=11.8.0
4+
ARG USER=tuning
5+
ARG USER_UID=1000
6+
7+
USER root
8+
9+
RUN dnf remove -y --disableplugin=subscription-manager \
10+
subscription-manager \
11+
# we install newer version of requests via pip
12+
python3.11-requests \
13+
&& dnf install -y make \
14+
# to help with debugging
15+
procps \
16+
&& dnf clean all
17+
18+
ENV LANG=C.UTF-8 \
19+
LC_ALL=C.UTF-8
20+
21+
ENV CUDA_VERSION=$CUDA_VERSION \
22+
NV_CUDA_LIB_VERSION=11.8.0-1 \
23+
NVIDIA_VISIBLE_DEVICES=all \
24+
NVIDIA_DRIVER_CAPABILITIES=compute,utility \
25+
NV_CUDA_CUDART_VERSION=11.8.89-1 \
26+
NV_CUDA_COMPAT_VERSION=520.61.05-1
27+
28+
RUN dnf config-manager \
29+
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
30+
&& dnf install -y \
31+
cuda-cudart-11-8-${NV_CUDA_CUDART_VERSION} \
32+
cuda-compat-11-8-${NV_CUDA_COMPAT_VERSION} \
33+
&& echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf \
34+
&& echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf \
35+
&& dnf clean all
36+
37+
ENV CUDA_HOME="/usr/local/cuda" \
38+
PATH="/usr/local/nvidia/bin:${CUDA_HOME}/bin:${PATH}" \
39+
LD_LIBRARY_PATH="/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$CUDA_HOME/lib64:$CUDA_HOME/extras/CUPTI/lib64:${LD_LIBRARY_PATH}"
40+
41+
42+
ENV NV_NVTX_VERSION=11.8.86-1 \
43+
NV_LIBNPP_VERSION=11.8.0.86-1 \
44+
NV_LIBCUBLAS_VERSION=11.11.3.6-1 \
45+
NV_LIBNCCL_PACKAGE_VERSION=2.15.5-1+cuda11.8
46+
47+
RUN dnf config-manager \
48+
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
49+
&& dnf install -y \
50+
cuda-libraries-11-8-${NV_CUDA_LIB_VERSION} \
51+
cuda-nvtx-11-8-${NV_NVTX_VERSION} \
52+
libnpp-11-8-${NV_LIBNPP_VERSION} \
53+
libcublas-11-8-${NV_LIBCUBLAS_VERSION} \
54+
libnccl-${NV_LIBNCCL_PACKAGE_VERSION} \
55+
&& dnf clean all
56+
57+
ENV NV_CUDA_CUDART_DEV_VERSION=11.8.89-1 \
58+
NV_NVML_DEV_VERSION=11.8.86-1 \
59+
NV_LIBCUBLAS_DEV_VERSION=11.11.3.6-1 \
60+
NV_LIBNPP_DEV_VERSION=11.8.0.86-1 \
61+
NV_LIBNCCL_DEV_PACKAGE_VERSION=2.15.5-1+cuda11.8
62+
63+
RUN dnf config-manager \
64+
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
65+
&& dnf install -y \
66+
cuda-command-line-tools-11-8-${NV_CUDA_LIB_VERSION} \
67+
cuda-libraries-devel-11-8-${NV_CUDA_LIB_VERSION} \
68+
cuda-minimal-build-11-8-${NV_CUDA_LIB_VERSION} \
69+
cuda-cudart-devel-11-8-${NV_CUDA_CUDART_DEV_VERSION} \
70+
cuda-nvml-devel-11-8-${NV_NVML_DEV_VERSION} \
71+
libcublas-devel-11-8-${NV_LIBCUBLAS_DEV_VERSION} \
72+
libnpp-devel-11-8-${NV_LIBNPP_DEV_VERSION} \
73+
libnccl-devel-${NV_LIBNCCL_DEV_PACKAGE_VERSION} \
74+
&& dnf clean all
75+
76+
ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs"
77+
78+
RUN dnf install -y python3.11 git && \
79+
ln -s /usr/bin/python3.11 /bin/python && \
80+
python -m ensurepip --upgrade
81+
82+
RUN mkdir /app
83+
84+
WORKDIR /tmp
85+
RUN python -m pip install packaging && \
86+
python -m pip install --upgrade pip && \
87+
python -m pip install torch && \
88+
python -m pip install wheel
89+
90+
# TODO Move to installing wheel once we have proper releases setup instead of cloning the repo
91+
RUN git clone https://github.com/foundation-model-stack/fms-hf-tuning.git && \
92+
cd fms-hf-tuning && \
93+
python -m pip install -r requirements.txt && \
94+
python -m pip install -r flashattn_requirements.txt && \
95+
python -m pip install -U datasets && \
96+
python -m pip install /tmp/fms-hf-tuning
97+
98+
RUN mkdir -p /licenses
99+
COPY LICENSE /licenses/
100+
101+
COPY launch_training.py /app
102+
RUN chmod +x /app/launch_training.py
103+
104+
# Need a better way to address this hack
105+
RUN touch /.aim_profile && \
106+
chmod -R 777 /.aim_profile && \
107+
mkdir /.cache && \
108+
chmod -R 777 /.cache
109+
110+
# create tuning user and give ownership to dirs
111+
RUN useradd -u $USER_UID tuning -m -g 0 --system && \
112+
chown -R $USER:0 /app && \
113+
chmod -R g+rwX /app
114+
115+
WORKDIR /app
116+
USER ${USER}
117+
118+
CMD [ "tail", "-f", "/dev/null" ]

build/launch_training.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright The SFT Trainer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Script wraps SFT Trainer to run for Train Conductor.
15+
Read SFTTrainer configuration via environment variable `SFT_TRAINER_CONFIG_JSON_PATH`
16+
for the path to the JSON config file with parameters or `SFT_TRAINER_CONFIG_JSON_ENV_VAR`
17+
for the encoded config string to parse.
18+
"""
19+
20+
# Standard
21+
import base64
22+
import os
23+
import pickle
24+
import json
25+
import tempfile
26+
import shutil
27+
import glob
28+
29+
# First Party
30+
import logging
31+
from tuning import sft_trainer
32+
from tuning.config import configs, peft_config
33+
from tuning.utils.merge_model_utils import create_merged_model
34+
35+
# Third Party
36+
import transformers
37+
38+
39+
def txt_to_obj(txt):
40+
base64_bytes = txt.encode("ascii")
41+
message_bytes = base64.b64decode(base64_bytes)
42+
obj = pickle.loads(message_bytes)
43+
return obj
44+
45+
46+
def get_highest_checkpoint(dir_path):
47+
checkpoint_dir = ""
48+
for curr_dir in os.listdir(dir_path):
49+
if curr_dir.startswith("checkpoint"):
50+
if checkpoint_dir:
51+
curr_dir_num = int(checkpoint_dir.split("-")[-1])
52+
new_dir_num = int(curr_dir.split("-")[-1])
53+
if new_dir_num > curr_dir_num:
54+
checkpoint_dir = curr_dir
55+
else:
56+
checkpoint_dir = curr_dir
57+
58+
return checkpoint_dir
59+
60+
61+
def main():
62+
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
63+
logging.basicConfig(level=LOGLEVEL)
64+
65+
logging.info("Attempting to launch training script")
66+
parser = transformers.HfArgumentParser(
67+
dataclass_types=(
68+
configs.ModelArguments,
69+
configs.DataArguments,
70+
configs.TrainingArguments,
71+
peft_config.LoraConfig,
72+
peft_config.PromptTuningConfig,
73+
)
74+
)
75+
peft_method_parsed = "pt"
76+
json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH")
77+
json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")
78+
79+
# accepts either path to JSON file or encoded string config
80+
if json_path:
81+
(
82+
model_args,
83+
data_args,
84+
training_args,
85+
lora_config,
86+
prompt_tuning_config,
87+
) = parser.parse_json_file(json_path, allow_extra_keys=True)
88+
89+
contents = ""
90+
with open(json_path, "r") as f:
91+
contents = json.load(f)
92+
peft_method_parsed = contents.get("peft_method")
93+
logging.debug(f"Input params parsed: {contents}")
94+
elif json_env_var:
95+
job_config_dict = txt_to_obj(json_env_var)
96+
logging.debug(f"Input params parsed: {job_config_dict}")
97+
98+
(
99+
model_args,
100+
data_args,
101+
training_args,
102+
lora_config,
103+
prompt_tuning_config,
104+
) = parser.parse_dict(job_config_dict, allow_extra_keys=True)
105+
106+
peft_method_parsed = job_config_dict.get("peft_method")
107+
else:
108+
raise ValueError(
109+
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
110+
)
111+
112+
tune_config = None
113+
merge_model = False
114+
if peft_method_parsed == "lora":
115+
tune_config = lora_config
116+
merge_model = True
117+
elif peft_method_parsed == "pt":
118+
tune_config = prompt_tuning_config
119+
120+
logging.debug(
121+
f"Parameters used to launch training: model_args {model_args}, data_args {data_args}, training_args {training_args}, tune_config {tune_config}"
122+
)
123+
124+
original_output_dir = training_args.output_dir
125+
with tempfile.TemporaryDirectory() as tempdir:
126+
training_args.output_dir = tempdir
127+
sft_trainer.train(model_args, data_args, training_args, tune_config)
128+
129+
if merge_model:
130+
export_path = os.getenv(
131+
"LORA_MERGE_MODELS_EXPORT_PATH", original_output_dir
132+
)
133+
134+
# get the highest checkpoint dir (last checkpoint)
135+
lora_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
136+
full_checkpoint_dir = os.path.join(
137+
training_args.output_dir, lora_checkpoint_dir
138+
)
139+
140+
logging.info(
141+
f"Merging lora tuned checkpoint {lora_checkpoint_dir} with base model into output path: {export_path}"
142+
)
143+
144+
create_merged_model(
145+
checkpoint_models=full_checkpoint_dir,
146+
export_path=export_path,
147+
base_model=model_args.model_name_or_path,
148+
save_tokenizer=True,
149+
)
150+
else:
151+
# copy last checkpoint into mounted output dir
152+
pt_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
153+
logging.info(
154+
f"Copying last checkpoint {pt_checkpoint_dir} into output dir {original_output_dir}"
155+
)
156+
shutil.copytree(
157+
os.path.join(training_args.output_dir, pt_checkpoint_dir),
158+
original_output_dir,
159+
dirs_exist_ok=True,
160+
)
161+
162+
# copy over any loss logs
163+
for file in glob.glob(f"{training_args.output_dir}/*loss.jsonl"):
164+
shutil.copy(file, original_output_dir)
165+
166+
167+
if __name__ == "__main__":
168+
main()

0 commit comments

Comments
 (0)