Skip to content

Commit 3c9ba77

Browse files
committed
Updated LoraConfig to subclass from peft.LoraConfig
1 parent bc39f95 commit 3c9ba77

4 files changed

Lines changed: 20 additions & 43 deletions

File tree

docs/tuning-techniques.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
## LoRA Tuning Example
2626

27-
Set `peft_method` to `"lora"`. You can additionally pass any arguments from [LoraConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L21).
27+
Set `peft_method` to `"lora"`. You can additionally pass any arguments from [LoraConfig](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig).
2828
```py
2929
# Args you can pass
3030
r: int =8
@@ -340,7 +340,7 @@ You can see details on a sample configuration of Accelerated GPTQ-LoRA [here](ht
340340

341341
To use GPTQ-LoRA technique, you can set the `quantized_lora_config` defined [here](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/acceleration_configs/quantized_lora_config.py). See the Notes section of FMS Acceleration doc [below](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/README.md#fms-acceleration) for usage. The only kernel we are supporting currently is `triton_v2`.
342342

343-
In addition, LoRA tuning technique is required to be used, set `peft_method` to `"lora"` and pass any arguments from [LoraConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L21).
343+
In addition, LoRA tuning technique is required to be used, set `peft_method` to `"lora"` and pass any arguments from [LoraConfig](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig).
344344

345345
Example command to run:
346346

tuning/config/peft_config.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import List
1919

2020
# Third Party
21+
from peft import LoraConfig as _LoraConfig
2122
from transformers.utils.quantization_config import Mxfp4Config as HfMxfp4Config
2223

2324

@@ -40,50 +41,26 @@ def to_hf_config(self):
4041

4142

4243
@dataclass
43-
class LoraConfig:
44+
class LoraConfig(_LoraConfig):
4445
"""
45-
This is the configuration class to store the configuration of a [`LoraModel`].
46+
This is the configuration class that extends peft.LoraConfig with a few defaults.
4647
4748
Args:
48-
r (`int`):
49-
Lora attention dimension (the "rank").
50-
target_modules (List[str]]):
51-
The names of the modules to apply the adapter to. \
52-
If this is specified, only the modules with the specified \
53-
names will be replaced. Please specify modules as per model architecture. \
54-
If the value is ["all-linear"], \
55-
then LORA selects all linear and Conv1D modules as per model architecture, \
56-
except for the output layer.
5749
lora_alpha (`int`):
5850
The alpha parameter for Lora scaling.
5951
lora_dropout (`float`):
6052
The dropout probability for Lora layers.
61-
bias (`str`):
62-
Bias type for LoRA. Can be 'none', 'all' or 'lora_only'. \
63-
If 'all' or 'lora_only', the corresponding biases will be updated during training. \
64-
Be aware that this means that, even when disabling the adapters, the model \
65-
will not produce the same output as the base model would have without adaptation.
6653
"""
67-
68-
r: int = 8
6954
lora_alpha: int = 32
70-
target_modules: List[str] = field(
71-
default=None,
72-
metadata={
73-
"help": "The names of the modules to apply LORA to. LORA selects modules which either \
74-
completely match or "
75-
'end with one of the strings. If the value is ["all-linear"], \
76-
then LORA selects all linear and Conv1D '
77-
"modules except for the output layer."
78-
},
79-
)
80-
target_parameters: List[str] = field(
81-
default=None,
82-
metadata={"help": "The names/regex of the parameters to apply LORA to"},
83-
)
84-
bias = "none"
8555
lora_dropout: float = 0.05
8656

57+
def __post_init__(self):
58+
# If target_modules is a single-element list, convert it into a plain string
59+
if self.target_modules == ["all-linear"]:
60+
self.target_modules = "all-linear"
61+
62+
super().__post_init__()
63+
8764

8865
@dataclass
8966
class PromptTuningConfig:

tuning/sft_trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def train(
7171
data_args: configs.DataArguments,
7272
train_args: configs.TrainingArguments,
7373
peft_config: Optional[ # pylint: disable=redefined-outer-name
74-
Union[peft_config.LoraConfig, LoraConfig, peft_config.PromptTuningConfig]
74+
Union[LoraConfig, peft_config.PromptTuningConfig]
7575
] = None,
7676
quantization_config: Optional[peft_config.Mxfp4Config] = None,
7777
trainer_controller_args: TrainerControllerCallback = None,
@@ -92,8 +92,7 @@ def train(
9292
model_args: tuning.config.configs.ModelArguments
9393
data_args: tuning.config.configs.DataArguments
9494
train_args: tuning.config.configs.TrainingArguments
95-
peft_config: peft_config.LoraConfig for Lora tuning | \
96-
LoraConfig (peft.LoraConfig): for activated Lora (aLoRA) tuning | \
95+
peft_config: LoraConfig (peft.LoraConfig): for activated Lora (aLoRA) tuning | \
9796
peft_config.PromptTuningConfig for prompt tuning | \
9897
None for full fine tuning
9998
The peft configuration to pass to trainer
@@ -110,7 +109,7 @@ def train(
110109
tracker with automatically be added.
111110
exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker.
112111
quantized_lora_config: tuning.config.acceleration_configs.QuantizedLoraConfig \
113-
Should be used in combination with peft_config.LoraConfig for Lora tuning \
112+
Should be used in combination with LoraConfig for Lora tuning \
114113
fusedops_kernels_config: tuning.config.acceleration_configs.FusedOpsAndKernelsConfig \
115114
Should be used in combination with quantized_lora_config. Also currently
116115
fused_lora and fast_kernels must used together (may change in future). \
@@ -855,7 +854,7 @@ def main():
855854
sys.exit(INTERNAL_ERROR_EXIT_CODE)
856855

857856
if isinstance(
858-
tune_config, (peft_config.LoraConfig, LoraConfig)
857+
tune_config, LoraConfig
859858
): # aLoraConfig subclasses LoraConfig
860859
try:
861860
if training_args.save_model_dir:

tuning/utils/config_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,10 @@ def get_hf_peft_config(task_type, tuning_config, tokenizer_name_or_path):
113113
hf_peft_config = alora_config
114114
elif isinstance(tuning_config, peft_config.LoraConfig):
115115
lora_config = asdict(tuning_config)
116-
if lora_config["target_modules"] == ["all-linear"]:
117-
lora_config["target_modules"] = "all-linear"
118-
hf_peft_config = HFLoraConfig(task_type=task_type, **lora_config)
116+
117+
if not hasattr(lora_config, "task_type"):
118+
lora_config["task_type"]=task_type
119+
hf_peft_config = HFLoraConfig(**lora_config)
119120
elif isinstance(tuning_config, peft_config.PromptTuningConfig):
120121
hf_peft_config = HFPromptTuningConfig(
121122
task_type=task_type,

0 commit comments

Comments
 (0)