Skip to content

Commit 72a374c

Browse files
TimDettmersclaude
andcommitted
Add tiled scalar GEMV v2 with shared memory + split-K
New kernel kbit_scalar_gemv_tiled_v2 that cooperatively loads full tiles into double-buffered shared memory via cp.async, then each thread reads its column from shared memory. 128 threads per block, each handling one column within the TILE_N=128 N-tile. Split-K for SM occupancy: grid = n_tiles * k_splits with atomicAdd to float32 workspace. Includes fix for empty-split bug where trailing k-splits with no work would prevent the last-block output conversion. Registered as bitsandbytes::kbit_scalar_gemv_v2_ with explicit workspace and tile_counters parameters for CUDA graph compatibility. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d834367 commit 72a374c

File tree

4 files changed

+474
-0
lines changed

4 files changed

+474
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,3 +811,37 @@ def _(
811811
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
812812
torch._check(out.dtype == A.dtype, lambda: f"out dtype {out.dtype} must match A dtype {A.dtype}")
813813
return out
814+
815+
816+
# K-bit scalar GEMV v2: tiled with shared memory + split-K (CUDA graph compatible)
817+
818+
torch.library.define(
819+
"bitsandbytes::kbit_scalar_gemv_v2_",
820+
"(Tensor A, Tensor B_packed_tiled, Tensor B_absmax_tiled, Tensor codebook, int K_dim, int N, int k, "
821+
"Tensor(a!) out, Tensor C_workspace, Tensor tile_counters) -> Tensor(a!)",
822+
)
823+
824+
825+
@register_fake("bitsandbytes::kbit_scalar_gemv_v2_")
826+
def _(
827+
A: torch.Tensor,
828+
B_packed_tiled: torch.Tensor,
829+
B_absmax_tiled: torch.Tensor,
830+
codebook: torch.Tensor,
831+
K_dim: int,
832+
N: int,
833+
k: int,
834+
out: torch.Tensor,
835+
C_workspace: torch.Tensor,
836+
tile_counters: torch.Tensor,
837+
) -> torch.Tensor:
838+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
839+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
840+
torch._check(A.shape[0] <= 4, lambda: f"kbit_scalar_gemv_v2_ supports M<=4, got {A.shape[0]}")
841+
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
842+
torch._check(out.dtype == A.dtype, lambda: f"out dtype {out.dtype} must match A dtype {A.dtype}")
843+
torch._check(C_workspace.dtype == torch.float32, lambda: f"C_workspace must be float32, got {C_workspace.dtype}")
844+
torch._check(
845+
tile_counters.dtype == torch.int32, lambda: f"tile_counters must be int32, got {tile_counters.dtype}"
846+
)
847+
return out

bitsandbytes/backends/cuda/ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,3 +1375,48 @@ def _(
13751375
_get_tensor_stream(A),
13761376
)
13771377
return out
1378+
1379+
1380+
@register_kernel("bitsandbytes::kbit_scalar_gemv_v2_", "cuda")
1381+
def _(
1382+
A: torch.Tensor,
1383+
B_packed_tiled: torch.Tensor,
1384+
B_absmax_tiled: torch.Tensor,
1385+
codebook: torch.Tensor,
1386+
K_dim: int,
1387+
N: int,
1388+
k: int,
1389+
out: torch.Tensor,
1390+
C_workspace: torch.Tensor,
1391+
tile_counters: torch.Tensor,
1392+
) -> torch.Tensor:
1393+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
1394+
torch._check(
1395+
A.dtype in (torch.float16, torch.bfloat16),
1396+
lambda: f"kbit_scalar_gemv_v2_ supports float16 and bfloat16, got {A.dtype}",
1397+
)
1398+
1399+
M = A.shape[0]
1400+
dtype_suffix = "fp16" if A.dtype == torch.float16 else "bf16"
1401+
abs_suffix = "_fp16abs" if B_absmax_tiled.dtype == torch.float16 else ""
1402+
1403+
# Zero workspace and counters (required by atomicAdd accumulation)
1404+
C_workspace.zero_()
1405+
tile_counters.zero_()
1406+
1407+
with _cuda_device_of(A):
1408+
fn = getattr(lib, f"ckbit_scalar_gemv_v2_{dtype_suffix}{abs_suffix}_k{k}")
1409+
fn(
1410+
get_ptr(A),
1411+
get_ptr(B_packed_tiled),
1412+
get_ptr(B_absmax_tiled),
1413+
get_ptr(codebook),
1414+
get_ptr(out),
1415+
get_ptr(C_workspace),
1416+
get_ptr(tile_counters),
1417+
ct.c_int(M),
1418+
ct.c_int(K_dim),
1419+
ct.c_int(N),
1420+
_get_tensor_stream(A),
1421+
)
1422+
return out

csrc/ops.cu

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,278 @@ void kbitScalarGemvTiled(
21622162
#undef LAUNCH_SCALAR_GEMV_TILED
21632163
}
21642164

2165+
// ---- Tiled Scalar GEMV v2 ----
2166+
// Cooperative tile loading into shared memory with split-K for occupancy.
2167+
// Grid = n_tiles * k_splits, Block = 128 threads (4 warps).
2168+
// Each thread handles one column within an N-tile.
2169+
// Double-buffered cp.async pipeline for B + absmax tiles.
2170+
// A loaded directly from global memory (L1 broadcast across columns).
2171+
2172+
template <int K_BITS, int M_VAL, typename scalar_t = half, typename ABSMAX_T = unsigned char>
2173+
__global__ void __launch_bounds__(128, 8) kbit_scalar_gemv_tiled_v2(
2174+
const scalar_t* __restrict__ A,
2175+
const unsigned int* __restrict__ B_packed,
2176+
const ABSMAX_T* __restrict__ B_absmax,
2177+
const float* __restrict__ codebook,
2178+
scalar_t* __restrict__ C,
2179+
float* __restrict__ C_workspace,
2180+
int* __restrict__ tile_counters,
2181+
const int M, const int K_dim, const int N, const int k_splits
2182+
) {
2183+
constexpr int BS = 32; // quantization block size
2184+
constexpr int TILE_K = 64;
2185+
constexpr int TILE_N = 128;
2186+
constexpr int BLOCK_DIM = 128; // threads per block
2187+
constexpr int NUM_WARPS = 4;
2188+
constexpr int M_MAX = 4;
2189+
constexpr int KB_PER_TILE = TILE_K / BS; // 2
2190+
constexpr int B_COL_WORDS = KB_PER_TILE * K_BITS;
2191+
constexpr int B_STAGE_WORDS = TILE_N * B_COL_WORDS;
2192+
constexpr int B_STAGE_BYTES = B_STAGE_WORDS * (int)sizeof(unsigned int);
2193+
constexpr int ABS_STAGE_ELEMS = TILE_N * KB_PER_TILE;
2194+
constexpr int ABS_STAGE_BYTES = ABS_STAGE_ELEMS * (int)sizeof(ABSMAX_T);
2195+
constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15) & ~15;
2196+
constexpr int STAGE_BYTES = B_STAGE_BYTES + ABS_STAGE_ALIGNED;
2197+
2198+
const int n_tiles = N / TILE_N;
2199+
const int k_tiles = (K_dim + TILE_K - 1) / TILE_K;
2200+
const int tiles_per_split = (k_tiles + k_splits - 1) / k_splits;
2201+
2202+
// Work item: which N-tile and K-split
2203+
const int work_id = blockIdx.x;
2204+
const int n_tile = work_id / k_splits;
2205+
const int ks_id = work_id % k_splits;
2206+
const int n_base = n_tile * TILE_N;
2207+
2208+
const int kt_start = ks_id * tiles_per_split;
2209+
const int kt_end = min(kt_start + tiles_per_split, k_tiles);
2210+
if (kt_start >= k_tiles) return;
2211+
2212+
// This thread's column within the tile
2213+
const int col_in_tile = threadIdx.x; // 0..127
2214+
const int col = n_base + col_in_tile;
2215+
2216+
const int warp_id = threadIdx.x / 32;
2217+
const int lane_id = threadIdx.x % 32;
2218+
2219+
// Codebook in registers (shuffle-based lookup)
2220+
float cb = (lane_id < (1 << K_BITS)) ? codebook[lane_id] : 0.0f;
2221+
2222+
// Double-buffered shared memory
2223+
extern __shared__ char smem[];
2224+
auto sh_b = [&](int stage) -> unsigned int* {
2225+
return reinterpret_cast<unsigned int*>(smem + stage * STAGE_BYTES);
2226+
};
2227+
auto sh_abs = [&](int stage) -> ABSMAX_T* {
2228+
return reinterpret_cast<ABSMAX_T*>(smem + stage * STAGE_BYTES + B_STAGE_BYTES);
2229+
};
2230+
2231+
// Accumulators
2232+
float acc[M_VAL];
2233+
#pragma unroll
2234+
for (int m = 0; m < M_VAL; m++) acc[m] = 0.0f;
2235+
2236+
// Fetch tile: cooperative cp.async loading of B + absmax
2237+
auto fetch_tile = [&](int stage, int kt) {
2238+
const int tile_idx = kt * n_tiles + n_tile; // K-major tile ordering
2239+
2240+
// B tile via cp.async (all 128 threads cooperatively load)
2241+
const int b_global_base = tile_idx * B_STAGE_WORDS;
2242+
constexpr int B_INT4S = B_STAGE_BYTES / 16;
2243+
const int4* b_src = reinterpret_cast<const int4*>(B_packed + b_global_base);
2244+
int4* b_dst = reinterpret_cast<int4*>(sh_b(stage));
2245+
for (int i = threadIdx.x; i < B_INT4S; i += BLOCK_DIM)
2246+
cp_async_cg_16(&b_dst[i], &b_src[i]);
2247+
2248+
// Absmax via cp.async
2249+
const int abs_global_base = tile_idx * ABS_STAGE_ELEMS;
2250+
constexpr int ABS_INT4S = (ABS_STAGE_BYTES + 15) / 16;
2251+
const int4* abs_src = reinterpret_cast<const int4*>(B_absmax + abs_global_base);
2252+
int4* abs_dst = reinterpret_cast<int4*>(sh_abs(stage));
2253+
for (int i = threadIdx.x; i < ABS_INT4S; i += BLOCK_DIM)
2254+
cp_async_cg_16(&abs_dst[i], &abs_src[i]);
2255+
};
2256+
2257+
// Compute tile: each thread reads its column from shared memory
2258+
auto compute_tile = [&](int stage, int kt) {
2259+
unsigned int* b_ptr = sh_b(stage);
2260+
ABSMAX_T* abs_ptr = sh_abs(stage);
2261+
const int k_base = kt * TILE_K;
2262+
2263+
// Process KB_PER_TILE (=2) K-blocks within this tile
2264+
#pragma unroll
2265+
for (int kb = 0; kb < KB_PER_TILE; kb++) {
2266+
const int block_k_base = k_base + kb * BS;
2267+
if (block_k_base >= K_dim) continue;
2268+
2269+
// Read bit-planes from shared memory for this column
2270+
int b_addr = col_in_tile * B_COL_WORDS + kb * K_BITS;
2271+
unsigned int planes[K_BITS];
2272+
if constexpr (K_BITS == 2) {
2273+
uint2 pv = *reinterpret_cast<const uint2*>(&b_ptr[b_addr]);
2274+
planes[0] = pv.x; planes[1] = pv.y;
2275+
} else if constexpr (K_BITS == 4) {
2276+
int4 pv = *reinterpret_cast<const int4*>(&b_ptr[b_addr]);
2277+
planes[0] = (unsigned int)pv.x; planes[1] = (unsigned int)pv.y;
2278+
planes[2] = (unsigned int)pv.z; planes[3] = (unsigned int)pv.w;
2279+
} else {
2280+
#pragma unroll
2281+
for (int b = 0; b < K_BITS; b++)
2282+
planes[b] = b_ptr[b_addr + b];
2283+
}
2284+
2285+
// Load absmax from shared memory
2286+
float amax = load_absmax(abs_ptr, col_in_tile * KB_PER_TILE + kb);
2287+
2288+
// Dequant-once loop: decode weight once, FMA across M rows
2289+
#pragma unroll
2290+
for (int sub = 0; sub < 4; sub++) {
2291+
// Load A for all M rows (int4 = 8 fp16 values)
2292+
int4 av[M_VAL];
2293+
#pragma unroll
2294+
for (int m = 0; m < M_VAL; m++)
2295+
av[m] = *reinterpret_cast<const int4*>(&A[m * K_dim + block_k_base + sub * 8]);
2296+
2297+
// Dequant each element once, then FMA across M rows
2298+
#pragma unroll
2299+
for (int j = 0; j < 8; j++) {
2300+
int idx = 0;
2301+
#pragma unroll
2302+
for (int b = 0; b < K_BITS; b++)
2303+
idx |= ((planes[b] >> (sub * 8 + j)) & 1) << b;
2304+
float w = __shfl_sync(0xFFFFFFFF, cb, idx) * amax;
2305+
2306+
#pragma unroll
2307+
for (int m = 0; m < M_VAL; m++) {
2308+
const scalar_t* ap = reinterpret_cast<const scalar_t*>(&av[m]);
2309+
acc[m] += w * ScalarOps<scalar_t>::to_float(ap[j]);
2310+
}
2311+
}
2312+
}
2313+
}
2314+
};
2315+
2316+
// Pipeline: double-buffered cp.async
2317+
fetch_tile(0, kt_start);
2318+
cp_async_fence();
2319+
2320+
for (int kt = kt_start; kt < kt_end; kt++) {
2321+
int cur = (kt - kt_start) % 2;
2322+
if (kt + 1 < kt_end) {
2323+
fetch_tile((kt + 1 - kt_start) % 2, kt + 1);
2324+
cp_async_fence();
2325+
cp_async_wait<1>();
2326+
} else {
2327+
cp_async_wait<0>();
2328+
}
2329+
__syncthreads();
2330+
compute_tile(cur, kt);
2331+
__syncthreads();
2332+
}
2333+
2334+
// Write output
2335+
if (k_splits == 1) {
2336+
// Direct write — this block owns the full K reduction
2337+
#pragma unroll
2338+
for (int m = 0; m < M_VAL; m++) {
2339+
if (m < M && col < N)
2340+
C[m * N + col] = ScalarOps<scalar_t>::from_float(acc[m]);
2341+
}
2342+
} else {
2343+
// Partial K — atomicAdd to workspace
2344+
#pragma unroll
2345+
for (int m = 0; m < M_VAL; m++) {
2346+
if (m < M && col < N)
2347+
atomicAdd(&C_workspace[m * N + col], acc[m]);
2348+
}
2349+
2350+
__threadfence();
2351+
2352+
// Last-arriving split converts workspace to output
2353+
__shared__ int is_last;
2354+
if (threadIdx.x == 0) {
2355+
int done = atomicAdd(&tile_counters[n_tile], 1);
2356+
is_last = (done == k_splits - 1) ? 1 : 0;
2357+
}
2358+
__syncthreads();
2359+
2360+
if (is_last) {
2361+
for (int i = threadIdx.x; i < M_VAL * TILE_N; i += BLOCK_DIM) {
2362+
int m = i / TILE_N;
2363+
int c = n_base + i % TILE_N;
2364+
if (m < M && c < N)
2365+
C[m * N + c] = ScalarOps<scalar_t>::from_float(C_workspace[m * N + c]);
2366+
}
2367+
}
2368+
}
2369+
}
2370+
2371+
// ---- Tiled GEMV v2 launcher ----
2372+
template <int K, int MV, typename scalar_t, typename ABSMAX_T>
2373+
static void kbitScalarGemvTiledV2Launch(
2374+
const scalar_t* A, const unsigned int* B_packed, const ABSMAX_T* B_absmax,
2375+
const float* codebook, scalar_t* C, float* C_workspace, int* tile_counters,
2376+
int M, int K_dim, int N, int num_sms, cudaStream_t stream
2377+
) {
2378+
constexpr int TILE_N = 128;
2379+
constexpr int TILE_K = 64;
2380+
constexpr int BLOCK_DIM = 128;
2381+
constexpr int BS = 32;
2382+
constexpr int KB_PER_TILE = TILE_K / BS;
2383+
constexpr int B_COL_WORDS = KB_PER_TILE * K;
2384+
constexpr int B_STAGE_BYTES = TILE_N * B_COL_WORDS * (int)sizeof(unsigned int);
2385+
constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE * (int)sizeof(ABSMAX_T);
2386+
constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15) & ~15;
2387+
constexpr int STAGE_BYTES = B_STAGE_BYTES + ABS_STAGE_ALIGNED;
2388+
2389+
int n_tiles = N / TILE_N;
2390+
int k_tiles = (K_dim + TILE_K - 1) / TILE_K;
2391+
2392+
// Choose k_splits to achieve ~4 blocks per SM.
2393+
// Recompute from tiles_per_split to guarantee no empty splits
2394+
// (empty splits would skip the tile_counters atomicAdd, breaking the last-block check).
2395+
int target_blocks = num_sms * 4;
2396+
int k_splits = max(1, (target_blocks + n_tiles - 1) / n_tiles);
2397+
k_splits = min(k_splits, k_tiles);
2398+
int tiles_per_split = (k_tiles + k_splits - 1) / k_splits;
2399+
k_splits = (k_tiles + tiles_per_split - 1) / tiles_per_split; // no empty splits
2400+
2401+
int grid_size = n_tiles * k_splits;
2402+
int smem_size = 2 * STAGE_BYTES;
2403+
2404+
kbit_scalar_gemv_tiled_v2<K, MV, scalar_t, ABSMAX_T>
2405+
<<<grid_size, BLOCK_DIM, smem_size, stream>>>(
2406+
A, B_packed, B_absmax, codebook, C, C_workspace, tile_counters,
2407+
M, K_dim, N, k_splits
2408+
);
2409+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
2410+
}
2411+
2412+
// Public entry point: selects M_VAL template, queries num_sms internally
2413+
template <int K, typename scalar_t, typename ABSMAX_T>
2414+
void kbitScalarGemvTiledV2(
2415+
const scalar_t* A, const unsigned int* B_packed, const ABSMAX_T* B_absmax,
2416+
const float* codebook, scalar_t* C, float* C_workspace, int* tile_counters,
2417+
int M, int K_dim, int N, cudaStream_t stream
2418+
) {
2419+
int dev;
2420+
cudaGetDevice(&dev);
2421+
int num_sms;
2422+
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev);
2423+
2424+
#define LAUNCH_GEMV_V2(MV) \
2425+
kbitScalarGemvTiledV2Launch<K, MV, scalar_t, ABSMAX_T>( \
2426+
A, B_packed, B_absmax, codebook, C, C_workspace, tile_counters, \
2427+
M, K_dim, N, num_sms, stream)
2428+
2429+
if (M <= 1) { LAUNCH_GEMV_V2(1); }
2430+
else if (M <= 2) { LAUNCH_GEMV_V2(2); }
2431+
else if (M <= 3) { LAUNCH_GEMV_V2(3); }
2432+
else { LAUNCH_GEMV_V2(4); }
2433+
2434+
#undef LAUNCH_GEMV_V2
2435+
}
2436+
21652437
// ---- Debug: Simple MMA test kernel ----
21662438
// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
21672439
__global__ void test_mma_kernel(const half* __restrict__ A, const half* __restrict__ B, float* __restrict__ C) {
@@ -2439,3 +2711,31 @@ INSTANTIATE_KBIT_SCALAR_GEMV_TILED_FP16(2)
24392711
INSTANTIATE_KBIT_SCALAR_GEMV_TILED_FP16(3)
24402712
INSTANTIATE_KBIT_SCALAR_GEMV_TILED_FP16(4)
24412713
INSTANTIATE_KBIT_SCALAR_GEMV_TILED_FP16(5)
2714+
// Scalar GEMV v2 (tiled with shared memory) instantiations — uint8 E4M4 absmax
2715+
#define INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8(K) \
2716+
template void kbitScalarGemvTiledV2<K, half, unsigned char>( \
2717+
const half*, const unsigned int*, const unsigned char*, const float*, half*, float*, int*, \
2718+
int, int, int, cudaStream_t \
2719+
); \
2720+
template void kbitScalarGemvTiledV2<K, __nv_bfloat16, unsigned char>( \
2721+
const __nv_bfloat16*, const unsigned int*, const unsigned char*, const float*, __nv_bfloat16*, float*, int*, \
2722+
int, int, int, cudaStream_t \
2723+
);
2724+
INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8(2)
2725+
INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8(3)
2726+
INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8(4)
2727+
INSTANTIATE_KBIT_SCALAR_GEMV_V2_U8(5)
2728+
// fp16 absmax
2729+
#define INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16(K) \
2730+
template void kbitScalarGemvTiledV2<K, half, half>( \
2731+
const half*, const unsigned int*, const half*, const float*, half*, float*, int*, \
2732+
int, int, int, cudaStream_t \
2733+
); \
2734+
template void kbitScalarGemvTiledV2<K, __nv_bfloat16, half>( \
2735+
const __nv_bfloat16*, const unsigned int*, const half*, const float*, __nv_bfloat16*, float*, int*, \
2736+
int, int, int, cudaStream_t \
2737+
);
2738+
INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16(2)
2739+
INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16(3)
2740+
INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16(4)
2741+
INSTANTIATE_KBIT_SCALAR_GEMV_V2_FP16(5)

0 commit comments

Comments
 (0)