Skip to content

Commit 23934d5

Browse files
Update config name
1 parent cd2c33f commit 23934d5

2 files changed

Lines changed: 20 additions & 15 deletions

File tree

src/transformers/quantizers/quantizer_bnb_4bit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,11 @@ def check_quantized_param(
151151
# but it would wrongly use uninitialized weight there.
152152
return True
153153
elif (
154-
self.quantization_config.bnb_4bit_target_parameters is not None
154+
self.quantization_config.target_parameters is not None
155155
): # Check if the parameter name is in the list of target parameters for quantization
156156
return any(
157157
target_param
158-
for target_param in self.quantization_config.bnb_4bit_target_parameters
158+
for target_param in self.quantization_config.target_parameters
159159
if param_name.endswith("." + target_param) or param_name == target_param
160160
)
161161

@@ -242,7 +242,7 @@ def create_quantized_param(
242242
device=target_device,
243243
**param_kwargs,
244244
)
245-
elif self.quantization_config.bnb_4bit_target_parameters:
245+
elif self.quantization_config.target_parameters:
246246
# Normal nn.Parameter, i.e. outside of a Linear4bit layer.
247247
import bitsandbytes.nn.parametrize
248248

@@ -357,15 +357,15 @@ def _process_model_before_weight_loading(
357357
)
358358
# TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
359359

360-
if self.quantization_config.bnb_4bit_target_parameters:
360+
if self.quantization_config.target_parameters:
361361
# TODO: consider when param is in a module specified by modules_to_not_convert
362362
matched_params = [
363363
param_name
364364
for param_name, _ in model.named_parameters()
365365
if any(
366366
filter(
367367
lambda target_param: param_name.endswith("." + target_param) or param_name == target_param,
368-
self.quantization_config.bnb_4bit_target_parameters,
368+
self.quantization_config.target_parameters,
369369
)
370370
)
371371
]
@@ -419,7 +419,7 @@ def is_trainable(self) -> bool:
419419
def _dequantize(self, model):
420420
from ..integrations import dequantize_and_replace
421421

422-
# TODO: support bnb_4bit_target_parameters
422+
# TODO: support target_parameters
423423

424424
model = dequantize_and_replace(
425425
model, self.modules_to_not_convert, quantization_config=self.quantization_config

src/transformers/utils/quantization_config.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,13 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
450450
quantized again.
451451
bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
452452
This sets the storage type to pack the quanitzed 4-bit params.
453-
bnb_4bit_target_parameters (`list[str]`, *optional*):
454-
A list of extra parameters that should be quantized in 4-bit. This is useful for models that have
453+
target_parameters (`list[str]`, *optional*):
454+
A list of extra parameters that should be quantized. This is useful for models that have
455455
additional parameters that are not Linear layers. Parameters that exactly match or end with the names
456456
provided here will be quantized in addition to the Linear weights. As an example, for
457457
[Llama4](https://huggingface.co/collections/meta-llama/llama-4-67f0c30d9fe03840bc9d0164),
458-
you can pass: `bnb_4bit_target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`
458+
you can pass: `target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`.
459+
This feature is experimental and only supported for 4bit quantization.
459460
kwargs (`dict[str, Any]`, *optional*):
460461
Additional parameters from which to initialize the configuration object.
461462
"""
@@ -472,7 +473,7 @@ def __init__(
472473
bnb_4bit_quant_type="fp4",
473474
bnb_4bit_use_double_quant=False,
474475
bnb_4bit_quant_storage=None,
475-
bnb_4bit_target_parameters=None,
476+
target_parameters=None,
476477
**kwargs,
477478
):
478479
self.quant_method = QuantizationMethod.BITS_AND_BYTES
@@ -488,7 +489,7 @@ def __init__(
488489
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
489490
self.bnb_4bit_quant_type = bnb_4bit_quant_type
490491
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
491-
self.bnb_4bit_target_parameters = bnb_4bit_target_parameters
492+
self.target_parameters = target_parameters
492493

493494
if bnb_4bit_compute_dtype is None:
494495
self.bnb_4bit_compute_dtype = torch.float32
@@ -581,10 +582,14 @@ def post_init(self):
581582
"4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
582583
)
583584

584-
if self.bnb_4bit_target_parameters is not None and bnb_version < version.parse("0.48.0"):
585-
raise ValueError(
586-
"bnb_4bit_target_parameters requires bitsandbytes>=0.48.0 - please upgrade your bitsandbytes version"
587-
)
585+
if self.target_parameters:
586+
if not self.load_in_4bit:
587+
raise ValueError("target_parameters is only supported for 4bit quantization.")
588+
589+
if bnb_version < version.parse("0.48.0"):
590+
raise ValueError(
591+
"target_parameters requires bitsandbytes>=0.48.0 - please upgrade your bitsandbytes version"
592+
)
588593

589594
def is_quantizable(self):
590595
r"""

0 commit comments

Comments
 (0)