@@ -1403,6 +1403,297 @@ void kbitGemmPipelined(
14031403 CUDA_CHECK_RETURN (cudaPeekAtLastError ());
14041404}
14051405
1406+ // ---- Stage 5: Split-K fused kbit dequant + GEMM kernel ----
1407+ // Extends Stage 4 with split-K: multiple blocks share an output tile, each handling
1408+ // a subset of k-tiles. Partial sums accumulated via atomicAdd in fp32 workspace.
1409+ // Grid: (n_tiles, m_tiles) for k_chunks=1, (n_tiles, m_tiles, k_chunks) for k_chunks>1.
1410+
1411+ template <int K_BITS>
1412+ __global__ void kbit_gemm_splitk (
1413+ const half* __restrict__ A, const unsigned int * __restrict__ B_packed, const unsigned char * __restrict__ B_absmax,
1414+ const float * __restrict__ codebook, half* __restrict__ C, float * __restrict__ C_workspace,
1415+ int * __restrict__ tile_counters, const int M, const int K_dim, const int N, const int k_chunks
1416+ ) {
1417+ constexpr int TILE_M = 16 ;
1418+ constexpr int TILE_K = 64 ;
1419+ constexpr int TILE_N = 128 ;
1420+ constexpr int BS = 32 ;
1421+ constexpr int KB_PER_TILE = TILE_K / BS;
1422+ constexpr int B_COL_WORDS = KB_PER_TILE * K_BITS;
1423+ constexpr int N_BLOCKS = 2 ;
1424+
1425+ constexpr int A_STAGE_ELEMS = TILE_M * TILE_K;
1426+ constexpr int B_STAGE_WORDS = TILE_N * B_COL_WORDS;
1427+ constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE;
1428+
1429+ constexpr int A_STAGE_BYTES = A_STAGE_ELEMS * sizeof (half);
1430+ constexpr int B_STAGE_BYTES_VAL = B_STAGE_WORDS * sizeof (unsigned int );
1431+ constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15 ) & ~15 ;
1432+ constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES_VAL + ABS_STAGE_ALIGNED;
1433+
1434+ const int n_tile = blockIdx .x ;
1435+ const int m_tile = blockIdx .y ;
1436+ const int k_chunk_id = (k_chunks > 1 ) ? blockIdx .z : 0 ;
1437+ const int n_tiles = N / TILE_N;
1438+ const int k_tiles = (K_dim + TILE_K - 1 ) / TILE_K;
1439+ const int tiles_per_chunk = (k_tiles + k_chunks - 1 ) / k_chunks;
1440+ const int kt_start = k_chunk_id * tiles_per_chunk;
1441+ const int kt_end = min (kt_start + tiles_per_chunk, k_tiles);
1442+
1443+ const int warp_id = threadIdx .x / 32 ;
1444+ const int lane_id = threadIdx .x % 32 ;
1445+ const int gid = lane_id / 4 ;
1446+ const int tid = lane_id % 4 ;
1447+ const int warp_n_base = warp_id * (TILE_N / 8 );
1448+ const int m_base = m_tile * TILE_M;
1449+
1450+ // Double-buffered shared memory
1451+ extern __shared__ char smem[];
1452+ auto sh_a = [&](int stage) -> half* {
1453+ return reinterpret_cast <half*>(smem + stage * STAGE_BYTES);
1454+ };
1455+ auto sh_b = [&](int stage) -> unsigned int * {
1456+ return reinterpret_cast <unsigned int *>(smem + stage * STAGE_BYTES + A_STAGE_BYTES);
1457+ };
1458+ auto sh_abs = [&](int stage) -> unsigned char * {
1459+ return reinterpret_cast <unsigned char *>(smem + stage * STAGE_BYTES + A_STAGE_BYTES + B_STAGE_BYTES_VAL);
1460+ };
1461+
1462+ half cb_h = (lane_id < (1 << K_BITS)) ? __float2half (codebook[lane_id]) : __float2half (0 .0f );
1463+
1464+ float frag_c[N_BLOCKS][4 ];
1465+ #pragma unroll
1466+ for (int nb = 0 ; nb < N_BLOCKS; nb++)
1467+ frag_c[nb][0 ] = frag_c[nb][1 ] = frag_c[nb][2 ] = frag_c[nb][3 ] = 0 .0f ;
1468+
1469+ // Early exit if this chunk has no tiles
1470+ if (kt_start >= k_tiles)
1471+ return ;
1472+
1473+ // Fetch tile lambda (same as Stage 4)
1474+ auto fetch_tile = [&](int stage, int kt) {
1475+ const int k_base = kt * TILE_K;
1476+ const int tile_idx = kt * n_tiles + n_tile;
1477+
1478+ const int b_global_base = tile_idx * B_STAGE_WORDS;
1479+ constexpr int B_INT4S = B_STAGE_BYTES_VAL / 16 ;
1480+ const int4 * b_src = reinterpret_cast <const int4 *>(B_packed + b_global_base);
1481+ int4 * b_dst = reinterpret_cast <int4 *>(sh_b (stage));
1482+ for (int i = threadIdx .x ; i < B_INT4S; i += blockDim .x )
1483+ cp_async_cg_16 (&b_dst[i], &b_src[i]);
1484+
1485+ const int abs_global_base = tile_idx * ABS_STAGE_BYTES;
1486+ constexpr int ABS_INT4S = (ABS_STAGE_BYTES + 15 ) / 16 ;
1487+ const int4 * abs_src = reinterpret_cast <const int4 *>(B_absmax + abs_global_base);
1488+ int4 * abs_dst = reinterpret_cast <int4 *>(sh_abs (stage));
1489+ if (threadIdx .x < ABS_INT4S)
1490+ cp_async_cg_16 (&abs_dst[threadIdx .x ], &abs_src[threadIdx .x ]);
1491+
1492+ half* a_dst = sh_a (stage);
1493+ for (int i = threadIdx .x ; i < A_STAGE_ELEMS; i += blockDim .x ) {
1494+ int row = i / TILE_K;
1495+ int col = i % TILE_K;
1496+ int gr = m_base + row;
1497+ int gc = k_base + col;
1498+ a_dst[row * TILE_K + col] = (gr < M && gc < K_dim) ? A[gr * K_dim + gc] : __float2half (0 .0f );
1499+ }
1500+ };
1501+
1502+ // Compute tile lambda (same as Stage 4)
1503+ auto compute_tile = [&](int stage) {
1504+ half* a_ptr = sh_a (stage);
1505+ unsigned int * b_ptr = sh_b (stage);
1506+ unsigned char * abs_ptr = sh_abs (stage);
1507+
1508+ #pragma unroll
1509+ for (int ks = 0 ; ks < 4 ; ks++) {
1510+ const int k_block = ks / 2 ;
1511+ const int half_idx = ks % 2 ;
1512+
1513+ uint32_t frag_a[4 ];
1514+ {
1515+ const int kc0 = ks * 16 + tid * 2 ;
1516+ const int kc1 = ks * 16 + tid * 2 + 8 ;
1517+ const int r0 = gid;
1518+ const int r1 = gid + 8 ;
1519+ half2 h_rlo_klo = __halves2half2 (
1520+ (r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0] : __float2half (0 .0f ),
1521+ (r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0 + 1 ] : __float2half (0 .0f ));
1522+ half2 h_rhi_klo = __halves2half2 (
1523+ (r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0] : __float2half (0 .0f ),
1524+ (r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0 + 1 ] : __float2half (0 .0f ));
1525+ half2 h_rlo_khi = __halves2half2 (
1526+ (r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1] : __float2half (0 .0f ),
1527+ (r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1 + 1 ] : __float2half (0 .0f ));
1528+ half2 h_rhi_khi = __halves2half2 (
1529+ (r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1] : __float2half (0 .0f ),
1530+ (r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1 + 1 ] : __float2half (0 .0f ));
1531+ frag_a[0 ] = *reinterpret_cast <uint32_t *>(&h_rlo_klo);
1532+ frag_a[1 ] = *reinterpret_cast <uint32_t *>(&h_rhi_klo);
1533+ frag_a[2 ] = *reinterpret_cast <uint32_t *>(&h_rlo_khi);
1534+ frag_a[3 ] = *reinterpret_cast <uint32_t *>(&h_rhi_khi);
1535+ }
1536+
1537+ #pragma unroll
1538+ for (int nb = 0 ; nb < N_BLOCKS; nb++) {
1539+ int col = warp_n_base + nb * 8 + gid;
1540+ unsigned int planes[K_BITS];
1541+ int b_addr = col * B_COL_WORDS + k_block * K_BITS;
1542+ #pragma unroll
1543+ for (int b = 0 ; b < K_BITS; b++)
1544+ planes[b] = b_ptr[b_addr + b];
1545+
1546+ half scale = __float2half (decode_e4m4_absmax (abs_ptr[col * KB_PER_TILE + k_block]));
1547+
1548+ const int bit_offset = half_idx * 16 ;
1549+ const int rows[4 ] = {2 * tid, 2 * tid + 1 , 2 * tid + 8 , 2 * tid + 9 };
1550+ half vals[4 ];
1551+ #pragma unroll
1552+ for (int r = 0 ; r < 4 ; r++) {
1553+ int bit_pos = bit_offset + rows[r];
1554+ int idx = 0 ;
1555+ #pragma unroll
1556+ for (int b = 0 ; b < K_BITS; b++)
1557+ idx |= ((planes[b] >> bit_pos) & 1 ) << b;
1558+ vals[r] = __hmul (__shfl_sync (0xFFFFFFFF , cb_h, idx), scale);
1559+ }
1560+
1561+ uint32_t frag_b[2 ];
1562+ {
1563+ half2 b0 = __halves2half2 (vals[0 ], vals[1 ]);
1564+ half2 b1 = __halves2half2 (vals[2 ], vals[3 ]);
1565+ frag_b[0 ] = *reinterpret_cast <uint32_t *>(&b0);
1566+ frag_b[1 ] = *reinterpret_cast <uint32_t *>(&b1);
1567+ }
1568+
1569+ asm volatile (" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
1570+ " {%0, %1, %2, %3}, "
1571+ " {%4, %5, %6, %7}, "
1572+ " {%8, %9}, "
1573+ " {%10, %11, %12, %13};\n "
1574+ : " =f" (frag_c[nb][0 ]), " =f" (frag_c[nb][1 ]), " =f" (frag_c[nb][2 ]),
1575+ " =f" (frag_c[nb][3 ])
1576+ : " r" (frag_a[0 ]), " r" (frag_a[1 ]), " r" (frag_a[2 ]), " r" (frag_a[3 ]),
1577+ " r" (frag_b[0 ]), " r" (frag_b[1 ]),
1578+ " f" (frag_c[nb][0 ]), " f" (frag_c[nb][1 ]), " f" (frag_c[nb][2 ]),
1579+ " f" (frag_c[nb][3 ]));
1580+ }
1581+ }
1582+ };
1583+
1584+ // ---- Pipeline over [kt_start, kt_end) ----
1585+ fetch_tile (0 , kt_start);
1586+ cp_async_fence ();
1587+
1588+ for (int kt = kt_start; kt < kt_end; kt++) {
1589+ int cur = (kt - kt_start) % 2 ;
1590+ if (kt + 1 < kt_end) {
1591+ fetch_tile ((kt + 1 - kt_start) % 2 , kt + 1 );
1592+ cp_async_fence ();
1593+ cp_async_wait<1 >();
1594+ } else {
1595+ cp_async_wait<0 >();
1596+ }
1597+ __syncthreads ();
1598+ compute_tile (cur);
1599+ __syncthreads ();
1600+ }
1601+
1602+ // ---- Write output ----
1603+ if (k_chunks == 1 ) {
1604+ // No split-K: write fp16 directly (same as Stage 4)
1605+ #pragma unroll
1606+ for (int nb = 0 ; nb < N_BLOCKS; nb++) {
1607+ int c_col = n_tile * TILE_N + warp_n_base + nb * 8 + tid * 2 ;
1608+ int m_row0 = m_base + gid;
1609+ int m_row1 = m_base + gid + 8 ;
1610+ if (m_row0 < M) {
1611+ C[m_row0 * N + c_col] = __float2half (frag_c[nb][0 ]);
1612+ C[m_row0 * N + c_col + 1 ] = __float2half (frag_c[nb][1 ]);
1613+ }
1614+ if (m_row1 < M) {
1615+ C[m_row1 * N + c_col] = __float2half (frag_c[nb][2 ]);
1616+ C[m_row1 * N + c_col + 1 ] = __float2half (frag_c[nb][3 ]);
1617+ }
1618+ }
1619+ } else {
1620+ // Split-K: atomicAdd partial sums to fp32 workspace (pre-zeroed by host)
1621+ #pragma unroll
1622+ for (int nb = 0 ; nb < N_BLOCKS; nb++) {
1623+ int c_col = n_tile * TILE_N + warp_n_base + nb * 8 + tid * 2 ;
1624+ int m_row0 = m_base + gid;
1625+ int m_row1 = m_base + gid + 8 ;
1626+ if (m_row0 < M) {
1627+ atomicAdd (&C_workspace[m_row0 * N + c_col], frag_c[nb][0 ]);
1628+ atomicAdd (&C_workspace[m_row0 * N + c_col + 1 ], frag_c[nb][1 ]);
1629+ }
1630+ if (m_row1 < M) {
1631+ atomicAdd (&C_workspace[m_row1 * N + c_col], frag_c[nb][2 ]);
1632+ atomicAdd (&C_workspace[m_row1 * N + c_col + 1 ], frag_c[nb][3 ]);
1633+ }
1634+ }
1635+
1636+ // Ensure all atomicAdds from this block are globally visible
1637+ __threadfence ();
1638+
1639+ // Signal completion and check if we're the last contributor
1640+ __shared__ int is_last;
1641+ if (threadIdx .x == 0 ) {
1642+ int mn_id = m_tile * n_tiles + n_tile;
1643+ int done = atomicAdd (&tile_counters[mn_id], 1 );
1644+ is_last = (done == k_chunks - 1 ) ? 1 : 0 ;
1645+ }
1646+ __syncthreads ();
1647+
1648+ // Last contributor: convert fp32 workspace -> fp16 output for this tile
1649+ if (is_last) {
1650+ for (int i = threadIdx .x ; i < TILE_M * TILE_N; i += blockDim .x ) {
1651+ int row = m_base + i / TILE_N;
1652+ int col = n_tile * TILE_N + i % TILE_N;
1653+ if (row < M)
1654+ C[row * N + col] = __float2half (C_workspace[row * N + col]);
1655+ }
1656+ }
1657+ }
1658+ }
1659+
1660+ // Stage 5 split-K GEMM launcher
1661+ template <int K>
1662+ void kbitGemmSplitK (
1663+ const half* A, const unsigned int * B_packed, const unsigned char * B_absmax, const float * codebook, half* C,
1664+ float * C_workspace, int * tile_counters, int M, int K_dim, int N, int k_chunks
1665+ ) {
1666+ constexpr int TILE_M = 16 ;
1667+ constexpr int TILE_K = 64 ;
1668+ constexpr int TILE_N = 128 ;
1669+ constexpr int BS = 32 ;
1670+ constexpr int KB_PER_TILE = TILE_K / BS;
1671+ constexpr int B_COL_WORDS = KB_PER_TILE * K;
1672+
1673+ constexpr int A_STAGE_BYTES = TILE_M * TILE_K * sizeof (half);
1674+ constexpr int B_STAGE_BYTES = TILE_N * B_COL_WORDS * sizeof (unsigned int );
1675+ constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE;
1676+ constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15 ) & ~15 ;
1677+ constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES + ABS_STAGE_ALIGNED;
1678+
1679+ int m_tiles = (M + TILE_M - 1 ) / TILE_M;
1680+ int n_tiles = N / TILE_N;
1681+
1682+ dim3 block (256 );
1683+ int smem_size = 2 * STAGE_BYTES;
1684+
1685+ if (k_chunks <= 1 ) {
1686+ dim3 grid (n_tiles, m_tiles);
1687+ kbit_gemm_splitk<K><<<grid, block, smem_size>>> (
1688+ A, B_packed, B_absmax, codebook, C, nullptr , nullptr , M, K_dim, N, 1 );
1689+ } else {
1690+ dim3 grid (n_tiles, m_tiles, k_chunks);
1691+ kbit_gemm_splitk<K><<<grid, block, smem_size>>> (
1692+ A, B_packed, B_absmax, codebook, C, C_workspace, tile_counters, M, K_dim, N, k_chunks);
1693+ }
1694+ CUDA_CHECK_RETURN (cudaPeekAtLastError ());
1695+ }
1696+
14061697// ---- Debug: Simple MMA test kernel ----
14071698// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
14081699__global__ void test_mma_kernel (const half* __restrict__ A, const half* __restrict__ B, float * __restrict__ C) {
@@ -1524,7 +1815,8 @@ INSTANTIATE_KBIT_REPACK(5)
15241815// GEMM instantiations: one per K value (fp16 only)
15251816#define INSTANTIATE_KBIT_GEMM (K ) \
15261817 template void kbitGemmMinimal<K>(const half*, const unsigned int *, const unsigned char *, const float *, half*, int , int , int ); \
1527- template void kbitGemmPipelined<K>(const half*, const unsigned int *, const unsigned char *, const float *, half*, int , int , int );
1818+ template void kbitGemmPipelined<K>(const half*, const unsigned int *, const unsigned char *, const float *, half*, int , int , int ); \
1819+ template void kbitGemmSplitK<K>(const half*, const unsigned int *, const unsigned char *, const float *, half*, float *, int *, int , int , int , int );
15281820
15291821INSTANTIATE_KBIT_GEMM (2 )
15301822INSTANTIATE_KBIT_GEMM(3 )
0 commit comments