Skip to content

Commit aba14e8

Browse files
TimDettmersclaude
andcommitted
Add scalar GEMV implementation guide
Detailed guide for implementing the scalar (non-MMA) GEMV kernel for M=1-4 decode. Covers thread mapping, inner loop design, data flow, template parameters, work distribution, and all files to modify. Referenced from optimization.md P0 section. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 96ba7a2 commit aba14e8

File tree

2 files changed

+391
-5
lines changed

2 files changed

+391
-5
lines changed

agents/scalar_gemv_guide.md

Lines changed: 383 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
# Scalar GEMV Kernel Implementation Guide
2+
3+
**Location:** `agents/scalar_gemv_guide.md`
4+
**Referenced from:** `optimization.md` Section 4 (P0: Scalar Kernel)
5+
**Key files to modify:**
6+
- `csrc/ops.cu` — CUDA kernel + launcher
7+
- `csrc/pythonInterface.cpp` — C wrappers
8+
- `bitsandbytes/_ops.py` — torch.library op definitions
9+
- `bitsandbytes/backends/cuda/ops.py` — Python dispatch
10+
- `tests/test_kbit_quantization.py` — correctness tests
11+
12+
**Context documents:** `progress.md` (full dev record), `optimization.md` (kernel strategy)
13+
14+
---
15+
16+
## 1. What This Kernel Does
17+
18+
Computes `C[M, N] = A[M, K_dim] * W_kbit[K_dim, N]^T` for M=1-4 using
19+
scalar FMA instead of tensor core MMA. Supports both:
20+
- **Single-matrix** (dense layers): one weight matrix
21+
- **Grouped** (MoE experts): multiple expert weight matrices in one launch
22+
23+
Uses the same tiled kbit data format as the existing MMA kernels — no
24+
repack changes needed.
25+
26+
### Why it's needed
27+
28+
At M=1, the MMA kernel wastes 93.75% of tensor core work (TILE_M=16,
29+
only 1 row has data). cuBLAS uses an optimized GEMV at M=1, achieving
30+
69% of peak DRAM bandwidth. Our MMA kernel achieves only 31%. The scalar
31+
kernel eliminates MMA waste entirely and should achieve ~50-60% bandwidth
32+
efficiency, translating the 3.6x data compression into a 2.5-3.5x speedup
33+
over cuBLAS.
34+
35+
---
36+
37+
## 2. Architecture
38+
39+
### Thread/block organization
40+
41+
- **Block size:** 256 threads (8 warps), same as MMA kernel
42+
- **TILE_N:** 128 output columns per block (same as MMA kernel)
43+
- **TILE_K:** 64 (same as MMA kernel, matches tiled data format)
44+
- **No TILE_M concept** — M is a runtime parameter (1-4), not tiled
45+
46+
Thread assignment for M=1:
47+
- 256 threads, 128 columns → 2 threads per column
48+
- Thread `t` and thread `t+128` split the K-dimension reduction
49+
- Thread `t` handles even k_tiles, thread `t+128` handles odd k_tiles
50+
- After all k_tiles: `__shfl_xor_sync` to reduce partial sums
51+
52+
Thread assignment for M=2-4:
53+
- Each thread owns one column, processes all M rows
54+
- 256 threads / 128 columns = 2 threads per column (split K)
55+
- Each thread maintains M accumulators (`float acc[M_VAL]`)
56+
- Dequant done once per element, weight reused across M rows
57+
58+
### Data flow per k_tile
59+
60+
1. **Load B tile** (kbit packed + absmax) into shared memory via cp.async
61+
- Same cp.async pipeline as MMA kernel (double-buffered)
62+
- B data: TILE_N × KB_PER_TILE × K_BITS uint32 words = 1024 words for K=4
63+
- Absmax: TILE_N × KB_PER_TILE = 256 bytes
64+
2. **Load A values** directly into registers from global memory
65+
- M × TILE_K × sizeof(half) = 128-512 bytes (tiny, no shared memory needed)
66+
- Simple coalesced load, no XOR swizzle needed
67+
3. **Dequant + FMA** in registers:
68+
- Read bit-plane words from shared memory
69+
- Extract K-bit index using bit manipulation
70+
- Codebook lookup via `__shfl_sync`
71+
- Scale by absmax
72+
- FMA: `acc[m] += weight * A_reg[m][k]`
73+
4. **Store output** directly to global memory
74+
75+
### Codebook lookup
76+
77+
Same technique as the dequant kernel: codebook entries stored in lane
78+
registers, lookup via `__shfl_sync`:
79+
80+
```cuda
81+
// At kernel start: load codebook into lane registers
82+
float cb = (lane_id < (1 << K_BITS)) ? codebook[lane_id] : 0.0f;
83+
84+
// During dequant: lookup by index
85+
float val = __shfl_sync(0xFFFFFFFF, cb, idx);
86+
float weight = val * amax;
87+
```
88+
89+
This is register-to-register (~5 cycles), no shared memory needed for
90+
the codebook.
91+
92+
### Shared memory budget
93+
94+
Per stage (one of two double-buffer slots):
95+
- B tile: 128 × 2 × K × 4 bytes = 4096 bytes (K=4)
96+
- Absmax: 256 bytes (aligned to 272)
97+
- A tile: NOT in shared memory (loaded directly to registers)
98+
- Total per stage: ~4368 bytes
99+
- Double-buffered: ~8736 bytes
100+
101+
Much less than the MMA kernel (~15-20 KB), so occupancy will be higher.
102+
103+
---
104+
105+
## 3. Inner Loop Detail
106+
107+
For each k_tile, each thread processes its assigned column's k-blocks:
108+
109+
```cuda
110+
// Thread owns column 'col', handles k-blocks based on thread assignment
111+
// For M=1 with K-split: thread t handles k_blocks 0,2,4,...
112+
// thread t+128 handles k_blocks 1,3,5,...
113+
// (or split by k_tile: thread t does even k_tiles, t+128 odd k_tiles)
114+
115+
const int col = threadIdx.x % 128; // output column
116+
const int k_split_id = threadIdx.x / 128; // 0 or 1
117+
118+
// After shared memory is ready for this k_tile:
119+
unsigned int* b_ptr = sh_b(stage);
120+
unsigned char* abs_ptr = sh_abs(stage);
121+
122+
#pragma unroll
123+
for (int kb = 0; kb < KB_PER_TILE; kb++) { // KB_PER_TILE = 2
124+
// Load K bit-plane words for this column's k-block
125+
unsigned int planes[K_BITS];
126+
int b_addr = col * B_COL_WORDS + kb * K_BITS;
127+
#pragma unroll
128+
for (int b = 0; b < K_BITS; b++)
129+
planes[b] = b_ptr[b_addr + b];
130+
131+
float amax = decode_e4m4_absmax_branchless(abs_ptr[col * KB_PER_TILE + kb]);
132+
133+
int k_base_local = kb * 32; // within the k_tile
134+
int k_global = kt * TILE_K + k_base_local;
135+
136+
#pragma unroll
137+
for (int j = 0; j < 32; j++) {
138+
// Extract K-bit index
139+
int idx = 0;
140+
#pragma unroll
141+
for (int b = 0; b < K_BITS; b++)
142+
idx |= ((planes[b] >> j) & 1) << b;
143+
144+
float w = __shfl_sync(0xFFFFFFFF, cb, idx) * amax;
145+
146+
// FMA for each M row (dequant done once, reused)
147+
#pragma unroll
148+
for (int m = 0; m < M_VAL; m++)
149+
acc[m] += w * A_vals[m][k_global + j];
150+
}
151+
}
152+
```
153+
154+
### A value loading strategy
155+
156+
For M=1-4, A values are tiny. Two options:
157+
158+
**Option A (simpler, recommended for first version):**
159+
Pre-load ALL A values for the full K_dim into registers at kernel start.
160+
For M=1, K=2048: 2048 fp16 = 4 KB. At M=4: 16 KB. This exceeds register
161+
file capacity, so use local memory (L1-cached, effectively free for
162+
sequential access). Access pattern: `A_vals[m][k]`.
163+
164+
**Option B (more efficient):**
165+
Load A values per k_tile into registers. For M=1, TILE_K=64: 64 fp16 =
166+
128 bytes = 32 registers. Fits easily. Load from global memory at the
167+
start of each k_tile iteration (while waiting for cp.async of B data).
168+
169+
Option B is better for register pressure. Implementation:
170+
```cuda
171+
// At start of each k_tile iteration:
172+
half A_local[M_VAL][TILE_K];
173+
for (int m = 0; m < M_VAL; m++)
174+
for (int i = 0; i < TILE_K; i += 8) {
175+
// Vectorized load: 8 halves = 16 bytes
176+
int k = kt * TILE_K + i;
177+
if (k + 7 < K_dim)
178+
*(int4*)&A_local[m][i] = *(const int4*)&A[m * K_dim + k];
179+
}
180+
```
181+
182+
---
183+
184+
## 4. Work Distribution
185+
186+
### Single-matrix (dense layers)
187+
188+
Grid: one block per n_tile. For N=5120: 40 blocks. Each block processes
189+
all K_dim for its 128 output columns.
190+
191+
For shapes with few n_tiles (N=512 → 4 blocks), use K-splitting:
192+
launch more blocks, each handles a subset of k_tiles, atomicAdd partial
193+
results to workspace. Same split-K logic as the production MMA kernel.
194+
195+
### Grouped (MoE experts)
196+
197+
Same as `kbit_grouped_gemm_prod`: persistent kernel with work_offsets,
198+
binary search to find expert_id. Each work item is one (expert, n_tile).
199+
No split-K (grouping provides enough parallelism).
200+
201+
The launcher computes work_offsets on the CPU side (tiny: num_experts+1
202+
ints copied from device), same pattern as the existing grouped GEMM.
203+
204+
---
205+
206+
## 5. Template Parameters
207+
208+
```cuda
209+
template <int K_BITS, int M_VAL, typename scalar_t>
210+
__global__ void kbit_scalar_gemv(
211+
const scalar_t* __restrict__ A,
212+
const unsigned int* __restrict__ B_packed,
213+
const unsigned char* __restrict__ B_absmax,
214+
const float* __restrict__ codebook,
215+
scalar_t* __restrict__ C,
216+
float* __restrict__ C_workspace, // for split-K
217+
int* __restrict__ tile_counters, // for split-K
218+
const int M, const int K_dim, const int N,
219+
const int k_splits, const int total_work
220+
);
221+
```
222+
223+
- `K_BITS`: 2, 3, 4, 5 (compile-time, same as MMA kernel)
224+
- `M_VAL`: 1, 2, 3, 4 (compile-time, controls unrolling)
225+
- `scalar_t`: half, __nv_bfloat16
226+
227+
Grouped variant:
228+
```cuda
229+
template <int K_BITS, int M_VAL, typename scalar_t>
230+
__global__ void kbit_grouped_scalar_gemv(
231+
const scalar_t* __restrict__ A_concat,
232+
const unsigned int* __restrict__ B_packed_all,
233+
const unsigned char* __restrict__ B_absmax_all,
234+
const float* __restrict__ codebook,
235+
scalar_t* __restrict__ C_concat,
236+
const int* __restrict__ expert_offsets,
237+
const int* __restrict__ work_offsets,
238+
const int K_dim, const int N,
239+
const int num_experts, const int total_work
240+
);
241+
```
242+
243+
### Instantiations needed
244+
245+
For each K in {2,3,4,5} × M_VAL in {1,2,4} × scalar_t in {half, bf16}:
246+
- 4 × 3 × 2 = 24 instantiations per kernel variant
247+
- Start with K=4, M=1, fp16 only for initial testing (1 instantiation)
248+
- Add remaining after correctness verified
249+
250+
---
251+
252+
## 6. Implementation Steps
253+
254+
### Step 1: CUDA kernel (`csrc/ops.cu`)
255+
256+
Add after the grouped GEMM code (around line 2547):
257+
258+
1. `kbit_scalar_gemv` kernel function (single-matrix with split-K)
259+
2. `kbit_grouped_scalar_gemv` kernel function (grouped, no split-K)
260+
3. `kbitScalarGemvLaunch` launcher (handles split-K grid sizing)
261+
4. `kbitScalarGemv` public entry (M_VAL dispatch + SM query)
262+
5. `kbitGroupedScalarGemv` public entry (M_VAL dispatch + work_offsets)
263+
6. Template instantiations at end of file
264+
265+
### Step 2: C interface (`csrc/pythonInterface.cpp`)
266+
267+
Add forward declarations and extern C wrappers:
268+
```cpp
269+
// Forward declarations
270+
#define MAKE_KBIT_SCALAR_GEMV_DECL(K) \
271+
void kbit_scalar_gemv_fp16_k##K(...); \
272+
void kbit_scalar_gemv_bf16_k##K(...);
273+
274+
// Extern C wrappers
275+
#define MAKE_CKBIT_SCALAR_GEMV(K) \
276+
void ckbit_scalar_gemv_fp16_k##K(...) { \
277+
kbit_scalar_gemv_fp16_k##K(...); \
278+
}
279+
```
280+
281+
Same pattern for grouped variant.
282+
283+
### Step 3: Python op registration (`bitsandbytes/_ops.py`)
284+
285+
Register two new ops:
286+
```python
287+
torch.library.define("bitsandbytes::kbit_scalar_gemv",
288+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, "
289+
"int K_dim, int N, int k) -> Tensor")
290+
291+
torch.library.define("bitsandbytes::kbit_grouped_scalar_gemv",
292+
"(Tensor A, Tensor B_packed_all, Tensor B_absmax_all, Tensor codebook, "
293+
"Tensor expert_offsets, int K_dim, int N, int k, int num_experts) -> Tensor")
294+
```
295+
296+
### Step 4: Python dispatch (`bitsandbytes/backends/cuda/ops.py`)
297+
298+
Implement the CUDA backend kernels. Key: auto-select M_VAL template
299+
based on actual M:
300+
```python
301+
@register_kernel("bitsandbytes::kbit_scalar_gemv", "cuda")
302+
def _(A, B_packed, B_absmax, codebook, K_dim, N, k):
303+
M = A.shape[0]
304+
assert M <= 4
305+
# Allocate output, workspace, tile_counters
306+
# Call ckbit_scalar_gemv_{dtype}_k{k}
307+
```
308+
309+
### Step 5: Correctness test
310+
311+
Add to `tests/test_kbit_quantization.py`:
312+
```python
313+
@pytest.mark.parametrize("K_dim,N", [(2048, 512), (2048, 5120), (5120, 2048)])
314+
@pytest.mark.parametrize("M", [1, 2, 4])
315+
@pytest.mark.parametrize("k", [4])
316+
def test_scalar_gemv_correctness(K_dim, N, M, k):
317+
# Quantize weight, compute reference via dequant + torch.mm
318+
# Compare against kbit_scalar_gemv output
319+
# Tolerance: same as existing GEMM tests
320+
```
321+
322+
### Step 6: Benchmark
323+
324+
Extend `benchmarks/bench_crossover.py` to include scalar GEMV in the
325+
comparison table. Key comparison: scalar GEMV vs cuBLAS at M=1,2,4.
326+
327+
---
328+
329+
## 7. Expected Performance
330+
331+
Based on roofline analysis (see `optimization.md` Section 6):
332+
333+
| Shape | Scalar est (M=1) | cuBLAS (M=1) | Projected speedup |
334+
|-------|------------------:|-------------:|------------------:|
335+
| gate/up 2048×5120 | ~5us | ~25us | ~5x |
336+
| down 5120×2048 | ~5us | ~25us | ~5x |
337+
| Q proj 2048×4096 | ~4us | ~17us | ~4x |
338+
| shared gate/up 2048×10240 | ~10us | ~55us | ~5.5x |
339+
| MoE expert 2048×512 (×8) | ~4us | ~17us | ~4x |
340+
341+
Full model per-layer (all projections combined):
342+
- Qwen3 batch=1: ~27us kbit vs ~141us cuBLAS = **5.3x**
343+
- GLM4.7 batch=1: ~37us kbit vs ~157us cuBLAS = **4.3x**
344+
345+
These use a 1.8x overhead factor over theoretical bandwidth minimum.
346+
The actual speedup depends on achieved bandwidth efficiency.
347+
348+
---
349+
350+
## 8. Key Differences from MMA Kernel
351+
352+
| Aspect | MMA kernel | Scalar kernel |
353+
|--------|-----------|---------------|
354+
| Inner compute | `mma.sync.aligned.m16n8k16` | Scalar FMA loop |
355+
| A data | Shared memory + ldmatrix + XOR swizzle | Registers (direct global load) |
356+
| B dequant output | Pack into MMA fragments (uint32) | Float value, used directly |
357+
| Thread→output mapping | Complex (gid/tid fragment layout) | Simple (thread % 128 = column) |
358+
| M handling | TILE_M=16, zero-padded | M_VAL template, no padding |
359+
| Registers/thread | ~128 (MMA fragments) | ~30-40 |
360+
| Occupancy | Low (register-limited) | High |
361+
| Shared memory | A tile + B tile + absmax (~15-20 KB) | B tile + absmax only (~9 KB) |
362+
363+
---
364+
365+
## 9. Risks and Mitigations
366+
367+
1. **Shared memory bank conflicts on B reads.** Multiple threads reading
368+
the same column's bit-plane words from shared memory. Mitigation:
369+
with 2 threads per column (K-split), only 2-way conflict. Acceptable.
370+
371+
2. **Codebook shuffle across warp boundaries.** `__shfl_sync` only works
372+
within a warp. Threads in different warps processing the same column
373+
need independent codebook registers. This is already handled: each
374+
thread loads `cb = codebook[lane_id]` at kernel start.
375+
376+
3. **Register spill for M=4.** Each thread needs 4 accumulators + A values
377+
+ packed words + temporaries. Estimate: ~40 registers. Fine for sm_89
378+
(255 max registers per thread).
379+
380+
4. **K-split reduction overhead.** For single-matrix with N=512 (4 blocks),
381+
need split-K to fill 128 SMs. atomicAdd overhead for the split-K
382+
reduction adds ~5-10us. Still much faster than MMA kernel. For grouped
383+
dispatch, split-K is unnecessary (enough experts to fill SMs).

0 commit comments

Comments
 (0)