@@ -2162,6 +2162,278 @@ void kbitScalarGemvTiled(
21622162#undef LAUNCH_SCALAR_GEMV_TILED
21632163}
21642164
2165+ // ---- Tiled Scalar GEMV v2 ----
2166+ // Cooperative tile loading into shared memory with split-K for occupancy.
2167+ // Grid = n_tiles * k_splits, Block = 128 threads (4 warps).
2168+ // Each thread handles one column within an N-tile.
2169+ // Double-buffered cp.async pipeline for B + absmax tiles.
2170+ // A loaded directly from global memory (L1 broadcast across columns).
2171+
2172+ template <int K_BITS, int M_VAL, typename scalar_t = half, typename ABSMAX_T = unsigned char >
2173+ __global__ void __launch_bounds__ (128 , 8 ) kbit_scalar_gemv_tiled_v2(
2174+ const scalar_t * __restrict__ A,
2175+ const unsigned int * __restrict__ B_packed,
2176+ const ABSMAX_T* __restrict__ B_absmax,
2177+ const float * __restrict__ codebook,
2178+ scalar_t * __restrict__ C,
2179+ float * __restrict__ C_workspace,
2180+ int * __restrict__ tile_counters,
2181+ const int M, const int K_dim, const int N, const int k_splits
2182+ ) {
2183+ constexpr int BS = 32 ; // quantization block size
2184+ constexpr int TILE_K = 64 ;
2185+ constexpr int TILE_N = 128 ;
2186+ constexpr int BLOCK_DIM = 128 ; // threads per block
2187+ constexpr int NUM_WARPS = 4 ;
2188+ constexpr int M_MAX = 4 ;
2189+ constexpr int KB_PER_TILE = TILE_K / BS; // 2
2190+ constexpr int B_COL_WORDS = KB_PER_TILE * K_BITS;
2191+ constexpr int B_STAGE_WORDS = TILE_N * B_COL_WORDS;
2192+ constexpr int B_STAGE_BYTES = B_STAGE_WORDS * (int )sizeof (unsigned int );
2193+ constexpr int ABS_STAGE_ELEMS = TILE_N * KB_PER_TILE;
2194+ constexpr int ABS_STAGE_BYTES = ABS_STAGE_ELEMS * (int )sizeof (ABSMAX_T);
2195+ constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15 ) & ~15 ;
2196+ constexpr int STAGE_BYTES = B_STAGE_BYTES + ABS_STAGE_ALIGNED;
2197+
2198+ const int n_tiles = N / TILE_N;
2199+ const int k_tiles = (K_dim + TILE_K - 1 ) / TILE_K;
2200+ const int tiles_per_split = (k_tiles + k_splits - 1 ) / k_splits;
2201+
2202+ // Work item: which N-tile and K-split
2203+ const int work_id = blockIdx .x ;
2204+ const int n_tile = work_id / k_splits;
2205+ const int ks_id = work_id % k_splits;
2206+ const int n_base = n_tile * TILE_N;
2207+
2208+ const int kt_start = ks_id * tiles_per_split;
2209+ const int kt_end = min (kt_start + tiles_per_split, k_tiles);
2210+ if (kt_start >= k_tiles) return ;
2211+
2212+ // This thread's column within the tile
2213+ const int col_in_tile = threadIdx .x ; // 0..127
2214+ const int col = n_base + col_in_tile;
2215+
2216+ const int warp_id = threadIdx .x / 32 ;
2217+ const int lane_id = threadIdx .x % 32 ;
2218+
2219+ // Codebook in registers (shuffle-based lookup)
2220+ float cb = (lane_id < (1 << K_BITS)) ? codebook[lane_id] : 0 .0f ;
2221+
2222+ // Double-buffered shared memory
2223+ extern __shared__ char smem[];
2224+ auto sh_b = [&](int stage) -> unsigned int * {
2225+ return reinterpret_cast <unsigned int *>(smem + stage * STAGE_BYTES);
2226+ };
2227+ auto sh_abs = [&](int stage) -> ABSMAX_T* {
2228+ return reinterpret_cast <ABSMAX_T*>(smem + stage * STAGE_BYTES + B_STAGE_BYTES);
2229+ };
2230+
2231+ // Accumulators
2232+ float acc[M_VAL];
2233+ #pragma unroll
2234+ for (int m = 0 ; m < M_VAL; m++) acc[m] = 0 .0f ;
2235+
2236+ // Fetch tile: cooperative cp.async loading of B + absmax
2237+ auto fetch_tile = [&](int stage, int kt) {
2238+ const int tile_idx = kt * n_tiles + n_tile; // K-major tile ordering
2239+
2240+ // B tile via cp.async (all 128 threads cooperatively load)
2241+ const int b_global_base = tile_idx * B_STAGE_WORDS;
2242+ constexpr int B_INT4S = B_STAGE_BYTES / 16 ;
2243+ const int4 * b_src = reinterpret_cast <const int4 *>(B_packed + b_global_base);
2244+ int4 * b_dst = reinterpret_cast <int4 *>(sh_b (stage));
2245+ for (int i = threadIdx .x ; i < B_INT4S; i += BLOCK_DIM)
2246+ cp_async_cg_16 (&b_dst[i], &b_src[i]);
2247+
2248+ // Absmax via cp.async
2249+ const int abs_global_base = tile_idx * ABS_STAGE_ELEMS;
2250+ constexpr int ABS_INT4S = (ABS_STAGE_BYTES + 15 ) / 16 ;
2251+ const int4 * abs_src = reinterpret_cast <const int4 *>(B_absmax + abs_global_base);
2252+ int4 * abs_dst = reinterpret_cast <int4 *>(sh_abs (stage));
2253+ for (int i = threadIdx .x ; i < ABS_INT4S; i += BLOCK_DIM)
2254+ cp_async_cg_16 (&abs_dst[i], &abs_src[i]);
2255+ };
2256+
2257+ // Compute tile: each thread reads its column from shared memory
2258+ auto compute_tile = [&](int stage, int kt) {
2259+ unsigned int * b_ptr = sh_b (stage);
2260+ ABSMAX_T* abs_ptr = sh_abs (stage);
2261+ const int k_base = kt * TILE_K;
2262+
2263+ // Process KB_PER_TILE (=2) K-blocks within this tile
2264+ #pragma unroll
2265+ for (int kb = 0 ; kb < KB_PER_TILE; kb++) {
2266+ const int block_k_base = k_base + kb * BS;
2267+ if (block_k_base >= K_dim) continue ;
2268+
2269+ // Read bit-planes from shared memory for this column
2270+ int b_addr = col_in_tile * B_COL_WORDS + kb * K_BITS;
2271+ unsigned int planes[K_BITS];
2272+ if constexpr (K_BITS == 2 ) {
2273+ uint2 pv = *reinterpret_cast <const uint2 *>(&b_ptr[b_addr]);
2274+ planes[0 ] = pv.x ; planes[1 ] = pv.y ;
2275+ } else if constexpr (K_BITS == 4 ) {
2276+ int4 pv = *reinterpret_cast <const int4 *>(&b_ptr[b_addr]);
2277+ planes[0 ] = (unsigned int )pv.x ; planes[1 ] = (unsigned int )pv.y ;
2278+ planes[2 ] = (unsigned int )pv.z ; planes[3 ] = (unsigned int )pv.w ;
2279+ } else {
2280+ #pragma unroll
2281+ for (int b = 0 ; b < K_BITS; b++)
2282+ planes[b] = b_ptr[b_addr + b];
2283+ }
2284+
2285+ // Load absmax from shared memory
2286+ float amax = load_absmax (abs_ptr, col_in_tile * KB_PER_TILE + kb);
2287+
2288+ // Dequant-once loop: decode weight once, FMA across M rows
2289+ #pragma unroll
2290+ for (int sub = 0 ; sub < 4 ; sub++) {
2291+ // Load A for all M rows (int4 = 8 fp16 values)
2292+ int4 av[M_VAL];
2293+ #pragma unroll
2294+ for (int m = 0 ; m < M_VAL; m++)
2295+ av[m] = *reinterpret_cast <const int4 *>(&A[m * K_dim + block_k_base + sub * 8 ]);
2296+
2297+ // Dequant each element once, then FMA across M rows
2298+ #pragma unroll
2299+ for (int j = 0 ; j < 8 ; j++) {
2300+ int idx = 0 ;
2301+ #pragma unroll
2302+ for (int b = 0 ; b < K_BITS; b++)
2303+ idx |= ((planes[b] >> (sub * 8 + j)) & 1 ) << b;
2304+ float w = __shfl_sync (0xFFFFFFFF , cb, idx) * amax;
2305+
2306+ #pragma unroll
2307+ for (int m = 0 ; m < M_VAL; m++) {
2308+ const scalar_t * ap = reinterpret_cast <const scalar_t *>(&av[m]);
2309+ acc[m] += w * ScalarOps<scalar_t >::to_float (ap[j]);
2310+ }
2311+ }
2312+ }
2313+ }
2314+ };
2315+
2316+ // Pipeline: double-buffered cp.async
2317+ fetch_tile (0 , kt_start);
2318+ cp_async_fence ();
2319+
2320+ for (int kt = kt_start; kt < kt_end; kt++) {
2321+ int cur = (kt - kt_start) % 2 ;
2322+ if (kt + 1 < kt_end) {
2323+ fetch_tile ((kt + 1 - kt_start) % 2 , kt + 1 );
2324+ cp_async_fence ();
2325+ cp_async_wait<1 >();
2326+ } else {
2327+ cp_async_wait<0 >();
2328+ }
2329+ __syncthreads ();
2330+ compute_tile (cur, kt);
2331+ __syncthreads ();
2332+ }
2333+
2334+ // Write output
2335+ if (k_splits == 1 ) {
2336+ // Direct write — this block owns the full K reduction
2337+ #pragma unroll
2338+ for (int m = 0 ; m < M_VAL; m++) {
2339+ if (m < M && col < N)
2340+ C[m * N + col] = ScalarOps<scalar_t >::from_float (acc[m]);
2341+ }
2342+ } else {
2343+ // Partial K — atomicAdd to workspace
2344+ #pragma unroll
2345+ for (int m = 0 ; m < M_VAL; m++) {
2346+ if (m < M && col < N)
2347+ atomicAdd (&C_workspace[m * N + col], acc[m]);
2348+ }
2349+
2350+ __threadfence ();
2351+
2352+ // Last-arriving split converts workspace to output
2353+ __shared__ int is_last;
2354+ if (threadIdx .x == 0 ) {
2355+ int done = atomicAdd (&tile_counters[n_tile], 1 );
2356+ is_last = (done == k_splits - 1 ) ? 1 : 0 ;
2357+ }
2358+ __syncthreads ();
2359+
2360+ if (is_last) {
2361+ for (int i = threadIdx .x ; i < M_VAL * TILE_N; i += BLOCK_DIM) {
2362+ int m = i / TILE_N;
2363+ int c = n_base + i % TILE_N;
2364+ if (m < M && c < N)
2365+ C[m * N + c] = ScalarOps<scalar_t >::from_float (C_workspace[m * N + c]);
2366+ }
2367+ }
2368+ }
2369+ }
2370+
2371+ // ---- Tiled GEMV v2 launcher ----
2372+ template <int K, int MV, typename scalar_t , typename ABSMAX_T>
2373+ static void kbitScalarGemvTiledV2Launch (
2374+ const scalar_t * A, const unsigned int * B_packed, const ABSMAX_T* B_absmax,
2375+ const float * codebook, scalar_t * C, float * C_workspace, int * tile_counters,
2376+ int M, int K_dim, int N, int num_sms, cudaStream_t stream
2377+ ) {
2378+ constexpr int TILE_N = 128 ;
2379+ constexpr int TILE_K = 64 ;
2380+ constexpr int BLOCK_DIM = 128 ;
2381+ constexpr int BS = 32 ;
2382+ constexpr int KB_PER_TILE = TILE_K / BS;
2383+ constexpr int B_COL_WORDS = KB_PER_TILE * K;
2384+ constexpr int B_STAGE_BYTES = TILE_N * B_COL_WORDS * (int )sizeof (unsigned int );
2385+ constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE * (int )sizeof (ABSMAX_T);
2386+ constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15 ) & ~15 ;
2387+ constexpr int STAGE_BYTES = B_STAGE_BYTES + ABS_STAGE_ALIGNED;
2388+
2389+ int n_tiles = N / TILE_N;
2390+ int k_tiles = (K_dim + TILE_K - 1 ) / TILE_K;
2391+
2392+ // Choose k_splits to achieve ~4 blocks per SM.
2393+ // Recompute from tiles_per_split to guarantee no empty splits
2394+ // (empty splits would skip the tile_counters atomicAdd, breaking the last-block check).
2395+ int target_blocks = num_sms * 4 ;
2396+ int k_splits = max (1 , (target_blocks + n_tiles - 1 ) / n_tiles);
2397+ k_splits = min (k_splits, k_tiles);
2398+ int tiles_per_split = (k_tiles + k_splits - 1 ) / k_splits;
2399+ k_splits = (k_tiles + tiles_per_split - 1 ) / tiles_per_split; // no empty splits
2400+
2401+ int grid_size = n_tiles * k_splits;
2402+ int smem_size = 2 * STAGE_BYTES;
2403+
2404+ kbit_scalar_gemv_tiled_v2<K, MV, scalar_t , ABSMAX_T>
2405+ <<<grid_size, BLOCK_DIM, smem_size, stream>>> (
2406+ A, B_packed, B_absmax, codebook, C, C_workspace, tile_counters,
2407+ M, K_dim, N, k_splits
2408+ );
2409+ CUDA_CHECK_RETURN (cudaPeekAtLastError ());
2410+ }
2411+
2412+ // Public entry point: selects M_VAL template, queries num_sms internally
2413+ template <int K, typename scalar_t , typename ABSMAX_T>
2414+ void kbitScalarGemvTiledV2 (
2415+ const scalar_t * A, const unsigned int * B_packed, const ABSMAX_T* B_absmax,
2416+ const float * codebook, scalar_t * C, float * C_workspace, int * tile_counters,
2417+ int M, int K_dim, int N, cudaStream_t stream
2418+ ) {
2419+ int dev;
2420+ cudaGetDevice (&dev);
2421+ int num_sms;
2422+ cudaDeviceGetAttribute (&num_sms, cudaDevAttrMultiProcessorCount, dev);
2423+
2424+ #define LAUNCH_GEMV_V2 (MV ) \
2425+ kbitScalarGemvTiledV2Launch<K, MV, scalar_t , ABSMAX_T>( \
2426+ A, B_packed, B_absmax, codebook, C, C_workspace, tile_counters, \
2427+ M, K_dim, N, num_sms, stream)
2428+
2429+ if (M <= 1 ) { LAUNCH_GEMV_V2 (1 ); }
2430+ else if (M <= 2 ) { LAUNCH_GEMV_V2 (2 ); }
2431+ else if (M <= 3 ) { LAUNCH_GEMV_V2 (3 ); }
2432+ else { LAUNCH_GEMV_V2 (4 ); }
2433+
2434+ #undef LAUNCH_GEMV_V2
2435+ }
2436+
21652437// ---- Debug: Simple MMA test kernel ----
21662438// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
21672439__global__ void test_mma_kernel (const half* __restrict__ A, const half* __restrict__ B, float * __restrict__ C) {
@@ -2439,3 +2711,31 @@ INSTANTIATE_KBIT_SCALAR_GEMV_TILED_FP16(2)
24392711INSTANTIATE_KBIT_SCALAR_GEMV_TILED_FP16(3 )
24402712INSTANTIATE_KBIT_SCALAR_GEMV_TILED_FP16(4 )
24412713INSTANTIATE_KBIT_SCALAR_GEMV_TILED_FP16(5 )
2714+ // Scalar GEMV v2 (tiled with shared memory) instantiations — uint8 E4M4 absmax
2715+ #define INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8 (K ) \
2716+ template void kbitScalarGemvTiledV2<K, half, unsigned char >( \
2717+ const half*, const unsigned int *, const unsigned char *, const float *, half*, float *, int *, \
2718+ int , int , int , cudaStream_t \
2719+ ); \
2720+ template void kbitScalarGemvTiledV2<K, __nv_bfloat16, unsigned char >( \
2721+ const __nv_bfloat16*, const unsigned int *, const unsigned char *, const float *, __nv_bfloat16*, float *, int *, \
2722+ int , int , int , cudaStream_t \
2723+ );
2724+ INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8 (2 )
2725+ INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8(3 )
2726+ INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8(4 )
2727+ INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8(5 )
2728+ // fp16 absmax
2729+ #define INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16 (K ) \
2730+ template void kbitScalarGemvTiledV2<K, half, half>( \
2731+ const half*, const unsigned int *, const half*, const float *, half*, float *, int *, \
2732+ int , int , int , cudaStream_t \
2733+ ); \
2734+ template void kbitScalarGemvTiledV2<K, __nv_bfloat16, half>( \
2735+ const __nv_bfloat16*, const unsigned int *, const half*, const float *, __nv_bfloat16*, float *, int *, \
2736+ int , int , int , cudaStream_t \
2737+ );
2738+ INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16 (2 )
2739+ INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16(3 )
2740+ INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16(4 )
2741+ INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16(5 )
0 commit comments