Skip to content

Commit cbb157d

Browse files
TimDettmersclaude
andcommitted
Fix RDC device linking: move kernels to ops.cu, all 157 tests pass
The "invalid device function" error was caused by mismatched kernel declarations in kernels.cuh (without __restrict__) vs definitions in ops.cu (with __restrict__). With CUDA separable compilation (-rdc=true), this created conflicting host stubs in the function registration. Fix: remove forward declarations from kernels.cuh, keep kernel definitions and launch wrappers together in ops.cu. Also added CUDA_RESOLVE_DEVICE_SYMBOLS ON to CMakeLists.txt. All 157 tests now pass: Stage 0 (Python ref), Stages 1-3 (CUDA test kernels), Stage 4 (quantize), Stage 5 (dequantize) -- covering K=2-5, fp16/bf16/fp32, various tensor sizes, and analytical error bounds. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 38e8642 commit cbb157d

File tree

5 files changed

+238
-375
lines changed

5 files changed

+238
-375
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ if(BUILD_CUDA)
312312
set_target_properties(bitsandbytes
313313
PROPERTIES
314314
CUDA_SEPARABLE_COMPILATION ON
315+
CUDA_RESOLVE_DEVICE_SYMBOLS ON
315316
)
316317
endif()
317318
if(BUILD_HIP)

KBIT_PROGRESS.md

Lines changed: 59 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,88 @@
11
# K-Bit Quantization Implementation Progress
22

33
**Branch**: `feature/kbit-quantization` (worktree at `~/git/bitsandbytes-kbit`)
4-
**Spec files**: `cuda-spec.md`, `cuda-spec-additions.md` (in main repo, gitignored)
4+
**Spec files**: `cuda-spec.md`, `cuda-spec-additions.md` (in main repo root, gitignored)
55

6-
## Completed
6+
## Status: Stages 0-5 COMPLETE, 157/157 tests passing
77

8-
### Stage 0: Pure Python Reference -- DONE
9-
- File: `tests/test_kbit_quantization.py`
10-
- Functions: `create_normal_float_codebook()`, `quantize_kbit_ref()`, `dequantize_kbit_ref()`, `pack_kbit_ref()`, `unpack_kbit_ref()`
11-
- 57 tests pass (codebook generation, round-trip, MSE ordering, error bounds, pack/unpack)
12-
- Serves as permanent ground truth for all CUDA validation
8+
All CUDA kernels are working. The full quantize/dequantize pipeline runs on GPU, validated against the Python reference.
139

14-
### Stages 1-5: CUDA Kernels -- CODE WRITTEN, BUILD ISSUE
10+
## What's Done
1511

16-
All CUDA kernel code is written and compiles, but there's a **device linker issue** preventing the kernels from appearing in the final `.so`.
12+
### Stage 0: Pure Python Reference
13+
- File: `tests/test_kbit_quantization.py` (top half)
14+
- `create_normal_float_codebook(k)` -- generates 2^k NF codebook from N(0,1) quantiles
15+
- `quantize_kbit_ref(A, codebook)` -- pure PyTorch blockwise quantize (blocksize=32)
16+
- `dequantize_kbit_ref(indices, absmax, codebook)` -- pure PyTorch dequantize
17+
- `pack_kbit_ref(indices, k)` / `unpack_kbit_ref(packed, k, n)` -- bit-plane packing reference
18+
- Tests: `TestCodebook`, `TestQuantizeRef`, `TestPackUnpackRef`
1719

18-
#### Files modified:
20+
### Stages 1-3: CUDA Test Kernels (temporary scaffolding)
21+
- `kTestPackUnpack_kbit<K>` -- in-warp __ballot_sync pack / bit-extract unpack round-trip
22+
- `kTestPackWrite_kbit<K>` / `kTestReadUnpack_kbit<K>` -- persistent memory format
23+
- `kTestCodebookLookup_kbit<K>` -- __shfl_sync codebook lookup
24+
- Tests: `TestStage1PackUnpackCUDA`, `TestStage2PackMemoryCUDA`, `TestStage3CodebookLookupCUDA`
1925

20-
1. **`csrc/kernels.cu`** (appended at end, ~200 lines):
21-
- `warp_reduce_absmax()` -- device helper for warp-level max reduction
22-
- `pack_kbit_warp<K>()` -- device helper, __ballot_sync bit-plane packing
23-
- `unpack_kbit_warp<K>()` -- device helper, bit extraction unpacking
24-
- `kTestPackUnpack_kbit<K>` -- Stage 1 test kernel (in-warp round-trip)
25-
- `kTestPackWrite_kbit<K>` -- Stage 2 test kernel (pack to global memory)
26-
- `kTestReadUnpack_kbit<K>` -- Stage 2 test kernel (read from global memory)
27-
- `kTestCodebookLookup_kbit<K>` -- Stage 3 test kernel (shfl_sync codebook)
28-
- `kQuantizeBlockwise_kbit<T, K>` -- Stage 4 production quantize kernel
29-
- `kDequantizeBlockwise_kbit<T, K>` -- Stage 5 production dequantize kernel
30-
- Template instantiation macros for K=2,3,4,5 x T=half,bf16,float
26+
### Stage 4: Full Quantize Kernel
27+
- `kQuantizeBlockwise_kbit<T, K>` -- warp-level absmax reduction, branchless codebook search, ballot_sync bit-plane packing
28+
- CUDA indices match Python reference exactly
29+
- Tests: `TestStage4QuantizeCUDA` (absmax correctness, indices match ref, all dtypes, various sizes)
3130

32-
2. **`csrc/kernels.cuh`** (appended before `#endif`):
33-
- Forward declarations of all kernel templates
31+
### Stage 5: Full Dequantize Kernel
32+
- `kDequantizeBlockwise_kbit<T, K>` -- bit-plane unpacking, shfl_sync codebook lookup, absmax scaling
33+
- Round-trip error within analytical bounds for all K
34+
- Tests: `TestStage5DequantizeCUDA` (matches ref, all dtypes, various sizes, error bounds)
3435

35-
3. **`csrc/ops.cu`** (appended at end, ~100 lines):
36-
- Launch wrappers: `test_pack_unpack_kbit<K>()`, `test_pack_write_kbit<K>()`, etc.
37-
- Launch wrappers: `quantizeBlockwise_kbit<T,K>()`, `dequantizeBlockwise_kbit<T,K>()`
38-
- Grid calculation: `ceil(n/32)/8` CUDA blocks, 256 threads per block
39-
- Template instantiation macros
36+
## Files Modified (relative to main branch)
4037

41-
4. **`csrc/pythonInterface.cpp`** (two sections added):
42-
- Unmangled wrappers (inside `#if BUILD_CUDA || BUILD_HIP`): `test_pack_unpack_k{K}()`, `quantize_kbit_{fp16,bf16,fp32}_k{K}()`, etc.
43-
- extern "C" wrappers: `ctest_pack_unpack_k{K}()`, `cquantize_kbit_{tname}_k{K}()`, `cdequantize_kbit_{tname}_k{K}()`, etc.
38+
| File | What changed |
39+
|------|-------------|
40+
| `csrc/ops.cu` | Kernel definitions + device helpers + launch wrappers (~280 lines appended) |
41+
| `csrc/kernels.cu` | Removed: just a comment pointing to ops.cu |
42+
| `csrc/kernels.cuh` | Removed stale forward declarations (was causing "invalid device function") |
43+
| `csrc/pythonInterface.cpp` | Unmangled wrappers + extern "C" exports for all kbit functions |
44+
| `CMakeLists.txt` | Added `CUDA_RESOLVE_DEVICE_SYMBOLS ON` |
45+
| `tests/test_kbit_quantization.py` | Full test file: Python ref + CUDA tests + ctypes wrappers |
4446

45-
5. **`tests/test_kbit_quantization.py`** (comprehensive test file):
46-
- Python reference tests (Stage 0): `TestCodebook`, `TestQuantizeRef`, `TestPackUnpackRef`
47-
- CUDA ctypes wrappers: `_cuda_test_pack_unpack()`, `_cuda_quantize_kbit()`, `_cuda_dequantize_kbit()`, etc.
48-
- CUDA tests (Stages 1-5): `TestStage1PackUnpackCUDA`, `TestStage2PackMemoryCUDA`, `TestStage3CodebookLookupCUDA`, `TestStage4QuantizeCUDA`, `TestStage5DequantizeCUDA`
47+
### Key Architecture Decision During Implementation
4948

50-
## Current Blocker: RDC Device Linking
49+
Kernel definitions MUST live in `ops.cu` (same file as launch wrappers), not in `kernels.cu`. The project uses CUDA separable compilation (`-rdc=true`), and having forward declarations in `kernels.cuh` (without `__restrict__`) alongside definitions in a different TU (with `__restrict__`) caused mismatched CUDA function registration. Keeping everything in one compilation unit avoids this entirely.
5150

52-
### Problem
53-
The compiled kernels exist in the `.o` object files (verified via `nm`), and the C-level symbols are exported in the final `.so` (verified via `nm -D`), but the **CUDA device code** (fatbinary) does not contain the new kernel functions. Running any kernel gives "invalid device function".
51+
## C Interface (exported symbols)
5452

55-
### Root Cause
56-
The project uses `-rdc=true` (relocatable device code) for separate compilation. The device link step (`cmake_device_link.o`) needs to resolve all device-side references. The template instantiations in `kernels.cu` produce weak symbols in the object file, but the device linker may not be pulling them in because they're not referenced from the device link compilation unit.
53+
Test kernels (prefix `ctest_`):
54+
- `ctest_pack_unpack_k{2,3,4,5}(indices, recovered, n)`
55+
- `ctest_pack_write_k{2,3,4,5}(indices, packed_out, n)`
56+
- `ctest_read_unpack_k{2,3,4,5}(packed_in, indices_out, n)`
57+
- `ctest_codebook_lookup_k{2,3,4,5}(indices, codebook, out, n)`
5758

58-
### How to Fix (options)
59+
Production kernels:
60+
- `cquantize_kbit_{fp16,bf16,fp32}_k{2,3,4,5}(codebook, A, absmax, packed_out, n)`
61+
- `cdequantize_kbit_{fp16,bf16,fp32}_k{2,3,4,5}(packed_in, codebook, absmax, out, n, stream)`
5962

60-
1. **Add `__global__` function declarations to the device link file**: Check how CMake generates the device link step and ensure it sees all `.cu` object files.
61-
62-
2. **Use `--relocatable-device-code=false` for the kbit kernels**: If the kbit kernels don't need cross-file device calls, they could be compiled without RDC. But this requires CMake changes.
63-
64-
3. **Move kernel definitions to the same file as the launch wrappers**: Instead of splitting between `kernels.cu` (kernel definitions) and `ops.cu` (launch wrappers), put everything in a single `.cu` file. This is the simplest fix -- add the kernel bodies directly to `ops.cu` or create a new `kbit_kernels.cu` that contains both kernels and launch wrappers.
65-
66-
4. **Check CMakeLists.txt for device link configuration**: The CMake `CUDA_SEPARABLE_COMPILATION` property or `CUDA_RESOLVE_DEVICE_SYMBOLS` might need adjustment.
67-
68-
**Recommended fix**: Option 3 -- move all kbit kernel code from `kernels.cu` into `ops.cu` (or a new self-contained file). This sidesteps the RDC linking issue entirely since the kernel and its launch site would be in the same compilation unit.
69-
70-
## Build Instructions
63+
## Build & Test
7164

7265
```bash
7366
cd ~/git/bitsandbytes-kbit
7467
cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="89;90" -S . -B build
7568
make -C build -j$(nproc)
7669
ln -sf libbitsandbytes_cuda124.so bitsandbytes/libbitsandbytes_cuda128.so
70+
python -m pytest tests/test_kbit_quantization.py -p no:randomly -v # 157 pass
7771
```
7872

79-
## Test Instructions
73+
## Not Yet Implemented
8074

81-
```bash
82-
# Python-only tests (all pass)
83-
python -m pytest tests/test_kbit_quantization.py -k "not CUDA" -v
75+
### Stages 6-8 (test scripts only, no new kernels needed)
76+
- **Stage 6**: Round-trip error analysis (analytical bounds, empirical MSE on large tensors)
77+
- **Stage 7**: Cross-validate K=4 against existing NF4 dequant
78+
- **Stage 8**: Performance benchmarking (measure HBM bandwidth utilization, target 60-80%)
8479

85-
# CUDA tests (currently fail due to device link issue)
86-
python -m pytest tests/test_kbit_quantization.py -k "CUDA" -v
87-
```
88-
89-
## Not Yet Implemented
80+
### Python API
81+
- `bitsandbytes/functional.py`: `quantize_kbit()` and `dequantize_kbit()` public functions
82+
- `bitsandbytes/_ops.py`: `torch.library` registration
83+
- Codebook caching/registration system (precomputed NF codebooks for K=2..5)
9084

91-
- Stages 6-8: Error analysis, NF4 cross-validation, performance benchmarking (test code not written)
92-
- Python API in `bitsandbytes/functional.py` (quantize_kbit, dequantize_kbit)
93-
- `torch.library` registration in `bitsandbytes/_ops.py`
94-
- Codebook caching/registration system
85+
### Cleanup
86+
- Remove temporary test kernels (Stages 1-3) after confirming Stages 4+5 are solid
87+
- Remove `ctest_*` exports from pythonInterface.cpp
88+
- Update KBIT_PROGRESS.md or remove it

0 commit comments

Comments
 (0)