Skip to content

Commit 79ce819

Browse files
authored
Fix QuantState and dict conversions (#1729)
* Fix QuantState.as_dict Signed-off-by: cyy <cyyever@outlook.com> * Fix QuantState.from_dict Signed-off-by: cyy <cyyever@outlook.com> * Fix QuantState.as_dict Signed-off-by: cyy <cyyever@outlook.com> * Add test Signed-off-by: cyy <cyyever@outlook.com> * Fix comment Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> * Fix self.quant_type is None Signed-off-by: Yuanyuan Chen <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com> Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
1 parent d1049e7 commit 79ce819

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

bitsandbytes/functional.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -520,12 +520,13 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
520520

521521
# unpacking tensor with non-tensor components
522522
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
523-
if not len(qs_key) and "quant_type" not in qs_dict:
524-
raise ValueError("Expected packed or unpacked quant_state items, found neither")
525-
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
526-
raise ValueError(
527-
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
528-
)
523+
if "quant_type" not in qs_dict:
524+
if not qs_key:
525+
raise ValueError("Expected packed or unpacked quant_state items, found neither")
526+
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
527+
raise ValueError(
528+
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
529+
)
529530

530531
# unpacking minor and non-tensor quant state items if necessary
531532
if len(qs_key) == 1:
@@ -558,7 +559,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
558559
)
559560
return quant_state
560561

561-
def as_dict(self, packed=False):
562+
def as_dict(self, packed: bool = False) -> dict[str, Any]:
562563
"""
563564
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
564565
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
@@ -569,7 +570,7 @@ def as_dict(self, packed=False):
569570
"blocksize": self.blocksize,
570571
"quant_map": self.code,
571572
"dtype": str(self.dtype).strip("torch."),
572-
"shape": tuple(self.shape),
573+
"shape": tuple(self.shape) if self.shape is not None else None,
573574
}
574575
if self.nested:
575576
qs_dict.update(
@@ -581,13 +582,16 @@ def as_dict(self, packed=False):
581582
"nested_offset": self.offset.item(),
582583
},
583584
)
584-
if not packed:
585+
if not packed or self.quant_type is None:
585586
return qs_dict
586587

587588
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
588589
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
589590
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
590-
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
591+
key = "quant_state.bitsandbytes__"
592+
if self.quant_type is not None:
593+
key += self.quant_type
594+
qs_packed_dict[key] = pack_dict_to_tensor(non_tensor_dict)
591595
return qs_packed_dict
592596

593597
def to(self, device):
@@ -995,7 +999,7 @@ def dequantize_4bit(
995999
"""Dequantizes a packed 4-bit quantized tensor.
9961000
9971001
The input tensor is dequantized by dividing it into blocks of `blocksize` values.
998-
The the absolute maximum value within these blocks is used for scaling
1002+
The absolute maximum value within these blocks is used for scaling
9991003
the non-linear dequantization.
10001004
10011005
Args:

tests/test_functional.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
118118
for i in range(iters):
119119
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
120120
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
121+
if i == 0:
122+
d = S.as_dict()
123+
S = F.QuantState.from_dict(d, device=torch.device(device))
121124
A2 = F.dequantize_blockwise(C, S)
122125
diff = torch.abs(A1 - A2).float()
123126
reldiff = diff / torch.abs(A1.float() + 1e-8)
@@ -134,6 +137,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
134137
for i in range(iters):
135138
A1 = torch.rand(1024, 1024, device=device, dtype=dtype)
136139
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
140+
if i == 0:
141+
d = S.as_dict()
142+
S = F.QuantState.from_dict(d, device=torch.device(device))
137143
A2 = F.dequantize_blockwise(C, S)
138144
diff = torch.abs(A1 - A2).float()
139145
reldiff = diff / torch.abs(A1.float() + 1e-8)
@@ -271,6 +277,9 @@ def test_fp8_quant(self, device):
271277
for i in range(10):
272278
A1 = torch.randn(1024, 1024, device=device)
273279
C, SC = F.quantize_blockwise(A1, code=code)
280+
if i == 0:
281+
d = SC.as_dict()
282+
SC = F.QuantState.from_dict(d, device=torch.device(device))
274283
A2 = F.dequantize_blockwise(C, SC)
275284
diff = torch.abs(A1 - A2)
276285
reldiff = diff / torch.abs(A1 + 1e-8)
@@ -1118,6 +1127,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11181127

11191128
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11201129
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
1130+
d = SA.as_dict()
1131+
SA = F.QuantState.from_dict(d, device=torch.device(device))
11211132
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
11221133
del qa, SA
11231134

0 commit comments

Comments
 (0)