Skip to content

Commit 4aecedf

Browse files
TimDettmersclaude
andcommitted
test: Add comprehensive VQ correctness tests for scalar GEMV and dispatch
- test_scalar_gemv.py: VQ roundtrip MSE (p=2,4), flat vs tiled dequant, scalar GEMV for all (M=1-4, p=2,4, fp16/bf16), edge K dimensions (32,64,2048,5120), flat vs tiled GEMV consistency, large shapes - test_kbit_gemm.py: VQ dequant+cuBLAS path (M=8-64, p=2,4), vq_linear dispatch across full M range, dtype, real model shapes, workspace/preallocated output paths, MMA kernel stubs (skipped) - All 266 tests pass (50 VQ scalar GEMV + 42 VQ dispatch + 174 kbit) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f08f614 commit 4aecedf

File tree

2 files changed

+498
-0
lines changed

2 files changed

+498
-0
lines changed

tests/test_kbit_gemm.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,3 +929,281 @@ def test_prod_mblock1_matches_reference(self):
929929
f"M_BLOCKS=1 regression: prod does not match reference.\n"
930930
f"Max diff: {(C_prod_cpu - C_direct).abs().max().item():.6f}"
931931
)
932+
933+
934+
# ===========================================================================
935+
# VQ Codebook Tests: dequant+cuBLAS path and vq_linear dispatch
936+
# ===========================================================================
937+
938+
939+
def _vq_dequant_matmul_ref(A, W, p, codebook=None):
940+
"""Reference: quantize W with VQ, dequantize, then matmul in float32.
941+
942+
Returns (C_ref, packed_flat, absmax_flat, codebook).
943+
"""
944+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, dequantize_vq
945+
946+
if codebook is None:
947+
codebook = create_vq_codebook(p, device="cuda")
948+
W_gpu = W.half().cuda()
949+
packed, absmax, codebook = quantize_vq(W_gpu, p=p, codebook=codebook)
950+
n_total = W.numel()
951+
W_deq = dequantize_vq(packed, absmax, codebook, p=p, n=n_total, dtype=torch.float16)
952+
W_deq = W_deq.reshape(W.shape)
953+
A_gpu = A.float().cuda()
954+
C_ref = (A_gpu @ W_deq.float().T).cpu()
955+
return C_ref, packed, absmax, codebook
956+
957+
958+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
959+
class TestVQDequantCublas:
960+
"""Test VQ dequantize + cuBLAS matmul path (M > 4 fallback in vq_linear)."""
961+
962+
@pytest.mark.parametrize("p", [2, 4])
963+
@pytest.mark.parametrize("M", [8, 16, 32, 64])
964+
def test_dequant_cublas_correctness(self, p, M):
965+
"""Tiled dequant + matmul matches flat dequant + matmul reference."""
966+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, repack_vq
967+
968+
K_dim, N = 512, 256
969+
torch.manual_seed(42)
970+
971+
W = torch.randn(N, K_dim)
972+
codebook = create_vq_codebook(p, device="cuda")
973+
W_gpu = W.half().cuda()
974+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
975+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
976+
977+
# Tiled dequant
978+
W_tiled = torch.ops.bitsandbytes.dequantize_vq_tiled(
979+
packed_tiled, codebook, absmax_tiled, p, K_dim, N, torch.float16,
980+
)
981+
W_tiled = W_tiled.reshape(N, K_dim)
982+
983+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
984+
C_tiled = (A.float() @ W_tiled.float().T).half()
985+
986+
# Flat dequant reference
987+
from bitsandbytes.functional import dequantize_vq
988+
989+
W_flat = dequantize_vq(packed_flat, absmax_flat, codebook, p=p, n=N * K_dim)
990+
W_flat = W_flat.reshape(N, K_dim)
991+
C_ref = (A.float() @ W_flat.float().T).half()
992+
993+
# Should be bit-identical since same dequant values
994+
diff = (C_tiled.float() - C_ref.float()).abs()
995+
scale = C_ref.float().abs().clamp(min=1.0)
996+
rel_err = (diff / scale).max().item()
997+
assert rel_err < 0.01, (
998+
f"p={p}, M={M}: tiled dequant+matmul vs flat dequant+matmul mismatch. "
999+
f"Max rel err: {rel_err:.6f}"
1000+
)
1001+
1002+
@pytest.mark.parametrize("p", [2, 4])
1003+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1004+
def test_dequant_cublas_dtype(self, p, dtype):
1005+
"""Tiled dequant works with both fp16 and bf16."""
1006+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, repack_vq
1007+
1008+
K_dim, N, M = 512, 256, 16
1009+
torch.manual_seed(42)
1010+
1011+
W = torch.randn(N, K_dim)
1012+
codebook = create_vq_codebook(p, device="cuda")
1013+
W_gpu = W.to(dtype).cuda()
1014+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
1015+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
1016+
1017+
W_tiled = torch.ops.bitsandbytes.dequantize_vq_tiled(
1018+
packed_tiled, codebook, absmax_tiled, p, K_dim, N, torch.float16,
1019+
)
1020+
# Output should be fp16 (codebook is fp16)
1021+
assert W_tiled.dtype == torch.float16, f"Expected fp16, got {W_tiled.dtype}"
1022+
1023+
1024+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
1025+
class TestVQLinearDispatch:
1026+
"""Test the vq_linear dispatch function across the full M range."""
1027+
1028+
@pytest.mark.parametrize("p", [2, 4])
1029+
@pytest.mark.parametrize("M", [1, 2, 3, 4])
1030+
def test_vq_linear_scalar_gemv_path(self, p, M):
1031+
"""vq_linear dispatches to scalar GEMV for M<=4."""
1032+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, repack_vq, vq_linear
1033+
1034+
K_dim, N = 2048, 512
1035+
torch.manual_seed(42)
1036+
1037+
W = torch.randn(N, K_dim)
1038+
codebook = create_vq_codebook(p, device="cuda")
1039+
W_gpu = W.half().cuda()
1040+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
1041+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
1042+
1043+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
1044+
C = vq_linear(A, packed_tiled, absmax_tiled, codebook, p, K_dim, N)
1045+
1046+
# Reference: dequant flat + matmul
1047+
from bitsandbytes.functional import dequantize_vq
1048+
1049+
W_deq = dequantize_vq(packed_flat, absmax_flat, codebook, p=p, n=N * K_dim)
1050+
W_deq = W_deq.reshape(N, K_dim)
1051+
C_ref = (A.float() @ W_deq.float().T).to(A.dtype)
1052+
1053+
diff = (C.float() - C_ref.float()).abs()
1054+
scale = C_ref.float().abs().clamp(min=1.0)
1055+
rel_err = (diff / scale).max().item()
1056+
assert rel_err < 0.10, (
1057+
f"p={p}, M={M}: vq_linear scalar GEMV path mismatch. Max rel err: {rel_err:.6f}"
1058+
)
1059+
1060+
@pytest.mark.parametrize("p", [2, 4])
1061+
@pytest.mark.parametrize("M", [8, 16, 32, 64])
1062+
def test_vq_linear_cublas_path(self, p, M):
1063+
"""vq_linear dispatches to dequant+cuBLAS for M>4."""
1064+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, repack_vq, vq_linear
1065+
1066+
K_dim, N = 512, 256
1067+
torch.manual_seed(42)
1068+
1069+
W = torch.randn(N, K_dim)
1070+
codebook = create_vq_codebook(p, device="cuda")
1071+
W_gpu = W.half().cuda()
1072+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
1073+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
1074+
1075+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
1076+
C = vq_linear(A, packed_tiled, absmax_tiled, codebook, p, K_dim, N)
1077+
1078+
# Reference
1079+
from bitsandbytes.functional import dequantize_vq
1080+
1081+
W_deq = dequantize_vq(packed_flat, absmax_flat, codebook, p=p, n=N * K_dim)
1082+
W_deq = W_deq.reshape(N, K_dim)
1083+
C_ref = (A.float() @ W_deq.float().T).to(A.dtype)
1084+
1085+
diff = (C.float() - C_ref.float()).abs()
1086+
scale = C_ref.float().abs().clamp(min=1.0)
1087+
rel_err = (diff / scale).max().item()
1088+
assert rel_err < 0.05, (
1089+
f"p={p}, M={M}: vq_linear cuBLAS path mismatch. Max rel err: {rel_err:.6f}"
1090+
)
1091+
1092+
@pytest.mark.parametrize("p", [2, 4])
1093+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1094+
def test_vq_linear_output_dtype(self, p, dtype):
1095+
"""vq_linear output has same dtype as input A."""
1096+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, repack_vq, vq_linear
1097+
1098+
K_dim, N, M = 512, 256, 2
1099+
torch.manual_seed(42)
1100+
1101+
W = torch.randn(N, K_dim)
1102+
codebook = create_vq_codebook(p, device="cuda")
1103+
W_gpu = W.to(dtype).cuda()
1104+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
1105+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
1106+
1107+
A = torch.randn(M, K_dim, dtype=dtype, device="cuda")
1108+
C = vq_linear(A, packed_tiled, absmax_tiled, codebook, p, K_dim, N)
1109+
assert C.dtype == dtype, f"Expected {dtype}, got {C.dtype}"
1110+
1111+
@pytest.mark.parametrize("p", [2, 4])
1112+
@pytest.mark.parametrize(
1113+
"K_dim,N",
1114+
[
1115+
(2048, 5120),
1116+
(5120, 2048),
1117+
(2048, 4096),
1118+
],
1119+
)
1120+
def test_vq_linear_real_shapes(self, p, K_dim, N):
1121+
"""vq_linear works with shapes from real model projections."""
1122+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, repack_vq, vq_linear
1123+
1124+
M = 1
1125+
torch.manual_seed(42)
1126+
1127+
W = torch.randn(N, K_dim)
1128+
codebook = create_vq_codebook(p, device="cuda")
1129+
W_gpu = W.half().cuda()
1130+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
1131+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
1132+
1133+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
1134+
C = vq_linear(A, packed_tiled, absmax_tiled, codebook, p, K_dim, N)
1135+
1136+
# Reference
1137+
from bitsandbytes.functional import dequantize_vq
1138+
1139+
W_deq = dequantize_vq(packed_flat, absmax_flat, codebook, p=p, n=N * K_dim)
1140+
W_deq = W_deq.reshape(N, K_dim)
1141+
C_ref = (A.float() @ W_deq.float().T).to(A.dtype)
1142+
1143+
diff = (C.float() - C_ref.float()).abs()
1144+
scale = C_ref.float().abs().clamp(min=1.0)
1145+
rel_err = (diff / scale).max().item()
1146+
assert rel_err < 0.10, (
1147+
f"p={p}, ({K_dim},{N}): vq_linear mismatch. Max rel err: {rel_err:.6f}"
1148+
)
1149+
1150+
@pytest.mark.parametrize("p", [2, 4])
1151+
def test_vq_linear_workspace(self, p):
1152+
"""vq_linear works with pre-allocated workspace."""
1153+
from bitsandbytes.functional import (
1154+
create_vq_codebook, quantize_vq, repack_vq, vq_linear, vq_linear_workspace,
1155+
)
1156+
1157+
K_dim, N, M = 512, 256, 32
1158+
torch.manual_seed(42)
1159+
1160+
W = torch.randn(N, K_dim)
1161+
codebook = create_vq_codebook(p, device="cuda")
1162+
W_gpu = W.half().cuda()
1163+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
1164+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
1165+
1166+
workspace = vq_linear_workspace(M, K_dim, N, p, torch.float16, torch.device("cuda"))
1167+
1168+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
1169+
C = vq_linear(A, packed_tiled, absmax_tiled, codebook, p, K_dim, N, workspace=workspace)
1170+
1171+
# Reference (without workspace)
1172+
C_ref = vq_linear(A, packed_tiled, absmax_tiled, codebook, p, K_dim, N)
1173+
1174+
assert torch.equal(C, C_ref), (
1175+
f"p={p}: workspace path differs from non-workspace path. "
1176+
f"Max diff: {(C.float() - C_ref.float()).abs().max().item()}"
1177+
)
1178+
1179+
@pytest.mark.parametrize("p", [2, 4])
1180+
def test_vq_linear_preallocated_output(self, p):
1181+
"""vq_linear works with pre-allocated output tensor."""
1182+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, repack_vq, vq_linear
1183+
1184+
K_dim, N, M = 512, 256, 2
1185+
torch.manual_seed(42)
1186+
1187+
W = torch.randn(N, K_dim)
1188+
codebook = create_vq_codebook(p, device="cuda")
1189+
W_gpu = W.half().cuda()
1190+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
1191+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
1192+
1193+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
1194+
out = torch.empty(M, N, dtype=torch.float16, device="cuda")
1195+
C = vq_linear(A, packed_tiled, absmax_tiled, codebook, p, K_dim, N, out=out)
1196+
1197+
# Verify it used the pre-allocated buffer
1198+
assert C.data_ptr() == out.data_ptr(), "vq_linear didn't use pre-allocated output"
1199+
1200+
# Verify correctness
1201+
C_ref = vq_linear(A, packed_tiled, absmax_tiled, codebook, p, K_dim, N)
1202+
assert torch.equal(C, C_ref), "Pre-allocated output differs from fresh output"
1203+
1204+
@pytest.mark.parametrize("p", [2, 4])
1205+
@pytest.mark.parametrize("M", [5, 8, 16, 32])
1206+
@pytest.mark.skip(reason="Task 5 (VQ MMA kernel) not yet implemented")
1207+
def test_vq_mma_kernel(self, p, M):
1208+
"""VQ MMA kernel correctness (placeholder for Task 5)."""
1209+
pass

0 commit comments

Comments
 (0)