Skip to content

Commit 33349dd

Browse files
TimDettmersclaude
andcommitted
Add optimized k-bit scalar GEMV kernel
Implements C=1 architecture with: - 1 column per block, 4 warps split K, grid=N - cp.async for A loads with double-buffered shared memory - Cached dequantized weights (reuse across M rows) - Pre-shifted bit planes for efficient index extraction Performance improvements: - k=4 M=1: 365 → 427 GB/s (17% faster, 42% DRAM) - k=5 M=1: 427 → 520 GB/s (22% faster, 52% DRAM) - k=4 M=4: 153 → 278 GB/s (82% faster, 28% DRAM) Supports k=2,3,4,5 and M=1,2,3,4 with fp16/bf16. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent aba14e8 commit 33349dd

File tree

3 files changed

+271
-0
lines changed

3 files changed

+271
-0
lines changed

csrc/ops.cu

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,6 +2546,208 @@ void kbitGroupedGemmProd(
25462546
CUDA_CHECK_RETURN(cudaFree(d_work_offsets));
25472547
}
25482548

2549+
// ===================================================================
2550+
// K-bit scalar GEMV: C[M,N] = A[M,K] * W_kbit^T (M=1..4)
2551+
// ===================================================================
2552+
//
2553+
// C=1 architecture: 1 output column per block, 4 warps split K.
2554+
// Grid = N (direct mapping). No split-K, no workspace.
2555+
// Element-at-a-time inner loop for low register pressure (~30 regs).
2556+
// Two-phase shared memory reduction (warp shuffle + shmem).
2557+
// B_packed and B_absmax are in flat (quantize_kbit) layout, no repack needed.
2558+
2559+
template <int K_BITS, int M_VAL, typename scalar_t>
2560+
__global__ void __launch_bounds__(128, 12)
2561+
kbit_scalar_gemv(
2562+
const scalar_t* __restrict__ A,
2563+
const unsigned int* __restrict__ B_packed, // flat: [N * num_k_blocks * K_BITS] uint32
2564+
const float* __restrict__ B_absmax, // flat: [N * num_k_blocks] float32
2565+
const float* __restrict__ codebook,
2566+
scalar_t* __restrict__ C,
2567+
const int M, const int K_dim, const int N
2568+
) {
2569+
constexpr int BS = 32; // quantization block size
2570+
constexpr int NUM_WARPS = 4;
2571+
constexpr int M_MAX = 4;
2572+
2573+
const int warp_id = threadIdx.x / 32;
2574+
const int lane_id = threadIdx.x % 32;
2575+
const int col = blockIdx.x;
2576+
2577+
const int num_k_blocks = K_dim / BS;
2578+
2579+
// Codebook in registers (shuffle-based lookup)
2580+
float cb = (lane_id < (1 << K_BITS)) ? codebook[lane_id] : 0.0f;
2581+
2582+
// Column base pointers (flat layout)
2583+
const unsigned int* B_col = B_packed + col * num_k_blocks * K_BITS;
2584+
const float* abs_col = B_absmax + col * num_k_blocks;
2585+
2586+
// Accumulators
2587+
float acc[M_VAL];
2588+
#pragma unroll
2589+
for (int m = 0; m < M_VAL; m++) acc[m] = 0.0f;
2590+
2591+
// All 128 threads stride through K blocks: thread t handles blocks t, t+128, t+256, ...
2592+
// max_iters ensures all lanes iterate same number of times (no warp divergence at __shfl_sync).
2593+
const int max_iters = (num_k_blocks + 127) / 128;
2594+
2595+
for (int iter = 0; iter < max_iters; iter++) {
2596+
const int block_idx = threadIdx.x + iter * 128;
2597+
const bool valid = (block_idx < num_k_blocks);
2598+
2599+
// Load k bit-plane words (guarded; invalid threads get 0)
2600+
unsigned int planes[K_BITS];
2601+
#pragma unroll
2602+
for (int b = 0; b < K_BITS; b++)
2603+
planes[b] = valid ? B_col[block_idx * K_BITS + b] : 0u;
2604+
2605+
// Load absmax (guarded; invalid threads get 0)
2606+
float amax = valid ? abs_col[block_idx] : 0.0f;
2607+
2608+
const int k_base = block_idx * BS;
2609+
2610+
// Double-buffered shared memory for async A loads
2611+
// Each thread loads 16 bytes per M row per sub-iteration
2612+
__shared__ int4 sh_A[2][M_MAX][128]; // [buffer][m][thread] = 2 * 4 * 128 * 16 = 16KB max
2613+
2614+
// Prefetch first sub-iteration before the loop
2615+
#pragma unroll
2616+
for (int m = 0; m < M_VAL; m++) {
2617+
if (valid) {
2618+
const int4* a_ptr = reinterpret_cast<const int4*>(
2619+
&A[m * K_dim + k_base + 0 * 8]);
2620+
asm volatile("cp.async.ca.shared.global [%0], [%1], 16;" :
2621+
: "r"((unsigned)&sh_A[0][m][threadIdx.x]), "l"((unsigned long long)a_ptr));
2622+
}
2623+
}
2624+
asm volatile("cp.async.commit_group;");
2625+
2626+
#pragma unroll
2627+
for (int sub = 0; sub < 4; sub++) {
2628+
const int buf = sub % 2;
2629+
const int next_buf = (sub + 1) % 2;
2630+
2631+
// Wait for previous prefetch to complete
2632+
asm volatile("cp.async.wait_group %0;" : : "n"(0));
2633+
2634+
// Prefetch next sub-iteration (if not last)
2635+
if (sub + 1 < 4) {
2636+
#pragma unroll
2637+
for (int m = 0; m < M_VAL; m++) {
2638+
if (valid) {
2639+
const int4* a_ptr = reinterpret_cast<const int4*>(
2640+
&A[m * K_dim + k_base + (sub + 1) * 8]);
2641+
asm volatile("cp.async.ca.shared.global [%0], [%1], 16;" :
2642+
: "r"((unsigned)&sh_A[next_buf][m][threadIdx.x]), "l"((unsigned long long)a_ptr));
2643+
}
2644+
}
2645+
asm volatile("cp.async.commit_group;");
2646+
}
2647+
2648+
// Cache 8 dequantized weights (same for all M rows)
2649+
// Optimized: pre-shift planes once, then extract with fewer ops
2650+
unsigned int sp[K_BITS];
2651+
#pragma unroll
2652+
for (int b = 0; b < K_BITS; b++)
2653+
sp[b] = planes[b] >> (sub * 8);
2654+
2655+
float w8[8];
2656+
#pragma unroll
2657+
for (int j = 0; j < 8; j++) {
2658+
// Extract j-th bit from each plane and combine into index
2659+
int idx = 0;
2660+
#pragma unroll
2661+
for (int b = 0; b < K_BITS; b++)
2662+
idx |= ((sp[b] >> j) & 1) << b;
2663+
w8[j] = __shfl_sync(0xFFFFFFFF, cb, idx) * amax;
2664+
}
2665+
2666+
// Compute using prefetched A from shared memory
2667+
#pragma unroll
2668+
for (int m = 0; m < M_VAL; m++) {
2669+
const int4 av = sh_A[buf][m][threadIdx.x];
2670+
const scalar_t* ap = reinterpret_cast<const scalar_t*>(&av);
2671+
2672+
// FMA using cached weights
2673+
#pragma unroll
2674+
for (int j = 0; j < 8; j++) {
2675+
if (valid)
2676+
acc[m] += w8[j] * ScalarOps<scalar_t>::to_float(ap[j]);
2677+
}
2678+
}
2679+
}
2680+
}
2681+
2682+
// Phase 1: Intra-warp reduction via shuffle
2683+
#pragma unroll
2684+
for (int m = 0; m < M_VAL; m++) {
2685+
#pragma unroll
2686+
for (int offset = 16; offset >= 1; offset /= 2)
2687+
acc[m] += __shfl_down_sync(0xFFFFFFFF, acc[m], offset);
2688+
}
2689+
2690+
// Phase 2: Inter-warp reduction via shared memory
2691+
__shared__ float s_partial[NUM_WARPS * M_MAX];
2692+
2693+
if (lane_id == 0) {
2694+
#pragma unroll
2695+
for (int m = 0; m < M_VAL; m++)
2696+
s_partial[warp_id * M_MAX + m] = acc[m];
2697+
}
2698+
__syncthreads();
2699+
2700+
// Thread 0 sums all 4 warps and writes output
2701+
if (threadIdx.x == 0) {
2702+
#pragma unroll
2703+
for (int m = 0; m < M_VAL; m++) {
2704+
if (m < M) {
2705+
float sum = s_partial[0 * M_MAX + m] + s_partial[1 * M_MAX + m]
2706+
+ s_partial[2 * M_MAX + m] + s_partial[3 * M_MAX + m];
2707+
C[m * N + col] = ScalarOps<scalar_t>::from_float(sum);
2708+
}
2709+
}
2710+
}
2711+
}
2712+
2713+
// ---- Scalar GEMV launcher ----
2714+
template <int K, int MV, typename scalar_t>
2715+
static void kbitScalarGemvLaunch(
2716+
const scalar_t* A, const unsigned int* B_packed,
2717+
const float* B_absmax, const float* codebook,
2718+
scalar_t* C, int M, int K_dim, int N
2719+
) {
2720+
constexpr int BLOCK_SIZE = 128; // 4 warps
2721+
int grid_size = N; // C=1: one block per output column
2722+
2723+
kbit_scalar_gemv<K, MV, scalar_t><<<grid_size, BLOCK_SIZE>>>(
2724+
A, B_packed, B_absmax, codebook, C, M, K_dim, N);
2725+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
2726+
}
2727+
2728+
// Public entry point: selects M_VAL template
2729+
template <int K, typename scalar_t>
2730+
void kbitScalarGemv(
2731+
const scalar_t* A, const unsigned int* B_packed,
2732+
const float* B_absmax, const float* codebook,
2733+
scalar_t* C, int M, int K_dim, int N
2734+
) {
2735+
#define LAUNCH_SCALAR_GEMV(MV) \
2736+
kbitScalarGemvLaunch<K, MV, scalar_t>( \
2737+
A, B_packed, B_absmax, codebook, C, M, K_dim, N)
2738+
2739+
if (M <= 1) { LAUNCH_SCALAR_GEMV(1); }
2740+
else if (M <= 2) { LAUNCH_SCALAR_GEMV(2); }
2741+
else if (M <= 3) { LAUNCH_SCALAR_GEMV(3); }
2742+
else { LAUNCH_SCALAR_GEMV(4); }
2743+
2744+
#undef LAUNCH_SCALAR_GEMV
2745+
}
2746+
2747+
// ===================================================================
2748+
// Grouped scalar GEMV: MoE expert dispatch (no split-K needed)
2749+
// ===================================================================
2750+
25492751
// ---- Debug: Simple MMA test kernel ----
25502752
// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
25512753
__global__ void test_mma_kernel(const half* __restrict__ A, const half* __restrict__ B, float* __restrict__ C) {
@@ -2694,3 +2896,13 @@ INSTANTIATE_KBIT_GROUPED_GEMM_PROD(2)
26942896
INSTANTIATE_KBIT_GROUPED_GEMM_PROD(3)
26952897
INSTANTIATE_KBIT_GROUPED_GEMM_PROD(4)
26962898
INSTANTIATE_KBIT_GROUPED_GEMM_PROD(5)
2899+
2900+
// Scalar GEMV instantiations (fp16 and bf16)
2901+
#define INSTANTIATE_SCALAR_GEMV(K) \
2902+
template void kbitScalarGemv<K, half>(const half*, const unsigned int*, const float*, const float*, half*, int, int, int); \
2903+
template void kbitScalarGemv<K, __nv_bfloat16>(const __nv_bfloat16*, const unsigned int*, const float*, const float*, __nv_bfloat16*, int, int, int);
2904+
2905+
INSTANTIATE_SCALAR_GEMV(2)
2906+
INSTANTIATE_SCALAR_GEMV(3)
2907+
INSTANTIATE_SCALAR_GEMV(4)
2908+
INSTANTIATE_SCALAR_GEMV(5)

csrc/ops.cuh

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,41 @@ void gemm_4bit_inference_naive(
187187

188188
template <typename T, int FUNC> void func(T* A, T* B, T value, long n);
189189

190+
// ===================================================================
191+
// K-bit scalar GEMV: C[M,N] = A[M,K] * W_kbit^T (M=1..4)
192+
// ===================================================================
193+
//
194+
// Computes M rows of GEMV where weights are k-bit quantized.
195+
// No split-K, no workspace. Grid=N direct mapping.
196+
//
197+
// Parameters:
198+
// K_BITS: 2, 3, 4, or 5 (quantization bit width)
199+
// M_VAL: 1, 2, 3, or 4 (batch size per call)
200+
// scalar_t: half or __nv_bfloat16
201+
//
202+
// Layout:
203+
// A: [M, K_dim] row-major fp16/bf16
204+
// B_packed: [N * num_k_blocks * K_BITS] uint32 (flat, no repack needed)
205+
// B_absmax: [N * num_k_blocks] float32
206+
// codebook: [2^K_BITS] float32
207+
// C: [M, N] row-major fp16/bf16
208+
209+
template <int K, typename scalar_t>
210+
void kbitScalarGemv(
211+
const scalar_t* A, const unsigned int* B_packed,
212+
const float* B_absmax, const float* codebook,
213+
scalar_t* C, int M, int K_dim, int N);
214+
215+
// Grouped scalar GEMV for MoE expert dispatch
216+
template <int K, typename scalar_t>
217+
void kbitGroupedScalarGemv(
218+
const scalar_t* A_concat,
219+
const unsigned int* B_packed_all,
220+
const unsigned char* B_absmax_all,
221+
const float* codebook,
222+
scalar_t* C_concat,
223+
const int* expert_offsets,
224+
const int* work_offsets,
225+
int K_dim, int N, int num_experts, int total_work);
226+
190227
#endif

csrc/pythonInterface.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,28 @@ MAKE_KBIT_GROUPED_GEMM_PROD(3)
550550
MAKE_KBIT_GROUPED_GEMM_PROD(4)
551551
MAKE_KBIT_GROUPED_GEMM_PROD(5)
552552

553+
// Forward declarations of scalar GEMV launchers
554+
template <int K> void kbitScalarGemv(const half*, const unsigned int*, const float*, const float*, half*, int, int, int);
555+
template <int K> void kbitScalarGemv(const __nv_bfloat16*, const unsigned int*, const float*, const float*, __nv_bfloat16*, int, int, int);
556+
557+
// Scalar GEMV extern C wrappers (fp16 and bf16) — flat layout, no workspace
558+
#define MAKE_KBIT_SCALAR_GEMV(K) \
559+
void kbit_scalar_gemv_fp16_##K( \
560+
const half* A, const unsigned int* B_packed, const float* B_absmax, \
561+
const float* codebook, half* C, int M, int K_dim, int N) { \
562+
kbitScalarGemv<K>(A, B_packed, B_absmax, codebook, C, M, K_dim, N); \
563+
} \
564+
void kbit_scalar_gemv_bf16_##K( \
565+
const __nv_bfloat16* A, const unsigned int* B_packed, const float* B_absmax, \
566+
const float* codebook, __nv_bfloat16* C, int M, int K_dim, int N) { \
567+
kbitScalarGemv<K>(A, B_packed, B_absmax, codebook, C, M, K_dim, N); \
568+
}
569+
570+
MAKE_KBIT_SCALAR_GEMV(2)
571+
MAKE_KBIT_SCALAR_GEMV(3)
572+
MAKE_KBIT_SCALAR_GEMV(4)
573+
MAKE_KBIT_SCALAR_GEMV(5)
574+
553575
// Debug MMA test
554576
void testMMA(const half*, const half*, float*);
555577

0 commit comments

Comments
 (0)