@@ -89,7 +89,7 @@ def train(
8989 model_args : configs .ModelArguments ,
9090 data_args : configs .DataArguments ,
9191 train_args : configs .TrainingArguments ,
92- peft_configs : Optional [
92+ peft_config : Optional [ # pylint: disable=redefined-outer-name
9393 Union [peft_config .LoraConfig , peft_config .PromptTuningConfig ]
9494 ] = None ,
9595):
@@ -99,7 +99,7 @@ def train(
9999 model_args: tuning.config.configs.ModelArguments
100100 data_args: tuning.config.configs.DataArguments
101101 train_args: tuning.config.configs.TrainingArguments
102- peft_configs : peft_config.LoraConfig for Lora tuning | \
102+ peft_config : peft_config.LoraConfig for Lora tuning | \
103103 peft_config.PromptTuningConfig for prompt tuning | \
104104 None for fine tuning
105105 The peft configuration to pass to trainer
@@ -131,7 +131,7 @@ def train(
131131 use_flash_attention_2 = model_args .use_flash_attn ,
132132 )
133133
134- peft_configs = get_hf_peft_config (task_type , peft_configs )
134+ peft_config = get_hf_peft_config (task_type , peft_config )
135135
136136 model .gradient_checkpointing_enable ()
137137
@@ -264,10 +264,10 @@ def train(
264264 args = train_args ,
265265 max_seq_length = model_max_length ,
266266 callbacks = callbacks ,
267- peft_config = peft_configs ,
267+ peft_config = peft_config ,
268268 )
269269
270- if run_distributed and peft_configs is not None :
270+ if run_distributed and peft_config is not None :
271271 trainer .accelerator .state .fsdp_plugin .auto_wrap_policy = fsdp_auto_wrap_policy (
272272 model
273273 )
0 commit comments