Skip to content

Commit 4907de8

Browse files
committed
Lint and fmt fixes
Signed-off-by: romit <romit@ibm.com>
1 parent f5be874 commit 4907de8

2 files changed

Lines changed: 53 additions & 34 deletions

File tree

tuning/config/peft_config.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import List, Optional
1919

2020
# Third Party
21-
from peft import LoraConfig as _LoraConfig
21+
from peft import LoraConfig as HFLoraConfig
2222
from transformers.utils.quantization_config import Mxfp4Config as HfMxfp4Config
2323

2424

@@ -41,7 +41,7 @@ def to_hf_config(self):
4141

4242

4343
@dataclass
44-
class LoraConfig(_LoraConfig):
44+
class LoraConfig(HFLoraConfig):
4545
"""
4646
This is the configuration class that extends peft.LoraConfig with a few defaults.
4747
@@ -63,71 +63,92 @@ class LoraConfig(_LoraConfig):
6363
default=None,
6464
metadata={
6565
"help": (
66-
"List of module names or regex expression of the module names to replace with LoRA. "
66+
"List of module names or regex expression of the module names to replace with LoRA."
6767
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
6868
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D "
6969
"(if the model is a PreTrainedModel, the output layer excluded). "
70-
"If not specified, modules will be chosen according to the model architecture, If the architecture is "
71-
"not known, an error will be raised -- in this case, you should specify the target modules manually. "
72-
"To avoid targeting any modules (because you want to apply `target_parameters`), set "
73-
"`target_modules=[]`."
70+
"If not specified, modules will be chosen according to the model architecture, "
71+
"If the architecture is not known, an error will be raised -- "
72+
"in this case, you should specify the target modules manually. "
73+
"To avoid targeting any modules (because you want to apply `target_parameters`) "
74+
", set `target_modules=[]`."
7475
),
7576
},
7677
)
77-
exclude_modules: List[str] | None = field(
78+
exclude_modules: Optional[List[str]] = field(
7879
default=None,
7980
metadata={
80-
"help": "List of module names or regex expression of the module names to exclude from Lora."
81+
"help": (
82+
"List of module names or regex expression of the module names to exclude from Lora."
83+
)
8184
},
8285
)
83-
init_lora_weights: (bool) = field(
86+
init_lora_weights: bool = field(
8487
default=True,
8588
metadata={
8689
"help": (
8790
"How to initialize the weights of the LoRA layers. "
88-
"Passing True (default) results in the default initialization from the reference implementation from "
89-
"Microsoft, with the LoRA B weight being set to 0. This means that without further training, the LoRA "
90-
"adapter will be a no-op. "
91-
"Setting the initialization to False leads to random initialization of LoRA A and B, meaning that LoRA "
92-
"is not a no-op before training; this setting is intended for debugging purposes. "
91+
"Passing True (default) results in the default initialization from "
92+
"the reference implementation from "
93+
"Microsoft, with the LoRA B weight being set to 0. "
94+
"This means that without further training, "
95+
"the LoRA adapter will be a no-op. "
96+
"Setting the initialization to False leads to random initialization of "
97+
"LoRA A and B, meaning that LoRA is not a no-op before training; "
98+
"this setting is intended for debugging purposes."
9399
),
94100
},
95101
)
96102
layers_to_transform: Optional[list[int]] = field(
97103
default=None,
98104
metadata={
99-
"help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index. "
100-
"This only works when target_modules is a list of str."
105+
"help": (
106+
"The layer indexes to transform, is this argument is specified, "
107+
"PEFT will transform only the layers indexes that are specified inside this list. "
108+
"If a single integer is passed, PEFT will transform only the layer at this index. "
109+
"This only works when target_modules is a list of str."
110+
)
101111
},
102112
)
103113
layers_pattern: Optional[list[str]] = field(
104114
default=None,
105115
metadata={
106-
"help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern."
107-
"This only works when target_modules is a list of str. This should target the `nn.ModuleList` of the "
108-
"model, which is often called `'layers'` or `'h'`."
116+
"help": (
117+
"The layer pattern name, used only if `layers_to_transform` is different to None "
118+
"and if the layer pattern is not in the common layers pattern. "
119+
"This only works when target_modules is a list of str. "
120+
"This should target the `nn.ModuleList` of the "
121+
"model, which is often called `'layers'` or `'h'`."
122+
)
109123
},
110124
)
111125
trainable_token_indices: Optional[list[int]] = field(
112126
default=None,
113127
metadata={
114128
"help": (
115-
"Lets you specify which token indices to selectively fine-tune without requiring to re-train the "
116-
"whole embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices "
117-
"in two ways. Either you specify a list of indices which will then target the model's input embedding "
118-
"layer (or, if not found, `embed_tokens`). (Not supported yet) Alternatively, you can specify a dictionary where the key "
119-
"is the name of the embedding module and the values are the list of token indices, e.g. "
120-
"`{'embed_tokens': [0, 1, ...]}`. Note that training with FSDP requires `use_orig_params=True` to "
129+
"Lets you specify which token indices to selectively fine-tune "
130+
"without requiring to re-train the "
131+
"whole embedding matrix using the `peft.TrainableTokensModel` method. "
132+
"You can specify token indices in two ways. "
133+
"Either you specify a list of indices which will then target the model's input "
134+
"embedding layer (or, if not found, `embed_tokens`). "
135+
"(Not supported yet) Alternatively, you can specify a dictionary "
136+
"where the key is the name of the embedding module "
137+
"and the values are the list of token indices, e.g. "
138+
"`{'embed_tokens': [0, 1, ...]}`. Note that training "
139+
"with FSDP requires `use_orig_params=True` to "
121140
"avoid issues with non-uniform `requires_grad`."
122141
)
123142
},
124143
)
125-
loftq_config: dict = field(
144+
loftq_config: Optional[dict] = field(
126145
default_factory=dict,
127146
metadata={
128147
"help": (
129-
"The configuration of LoftQ. If this is passed, then LoftQ will be used to quantize the backbone "
130-
"weights and initialize Lora layers. Also set `init_lora_weights='loftq'` in this case."
148+
"The configuration of LoftQ. If this is passed, "
149+
"then LoftQ will be used to quantize the backbone "
150+
"weights and initialize Lora layers. Also set `init_lora_weights='loftq'` "
151+
"in this case."
131152
)
132153
},
133154
)

tuning/utils/config_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import pickle
2121

2222
# Third Party
23-
from peft import LoraConfig as HFLoraConfig
2423
from peft import PromptTuningConfig as HFPromptTuningConfig
2524

2625
# Local
@@ -112,11 +111,10 @@ def get_hf_peft_config(task_type, tuning_config, tokenizer_name_or_path):
112111
alora_config.task_type = task_type
113112
hf_peft_config = alora_config
114113
elif isinstance(tuning_config, peft_config.LoraConfig):
115-
lora_config = asdict(tuning_config)
114+
if getattr(tuning_config, "task_type") is None:
115+
setattr(tuning_config, "task_type", task_type)
116116

117-
if not hasattr(lora_config, "task_type"):
118-
lora_config["task_type"] = task_type
119-
hf_peft_config = HFLoraConfig(**lora_config)
117+
hf_peft_config = tuning_config
120118
elif isinstance(tuning_config, peft_config.PromptTuningConfig):
121119
hf_peft_config = HFPromptTuningConfig(
122120
task_type=task_type,

0 commit comments

Comments
 (0)