Skip to content

Commit 32789ad

Browse files
Additional 4bit CPU ops
1 parent 4960932 commit 32789ad

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
import ctypes as ct
23
from typing import Optional
34

@@ -119,6 +120,10 @@ def _(
119120
) -> tuple[torch.Tensor, torch.Tensor]:
120121
torch._check_is_size(blocksize)
121122
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
123+
torch._check(
124+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
125+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
126+
)
122127

123128
n = A.numel()
124129

@@ -140,3 +145,39 @@ def _(
140145
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
141146

142147
return packed, absmax.float()
148+
149+
150+
@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
151+
def _(
152+
A: torch.Tensor,
153+
absmax: torch.Tensor,
154+
blocksize: int,
155+
quant_type: str,
156+
shape: Sequence[int],
157+
dtype: torch.dtype,
158+
) -> torch.Tensor:
159+
torch._check_is_size(blocksize)
160+
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
161+
torch._check(
162+
dtype in [torch.bfloat16, torch.float16, torch.float32],
163+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
164+
)
165+
torch._check(
166+
A.dtype == torch.uint8,
167+
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
168+
)
169+
170+
# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
171+
upper = (A >> 4).to(torch.int64)
172+
lower = (A & 0x0F).to(torch.int64)
173+
174+
# Expand to blocks
175+
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
176+
177+
# Dequantize
178+
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
179+
180+
# Reshape to original shape
181+
blocks = blocks.reshape(-1, *shape[1:])
182+
183+
return blocks.to(dtype)

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def forward(self, x: torch.Tensor):
480480

481481
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
482482

483-
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
483+
return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
484484

485485

486486
class LinearFP4(Linear4bit):

tests/test_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
171171
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
172172
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
173173
if device == "cpu":
174-
pytest.skip("CPU implementation is not available")
174+
if quant_type != "nf4":
175+
pytest.skip("CPU implementation is only available for nf4")
176+
177+
if storage_dtype != torch.uint8:
178+
pytest.skip("CPU implementation only supports uint8 storage")
175179

176180
shape = (128, 128)
177181

0 commit comments

Comments
 (0)