Skip to content

Commit 38e8642

Browse files
TimDettmersclaude
andcommitted
Add k-bit quantization kernels (K=2-5, blocksize=32) -- WIP
Implements Stages 0-5 of the k-bit quantization plan from cuda-spec.md: - Pure Python reference (quantize_kbit_ref, dequantize_kbit_ref) with 57 passing tests - CUDA kernels using __ballot_sync bit-plane packing and __shfl_sync codebook lookup - Test kernels (pack/unpack, memory format, codebook lookup) and production kernels - All C interface symbols exported and loadable via ctypes CUDA kernels compile but are not yet executable due to an RDC device linking issue where template instantiations in kernels.cu are not pulled into the final fatbinary. See KBIT_PROGRESS.md for diagnosis and recommended fix (move kernel bodies into ops.cu or a new self-contained file). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5ea3c89 commit 38e8642

File tree

6 files changed

+1278
-1
lines changed

6 files changed

+1278
-1
lines changed

KBIT_PROGRESS.md

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# K-Bit Quantization Implementation Progress
2+
3+
**Branch**: `feature/kbit-quantization` (worktree at `~/git/bitsandbytes-kbit`)
4+
**Spec files**: `cuda-spec.md`, `cuda-spec-additions.md` (in main repo, gitignored)
5+
6+
## Completed
7+
8+
### Stage 0: Pure Python Reference -- DONE
9+
- File: `tests/test_kbit_quantization.py`
10+
- Functions: `create_normal_float_codebook()`, `quantize_kbit_ref()`, `dequantize_kbit_ref()`, `pack_kbit_ref()`, `unpack_kbit_ref()`
11+
- 57 tests pass (codebook generation, round-trip, MSE ordering, error bounds, pack/unpack)
12+
- Serves as permanent ground truth for all CUDA validation
13+
14+
### Stages 1-5: CUDA Kernels -- CODE WRITTEN, BUILD ISSUE
15+
16+
All CUDA kernel code is written and compiles, but there's a **device linker issue** preventing the kernels from appearing in the final `.so`.
17+
18+
#### Files modified:
19+
20+
1. **`csrc/kernels.cu`** (appended at end, ~200 lines):
21+
- `warp_reduce_absmax()` -- device helper for warp-level max reduction
22+
- `pack_kbit_warp<K>()` -- device helper, __ballot_sync bit-plane packing
23+
- `unpack_kbit_warp<K>()` -- device helper, bit extraction unpacking
24+
- `kTestPackUnpack_kbit<K>` -- Stage 1 test kernel (in-warp round-trip)
25+
- `kTestPackWrite_kbit<K>` -- Stage 2 test kernel (pack to global memory)
26+
- `kTestReadUnpack_kbit<K>` -- Stage 2 test kernel (read from global memory)
27+
- `kTestCodebookLookup_kbit<K>` -- Stage 3 test kernel (shfl_sync codebook)
28+
- `kQuantizeBlockwise_kbit<T, K>` -- Stage 4 production quantize kernel
29+
- `kDequantizeBlockwise_kbit<T, K>` -- Stage 5 production dequantize kernel
30+
- Template instantiation macros for K=2,3,4,5 x T=half,bf16,float
31+
32+
2. **`csrc/kernels.cuh`** (appended before `#endif`):
33+
- Forward declarations of all kernel templates
34+
35+
3. **`csrc/ops.cu`** (appended at end, ~100 lines):
36+
- Launch wrappers: `test_pack_unpack_kbit<K>()`, `test_pack_write_kbit<K>()`, etc.
37+
- Launch wrappers: `quantizeBlockwise_kbit<T,K>()`, `dequantizeBlockwise_kbit<T,K>()`
38+
- Grid calculation: `ceil(n/32)/8` CUDA blocks, 256 threads per block
39+
- Template instantiation macros
40+
41+
4. **`csrc/pythonInterface.cpp`** (two sections added):
42+
- Unmangled wrappers (inside `#if BUILD_CUDA || BUILD_HIP`): `test_pack_unpack_k{K}()`, `quantize_kbit_{fp16,bf16,fp32}_k{K}()`, etc.
43+
- extern "C" wrappers: `ctest_pack_unpack_k{K}()`, `cquantize_kbit_{tname}_k{K}()`, `cdequantize_kbit_{tname}_k{K}()`, etc.
44+
45+
5. **`tests/test_kbit_quantization.py`** (comprehensive test file):
46+
- Python reference tests (Stage 0): `TestCodebook`, `TestQuantizeRef`, `TestPackUnpackRef`
47+
- CUDA ctypes wrappers: `_cuda_test_pack_unpack()`, `_cuda_quantize_kbit()`, `_cuda_dequantize_kbit()`, etc.
48+
- CUDA tests (Stages 1-5): `TestStage1PackUnpackCUDA`, `TestStage2PackMemoryCUDA`, `TestStage3CodebookLookupCUDA`, `TestStage4QuantizeCUDA`, `TestStage5DequantizeCUDA`
49+
50+
## Current Blocker: RDC Device Linking
51+
52+
### Problem
53+
The compiled kernels exist in the `.o` object files (verified via `nm`), and the C-level symbols are exported in the final `.so` (verified via `nm -D`), but the **CUDA device code** (fatbinary) does not contain the new kernel functions. Running any kernel gives "invalid device function".
54+
55+
### Root Cause
56+
The project uses `-rdc=true` (relocatable device code) for separate compilation. The device link step (`cmake_device_link.o`) needs to resolve all device-side references. The template instantiations in `kernels.cu` produce weak symbols in the object file, but the device linker may not be pulling them in because they're not referenced from the device link compilation unit.
57+
58+
### How to Fix (options)
59+
60+
1. **Add `__global__` function declarations to the device link file**: Check how CMake generates the device link step and ensure it sees all `.cu` object files.
61+
62+
2. **Use `--relocatable-device-code=false` for the kbit kernels**: If the kbit kernels don't need cross-file device calls, they could be compiled without RDC. But this requires CMake changes.
63+
64+
3. **Move kernel definitions to the same file as the launch wrappers**: Instead of splitting between `kernels.cu` (kernel definitions) and `ops.cu` (launch wrappers), put everything in a single `.cu` file. This is the simplest fix -- add the kernel bodies directly to `ops.cu` or create a new `kbit_kernels.cu` that contains both kernels and launch wrappers.
65+
66+
4. **Check CMakeLists.txt for device link configuration**: The CMake `CUDA_SEPARABLE_COMPILATION` property or `CUDA_RESOLVE_DEVICE_SYMBOLS` might need adjustment.
67+
68+
**Recommended fix**: Option 3 -- move all kbit kernel code from `kernels.cu` into `ops.cu` (or a new self-contained file). This sidesteps the RDC linking issue entirely since the kernel and its launch site would be in the same compilation unit.
69+
70+
## Build Instructions
71+
72+
```bash
73+
cd ~/git/bitsandbytes-kbit
74+
cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="89;90" -S . -B build
75+
make -C build -j$(nproc)
76+
ln -sf libbitsandbytes_cuda124.so bitsandbytes/libbitsandbytes_cuda128.so
77+
```
78+
79+
## Test Instructions
80+
81+
```bash
82+
# Python-only tests (all pass)
83+
python -m pytest tests/test_kbit_quantization.py -k "not CUDA" -v
84+
85+
# CUDA tests (currently fail due to device link issue)
86+
python -m pytest tests/test_kbit_quantization.py -k "CUDA" -v
87+
```
88+
89+
## Not Yet Implemented
90+
91+
- Stages 6-8: Error analysis, NF4 cross-validation, performance benchmarking (test code not written)
92+
- Python API in `bitsandbytes/functional.py` (quantize_kbit, dequantize_kbit)
93+
- `torch.library` registration in `bitsandbytes/_ops.py`
94+
- Codebook caching/registration system

csrc/kernels.cu

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,3 +2601,297 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1)
26012601
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
26022602
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
26032603
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)
2604+
2605+
// ===========================================================================
2606+
// K-bit blockwise quantization/dequantization kernels (blocksize=32, K=2..5)
2607+
//
2608+
// Uses bit-plane packing via __ballot_sync and codebook lookup via __shfl_sync.
2609+
// One warp (32 threads) per quantization block. 8 warps per CUDA block.
2610+
// ===========================================================================
2611+
2612+
// ---- Device helpers ----
2613+
2614+
// Warp-level max reduction (32 threads). Returns the max broadcast to all lanes.
2615+
__device__ __forceinline__ float warp_reduce_absmax(float val) {
2616+
#pragma unroll
2617+
for (int offset = 16; offset > 0; offset >>= 1)
2618+
val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
2619+
return __shfl_sync(0xFFFFFFFF, val, 0);
2620+
}
2621+
2622+
// Pack one K-bit value per lane into K bit-plane uint32 words via __ballot_sync.
2623+
// packed_words[0..K-1] are written with the bit-plane representation.
2624+
// All lanes in the warp must call this simultaneously.
2625+
template <int K>
2626+
__device__ __forceinline__ void pack_kbit_warp(unsigned char qval, unsigned int* packed_words) {
2627+
#pragma unroll
2628+
for (int bit = 0; bit < K; bit++)
2629+
packed_words[bit] = __ballot_sync(0xFFFFFFFF, (qval >> bit) & 1);
2630+
}
2631+
2632+
// Unpack one K-bit value for this lane from K bit-plane uint32 words.
2633+
template <int K>
2634+
__device__ __forceinline__ unsigned char unpack_kbit_warp(const unsigned int* packed_words, int lane_id) {
2635+
unsigned char val = 0;
2636+
#pragma unroll
2637+
for (int bit = 0; bit < K; bit++)
2638+
val |= ((packed_words[bit] >> lane_id) & 1) << bit;
2639+
return val;
2640+
}
2641+
2642+
// ---- Stage 1: Pack/unpack round-trip test kernel ----
2643+
// Input: uint8 indices[n], Output: uint8 recovered[n]
2644+
template <int K>
2645+
__global__ void kTestPackUnpack_kbit(
2646+
const unsigned char* __restrict__ indices,
2647+
unsigned char* __restrict__ recovered,
2648+
const int n
2649+
) {
2650+
const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
2651+
const int lane_id = threadIdx.x % 32;
2652+
const int block_start = warp_id * 32;
2653+
2654+
if (block_start >= n) return;
2655+
2656+
// Load index (with bounds guard for partial last block)
2657+
unsigned char qval = 0;
2658+
if (block_start + lane_id < n)
2659+
qval = indices[block_start + lane_id];
2660+
2661+
// Pack into bit planes
2662+
unsigned int packed[K];
2663+
pack_kbit_warp<K>(qval, packed);
2664+
2665+
// Unpack
2666+
unsigned char recovered_val = unpack_kbit_warp<K>(packed, lane_id);
2667+
2668+
// Store
2669+
if (block_start + lane_id < n)
2670+
recovered[block_start + lane_id] = recovered_val;
2671+
}
2672+
2673+
// ---- Stage 2: Pack-write and read-unpack test kernels ----
2674+
2675+
// Pack indices and write bit-plane words to global memory
2676+
template <int K>
2677+
__global__ void kTestPackWrite_kbit(
2678+
const unsigned char* __restrict__ indices,
2679+
unsigned int* __restrict__ packed_out,
2680+
const int n
2681+
) {
2682+
const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
2683+
const int lane_id = threadIdx.x % 32;
2684+
const int block_start = warp_id * 32;
2685+
2686+
if (block_start >= n) return;
2687+
2688+
unsigned char qval = 0;
2689+
if (block_start + lane_id < n)
2690+
qval = indices[block_start + lane_id];
2691+
2692+
unsigned int packed[K];
2693+
pack_kbit_warp<K>(qval, packed);
2694+
2695+
// Lanes 0..K-1 each write one word
2696+
if (lane_id < K)
2697+
packed_out[warp_id * K + lane_id] = packed[lane_id];
2698+
}
2699+
2700+
// Read bit-plane words from global memory and unpack to indices
2701+
template <int K>
2702+
__global__ void kTestReadUnpack_kbit(
2703+
const unsigned int* __restrict__ packed_in,
2704+
unsigned char* __restrict__ indices_out,
2705+
const int n
2706+
) {
2707+
const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
2708+
const int lane_id = threadIdx.x % 32;
2709+
const int block_start = warp_id * 32;
2710+
2711+
if (block_start >= n) return;
2712+
2713+
// Load K words, broadcast to all lanes
2714+
unsigned int packed[K];
2715+
#pragma unroll
2716+
for (int bit = 0; bit < K; bit++) {
2717+
unsigned int word = 0;
2718+
if (lane_id == bit)
2719+
word = packed_in[warp_id * K + bit];
2720+
packed[bit] = __shfl_sync(0xFFFFFFFF, word, bit);
2721+
}
2722+
2723+
unsigned char val = unpack_kbit_warp<K>(packed, lane_id);
2724+
2725+
if (block_start + lane_id < n)
2726+
indices_out[block_start + lane_id] = val;
2727+
}
2728+
2729+
// ---- Stage 3: Codebook shuffle lookup test kernel ----
2730+
2731+
template <int K>
2732+
__global__ void kTestCodebookLookup_kbit(
2733+
const unsigned char* __restrict__ indices,
2734+
const float* __restrict__ codebook,
2735+
float* __restrict__ out,
2736+
const int n
2737+
) {
2738+
const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
2739+
const int lane_id = threadIdx.x % 32;
2740+
const int block_start = warp_id * 32;
2741+
2742+
if (block_start >= n) return;
2743+
2744+
// Load codebook into warp lanes
2745+
float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0.0f;
2746+
2747+
// Load index
2748+
unsigned char idx = 0;
2749+
if (block_start + lane_id < n)
2750+
idx = indices[block_start + lane_id];
2751+
2752+
// Shuffle lookup
2753+
float val = __shfl_sync(0xFFFFFFFF, cb, idx);
2754+
2755+
if (block_start + lane_id < n)
2756+
out[block_start + lane_id] = val;
2757+
}
2758+
2759+
// ---- Stage 4: Full quantize kernel ----
2760+
2761+
template <typename T, int K>
2762+
__global__ void kQuantizeBlockwise_kbit(
2763+
const float* __restrict__ codebook,
2764+
const T* __restrict__ A,
2765+
float* __restrict__ absmax,
2766+
unsigned int* __restrict__ packed_out,
2767+
const int n
2768+
) {
2769+
const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
2770+
const int lane_id = threadIdx.x % 32;
2771+
const int block_start = warp_id * 32;
2772+
2773+
if (block_start >= n) return;
2774+
2775+
// 1. Load input value
2776+
float val = 0.0f;
2777+
if (block_start + lane_id < n)
2778+
val = (float)A[block_start + lane_id];
2779+
2780+
// 2. Warp-level absmax reduction
2781+
float amax = warp_reduce_absmax(fabsf(val));
2782+
float amax_safe = fmaxf(amax, 1e-8f);
2783+
2784+
// 3. Lane 0 stores absmax
2785+
if (lane_id == 0)
2786+
absmax[warp_id] = amax;
2787+
2788+
// 4. Normalize to [-1, 1]
2789+
float normalized = val / amax_safe;
2790+
2791+
// 5. Load codebook into warp lanes
2792+
float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0.0f;
2793+
2794+
// 6. Branchless nearest-codebook search
2795+
unsigned char best_idx = 0;
2796+
float best_dist = 1e10f;
2797+
#pragma unroll
2798+
for (int i = 0; i < (1 << K); i++) {
2799+
float cb_val = __shfl_sync(0xFFFFFFFF, cb, i);
2800+
float dist = fabsf(normalized - cb_val);
2801+
bool closer = (dist < best_dist);
2802+
best_dist = closer ? dist : best_dist;
2803+
best_idx = closer ? (unsigned char)i : best_idx;
2804+
}
2805+
2806+
// 7. Pack into bit planes
2807+
unsigned int packed[K];
2808+
pack_kbit_warp<K>(best_idx, packed);
2809+
2810+
// 8. Write K packed words
2811+
if (lane_id < K)
2812+
packed_out[warp_id * K + lane_id] = packed[lane_id];
2813+
}
2814+
2815+
// ---- Stage 5: Full dequantize kernel ----
2816+
2817+
template <typename T, int K>
2818+
__global__ void kDequantizeBlockwise_kbit(
2819+
const unsigned int* __restrict__ packed_in,
2820+
const float* __restrict__ codebook,
2821+
const float* __restrict__ absmax,
2822+
T* __restrict__ out,
2823+
const int n
2824+
) {
2825+
const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
2826+
const int lane_id = threadIdx.x % 32;
2827+
const int block_start = warp_id * 32;
2828+
2829+
if (block_start >= n) return;
2830+
2831+
// 1. Load codebook into warp lanes
2832+
float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0.0f;
2833+
2834+
// 2. Load absmax for this block
2835+
float amax = absmax[warp_id];
2836+
2837+
// 3. Load K packed words, broadcast to all lanes
2838+
unsigned int packed[K];
2839+
#pragma unroll
2840+
for (int bit = 0; bit < K; bit++) {
2841+
unsigned int word = 0;
2842+
if (lane_id == bit)
2843+
word = packed_in[warp_id * K + bit];
2844+
packed[bit] = __shfl_sync(0xFFFFFFFF, word, bit);
2845+
}
2846+
2847+
// 4. Unpack this thread's K-bit index
2848+
unsigned char idx = unpack_kbit_warp<K>(packed, lane_id);
2849+
2850+
// 5. Codebook lookup via shuffle
2851+
float val = __shfl_sync(0xFFFFFFFF, cb, idx);
2852+
2853+
// 6. Scale by absmax
2854+
val *= amax;
2855+
2856+
// 7. Store
2857+
if (block_start + lane_id < n)
2858+
out[block_start + lane_id] = (T)val;
2859+
}
2860+
2861+
// ---- Template instantiations ----
2862+
2863+
// Test kernels (Stage 1-3)
2864+
#define INSTANTIATE_TEST_KBIT(K) \
2865+
template __global__ void kTestPackUnpack_kbit<K>( \
2866+
const unsigned char*, unsigned char*, const int); \
2867+
template __global__ void kTestPackWrite_kbit<K>( \
2868+
const unsigned char*, unsigned int*, const int); \
2869+
template __global__ void kTestReadUnpack_kbit<K>( \
2870+
const unsigned int*, unsigned char*, const int); \
2871+
template __global__ void kTestCodebookLookup_kbit<K>( \
2872+
const unsigned char*, const float*, float*, const int);
2873+
2874+
INSTANTIATE_TEST_KBIT(2)
2875+
INSTANTIATE_TEST_KBIT(3)
2876+
INSTANTIATE_TEST_KBIT(4)
2877+
INSTANTIATE_TEST_KBIT(5)
2878+
2879+
// Production kernels (Stage 4-5)
2880+
#define INSTANTIATE_KBIT_QUANT(T, K) \
2881+
template __global__ void kQuantizeBlockwise_kbit<T, K>( \
2882+
const float*, const T*, float*, unsigned int*, const int); \
2883+
template __global__ void kDequantizeBlockwise_kbit<T, K>( \
2884+
const unsigned int*, const float*, const float*, T*, const int);
2885+
2886+
INSTANTIATE_KBIT_QUANT(half, 2)
2887+
INSTANTIATE_KBIT_QUANT(half, 3)
2888+
INSTANTIATE_KBIT_QUANT(half, 4)
2889+
INSTANTIATE_KBIT_QUANT(half, 5)
2890+
INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 2)
2891+
INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 3)
2892+
INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 4)
2893+
INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 5)
2894+
INSTANTIATE_KBIT_QUANT(float, 2)
2895+
INSTANTIATE_KBIT_QUANT(float, 3)
2896+
INSTANTIATE_KBIT_QUANT(float, 4)
2897+
INSTANTIATE_KBIT_QUANT(float, 5)

0 commit comments

Comments
 (0)