Skip to content

Commit a91c313

Browse files
TimDettmersclaude
andcommitted
docs: Update progress report with Stages 4-6 completion
Documents cp.async pipeline, split-K, bf16 support, ldmatrix swizzle, and benchmark results. Includes optimization opportunities for further work (multi-M-block, larger N, C staging, persistent kernel). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 27cf6a2 commit a91c313

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

progress.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,3 +1970,101 @@ Test criterion: output matches Stage 3 bit-for-bit.
19701970
- `cp_async_wait<1>()` inside the loop waits for the computing stage
19711971
- The first tile is prefetched before the loop starts
19721972
- `cp_async_wait<0>()` after the loop drains the pipeline
1973+
1974+
---
1975+
1976+
## 36. Implementation Progress: Stage 4-6 Complete
1977+
1978+
### Stage 4: cp.async Double-Buffered Pipeline (commit 9b155d3)
1979+
1980+
Replaces synchronous global→shared memory loads with `cp.async` double buffering.
1981+
B tile and absmax loaded via `cp.async.cg.shared.global` (16-byte copies, L2 only).
1982+
A tile loaded synchronously (needs M/K_dim bounds checking).
1983+
Output is bit-exact identical to Stage 3 for all K values.
1984+
1985+
**Tests:** 13 new tests → 89 total (all pass).
1986+
1987+
### Stage 5: Split-K GEMM (commit fdcec9c)
1988+
1989+
Adds split-K support: multiple blocks share an output tile, each handling a
1990+
subset of k-tiles. Partial sums accumulated via atomicAdd in fp32 workspace.
1991+
Grid is 2D for k_chunks=1, 3D for k_chunks>1. Last contributor (detected via
1992+
atomic tile counter) converts fp32→fp16 output.
1993+
1994+
**Tests:** 21 new tests → 110 total (all pass).
1995+
1996+
### Stage 6: Production Kernel with bf16, ldmatrix, Swizzle, Benchmarks
1997+
1998+
#### bf16 Support (commit 24406d2)
1999+
2000+
New production kernel `kbit_gemm_prod` templates on `scalar_t` (half or
2001+
__nv_bfloat16). Uses `if constexpr` to select the right MMA PTX instruction:
2002+
- fp16: `mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32`
2003+
- bf16: `mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32`
2004+
2005+
Helper structs `ScalarOps<T>`, `pack_two<T>`, and `mma_m16n8k16<T>` abstract
2006+
type-specific operations. 8 kernel variants instantiated (4 K × 2 dtypes).
2007+
2008+
fp16 path matches Stage 5 split-K output bit-for-bit.
2009+
bf16 path matches Python reference within tolerance for all K values.
2010+
2011+
**Tests:** 29 new tests → 139 total (all pass).
2012+
2013+
#### ldmatrix + XOR Swizzle (commit b64bb91)
2014+
2015+
Replaced 8 element-by-element shared memory reads per A fragment with a single
2016+
`ldmatrix.sync.aligned.m8n8.x4.shared.b16` instruction.
2017+
2018+
**The bank conflict problem:** Without swizzle, the A tile stored in shared
2019+
memory with stride TILE_K=64 halves (128 bytes) causes every row to start at
2020+
the same bank (stride is a multiple of 128 bytes = the bank repeat distance).
2021+
This gives 8-way bank conflicts during ldmatrix.
2022+
2023+
**The fix:** XOR-based swizzle at 8-half (16-byte) granularity:
2024+
```
2025+
col_group = col / 8
2026+
swizzled_group = col_group ^ (row % 8)
2027+
swizzled_col = swizzled_group * 8 + (col % 8)
2028+
```
2029+
2030+
Applied during A tile write to shared memory AND in the ldmatrix address
2031+
calculation. The XOR distributes 8 threads in an ldmatrix group across 8
2032+
different banks (zero conflicts).
2033+
2034+
Output is mathematically identical (verified by tests).
2035+
2036+
#### Benchmark Results (commit 27cf6a2)
2037+
2038+
RTX 4090, K=4 (4-bit), fp16, k_chunks=1:
2039+
2040+
| M | K_dim | N | kbit (µs) | kbit TFLOPS | cuBLAS (µs) | Speedup |
2041+
|---:|------:|------:|----------:|------------:|------------:|--------:|
2042+
| 1 | 4096 | 4096 | 109 | 0.31 | 43 | 0.39x |
2043+
| 1 | 4096 | 11008 | 82 | 1.10 | 128 | **1.56x** |
2044+
| 4 | 4096 | 11008 | 100 | 3.61 | 121 | **1.21x** |
2045+
| 4 | 4096 | 4096 | 92 | 1.46 | 22 | 0.24x |
2046+
2047+
**Analysis:** The kernel wins in the memory-bandwidth-bound regime (M=1, large
2048+
N) where reading 4x less weight data matters. It loses in compute-bound cases
2049+
because the current tile is small (TILE_M=16, only 2 N-blocks per warp).
2050+
2051+
### Optimization Opportunities for Further Work
2052+
2053+
1. **Multi-M-block tiling:** Template on M_BLOCKS (1-4) so TILE_M scales to
2054+
32/48/64. This is the biggest performance lever for M>1.
2055+
2. **Larger N_BLOCKS:** Use more of the warp's N-dimension capacity.
2056+
3. **C output staging through shared memory:** Coalesce the scattered fragment
2057+
writes to global memory (currently each thread writes to non-contiguous rows).
2058+
4. **Persistent kernel:** Replace the 3D grid with a persistent kernel that
2059+
loops over work items, reducing launch overhead and enabling better SM
2060+
utilization for small tile counts.
2061+
2062+
### Commit History (Stages 4-6)
2063+
2064+
```
2065+
27cf6a2 Add kbit GEMM benchmark script
2066+
b64bb91 Add ldmatrix + XOR swizzle for A-fragment loading in production kernel
2067+
24406d2 Add Stage 6 production kernel with bf16 support (139 tests pass)
2068+
fdcec9c Add Stage 5 split-K GEMM kernel (110 tests pass)
2069+
9b155d3 Add Stage 4 pipelined GEMM kernel with cp.async double-buffering (89 tests pass)
2070+
```

0 commit comments

Comments
 (0)