Fix load_adapter OOM caused by full-model warmup sizing#46145
Conversation
load_adapter passed every named parameter on the model, including the base model, as expected_keys to _load_pretrained_model. Downstream, caching_allocator_warmup summed those into a full base-model byte count and issued a single same-size allocation on top of the already-resident base model, OOMing whenever the base model occupies more than ~half of GPU memory. The file already defined an is_adapter_key helper for identifying parameters belonging to the freshly-injected adapter, but it was declared after the _load_pretrained_model call. Hoist the helper above the call and apply it to expected_keys. Adds a regression test that captures the device map passed to caching_allocator_warmup during load_adapter and asserts it contains only adapter-owned parameter names, not base-model names.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for creating this PR. I confirmed that memory usage is actually doubled without the amendment. Here is a small reproducer:
import os
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
path = "/tmp/peft/llama"
model_id = "meta-llama/Llama-3.2-3B"
if not os.path.exists(os.path.join(path, "adapter_model.safetensors")):
model = model = AutoModelForCausalLM.from_pretrained(model_id)
config = LoraConfig()
model = get_peft_model(model, config)
model.save_pretrained(path)
del model
print(f"LoRA adapter did not exist, saved it to {path}")
model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
model.load_adapter(path)Setting a breakpoint before and after the self._load_pretrained_model, I can see that the VRAM usage doubles. With the provided fix, this is no longer the case. Thus, from my side, the PR looks good, just some small comments.
As I'm not very knowledgeable about the overall weight loading machinery in Transformers, I defer to @Cyrilvallez to judge if this is the best solution to the problem.
| def capture_warmup(model, expanded_device_map, hf_quantizer): | ||
| captured_device_maps.append(dict(expanded_device_map)) | ||
|
|
||
| modeling_utils.caching_allocator_warmup = capture_warmup |
There was a problem hiding this comment.
How about using unittest.mock.patch?
| # after loading, no meta device should be remaining | ||
| self.assertFalse(any((p.device.type == "meta") for p in model.parameters())) | ||
|
|
||
| def test_peft_load_adapter_warmup_uses_adapter_expected_keys(self): |
There was a problem hiding this comment.
The test somewhat relies on implementation details to work. As it's not easy to check the actual effect in terms of memory usage, I'd say it's fine. But if, for instance, caching_allocator_warmup is no longer used by _load_pretrained_model, the test would break and it would be hard to debug. So let's expand the test description to include how exactly this is being tested.
Cyrilvallez
left a comment
There was a problem hiding this comment.
Indeed, good catch! Looks good to me on the logic size. Will let @BenjaminBossan merge when he's happy with the test (I see a few comments there)
|
@Yooniel Do you agree with my feedback regarding the test? If yes, would you please update it accordingly? |
|
Thanks for the comments! I updated the test to use |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for updating the test, PR LGTM.
…46145) * Fix load_adapter OOM caused by full-model warmup sizing load_adapter passed every named parameter on the model, including the base model, as expected_keys to _load_pretrained_model. Downstream, caching_allocator_warmup summed those into a full base-model byte count and issued a single same-size allocation on top of the already-resident base model, OOMing whenever the base model occupies more than ~half of GPU memory. The file already defined an is_adapter_key helper for identifying parameters belonging to the freshly-injected adapter, but it was declared after the _load_pretrained_model call. Hoist the helper above the call and apply it to expected_keys. Adds a regression test that captures the device map passed to caching_allocator_warmup during load_adapter and asserts it contains only adapter-owned parameter names, not base-model names. * Address review: use unittest.mock.patch and expand test docstring
…46145) * Fix load_adapter OOM caused by full-model warmup sizing load_adapter passed every named parameter on the model, including the base model, as expected_keys to _load_pretrained_model. Downstream, caching_allocator_warmup summed those into a full base-model byte count and issued a single same-size allocation on top of the already-resident base model, OOMing whenever the base model occupies more than ~half of GPU memory. The file already defined an is_adapter_key helper for identifying parameters belonging to the freshly-injected adapter, but it was declared after the _load_pretrained_model call. Hoist the helper above the call and apply it to expected_keys. Adds a regression test that captures the device map passed to caching_allocator_warmup during load_adapter and asserts it contains only adapter-owned parameter names, not base-model names. * Address review: use unittest.mock.patch and expand test docstring
What does this PR do?
Fixes an OOM in
load_adapteron configurations where the base model occupies more than ~half of GPU memory, e.g. Gemma-3-27B in bf16 on a single H100/H200 or Llama-70B on a single 80 GB GPU.Root cause
load_adapterpasses every named parameter on the model, base model included, asexpected_keysto_load_pretrained_model. Downstream,caching_allocator_warmupsums those into a full base-model byte count and issues a single same-size allocation on top of the already-resident base model, OOMing.The allocation attempt, 51.87 GiB, is essentially the size of the base model already resident on the GPU.
Fix
Hoist the existing
is_adapter_keyhelper above the_load_pretrained_modelcall and apply it toexpected_keys, so warmup is sized only from adapter parameters. The downstreammissing_keysfilter that already used the helper is preserved.Tests
Adds a regression test that captures the device map passed to
caching_allocator_warmupduringload_adapterand asserts it contains only adapter-owned parameter names, not base-model names. Without the fix, the test fails with 84 base-model parameter names leaking into the warmup.Also verified the original GH200 repro locally: before the fix,
load_adaptertried to allocate 51.87 GiB and OOMed; after the fix, the adapter loads successfully.Related
load_adapterOOM (state-dict materialization inload_best_model_at_end), not warmup over-allocation.No associated issue was filed; this is a focused bugfix PR with a local repro, root-cause analysis, and regression test.
Code Agent Policy
Before submitting
Who can review?
caching_allocator_warmuppath.integrations/peft.pyand concerns adapter loading semantics.