Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,21 @@ def _(
return out

if has_avx512bf16():
gemm_4bit_forward_kernel = None
try:
from pathlib import Path

from kernels import get_local_kernel

gemm_4bit_forward_kernel = get_local_kernel(
repo_path=Path(
"/workspace/nix/nix/store/vvsb2xvj5zkzfd37r1k1d5j23hpa9n86-quantization_bitsandbytes-torch-ext"
),
package_name="quantization_bitsandbytes",
).gemm_4bit_forward
except Exception as exc: # pragma: no cover - best effort fallback
gemm_4bit_forward_kernel = None
logger.warning("Failed to load CPU gemm_4bit kernel: %s", exc)

@register_kernel("bitsandbytes::gemv_4bit", "cpu")
def _(
Expand All @@ -239,38 +254,42 @@ def _(
final_out_shape = (*A.shape[:-1], shapeB[0])
A = A.reshape(-1, A.shape[-1])
out_shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
M = A.shape[0]
N = shapeB[0]
K = A.shape[1]
x_strideM = A.stride(0)
out_strideM = out.stride(0)
if quant_type == "fp4":
lib.gemv_4bit_inference_cpu_fp4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
elif quant_type == "nf4":
lib.gemv_4bit_inference_cpu_nf4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
quant_type_num = 1 if quant_type == "fp4" else 0
if gemm_4bit_forward_kernel is not None:
out = gemm_4bit_forward_kernel(A, B, absmax, blocksize, quant_type_num)
else:
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
M = A.shape[0]
N = shapeB[0]
K = A.shape[1]
x_strideM = A.stride(0)
out_strideM = out.stride(0)
if quant_type == "fp4":
lib.gemv_4bit_inference_cpu_fp4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
elif quant_type == "nf4":
lib.gemv_4bit_inference_cpu_nf4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)

if dtype != torch.bfloat16:
out = out.to(dtype)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ classifiers = [
dependencies = [
"torch>=2.3,<3",
"numpy>=1.17",
"packaging>=20.9"
"packaging>=20.9",
"kernels>=0.11.1"
Comment thread
jiqing-feng marked this conversation as resolved.
Outdated
]

[project.urls]
Expand Down
1 change: 0 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
quant_type=quant_type,
)
B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state)
B_q = B_q.t()
absmax = state.absmax
out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)

Expand Down