@@ -520,12 +520,13 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
520520
521521 # unpacking tensor with non-tensor components
522522 qs_key = [k for k , v in qs_dict .items () if "quant_state" in k and isinstance (v , torch .Tensor )]
523- if not len (qs_key ) and "quant_type" not in qs_dict :
524- raise ValueError ("Expected packed or unpacked quant_state items, found neither" )
525- elif len (qs_key ) != 1 or qs_key [0 ].split ("." )[- 1 ] not in cls .valid_qs_type_keys :
526- raise ValueError (
527- f"There should be exactly one `quant_state` item with ending from { cls .valid_qs_type_keys } .\n Detected { qs_key } ." ,
528- )
523+ if "quant_type" not in qs_dict :
524+ if not qs_key :
525+ raise ValueError ("Expected packed or unpacked quant_state items, found neither" )
526+ elif len (qs_key ) != 1 or qs_key [0 ].split ("." )[- 1 ] not in cls .valid_qs_type_keys :
527+ raise ValueError (
528+ f"There should be exactly one `quant_state` item with ending from { cls .valid_qs_type_keys } .\n Detected { qs_key } ." ,
529+ )
529530
530531 # unpacking minor and non-tensor quant state items if necessary
531532 if len (qs_key ) == 1 :
@@ -558,7 +559,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
558559 )
559560 return quant_state
560561
561- def as_dict (self , packed = False ):
562+ def as_dict (self , packed : bool = False ) -> dict [ str , Any ] :
562563 """
563564 returns dict of tensors and strings to use in serialization via _save_to_state_dict()
564565 param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
@@ -569,7 +570,7 @@ def as_dict(self, packed=False):
569570 "blocksize" : self .blocksize ,
570571 "quant_map" : self .code ,
571572 "dtype" : str (self .dtype ).strip ("torch." ),
572- "shape" : tuple (self .shape ),
573+ "shape" : tuple (self .shape ) if self . shape is not None else None ,
573574 }
574575 if self .nested :
575576 qs_dict .update (
@@ -581,13 +582,16 @@ def as_dict(self, packed=False):
581582 "nested_offset" : self .offset .item (),
582583 },
583584 )
584- if not packed :
585+ if not packed or self . quant_type is None :
585586 return qs_dict
586587
587588 # packed format allows serialization of non-tensor components, critical for saving in safetensors format
588589 qs_packed_dict = {k : v for k , v in qs_dict .items () if isinstance (v , torch .Tensor )}
589590 non_tensor_dict = {k : v for k , v in qs_dict .items () if not isinstance (v , torch .Tensor )}
590- qs_packed_dict ["quant_state." + "bitsandbytes__" + self .quant_type ] = pack_dict_to_tensor (non_tensor_dict )
591+ key = "quant_state.bitsandbytes__"
592+ if self .quant_type is not None :
593+ key += self .quant_type
594+ qs_packed_dict [key ] = pack_dict_to_tensor (non_tensor_dict )
591595 return qs_packed_dict
592596
593597 def to (self , device ):
@@ -995,7 +999,7 @@ def dequantize_4bit(
995999 """Dequantizes a packed 4-bit quantized tensor.
9961000
9971001 The input tensor is dequantized by dividing it into blocks of `blocksize` values.
998- The the absolute maximum value within these blocks is used for scaling
1002+ The absolute maximum value within these blocks is used for scaling
9991003 the non-linear dequantization.
10001004
10011005 Args:
0 commit comments