Skip to content

Commit 3c71007

Browse files
authored
Hf kernel (#1814)
* enable hf kernel Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add kernels dep Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * optional for kernels Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update kernel Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent a65a985 commit 3c71007

File tree

3 files changed

+48
-34
lines changed

3 files changed

+48
-34
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,17 @@ def _(
219219
return out
220220

221221
if has_avx512bf16():
222+
gemm_4bit_forward_kernel = None
223+
try:
224+
from kernels import get_kernel
225+
226+
gemm_4bit_forward_kernel = get_kernel("kernels-community/quantization_bitsandbytes").gemm_4bit_forward
227+
except Exception as exc: # pragma: no cover - best effort fallback
228+
gemm_4bit_forward_kernel = None
229+
logger.warning(
230+
"Failed to load CPU gemm_4bit_forward from kernels-community: %s. Please make sure you already `pip install kernels` and the kernels >= 0.11.1",
231+
exc,
232+
)
222233

223234
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
224235
def _(
@@ -239,38 +250,42 @@ def _(
239250
final_out_shape = (*A.shape[:-1], shapeB[0])
240251
A = A.reshape(-1, A.shape[-1])
241252
out_shape = (*A.shape[:-1], shapeB[0])
242-
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
243-
M = A.shape[0]
244-
N = shapeB[0]
245-
K = A.shape[1]
246-
x_strideM = A.stride(0)
247-
out_strideM = out.stride(0)
248-
if quant_type == "fp4":
249-
lib.gemv_4bit_inference_cpu_fp4_bf16(
250-
ct.c_int64(M),
251-
ct.c_int64(N),
252-
ct.c_int64(K),
253-
get_ptr(A),
254-
get_ptr(B),
255-
get_ptr(absmax),
256-
get_ptr(out),
257-
ct.c_int64(blocksize),
258-
ct.c_int64(x_strideM),
259-
ct.c_int64(out_strideM),
260-
)
261-
elif quant_type == "nf4":
262-
lib.gemv_4bit_inference_cpu_nf4_bf16(
263-
ct.c_int64(M),
264-
ct.c_int64(N),
265-
ct.c_int64(K),
266-
get_ptr(A),
267-
get_ptr(B),
268-
get_ptr(absmax),
269-
get_ptr(out),
270-
ct.c_int64(blocksize),
271-
ct.c_int64(x_strideM),
272-
ct.c_int64(out_strideM),
273-
)
253+
if gemm_4bit_forward_kernel is not None:
254+
quant_type_num = 1 if quant_type == "fp4" else 0
255+
out = gemm_4bit_forward_kernel(A, B, absmax, blocksize, quant_type_num)
256+
else:
257+
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
258+
M = A.shape[0]
259+
N = shapeB[0]
260+
K = A.shape[1]
261+
x_strideM = A.stride(0)
262+
out_strideM = out.stride(0)
263+
if quant_type == "fp4":
264+
lib.gemv_4bit_inference_cpu_fp4_bf16(
265+
ct.c_int64(M),
266+
ct.c_int64(N),
267+
ct.c_int64(K),
268+
get_ptr(A),
269+
get_ptr(B),
270+
get_ptr(absmax),
271+
get_ptr(out),
272+
ct.c_int64(blocksize),
273+
ct.c_int64(x_strideM),
274+
ct.c_int64(out_strideM),
275+
)
276+
elif quant_type == "nf4":
277+
lib.gemv_4bit_inference_cpu_nf4_bf16(
278+
ct.c_int64(M),
279+
ct.c_int64(N),
280+
ct.c_int64(K),
281+
get_ptr(A),
282+
get_ptr(B),
283+
get_ptr(absmax),
284+
get_ptr(out),
285+
ct.c_int64(blocksize),
286+
ct.c_int64(x_strideM),
287+
ct.c_int64(out_strideM),
288+
)
274289

275290
if dtype != torch.bfloat16:
276291
out = out.to(dtype)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ classifiers = [
4545
dependencies = [
4646
"torch>=2.3,<3",
4747
"numpy>=1.17",
48-
"packaging>=20.9"
48+
"packaging>=20.9",
4949
]
5050

5151
[project.urls]

tests/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
237237
quant_type=quant_type,
238238
)
239239
B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state)
240-
B_q = B_q.t()
241240
absmax = state.absmax
242241
out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)
243242

0 commit comments

Comments
 (0)