@@ -1970,3 +1970,101 @@ Test criterion: output matches Stage 3 bit-for-bit.
19701970- ` cp_async_wait<1>() ` inside the loop waits for the computing stage
19711971- The first tile is prefetched before the loop starts
19721972- ` cp_async_wait<0>() ` after the loop drains the pipeline
1973+
1974+ ---
1975+
1976+ ## 36. Implementation Progress: Stage 4-6 Complete
1977+
1978+ ### Stage 4: cp.async Double-Buffered Pipeline (commit 9b155d3)
1979+
1980+ Replaces synchronous global→shared memory loads with ` cp.async ` double buffering.
1981+ B tile and absmax loaded via ` cp.async.cg.shared.global ` (16-byte copies, L2 only).
1982+ A tile loaded synchronously (needs M/K_dim bounds checking).
1983+ Output is bit-exact identical to Stage 3 for all K values.
1984+
1985+ ** Tests:** 13 new tests → 89 total (all pass).
1986+
1987+ ### Stage 5: Split-K GEMM (commit fdcec9c)
1988+
1989+ Adds split-K support: multiple blocks share an output tile, each handling a
1990+ subset of k-tiles. Partial sums accumulated via atomicAdd in fp32 workspace.
1991+ Grid is 2D for k_chunks=1, 3D for k_chunks>1. Last contributor (detected via
1992+ atomic tile counter) converts fp32→fp16 output.
1993+
1994+ ** Tests:** 21 new tests → 110 total (all pass).
1995+
1996+ ### Stage 6: Production Kernel with bf16, ldmatrix, Swizzle, Benchmarks
1997+
1998+ #### bf16 Support (commit 24406d2)
1999+
2000+ New production kernel ` kbit_gemm_prod ` templates on ` scalar_t ` (half or
2001+ __ nv_bfloat16). Uses ` if constexpr ` to select the right MMA PTX instruction:
2002+ - fp16: ` mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 `
2003+ - bf16: ` mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 `
2004+
2005+ Helper structs ` ScalarOps<T> ` , ` pack_two<T> ` , and ` mma_m16n8k16<T> ` abstract
2006+ type-specific operations. 8 kernel variants instantiated (4 K × 2 dtypes).
2007+
2008+ fp16 path matches Stage 5 split-K output bit-for-bit.
2009+ bf16 path matches Python reference within tolerance for all K values.
2010+
2011+ ** Tests:** 29 new tests → 139 total (all pass).
2012+
2013+ #### ldmatrix + XOR Swizzle (commit b64bb91)
2014+
2015+ Replaced 8 element-by-element shared memory reads per A fragment with a single
2016+ ` ldmatrix.sync.aligned.m8n8.x4.shared.b16 ` instruction.
2017+
2018+ ** The bank conflict problem:** Without swizzle, the A tile stored in shared
2019+ memory with stride TILE_K=64 halves (128 bytes) causes every row to start at
2020+ the same bank (stride is a multiple of 128 bytes = the bank repeat distance).
2021+ This gives 8-way bank conflicts during ldmatrix.
2022+
2023+ ** The fix:** XOR-based swizzle at 8-half (16-byte) granularity:
2024+ ```
2025+ col_group = col / 8
2026+ swizzled_group = col_group ^ (row % 8)
2027+ swizzled_col = swizzled_group * 8 + (col % 8)
2028+ ```
2029+
2030+ Applied during A tile write to shared memory AND in the ldmatrix address
2031+ calculation. The XOR distributes 8 threads in an ldmatrix group across 8
2032+ different banks (zero conflicts).
2033+
2034+ Output is mathematically identical (verified by tests).
2035+
2036+ #### Benchmark Results (commit 27cf6a2)
2037+
2038+ RTX 4090, K=4 (4-bit), fp16, k_chunks=1:
2039+
2040+ | M | K_dim | N | kbit (µs) | kbit TFLOPS | cuBLAS (µs) | Speedup |
2041+ | ---:| ------:| ------:| ----------:| ------------:| ------------:| --------:|
2042+ | 1 | 4096 | 4096 | 109 | 0.31 | 43 | 0.39x |
2043+ | 1 | 4096 | 11008 | 82 | 1.10 | 128 | ** 1.56x** |
2044+ | 4 | 4096 | 11008 | 100 | 3.61 | 121 | ** 1.21x** |
2045+ | 4 | 4096 | 4096 | 92 | 1.46 | 22 | 0.24x |
2046+
2047+ ** Analysis:** The kernel wins in the memory-bandwidth-bound regime (M=1, large
2048+ N) where reading 4x less weight data matters. It loses in compute-bound cases
2049+ because the current tile is small (TILE_M=16, only 2 N-blocks per warp).
2050+
2051+ ### Optimization Opportunities for Further Work
2052+
2053+ 1 . ** Multi-M-block tiling:** Template on M_BLOCKS (1-4) so TILE_M scales to
2054+ 32/48/64. This is the biggest performance lever for M>1.
2055+ 2 . ** Larger N_BLOCKS:** Use more of the warp's N-dimension capacity.
2056+ 3 . ** C output staging through shared memory:** Coalesce the scattered fragment
2057+ writes to global memory (currently each thread writes to non-contiguous rows).
2058+ 4 . ** Persistent kernel:** Replace the 3D grid with a persistent kernel that
2059+ loops over work items, reducing launch overhead and enabling better SM
2060+ utilization for small tile counts.
2061+
2062+ ### Commit History (Stages 4-6)
2063+
2064+ ```
2065+ 27cf6a2 Add kbit GEMM benchmark script
2066+ b64bb91 Add ldmatrix + XOR swizzle for A-fragment loading in production kernel
2067+ 24406d2 Add Stage 6 production kernel with bf16 support (139 tests pass)
2068+ fdcec9c Add Stage 5 split-K GEMM kernel (110 tests pass)
2069+ 9b155d3 Add Stage 4 pipelined GEMM kernel with cp.async double-buffering (89 tests pass)
2070+ ```
0 commit comments