Skip to content

Commit 7e261d2

Browse files
authored
feat: Support gpt-oss class of models with flash attention 3 support (#603)
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent 607984c commit 7e261d2

10 files changed

Lines changed: 213 additions & 41 deletions

File tree

build/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,4 @@ USER ${USER}
237237
COPY --from=python-installations /home/${USER}/.local /home/${USER}/.local
238238
ENV PYTHONPATH="/home/${USER}/.local/lib/python${PYTHON_VERSION}/site-packages"
239239

240-
CMD [ "python", "/app/accelerate_launch.py" ]
240+
CMD [ "python", "/app/accelerate_launch.py" ]

build/nvcr.Dockerfile

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
## Global Args #################################################################
16+
## If the nvcr container is updated, ensure to check the torch and python
17+
## installation version inside the dockerfile before pushing changes.
18+
ARG NVCR_IMAGE_VERSION=25.02-py3
19+
20+
# This is based on what is inside the NVCR image already
21+
ARG PYTHON_VERSION=3.12
22+
23+
## Base Layer ##################################################################
24+
FROM nvcr.io/nvidia/pytorch:${NVCR_IMAGE_VERSION} AS dev
25+
26+
ARG USER=root
27+
ARG USER_UID=0
28+
ARG WORKDIR=/app
29+
ARG SOURCE_DIR=${WORKDIR}/fms-hf-tuning
30+
31+
ARG ENABLE_FMS_ACCELERATION=true
32+
ARG ENABLE_AIM=true
33+
ARG ENABLE_ALORA=true
34+
ARG ENABLE_MLFLOW=true
35+
ARG ENABLE_SCANNER=true
36+
ARG ENABLE_CLEARML=true
37+
ARG ENABLE_TRITON_KERNELS=true
38+
ARG ENABLE_MAMBA_SUPPORT=true
39+
40+
# Ensures to always build mamba_ssm from source
41+
ENV PIP_NO_BINARY=mamba-ssm,mamba_ssm
42+
43+
RUN python -m pip install --upgrade pip
44+
45+
# upgrade torch as the base layer contains only torch 2.7
46+
RUN pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128
47+
48+
# Install main package + flash attention
49+
RUN COPY . ${SOURCE_DIR}
50+
RUN cd ${SOURCE_DIR}
51+
RUN pip install --no-cache-dir ${SOURCE_DIR} && \
52+
pip install --no-cache-dir ${SOURCE_DIR}[flash-attn]
53+
54+
# Optional extras
55+
RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
56+
pip install --no-cache-dir ${SOURCE_DIR}[fms-accel] && \
57+
python -m fms_acceleration.cli install fms_acceleration_peft && \
58+
python -m fms_acceleration.cli install fms_acceleration_foak && \
59+
python -m fms_acceleration.cli install fms_acceleration_aadp && \
60+
python -m fms_acceleration.cli install fms_acceleration_moe; \
61+
fi
62+
63+
RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \
64+
pip install --no-cache-dir ${SOURCE_DIR}[activated-lora]; \
65+
fi
66+
RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
67+
pip install --no-cache-dir ${SOURCE_DIR}[aim]; \
68+
fi
69+
RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
70+
pip install --no-cache-dir ${SOURCE_DIR}[mlflow]; \
71+
fi
72+
RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
73+
pip install --no-cache-dir ${SOURCE_DIR}[scanner-dev]; \
74+
fi
75+
RUN if [[ "${ENABLE_CLEARML}" == "true" ]]; then \
76+
pip install --no-cache-dir ${SOURCE_DIR}[clearml]; \
77+
fi
78+
RUN if [[ "${ENABLE_MAMBA_SUPPORT}" == "true" ]]; then \
79+
pip install --no-cache-dir ${SOURCE_DIR}[mamba]; \
80+
fi
81+
RUN if [[ "${ENABLE_TRITON_KERNELS}" == "true" ]]; then \
82+
pip install --no-cache-dir "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"; \
83+
fi
84+
85+
RUN chmod -R g+rwX $WORKDIR /tmp
86+
RUN mkdir -p /.cache && chmod -R 777 /.cache
87+
88+
# Set Triton environment variables for qLoRA
89+
ENV TRITON_HOME="/tmp/triton_home"
90+
ENV TRITON_DUMP_DIR="/tmp/triton_dump_dir"
91+
ENV TRITON_CACHE_DIR="/tmp/triton_cache_dir"
92+
ENV TRITON_OVERRIDE_DIR="/tmp/triton_override_dir"
93+
94+
WORKDIR $WORKDIR
95+
96+
CMD ["${SOURCE_DIR}/build/accelerate_launch.py"]

pyproject.toml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,19 @@ classifiers=[
2727
"Programming Language :: Python :: 3.12"
2828
]
2929
dependencies = [
30-
"numpy>=1.26.4,<2.0",
31-
"accelerate>=0.20.3,!=0.34,<1.7",
32-
"transformers>=4.53.0,<=4.55.4",
33-
"torch>2.6.0,<=2.8.0",
30+
"numpy>=1.26.4,<2.2.0",
31+
"accelerate>=1.9.0,<2.0.0",
32+
"transformers>=4.55.0,<=4.55.4",
33+
"torch>2.7.0,<2.9.0",
3434
"sentencepiece>=0.1.99,<0.3",
35-
"tokenizers>=0.13.3,<1.0",
35+
"tokenizers<=0.22",
3636
"tqdm>=4.66.2,<5.0",
37-
"trl>=0.13,<0.20",
38-
"peft>=0.15.0,<=0.15.2",
39-
"protobuf>=5.28.0,<6.0.0",
40-
"datasets>=3.5.0,<4.0",
37+
"trl>=0.19.1,<0.20.0",
38+
"peft>=0.17.0,<0.18.0",
39+
"datasets>=4.0.0,<5.0.0",
4140
"simpleeval>=0.9.13,<2.0",
4241
"pillow>=11.0.0,<12.0",
42+
"kernels<=0.9.0",
4343
]
4444

4545
[project.optional-dependencies]
@@ -54,14 +54,10 @@ mamba = ["mamba_ssm[causal-conv1d]>=2.0.0,<3.0.0"]
5454
scanner-dev = ["HFResourceScanner>=0.1.0"]
5555
activated-lora = ["alora>=0.3.0"]
5656

57-
5857
[tool.setuptools.packages.find]
5958
exclude = ["tests", "tests.*"]
6059
namespaces = false
6160

62-
[tool.setuptools_scm]
63-
version_file = "tuning/_version.py"
64-
6561
[project.urls]
6662
Homepage = "https://github.com/foundation-model-stack/fms-hf-tuning"
6763
Repository = "https://github.com/foundation-model-stack/fms-hf-tuning"

tests/test_sft_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def test_parse_arguments(job_config):
407407
_,
408408
_,
409409
_,
410+
_,
410411
) = sft_trainer.parse_arguments(parser, job_config_copy)
411412
assert str(model_args.torch_dtype) == "torch.bfloat16"
412413
assert data_args.dataset_text_field == "output"
@@ -432,6 +433,7 @@ def test_parse_arguments_defaults(job_config):
432433
_,
433434
_,
434435
_,
436+
_,
435437
) = sft_trainer.parse_arguments(parser, job_config_defaults)
436438
assert str(model_args.torch_dtype) == "torch.bfloat16"
437439
assert model_args.use_flash_attn is False
@@ -454,7 +456,9 @@ def test_parse_arguments_peft_method(job_config):
454456
_,
455457
_,
456458
_,
459+
_,
457460
) = sft_trainer.parse_arguments(parser, job_config_pt)
461+
458462
assert isinstance(tune_config, peft_config.PromptTuningConfig)
459463

460464
job_config_lora = copy.deepcopy(job_config)
@@ -471,6 +475,7 @@ def test_parse_arguments_peft_method(job_config):
471475
_,
472476
_,
473477
_,
478+
_,
474479
) = sft_trainer.parse_arguments(parser, job_config_lora)
475480
assert isinstance(tune_config, peft_config.LoraConfig)
476481
assert not tune_config.target_modules

tuning/config/configs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ class ModelArguments:
6161
tokenizer classes."
6262
},
6363
)
64+
flash_attn_implementation: Optional[str] = field(
65+
default="flash_attention_2",
66+
metadata={
67+
"help": "Flash Attention implementation to choose.\
68+
For almost all models don't need to pass or use default i.e. flash_attention_2.\
69+
Requires use_flash_attn=True flag to be enabled."
70+
},
71+
)
6472

6573

6674
@dataclass

tuning/config/peft_config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,30 @@
1414

1515
# Standard
1616
from dataclasses import dataclass, field
17+
from enum import Enum
1718
from typing import List
1819

20+
# Third Party
21+
from transformers.utils.quantization_config import Mxfp4Config as HfMxfp4Config
22+
23+
24+
class QUANT_METHOD(Enum):
25+
MXFP4 = "mxfp4"
26+
27+
28+
class PEFT_METHOD(Enum):
29+
PT = "pt"
30+
LORA = "lora"
31+
ALORA = "alora"
32+
33+
34+
@dataclass
35+
class Mxfp4Config:
36+
dequantize: bool = True
37+
38+
def to_hf_config(self):
39+
return HfMxfp4Config(dequantize=self.dequantize)
40+
1941

2042
@dataclass
2143
class LoraConfig:
@@ -55,6 +77,10 @@ class LoraConfig:
5577
"modules except for the output layer."
5678
},
5779
)
80+
target_parameters: List[str] = field(
81+
default=None,
82+
metadata={"help": "The names/regex of the parameters to apply LORA to"},
83+
)
5884
bias = "none"
5985
lora_dropout: float = 0.05
6086

tuning/data/data_processors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _load_dataset(
146146
load_path = builder if builder else data_path
147147

148148
try:
149-
return datasets.load_dataset(path=load_path, **load_kwargs)
149+
return datasets.load_dataset(load_path, **load_kwargs)
150150
except DatasetNotFoundError as e:
151151
# Reraise with a more context-specific message if needed
152152
raise e

0 commit comments

Comments
 (0)