Skip to content

Commit f3151d2

Browse files
yueshen2016claude
authored andcommitted
fix: pass include_buffers=True to init_empty_weights for Gemma-4 support (#1169)
### What does this PR do? Type of change: Bug fix Pass `include_buffers=True` to `init_empty_weights()` when computing the device map for HuggingFace models. Gemma-4 registers its model parameters as buffers rather than parameters, so without this flag they are not accounted for during device map computation, causing incorrect placement or OOM errors. ### Usage ```python # No API change — fix is internal to get_model() in example_utils.py # Simply run HF PTQ with a Gemma-4 model as usual: python hf_ptq.py --pyt_ckpt_path google/gemma-4-... --quantize ... ``` ### Testing Manually tested with Gemma-4 model loading via `hf_ptq.py`. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information Gemma-4 uses buffers instead of parameters for some model weights, requiring `include_buffers=True` for correct device map estimation with `accelerate`'s `init_empty_weights`. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved model initialization in quantization examples by ensuring buffers are properly included during temporary model construction, resulting in more accurate device mapping inference for model optimization. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: James Shen <yueshen@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 62bde15 commit f3151d2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def has_pack_quantized_config(config):
643643
auto_model_module = getattr(transformers, architecture)
644644
from_config = auto_model_module._from_config
645645

646-
with init_empty_weights():
646+
with init_empty_weights(include_buffers=True):
647647
# When computing the device_map, assuming bfloat16 precision by default,
648648
# unless specified by the hf_config.
649649
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)

0 commit comments

Comments
 (0)