Skip to content

Commit cf2f64a

Browse files
TimDettmersclaude
andcommitted
feat: Add VQ repack kernel (flat → tiled layout) and tiled GEMV support
- kRepackVQ<P_VAL> kernel: maps flat VQ byte layout to tile-interleaved layout - Same tile geometry as kbit repack (TILE_K=64, TILE_N=128, BS=32) - Full registration chain: ops.cu → pythonInterface → _ops.py → backends → functional.py - repack_vq() Python wrapper in functional.py - Verified: flat vs tiled GEMV produces bit-identical results for all (p,K,N,M) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 01f7b5f commit cf2f64a

File tree

5 files changed

+194
-0
lines changed

5 files changed

+194
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,34 @@ def _(
744744
return packed_tiled, absmax_tiled
745745

746746

747+
# VQ repack: flat VQ byte layout -> tiled layout
748+
749+
torch.library.define(
750+
"bitsandbytes::repack_vq",
751+
"(Tensor packed_flat, Tensor absmax_flat, int K_dim, int N, int p) -> (Tensor, Tensor)",
752+
)
753+
754+
755+
@register_fake("bitsandbytes::repack_vq")
756+
def _(
757+
packed_flat: torch.Tensor, absmax_flat: torch.Tensor, K_dim: int, N: int, p: int
758+
) -> tuple[torch.Tensor, torch.Tensor]:
759+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
760+
TILE_K, TILE_N, BLOCKSIZE = 64, 128, 32
761+
torch._check(N % TILE_N == 0, lambda: f"N ({N}) must be divisible by {TILE_N}")
762+
torch._check(K_dim % BLOCKSIZE == 0, lambda: f"K_dim ({K_dim}) must be divisible by {BLOCKSIZE}")
763+
K_dim_padded = ((K_dim + TILE_K - 1) // TILE_K) * TILE_K
764+
k_tiles = K_dim_padded // TILE_K
765+
n_tiles = N // TILE_N
766+
k_blocks_per_tile = TILE_K // BLOCKSIZE
767+
words_per_block = BLOCKSIZE // (p * 4)
768+
total_words = k_tiles * n_tiles * TILE_N * k_blocks_per_tile * words_per_block
769+
total_absmax = k_tiles * n_tiles * TILE_N * k_blocks_per_tile
770+
packed_tiled = torch.empty(total_words, device=packed_flat.device, dtype=torch.int32)
771+
absmax_tiled = torch.empty(total_absmax, device=packed_flat.device, dtype=torch.uint8)
772+
return packed_tiled, absmax_tiled
773+
774+
747775
# Hadamard rotation (in-place, for kbit quantization outlier spreading)
748776

749777
torch.library.define(

bitsandbytes/backends/cuda/ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,51 @@ def _(
12101210
return packed_tiled, absmax_tiled
12111211

12121212

1213+
@register_kernel("bitsandbytes::repack_vq", "cuda")
1214+
def _(
1215+
packed_flat: torch.Tensor,
1216+
absmax_flat: torch.Tensor,
1217+
K_dim: int,
1218+
N: int,
1219+
p: int,
1220+
) -> tuple[torch.Tensor, torch.Tensor]:
1221+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
1222+
torch._check(packed_flat.dtype == torch.int32, lambda: f"packed_flat must be int32, got {packed_flat.dtype}")
1223+
torch._check(
1224+
absmax_flat.dtype == torch.uint8, lambda: f"absmax_flat must be uint8 (E4M4), got {absmax_flat.dtype}"
1225+
)
1226+
1227+
TILE_K, TILE_N, BLOCKSIZE = 64, 128, 32
1228+
torch._check(N % TILE_N == 0, lambda: f"N ({N}) must be divisible by {TILE_N}")
1229+
torch._check(K_dim % BLOCKSIZE == 0, lambda: f"K_dim ({K_dim}) must be divisible by {BLOCKSIZE}")
1230+
1231+
K_dim_padded = ((K_dim + TILE_K - 1) // TILE_K) * TILE_K
1232+
k_tiles = K_dim_padded // TILE_K
1233+
n_tiles = N // TILE_N
1234+
k_blocks_per_tile = TILE_K // BLOCKSIZE
1235+
words_per_block = BLOCKSIZE // (p * 4)
1236+
total_words = k_tiles * n_tiles * TILE_N * k_blocks_per_tile * words_per_block
1237+
total_absmax = k_tiles * n_tiles * TILE_N * k_blocks_per_tile
1238+
1239+
# Zero-fill for padding regions (when K_dim is not multiple of TILE_K)
1240+
packed_tiled = torch.zeros(total_words, device=packed_flat.device, dtype=torch.int32)
1241+
absmax_tiled = torch.zeros(total_absmax, device=packed_flat.device, dtype=torch.uint8)
1242+
1243+
with _cuda_device_of(packed_flat):
1244+
fn = getattr(lib, f"crepack_vq_p{p}")
1245+
fn(
1246+
get_ptr(packed_flat),
1247+
get_ptr(absmax_flat),
1248+
get_ptr(packed_tiled),
1249+
get_ptr(absmax_tiled),
1250+
ct.c_int(K_dim),
1251+
ct.c_int(N),
1252+
_get_tensor_stream(packed_flat),
1253+
)
1254+
1255+
return packed_tiled, absmax_tiled
1256+
1257+
12131258
@register_kernel("bitsandbytes::hadamard_rotate_", "cuda")
12141259
def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
12151260
torch._check(

bitsandbytes/functional.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,31 @@ def dequantize_vq(
13761376
return result[:n]
13771377

13781378

1379+
def repack_vq(
1380+
packed_flat: Tensor,
1381+
absmax_flat: Tensor,
1382+
K_dim: int,
1383+
N: int,
1384+
p: int = 2,
1385+
) -> tuple[Tensor, Tensor]:
1386+
"""Repack VQ quantized weights from flat to tiled layout.
1387+
1388+
Rearranges packed byte indices and absmax from flat column-major layout
1389+
to tile-interleaved layout used by vq_scalar_gemv_tiled and vq_gemm_prod.
1390+
1391+
Args:
1392+
packed_flat: int32 tensor of packed byte indices (from quantize_vq).
1393+
absmax_flat: uint8 E4M4 per-block absmax values.
1394+
K_dim: Reduction dimension.
1395+
N: Output dimension (must be multiple of 128).
1396+
p: VQ dimension (2 or 4).
1397+
1398+
Returns:
1399+
Tuple of (packed_tiled, absmax_tiled).
1400+
"""
1401+
return torch.ops.bitsandbytes.repack_vq(packed_flat, absmax_flat, K_dim, N, p)
1402+
1403+
13791404
def dequantize_kbit_tiled(
13801405
packed: Tensor,
13811406
absmax: Tensor,

csrc/ops.cu

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,67 @@ void repackKbit(
11801180
CUDA_CHECK_RETURN(cudaPeekAtLastError());
11811181
}
11821182

1183+
// ---- VQ Repack (flat VQ bytes -> tiled layout) ----
1184+
// Same tile geometry as kbit repack but with VQ byte words instead of bit planes.
1185+
// words_per_block = BS / (P_VAL * 4): p=2→4, p=4→2
1186+
1187+
template <int P_VAL>
1188+
__global__ void kRepackVQ(
1189+
const unsigned int* __restrict__ packed_flat, const unsigned char* __restrict__ absmax_flat,
1190+
unsigned int* __restrict__ packed_tiled, unsigned char* __restrict__ absmax_tiled, const int K_dim, const int N
1191+
) {
1192+
constexpr int BS = 32;
1193+
constexpr int WORDS_PER_BLOCK = BS / (P_VAL * 4); // p=2: 4, p=4: 2
1194+
const int total_k_blocks = K_dim / BS;
1195+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
1196+
if (idx >= N * total_k_blocks)
1197+
return;
1198+
1199+
const int n_idx = idx / total_k_blocks;
1200+
const int k_block_idx = idx % total_k_blocks;
1201+
const int k_start = k_block_idx * BS;
1202+
1203+
// Source: flat layout
1204+
const int flat_block_id = n_idx * total_k_blocks + k_block_idx;
1205+
1206+
// Destination: tiled layout
1207+
const int k_tile = k_start / KBIT_TILE_K;
1208+
const int n_tile = n_idx / KBIT_TILE_N;
1209+
const int col = n_idx % KBIT_TILE_N;
1210+
const int kb = (k_start % KBIT_TILE_K) / BS;
1211+
1212+
const int n_tiles = N / KBIT_TILE_N;
1213+
constexpr int KB_PER_TILE = KBIT_TILE_K / BS; // 2
1214+
constexpr int WORDS_PER_TILE = KBIT_TILE_N * KB_PER_TILE * WORDS_PER_BLOCK;
1215+
constexpr int ABS_PER_TILE = KBIT_TILE_N * KB_PER_TILE;
1216+
1217+
const int tile_base = k_tile * n_tiles + n_tile;
1218+
const int dst_word_base = tile_base * WORDS_PER_TILE + (col * KB_PER_TILE + kb) * WORDS_PER_BLOCK;
1219+
const int src_word_base = flat_block_id * WORDS_PER_BLOCK;
1220+
1221+
#pragma unroll
1222+
for (int w = 0; w < WORDS_PER_BLOCK; w++)
1223+
packed_tiled[dst_word_base + w] = packed_flat[src_word_base + w];
1224+
1225+
const int dst_abs_idx = tile_base * ABS_PER_TILE + col * KB_PER_TILE + kb;
1226+
absmax_tiled[dst_abs_idx] = absmax_flat[flat_block_id];
1227+
}
1228+
1229+
// VQ Repack launcher
1230+
template <int P_VAL>
1231+
void repackVQ(
1232+
const unsigned int* packed_flat, const unsigned char* absmax_flat,
1233+
unsigned int* packed_tiled, unsigned char* absmax_tiled,
1234+
int K_dim, int N, cudaStream_t stream
1235+
) {
1236+
int total_work = N * (K_dim / 32);
1237+
int block_size = 256;
1238+
int grid_size = (total_work + block_size - 1) / block_size;
1239+
kRepackVQ<P_VAL>
1240+
<<<grid_size, block_size, 0, stream>>>(packed_flat, absmax_flat, packed_tiled, absmax_tiled, K_dim, N);
1241+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
1242+
}
1243+
11831244
// ===========================================================================
11841245
// Hadamard rotation kernel (in-place, blocksize-templated)
11851246
//
@@ -3500,6 +3561,14 @@ INSTANTIATE_KBIT_REPACK(3)
35003561
INSTANTIATE_KBIT_REPACK(4)
35013562
INSTANTIATE_KBIT_REPACK(5)
35023563

3564+
// VQ repack: P_VAL
3565+
#define INSTANTIATE_VQ_REPACK(P) \
3566+
template void repackVQ<P>( \
3567+
const unsigned int*, const unsigned char*, unsigned int*, unsigned char*, int, int, cudaStream_t \
3568+
);
3569+
INSTANTIATE_VQ_REPACK(2)
3570+
INSTANTIATE_VQ_REPACK(4)
3571+
35033572
// Production kernel instantiations — uint8 E4M4 absmax (default)
35043573
#define INSTANTIATE_KBIT_GEMM_PROD_U8(K) \
35053574
template void kbitGemmProd<K, half, unsigned char>( \

csrc/pythonInterface.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,21 @@ MAKE_KBIT_REPACK(3)
570570
MAKE_KBIT_REPACK(4)
571571
MAKE_KBIT_REPACK(5)
572572

573+
// Forward declaration of VQ repack launcher
574+
template <int P>
575+
void repackVQ(const unsigned int*, const unsigned char*, unsigned int*, unsigned char*, int, int, cudaStream_t);
576+
577+
#define MAKE_VQ_REPACK(P) \
578+
void repack_vq_p##P( \
579+
const unsigned int* packed_flat, const unsigned char* absmax_flat, unsigned int* packed_tiled, \
580+
unsigned char* absmax_tiled, int K_dim, int N, cudaStream_t stream \
581+
) { \
582+
repackVQ<P>(packed_flat, absmax_flat, packed_tiled, absmax_tiled, K_dim, N, stream); \
583+
}
584+
585+
MAKE_VQ_REPACK(2)
586+
MAKE_VQ_REPACK(4)
587+
573588
// Forward declarations of GEMM launchers
574589
template <int K, typename scalar_t, typename ABSMAX_T>
575590
void kbitGemmProd(
@@ -1519,6 +1534,18 @@ MAKE_CKBIT_REPACK(3)
15191534
MAKE_CKBIT_REPACK(4)
15201535
MAKE_CKBIT_REPACK(5)
15211536

1537+
// VQ repack extern C wrappers
1538+
#define MAKE_CREPACK_VQ(P) \
1539+
void crepack_vq_p##P( \
1540+
const unsigned int* packed_flat, const unsigned char* absmax_flat, unsigned int* packed_tiled, \
1541+
unsigned char* absmax_tiled, int K_dim, int N, cudaStream_t stream \
1542+
) { \
1543+
repack_vq_p##P(packed_flat, absmax_flat, packed_tiled, absmax_tiled, K_dim, N, stream); \
1544+
}
1545+
1546+
MAKE_CREPACK_VQ(2)
1547+
MAKE_CREPACK_VQ(4)
1548+
15221549
// fp16 absmax - all output types
15231550
MAKE_CKBIT_DEQUANT(fp16, half, fp16abs, half, 2)
15241551
MAKE_CKBIT_DEQUANT(fp16, half, fp16abs, half, 3)

0 commit comments

Comments
 (0)