@@ -236,17 +236,35 @@ def create_quantized_param(
236236 if unexpected_keys is not None and k in unexpected_keys :
237237 unexpected_keys .remove (k )
238238
239- param_kwargs = {}
240- if self .is_bnb_supports_quant_storage_module :
241- param_kwargs ["module" ] = module
242-
243- module ._parameters [tensor_name ] = bnb .nn .Params4bit .from_prequantized (
244- data = param_value ,
245- quantized_stats = quantized_stats ,
246- requires_grad = False ,
247- device = target_device ,
248- ** param_kwargs ,
249- )
239+ if isinstance (module ._parameters [tensor_name ], bnb .nn .Params4bit ):
240+ param_kwargs = {}
241+ if self .is_bnb_supports_quant_storage_module :
242+ param_kwargs ["module" ] = module
243+
244+ module ._parameters [tensor_name ] = bnb .nn .Params4bit .from_prequantized (
245+ data = param_value ,
246+ quantized_stats = quantized_stats ,
247+ requires_grad = False ,
248+ device = target_device ,
249+ ** param_kwargs ,
250+ )
251+ elif self .quantization_config .bnb_4bit_target_paarameters :
252+ # Normal nn.Parameter, i.e. outside of a Linear4bit layer.
253+ import bitsandbytes .nn .parametrize
254+
255+ # Load the parameter on the target device
256+ module ._parameters [tensor_name ] = torch .nn .Parameter (
257+ param_value .to (target_device ), requires_grad = False
258+ )
259+
260+ # Apply the bitsandbytes parametrization to support dequantization
261+ bitsandbytes .nn .parametrize .replace_parameter_4bit_prequantized (
262+ module ,
263+ tensor_name ,
264+ qs_dict = quantized_stats ,
265+ device = target_device ,
266+ )
267+
250268 else :
251269 new_value = param_value .to ("cpu" )
252270
@@ -359,20 +377,17 @@ def _process_model_before_weight_loading(
359377 ]
360378
361379 if any (matched_params ):
362- import bitsandbytes .nn .parametrize
363-
364380 for param_name in matched_params :
365381 module , tensor_name = get_module_from_name (model , param_name )
366382
367- # Fake quantize/replace parameter - we're in `init_empty_weights`
368- # TODO: we could probably just infer the dtype/shape
369- quantized_data , quant_state = bitsandbytes .functional .quantize_4bit (
370- model .get_parameter (param_name ).data ,
371- compress_statistics = self .quantization_config .bnb_4bit_use_double_quant ,
372- quant_type = self .quantization_config .bnb_4bit_quant_type ,
383+ param = model .get_parameter (param_name )
384+
385+ quant_param = torch .nn .Parameter (
386+ torch .empty ((param .numel () + 1 ) // 2 , dtype = torch .uint8 ),
387+ requires_grad = False ,
373388 )
374389
375- setattr (module , tensor_name , torch . nn . Parameter ( quantized_data , requires_grad = False ) )
390+ setattr (module , tensor_name , quant_param )
376391
377392 model .config .quantization_config = self .quantization_config
378393
0 commit comments