Skip to content

Commit d95c3ed

Browse files
committed
perf: Shared memory tiling for NVFP4 GEMM kernel
Replace per-thread scattered global loads with cooperative tile loading into shared memory. All 256 threads cooperatively load A/B/SFA/SFB tiles with coalesced access (uint32 for A, uint4 for B), then each thread reads its MMA registers from fast shared memory. Data reuse: A shared across N_WARPS (4x), B shared across M_WARPS (2x). SFA/SFB packed registers loaded as single uint32 (proved consecutive). Total bandwidth reduction: ~2.4x vs previous per-warp global loads.
1 parent 940ac12 commit d95c3ed

File tree

1 file changed

+175
-155
lines changed

1 file changed

+175
-155
lines changed

csrc/kernels_nvfp4_sm120.cu

Lines changed: 175 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,26 @@ __device__ __forceinline__ uint32_t pack_8_nibbles_slow(const unsigned char* dat
6363
}
6464

6565
// ============================================================================
66-
// Optimized NVFP4 GEMM kernel
66+
// Shared-memory NVFP4 GEMM kernel
6767
//
68-
// Key optimizations over kGemmNVFP4_simple:
69-
// 1. Vectorized uint32 loads: Each MMA register's 8 nibbles map to 4 consecutive
70-
// bytes in memory. Load as uint32 instead of 8 individual nibble extractions.
71-
// 2. Multi-N per warp: Each warp computes m16 x nN_TILE_PER_WARP (4 MMA
72-
// instructions per K-step), reusing A registers across N-slices.
73-
// 3. Shared memory A tile: All warps in a block share the same m16 tile of A.
74-
// A is loaded cooperatively into shared memory, then each warp reads its
75-
// registers from smem. This gives N_WARPS x reuse of A bandwidth.
68+
// Key optimizations:
69+
// 1. Cooperative tiling: All threads cooperatively load A/B/SFA/SFB tiles from
70+
// global memory into shared memory with coalesced access patterns.
71+
// 2. Data reuse: A tile shared across N_WARPS (4x saving), B tile shared
72+
// across M_WARPS (2x saving). Total ~2.4x bandwidth reduction.
73+
// 3. Fast register packing: MMA registers read from smem as uint32 loads.
74+
// SFA/SFB packed registers loaded as single uint32 (all 4 bytes are
75+
// consecutive in the same row — proven by CuTE layout analysis).
76+
// 4. Vectorized global loads: A uses uint32 (4B), B uses uint4 (16B).
7677
//
77-
// Register layout (derived from CuTE ALayout/BLayout analysis):
78+
// Register layout (from CuTE ALayout/BLayout):
7879
// A reg[i] = A[tile_m + 2*t1 + (i&1), k_start + t0*8 + (i>>1)*32 .. +7]
7980
// B reg[i] = B[tile_n + t1, k_start + t0*8 + i*32 .. +7]
80-
// where t0 = lane%4, t1 = lane/4 (CuTE thread decomposition)
81-
// Each register's 8 nibbles = 4 consecutive packed bytes in memory.
81+
// SFA packed = SFA[actual_m, k_blk 0..3] (consecutive in memory)
82+
// SFB packed = SFB[tile_n + t1, k_blk 0..3] (consecutive in memory)
8283
//
83-
// Block/warp configuration:
84-
// 4 warps per block, block tile = m16 x n32
85-
// Each warp handles a different n8 slice, all share same m16
86-
// Shared memory: A tile (512 bytes) + SFA (64 bytes) per K-step
84+
// Block tile: m32 x n128 (M_WARPS=2, N_WARPS=4, N_TILES_PER_WARP=4)
85+
// Shared memory per K-step: 1024 + 4096 + 128 + 512 = 5760 bytes
8786
// ============================================================================
8887

8988
// N-tiles per warp: each warp computes m16 x (N_TILES_PER_WARP * 8)
@@ -93,173 +92,197 @@ __device__ __forceinline__ uint32_t pack_8_nibbles_slow(const unsigned char* dat
9392
#define N_WARPS 4
9493
#define WARPS_PER_BLOCK (M_WARPS * N_WARPS) // 8
9594

96-
// 256 threads, target 4 blocks/SM (limit regs to 48 via maxrregcount)
97-
__global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 4) void kGemmNVFP4_opt(
95+
// Block tile dimensions
96+
#define BLOCK_M_DIM (M_WARPS * 16) // 32
97+
#define BLOCK_N_DIM (N_WARPS * N_TILES_PER_WARP * 8) // 128
98+
99+
// Shared memory sizes (bytes per K-step)
100+
#define SMEM_A_BYTES (BLOCK_M_DIM * 32) // 1024
101+
#define SMEM_B_BYTES (BLOCK_N_DIM * 32) // 4096
102+
#define SMEM_SFA_BYTES (BLOCK_M_DIM * 4) // 128
103+
#define SMEM_SFB_BYTES (BLOCK_N_DIM * 4) // 512
104+
#define SMEM_TOTAL (SMEM_A_BYTES + SMEM_B_BYTES + SMEM_SFA_BYTES + SMEM_SFB_BYTES)
105+
106+
// 256 threads, target 4 blocks/SM for occupancy
107+
__global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 4) void kGemmNVFP4_smem(
98108
const unsigned char* __restrict__ A, // M x K/2 packed FP4 (row-major)
99109
const unsigned char* __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major)
100110
const unsigned char* __restrict__ SFA, // M x K/16 UE4M3 scales
101111
const unsigned char* __restrict__ SFB, // N x K/16 UE4M3 scales
102112
float* __restrict__ D, // M x N output (F32)
103113
int M, int N, int K
104114
) {
105-
// Block tile: m(M_WARPS*16) x n(N_WARPS * N_TILES_PER_WARP * 8)
106-
// = m32 x n128 for 2x4 warps
107-
const int BLOCK_M = M_WARPS * 16;
108-
const int BLOCK_N = N_WARPS * N_TILES_PER_WARP * 8;
109-
110-
int warp_in_block = threadIdx.x / 32;
111-
int lane_id = threadIdx.x % 32;
112-
113-
// 2D warp mapping: m_warp along M, n_warp along N
114-
int m_warp = warp_in_block / N_WARPS; // 0..(M_WARPS-1)
115-
int n_warp = warp_in_block % N_WARPS; // 0..(N_WARPS-1)
116-
117-
// Block-level tile position
118-
int tile_m = blockIdx.y * BLOCK_M + m_warp * 16;
119-
int tile_n_base = blockIdx.x * BLOCK_N;
120-
121-
if (tile_m >= M)
122-
return;
123-
124-
// This warp's N offset within the block
125-
int warp_n_base = tile_n_base + n_warp * N_TILES_PER_WARP * 8;
126-
127-
// CuTE thread decomposition: t0 = lane%4 (0-3), t1 = lane/4 (0-7)
128-
int t0 = lane_id % 4;
129-
int t1 = lane_id / 4;
130-
131-
// Precompute A row indices for this thread's registers
132-
// reg[0,2] → row0 = tile_m + 2*t1, reg[1,3] → row1 = tile_m + 2*t1 + 1
133-
int a_row0 = tile_m + 2 * t1;
134-
int a_row1 = a_row0 + 1;
135-
136-
int half_K = K / 2;
137-
int scale_stride_K = K / 16;
138-
139-
// Accumulators: N_TILES_PER_WARP * 4 floats per thread
115+
// Shared memory: 16-byte aligned for uint4 stores
116+
__shared__ __align__(16) unsigned char smem[SMEM_TOTAL]; // 5760 bytes
117+
unsigned char* smem_A = smem;
118+
unsigned char* smem_B = smem + SMEM_A_BYTES;
119+
unsigned char* smem_SFA = smem + SMEM_A_BYTES + SMEM_B_BYTES;
120+
unsigned char* smem_SFB = smem + SMEM_A_BYTES + SMEM_B_BYTES + SMEM_SFA_BYTES;
121+
122+
const int tid = threadIdx.x;
123+
const int warp_in_block = tid / 32;
124+
const int lane_id = tid % 32;
125+
const int m_warp = warp_in_block / N_WARPS; // 0..1
126+
const int n_warp = warp_in_block % N_WARPS; // 0..3
127+
128+
const int block_m = blockIdx.y * BLOCK_M_DIM;
129+
const int block_n = blockIdx.x * BLOCK_N_DIM;
130+
const int tile_m = block_m + m_warp * 16;
131+
const int warp_n_base = block_n + n_warp * N_TILES_PER_WARP * 8;
132+
133+
const int t0 = lane_id % 4;
134+
const int t1 = lane_id / 4;
135+
const int half_K = K / 2;
136+
const int scale_K = K / 16;
137+
138+
// Accumulators
140139
float acc[N_TILES_PER_WARP][4];
141140
#pragma unroll
142141
for (int nt = 0; nt < N_TILES_PER_WARP; nt++) {
143-
acc[nt][0] = 0.0f;
144-
acc[nt][1] = 0.0f;
145-
acc[nt][2] = 0.0f;
146-
acc[nt][3] = 0.0f;
142+
acc[nt][0] = acc[nt][1] = acc[nt][2] = acc[nt][3] = 0.0f;
147143
}
148144

145+
// Precompute smem row indices for A register reads
146+
const int a_local_row0 = m_warp * 16 + 2 * t1;
147+
const int a_local_row1 = a_local_row0 + 1;
148+
149+
// Precompute SFA row for this thread (all 4 bytes come from the same row)
150+
// sf_tidx = (lane%2)*8 + lane/4; cute_m_0 = sf_tidx % 16
151+
// actual_m = (cute_m_0 % 8)*2 + cute_m_0/8
152+
const int sf_tidx = (lane_id % 2) * 8 + (lane_id / 4);
153+
const int cute_sf_m0 = sf_tidx % 16;
154+
const int sfa_local_row = m_warp * 16 + (cute_sf_m0 % 8) * 2 + cute_sf_m0 / 8;
155+
149156
// K-loop
150157
for (int k_start = 0; k_start < K; k_start += 64) {
151-
// ---- Load A registers (4 x uint32) ----
152-
// reg[0] = A[row0, k_start + t0*8 + 0..7] → 4 bytes at row0*K/2 + (k_start+t0*8)/2
153-
// reg[1] = A[row1, k_start + t0*8 + 0..7]
154-
// reg[2] = A[row0, k_start + t0*8 + 32..39]
155-
// reg[3] = A[row1, k_start + t0*8 + 32..39]
156-
uint32_t a_regs[4];
157-
int k_col_lo = k_start + t0 * 8;
158-
int k_col_hi = k_col_lo + 32;
159-
160-
// Fast path: no boundary check needed
161-
bool a_row0_ok = (a_row0 < M);
162-
bool a_row1_ok = (a_row1 < M);
163-
bool k_lo_ok = (k_col_lo + 7 < K);
164-
bool k_hi_ok = (k_col_hi + 7 < K);
165-
166-
if (a_row0_ok && k_lo_ok) {
167-
a_regs[0] = *(const uint32_t*)(A + a_row0 * half_K + k_col_lo / 2);
168-
} else {
169-
a_regs[0] = pack_8_nibbles_slow(A, a_row0, k_col_lo, K, M, K);
170-
}
171-
if (a_row1_ok && k_lo_ok) {
172-
a_regs[1] = *(const uint32_t*)(A + a_row1 * half_K + k_col_lo / 2);
173-
} else {
174-
a_regs[1] = pack_8_nibbles_slow(A, a_row1, k_col_lo, K, M, K);
175-
}
176-
if (a_row0_ok && k_hi_ok) {
177-
a_regs[2] = *(const uint32_t*)(A + a_row0 * half_K + k_col_hi / 2);
178-
} else {
179-
a_regs[2] = pack_8_nibbles_slow(A, a_row0, k_col_hi, K, M, K);
180-
}
181-
if (a_row1_ok && k_hi_ok) {
182-
a_regs[3] = *(const uint32_t*)(A + a_row1 * half_K + k_col_hi / 2);
183-
} else {
184-
a_regs[3] = pack_8_nibbles_slow(A, a_row1, k_col_hi, K, M, K);
185-
}
158+
const int k_byte = k_start / 2;
159+
const int k_scale = k_start / 16;
186160

187-
// ---- Load SFA ----
188-
// SFA layout: sf_thread_idx = (lane%2)*8 + (lane/4)
189-
// Scale coord = sf_thread_idx + v*16 → cute_m = coord%16, k_blk = coord/16
190-
// Remap: actual_m = (cute_m%8)*2 + cute_m/8
191-
uint32_t sfa_packed = 0;
161+
// ================================================================
162+
// Phase 1: Cooperative load from global → shared memory
163+
// ================================================================
164+
165+
// ---- A tile: BLOCK_M×32 = 1024 bytes, 256 threads × 4 bytes each ----
192166
{
193-
int sf_tidx = (lane_id % 2) * 8 + (lane_id / 4);
194-
for (int sv = 0; sv < 4; sv++) {
195-
int sfe = sf_tidx + sv * 16;
196-
int cute_sf_m = sfe % 16;
197-
int sf_col = sfe / 16;
198-
int sf_row = (cute_sf_m % 8) * 2 + cute_sf_m / 8;
199-
int gm = tile_m + sf_row;
200-
int gkb = k_start / 16 + sf_col;
201-
unsigned char sf_val = 0;
202-
if (gm < M && gkb < scale_stride_K) {
203-
sf_val = SFA[gm * scale_stride_K + gkb];
167+
const int off = tid * 4; // byte offset in smem_A (0..1020)
168+
const int row = off >> 5; // off / 32 → local row (0..31)
169+
const int col = off & 31; // off % 32 → byte col (0,4,...,28)
170+
const int gm = block_m + row;
171+
172+
uint32_t val = 0;
173+
if (gm < M) {
174+
const int gaddr = gm * half_K + k_byte + col;
175+
if (k_byte + col + 3 < half_K) {
176+
val = *(const uint32_t*)(A + gaddr);
177+
} else {
178+
// K-boundary: byte-by-byte
179+
for (int b = 0; b < 4; b++) {
180+
if (k_byte + col + b < half_K)
181+
val |= ((uint32_t)A[gaddr + b]) << (b * 8);
182+
}
204183
}
205-
sfa_packed |= ((uint32_t)sf_val << (sv * 8));
206184
}
185+
*(uint32_t*)(smem_A + off) = val;
207186
}
208187

209-
// ---- For each N-tile in this warp ----
210-
#pragma unroll
211-
for (int nt = 0; nt < N_TILES_PER_WARP; nt++) {
212-
int this_tile_n = warp_n_base + nt * 8;
213-
if (this_tile_n >= N)
214-
break;
215-
216-
// Load B registers (2 x uint32)
217-
// reg[0] = B[this_tile_n + t1, k_start + t0*8 + 0..7]
218-
// reg[1] = B[this_tile_n + t1, k_start + t0*8 + 32..39]
219-
uint32_t b_regs[2];
220-
int b_row = this_tile_n + t1;
221-
bool b_row_ok = (b_row < N);
222-
223-
if (b_row_ok && k_lo_ok) {
224-
b_regs[0] = *(const uint32_t*)(B + b_row * half_K + k_col_lo / 2);
188+
// ---- B tile: BLOCK_N×32 = 4096 bytes, 256 threads × 16 bytes each ----
189+
{
190+
const int off = tid * 16; // byte offset in smem_B (0..4080)
191+
const int row = off >> 5; // local row (0..127)
192+
const int col = off & 31; // 0 or 16
193+
const int gn = block_n + row;
194+
195+
if (gn < N) {
196+
const int gaddr = gn * half_K + k_byte + col;
197+
if (k_byte + col + 15 < half_K) {
198+
*(uint4*)(smem_B + off) = *(const uint4*)(B + gaddr);
199+
} else {
200+
// K-boundary: byte-by-byte
201+
for (int b = 0; b < 16; b++) {
202+
smem_B[off + b] = (k_byte + col + b < half_K) ? B[gaddr + b] : 0;
203+
}
204+
}
225205
} else {
226-
b_regs[0] = pack_8_nibbles_slow(B, b_row, k_col_lo, K, N, K);
206+
// N-boundary: zero-fill
207+
*(uint4*)(smem_B + off) = make_uint4(0, 0, 0, 0);
227208
}
228-
if (b_row_ok && k_hi_ok) {
229-
b_regs[1] = *(const uint32_t*)(B + b_row * half_K + k_col_hi / 2);
230-
} else {
231-
b_regs[1] = pack_8_nibbles_slow(B, b_row, k_col_hi, K, N, K);
209+
}
210+
211+
// ---- SFA: BLOCK_M×4 = 128 bytes. First 32 threads load 4 bytes each ----
212+
if (tid < BLOCK_M_DIM) {
213+
const int gm = block_m + tid;
214+
uint32_t val = 0;
215+
if (gm < M) {
216+
const int base = gm * scale_K + k_scale;
217+
if (k_scale + 3 < scale_K) {
218+
val = *(const uint32_t*)(SFA + base);
219+
} else {
220+
for (int b = 0; b < 4; b++) {
221+
if (k_scale + b < scale_K)
222+
val |= ((uint32_t)SFA[base + b]) << (b * 8);
223+
}
224+
}
232225
}
226+
*(uint32_t*)(smem_SFA + tid * 4) = val;
227+
}
233228

234-
// Load SFB for this N-tile
235-
// SFB layout: sf_thread_idx = lane/4 = t1
236-
// coord = t1 + v*8, n = coord%8, k_blk = coord/8
237-
uint32_t sfb_packed = 0;
238-
{
239-
for (int sv = 0; sv < 4; sv++) {
240-
int sfe = t1 + sv * 8;
241-
int sf_n = sfe % 8;
242-
int sf_col = sfe / 8;
243-
int gn = this_tile_n + sf_n;
244-
int gkb = k_start / 16 + sf_col;
245-
unsigned char sf_val = 0;
246-
if (gn < N && gkb < scale_stride_K) {
247-
sf_val = SFB[gn * scale_stride_K + gkb];
229+
// ---- SFB: BLOCK_N×4 = 512 bytes. First 128 threads load 4 bytes each ----
230+
if (tid < BLOCK_N_DIM) {
231+
const int gn = block_n + tid;
232+
uint32_t val = 0;
233+
if (gn < N) {
234+
const int base = gn * scale_K + k_scale;
235+
if (k_scale + 3 < scale_K) {
236+
val = *(const uint32_t*)(SFB + base);
237+
} else {
238+
for (int b = 0; b < 4; b++) {
239+
if (k_scale + b < scale_K)
240+
val |= ((uint32_t)SFB[base + b]) << (b * 8);
248241
}
249-
sfb_packed |= ((uint32_t)sf_val << (sv * 8));
250242
}
251243
}
244+
*(uint32_t*)(smem_SFB + tid * 4) = val;
245+
}
246+
247+
__syncthreads();
248+
249+
// ================================================================
250+
// Phase 2: Read MMA registers from smem and compute
251+
// ================================================================
252+
253+
// A registers (shared across all N-tiles in this warp)
254+
uint32_t a_regs[4];
255+
a_regs[0] = *(const uint32_t*)(smem_A + a_local_row0 * 32 + t0 * 4);
256+
a_regs[1] = *(const uint32_t*)(smem_A + a_local_row1 * 32 + t0 * 4);
257+
a_regs[2] = *(const uint32_t*)(smem_A + a_local_row0 * 32 + t0 * 4 + 16);
258+
a_regs[3] = *(const uint32_t*)(smem_A + a_local_row1 * 32 + t0 * 4 + 16);
259+
260+
// SFA: single uint32 load (all 4 bytes are in the same row)
261+
uint32_t sfa_packed = *(const uint32_t*)(smem_SFA + sfa_local_row * 4);
262+
263+
// Per-N-tile: read B and SFB from smem, execute MMA
264+
#pragma unroll
265+
for (int nt = 0; nt < N_TILES_PER_WARP; nt++) {
266+
int local_n = n_warp * N_TILES_PER_WARP * 8 + nt * 8;
267+
int b_row = local_n + t1;
268+
269+
// B registers: 2 × uint32 from smem
270+
uint32_t b_regs[2];
271+
b_regs[0] = *(const uint32_t*)(smem_B + b_row * 32 + t0 * 4);
272+
b_regs[1] = *(const uint32_t*)(smem_B + b_row * 32 + t0 * 4 + 16);
273+
274+
// SFB: single uint32 load (4 consecutive bytes at row t1)
275+
uint32_t sfb_packed = *(const uint32_t*)(smem_SFB + (local_n + t1) * 4);
252276

253-
// Execute MMA: accumulate into this N-tile's accumulators
254277
mma_nvfp4_m16n8k64(acc[nt][0], acc[nt][1], acc[nt][2], acc[nt][3], a_regs[0], a_regs[1], a_regs[2], a_regs[3],
255278
b_regs[0], b_regs[1], acc[nt][0], acc[nt][1], acc[nt][2], acc[nt][3], sfa_packed, sfb_packed);
256279
}
280+
281+
__syncthreads(); // Barrier before next K-step's smem writes
257282
}
258283

259284
// ---- Write output ----
260285
// SM80_16x8_Row: octet = lane/4, quad = lane%4
261-
// d[0] = C[octet*2, quad*2], d[1] = C[octet*2, quad*2+1]
262-
// d[2] = C[octet*2+1, quad*2], d[3] = C[octet*2+1, quad*2+1]
263286
int octet = lane_id / 4;
264287
int quad = lane_id % 4;
265288
int out_row0 = tile_m + octet * 2;
@@ -434,20 +457,17 @@ __global__ void kGemmNVFP4_simple(
434457
}
435458

436459
// ============================================================================
437-
// Host-side launcher — uses optimized kernel
460+
// Host-side launcher — uses shared memory kernel
438461
// ============================================================================
439462
extern "C" void cgemm_nvfp4(
440463
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M,
441464
int N, int K
442465
) {
443-
const int BLOCK_M = M_WARPS * 16;
444-
const int BLOCK_N = N_WARPS * N_TILES_PER_WARP * 8;
445-
446-
int num_m_blocks = (M + BLOCK_M - 1) / BLOCK_M;
447-
int num_n_blocks = (N + BLOCK_N - 1) / BLOCK_N;
466+
int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM;
467+
int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM;
448468

449469
dim3 grid(num_n_blocks, num_m_blocks);
450470
int threads_per_block = WARPS_PER_BLOCK * 32; // 256
451471

452-
kGemmNVFP4_opt<<<grid, threads_per_block>>>(A, B, SFA, SFB, D, M, N, K);
472+
kGemmNVFP4_smem<<<grid, threads_per_block>>>(A, B, SFA, SFB, D, M, N, K);
453473
}

0 commit comments

Comments
 (0)