Skip to content

Commit a47ad9f

Browse files
TimDettmersclaude
andcommitted
feat: Add VQ quantize/dequantize CUDA kernels and Python API
- kQuantize_VQ<P_VAL, scalar_t>: warp-level VQ quantizer with shared memory codebook. Finds nearest codebook entry for each group of P weights via brute-force L2 distance in P-dimensional space. - kDequantize_VQ<P_VAL, T, ABSMAX_T>: flat-layout VQ dequantizer with codebook lookup and absmax scaling. - Launch wrappers, extern C bindings, torch op registration, and Python wrappers (quantize_vq, dequantize_vq) for both p=2 and p=4. - Verified: MSE=0.039 (p=2, 4 bits/wt), MSE=0.107 (p=4, 2 bits/wt). Supports fp16, bf16 input types and E4M4/fp32 absmax formats. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 319c24e commit a47ad9f

File tree

5 files changed

+501
-0
lines changed

5 files changed

+501
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,66 @@ def _(
557557
return out
558558

559559

560+
# VQ (Vector Quantization) quantize/dequantize
561+
562+
torch.library.define(
563+
"bitsandbytes::quantize_vq",
564+
"(Tensor A, Tensor codebook, int p) -> (Tensor, Tensor)",
565+
)
566+
567+
568+
@register_fake("bitsandbytes::quantize_vq")
569+
def _(A: torch.Tensor, codebook: torch.Tensor, p: int) -> tuple[torch.Tensor, torch.Tensor]:
570+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
571+
torch._check(codebook.shape == (256, p), lambda: f"codebook must be [256, {p}], got {codebook.shape}")
572+
n = A.numel()
573+
num_blocks = -(n // -32)
574+
words_per_block = 32 // p // 4 # p=2: 4, p=4: 2
575+
packed = torch.empty(num_blocks * words_per_block, device=A.device, dtype=torch.int32)
576+
absmax = torch.empty(num_blocks, device=A.device, dtype=torch.uint8)
577+
return packed, absmax
578+
579+
580+
torch.library.define(
581+
"bitsandbytes::dequantize_vq",
582+
"(Tensor packed, Tensor codebook, Tensor absmax, int p, int n, ScalarType dtype) -> Tensor",
583+
)
584+
585+
586+
@register_fake("bitsandbytes::dequantize_vq")
587+
def _(
588+
packed: torch.Tensor,
589+
codebook: torch.Tensor,
590+
absmax: torch.Tensor,
591+
p: int,
592+
n: int,
593+
dtype: torch.dtype,
594+
) -> torch.Tensor:
595+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
596+
num_blocks = -(n // -32)
597+
return torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)
598+
599+
600+
torch.library.define(
601+
"bitsandbytes::dequantize_vq_",
602+
"(Tensor packed, Tensor codebook, Tensor absmax, int p, int n, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)",
603+
)
604+
605+
606+
@register_fake("bitsandbytes::dequantize_vq_")
607+
def _(
608+
packed: torch.Tensor,
609+
codebook: torch.Tensor,
610+
absmax: torch.Tensor,
611+
p: int,
612+
n: int,
613+
dtype: torch.dtype,
614+
out: torch.Tensor,
615+
) -> torch.Tensor:
616+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
617+
return out
618+
619+
560620
# K-bit repack: flat bit-plane layout -> GEMM-tiled layout
561621

562622
torch.library.define(

bitsandbytes/backends/cuda/ops.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,109 @@ def _(
956956
return out
957957

958958

959+
_VQ_DTYPE_SUFFIX = {
960+
torch.float16: "fp16",
961+
torch.bfloat16: "bf16",
962+
torch.float32: "fp32",
963+
}
964+
965+
966+
@register_kernel("bitsandbytes::quantize_vq", "cuda")
967+
def _(A: torch.Tensor, codebook: torch.Tensor, p: int) -> tuple[torch.Tensor, torch.Tensor]:
968+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
969+
torch._check(
970+
A.dtype in _VQ_DTYPE_SUFFIX,
971+
lambda: f"quantize_vq only supports float16/bfloat16/float32, got {A.dtype}",
972+
)
973+
torch._check(codebook.dtype == torch.float16, lambda: f"codebook must be float16, got {codebook.dtype}")
974+
975+
n = A.numel()
976+
num_blocks = -(n // -32)
977+
words_per_block = 32 // p // 4
978+
packed = torch.zeros(num_blocks * words_per_block, device=A.device, dtype=torch.int32)
979+
absmax = torch.zeros(num_blocks, device=A.device, dtype=torch.uint8)
980+
981+
with _cuda_device_of(A):
982+
tname = _VQ_DTYPE_SUFFIX[A.dtype]
983+
fn = getattr(lib, f"cquantize_vq_{tname}_p{p}")
984+
fn(
985+
get_ptr(codebook),
986+
get_ptr(A),
987+
get_ptr(absmax),
988+
get_ptr(packed),
989+
ct.c_int(n),
990+
_get_tensor_stream(A),
991+
)
992+
993+
return packed, absmax
994+
995+
996+
def _dequantize_vq_impl(
997+
packed: torch.Tensor,
998+
codebook: torch.Tensor,
999+
absmax: torch.Tensor,
1000+
p: int,
1001+
n: int,
1002+
dtype: torch.dtype,
1003+
out: torch.Tensor,
1004+
) -> None:
1005+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
1006+
torch._check(
1007+
dtype in _VQ_DTYPE_SUFFIX,
1008+
lambda: f"dequantize_vq only supports float16/bfloat16/float32, got {dtype}",
1009+
)
1010+
torch._check(codebook.dtype == torch.float16, lambda: f"codebook must be float16, got {codebook.dtype}")
1011+
1012+
# If fp32 absmax, encode to E4M4 first
1013+
if absmax.dtype == torch.float32:
1014+
from bitsandbytes.functional import encode_absmax_e4m4
1015+
1016+
absmax = encode_absmax_e4m4(absmax)
1017+
1018+
tname = _VQ_DTYPE_SUFFIX[dtype]
1019+
aname = _KBIT_ABSMAX_SUFFIX[absmax.dtype]
1020+
1021+
with _cuda_device_of(packed):
1022+
fn = getattr(lib, f"cdequantize_vq_{tname}_{aname}_p{p}")
1023+
fn(
1024+
get_ptr(packed),
1025+
get_ptr(codebook),
1026+
get_ptr(absmax),
1027+
get_ptr(out),
1028+
ct.c_int(n),
1029+
_get_tensor_stream(packed),
1030+
)
1031+
1032+
1033+
@register_kernel("bitsandbytes::dequantize_vq", "cuda")
1034+
def _(
1035+
packed: torch.Tensor,
1036+
codebook: torch.Tensor,
1037+
absmax: torch.Tensor,
1038+
p: int,
1039+
n: int,
1040+
dtype: torch.dtype,
1041+
) -> torch.Tensor:
1042+
num_blocks = -(n // -32)
1043+
out = torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)
1044+
_dequantize_vq_impl(packed, codebook, absmax, p, n, dtype, out)
1045+
return out
1046+
1047+
1048+
@register_kernel("bitsandbytes::dequantize_vq_", "cuda")
1049+
def _(
1050+
packed: torch.Tensor,
1051+
codebook: torch.Tensor,
1052+
absmax: torch.Tensor,
1053+
p: int,
1054+
n: int,
1055+
dtype: torch.dtype,
1056+
out: torch.Tensor,
1057+
) -> torch.Tensor:
1058+
_dequantize_vq_impl(packed, codebook, absmax, p, n, dtype, out)
1059+
return out
1060+
1061+
9591062
@register_kernel("bitsandbytes::repack_kbit", "cuda")
9601063
def _(
9611064
packed_flat: torch.Tensor,

bitsandbytes/functional.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,72 @@ def dequantize_kbit(
13101310
return result[:n]
13111311

13121312

1313+
def quantize_vq(
1314+
A: Tensor,
1315+
p: int = 2,
1316+
codebook: Optional[Tensor] = None,
1317+
) -> tuple[Tensor, Tensor, Tensor]:
1318+
"""Quantize a tensor using VQ codebook quantization (blocksize=32).
1319+
1320+
Each group of p consecutive weights is mapped to the nearest entry in a
1321+
256-entry codebook. Produces 8/p bits per weight (p=2: 4 bits, p=4: 2 bits).
1322+
1323+
Args:
1324+
A: Input tensor. Supports float16, bfloat16, or float32.
1325+
p: VQ dimension (2 or 4). Each byte index maps to p weight values.
1326+
codebook: Optional fp16 codebook tensor of shape [256, p].
1327+
If None, uses precomputed Gaussian codebook.
1328+
1329+
Returns:
1330+
Tuple of (packed, absmax, codebook):
1331+
- packed: int32 tensor of packed byte indices.
1332+
- absmax: uint8 tensor of E4M4 per-block absmax values.
1333+
- codebook: The codebook tensor used.
1334+
"""
1335+
if codebook is None:
1336+
codebook = create_vq_codebook(p, device=A.device)
1337+
else:
1338+
codebook = codebook.to(device=A.device, dtype=torch.float16)
1339+
1340+
A_flat = A.contiguous().view(-1)
1341+
packed, absmax = torch.ops.bitsandbytes.quantize_vq(A_flat, codebook, p)
1342+
return packed, absmax, codebook
1343+
1344+
1345+
def dequantize_vq(
1346+
packed: Tensor,
1347+
absmax: Tensor,
1348+
codebook: Tensor,
1349+
p: int,
1350+
n: int,
1351+
dtype: torch.dtype = torch.float16,
1352+
out: Optional[Tensor] = None,
1353+
) -> Tensor:
1354+
"""Dequantize a VQ codebook quantized tensor.
1355+
1356+
Args:
1357+
packed: int32 tensor of packed byte indices (from quantize_vq).
1358+
absmax: Per-block absmax values (uint8 E4M4 or float32).
1359+
codebook: fp16 codebook tensor of shape [256, p].
1360+
p: VQ dimension (2 or 4).
1361+
n: Number of original elements.
1362+
dtype: Output dtype. Defaults to float16.
1363+
out: Optional pre-allocated output tensor.
1364+
1365+
Returns:
1366+
Dequantized tensor of shape (n,) with the given dtype.
1367+
"""
1368+
num_blocks = -(n // -32)
1369+
padded_n = num_blocks * 32
1370+
1371+
if out is not None:
1372+
torch.ops.bitsandbytes.dequantize_vq_(packed, codebook, absmax, p, n, dtype, out)
1373+
return out[:n]
1374+
1375+
result = torch.ops.bitsandbytes.dequantize_vq(packed, codebook, absmax, p, n, dtype)
1376+
return result[:n]
1377+
1378+
13131379
def dequantize_kbit_tiled(
13141380
packed: Tensor,
13151381
absmax: Tensor,

0 commit comments

Comments
 (0)