@@ -1131,6 +1131,278 @@ void kbitGemmMinimal(
11311131 CUDA_CHECK_RETURN (cudaPeekAtLastError ());
11321132}
11331133
1134+ // ---- Stage 4: Pipelined fused kbit dequant + GEMM kernel ----
1135+ // Double-buffered cp.async pipeline overlapping loads with compute.
1136+ // Same math as Stage 3 but with async global→shared memory copies for B and absmax,
1137+ // and synchronous A loads (small tile, needs bounds checking).
1138+ // B tile stored WITHOUT +1 padding (simpler cp.async, bank conflicts deferred to Stage 6).
1139+
1140+ // cp.async helpers (sm_80+)
1141+ __device__ __forceinline__ void cp_async_cg_16 (void * __restrict__ smem, const void * __restrict__ gmem) {
1142+ uint32_t smem_addr = static_cast <uint32_t >(__cvta_generic_to_shared (smem));
1143+ asm volatile (" cp.async.cg.shared.global [%0], [%1], 16;\n " ::" r" (smem_addr), " l" (gmem));
1144+ }
1145+
1146+ __device__ __forceinline__ void cp_async_fence () {
1147+ asm volatile (" cp.async.commit_group;\n " ::);
1148+ }
1149+
1150+ template <int N>
1151+ __device__ __forceinline__ void cp_async_wait () {
1152+ asm volatile (" cp.async.wait_group %0;\n " ::" n" (N));
1153+ }
1154+
1155+ template <int K_BITS>
1156+ __global__ void kbit_gemm_pipelined (
1157+ const half* __restrict__ A, const unsigned int * __restrict__ B_packed, const unsigned char * __restrict__ B_absmax,
1158+ const float * __restrict__ codebook, half* __restrict__ C, const int M, const int K_dim, const int N
1159+ ) {
1160+ constexpr int TILE_M = 16 ;
1161+ constexpr int TILE_K = 64 ;
1162+ constexpr int TILE_N = 128 ;
1163+ constexpr int BS = 32 ;
1164+ constexpr int KB_PER_TILE = TILE_K / BS; // 2
1165+ constexpr int B_COL_WORDS = KB_PER_TILE * K_BITS; // words per column (no padding)
1166+ constexpr int N_BLOCKS = 2 ; // 16 cols per warp / 8 cols per MMA
1167+
1168+ // Per-stage sizes in elements
1169+ constexpr int A_STAGE_ELEMS = TILE_M * TILE_K; // half elements
1170+ constexpr int B_STAGE_WORDS = TILE_N * B_COL_WORDS; // uint32 elements
1171+ constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE; // uint8 elements
1172+
1173+ // Per-stage sizes in bytes (all naturally 16-byte aligned)
1174+ constexpr int A_STAGE_BYTES = A_STAGE_ELEMS * sizeof (half);
1175+ constexpr int B_STAGE_BYTES = B_STAGE_WORDS * sizeof (unsigned int );
1176+ // Round absmax up to 16-byte boundary for alignment
1177+ constexpr int ABS_STAGE_BYTES_ALIGNED = (ABS_STAGE_BYTES + 15 ) & ~15 ;
1178+
1179+ constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES + ABS_STAGE_BYTES_ALIGNED;
1180+
1181+ const int n_tile = blockIdx .x ;
1182+ const int m_tile = blockIdx .y ;
1183+ const int n_tiles = N / TILE_N;
1184+ const int k_tiles = (K_dim + TILE_K - 1 ) / TILE_K;
1185+ const int warp_id = threadIdx .x / 32 ;
1186+ const int lane_id = threadIdx .x % 32 ;
1187+ const int gid = lane_id / 4 ;
1188+ const int tid = lane_id % 4 ;
1189+
1190+ const int warp_n_base = warp_id * (TILE_N / 8 );
1191+ const int m_base = m_tile * TILE_M;
1192+
1193+ // Double-buffered shared memory: 2 stages
1194+ extern __shared__ char smem[];
1195+
1196+ // Helper lambdas for stage-indexed shared memory pointers
1197+ auto sh_a = [&](int stage) -> half* {
1198+ return reinterpret_cast <half*>(smem + stage * STAGE_BYTES);
1199+ };
1200+ auto sh_b = [&](int stage) -> unsigned int * {
1201+ return reinterpret_cast <unsigned int *>(smem + stage * STAGE_BYTES + A_STAGE_BYTES);
1202+ };
1203+ auto sh_abs = [&](int stage) -> unsigned char * {
1204+ return reinterpret_cast <unsigned char *>(smem + stage * STAGE_BYTES + A_STAGE_BYTES + B_STAGE_BYTES);
1205+ };
1206+
1207+ // Codebook in register
1208+ half cb_h = (lane_id < (1 << K_BITS)) ? __float2half (codebook[lane_id]) : __float2half (0 .0f );
1209+
1210+ // Accumulators
1211+ float frag_c[N_BLOCKS][4 ];
1212+ #pragma unroll
1213+ for (int nb = 0 ; nb < N_BLOCKS; nb++)
1214+ frag_c[nb][0 ] = frag_c[nb][1 ] = frag_c[nb][2 ] = frag_c[nb][3 ] = 0 .0f ;
1215+
1216+ // ---- Tile fetch function (inlined via lambda) ----
1217+ // B and absmax: cp.async (contiguous, always in-bounds from repack)
1218+ // A: synchronous with bounds checking
1219+ auto fetch_tile = [&](int stage, int kt) {
1220+ const int k_base = kt * TILE_K;
1221+ const int tile_idx = kt * n_tiles + n_tile;
1222+
1223+ // B tile: contiguous cp.async (16-byte / int4 granularity)
1224+ const int b_global_base = tile_idx * B_STAGE_WORDS;
1225+ constexpr int B_INT4S = B_STAGE_BYTES / 16 ;
1226+ const int4 * b_src = reinterpret_cast <const int4 *>(B_packed + b_global_base);
1227+ int4 * b_dst = reinterpret_cast <int4 *>(sh_b (stage));
1228+ for (int i = threadIdx .x ; i < B_INT4S; i += blockDim .x )
1229+ cp_async_cg_16 (&b_dst[i], &b_src[i]);
1230+
1231+ // Absmax tile: contiguous cp.async
1232+ const int abs_global_base = tile_idx * ABS_STAGE_BYTES;
1233+ constexpr int ABS_INT4S = (ABS_STAGE_BYTES + 15 ) / 16 ;
1234+ const int4 * abs_src = reinterpret_cast <const int4 *>(B_absmax + abs_global_base);
1235+ int4 * abs_dst = reinterpret_cast <int4 *>(sh_abs (stage));
1236+ if (threadIdx .x < ABS_INT4S)
1237+ cp_async_cg_16 (&abs_dst[threadIdx .x ], &abs_src[threadIdx .x ]);
1238+
1239+ // A tile: synchronous with bounds checking
1240+ half* a_dst = sh_a (stage);
1241+ for (int i = threadIdx .x ; i < A_STAGE_ELEMS; i += blockDim .x ) {
1242+ int row = i / TILE_K;
1243+ int col = i % TILE_K;
1244+ int gr = m_base + row;
1245+ int gc = k_base + col;
1246+ a_dst[row * TILE_K + col] = (gr < M && gc < K_dim) ? A[gr * K_dim + gc] : __float2half (0 .0f );
1247+ }
1248+ };
1249+
1250+ // ---- Compute function for one k-tile ----
1251+ auto compute_tile = [&](int stage) {
1252+ half* a_ptr = sh_a (stage);
1253+ unsigned int * b_ptr = sh_b (stage);
1254+ unsigned char * abs_ptr = sh_abs (stage);
1255+
1256+ #pragma unroll
1257+ for (int ks = 0 ; ks < 4 ; ks++) {
1258+ const int k_block = ks / 2 ;
1259+ const int half_idx = ks % 2 ;
1260+
1261+ // Load A fragment (same as Stage 3)
1262+ uint32_t frag_a[4 ];
1263+ {
1264+ const int kc0 = ks * 16 + tid * 2 ;
1265+ const int kc1 = ks * 16 + tid * 2 + 8 ;
1266+ const int r0 = gid;
1267+ const int r1 = gid + 8 ;
1268+ half2 h_rlo_klo = __halves2half2 (
1269+ (r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0] : __float2half (0 .0f ),
1270+ (r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0 + 1 ] : __float2half (0 .0f ));
1271+ half2 h_rhi_klo = __halves2half2 (
1272+ (r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0] : __float2half (0 .0f ),
1273+ (r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0 + 1 ] : __float2half (0 .0f ));
1274+ half2 h_rlo_khi = __halves2half2 (
1275+ (r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1] : __float2half (0 .0f ),
1276+ (r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1 + 1 ] : __float2half (0 .0f ));
1277+ half2 h_rhi_khi = __halves2half2 (
1278+ (r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1] : __float2half (0 .0f ),
1279+ (r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1 + 1 ] : __float2half (0 .0f ));
1280+ frag_a[0 ] = *reinterpret_cast <uint32_t *>(&h_rlo_klo);
1281+ frag_a[1 ] = *reinterpret_cast <uint32_t *>(&h_rhi_klo);
1282+ frag_a[2 ] = *reinterpret_cast <uint32_t *>(&h_rlo_khi);
1283+ frag_a[3 ] = *reinterpret_cast <uint32_t *>(&h_rhi_khi);
1284+ }
1285+
1286+ #pragma unroll
1287+ for (int nb = 0 ; nb < N_BLOCKS; nb++) {
1288+ int col = warp_n_base + nb * 8 + gid;
1289+
1290+ // B: read from non-padded layout
1291+ unsigned int planes[K_BITS];
1292+ int b_addr = col * B_COL_WORDS + k_block * K_BITS;
1293+ #pragma unroll
1294+ for (int b = 0 ; b < K_BITS; b++)
1295+ planes[b] = b_ptr[b_addr + b];
1296+
1297+ half scale = __float2half (decode_e4m4_absmax (abs_ptr[col * KB_PER_TILE + k_block]));
1298+
1299+ const int bit_offset = half_idx * 16 ;
1300+ const int rows[4 ] = {2 * tid, 2 * tid + 1 , 2 * tid + 8 , 2 * tid + 9 };
1301+ half vals[4 ];
1302+ #pragma unroll
1303+ for (int r = 0 ; r < 4 ; r++) {
1304+ int bit_pos = bit_offset + rows[r];
1305+ int idx = 0 ;
1306+ #pragma unroll
1307+ for (int b = 0 ; b < K_BITS; b++)
1308+ idx |= ((planes[b] >> bit_pos) & 1 ) << b;
1309+ vals[r] = __hmul (__shfl_sync (0xFFFFFFFF , cb_h, idx), scale);
1310+ }
1311+
1312+ uint32_t frag_b[2 ];
1313+ {
1314+ half2 b0 = __halves2half2 (vals[0 ], vals[1 ]);
1315+ half2 b1 = __halves2half2 (vals[2 ], vals[3 ]);
1316+ frag_b[0 ] = *reinterpret_cast <uint32_t *>(&b0);
1317+ frag_b[1 ] = *reinterpret_cast <uint32_t *>(&b1);
1318+ }
1319+
1320+ asm volatile (" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
1321+ " {%0, %1, %2, %3}, "
1322+ " {%4, %5, %6, %7}, "
1323+ " {%8, %9}, "
1324+ " {%10, %11, %12, %13};\n "
1325+ : " =f" (frag_c[nb][0 ]), " =f" (frag_c[nb][1 ]), " =f" (frag_c[nb][2 ]),
1326+ " =f" (frag_c[nb][3 ])
1327+ : " r" (frag_a[0 ]), " r" (frag_a[1 ]), " r" (frag_a[2 ]), " r" (frag_a[3 ]),
1328+ " r" (frag_b[0 ]), " r" (frag_b[1 ]),
1329+ " f" (frag_c[nb][0 ]), " f" (frag_c[nb][1 ]), " f" (frag_c[nb][2 ]),
1330+ " f" (frag_c[nb][3 ]));
1331+ }
1332+ }
1333+ };
1334+
1335+ // ---- Double-buffered pipeline ----
1336+ // Fetch first tile
1337+ fetch_tile (0 , 0 );
1338+ cp_async_fence ();
1339+
1340+ for (int kt = 0 ; kt < k_tiles; kt++) {
1341+ int cur = kt % 2 ;
1342+
1343+ // Prefetch next tile into the other buffer
1344+ if (kt + 1 < k_tiles) {
1345+ fetch_tile ((kt + 1 ) % 2 , kt + 1 );
1346+ cp_async_fence ();
1347+ cp_async_wait<1 >(); // wait for current tile, allow next pending
1348+ } else {
1349+ cp_async_wait<0 >(); // last tile: wait for everything
1350+ }
1351+ __syncthreads ();
1352+
1353+ // Compute on current tile
1354+ compute_tile (cur);
1355+ __syncthreads ();
1356+ }
1357+
1358+ // ---- Write output (same as Stage 3) ----
1359+ #pragma unroll
1360+ for (int nb = 0 ; nb < N_BLOCKS; nb++) {
1361+ int c_col = n_tile * TILE_N + warp_n_base + nb * 8 + tid * 2 ;
1362+ int m_row0 = m_base + gid;
1363+ int m_row1 = m_base + gid + 8 ;
1364+ if (m_row0 < M) {
1365+ C[m_row0 * N + c_col] = __float2half (frag_c[nb][0 ]);
1366+ C[m_row0 * N + c_col + 1 ] = __float2half (frag_c[nb][1 ]);
1367+ }
1368+ if (m_row1 < M) {
1369+ C[m_row1 * N + c_col] = __float2half (frag_c[nb][2 ]);
1370+ C[m_row1 * N + c_col + 1 ] = __float2half (frag_c[nb][3 ]);
1371+ }
1372+ }
1373+ }
1374+
1375+ // Stage 4 GEMM launcher
1376+ template <int K>
1377+ void kbitGemmPipelined (
1378+ const half* A, const unsigned int * B_packed, const unsigned char * B_absmax, const float * codebook, half* C, int M,
1379+ int K_dim, int N
1380+ ) {
1381+ constexpr int TILE_M = 16 ;
1382+ constexpr int TILE_K = 64 ;
1383+ constexpr int TILE_N = 128 ;
1384+ constexpr int BS = 32 ;
1385+ constexpr int KB_PER_TILE = TILE_K / BS;
1386+ constexpr int B_COL_WORDS = KB_PER_TILE * K;
1387+
1388+ constexpr int A_STAGE_BYTES = TILE_M * TILE_K * sizeof (half);
1389+ constexpr int B_STAGE_BYTES = TILE_N * B_COL_WORDS * sizeof (unsigned int );
1390+ constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE;
1391+ constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15 ) & ~15 ;
1392+ constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES + ABS_STAGE_ALIGNED;
1393+
1394+ int m_tiles = (M + TILE_M - 1 ) / TILE_M;
1395+ int n_tiles = N / TILE_N;
1396+
1397+ dim3 grid (n_tiles, m_tiles);
1398+ dim3 block (256 );
1399+
1400+ int smem_size = 2 * STAGE_BYTES; // double buffer
1401+
1402+ kbit_gemm_pipelined<K><<<grid, block, smem_size>>> (A, B_packed, B_absmax, codebook, C, M, K_dim, N);
1403+ CUDA_CHECK_RETURN (cudaPeekAtLastError ());
1404+ }
1405+
11341406// ---- Debug: Simple MMA test kernel ----
11351407// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
11361408__global__ void test_mma_kernel (const half* __restrict__ A, const half* __restrict__ B, float * __restrict__ C) {
@@ -1249,8 +1521,10 @@ INSTANTIATE_KBIT_REPACK(3)
12491521INSTANTIATE_KBIT_REPACK(4 )
12501522INSTANTIATE_KBIT_REPACK(5 )
12511523
1252- // GEMM instantiations: one per K value (fp16 only for Stage 3)
1253- #define INSTANTIATE_KBIT_GEMM (K ) template void kbitGemmMinimal<K>(const half*, const unsigned int *, const unsigned char *, const float *, half*, int , int , int );
1524+ // GEMM instantiations: one per K value (fp16 only)
1525+ #define INSTANTIATE_KBIT_GEMM (K ) \
1526+ template void kbitGemmMinimal<K>(const half*, const unsigned int *, const unsigned char *, const float *, half*, int , int , int ); \
1527+ template void kbitGemmPipelined<K>(const half*, const unsigned int *, const unsigned char *, const float *, half*, int , int , int );
12541528
12551529INSTANTIATE_KBIT_GEMM (2 )
12561530INSTANTIATE_KBIT_GEMM(3 )
0 commit comments