Skip to content

Commit 7d4b464

Browse files
Merge branch 'main' into convert_fp32_to_bf16
2 parents 35f0908 + d6dc4c9 commit 7d4b464

8 files changed

Lines changed: 153 additions & 80 deletions

File tree

build/nvcr.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ RUN python -m pip install --upgrade pip
4646
RUN pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128
4747

4848
# Install main package + flash attention
49-
RUN COPY . ${SOURCE_DIR}
49+
COPY . ${SOURCE_DIR}
5050
RUN cd ${SOURCE_DIR}
5151
RUN pip install --no-cache-dir ${SOURCE_DIR} && \
5252
pip install --no-cache-dir ${SOURCE_DIR}[flash-attn]

docs/advanced-data-preprocessing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ We recommend inspecting the data and chat template to decide if you need to pass
509509
Depending on various scenarios users might need to decide on how to use chat template with their data or which chat template to use for their use case.
510510
511511
Following are the Guidelines from us in a flow chart :
512-
![guidelines for chat template](docs/images/chat_template_guide.jpg)
512+
![guidelines for chat template](images/chat_template_guide.jpg)
513513
514514
Here are some scenarios addressed in the flow chart:
515515
1. Depending on the model the tokenizer for the model may or may not have a chat template

docs/tuning-techniques.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
## LoRA Tuning Example
2626

27-
Set `peft_method` to `"lora"`. You can additionally pass any arguments from [LoraConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L21).
27+
Set `peft_method` to `"lora"`. You can additionally pass any arguments from [LoraConfig](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig).
2828
```py
2929
# Args you can pass
3030
r: int =8
@@ -340,7 +340,7 @@ You can see details on a sample configuration of Accelerated GPTQ-LoRA [here](ht
340340

341341
To use GPTQ-LoRA technique, you can set the `quantized_lora_config` defined [here](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/acceleration_configs/quantized_lora_config.py). See the Notes section of FMS Acceleration doc [below](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/README.md#fms-acceleration) for usage. The only kernel we are supporting currently is `triton_v2`.
342342

343-
In addition, LoRA tuning technique is required to be used, set `peft_method` to `"lora"` and pass any arguments from [LoraConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L21).
343+
In addition, LoRA tuning technique is required to be used, set `peft_method` to `"lora"` and pass any arguments from [LoraConfig](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig).
344344

345345
Example command to run:
346346

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ flash-attn = ["flash-attn>=2.8.3"]
4848
aim = ["aim>=3.19.0,<4.0"]
4949
mlflow = ["mlflow"]
5050
clearml = ["clearml==2.0.0"]
51-
fms-accel = ["fms-acceleration>=0.6"]
51+
fms-accel = ["fms-acceleration>=0.6.2"]
5252
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
5353
mamba = ["mamba_ssm[causal-conv1d]>=2.0.0,<3.0.0"]
5454
scanner-dev = ["HFResourceScanner>=0.1.0"]

tuning/config/peft_config.py

Lines changed: 105 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
# Standard
1616
from dataclasses import dataclass, field
1717
from enum import Enum
18-
from typing import List
18+
from typing import List, Optional
1919

2020
# Third Party
21+
from peft import LoraConfig as HFLoraConfig
2122
from transformers.utils.quantization_config import Mxfp4Config as HfMxfp4Config
2223

2324

@@ -40,49 +41,125 @@ def to_hf_config(self):
4041

4142

4243
@dataclass
43-
class LoraConfig:
44+
class LoraConfig(HFLoraConfig):
4445
"""
45-
This is the configuration class to store the configuration of a [`LoraModel`].
46+
This is the configuration class that extends peft.LoraConfig with a few defaults.
4647
4748
Args:
48-
r (`int`):
49-
Lora attention dimension (the "rank").
50-
target_modules (List[str]]):
51-
The names of the modules to apply the adapter to. \
52-
If this is specified, only the modules with the specified \
53-
names will be replaced. Please specify modules as per model architecture. \
54-
If the value is ["all-linear"], \
55-
then LORA selects all linear and Conv1D modules as per model architecture, \
56-
except for the output layer.
5749
lora_alpha (`int`):
5850
The alpha parameter for Lora scaling.
5951
lora_dropout (`float`):
6052
The dropout probability for Lora layers.
61-
bias (`str`):
62-
Bias type for LoRA. Can be 'none', 'all' or 'lora_only'. \
63-
If 'all' or 'lora_only', the corresponding biases will be updated during training. \
64-
Be aware that this means that, even when disabling the adapters, the model \
65-
will not produce the same output as the base model would have without adaptation.
6653
"""
6754

68-
r: int = 8
6955
lora_alpha: int = 32
70-
target_modules: List[str] = field(
56+
lora_dropout: float = 0.05
57+
58+
# HACK: The following list of arguments listed below
59+
# is a fix which reduces the field annotation from
60+
# Optional[List[str], str] type to Optional[List[str]] type
61+
# This is done for compatibility with HFArgumentParser
62+
# Please see: https://github.com/huggingface/peft/issues/2798 for further explanation!
63+
target_modules: Optional[List[str]] = field(
7164
default=None,
7265
metadata={
73-
"help": "The names of the modules to apply LORA to. LORA selects modules which either \
74-
completely match or "
75-
'end with one of the strings. If the value is ["all-linear"], \
76-
then LORA selects all linear and Conv1D '
77-
"modules except for the output layer."
66+
"help": (
67+
"List of module names or regex expression of the module names to replace with LoRA."
68+
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
69+
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D "
70+
"(if the model is a PreTrainedModel, the output layer excluded). "
71+
"If not specified, modules will be chosen according to the model architecture, "
72+
"If the architecture is not known, an error will be raised -- "
73+
"in this case, you should specify the target modules manually. "
74+
"To avoid targeting any modules (because you want to apply `target_parameters`) "
75+
", set `target_modules=[]`."
76+
),
7877
},
7978
)
80-
target_parameters: List[str] = field(
79+
exclude_modules: Optional[List[str]] = field(
8180
default=None,
82-
metadata={"help": "The names/regex of the parameters to apply LORA to"},
81+
metadata={
82+
"help": (
83+
"List of module names or regex expression of the module names to exclude from Lora."
84+
)
85+
},
8386
)
84-
bias = "none"
85-
lora_dropout: float = 0.05
87+
init_lora_weights: bool = field(
88+
default=True,
89+
metadata={
90+
"help": (
91+
"How to initialize the weights of the LoRA layers. "
92+
"Passing True (default) results in the default initialization from "
93+
"the reference implementation from "
94+
"Microsoft, with the LoRA B weight being set to 0. "
95+
"This means that without further training, "
96+
"the LoRA adapter will be a no-op. "
97+
"Setting the initialization to False leads to random initialization of "
98+
"LoRA A and B, meaning that LoRA is not a no-op before training; "
99+
"this setting is intended for debugging purposes."
100+
),
101+
},
102+
)
103+
layers_to_transform: Optional[list[int]] = field(
104+
default=None,
105+
metadata={
106+
"help": (
107+
"The layer indexes to transform, is this argument is specified, "
108+
"PEFT will transform only the layers indexes that are specified inside this list. "
109+
"If a single integer is passed, PEFT will transform only the layer at this index. "
110+
"This only works when target_modules is a list of str."
111+
)
112+
},
113+
)
114+
layers_pattern: Optional[list[str]] = field(
115+
default=None,
116+
metadata={
117+
"help": (
118+
"The layer pattern name, used only if `layers_to_transform` is different to None "
119+
"and if the layer pattern is not in the common layers pattern. "
120+
"This only works when target_modules is a list of str. "
121+
"This should target the `nn.ModuleList` of the "
122+
"model, which is often called `'layers'` or `'h'`."
123+
)
124+
},
125+
)
126+
trainable_token_indices: Optional[list[int]] = field(
127+
default=None,
128+
metadata={
129+
"help": (
130+
"Lets you specify which token indices to selectively fine-tune "
131+
"without requiring to re-train the "
132+
"whole embedding matrix using the `peft.TrainableTokensModel` method. "
133+
"You can specify token indices in two ways. "
134+
"Either you specify a list of indices which will then target the model's input "
135+
"embedding layer (or, if not found, `embed_tokens`). "
136+
"(Not supported yet) Alternatively, you can specify a dictionary "
137+
"where the key is the name of the embedding module "
138+
"and the values are the list of token indices, e.g. "
139+
"`{'embed_tokens': [0, 1, ...]}`. Note that training "
140+
"with FSDP requires `use_orig_params=True` to "
141+
"avoid issues with non-uniform `requires_grad`."
142+
)
143+
},
144+
)
145+
loftq_config: Optional[dict] = field(
146+
default_factory=dict,
147+
metadata={
148+
"help": (
149+
"The configuration of LoftQ. If this is passed, "
150+
"then LoftQ will be used to quantize the backbone "
151+
"weights and initialize Lora layers. Also set `init_lora_weights='loftq'` "
152+
"in this case."
153+
)
154+
},
155+
)
156+
157+
def __post_init__(self):
158+
# If target_modules is a single-element list, convert it into a plain string
159+
if self.target_modules == ["all-linear"]:
160+
self.target_modules = "all-linear"
161+
162+
super().__post_init__()
86163

87164

88165
@dataclass

tuning/data/setup_dataprocessor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def process_dataconfig_file(
159159

160160

161161
# Data Format 1: Pretokenized Data
162-
def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
162+
def _get_pretokenized_dataset_handlers(
163+
data_args: DataArguments, is_eval_present, is_eval_tokenized
164+
):
163165

164166
# if the provided train dataset is pretokenized
165167
# however user provides formatting flags, error out
@@ -168,6 +170,7 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
168170
or data_args.data_formatter_template
169171
or data_args.dataset_text_field
170172
or data_args.instruction_template
173+
or data_args.dataset_conversation_field
171174
):
172175
raise ValueError(
173176
"fields response_template, data_formatter_template,"
@@ -177,7 +180,7 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
177180

178181
# if the train dataset is pretokenized
179182
# ensure validation dataset is pretokenized otherwise error out
180-
if is_eval_tokenized:
183+
if is_eval_present and not is_eval_tokenized:
181184
raise ValueError(
182185
"validation data should be pretokenized to be used \
183186
along with pretokenized train data"
@@ -189,7 +192,9 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
189192

190193
### Data format 2
191194
# pylint: disable=unused-argument
192-
def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False):
195+
def _get_dataset_formatting_handlers(
196+
data_args: DataArguments, packing, is_padding_free=False
197+
):
193198

194199
if data_args.response_template is None:
195200
if packing is False:
@@ -253,7 +258,7 @@ def _get_chat_dataset_handlers(data_args, tokenizer_kwargs):
253258
fn_kwargs["formatted_text_column_name"] = data_args.dataset_text_field
254259
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs
255260
if data_args.dataset_conversation_field is not None:
256-
fn_kwargs["conversation_column"] = data_args.dataset_conversation_field
261+
fn_kwargs["conversation_column_name"] = data_args.dataset_conversation_field
257262

258263
kwargs = {"fn_kwargs": fn_kwargs, "batched": False, "remove_columns": "all"}
259264

@@ -284,14 +289,14 @@ def _get_default_dataset_handlers(data_args, tokenizer_kwargs):
284289

285290

286291
### Vsion Data Format
287-
def _get_vision_dataset_handlers(data_args, processor_kwargs):
292+
def _get_vision_dataset_handlers(data_args: DataArguments, processor_kwargs):
288293

289294
handlers = []
290295

291296
# First data handler configuration
292297
handler_fn_kwargs1 = {
293-
"dataset_text_field": data_args.dataset_text_field,
294-
"conversation_column": data_args.dataset_text_field,
298+
"formatted_text_column_name": data_args.dataset_text_field,
299+
"conversation_column_name": data_args.dataset_conversation_field,
295300
}
296301
handler_kwargs1 = {
297302
"fn_kwargs": handler_fn_kwargs1,
@@ -403,7 +408,7 @@ def _process_raw_data_args(
403408
if is_traindata_tokenized:
404409
# Data Format 1: Pretokenized Data
405410
handlers, dataset_text_field = _get_pretokenized_dataset_handlers(
406-
data_args, (is_eval_dataset_present and not is_evaldata_tokenized)
411+
data_args, is_eval_dataset_present, is_evaldata_tokenized
407412
)
408413
elif processor and data_args.dataset_text_field and data_args.dataset_image_field:
409414

0 commit comments

Comments
 (0)