@@ -63,27 +63,26 @@ __device__ __forceinline__ uint32_t pack_8_nibbles_slow(const unsigned char* dat
6363}
6464
6565// ============================================================================
66- // Optimized NVFP4 GEMM kernel
66+ // Shared-memory NVFP4 GEMM kernel
6767//
68- // Key optimizations over kGemmNVFP4_simple:
69- // 1. Vectorized uint32 loads: Each MMA register's 8 nibbles map to 4 consecutive
70- // bytes in memory. Load as uint32 instead of 8 individual nibble extractions.
71- // 2. Multi-N per warp: Each warp computes m16 x nN_TILE_PER_WARP (4 MMA
72- // instructions per K-step), reusing A registers across N-slices.
73- // 3. Shared memory A tile: All warps in a block share the same m16 tile of A.
74- // A is loaded cooperatively into shared memory, then each warp reads its
75- // registers from smem. This gives N_WARPS x reuse of A bandwidth.
68+ // Key optimizations:
69+ // 1. Cooperative tiling: All threads cooperatively load A/B/SFA/SFB tiles from
70+ // global memory into shared memory with coalesced access patterns.
71+ // 2. Data reuse: A tile shared across N_WARPS (4x saving), B tile shared
72+ // across M_WARPS (2x saving). Total ~2.4x bandwidth reduction.
73+ // 3. Fast register packing: MMA registers read from smem as uint32 loads.
74+ // SFA/SFB packed registers loaded as single uint32 (all 4 bytes are
75+ // consecutive in the same row — proven by CuTE layout analysis).
76+ // 4. Vectorized global loads: A uses uint32 (4B), B uses uint4 (16B).
7677//
77- // Register layout (derived from CuTE ALayout/BLayout analysis ):
78+ // Register layout (from CuTE ALayout/BLayout):
7879// A reg[i] = A[tile_m + 2*t1 + (i&1), k_start + t0*8 + (i>>1)*32 .. +7]
7980// B reg[i] = B[tile_n + t1, k_start + t0*8 + i*32 .. +7]
80- // where t0 = lane%4, t1 = lane/4 (CuTE thread decomposition )
81- // Each register's 8 nibbles = 4 consecutive packed bytes in memory.
81+ // SFA packed = SFA[actual_m, k_blk 0..3] (consecutive in memory )
82+ // SFB packed = SFB[tile_n + t1, k_blk 0..3] (consecutive in memory)
8283//
83- // Block/warp configuration:
84- // 4 warps per block, block tile = m16 x n32
85- // Each warp handles a different n8 slice, all share same m16
86- // Shared memory: A tile (512 bytes) + SFA (64 bytes) per K-step
84+ // Block tile: m32 x n128 (M_WARPS=2, N_WARPS=4, N_TILES_PER_WARP=4)
85+ // Shared memory per K-step: 1024 + 4096 + 128 + 512 = 5760 bytes
8786// ============================================================================
8887
8988// N-tiles per warp: each warp computes m16 x (N_TILES_PER_WARP * 8)
@@ -93,173 +92,197 @@ __device__ __forceinline__ uint32_t pack_8_nibbles_slow(const unsigned char* dat
9392#define N_WARPS 4
9493#define WARPS_PER_BLOCK (M_WARPS * N_WARPS) // 8
9594
96- // 256 threads, target 4 blocks/SM (limit regs to 48 via maxrregcount)
97- __global__ __launch_bounds__ (WARPS_PER_BLOCK * 32 , 4 ) void kGemmNVFP4_opt(
95+ // Block tile dimensions
96+ #define BLOCK_M_DIM (M_WARPS * 16 ) // 32
97+ #define BLOCK_N_DIM (N_WARPS * N_TILES_PER_WARP * 8 ) // 128
98+
99+ // Shared memory sizes (bytes per K-step)
100+ #define SMEM_A_BYTES (BLOCK_M_DIM * 32 ) // 1024
101+ #define SMEM_B_BYTES (BLOCK_N_DIM * 32 ) // 4096
102+ #define SMEM_SFA_BYTES (BLOCK_M_DIM * 4 ) // 128
103+ #define SMEM_SFB_BYTES (BLOCK_N_DIM * 4 ) // 512
104+ #define SMEM_TOTAL (SMEM_A_BYTES + SMEM_B_BYTES + SMEM_SFA_BYTES + SMEM_SFB_BYTES)
105+
106+ // 256 threads, target 4 blocks/SM for occupancy
107+ __global__ __launch_bounds__ (WARPS_PER_BLOCK * 32 , 4 ) void kGemmNVFP4_smem(
98108 const unsigned char * __restrict__ A, // M x K/2 packed FP4 (row-major)
99109 const unsigned char * __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major)
100110 const unsigned char * __restrict__ SFA, // M x K/16 UE4M3 scales
101111 const unsigned char * __restrict__ SFB, // N x K/16 UE4M3 scales
102112 float * __restrict__ D, // M x N output (F32)
103113 int M, int N, int K
104114) {
105- // Block tile: m(M_WARPS*16) x n(N_WARPS * N_TILES_PER_WARP * 8)
106- // = m32 x n128 for 2x4 warps
107- const int BLOCK_M = M_WARPS * 16 ;
108- const int BLOCK_N = N_WARPS * N_TILES_PER_WARP * 8 ;
109-
110- int warp_in_block = threadIdx .x / 32 ;
111- int lane_id = threadIdx .x % 32 ;
112-
113- // 2D warp mapping: m_warp along M, n_warp along N
114- int m_warp = warp_in_block / N_WARPS; // 0..(M_WARPS-1)
115- int n_warp = warp_in_block % N_WARPS; // 0..(N_WARPS-1)
116-
117- // Block-level tile position
118- int tile_m = blockIdx .y * BLOCK_M + m_warp * 16 ;
119- int tile_n_base = blockIdx .x * BLOCK_N;
120-
121- if (tile_m >= M)
122- return ;
123-
124- // This warp's N offset within the block
125- int warp_n_base = tile_n_base + n_warp * N_TILES_PER_WARP * 8 ;
126-
127- // CuTE thread decomposition: t0 = lane%4 (0-3), t1 = lane/4 (0-7)
128- int t0 = lane_id % 4 ;
129- int t1 = lane_id / 4 ;
130-
131- // Precompute A row indices for this thread's registers
132- // reg[0,2] → row0 = tile_m + 2*t1, reg[1,3] → row1 = tile_m + 2*t1 + 1
133- int a_row0 = tile_m + 2 * t1;
134- int a_row1 = a_row0 + 1 ;
135-
136- int half_K = K / 2 ;
137- int scale_stride_K = K / 16 ;
138-
139- // Accumulators: N_TILES_PER_WARP * 4 floats per thread
115+ // Shared memory: 16-byte aligned for uint4 stores
116+ __shared__ __align__ (16 ) unsigned char smem[SMEM_TOTAL]; // 5760 bytes
117+ unsigned char * smem_A = smem;
118+ unsigned char * smem_B = smem + SMEM_A_BYTES;
119+ unsigned char * smem_SFA = smem + SMEM_A_BYTES + SMEM_B_BYTES;
120+ unsigned char * smem_SFB = smem + SMEM_A_BYTES + SMEM_B_BYTES + SMEM_SFA_BYTES;
121+
122+ const int tid = threadIdx .x ;
123+ const int warp_in_block = tid / 32 ;
124+ const int lane_id = tid % 32 ;
125+ const int m_warp = warp_in_block / N_WARPS; // 0..1
126+ const int n_warp = warp_in_block % N_WARPS; // 0..3
127+
128+ const int block_m = blockIdx .y * BLOCK_M_DIM;
129+ const int block_n = blockIdx .x * BLOCK_N_DIM;
130+ const int tile_m = block_m + m_warp * 16 ;
131+ const int warp_n_base = block_n + n_warp * N_TILES_PER_WARP * 8 ;
132+
133+ const int t0 = lane_id % 4 ;
134+ const int t1 = lane_id / 4 ;
135+ const int half_K = K / 2 ;
136+ const int scale_K = K / 16 ;
137+
138+ // Accumulators
140139 float acc[N_TILES_PER_WARP][4 ];
141140#pragma unroll
142141 for (int nt = 0 ; nt < N_TILES_PER_WARP; nt++) {
143- acc[nt][0 ] = 0 .0f ;
144- acc[nt][1 ] = 0 .0f ;
145- acc[nt][2 ] = 0 .0f ;
146- acc[nt][3 ] = 0 .0f ;
142+ acc[nt][0 ] = acc[nt][1 ] = acc[nt][2 ] = acc[nt][3 ] = 0 .0f ;
147143 }
148144
145+ // Precompute smem row indices for A register reads
146+ const int a_local_row0 = m_warp * 16 + 2 * t1;
147+ const int a_local_row1 = a_local_row0 + 1 ;
148+
149+ // Precompute SFA row for this thread (all 4 bytes come from the same row)
150+ // sf_tidx = (lane%2)*8 + lane/4; cute_m_0 = sf_tidx % 16
151+ // actual_m = (cute_m_0 % 8)*2 + cute_m_0/8
152+ const int sf_tidx = (lane_id % 2 ) * 8 + (lane_id / 4 );
153+ const int cute_sf_m0 = sf_tidx % 16 ;
154+ const int sfa_local_row = m_warp * 16 + (cute_sf_m0 % 8 ) * 2 + cute_sf_m0 / 8 ;
155+
149156 // K-loop
150157 for (int k_start = 0 ; k_start < K; k_start += 64 ) {
151- // ---- Load A registers (4 x uint32) ----
152- // reg[0] = A[row0, k_start + t0*8 + 0..7] → 4 bytes at row0*K/2 + (k_start+t0*8)/2
153- // reg[1] = A[row1, k_start + t0*8 + 0..7]
154- // reg[2] = A[row0, k_start + t0*8 + 32..39]
155- // reg[3] = A[row1, k_start + t0*8 + 32..39]
156- uint32_t a_regs[4 ];
157- int k_col_lo = k_start + t0 * 8 ;
158- int k_col_hi = k_col_lo + 32 ;
159-
160- // Fast path: no boundary check needed
161- bool a_row0_ok = (a_row0 < M);
162- bool a_row1_ok = (a_row1 < M);
163- bool k_lo_ok = (k_col_lo + 7 < K);
164- bool k_hi_ok = (k_col_hi + 7 < K);
165-
166- if (a_row0_ok && k_lo_ok) {
167- a_regs[0 ] = *(const uint32_t *)(A + a_row0 * half_K + k_col_lo / 2 );
168- } else {
169- a_regs[0 ] = pack_8_nibbles_slow (A, a_row0, k_col_lo, K, M, K);
170- }
171- if (a_row1_ok && k_lo_ok) {
172- a_regs[1 ] = *(const uint32_t *)(A + a_row1 * half_K + k_col_lo / 2 );
173- } else {
174- a_regs[1 ] = pack_8_nibbles_slow (A, a_row1, k_col_lo, K, M, K);
175- }
176- if (a_row0_ok && k_hi_ok) {
177- a_regs[2 ] = *(const uint32_t *)(A + a_row0 * half_K + k_col_hi / 2 );
178- } else {
179- a_regs[2 ] = pack_8_nibbles_slow (A, a_row0, k_col_hi, K, M, K);
180- }
181- if (a_row1_ok && k_hi_ok) {
182- a_regs[3 ] = *(const uint32_t *)(A + a_row1 * half_K + k_col_hi / 2 );
183- } else {
184- a_regs[3 ] = pack_8_nibbles_slow (A, a_row1, k_col_hi, K, M, K);
185- }
158+ const int k_byte = k_start / 2 ;
159+ const int k_scale = k_start / 16 ;
186160
187- // ---- Load SFA ----
188- // SFA layout: sf_thread_idx = (lane%2)*8 + (lane/4)
189- // Scale coord = sf_thread_idx + v*16 → cute_m = coord%16, k_blk = coord/16
190- // Remap: actual_m = (cute_m%8)*2 + cute_m/8
191- uint32_t sfa_packed = 0 ;
161+ // ================================================================
162+ // Phase 1: Cooperative load from global → shared memory
163+ // ================================================================
164+
165+ // ---- A tile: BLOCK_M×32 = 1024 bytes, 256 threads × 4 bytes each ----
192166 {
193- int sf_tidx = (lane_id % 2 ) * 8 + (lane_id / 4 );
194- for (int sv = 0 ; sv < 4 ; sv++) {
195- int sfe = sf_tidx + sv * 16 ;
196- int cute_sf_m = sfe % 16 ;
197- int sf_col = sfe / 16 ;
198- int sf_row = (cute_sf_m % 8 ) * 2 + cute_sf_m / 8 ;
199- int gm = tile_m + sf_row;
200- int gkb = k_start / 16 + sf_col;
201- unsigned char sf_val = 0 ;
202- if (gm < M && gkb < scale_stride_K) {
203- sf_val = SFA[gm * scale_stride_K + gkb];
167+ const int off = tid * 4 ; // byte offset in smem_A (0..1020)
168+ const int row = off >> 5 ; // off / 32 → local row (0..31)
169+ const int col = off & 31 ; // off % 32 → byte col (0,4,...,28)
170+ const int gm = block_m + row;
171+
172+ uint32_t val = 0 ;
173+ if (gm < M) {
174+ const int gaddr = gm * half_K + k_byte + col;
175+ if (k_byte + col + 3 < half_K) {
176+ val = *(const uint32_t *)(A + gaddr);
177+ } else {
178+ // K-boundary: byte-by-byte
179+ for (int b = 0 ; b < 4 ; b++) {
180+ if (k_byte + col + b < half_K)
181+ val |= ((uint32_t )A[gaddr + b]) << (b * 8 );
182+ }
204183 }
205- sfa_packed |= ((uint32_t )sf_val << (sv * 8 ));
206184 }
185+ *(uint32_t *)(smem_A + off) = val;
207186 }
208187
209- // ---- For each N-tile in this warp ----
210- #pragma unroll
211- for (int nt = 0 ; nt < N_TILES_PER_WARP; nt++) {
212- int this_tile_n = warp_n_base + nt * 8 ;
213- if (this_tile_n >= N)
214- break ;
215-
216- // Load B registers (2 x uint32)
217- // reg[0] = B[this_tile_n + t1, k_start + t0*8 + 0..7]
218- // reg[1] = B[this_tile_n + t1, k_start + t0*8 + 32..39]
219- uint32_t b_regs[2 ];
220- int b_row = this_tile_n + t1;
221- bool b_row_ok = (b_row < N);
222-
223- if (b_row_ok && k_lo_ok) {
224- b_regs[0 ] = *(const uint32_t *)(B + b_row * half_K + k_col_lo / 2 );
188+ // ---- B tile: BLOCK_N×32 = 4096 bytes, 256 threads × 16 bytes each ----
189+ {
190+ const int off = tid * 16 ; // byte offset in smem_B (0..4080)
191+ const int row = off >> 5 ; // local row (0..127)
192+ const int col = off & 31 ; // 0 or 16
193+ const int gn = block_n + row;
194+
195+ if (gn < N) {
196+ const int gaddr = gn * half_K + k_byte + col;
197+ if (k_byte + col + 15 < half_K) {
198+ *(uint4 *)(smem_B + off) = *(const uint4 *)(B + gaddr);
199+ } else {
200+ // K-boundary: byte-by-byte
201+ for (int b = 0 ; b < 16 ; b++) {
202+ smem_B[off + b] = (k_byte + col + b < half_K) ? B[gaddr + b] : 0 ;
203+ }
204+ }
225205 } else {
226- b_regs[0 ] = pack_8_nibbles_slow (B, b_row, k_col_lo, K, N, K);
206+ // N-boundary: zero-fill
207+ *(uint4 *)(smem_B + off) = make_uint4 (0 , 0 , 0 , 0 );
227208 }
228- if (b_row_ok && k_hi_ok) {
229- b_regs[1 ] = *(const uint32_t *)(B + b_row * half_K + k_col_hi / 2 );
230- } else {
231- b_regs[1 ] = pack_8_nibbles_slow (B, b_row, k_col_hi, K, N, K);
209+ }
210+
211+ // ---- SFA: BLOCK_M×4 = 128 bytes. First 32 threads load 4 bytes each ----
212+ if (tid < BLOCK_M_DIM) {
213+ const int gm = block_m + tid;
214+ uint32_t val = 0 ;
215+ if (gm < M) {
216+ const int base = gm * scale_K + k_scale;
217+ if (k_scale + 3 < scale_K) {
218+ val = *(const uint32_t *)(SFA + base);
219+ } else {
220+ for (int b = 0 ; b < 4 ; b++) {
221+ if (k_scale + b < scale_K)
222+ val |= ((uint32_t )SFA[base + b]) << (b * 8 );
223+ }
224+ }
232225 }
226+ *(uint32_t *)(smem_SFA + tid * 4 ) = val;
227+ }
233228
234- // Load SFB for this N-tile
235- // SFB layout: sf_thread_idx = lane/4 = t1
236- // coord = t1 + v*8, n = coord%8, k_blk = coord/8
237- uint32_t sfb_packed = 0 ;
238- {
239- for (int sv = 0 ; sv < 4 ; sv++) {
240- int sfe = t1 + sv * 8 ;
241- int sf_n = sfe % 8 ;
242- int sf_col = sfe / 8 ;
243- int gn = this_tile_n + sf_n;
244- int gkb = k_start / 16 + sf_col;
245- unsigned char sf_val = 0 ;
246- if (gn < N && gkb < scale_stride_K) {
247- sf_val = SFB[gn * scale_stride_K + gkb];
229+ // ---- SFB: BLOCK_N×4 = 512 bytes. First 128 threads load 4 bytes each ----
230+ if (tid < BLOCK_N_DIM) {
231+ const int gn = block_n + tid;
232+ uint32_t val = 0 ;
233+ if (gn < N) {
234+ const int base = gn * scale_K + k_scale;
235+ if (k_scale + 3 < scale_K) {
236+ val = *(const uint32_t *)(SFB + base);
237+ } else {
238+ for (int b = 0 ; b < 4 ; b++) {
239+ if (k_scale + b < scale_K)
240+ val |= ((uint32_t )SFB[base + b]) << (b * 8 );
248241 }
249- sfb_packed |= ((uint32_t )sf_val << (sv * 8 ));
250242 }
251243 }
244+ *(uint32_t *)(smem_SFB + tid * 4 ) = val;
245+ }
246+
247+ __syncthreads ();
248+
249+ // ================================================================
250+ // Phase 2: Read MMA registers from smem and compute
251+ // ================================================================
252+
253+ // A registers (shared across all N-tiles in this warp)
254+ uint32_t a_regs[4 ];
255+ a_regs[0 ] = *(const uint32_t *)(smem_A + a_local_row0 * 32 + t0 * 4 );
256+ a_regs[1 ] = *(const uint32_t *)(smem_A + a_local_row1 * 32 + t0 * 4 );
257+ a_regs[2 ] = *(const uint32_t *)(smem_A + a_local_row0 * 32 + t0 * 4 + 16 );
258+ a_regs[3 ] = *(const uint32_t *)(smem_A + a_local_row1 * 32 + t0 * 4 + 16 );
259+
260+ // SFA: single uint32 load (all 4 bytes are in the same row)
261+ uint32_t sfa_packed = *(const uint32_t *)(smem_SFA + sfa_local_row * 4 );
262+
263+ // Per-N-tile: read B and SFB from smem, execute MMA
264+ #pragma unroll
265+ for (int nt = 0 ; nt < N_TILES_PER_WARP; nt++) {
266+ int local_n = n_warp * N_TILES_PER_WARP * 8 + nt * 8 ;
267+ int b_row = local_n + t1;
268+
269+ // B registers: 2 × uint32 from smem
270+ uint32_t b_regs[2 ];
271+ b_regs[0 ] = *(const uint32_t *)(smem_B + b_row * 32 + t0 * 4 );
272+ b_regs[1 ] = *(const uint32_t *)(smem_B + b_row * 32 + t0 * 4 + 16 );
273+
274+ // SFB: single uint32 load (4 consecutive bytes at row t1)
275+ uint32_t sfb_packed = *(const uint32_t *)(smem_SFB + (local_n + t1) * 4 );
252276
253- // Execute MMA: accumulate into this N-tile's accumulators
254277 mma_nvfp4_m16n8k64 (acc[nt][0 ], acc[nt][1 ], acc[nt][2 ], acc[nt][3 ], a_regs[0 ], a_regs[1 ], a_regs[2 ], a_regs[3 ],
255278 b_regs[0 ], b_regs[1 ], acc[nt][0 ], acc[nt][1 ], acc[nt][2 ], acc[nt][3 ], sfa_packed, sfb_packed);
256279 }
280+
281+ __syncthreads (); // Barrier before next K-step's smem writes
257282 }
258283
259284 // ---- Write output ----
260285 // SM80_16x8_Row: octet = lane/4, quad = lane%4
261- // d[0] = C[octet*2, quad*2], d[1] = C[octet*2, quad*2+1]
262- // d[2] = C[octet*2+1, quad*2], d[3] = C[octet*2+1, quad*2+1]
263286 int octet = lane_id / 4 ;
264287 int quad = lane_id % 4 ;
265288 int out_row0 = tile_m + octet * 2 ;
@@ -434,20 +457,17 @@ __global__ void kGemmNVFP4_simple(
434457}
435458
436459// ============================================================================
437- // Host-side launcher — uses optimized kernel
460+ // Host-side launcher — uses shared memory kernel
438461// ============================================================================
439462extern " C" void cgemm_nvfp4 (
440463 const unsigned char * A, const unsigned char * B, const unsigned char * SFA, const unsigned char * SFB, float * D, int M,
441464 int N, int K
442465) {
443- const int BLOCK_M = M_WARPS * 16 ;
444- const int BLOCK_N = N_WARPS * N_TILES_PER_WARP * 8 ;
445-
446- int num_m_blocks = (M + BLOCK_M - 1 ) / BLOCK_M;
447- int num_n_blocks = (N + BLOCK_N - 1 ) / BLOCK_N;
466+ int num_m_blocks = (M + BLOCK_M_DIM - 1 ) / BLOCK_M_DIM;
467+ int num_n_blocks = (N + BLOCK_N_DIM - 1 ) / BLOCK_N_DIM;
448468
449469 dim3 grid (num_n_blocks, num_m_blocks);
450470 int threads_per_block = WARPS_PER_BLOCK * 32 ; // 256
451471
452- kGemmNVFP4_opt <<<grid, threads_per_block>>> (A, B, SFA, SFB, D, M, N, K);
472+ kGemmNVFP4_smem <<<grid, threads_per_block>>> (A, B, SFA, SFB, D, M, N, K);
453473}
0 commit comments