|
| 1 | +# Scalar GEMV Kernel Implementation Guide |
| 2 | + |
| 3 | +**Location:** `agents/scalar_gemv_guide.md` |
| 4 | +**Referenced from:** `optimization.md` Section 4 (P0: Scalar Kernel) |
| 5 | +**Key files to modify:** |
| 6 | +- `csrc/ops.cu` — CUDA kernel + launcher |
| 7 | +- `csrc/pythonInterface.cpp` — C wrappers |
| 8 | +- `bitsandbytes/_ops.py` — torch.library op definitions |
| 9 | +- `bitsandbytes/backends/cuda/ops.py` — Python dispatch |
| 10 | +- `tests/test_kbit_quantization.py` — correctness tests |
| 11 | + |
| 12 | +**Context documents:** `progress.md` (full dev record), `optimization.md` (kernel strategy) |
| 13 | + |
| 14 | +--- |
| 15 | + |
| 16 | +## 1. What This Kernel Does |
| 17 | + |
| 18 | +Computes `C[M, N] = A[M, K_dim] * W_kbit[K_dim, N]^T` for M=1-4 using |
| 19 | +scalar FMA instead of tensor core MMA. Supports both: |
| 20 | +- **Single-matrix** (dense layers): one weight matrix |
| 21 | +- **Grouped** (MoE experts): multiple expert weight matrices in one launch |
| 22 | + |
| 23 | +Uses the same tiled kbit data format as the existing MMA kernels — no |
| 24 | +repack changes needed. |
| 25 | + |
| 26 | +### Why it's needed |
| 27 | + |
| 28 | +At M=1, the MMA kernel wastes 93.75% of tensor core work (TILE_M=16, |
| 29 | +only 1 row has data). cuBLAS uses an optimized GEMV at M=1, achieving |
| 30 | +69% of peak DRAM bandwidth. Our MMA kernel achieves only 31%. The scalar |
| 31 | +kernel eliminates MMA waste entirely and should achieve ~50-60% bandwidth |
| 32 | +efficiency, translating the 3.6x data compression into a 2.5-3.5x speedup |
| 33 | +over cuBLAS. |
| 34 | + |
| 35 | +--- |
| 36 | + |
| 37 | +## 2. Architecture |
| 38 | + |
| 39 | +### Thread/block organization |
| 40 | + |
| 41 | +- **Block size:** 256 threads (8 warps), same as MMA kernel |
| 42 | +- **TILE_N:** 128 output columns per block (same as MMA kernel) |
| 43 | +- **TILE_K:** 64 (same as MMA kernel, matches tiled data format) |
| 44 | +- **No TILE_M concept** — M is a runtime parameter (1-4), not tiled |
| 45 | + |
| 46 | +Thread assignment for M=1: |
| 47 | +- 256 threads, 128 columns → 2 threads per column |
| 48 | +- Thread `t` and thread `t+128` split the K-dimension reduction |
| 49 | +- Thread `t` handles even k_tiles, thread `t+128` handles odd k_tiles |
| 50 | +- After all k_tiles: `__shfl_xor_sync` to reduce partial sums |
| 51 | + |
| 52 | +Thread assignment for M=2-4: |
| 53 | +- Each thread owns one column, processes all M rows |
| 54 | +- 256 threads / 128 columns = 2 threads per column (split K) |
| 55 | +- Each thread maintains M accumulators (`float acc[M_VAL]`) |
| 56 | +- Dequant done once per element, weight reused across M rows |
| 57 | + |
| 58 | +### Data flow per k_tile |
| 59 | + |
| 60 | +1. **Load B tile** (kbit packed + absmax) into shared memory via cp.async |
| 61 | + - Same cp.async pipeline as MMA kernel (double-buffered) |
| 62 | + - B data: TILE_N × KB_PER_TILE × K_BITS uint32 words = 1024 words for K=4 |
| 63 | + - Absmax: TILE_N × KB_PER_TILE = 256 bytes |
| 64 | +2. **Load A values** directly into registers from global memory |
| 65 | + - M × TILE_K × sizeof(half) = 128-512 bytes (tiny, no shared memory needed) |
| 66 | + - Simple coalesced load, no XOR swizzle needed |
| 67 | +3. **Dequant + FMA** in registers: |
| 68 | + - Read bit-plane words from shared memory |
| 69 | + - Extract K-bit index using bit manipulation |
| 70 | + - Codebook lookup via `__shfl_sync` |
| 71 | + - Scale by absmax |
| 72 | + - FMA: `acc[m] += weight * A_reg[m][k]` |
| 73 | +4. **Store output** directly to global memory |
| 74 | + |
| 75 | +### Codebook lookup |
| 76 | + |
| 77 | +Same technique as the dequant kernel: codebook entries stored in lane |
| 78 | +registers, lookup via `__shfl_sync`: |
| 79 | + |
| 80 | +```cuda |
| 81 | +// At kernel start: load codebook into lane registers |
| 82 | +float cb = (lane_id < (1 << K_BITS)) ? codebook[lane_id] : 0.0f; |
| 83 | +
|
| 84 | +// During dequant: lookup by index |
| 85 | +float val = __shfl_sync(0xFFFFFFFF, cb, idx); |
| 86 | +float weight = val * amax; |
| 87 | +``` |
| 88 | + |
| 89 | +This is register-to-register (~5 cycles), no shared memory needed for |
| 90 | +the codebook. |
| 91 | + |
| 92 | +### Shared memory budget |
| 93 | + |
| 94 | +Per stage (one of two double-buffer slots): |
| 95 | +- B tile: 128 × 2 × K × 4 bytes = 4096 bytes (K=4) |
| 96 | +- Absmax: 256 bytes (aligned to 272) |
| 97 | +- A tile: NOT in shared memory (loaded directly to registers) |
| 98 | +- Total per stage: ~4368 bytes |
| 99 | +- Double-buffered: ~8736 bytes |
| 100 | + |
| 101 | +Much less than the MMA kernel (~15-20 KB), so occupancy will be higher. |
| 102 | + |
| 103 | +--- |
| 104 | + |
| 105 | +## 3. Inner Loop Detail |
| 106 | + |
| 107 | +For each k_tile, each thread processes its assigned column's k-blocks: |
| 108 | + |
| 109 | +```cuda |
| 110 | +// Thread owns column 'col', handles k-blocks based on thread assignment |
| 111 | +// For M=1 with K-split: thread t handles k_blocks 0,2,4,... |
| 112 | +// thread t+128 handles k_blocks 1,3,5,... |
| 113 | +// (or split by k_tile: thread t does even k_tiles, t+128 odd k_tiles) |
| 114 | +
|
| 115 | +const int col = threadIdx.x % 128; // output column |
| 116 | +const int k_split_id = threadIdx.x / 128; // 0 or 1 |
| 117 | +
|
| 118 | +// After shared memory is ready for this k_tile: |
| 119 | +unsigned int* b_ptr = sh_b(stage); |
| 120 | +unsigned char* abs_ptr = sh_abs(stage); |
| 121 | +
|
| 122 | +#pragma unroll |
| 123 | +for (int kb = 0; kb < KB_PER_TILE; kb++) { // KB_PER_TILE = 2 |
| 124 | + // Load K bit-plane words for this column's k-block |
| 125 | + unsigned int planes[K_BITS]; |
| 126 | + int b_addr = col * B_COL_WORDS + kb * K_BITS; |
| 127 | + #pragma unroll |
| 128 | + for (int b = 0; b < K_BITS; b++) |
| 129 | + planes[b] = b_ptr[b_addr + b]; |
| 130 | +
|
| 131 | + float amax = decode_e4m4_absmax_branchless(abs_ptr[col * KB_PER_TILE + kb]); |
| 132 | +
|
| 133 | + int k_base_local = kb * 32; // within the k_tile |
| 134 | + int k_global = kt * TILE_K + k_base_local; |
| 135 | +
|
| 136 | + #pragma unroll |
| 137 | + for (int j = 0; j < 32; j++) { |
| 138 | + // Extract K-bit index |
| 139 | + int idx = 0; |
| 140 | + #pragma unroll |
| 141 | + for (int b = 0; b < K_BITS; b++) |
| 142 | + idx |= ((planes[b] >> j) & 1) << b; |
| 143 | +
|
| 144 | + float w = __shfl_sync(0xFFFFFFFF, cb, idx) * amax; |
| 145 | +
|
| 146 | + // FMA for each M row (dequant done once, reused) |
| 147 | + #pragma unroll |
| 148 | + for (int m = 0; m < M_VAL; m++) |
| 149 | + acc[m] += w * A_vals[m][k_global + j]; |
| 150 | + } |
| 151 | +} |
| 152 | +``` |
| 153 | + |
| 154 | +### A value loading strategy |
| 155 | + |
| 156 | +For M=1-4, A values are tiny. Two options: |
| 157 | + |
| 158 | +**Option A (simpler, recommended for first version):** |
| 159 | +Pre-load ALL A values for the full K_dim into registers at kernel start. |
| 160 | +For M=1, K=2048: 2048 fp16 = 4 KB. At M=4: 16 KB. This exceeds register |
| 161 | +file capacity, so use local memory (L1-cached, effectively free for |
| 162 | +sequential access). Access pattern: `A_vals[m][k]`. |
| 163 | + |
| 164 | +**Option B (more efficient):** |
| 165 | +Load A values per k_tile into registers. For M=1, TILE_K=64: 64 fp16 = |
| 166 | +128 bytes = 32 registers. Fits easily. Load from global memory at the |
| 167 | +start of each k_tile iteration (while waiting for cp.async of B data). |
| 168 | + |
| 169 | +Option B is better for register pressure. Implementation: |
| 170 | +```cuda |
| 171 | +// At start of each k_tile iteration: |
| 172 | +half A_local[M_VAL][TILE_K]; |
| 173 | +for (int m = 0; m < M_VAL; m++) |
| 174 | + for (int i = 0; i < TILE_K; i += 8) { |
| 175 | + // Vectorized load: 8 halves = 16 bytes |
| 176 | + int k = kt * TILE_K + i; |
| 177 | + if (k + 7 < K_dim) |
| 178 | + *(int4*)&A_local[m][i] = *(const int4*)&A[m * K_dim + k]; |
| 179 | + } |
| 180 | +``` |
| 181 | + |
| 182 | +--- |
| 183 | + |
| 184 | +## 4. Work Distribution |
| 185 | + |
| 186 | +### Single-matrix (dense layers) |
| 187 | + |
| 188 | +Grid: one block per n_tile. For N=5120: 40 blocks. Each block processes |
| 189 | +all K_dim for its 128 output columns. |
| 190 | + |
| 191 | +For shapes with few n_tiles (N=512 → 4 blocks), use K-splitting: |
| 192 | +launch more blocks, each handles a subset of k_tiles, atomicAdd partial |
| 193 | +results to workspace. Same split-K logic as the production MMA kernel. |
| 194 | + |
| 195 | +### Grouped (MoE experts) |
| 196 | + |
| 197 | +Same as `kbit_grouped_gemm_prod`: persistent kernel with work_offsets, |
| 198 | +binary search to find expert_id. Each work item is one (expert, n_tile). |
| 199 | +No split-K (grouping provides enough parallelism). |
| 200 | + |
| 201 | +The launcher computes work_offsets on the CPU side (tiny: num_experts+1 |
| 202 | +ints copied from device), same pattern as the existing grouped GEMM. |
| 203 | + |
| 204 | +--- |
| 205 | + |
| 206 | +## 5. Template Parameters |
| 207 | + |
| 208 | +```cuda |
| 209 | +template <int K_BITS, int M_VAL, typename scalar_t> |
| 210 | +__global__ void kbit_scalar_gemv( |
| 211 | + const scalar_t* __restrict__ A, |
| 212 | + const unsigned int* __restrict__ B_packed, |
| 213 | + const unsigned char* __restrict__ B_absmax, |
| 214 | + const float* __restrict__ codebook, |
| 215 | + scalar_t* __restrict__ C, |
| 216 | + float* __restrict__ C_workspace, // for split-K |
| 217 | + int* __restrict__ tile_counters, // for split-K |
| 218 | + const int M, const int K_dim, const int N, |
| 219 | + const int k_splits, const int total_work |
| 220 | +); |
| 221 | +``` |
| 222 | + |
| 223 | +- `K_BITS`: 2, 3, 4, 5 (compile-time, same as MMA kernel) |
| 224 | +- `M_VAL`: 1, 2, 3, 4 (compile-time, controls unrolling) |
| 225 | +- `scalar_t`: half, __nv_bfloat16 |
| 226 | + |
| 227 | +Grouped variant: |
| 228 | +```cuda |
| 229 | +template <int K_BITS, int M_VAL, typename scalar_t> |
| 230 | +__global__ void kbit_grouped_scalar_gemv( |
| 231 | + const scalar_t* __restrict__ A_concat, |
| 232 | + const unsigned int* __restrict__ B_packed_all, |
| 233 | + const unsigned char* __restrict__ B_absmax_all, |
| 234 | + const float* __restrict__ codebook, |
| 235 | + scalar_t* __restrict__ C_concat, |
| 236 | + const int* __restrict__ expert_offsets, |
| 237 | + const int* __restrict__ work_offsets, |
| 238 | + const int K_dim, const int N, |
| 239 | + const int num_experts, const int total_work |
| 240 | +); |
| 241 | +``` |
| 242 | + |
| 243 | +### Instantiations needed |
| 244 | + |
| 245 | +For each K in {2,3,4,5} × M_VAL in {1,2,4} × scalar_t in {half, bf16}: |
| 246 | +- 4 × 3 × 2 = 24 instantiations per kernel variant |
| 247 | +- Start with K=4, M=1, fp16 only for initial testing (1 instantiation) |
| 248 | +- Add remaining after correctness verified |
| 249 | + |
| 250 | +--- |
| 251 | + |
| 252 | +## 6. Implementation Steps |
| 253 | + |
| 254 | +### Step 1: CUDA kernel (`csrc/ops.cu`) |
| 255 | + |
| 256 | +Add after the grouped GEMM code (around line 2547): |
| 257 | + |
| 258 | +1. `kbit_scalar_gemv` kernel function (single-matrix with split-K) |
| 259 | +2. `kbit_grouped_scalar_gemv` kernel function (grouped, no split-K) |
| 260 | +3. `kbitScalarGemvLaunch` launcher (handles split-K grid sizing) |
| 261 | +4. `kbitScalarGemv` public entry (M_VAL dispatch + SM query) |
| 262 | +5. `kbitGroupedScalarGemv` public entry (M_VAL dispatch + work_offsets) |
| 263 | +6. Template instantiations at end of file |
| 264 | + |
| 265 | +### Step 2: C interface (`csrc/pythonInterface.cpp`) |
| 266 | + |
| 267 | +Add forward declarations and extern C wrappers: |
| 268 | +```cpp |
| 269 | +// Forward declarations |
| 270 | +#define MAKE_KBIT_SCALAR_GEMV_DECL(K) \ |
| 271 | + void kbit_scalar_gemv_fp16_k##K(...); \ |
| 272 | + void kbit_scalar_gemv_bf16_k##K(...); |
| 273 | + |
| 274 | +// Extern C wrappers |
| 275 | +#define MAKE_CKBIT_SCALAR_GEMV(K) \ |
| 276 | + void ckbit_scalar_gemv_fp16_k##K(...) { \ |
| 277 | + kbit_scalar_gemv_fp16_k##K(...); \ |
| 278 | + } |
| 279 | +``` |
| 280 | +
|
| 281 | +Same pattern for grouped variant. |
| 282 | +
|
| 283 | +### Step 3: Python op registration (`bitsandbytes/_ops.py`) |
| 284 | +
|
| 285 | +Register two new ops: |
| 286 | +```python |
| 287 | +torch.library.define("bitsandbytes::kbit_scalar_gemv", |
| 288 | + "(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, " |
| 289 | + "int K_dim, int N, int k) -> Tensor") |
| 290 | +
|
| 291 | +torch.library.define("bitsandbytes::kbit_grouped_scalar_gemv", |
| 292 | + "(Tensor A, Tensor B_packed_all, Tensor B_absmax_all, Tensor codebook, " |
| 293 | + "Tensor expert_offsets, int K_dim, int N, int k, int num_experts) -> Tensor") |
| 294 | +``` |
| 295 | + |
| 296 | +### Step 4: Python dispatch (`bitsandbytes/backends/cuda/ops.py`) |
| 297 | + |
| 298 | +Implement the CUDA backend kernels. Key: auto-select M_VAL template |
| 299 | +based on actual M: |
| 300 | +```python |
| 301 | +@register_kernel("bitsandbytes::kbit_scalar_gemv", "cuda") |
| 302 | +def _(A, B_packed, B_absmax, codebook, K_dim, N, k): |
| 303 | + M = A.shape[0] |
| 304 | + assert M <= 4 |
| 305 | + # Allocate output, workspace, tile_counters |
| 306 | + # Call ckbit_scalar_gemv_{dtype}_k{k} |
| 307 | +``` |
| 308 | + |
| 309 | +### Step 5: Correctness test |
| 310 | + |
| 311 | +Add to `tests/test_kbit_quantization.py`: |
| 312 | +```python |
| 313 | +@pytest.mark.parametrize("K_dim,N", [(2048, 512), (2048, 5120), (5120, 2048)]) |
| 314 | +@pytest.mark.parametrize("M", [1, 2, 4]) |
| 315 | +@pytest.mark.parametrize("k", [4]) |
| 316 | +def test_scalar_gemv_correctness(K_dim, N, M, k): |
| 317 | + # Quantize weight, compute reference via dequant + torch.mm |
| 318 | + # Compare against kbit_scalar_gemv output |
| 319 | + # Tolerance: same as existing GEMM tests |
| 320 | +``` |
| 321 | + |
| 322 | +### Step 6: Benchmark |
| 323 | + |
| 324 | +Extend `benchmarks/bench_crossover.py` to include scalar GEMV in the |
| 325 | +comparison table. Key comparison: scalar GEMV vs cuBLAS at M=1,2,4. |
| 326 | + |
| 327 | +--- |
| 328 | + |
| 329 | +## 7. Expected Performance |
| 330 | + |
| 331 | +Based on roofline analysis (see `optimization.md` Section 6): |
| 332 | + |
| 333 | +| Shape | Scalar est (M=1) | cuBLAS (M=1) | Projected speedup | |
| 334 | +|-------|------------------:|-------------:|------------------:| |
| 335 | +| gate/up 2048×5120 | ~5us | ~25us | ~5x | |
| 336 | +| down 5120×2048 | ~5us | ~25us | ~5x | |
| 337 | +| Q proj 2048×4096 | ~4us | ~17us | ~4x | |
| 338 | +| shared gate/up 2048×10240 | ~10us | ~55us | ~5.5x | |
| 339 | +| MoE expert 2048×512 (×8) | ~4us | ~17us | ~4x | |
| 340 | + |
| 341 | +Full model per-layer (all projections combined): |
| 342 | +- Qwen3 batch=1: ~27us kbit vs ~141us cuBLAS = **5.3x** |
| 343 | +- GLM4.7 batch=1: ~37us kbit vs ~157us cuBLAS = **4.3x** |
| 344 | + |
| 345 | +These use a 1.8x overhead factor over theoretical bandwidth minimum. |
| 346 | +The actual speedup depends on achieved bandwidth efficiency. |
| 347 | + |
| 348 | +--- |
| 349 | + |
| 350 | +## 8. Key Differences from MMA Kernel |
| 351 | + |
| 352 | +| Aspect | MMA kernel | Scalar kernel | |
| 353 | +|--------|-----------|---------------| |
| 354 | +| Inner compute | `mma.sync.aligned.m16n8k16` | Scalar FMA loop | |
| 355 | +| A data | Shared memory + ldmatrix + XOR swizzle | Registers (direct global load) | |
| 356 | +| B dequant output | Pack into MMA fragments (uint32) | Float value, used directly | |
| 357 | +| Thread→output mapping | Complex (gid/tid fragment layout) | Simple (thread % 128 = column) | |
| 358 | +| M handling | TILE_M=16, zero-padded | M_VAL template, no padding | |
| 359 | +| Registers/thread | ~128 (MMA fragments) | ~30-40 | |
| 360 | +| Occupancy | Low (register-limited) | High | |
| 361 | +| Shared memory | A tile + B tile + absmax (~15-20 KB) | B tile + absmax only (~9 KB) | |
| 362 | + |
| 363 | +--- |
| 364 | + |
| 365 | +## 9. Risks and Mitigations |
| 366 | + |
| 367 | +1. **Shared memory bank conflicts on B reads.** Multiple threads reading |
| 368 | + the same column's bit-plane words from shared memory. Mitigation: |
| 369 | + with 2 threads per column (K-split), only 2-way conflict. Acceptable. |
| 370 | + |
| 371 | +2. **Codebook shuffle across warp boundaries.** `__shfl_sync` only works |
| 372 | + within a warp. Threads in different warps processing the same column |
| 373 | + need independent codebook registers. This is already handled: each |
| 374 | + thread loads `cb = codebook[lane_id]` at kernel start. |
| 375 | + |
| 376 | +3. **Register spill for M=4.** Each thread needs 4 accumulators + A values |
| 377 | + + packed words + temporaries. Estimate: ~40 registers. Fine for sm_89 |
| 378 | + (255 max registers per thread). |
| 379 | + |
| 380 | +4. **K-split reduction overhead.** For single-matrix with N=512 (4 blocks), |
| 381 | + need split-K to fill 128 SMs. atomicAdd overhead for the split-K |
| 382 | + reduction adds ~5-10us. Still much faster than MMA kernel. For grouped |
| 383 | + dispatch, split-K is unnecessary (enough experts to fill SMs). |
0 commit comments