Skip to content

Commit 6aa9619

Browse files
authored
Cpu fused kernel (#1804)
* add template to support more dtypes Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update cmake list Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix compile cpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * make different dtype works Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * use bf16 on CPU Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix state2 dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * remove torch Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm torch Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable float to bf16 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm dequantizeBlockwise4bitCpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable dequant 4bit kernel Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix dequantize Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * change input param Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix input param Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * spliut 8bit and 4bit Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix input params Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix input params Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable dequant4bit Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix reverse Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix dequant 4bit fallback path Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix fp4 dequant Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm _Float16 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * tmp codes Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable gemv Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * change to 4bit dequant Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix def Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix type Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix absmax dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix type Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix compile and type Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable gemv Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix shape Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix lib name Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * debug Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable gemv 4bit bf16 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable avx512 check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix endif Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix def Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix position Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm duplicated func Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm useless code comments Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix out shape Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comments Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add reverse format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check avx512bf15 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix has_avx512bf16 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix absmax shhape Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix compile Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix test_gemv Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * disable binsearch Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix lint Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix save Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent bd028b8 commit 6aa9619

File tree

10 files changed

+755
-81
lines changed

10 files changed

+755
-81
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,9 @@ if (BUILD_CPU)
282282
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG)
283283
if (HAS_AVX512F_FLAG)
284284
target_compile_options(bitsandbytes PRIVATE -mavx512f)
285+
target_compile_options(bitsandbytes PRIVATE -mavx512dq)
286+
target_compile_options(bitsandbytes PRIVATE -mavx512bw)
287+
target_compile_options(bitsandbytes PRIVATE -mavx512vl)
285288
endif()
286289
if (HAS_AVX512BF16_FLAG)
287290
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)

bitsandbytes/autograd/_functions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,18 @@ def matmul_4bit(
374374
bias: Optional[torch.Tensor] = None,
375375
):
376376
assert quant_state is not None
377-
# Change dtype to bfloat16 on CPU
377+
# Change dtype to input dtype on CPU
378378
if A.device.type == "cpu":
379379
quant_state.dtype = A.dtype
380380

381+
if getattr(quant_state, "packing_format_for_cpu", False):
382+
out = F.gemv_4bit(A, B, out, state=quant_state)
383+
if bias is not None:
384+
out += bias
385+
return out
386+
else:
387+
return MatMul4Bit.apply(A, B, out, bias, quant_state)
388+
381389
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
382390
if A.shape[-1] % quant_state.blocksize != 0:
383391
warn(

bitsandbytes/backends/cpu/ops.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from bitsandbytes.functional import get_ptr
8+
from bitsandbytes.functional import get_ptr, has_avx512bf16
99

1010
from ..._ops import register_kernel
1111
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
@@ -217,3 +217,62 @@ def _(
217217
raise ValueError
218218

219219
return out
220+
221+
if has_avx512bf16():
222+
223+
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
224+
def _(
225+
A: torch.Tensor,
226+
B: torch.Tensor,
227+
shapeB: Sequence[int],
228+
absmax: torch.Tensor,
229+
code: torch.Tensor,
230+
blocksize: int,
231+
) -> torch.Tensor:
232+
assert B.dtype == torch.uint8, "Only support uint8 qweight"
233+
dtype = A.dtype
234+
quant_type = "fp4" if code[1] > 0 else "nf4"
235+
# cpu fused op only support bf16 for now.
236+
if dtype != torch.bfloat16:
237+
A = A.to(torch.bfloat16)
238+
239+
final_out_shape = (*A.shape[:-1], shapeB[0])
240+
A = A.reshape(-1, A.shape[-1])
241+
out_shape = (*A.shape[:-1], shapeB[0])
242+
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
243+
M = A.shape[0]
244+
N = shapeB[0]
245+
K = A.shape[1]
246+
x_strideM = A.stride(0)
247+
out_strideM = out.stride(0)
248+
if quant_type == "fp4":
249+
lib.gemv_4bit_inference_cpu_fp4_bf16(
250+
ct.c_int64(M),
251+
ct.c_int64(N),
252+
ct.c_int64(K),
253+
get_ptr(A),
254+
get_ptr(B),
255+
get_ptr(absmax),
256+
get_ptr(out),
257+
ct.c_int64(blocksize),
258+
ct.c_int64(x_strideM),
259+
ct.c_int64(out_strideM),
260+
)
261+
elif quant_type == "nf4":
262+
lib.gemv_4bit_inference_cpu_nf4_bf16(
263+
ct.c_int64(M),
264+
ct.c_int64(N),
265+
ct.c_int64(K),
266+
get_ptr(A),
267+
get_ptr(B),
268+
get_ptr(absmax),
269+
get_ptr(out),
270+
ct.c_int64(blocksize),
271+
ct.c_int64(x_strideM),
272+
ct.c_int64(out_strideM),
273+
)
274+
275+
if dtype != torch.bfloat16:
276+
out = out.to(dtype)
277+
278+
return out.reshape(final_out_shape)

bitsandbytes/functional.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,4 +2103,138 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
21032103
return out
21042104

21052105

2106+
def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32):
2107+
"""
2108+
qweight: (K * N / 2) uint8
2109+
return: packed_weight
2110+
"""
2111+
if qweight.dtype != torch.uint8:
2112+
quant_state.original_storage_type = qweight.dtype
2113+
qweight = qweight.view(torch.uint8)
2114+
quant_state.original_dtype = quant_state.dtype
2115+
quant_state.original_nested = quant_state.nested
2116+
quant_state.original_qshape = qweight.shape
2117+
2118+
qweight = qweight.reshape(-1)
2119+
unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=qweight.device)
2120+
unpacked_w[1::2] = qweight & 0xF
2121+
unpacked_w[::2] = qweight >> 4
2122+
qweight_final = unpacked_w.reshape(quant_state.shape).to(torch.uint8) # (*, N, K)
2123+
# pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit
2124+
assert len(qweight_final.shape) == 2
2125+
N, K = qweight_final.shape[0], qweight_final.shape[1]
2126+
assert N % block_n == 0, "N must be divisible by block_n"
2127+
assert K % 2 == 0, "K must be even"
2128+
BLOCK_N = block_n
2129+
BIT_COUNT = 32 # (=32 low +32 high)
2130+
new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2]
2131+
out_shape = [N, K // 2]
2132+
qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2)
2133+
qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2)
2134+
qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64]
2135+
high = qw[:, BIT_COUNT:] # high 32
2136+
low = qw[:, :BIT_COUNT] # low 32
2137+
packed = ((high << 4) | low).to(torch.uint8) # combine
2138+
final_qweight = packed.reshape(out_shape)
2139+
if quant_state.nested:
2140+
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
2141+
absmax += quant_state.offset
2142+
if absmax.dtype != torch.float32:
2143+
absmax = absmax.float()
2144+
2145+
quant_state.absmax = absmax
2146+
quant_state.nested = False
2147+
delattr(quant_state, "state2")
2148+
2149+
quant_state.absmax = (
2150+
quant_state.absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
2151+
.T.to(torch.bfloat16)
2152+
.contiguous()
2153+
)
2154+
2155+
quant_state.dtype = torch.bfloat16
2156+
quant_state.packing_format_for_cpu = True
2157+
return final_qweight, quant_state
2158+
2159+
2160+
def _convert_weight_packed_for_cpu_inverse(
2161+
packed_weight: torch.Tensor,
2162+
quant_state: QuantState,
2163+
block_n: int = 32,
2164+
) -> tuple[torch.Tensor, QuantState]:
2165+
"""
2166+
packed_weight: [N, K/2] uint8, output of `_convert_weight_packed_for_cpu` (final_qweight)
2167+
quant_state: QuantState that was modified by `_convert_weight_packed_for_cpu`
2168+
Returns:
2169+
qweight: [*, N, K] uint8, original qweight shape (quant_state.shape)
2170+
recovered_state: QuantState with partially restored fields (best-effort inverse)
2171+
"""
2172+
assert quant_state.packing_format_for_cpu, "only for packing format"
2173+
assert packed_weight.dtype == torch.uint8
2174+
assert len(packed_weight.shape) == 2, "packed_weight should be [N, K/2]"
2175+
N, K_half = packed_weight.shape
2176+
K = K_half * 2
2177+
2178+
# 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2]
2179+
BLOCK_N = block_n
2180+
BIT_COUNT = 32 # (=32 low + 32 high)
2181+
2182+
assert N % BLOCK_N == 0, "N must be divisible by block_n"
2183+
assert K % 2 == 0, "K must be even"
2184+
2185+
# [N, K/2] -> [-1, 64] (32 low + 32 high)
2186+
packed = packed_weight.reshape(-1, BIT_COUNT) # [-1, 64]
2187+
# split high/low nibbles
2188+
high = (packed >> 4) & 0xF
2189+
low = packed & 0xF
2190+
# concatenate to [..., 64], first 32 are low, last 32 are high
2191+
qw = torch.cat([low, high], dim=-1).to(torch.uint8) # [..., 64]
2192+
2193+
# -> [N/BLOCK_N, K/2, BLOCK_N, 2] -> [N, K]
2194+
qw = qw.reshape(N // BLOCK_N, K_half, BLOCK_N, 2) # [N/B, K/2, B, 2]
2195+
qw = qw.transpose(-3, -2).contiguous() # [N/B, B, K/2, 2]
2196+
qw = qw.reshape(N, K) # [N, K]
2197+
2198+
qweight = qw # [N, K]
2199+
2200+
unpacked_w = qweight.reshape(-1).to(torch.int32) # [K*N]
2201+
high4 = (unpacked_w[::2] & 0xF).to(torch.uint8)
2202+
low4 = (unpacked_w[1::2] & 0xF).to(torch.uint8)
2203+
qweight = (high4 << 4) | low4 # [K*N/2]
2204+
2205+
# 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.)
2206+
recovered_state = quant_state
2207+
qweight = qweight.to(torch.uint8).reshape(recovered_state.original_qshape)
2208+
2209+
# quantize absmax
2210+
if recovered_state.original_nested:
2211+
absmax = recovered_state.absmax.T.reshape(-1).to(recovered_state.original_dtype)
2212+
offset = absmax.mean()
2213+
qabsmax, state2 = quantize_blockwise(absmax - offset, blocksize=256)
2214+
recovered_state.absmax = qabsmax
2215+
recovered_state.offset = offset
2216+
recovered_state.state2 = state2
2217+
recovered_state.nested = True
2218+
2219+
recovered_state.dtype = recovered_state.original_dtype
2220+
recovered_state.packing_format_for_cpu = False
2221+
2222+
if getattr(recovered_state, "original_storage_type", None):
2223+
qweight = qweight.view(recovered_state.original_storage_type)
2224+
2225+
return qweight, recovered_state
2226+
2227+
2228+
def has_avx512bf16():
2229+
"""
2230+
Try calling native lib.has_avx512bf16_cpu().
2231+
Return False explicitly if symbol missing or call fails.
2232+
"""
2233+
try:
2234+
support_avx_bf16 = lib.has_avx512bf16_cpu()
2235+
except (AttributeError, RuntimeError, OSError):
2236+
support_avx_bf16 = False
2237+
return support_avx_bf16
2238+
2239+
21062240
C = 127.0

bitsandbytes/nn/modules.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
import bitsandbytes as bnb
1414
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
15-
from bitsandbytes.functional import QuantState
15+
from bitsandbytes.functional import (
16+
QuantState,
17+
_convert_weight_packed_for_cpu,
18+
_convert_weight_packed_for_cpu_inverse,
19+
has_avx512bf16,
20+
)
1621
from bitsandbytes.optim import GlobalOptimManager
1722
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
1823

@@ -311,9 +316,13 @@ def cpu(self):
311316
return self.to(device="cpu")
312317

313318
def cuda(self, device: Optional[int | device | str] = None, non_blocking: bool = False):
319+
if getattr(self.quant_state, "packing_format_for_cpu", False):
320+
self.data, self.quant_state = _convert_weight_packed_for_cpu_inverse(self.data, self.quant_state)
314321
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
315322

316323
def xpu(self, device: Optional[int | device | str] = None, non_blocking: bool = False):
324+
if getattr(self.quant_state, "packing_format_for_cpu", False):
325+
self.data, self.quant_state = _convert_weight_packed_for_cpu_inverse(self.data, self.quant_state)
317326
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
318327

319328
@overload
@@ -479,6 +488,7 @@ def __init__(
479488
self.compute_type_is_set = compute_dtype is not None
480489
self.quant_state = None
481490
self.quant_storage = quant_storage
491+
self.support_avx512bf16_for_cpu = has_avx512bf16()
482492

483493
def set_compute_type(self, x):
484494
if x.dtype in [torch.float32, torch.bfloat16]:
@@ -505,14 +515,27 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
505515
save weight and bias,
506516
then fill state_dict with components of quant_state
507517
"""
518+
if getattr(self.weight.quant_state, "packing_format_for_cpu", False):
519+
self.weight.data, self.weight.quant_state = _convert_weight_packed_for_cpu_inverse(
520+
self.weight.data, self.weight.quant_state
521+
)
508522
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
509-
510523
if getattr(self.weight, "quant_state", None) is not None:
511524
for k, v in self.weight.quant_state.as_dict(packed=True).items():
512525
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
513526

514527
def forward(self, x: torch.Tensor):
515528
fix_4bit_weight_quant_state_from_module(self)
529+
quant_state = self.weight.quant_state
530+
531+
if (
532+
not getattr(quant_state, "packing_format_for_cpu", False)
533+
and x.device.type == "cpu"
534+
and self.support_avx512bf16_for_cpu
535+
and not self.training
536+
and x.requires_grad == False
537+
):
538+
self.weight.data, quant_state = _convert_weight_packed_for_cpu(self.weight.data, quant_state)
516539

517540
# weights are cast automatically as Int8Params, but the bias has to be cast manually
518541
if self.bias is not None and self.bias.dtype != x.dtype:
@@ -527,9 +550,9 @@ def forward(self, x: torch.Tensor):
527550
x = x.to(self.compute_dtype)
528551

529552
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
530-
weight = self.weight.t()
553+
weight = self.weight if getattr(quant_state, "packing_format_for_cpu", False) else self.weight.t()
531554

532-
return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
555+
return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype)
533556

534557

535558
class LinearFP4(Linear4bit):

0 commit comments

Comments
 (0)