Skip to content

Commit c25d7af

Browse files
TimDettmersclaude
andcommitted
Merge feature/qutlass-nvfp4-gemm into QLORA-2
Combines the latest inference kernel work from qutlass (NVFP4, VQ, MoE optimizations for SM100/B200) with the QLORA-2 training infrastructure (autograd, LoRA, weight streaming, pipeline parallelism). Resolved conflicts by: - Using qutlass versions as base for most files (newer kernel work) - Appending QLORA-2 training code (kernels, bindings, classes) to preserve training capabilities - Removing obsolete code that qutlass intentionally cleaned up Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2 parents eea8518 + 1895945 commit c25d7af

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+24777
-7336
lines changed

CLAUDE.md

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,36 @@ Do NOT run the full test suite — it takes 10+ minutes. Instead, run only the t
4040
pytest tests/test_relevant_file.py -v --tb=short -k "relevant_test_name"
4141
```
4242

43-
The full suite will be run separately. Best practices and known issues: `agents/testing_guide.md`
43+
The full suite will be run separately. Best practices, benchmark data, and known architecture-specific issues: `agents/testing_guide.md`
44+
45+
# Benchmarking
46+
47+
Benchmark scripts live in `benchmarks/`. The two kbit-specific ones:
48+
49+
- `bench_hadamard.py` — Hadamard rotation kernel + M=1 pipeline (rotation + scalar GEMV) vs cuBLAS FP16. Quick focused benchmark for the decode path.
50+
- `bench_kbit_vlm.py` — Comprehensive sweep across all VLM-relevant M values (1 to 1024), all kernel variants (scalar GEMV, MMA, dequant+cuBLAS), all k values (2-5), with and without Hadamard rotation. GLM-4.7 shapes (see `spec.md` § Target Model for layer dimensions).
51+
52+
```bash
53+
# Quick M=1 decode benchmark
54+
python benchmarks/bench_hadamard.py
55+
56+
# Full VLM sweep (all M, all k)
57+
python benchmarks/bench_kbit_vlm.py
58+
59+
# Single k value, subset of M
60+
python benchmarks/bench_kbit_vlm.py --k 4 --m 1,4,16,256,1024
61+
62+
# Higher accuracy (more iterations)
63+
python benchmarks/bench_kbit_vlm.py --inner 1000 --outer 30
64+
```
65+
66+
## CUDA graph benchmarking methodology
67+
68+
Single graph replay has a ~14 us timing floor (on RTX 4090) that masks sub-14 us kernel differences. The benchmarks use **batched graph replay**: replay the graph N times within one event-timed region, then divide. This amortizes the per-replay overhead to ~14/N us per iteration.
69+
70+
The `--inner` flag controls N (replays per measurement). Default 500 gives ~0.03 us amortized overhead. Use `--inner 1000` for the highest accuracy when comparing kernels that differ by < 1 us.
71+
72+
`--outer` controls the number of measurements (default 15). The median is reported to reject outliers.
4473

4574
# Agent Dispatch (the "Dispatcher" role)
4675

CMakeLists.txt

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,9 @@ if(BUILD_CUDA)
248248
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.8" AND EXISTS "${CMAKE_SOURCE_DIR}/third_party/cutlass/include")
249249
list(APPEND _NVFP4_SM120_SOURCES
250250
csrc/qutlass/gemm_nvfp4_sm120.cu
251-
csrc/qutlass/scale_reorder.cu
252-
csrc/qutlass/fused_quantize_nv.cu
253251
)
254252
set(_HAS_CUTLASS_NVFP4 TRUE)
255-
message(STATUS "CUTLASS NVFP4 SM_120a GEMM + fused quantize enabled")
253+
message(STATUS "CUTLASS NVFP4 SM_120a GEMM enabled")
256254
else()
257255
set(_HAS_CUTLASS_NVFP4 FALSE)
258256
message(STATUS "CUTLASS NVFP4 GEMM disabled (needs CUDA >= 12.8 and third_party/cutlass)")
@@ -286,6 +284,88 @@ if(BUILD_CUDA)
286284
message(STATUS "NVFP4 SM_120a GEMM kernel enabled")
287285
endif()
288286

287+
# SM_100a NVFP4 GEMM kernel: requires compute_100a for block-scaled MMA
288+
# Only include if 100 or 101 is in the target architectures
289+
set(_HAS_SM100 FALSE)
290+
foreach(_cap IN LISTS COMPUTE_CAPABILITY)
291+
if(_cap MATCHES "^10[01]$")
292+
set(_HAS_SM100 TRUE)
293+
endif()
294+
endforeach()
295+
if(_LATEST_CAPABILITY MATCHES "^10[01]$")
296+
set(_HAS_SM100 TRUE)
297+
endif()
298+
if(_HAS_SM100)
299+
# CUTLASS-based NVFP4 GEMM for SM_100 (requires CUDA 12.8+)
300+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.8" AND EXISTS "${CMAKE_SOURCE_DIR}/third_party/cutlass/include")
301+
set(_NVFP4_SM100_SOURCES
302+
csrc/qutlass/gemm_nvfp4_sm100.cu
303+
csrc/qutlass/gemm_nvfp4_moe_sm100.cu
304+
)
305+
306+
add_library(nvfp4_sm100a OBJECT ${_NVFP4_SM100_SOURCES})
307+
set_target_properties(nvfp4_sm100a PROPERTIES
308+
CUDA_ARCHITECTURES "100a"
309+
POSITION_INDEPENDENT_CODE ON
310+
CUDA_SEPARABLE_COMPILATION OFF
311+
)
312+
target_compile_options(nvfp4_sm100a PRIVATE
313+
$<$<COMPILE_LANGUAGE:CUDA>:--use_fast_math>
314+
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
315+
$<$<COMPILE_LANGUAGE:CUDA>:-std=c++17>
316+
$<$<COMPILE_LANGUAGE:CUDA>:-O3>
317+
$<$<COMPILE_LANGUAGE:CUDA>:-DNDEBUG>
318+
$<$<COMPILE_LANGUAGE:CUDA>:-DQUTLASS_DISABLE_PYBIND>
319+
)
320+
target_include_directories(nvfp4_sm100a PRIVATE
321+
"${CMAKE_SOURCE_DIR}/third_party/cutlass/include"
322+
"${CMAKE_SOURCE_DIR}/third_party/cutlass/tools/util/include"
323+
"${CMAKE_SOURCE_DIR}/csrc/qutlass/include"
324+
)
325+
message(STATUS "CUTLASS NVFP4 SM_100a GEMM enabled")
326+
else()
327+
message(STATUS "CUTLASS NVFP4 SM_100a GEMM disabled (needs CUDA >= 12.8 and third_party/cutlass)")
328+
endif()
329+
endif()
330+
331+
# Common CUTLASS utilities (scale_reorder, fused_quantize) compiled for
332+
# ALL enabled Blackwell architectures. Kept in one library to avoid
333+
# duplicate C symbol errors from the extern "C" wrappers.
334+
if(_HAS_CUTLASS_NVFP4 OR _HAS_SM100)
335+
set(_NVFP4_COMMON_ARCHS "")
336+
if(_HAS_SM120 AND _HAS_CUTLASS_NVFP4)
337+
list(APPEND _NVFP4_COMMON_ARCHS "120a")
338+
endif()
339+
if(_HAS_SM100)
340+
list(APPEND _NVFP4_COMMON_ARCHS "100a")
341+
endif()
342+
343+
add_library(nvfp4_common OBJECT
344+
csrc/qutlass/scale_reorder.cu
345+
csrc/qutlass/fused_quantize_nv.cu
346+
csrc/qutlass/moe_scatter_gather.cu
347+
)
348+
set_target_properties(nvfp4_common PROPERTIES
349+
CUDA_ARCHITECTURES "${_NVFP4_COMMON_ARCHS}"
350+
POSITION_INDEPENDENT_CODE ON
351+
CUDA_SEPARABLE_COMPILATION OFF
352+
)
353+
target_include_directories(nvfp4_common PRIVATE
354+
"${CMAKE_SOURCE_DIR}/third_party/cutlass/include"
355+
"${CMAKE_SOURCE_DIR}/third_party/cutlass/tools/util/include"
356+
"${CMAKE_SOURCE_DIR}/csrc/qutlass/include"
357+
)
358+
target_compile_options(nvfp4_common PRIVATE
359+
$<$<COMPILE_LANGUAGE:CUDA>:--use_fast_math>
360+
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
361+
$<$<COMPILE_LANGUAGE:CUDA>:-std=c++17>
362+
$<$<COMPILE_LANGUAGE:CUDA>:-O3>
363+
$<$<COMPILE_LANGUAGE:CUDA>:-DNDEBUG>
364+
$<$<COMPILE_LANGUAGE:CUDA>:-DQUTLASS_DISABLE_PYBIND>
365+
)
366+
message(STATUS "CUTLASS common utilities (scale_reorder, fused_quantize) for archs: ${_NVFP4_COMMON_ARCHS}")
367+
endif()
368+
289369
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
290370
add_compile_definitions(BUILD_CUDA)
291371
elseif(BUILD_HIP)
@@ -379,6 +459,16 @@ if(TARGET nvfp4_sm120a)
379459
target_sources(bitsandbytes PRIVATE $<TARGET_OBJECTS:nvfp4_sm120a>)
380460
endif()
381461

462+
# Link NVFP4 SM_100a object library if available
463+
if(TARGET nvfp4_sm100a)
464+
target_sources(bitsandbytes PRIVATE $<TARGET_OBJECTS:nvfp4_sm100a>)
465+
endif()
466+
467+
# Link common CUTLASS utilities (scale_reorder, fused_quantize) if available
468+
if(TARGET nvfp4_common)
469+
target_sources(bitsandbytes PRIVATE $<TARGET_OBJECTS:nvfp4_common>)
470+
endif()
471+
382472
if (BUILD_CPU)
383473
if (OpenMP_CXX_FOUND)
384474
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)

PROGRESS.md

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Absmax format migration: float32 -> uint8 E4M4 (default) + float16 (option)
2+
3+
Branch: `experiment/scalar-gemv-int8-absmax`
4+
Worktree: `/home/tim/git/bnb-kbit-gemm-int8-absmax`
5+
Base: `23f92e5` (feature/kbit-gemv-v8)
6+
7+
## Motivation
8+
9+
Benchmarking shows uint8 E4M4 absmax has identical performance to float32
10+
absmax in the scalar GEMV kernel, and adds at most ~4.5% to mean absolute
11+
error (at k=5; negligible at k=2-3) on top of the existing kbit quantization
12+
error. Switching to uint8 halves absmax storage (4 bytes -> 1 byte per quant
13+
block) and unifies the format across all kernels.
14+
15+
## Current absmax formats (before this branch)
16+
17+
| Kernel | Absmax type | Layout |
18+
|---------------------|---------------|--------|
19+
| MMA (dense) | uint8 E4M4 | tiled |
20+
| MMA (grouped/MoE) | uint8 E4M4 | tiled |
21+
| Scalar GEMV (dense) | **float32** | flat |
22+
| Scalar GEMV (grouped/MoE) | **float32** | flat |
23+
| Dequantize | templated (both) | flat/tiled |
24+
25+
**Target**: all kernels use uint8 E4M4 by default, with float16 as alternative.
26+
Remove float32 absmax path entirely.
27+
28+
## Current status
29+
30+
### Code changes DONE (uncommitted, in working tree):
31+
32+
**CUDA kernels (`csrc/ops.cu`)**:
33+
- Moved E4M4 encode/decode functions before quantize kernel (eliminated forward declaration issue)
34+
- `kQuantizeBlockwise_kbit`: writes `unsigned char*` absmax via `encode_e4m4_absmax(amax)`
35+
- `kRepackKbit`: accepts `unsigned char*` absmax input, copies bytes directly (no re-encode)
36+
- `kbitScalarGemv` / `kbitGroupedScalarGemv`: `unsigned char*` absmax + `load_absmax()` decode
37+
- All launchers, entry points, and template instantiations updated
38+
39+
**C++ interface (`csrc/pythonInterface.cpp`)**:
40+
- All forward declarations, wrappers, and extern C macros updated for `unsigned char*`
41+
- Added extern C wrappers for fp16abs scalar GEMV + grouped scalar GEMV (16 new symbols)
42+
43+
**Python (`bitsandbytes/`)**:
44+
- `backends/cuda/ops.py`: quantize_kbit allocates uint8, repack_kbit expects uint8
45+
- `backends/cuda/ops.py`: scalar GEMV + grouped GEMV dispatch routes by absmax dtype (uint8 default, fp16 via `_fp16abs` suffix)
46+
- `_ops.py`: quantize_kbit fake op returns uint8
47+
- `functional.py`: removed redundant Python-side E4M4 encode (kernel does it natively)
48+
49+
**Tests**:
50+
- `test_scalar_gemv.py`: added `decode_e4m4_absmax`, updated `dequant_reference`
51+
- `test_kbit_gemm.py`: `quantize_kbit_ref` returns uint8 E4M4, updated dequant/repack refs
52+
53+
**Benchmarks**:
54+
- `ncu_driver.py`: updated comments, removed stale `.cuda()` call; all 4 kernel modes verified
55+
56+
### Bug: illegal memory access at runtime — FIXED
57+
58+
Root cause: stale build artifact. The previous session's `make` command
59+
didn't actually recompile `ops.cu` after source changes. The `.so` still
60+
had the old `float*` absmax signature while `pythonInterface.cpp` was
61+
passing `unsigned char*` via ctypes — causing out-of-bounds reads (the
62+
kernel read 4 bytes per absmax element instead of 1).
63+
64+
Fix: clean rebuild (`rm -rf build && cmake -B build ... && make`).
65+
66+
## Work items
67+
68+
### 1. Scalar GEMV (dense) — float32 -> uint8 E4M4
69+
- [x] Baseline benchmark (current float32)
70+
- [x] Change kernel to use `unsigned char*` + `load_absmax<unsigned char>`
71+
- [x] Update pythonInterface.cpp, backends/cuda/ops.py
72+
- [x] **FIX BUG**: stale build — clean rebuild fixed it
73+
- [x] Post-change benchmark
74+
- [x] Record results below — **no regression**
75+
76+
### 2. Grouped scalar GEMV (MoE) — float32 -> uint8 E4M4
77+
- [x] Baseline benchmark (current float32)
78+
- [x] Change kernel to use `unsigned char*` + `load_absmax<unsigned char>`
79+
- [x] Update pythonInterface.cpp, backends/cuda/ops.py
80+
- [x] **FIX BUG**: same stale build issue
81+
- [x] Post-change benchmark
82+
- [x] Record results below — **within noise for M=4, slight regression for M=1**
83+
84+
### 3. quantize_kbit — return uint8 E4M4 by default
85+
- [x] Add E4M4 encode to quantize kernel (`encode_e4m4_absmax` in kQuantizeBlockwise_kbit)
86+
- [x] Update Python op return type (`_ops.py` allocates uint8, `backends/cuda/ops.py` allocates uint8)
87+
- [x] Remove Python-side double-encode in `functional.py::quantize_kbit` (kernel does it natively)
88+
- [x] Update repack_kbit: kernel accepts `unsigned char*` input, just copies bytes (no re-encode)
89+
- [x] Move E4M4 encode/decode definitions before quantize kernel (was forward-declared, caused issues)
90+
- [x] **BUG FIXED**: Previous session's forward declaration of `encode_e4m4_absmax` before `E4M4_BIAS`
91+
was defined compiled but produced wrong results. Moved all E4M4 functions before quantize kernel.
92+
- [x] **BUG FIXED**: `functional.py::quantize_kbit` applied Python-side E4M4 encode on top of the
93+
already-encoded kernel output (double encoding). Removed the redundant Python encode.
94+
95+
### 4. Add float16 absmax alternative path — DONE
96+
- [x] Generic `load_absmax<ABSMAX_T>` already handles `half` (casts to float)
97+
- [x] Templated scalar GEMV + grouped scalar GEMV on `ABSMAX_T` (default = `unsigned char`)
98+
- [x] Added fp16 absmax template instantiations in ops.cu
99+
- [x] Added fp16abs C++ wrappers in pythonInterface.cpp (unmangled functions ready)
100+
- [x] Added extern C wrappers for fp16abs scalar GEMV + grouped scalar GEMV (in pythonInterface.cpp)
101+
- [x] Added Python dispatch: absmax dtype routing via `_fp16abs` suffix in `backends/cuda/ops.py`
102+
- [x] `_ops.py` — no changes needed, torch op defs use generic `Tensor` type
103+
- [x] Build compiles, all 31 scalar GEMV tests pass, all 195 GEMM tests pass
104+
- [x] Verified fp16abs path produces identical results to uint8 path (when E4M4→fp16 is lossless)
105+
106+
### 5. Tests
107+
- [x] Updated test_scalar_gemv.py: added `decode_e4m4_absmax`, updated `dequant_reference`
108+
- [x] Updated test_kbit_gemm.py: `quantize_kbit_ref` now returns uint8 E4M4, updated dequant/repack refs
109+
- [x] All 31 test_scalar_gemv tests pass
110+
- [x] All 195 test_kbit_gemm tests pass
111+
- [ ] test_grouped_gemm.py has pre-existing failures (missing `max_M` arg, not related)
112+
113+
### 6. Benchmark driver — DONE
114+
- [x] Updated ncu_driver.py: comment fix (uint8 absmax), removed stale `.cuda()` call
115+
- [x] All 4 kernel modes (mma, scalar, grouped, grouped_mma) verified working
116+
117+
### 7. Update _ops.py
118+
- [x] No changes needed — torch op defs use generic `Tensor` type
119+
120+
## Benchmark results
121+
122+
### Scalar GEMV (dense)
123+
124+
#### Baseline (float32 absmax)
125+
126+
CUDA events, WARMUP=50, ITERS=200, fp16, RTX 4090
127+
128+
| shape | k | M | us |
129+
|----------|----|----|-------|
130+
| gateup | 3 | 1 | 87.5 |
131+
| gateup | 3 | 4 | 163.5 |
132+
| gateup | 4 | 1 | 117.1 |
133+
| gateup | 4 | 4 | 172.7 |
134+
| down | 3 | 1 | 80.4 |
135+
| down | 3 | 4 | 165.5 |
136+
| down | 4 | 1 | 118.9 |
137+
| down | 4 | 4 | 186.3 |
138+
| Q | 3 | 1 | 36.7 |
139+
| Q | 3 | 4 | 64.2 |
140+
| Q | 4 | 1 | 38.9 |
141+
| Q | 4 | 4 | 65.7 |
142+
| KV | 3 | 1 | 36.7 |
143+
| KV | 3 | 4 | 35.9 |
144+
| KV | 4 | 1 | 36.1 |
145+
| KV | 4 | 4 | 36.5 |
146+
147+
#### After change (uint8 E4M4 absmax)
148+
149+
CUDA events, WARMUP=100, ITERS=500, fp16, RTX 4090
150+
Baseline and uint8 runs done with proper `pip install -e .` for each worktree.
151+
152+
| shape | k | M | f32(us) | u8(us) | delta |
153+
|----------|----|----|----------|---------|-------|
154+
| gateup | 3 | 1 | 81.6 | 83.5 | +2.3% |
155+
| gateup | 3 | 4 | 164.1 | 168.7 | +2.8% |
156+
| gateup | 4 | 1 | 104.5 | 101.2 | -3.2% |
157+
| gateup | 4 | 4 | 151.9 | 146.9 | -3.3% |
158+
| down | 3 | 1 | 69.2 | 74.2 | +7.2% |
159+
| down | 3 | 4 | 169.1 | 152.9 | -9.6% |
160+
| down | 4 | 1 | 120.6 | 85.6 | -29.0% |
161+
| down | 4 | 4 | 185.4 | 176.6 | -4.7% |
162+
| Q | 3 | 1 | 38.5 | 39.1 | +1.6% |
163+
| Q | 3 | 4 | 60.7 | 72.1 | +18.8% |
164+
| Q | 4 | 1 | 37.4 | 40.1 | +7.2% |
165+
| Q | 4 | 4 | 65.7 | 62.9 | -4.3% |
166+
| KV | 3 | 1 | 38.5 | 37.1 | -3.6% |
167+
| KV | 3 | 4 | 35.5 | 37.2 | +4.8% |
168+
| KV | 4 | 1 | 36.3 | 37.7 | +3.9% |
169+
| KV | 4 | 4 | 35.8 | 39.7 | +10.9% |
170+
171+
**Summary**: High variance between runs (up to ~30% swing on some shapes).
172+
Overall no consistent pattern — performance is essentially equivalent.
173+
The variance dominates any signal from the absmax format change.
174+
175+
### Grouped scalar GEMV (MoE)
176+
177+
#### Baseline (float32 absmax)
178+
179+
CUDA events, WARMUP=100, ITERS=500, fp16, 8 experts, RTX 4090
180+
181+
| shape | k | M | us |
182+
|----------|----|----|-------|
183+
| moe_gu | 3 | 1 | 47.8 |
184+
| moe_gu | 3 | 4 | 101.8 |
185+
| moe_gu | 4 | 1 | 58.3 |
186+
| moe_gu | 4 | 4 | 103.6 |
187+
| moe_dn | 3 | 1 | 47.2 |
188+
| moe_dn | 3 | 4 | 92.7 |
189+
| moe_dn | 4 | 1 | 55.0 |
190+
| moe_dn | 4 | 4 | 94.2 |
191+
192+
#### After change (uint8 E4M4 absmax)
193+
194+
CUDA events, WARMUP=100, ITERS=500, fp16, 8 experts, RTX 4090
195+
196+
| shape | k | M | f32(us) | u8(us) | delta |
197+
|----------|----|----|----------|---------|-------|
198+
| moe_gu | 3 | 1 | 47.8 | 58.3 | +22.0% |
199+
| moe_gu | 3 | 4 | 101.8 | 98.8 | -2.9% |
200+
| moe_gu | 4 | 1 | 58.3 | 61.3 | +5.1% |
201+
| moe_gu | 4 | 4 | 103.6 | 102.0 | -1.5% |
202+
| moe_dn | 3 | 1 | 47.2 | 51.9 | +10.0% |
203+
| moe_dn | 3 | 4 | 92.7 | 91.3 | -1.5% |
204+
| moe_dn | 4 | 1 | 55.0 | 57.6 | +4.7% |
205+
| moe_dn | 4 | 4 | 94.2 | 92.5 | -1.8% |
206+
207+
**Summary**: M=4 cases within noise (~+/-3%). M=1 cases show 5-22% regression,
208+
possibly from E4M4 decode overhead being a larger fraction of work with only
209+
1 row of FMA. But variance is high — the moe_gu k=3 M=1 outlier (+22%) is
210+
likely noise since other M=1 shapes show only +5%.

0 commit comments

Comments
 (0)