@@ -2546,6 +2546,208 @@ void kbitGroupedGemmProd(
25462546 CUDA_CHECK_RETURN (cudaFree (d_work_offsets));
25472547}
25482548
2549+ // ===================================================================
2550+ // K-bit scalar GEMV: C[M,N] = A[M,K] * W_kbit^T (M=1..4)
2551+ // ===================================================================
2552+ //
2553+ // C=1 architecture: 1 output column per block, 4 warps split K.
2554+ // Grid = N (direct mapping). No split-K, no workspace.
2555+ // Element-at-a-time inner loop for low register pressure (~30 regs).
2556+ // Two-phase shared memory reduction (warp shuffle + shmem).
2557+ // B_packed and B_absmax are in flat (quantize_kbit) layout, no repack needed.
2558+
2559+ template <int K_BITS, int M_VAL, typename scalar_t >
2560+ __global__ void __launch_bounds__ (128 , 12 )
2561+ kbit_scalar_gemv(
2562+ const scalar_t * __restrict__ A,
2563+ const unsigned int * __restrict__ B_packed, // flat: [N * num_k_blocks * K_BITS] uint32
2564+ const float * __restrict__ B_absmax, // flat: [N * num_k_blocks] float32
2565+ const float * __restrict__ codebook,
2566+ scalar_t * __restrict__ C,
2567+ const int M, const int K_dim, const int N
2568+ ) {
2569+ constexpr int BS = 32 ; // quantization block size
2570+ constexpr int NUM_WARPS = 4 ;
2571+ constexpr int M_MAX = 4 ;
2572+
2573+ const int warp_id = threadIdx .x / 32 ;
2574+ const int lane_id = threadIdx .x % 32 ;
2575+ const int col = blockIdx .x ;
2576+
2577+ const int num_k_blocks = K_dim / BS;
2578+
2579+ // Codebook in registers (shuffle-based lookup)
2580+ float cb = (lane_id < (1 << K_BITS)) ? codebook[lane_id] : 0 .0f ;
2581+
2582+ // Column base pointers (flat layout)
2583+ const unsigned int * B_col = B_packed + col * num_k_blocks * K_BITS;
2584+ const float * abs_col = B_absmax + col * num_k_blocks;
2585+
2586+ // Accumulators
2587+ float acc[M_VAL];
2588+ #pragma unroll
2589+ for (int m = 0 ; m < M_VAL; m++) acc[m] = 0 .0f ;
2590+
2591+ // All 128 threads stride through K blocks: thread t handles blocks t, t+128, t+256, ...
2592+ // max_iters ensures all lanes iterate same number of times (no warp divergence at __shfl_sync).
2593+ const int max_iters = (num_k_blocks + 127 ) / 128 ;
2594+
2595+ for (int iter = 0 ; iter < max_iters; iter++) {
2596+ const int block_idx = threadIdx .x + iter * 128 ;
2597+ const bool valid = (block_idx < num_k_blocks);
2598+
2599+ // Load k bit-plane words (guarded; invalid threads get 0)
2600+ unsigned int planes[K_BITS];
2601+ #pragma unroll
2602+ for (int b = 0 ; b < K_BITS; b++)
2603+ planes[b] = valid ? B_col[block_idx * K_BITS + b] : 0u ;
2604+
2605+ // Load absmax (guarded; invalid threads get 0)
2606+ float amax = valid ? abs_col[block_idx] : 0 .0f ;
2607+
2608+ const int k_base = block_idx * BS;
2609+
2610+ // Double-buffered shared memory for async A loads
2611+ // Each thread loads 16 bytes per M row per sub-iteration
2612+ __shared__ int4 sh_A[2 ][M_MAX][128 ]; // [buffer][m][thread] = 2 * 4 * 128 * 16 = 16KB max
2613+
2614+ // Prefetch first sub-iteration before the loop
2615+ #pragma unroll
2616+ for (int m = 0 ; m < M_VAL; m++) {
2617+ if (valid) {
2618+ const int4 * a_ptr = reinterpret_cast <const int4 *>(
2619+ &A[m * K_dim + k_base + 0 * 8 ]);
2620+ asm volatile (" cp.async.ca.shared.global [%0], [%1], 16;" :
2621+ : " r" ((unsigned )&sh_A[0 ][m][threadIdx .x ]), " l" ((unsigned long long )a_ptr));
2622+ }
2623+ }
2624+ asm volatile (" cp.async.commit_group;" );
2625+
2626+ #pragma unroll
2627+ for (int sub = 0 ; sub < 4 ; sub++) {
2628+ const int buf = sub % 2 ;
2629+ const int next_buf = (sub + 1 ) % 2 ;
2630+
2631+ // Wait for previous prefetch to complete
2632+ asm volatile (" cp.async.wait_group %0;" : : " n" (0 ));
2633+
2634+ // Prefetch next sub-iteration (if not last)
2635+ if (sub + 1 < 4 ) {
2636+ #pragma unroll
2637+ for (int m = 0 ; m < M_VAL; m++) {
2638+ if (valid) {
2639+ const int4 * a_ptr = reinterpret_cast <const int4 *>(
2640+ &A[m * K_dim + k_base + (sub + 1 ) * 8 ]);
2641+ asm volatile (" cp.async.ca.shared.global [%0], [%1], 16;" :
2642+ : " r" ((unsigned )&sh_A[next_buf][m][threadIdx .x ]), " l" ((unsigned long long )a_ptr));
2643+ }
2644+ }
2645+ asm volatile (" cp.async.commit_group;" );
2646+ }
2647+
2648+ // Cache 8 dequantized weights (same for all M rows)
2649+ // Optimized: pre-shift planes once, then extract with fewer ops
2650+ unsigned int sp[K_BITS];
2651+ #pragma unroll
2652+ for (int b = 0 ; b < K_BITS; b++)
2653+ sp[b] = planes[b] >> (sub * 8 );
2654+
2655+ float w8[8 ];
2656+ #pragma unroll
2657+ for (int j = 0 ; j < 8 ; j++) {
2658+ // Extract j-th bit from each plane and combine into index
2659+ int idx = 0 ;
2660+ #pragma unroll
2661+ for (int b = 0 ; b < K_BITS; b++)
2662+ idx |= ((sp[b] >> j) & 1 ) << b;
2663+ w8[j] = __shfl_sync (0xFFFFFFFF , cb, idx) * amax;
2664+ }
2665+
2666+ // Compute using prefetched A from shared memory
2667+ #pragma unroll
2668+ for (int m = 0 ; m < M_VAL; m++) {
2669+ const int4 av = sh_A[buf][m][threadIdx .x ];
2670+ const scalar_t * ap = reinterpret_cast <const scalar_t *>(&av);
2671+
2672+ // FMA using cached weights
2673+ #pragma unroll
2674+ for (int j = 0 ; j < 8 ; j++) {
2675+ if (valid)
2676+ acc[m] += w8[j] * ScalarOps<scalar_t >::to_float (ap[j]);
2677+ }
2678+ }
2679+ }
2680+ }
2681+
2682+ // Phase 1: Intra-warp reduction via shuffle
2683+ #pragma unroll
2684+ for (int m = 0 ; m < M_VAL; m++) {
2685+ #pragma unroll
2686+ for (int offset = 16 ; offset >= 1 ; offset /= 2 )
2687+ acc[m] += __shfl_down_sync (0xFFFFFFFF , acc[m], offset);
2688+ }
2689+
2690+ // Phase 2: Inter-warp reduction via shared memory
2691+ __shared__ float s_partial[NUM_WARPS * M_MAX];
2692+
2693+ if (lane_id == 0 ) {
2694+ #pragma unroll
2695+ for (int m = 0 ; m < M_VAL; m++)
2696+ s_partial[warp_id * M_MAX + m] = acc[m];
2697+ }
2698+ __syncthreads ();
2699+
2700+ // Thread 0 sums all 4 warps and writes output
2701+ if (threadIdx .x == 0 ) {
2702+ #pragma unroll
2703+ for (int m = 0 ; m < M_VAL; m++) {
2704+ if (m < M) {
2705+ float sum = s_partial[0 * M_MAX + m] + s_partial[1 * M_MAX + m]
2706+ + s_partial[2 * M_MAX + m] + s_partial[3 * M_MAX + m];
2707+ C[m * N + col] = ScalarOps<scalar_t >::from_float (sum);
2708+ }
2709+ }
2710+ }
2711+ }
2712+
2713+ // ---- Scalar GEMV launcher ----
2714+ template <int K, int MV, typename scalar_t >
2715+ static void kbitScalarGemvLaunch (
2716+ const scalar_t * A, const unsigned int * B_packed,
2717+ const float * B_absmax, const float * codebook,
2718+ scalar_t * C, int M, int K_dim, int N
2719+ ) {
2720+ constexpr int BLOCK_SIZE = 128 ; // 4 warps
2721+ int grid_size = N; // C=1: one block per output column
2722+
2723+ kbit_scalar_gemv<K, MV, scalar_t ><<<grid_size, BLOCK_SIZE>>> (
2724+ A, B_packed, B_absmax, codebook, C, M, K_dim, N);
2725+ CUDA_CHECK_RETURN (cudaPeekAtLastError ());
2726+ }
2727+
2728+ // Public entry point: selects M_VAL template
2729+ template <int K, typename scalar_t >
2730+ void kbitScalarGemv (
2731+ const scalar_t * A, const unsigned int * B_packed,
2732+ const float * B_absmax, const float * codebook,
2733+ scalar_t * C, int M, int K_dim, int N
2734+ ) {
2735+ #define LAUNCH_SCALAR_GEMV (MV ) \
2736+ kbitScalarGemvLaunch<K, MV, scalar_t >( \
2737+ A, B_packed, B_absmax, codebook, C, M, K_dim, N)
2738+
2739+ if (M <= 1 ) { LAUNCH_SCALAR_GEMV (1 ); }
2740+ else if (M <= 2 ) { LAUNCH_SCALAR_GEMV (2 ); }
2741+ else if (M <= 3 ) { LAUNCH_SCALAR_GEMV (3 ); }
2742+ else { LAUNCH_SCALAR_GEMV (4 ); }
2743+
2744+ #undef LAUNCH_SCALAR_GEMV
2745+ }
2746+
2747+ // ===================================================================
2748+ // Grouped scalar GEMV: MoE expert dispatch (no split-K needed)
2749+ // ===================================================================
2750+
25492751// ---- Debug: Simple MMA test kernel ----
25502752// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
25512753__global__ void test_mma_kernel (const half* __restrict__ A, const half* __restrict__ B, float * __restrict__ C) {
@@ -2694,3 +2896,13 @@ INSTANTIATE_KBIT_GROUPED_GEMM_PROD(2)
26942896INSTANTIATE_KBIT_GROUPED_GEMM_PROD(3 )
26952897INSTANTIATE_KBIT_GROUPED_GEMM_PROD(4 )
26962898INSTANTIATE_KBIT_GROUPED_GEMM_PROD(5 )
2899+
2900+ // Scalar GEMV instantiations (fp16 and bf16)
2901+ #define INSTANTIATE_SCALAR_GEMV (K ) \
2902+ template void kbitScalarGemv<K, half>(const half*, const unsigned int *, const float *, const float *, half*, int , int , int ); \
2903+ template void kbitScalarGemv<K, __nv_bfloat16>(const __nv_bfloat16*, const unsigned int *, const float *, const float *, __nv_bfloat16*, int , int , int );
2904+
2905+ INSTANTIATE_SCALAR_GEMV (2 )
2906+ INSTANTIATE_SCALAR_GEMV(3 )
2907+ INSTANTIATE_SCALAR_GEMV(4 )
2908+ INSTANTIATE_SCALAR_GEMV(5 )
0 commit comments