-
Notifications
You must be signed in to change notification settings - Fork 33.7k
FIX Restore LoRA hotswapping functionality #45682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1101322
3e59566
6a15499
87356d8
85e8d28
8e111a6
96c477b
819e59b
a91877f
4f52050
8211bbc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1099,33 +1099,41 @@ def tearDown(self): | |
| torch.compiler.reset() | ||
| gc.collect() | ||
|
|
||
| def _check_model_hotswap(self, *, rank1, rank2, do_compile): | ||
| def _check_model_hotswap( | ||
| self, *, rank1, rank2, do_compile, model_id="hf-internal-testing/tiny-random-OPTForCausalLM" | ||
| ): | ||
| # utility method that checks that we can successfully hotswap adapters, with the model outputs corresponding to | ||
| # the respective adapters | ||
| from peft import LoraConfig | ||
|
|
||
| torch.manual_seed(0) | ||
| model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) | ||
| input = torch.randint(0, 100, (1, 10)).to(torch_device) | ||
| with torch.inference_mode(): | ||
| base_output = model(input).logits | ||
|
|
||
| # create 2 adapters | ||
| model.add_adapter(LoraConfig(r=rank1, init_lora_weights=False), adapter_name="adapter_1") | ||
| model.add_adapter( | ||
| LoraConfig(r=rank1, init_lora_weights=False, target_modules=["q_proj", "v_proj"]), adapter_name="adapter_1" | ||
| ) | ||
| with torch.inference_mode(): | ||
| lora_1_output = model(input).logits | ||
|
|
||
| # second adapter may have a different rank | ||
| model.add_adapter(LoraConfig(r=rank2, init_lora_weights=False), adapter_name="adapter_2") | ||
| model.add_adapter( | ||
| LoraConfig(r=rank2, init_lora_weights=False, target_modules=["q_proj", "v_proj"]), adapter_name="adapter_2" | ||
| ) | ||
| model.set_adapter("adapter_2") | ||
| with torch.inference_mode(): | ||
| lora_2_output = model(input).logits | ||
|
|
||
| # sanity checks | ||
| self.assertFalse(torch.allclose(base_output, lora_1_output, atol=1e-6, rtol=1e-6)) | ||
| self.assertFalse(torch.allclose(base_output, lora_2_output, atol=1e-6, rtol=1e-6)) | ||
| self.assertFalse(torch.allclose(lora_1_output, lora_2_output, atol=1e-6, rtol=1e-6)) | ||
| atol = 2e-3 | ||
| rtol = 1e-6 | ||
| self.assertFalse(torch.allclose(base_output, lora_1_output, atol=atol, rtol=rtol)) | ||
| self.assertFalse(torch.allclose(base_output, lora_2_output, atol=atol, rtol=rtol)) | ||
| self.assertFalse(torch.allclose(lora_1_output, lora_2_output, atol=atol, rtol=rtol)) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdirname: | ||
| path_1 = os.path.join(tmpdirname, "adapter_1") | ||
|
|
@@ -1143,26 +1151,32 @@ def _check_model_hotswap(self, *, rank1, rank2, do_compile): | |
| model.enable_peft_hotswap(target_rank=max(rank1, rank2)) | ||
|
|
||
| # load the first adapter without hotswap (hotswap requires an existing adapter) | ||
| model.load_adapter(path_1, adapter_name="adapter_1") | ||
| model.load_adapter(path_1, adapter_name="adapter_1", is_trainable=False) | ||
| if do_compile: | ||
| # compile the model after loading the first adapter | ||
| model = torch.compile(model, mode="reduce-overhead") | ||
| if "mixtral" not in model_id.lower(): | ||
| model = torch.compile(model, mode="reduce-overhead") | ||
| else: | ||
| # The tiny mixtral model is incompatible with 'reduce-overhead', resulting in: | ||
| # > torch.AcceleratorError: CUDA error: operation failed due to a previous error during capture | ||
| # For the purpose of this test, 'reduce-overhead' is not material, so we drop it here. | ||
| model = torch.compile(model) | ||
|
|
||
| with torch.inference_mode(): | ||
| lora_1_output_loaded = model(input).logits | ||
| self.assertTrue(torch.allclose(lora_1_output, lora_1_output_loaded, atol=1e-6, rtol=1e-6)) | ||
| self.assertTrue(torch.allclose(lora_1_output, lora_1_output_loaded, atol=atol, rtol=rtol)) | ||
|
|
||
| # hotswap in adapter_2 again, output should be same as lora_2_output | ||
| if enable_hotswap: | ||
| # after calling enable_peft_hotswap, hotswap will automatically be enabled | ||
| model.load_adapter(path_2, adapter_name="adapter_1") | ||
| model.load_adapter(path_2, adapter_name="adapter_1", is_trainable=False) | ||
| else: | ||
| # enable_peft_hotswap was not called, need to explicitly pass hotswap=True | ||
| model.load_adapter(path_2, adapter_name="adapter_1", hotswap=True) | ||
| model.load_adapter(path_2, adapter_name="adapter_1", hotswap=True, is_trainable=False) | ||
|
|
||
| with torch.inference_mode(): | ||
| lora_2_output_loaded = model(input).logits | ||
| self.assertTrue(torch.allclose(lora_2_output, lora_2_output_loaded, atol=1e-6, rtol=1e-6)) | ||
| self.assertTrue(torch.allclose(lora_2_output, lora_2_output_loaded, atol=atol, rtol=rtol)) | ||
|
|
||
| def test_hotswap_wrong_peft_type_raises(self): | ||
| # only LoRA is supported for now | ||
|
|
@@ -1313,3 +1327,17 @@ def test_maybe_load_adapters_path_not_overwritten_for_complete_model(self): | |
| # Load from the saved path and make sure it actually loads despite | ||
| # the invalid adapter config path | ||
| AutoModel.from_pretrained(tmp_dir) | ||
|
|
||
| def test_mixtral_hotswap_without_compile_works(self): | ||
| # test a model that usees weight conversion | ||
| model_id = "hf-internal-testing/Mixtral-tiny" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are the experts properly targeted ? (the only that actually do
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, the experts are not covered by this test. As mentioned, it wouldn't work for lack of support in PEFT itself. I have an item on my TODO list to see if expert (or rather, |
||
| self._check_model_hotswap(rank1=7, rank2=13, do_compile=False, model_id=model_id) | ||
|
|
||
| def test_mixtral_hotswap_with_compile_works(self): | ||
| # test a model that usees weight conversion | ||
| model_id = "hf-internal-testing/Mixtral-tiny" | ||
| with ( | ||
| torch._dynamo.config.patch(error_on_recompile=True), | ||
| torch._inductor.utils.fresh_inductor_cache(), | ||
| ): | ||
| self._check_model_hotswap(rank1=8, rank2=8, do_compile=True, model_id=model_id) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This is not new code, it's just moved to a separate method to avoid duplication. Original code here:
transformers/src/transformers/integrations/peft.py
Lines 631 to 656 in 049d2bf