Skip to content

Commit f08f614

Browse files
TimDettmersclaude
andcommitted
feat: Add VQ tiled dequantize kernel and vq_linear dispatch
- kDequantize_VQ_tiled: reads tiled VQ layout, writes flat [N,K] output - Full registration chain for tiled dequant (ops, bindings, Python) - vq_linear() dispatch: M≤4 → vq_scalar_gemv_tiled, M>4 → dequant+cuBLAS - vq_linear_workspace() for CUDA graph compatibility - End-to-end pipeline verified: quantize→repack→vq_linear→correct output for all (p={2,4}, K={64,2048,5120}, N={128,512,5120}, M={1,4,8,32}) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent cf2f64a commit f08f614

File tree

5 files changed

+327
-0
lines changed

5 files changed

+327
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,50 @@ def _(
617617
return out
618618

619619

620+
# VQ tiled dequantize: reads tiled VQ layout, writes flat [N, K_dim] output
621+
622+
torch.library.define(
623+
"bitsandbytes::dequantize_vq_tiled",
624+
"(Tensor packed_tiled, Tensor codebook, Tensor absmax_tiled, int p, int K_dim, int N, ScalarType dtype) -> Tensor",
625+
)
626+
627+
628+
@register_fake("bitsandbytes::dequantize_vq_tiled")
629+
def _(
630+
packed_tiled: torch.Tensor,
631+
codebook: torch.Tensor,
632+
absmax_tiled: torch.Tensor,
633+
p: int,
634+
K_dim: int,
635+
N: int,
636+
dtype: torch.dtype,
637+
) -> torch.Tensor:
638+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
639+
return torch.empty(N * K_dim, device=packed_tiled.device, dtype=dtype)
640+
641+
642+
torch.library.define(
643+
"bitsandbytes::dequantize_vq_tiled_",
644+
"(Tensor packed_tiled, Tensor codebook, Tensor absmax_tiled, int p, int K_dim, int N, ScalarType dtype, "
645+
"Tensor(a!) out) -> Tensor(a!)",
646+
)
647+
648+
649+
@register_fake("bitsandbytes::dequantize_vq_tiled_")
650+
def _(
651+
packed_tiled: torch.Tensor,
652+
codebook: torch.Tensor,
653+
absmax_tiled: torch.Tensor,
654+
p: int,
655+
K_dim: int,
656+
N: int,
657+
dtype: torch.dtype,
658+
out: torch.Tensor,
659+
) -> torch.Tensor:
660+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
661+
return out
662+
663+
620664
# VQ scalar GEMV: byte-indexed codebook lookup GEMV for M=1-4
621665

622666
torch.library.define(

bitsandbytes/backends/cuda/ops.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,76 @@ def _(
10591059
return out
10601060

10611061

1062+
def _dequantize_vq_tiled_impl(
1063+
packed_tiled: torch.Tensor,
1064+
codebook: torch.Tensor,
1065+
absmax_tiled: torch.Tensor,
1066+
p: int,
1067+
K_dim: int,
1068+
N: int,
1069+
dtype: torch.dtype,
1070+
out: torch.Tensor,
1071+
) -> None:
1072+
torch._check(codebook.dtype == torch.float16, lambda: f"codebook must be float16, got {codebook.dtype}")
1073+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
1074+
1075+
if dtype in (torch.float16,):
1076+
tname = "fp16"
1077+
elif dtype == torch.bfloat16:
1078+
tname = "bf16"
1079+
else:
1080+
raise ValueError(f"dequantize_vq_tiled only supports float16/bfloat16, got {dtype}")
1081+
1082+
if absmax_tiled.dtype == torch.uint8:
1083+
aname = "u8abs"
1084+
elif absmax_tiled.dtype == torch.float32:
1085+
aname = "fp32abs"
1086+
else:
1087+
raise ValueError(f"absmax must be uint8 or float32, got {absmax_tiled.dtype}")
1088+
1089+
with _cuda_device_of(packed_tiled):
1090+
fn = getattr(lib, f"cdequantize_vq_tiled_{tname}_{aname}_p{p}")
1091+
fn(
1092+
get_ptr(packed_tiled),
1093+
get_ptr(codebook),
1094+
get_ptr(absmax_tiled),
1095+
get_ptr(out),
1096+
ct.c_int(K_dim),
1097+
ct.c_int(N),
1098+
_get_tensor_stream(packed_tiled),
1099+
)
1100+
1101+
1102+
@register_kernel("bitsandbytes::dequantize_vq_tiled", "cuda")
1103+
def _(
1104+
packed_tiled: torch.Tensor,
1105+
codebook: torch.Tensor,
1106+
absmax_tiled: torch.Tensor,
1107+
p: int,
1108+
K_dim: int,
1109+
N: int,
1110+
dtype: torch.dtype,
1111+
) -> torch.Tensor:
1112+
out = torch.empty(N * K_dim, device=packed_tiled.device, dtype=dtype)
1113+
_dequantize_vq_tiled_impl(packed_tiled, codebook, absmax_tiled, p, K_dim, N, dtype, out)
1114+
return out
1115+
1116+
1117+
@register_kernel("bitsandbytes::dequantize_vq_tiled_", "cuda")
1118+
def _(
1119+
packed_tiled: torch.Tensor,
1120+
codebook: torch.Tensor,
1121+
absmax_tiled: torch.Tensor,
1122+
p: int,
1123+
K_dim: int,
1124+
N: int,
1125+
dtype: torch.dtype,
1126+
out: torch.Tensor,
1127+
) -> torch.Tensor:
1128+
_dequantize_vq_tiled_impl(packed_tiled, codebook, absmax_tiled, p, K_dim, N, dtype, out)
1129+
return out
1130+
1131+
10621132
def _vq_scalar_gemv_impl(
10631133
A: torch.Tensor,
10641134
B_packed: torch.Tensor,

bitsandbytes/functional.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,90 @@ def kbit_linear_workspace(M: int, K_dim: int, N: int, dtype: torch.dtype, device
15451545
}
15461546

15471547

1548+
def vq_linear(
1549+
A: Tensor,
1550+
B_packed: Tensor,
1551+
B_absmax: Tensor,
1552+
codebook: Tensor,
1553+
p: int,
1554+
K_dim: int,
1555+
N: int,
1556+
out: Optional[Tensor] = None,
1557+
workspace: Optional[dict] = None,
1558+
) -> Tensor:
1559+
"""Unified dispatch for VQ codebook quantized linear (C = A @ B^T).
1560+
1561+
Routes to the optimal kernel based on M (batch dimension):
1562+
- M <= 4: scalar GEMV (tiled layout, shmem codebook lookup)
1563+
- M > 4: dequantize to fp16/bf16 + cuBLAS matmul
1564+
1565+
All paths read tiled B layout (from repack_vq output).
1566+
1567+
Args:
1568+
A: Input activations [M, K_dim], fp16 or bf16.
1569+
B_packed: Tiled VQ packed weights (from repack_vq).
1570+
B_absmax: Tiled per-block absmax values (from repack_vq).
1571+
codebook: fp16 codebook tensor [256, p].
1572+
p: VQ dimension (2 or 4).
1573+
K_dim: Reduction dimension of weight matrix.
1574+
N: Output dimension of weight matrix.
1575+
out: Optional pre-allocated output [M, N] for CUDA graph compat.
1576+
workspace: Optional dict with pre-allocated buffers:
1577+
'dequant_buf': fp16/bf16 [N * K_dim] for dequant+matmul path
1578+
1579+
Returns:
1580+
Output tensor [M, N] with same dtype as A.
1581+
"""
1582+
M = A.shape[0]
1583+
dtype = A.dtype
1584+
1585+
if M <= 4:
1586+
# Scalar GEMV: tiled layout, shared memory codebook lookup
1587+
if out is not None:
1588+
return torch.ops.bitsandbytes.vq_scalar_gemv_tiled_(
1589+
A, B_packed, B_absmax, codebook, K_dim, N, p, out[:M]
1590+
)
1591+
return torch.ops.bitsandbytes.vq_scalar_gemv_tiled(A, B_packed, B_absmax, codebook, K_dim, N, p)
1592+
1593+
# M > 4: dequantize tiled VQ to dense + cuBLAS matmul
1594+
if workspace is not None and "dequant_buf" in workspace:
1595+
dequant_buf = workspace["dequant_buf"]
1596+
torch.ops.bitsandbytes.dequantize_vq_tiled_(
1597+
B_packed, codebook, B_absmax, p, K_dim, N, dtype, dequant_buf
1598+
)
1599+
W = dequant_buf[: N * K_dim].view(N, K_dim)
1600+
else:
1601+
W_flat = torch.ops.bitsandbytes.dequantize_vq_tiled(B_packed, codebook, B_absmax, p, K_dim, N, dtype)
1602+
W = W_flat[: N * K_dim].view(N, K_dim)
1603+
1604+
if out is not None:
1605+
torch.mm(A, W.t(), out=out[:M])
1606+
return out[:M]
1607+
return torch.mm(A, W.t())
1608+
1609+
1610+
def vq_linear_workspace(M: int, K_dim: int, N: int, p: int, dtype: torch.dtype, device: torch.device) -> dict:
1611+
"""Pre-allocate workspace buffers for vq_linear (CUDA graph compatibility).
1612+
1613+
Args:
1614+
M: Maximum batch size (must be >= actual M at runtime).
1615+
K_dim: Reduction dimension.
1616+
N: Output dimension.
1617+
p: VQ dimension (2 or 4).
1618+
dtype: Activation dtype (fp16 or bf16).
1619+
device: CUDA device.
1620+
1621+
Returns:
1622+
Dict with 'dequant_buf' tensor.
1623+
"""
1624+
n_total = N * K_dim
1625+
num_blocks = -(n_total // -32)
1626+
1627+
return {
1628+
"dequant_buf": torch.empty(num_blocks * 32, device=device, dtype=dtype),
1629+
}
1630+
1631+
15481632
def kbit_expert_linear(
15491633
A_concat: Tensor,
15501634
B_packed_all: Tensor,

csrc/ops.cu

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,65 @@ __global__ void kDequantize_VQ(
980980
}
981981

982982

983+
// ---- VQ tiled dequantize kernel ----
984+
// Reads from tiled VQ layout (from repack_vq output), writes flat [N, K_dim] row-major.
985+
986+
template <int P_VAL, typename T, typename ABSMAX_T>
987+
__global__ void kDequantize_VQ_tiled(
988+
const unsigned int* __restrict__ packed_tiled,
989+
const half* __restrict__ codebook,
990+
const ABSMAX_T* __restrict__ absmax_tiled,
991+
T* __restrict__ out,
992+
const int K_dim, const int N
993+
) {
994+
constexpr int BS = 32;
995+
constexpr int TILE_K = 64;
996+
constexpr int TILE_N = 128;
997+
constexpr int KB_PER_TILE = TILE_K / BS;
998+
constexpr int WORDS_PER_BLOCK = BS / (P_VAL * 4);
999+
constexpr int WORDS_PER_TILE = TILE_N * KB_PER_TILE * WORDS_PER_BLOCK;
1000+
constexpr int ABS_PER_TILE = TILE_N * KB_PER_TILE;
1001+
constexpr int GROUPS_PER_BLOCK = BS / P_VAL;
1002+
1003+
const int num_k_blocks = K_dim / BS;
1004+
const int n_tiles = N / TILE_N;
1005+
1006+
// Each thread handles one element in the [N, K_dim] output
1007+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
1008+
const int total = N * K_dim;
1009+
if (idx >= total)
1010+
return;
1011+
1012+
const int n_idx = idx / K_dim;
1013+
const int k_idx = idx % K_dim;
1014+
const int k_block = k_idx / BS;
1015+
const int elem_in_block = k_idx % BS;
1016+
1017+
// Tiled addressing
1018+
const int k_tile = k_block / KB_PER_TILE;
1019+
const int kb = k_block % KB_PER_TILE;
1020+
const int n_tile = n_idx / TILE_N;
1021+
const int col_in_tile = n_idx % TILE_N;
1022+
const int tile_base = k_tile * n_tiles + n_tile;
1023+
1024+
// Load absmax
1025+
const int abs_idx = tile_base * ABS_PER_TILE + col_in_tile * KB_PER_TILE + kb;
1026+
float amax = load_absmax(absmax_tiled, abs_idx);
1027+
1028+
// Find the byte index for this element
1029+
const int group = elem_in_block / P_VAL;
1030+
const int component = elem_in_block % P_VAL;
1031+
const int word_in_block = group / 4;
1032+
const int byte_in_word = group % 4;
1033+
1034+
const int word_idx = tile_base * WORDS_PER_TILE + (col_in_tile * KB_PER_TILE + kb) * WORDS_PER_BLOCK + word_in_block;
1035+
unsigned char byte_idx = (packed_tiled[word_idx] >> (byte_in_word * 8)) & 0xFF;
1036+
1037+
// Codebook lookup
1038+
float val = __half2float(codebook[byte_idx * P_VAL + component]) * amax;
1039+
out[idx] = (T)val;
1040+
}
1041+
9831042
// ---- Launch wrappers ----
9841043

9851044
#define KBIT_WARPS_PER_BLOCK 8
@@ -1115,6 +1174,19 @@ void dequantize_vq(
11151174
CUDA_CHECK_RETURN(cudaPeekAtLastError());
11161175
}
11171176

1177+
template <int P_VAL, typename T, typename ABSMAX_T>
1178+
void dequantize_vq_tiled(
1179+
const unsigned int* packed_tiled, const half* codebook, const ABSMAX_T* absmax_tiled,
1180+
T* out, int K_dim, int N, cudaStream_t stream
1181+
) {
1182+
int total = N * K_dim;
1183+
int threads = 256;
1184+
int blocks = (total + threads - 1) / threads;
1185+
kDequantize_VQ_tiled<P_VAL, T, ABSMAX_T>
1186+
<<<blocks, threads, 0, stream>>>(packed_tiled, codebook, absmax_tiled, out, K_dim, N);
1187+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
1188+
}
1189+
11181190
// ---- Stage 2: Repack kernel (flat bit-plane -> GEMM-tiled layout) ----
11191191

11201192
// Tile sizes matching the GEMM kernel design (compile-time constants).
@@ -3737,6 +3809,23 @@ INSTANTIATE_VQ_QUANT(4)
37373809
INSTANTIATE_VQ_DEQUANT(2)
37383810
INSTANTIATE_VQ_DEQUANT(4)
37393811

3812+
// dequantize_vq_tiled: P_VAL × T × ABSMAX_T
3813+
#define INSTANTIATE_VQ_DEQUANT_TILED(P) \
3814+
template void dequantize_vq_tiled<P, half, unsigned char>( \
3815+
const unsigned int*, const half*, const unsigned char*, half*, int, int, cudaStream_t \
3816+
); \
3817+
template void dequantize_vq_tiled<P, __nv_bfloat16, unsigned char>( \
3818+
const unsigned int*, const half*, const unsigned char*, __nv_bfloat16*, int, int, cudaStream_t \
3819+
); \
3820+
template void dequantize_vq_tiled<P, half, float>( \
3821+
const unsigned int*, const half*, const float*, half*, int, int, cudaStream_t \
3822+
); \
3823+
template void dequantize_vq_tiled<P, __nv_bfloat16, float>( \
3824+
const unsigned int*, const half*, const float*, __nv_bfloat16*, int, int, cudaStream_t \
3825+
);
3826+
INSTANTIATE_VQ_DEQUANT_TILED(2)
3827+
INSTANTIATE_VQ_DEQUANT_TILED(4)
3828+
37403829
// vq_scalar_gemv: P_VAL × scalar_t × ABSMAX_T (flat + tiled)
37413830
#define INSTANTIATE_VQ_SCALAR_GEMV_U8(P) \
37423831
template void vqScalarGemv<P, half, unsigned char>( \

csrc/pythonInterface.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,28 @@ MAKE_VQ_DEQUANT(fp16, half, fp32abs, float, 4)
552552
MAKE_VQ_DEQUANT(bf16, __nv_bfloat16, fp32abs, float, 2)
553553
MAKE_VQ_DEQUANT(bf16, __nv_bfloat16, fp32abs, float, 4)
554554

555+
// Forward declaration of VQ tiled dequant launcher
556+
template <int P_VAL, typename T, typename ABSMAX_T>
557+
void dequantize_vq_tiled(const unsigned int*, const half*, const ABSMAX_T*, T*, int, int, cudaStream_t);
558+
559+
// Unmangled VQ tiled dequant wrappers
560+
#define MAKE_VQ_DEQUANT_TILED(tname, T, aname, ABSMAX_T, P) \
561+
void dequantize_vq_tiled_##tname##_##aname##_p##P( \
562+
const unsigned int* packed_tiled, const half* codebook, const ABSMAX_T* absmax_tiled, T* out, int K_dim, \
563+
int N, cudaStream_t stream \
564+
) { \
565+
dequantize_vq_tiled<P, T, ABSMAX_T>(packed_tiled, codebook, absmax_tiled, out, K_dim, N, stream); \
566+
}
567+
568+
MAKE_VQ_DEQUANT_TILED(fp16, half, u8abs, unsigned char, 2)
569+
MAKE_VQ_DEQUANT_TILED(fp16, half, u8abs, unsigned char, 4)
570+
MAKE_VQ_DEQUANT_TILED(bf16, __nv_bfloat16, u8abs, unsigned char, 2)
571+
MAKE_VQ_DEQUANT_TILED(bf16, __nv_bfloat16, u8abs, unsigned char, 4)
572+
MAKE_VQ_DEQUANT_TILED(fp16, half, fp32abs, float, 2)
573+
MAKE_VQ_DEQUANT_TILED(fp16, half, fp32abs, float, 4)
574+
MAKE_VQ_DEQUANT_TILED(bf16, __nv_bfloat16, fp32abs, float, 2)
575+
MAKE_VQ_DEQUANT_TILED(bf16, __nv_bfloat16, fp32abs, float, 4)
576+
555577
// Forward declaration of repack launcher
556578
template <int K>
557579
void repackKbit(const unsigned int*, const unsigned char*, unsigned int*, unsigned char*, int, int, cudaStream_t);
@@ -1646,6 +1668,24 @@ MAKE_CVQ_DEQUANT(fp16, half, fp32abs, float, 4)
16461668
MAKE_CVQ_DEQUANT(bf16, __nv_bfloat16, fp32abs, float, 2)
16471669
MAKE_CVQ_DEQUANT(bf16, __nv_bfloat16, fp32abs, float, 4)
16481670

1671+
// VQ tiled dequant extern C wrappers
1672+
#define MAKE_CVQ_DEQUANT_TILED(tname, T, aname, ABSMAX_T, P) \
1673+
void cdequantize_vq_tiled_##tname##_##aname##_p##P( \
1674+
const unsigned int* packed_tiled, const half* codebook, const ABSMAX_T* absmax_tiled, T* out, int K_dim, \
1675+
int N, cudaStream_t stream \
1676+
) { \
1677+
dequantize_vq_tiled_##tname##_##aname##_p##P(packed_tiled, codebook, absmax_tiled, out, K_dim, N, stream); \
1678+
}
1679+
1680+
MAKE_CVQ_DEQUANT_TILED(fp16, half, u8abs, unsigned char, 2)
1681+
MAKE_CVQ_DEQUANT_TILED(fp16, half, u8abs, unsigned char, 4)
1682+
MAKE_CVQ_DEQUANT_TILED(bf16, __nv_bfloat16, u8abs, unsigned char, 2)
1683+
MAKE_CVQ_DEQUANT_TILED(bf16, __nv_bfloat16, u8abs, unsigned char, 4)
1684+
MAKE_CVQ_DEQUANT_TILED(fp16, half, fp32abs, float, 2)
1685+
MAKE_CVQ_DEQUANT_TILED(fp16, half, fp32abs, float, 4)
1686+
MAKE_CVQ_DEQUANT_TILED(bf16, __nv_bfloat16, fp32abs, float, 2)
1687+
MAKE_CVQ_DEQUANT_TILED(bf16, __nv_bfloat16, fp32abs, float, 4)
1688+
16491689
// VQ scalar GEMV extern C wrappers (flat layout)
16501690
#define MAKE_CVQ_SCALAR_GEMV(P) \
16511691
void cvq_scalar_gemv_fp16_p##P( \

0 commit comments

Comments
 (0)