Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 111 additions & 21 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,32 @@ def check_quantized_param(
**kwargs,
) -> bool:
import bitsandbytes as bnb
import bitsandbytes.nn.parametrize as bnb_parametrize

parametrizations = model.parametrizations.get(param_name, None)
has_bnb_4bit = any(isinstance(p, bnb_parametrize.Bnb4bitParametrization) for p in parametrizations) if isinstance(parametrizations, ModuleDict) else False

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
if has_bnb_4bit:
# explicit parametrization registered already
return True
elif isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return True
else:
return False
elif (
self.quantization_config.target_parameters is not None
): # Check if the parameter name is in the list of target parameters for quantization
return any(
target_param
for target_param in self.quantization_config.target_parameters
if param_name.endswith("." + target_param) or param_name == target_param
)

return False

def create_quantized_param(
self,
Expand Down Expand Up @@ -187,8 +202,9 @@ def create_quantized_param(
module._parameters[tensor_name] = new_value
return

if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
raise ValueError("this function only loads `Linear4bit components`")
# if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
# raise ValueError("this function only loads `Linear4bit components`")

if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
Expand All @@ -203,7 +219,7 @@ def create_quantized_param(

if not self.is_serializable:
raise ValueError(
"Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. "
"Detected 4bit weights but the version of bitsandbytes is not compatible with 4bit serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)

Expand All @@ -221,29 +237,61 @@ def create_quantized_param(
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)

param_kwargs = {}
if self.is_bnb_supports_quant_storage_module:
param_kwargs["module"] = module
if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
param_kwargs = {}
if self.is_bnb_supports_quant_storage_module:
param_kwargs["module"] = module

module._parameters[tensor_name] = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
**param_kwargs,
)
elif self.quantization_config.target_parameters:
# Normal nn.Parameter, i.e. outside of a Linear4bit layer.
import bitsandbytes.nn.parametrize

# Load the parameter on the target device
module._parameters[tensor_name] = torch.nn.Parameter(
param_value.to(target_device), requires_grad=False
)

# Apply the bitsandbytes parametrization to support dequantization
bitsandbytes.nn.parametrize.replace_parameter_4bit_prequantized(
module,
tensor_name,
qs_dict=quantized_stats,
device=target_device,
)

new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
**param_kwargs,
)
else:
new_value = param_value.to("cpu")

# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D):
if hasattr(module, "source_cls") and issubclass(module.source_cls, Conv1D):
new_value = new_value.T

kwargs = old_value.__dict__
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
kwargs = old_value.__dict__
module._parameters[tensor_name] = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(
target_device
)
else:
# This is a regular parameter, i.e. outside of a Linear4bit layer.
import bitsandbytes.nn.parametrize

module._parameters[tensor_name] = new_value
module._parameters[tensor_name] = torch.nn.Parameter(
param_value.to(target_device), requires_grad=False
)
bitsandbytes.nn.parametrize.replace_parameter_4bit(
module,
tensor_name,
compress_statistics=self.quantization_config.bnb_4bit_use_double_quant,
quant_type=self.quantization_config.bnb_4bit_quant_type,
)

# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory
def adjust_max_memory(self, max_memory: dict[str, Union[int, str]]) -> dict[str, Union[int, str]]:
Expand Down Expand Up @@ -284,7 +332,6 @@ def update_device_map(self, device_map):
)
return device_map

# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
Expand Down Expand Up @@ -317,6 +364,34 @@ def _process_model_before_weight_loading(
)
# TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here

if self.quantization_config.target_parameters:
matched_params = [
param_name
for param_name, _ in model.named_parameters()
if any(
filter(
lambda target_param: param_name.endswith("." + target_param) or param_name == target_param,
self.quantization_config.target_parameters,
)
)
and not any(
(key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert
)
]

if any(matched_params):
for param_name in matched_params:
module, tensor_name = get_module_from_name(model, param_name)

param = model.get_parameter(param_name)

quant_param = torch.nn.Parameter(
torch.empty((param.numel() + 1) // 2, dtype=torch.uint8),
requires_grad=False,
)

setattr(module, tensor_name, quant_param)

model.config.quantization_config = self.quantization_config

# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
Expand Down Expand Up @@ -356,4 +431,19 @@ def _dequantize(self, model):
model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)

# Remove parametrizations applied to target parameters, leaving the dequantized values
if self.quantization_config.target_parameters:
import torch.nn.utils.parametrize as P

for module_name, module in model.named_modules():
if P.is_parametrized(module):
for param_name in list(module.parametrizations.keys()):
full_name = f"{module_name}.{param_name}" if module_name else param_name
if any(
full_name.endswith("." + tp) or full_name == tp
for tp in self.quantization_config.target_parameters
):
P.remove_parametrizations(module, param_name, leave_parametrized=True)

return model
27 changes: 23 additions & 4 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,14 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
This flag is used for nested quantization where the quantization constants from the first quantization are
quantized again.
bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
This sets the storage type to pack the quantized 4-bit params.
This sets the storage type to pack the quanitzed 4-bit params.
target_parameters (`list[str]`, *optional*):
A list of extra parameters that should be quantized. This is useful for models that have
additional parameters that are not Linear layers. Parameters that exactly match or end with the names
provided here will be quantized in addition to the Linear weights. As an example, for
[Llama4](https://huggingface.co/collections/meta-llama/llama-4-67f0c30d9fe03840bc9d0164),
you can pass: `target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`.
This feature is experimental and only supported for 4bit quantization.
kwargs (`dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
"""
Expand All @@ -466,6 +473,7 @@ def __init__(
bnb_4bit_quant_type="fp4",
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_storage=None,
target_parameters=None,
**kwargs,
):
self.quant_method = QuantizationMethod.BITS_AND_BYTES
Expand All @@ -481,6 +489,7 @@ def __init__(
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
self.bnb_4bit_quant_type = bnb_4bit_quant_type
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.target_parameters = target_parameters

if bnb_4bit_compute_dtype is None:
self.bnb_4bit_compute_dtype = torch.float32
Expand Down Expand Up @@ -539,6 +548,9 @@ def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""

bnb_version = version.parse(version.parse(importlib.metadata.version("bitsandbytes")).base_version)

if not isinstance(self.load_in_4bit, bool):
raise TypeError("load_in_4bit must be a boolean")

Expand All @@ -565,13 +577,20 @@ def post_init(self):
if not isinstance(self.bnb_4bit_use_double_quant, bool):
raise TypeError("bnb_4bit_use_double_quant must be a boolean")

if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
"0.39.0"
):
if self.load_in_4bit and not bnb_version >= version.parse("0.39.0"):
raise ValueError(
"4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
)

if self.target_parameters:
if not self.load_in_4bit:
raise ValueError("target_parameters is only supported for 4bit quantization.")

if bnb_version < version.parse("0.48.0"):
raise ValueError(
"target_parameters requires bitsandbytes>=0.48.0 - please upgrade your bitsandbytes version"
)

def is_quantizable(self):
r"""
Returns `True` if the model is quantizable, `False` otherwise.
Expand Down
Loading