-
Notifications
You must be signed in to change notification settings - Fork 66
fix: subclass Lora config from upstream peft.LoraConfig #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
80b866c
de6baa7
f5be874
5acad56
3c5711e
8437bcf
dbc8b65
bfb995d
8475748
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,7 @@ def train( | |
| data_args: configs.DataArguments, | ||
| train_args: configs.TrainingArguments, | ||
| peft_config: Optional[ # pylint: disable=redefined-outer-name | ||
| Union[peft_config.LoraConfig, LoraConfig, peft_config.PromptTuningConfig] | ||
| Union[LoraConfig, peft_config.PromptTuningConfig] | ||
| ] = None, | ||
| quantization_config: Optional[peft_config.Mxfp4Config] = None, | ||
| trainer_controller_args: TrainerControllerCallback = None, | ||
|
|
@@ -92,8 +92,7 @@ def train( | |
| model_args: tuning.config.configs.ModelArguments | ||
| data_args: tuning.config.configs.DataArguments | ||
| train_args: tuning.config.configs.TrainingArguments | ||
| peft_config: peft_config.LoraConfig for Lora tuning | \ | ||
| LoraConfig (peft.LoraConfig): for activated Lora (aLoRA) tuning | \ | ||
| peft_config: LoraConfig (peft.LoraConfig): for activated Lora (aLoRA) tuning | \ | ||
| peft_config.PromptTuningConfig for prompt tuning | \ | ||
| None for full fine tuning | ||
| The peft configuration to pass to trainer | ||
|
|
@@ -110,7 +109,7 @@ def train( | |
| tracker with automatically be added. | ||
| exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker. | ||
| quantized_lora_config: tuning.config.acceleration_configs.QuantizedLoraConfig \ | ||
| Should be used in combination with peft_config.LoraConfig for Lora tuning \ | ||
| Should be used in combination with LoraConfig for Lora tuning \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a reference to the HuggingFace loraconfig arguments here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| fusedops_kernels_config: tuning.config.acceleration_configs.FusedOpsAndKernelsConfig \ | ||
| Should be used in combination with quantized_lora_config. Also currently | ||
| fused_lora and fast_kernels must used together (may change in future). \ | ||
|
|
@@ -845,9 +844,7 @@ def main(): | |
| ) | ||
| sys.exit(INTERNAL_ERROR_EXIT_CODE) | ||
|
|
||
| if isinstance( | ||
| tune_config, (peft_config.LoraConfig, LoraConfig) | ||
| ): # aLoraConfig subclasses LoraConfig | ||
| if isinstance(tune_config, LoraConfig): # aLoraConfig subclasses LoraConfig | ||
| try: | ||
| if training_args.save_model_dir: | ||
| # Write number of added tokens to artifacts | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,6 @@ | |
| import pickle | ||
|
|
||
| # Third Party | ||
| from peft import LoraConfig as HFLoraConfig | ||
| from peft import PromptTuningConfig as HFPromptTuningConfig | ||
|
|
||
| # Local | ||
|
|
@@ -112,10 +111,13 @@ def get_hf_peft_config(task_type, tuning_config, tokenizer_name_or_path): | |
| alora_config.task_type = task_type | ||
| hf_peft_config = alora_config | ||
| elif isinstance(tuning_config, peft_config.LoraConfig): | ||
| lora_config = asdict(tuning_config) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have removed this since Flattening the parameters using |
||
| if lora_config["target_modules"] == ["all-linear"]: | ||
| lora_config["target_modules"] = "all-linear" | ||
| hf_peft_config = HFLoraConfig(task_type=task_type, **lora_config) | ||
| if getattr(tuning_config, "target_modules") == ["all-linear"]: | ||
| setattr(tuning_config, "target_modules", "all-linear") | ||
|
|
||
| if getattr(tuning_config, "task_type") is None: | ||
| setattr(tuning_config, "task_type", task_type) | ||
|
|
||
| hf_peft_config = tuning_config | ||
| elif isinstance(tuning_config, peft_config.PromptTuningConfig): | ||
| hf_peft_config = HFPromptTuningConfig( | ||
| task_type=task_type, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@romitjain per your comment on line 58 above the other arguments are incompatible with arg parser but do we need to have arguments like this too?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, any field that accepts heterogeneous fields will need to be defined again.
For
init_lora_weights, the original type is: