Skip to content

Commit 52fe6af

Browse files
TimDettmersclaude
andcommitted
feat: Template hand-written NVFP4 GEMM for BF16/FP32 output
Template kGemmNVFP4_smem on output type (float, __nv_bfloat16, half). For split-K: accumulates in FP32 workspace via atomicAdd, then runs a tiny conversion kernel. Non-split-K stores directly with type conversion. New C entry points: cgemm_nvfp4_bf16, cgemm_nvfp4_bf16_splitk New Python wrapper: _gemm_nvfp4_hw_bf16_raw Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 240d9af commit 52fe6af

File tree

2 files changed

+154
-45
lines changed

2 files changed

+154
-45
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,8 +1049,12 @@ def _(scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
10491049
# Hand-written NVFP4 GEMM (SM_120+)
10501050
#
10511051
# Uses mma.sync.aligned.block_scale instructions for small-M decode.
1052-
# Expects flat (non-swizzled) row-major scales. Output is FP32.
1052+
# Expects flat (non-swizzled) row-major scales.
10531053
# Uses automatic split-K when tile count is low relative to SM count.
1054+
#
1055+
# Output variants:
1056+
# _gemm_nvfp4_hw_raw — FP32 output (cgemm_nvfp4)
1057+
# _gemm_nvfp4_hw_bf16_raw — BF16 output (cgemm_nvfp4_bf16), needs FP32 workspace for split-K
10541058
def _gemm_nvfp4_hw_raw(
10551059
A_packed: torch.Tensor,
10561060
B_packed: torch.Tensor,
@@ -1061,7 +1065,7 @@ def _gemm_nvfp4_hw_raw(
10611065
N: int,
10621066
K: int,
10631067
) -> None:
1064-
"""Raw hand-written NVFP4 GEMM — zero allocations, CUDA-graph-safe.
1068+
"""Raw hand-written NVFP4 GEMM (FP32 output) — zero allocations, CUDA-graph-safe.
10651069
10661070
All buffers must be pre-allocated. D_out must be FP32 of shape (M, N).
10671071
Scales are flat row-major (not swizzled). Uses auto split-K internally
@@ -1080,6 +1084,37 @@ def _gemm_nvfp4_hw_raw(
10801084
)
10811085

10821086

1087+
def _gemm_nvfp4_hw_bf16_raw(
1088+
A_packed: torch.Tensor,
1089+
B_packed: torch.Tensor,
1090+
A_scales: torch.Tensor,
1091+
B_scales: torch.Tensor,
1092+
D_out: torch.Tensor,
1093+
workspace: torch.Tensor,
1094+
M: int,
1095+
N: int,
1096+
K: int,
1097+
) -> None:
1098+
"""Raw hand-written NVFP4 GEMM (BF16 output) — zero allocations, CUDA-graph-safe.
1099+
1100+
All buffers must be pre-allocated. D_out must be BF16 of shape (M, N).
1101+
workspace must be FP32 of shape (M, N) — used for split-K accumulation.
1102+
Scales are flat row-major (not swizzled).
1103+
"""
1104+
lib.cgemm_nvfp4_bf16(
1105+
get_ptr(A_packed),
1106+
get_ptr(B_packed),
1107+
get_ptr(A_scales),
1108+
get_ptr(B_scales),
1109+
get_ptr(D_out),
1110+
get_ptr(workspace),
1111+
ct.c_int(M),
1112+
ct.c_int(N),
1113+
ct.c_int(K),
1114+
_get_tensor_stream(A_packed),
1115+
)
1116+
1117+
10831118
def _gemm_nvfp4_hw_splitk_raw(
10841119
A_packed: torch.Tensor,
10851120
B_packed: torch.Tensor,
@@ -1091,7 +1126,7 @@ def _gemm_nvfp4_hw_splitk_raw(
10911126
K: int,
10921127
split_k: int,
10931128
) -> None:
1094-
"""Raw hand-written NVFP4 GEMM with explicit split-K — CUDA-graph-safe."""
1129+
"""Raw hand-written NVFP4 GEMM with explicit split-K (FP32 output)."""
10951130
lib.cgemm_nvfp4_splitk(
10961131
get_ptr(A_packed),
10971132
get_ptr(B_packed),

csrc/kernels_nvfp4_sm120.cu

Lines changed: 116 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <cuda_bf16.h>
1616
#include <cuda_fp16.h>
1717
#include <cuda_runtime.h>
18+
#include <type_traits>
1819

1920
// ============================================================================
2021
// MMA wrapper: m16n8k64 E2M1 x E2M1 -> F32 with UE4M3 block scales
@@ -96,13 +97,36 @@ __device__ __forceinline__ uint32_t
9697
#define SMEM_SFB_BYTES (BLOCK_N_DIM * 4) // 512
9798
#define SMEM_TOTAL (SMEM_A_BYTES + SMEM_B_BYTES + SMEM_SFA_BYTES + SMEM_SFB_BYTES)
9899

100+
// ============================================================================
101+
// Output conversion helpers
102+
// ============================================================================
103+
template <typename T> __device__ __forceinline__ T float_to_out(float v);
104+
105+
template <> __device__ __forceinline__ float float_to_out<float>(float v) { return v; }
106+
107+
template <> __device__ __forceinline__ __nv_bfloat16 float_to_out<__nv_bfloat16>(float v) {
108+
return __float2bfloat16(v);
109+
}
110+
111+
template <> __device__ __forceinline__ half float_to_out<half>(float v) { return __float2half(v); }
112+
113+
// Tiny kernel: convert FP32 workspace to OutT after split-K reduction
114+
template <typename OutT> __global__ void kConvertOutput(const float* __restrict__ src, OutT* __restrict__ dst, int n) {
115+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
116+
if (idx < n) {
117+
dst[idx] = float_to_out<OutT>(src[idx]);
118+
}
119+
}
120+
99121
// 256 threads, target 4 blocks/SM for occupancy
122+
template <typename OutT>
100123
__global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 4) void kGemmNVFP4_smem(
101124
const unsigned char* __restrict__ A, // M x K/2 packed FP4 (row-major)
102125
const unsigned char* __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major)
103126
const unsigned char* __restrict__ SFA, // M x K/16 UE4M3 scales
104127
const unsigned char* __restrict__ SFB, // N x K/16 UE4M3 scales
105-
float* __restrict__ D, // M x N output (F32)
128+
OutT* __restrict__ D, // M x N output
129+
float* __restrict__ D_splitk, // M x N FP32 workspace (only used when split-K > 1)
106130
int M, int N, int K
107131
) {
108132
// Split-K: compute this block's K-range from blockIdx.z / gridDim.z
@@ -321,39 +345,41 @@ __global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 4) void kGemmNVFP4_smem(
321345
#undef COMPUTE_STEP
322346

323347
// ---- Write output ----
324-
// Use atomicAdd when split-K is active (gridDim.z > 1) to accumulate
325-
// partial results from different K-slices
348+
// split-K (gridDim.z > 1): atomicAdd to FP32 workspace, host converts later
349+
// no split-K: convert and store directly to typed output
326350
int octet = lane_id / 4;
327351
int quad = lane_id % 4;
328352
int out_row0 = tile_m + octet * 2;
329353
int out_row1 = out_row0 + 1;
330354
int out_col_base = quad * 2;
331-
const bool use_atomic = (gridDim.z > 1);
355+
const bool use_splitk = (gridDim.z > 1);
332356

333357
#pragma unroll
334358
for (int nt = 0; nt < N_TILES_PER_WARP; nt++) {
335359
int this_tile_n = warp_n_base + nt * 8;
336360
int c0 = this_tile_n + out_col_base;
337361
int c1 = c0 + 1;
338362

339-
if (use_atomic) {
363+
if (use_splitk) {
364+
// Accumulate partial sums in FP32 workspace via atomicAdd
340365
if (out_row0 < M && c0 < N)
341-
atomicAdd(&D[out_row0 * N + c0], acc[nt][0]);
366+
atomicAdd(&D_splitk[out_row0 * N + c0], acc[nt][0]);
342367
if (out_row0 < M && c1 < N)
343-
atomicAdd(&D[out_row0 * N + c1], acc[nt][1]);
368+
atomicAdd(&D_splitk[out_row0 * N + c1], acc[nt][1]);
344369
if (out_row1 < M && c0 < N)
345-
atomicAdd(&D[out_row1 * N + c0], acc[nt][2]);
370+
atomicAdd(&D_splitk[out_row1 * N + c0], acc[nt][2]);
346371
if (out_row1 < M && c1 < N)
347-
atomicAdd(&D[out_row1 * N + c1], acc[nt][3]);
372+
atomicAdd(&D_splitk[out_row1 * N + c1], acc[nt][3]);
348373
} else {
374+
// Direct store with type conversion (no split-K)
349375
if (out_row0 < M && c0 < N)
350-
D[out_row0 * N + c0] = acc[nt][0];
376+
D[out_row0 * N + c0] = float_to_out<OutT>(acc[nt][0]);
351377
if (out_row0 < M && c1 < N)
352-
D[out_row0 * N + c1] = acc[nt][1];
378+
D[out_row0 * N + c1] = float_to_out<OutT>(acc[nt][1]);
353379
if (out_row1 < M && c0 < N)
354-
D[out_row1 * N + c0] = acc[nt][2];
380+
D[out_row1 * N + c0] = float_to_out<OutT>(acc[nt][2]);
355381
if (out_row1 < M && c1 < N)
356-
D[out_row1 * N + c1] = acc[nt][3];
382+
D[out_row1 * N + c1] = float_to_out<OutT>(acc[nt][3]);
357383
}
358384
}
359385
}
@@ -515,68 +541,116 @@ __global__ void kGemmNVFP4_simple(
515541
// RTX PRO 6000: 84 SMs
516542
static const int NUM_SMS = 84;
517543

518-
extern "C" void cgemm_nvfp4(
519-
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M,
520-
int N, int K, cudaStream_t stream
521-
) {
522-
int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM;
523-
int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM;
524-
int base_blocks = num_m_blocks * num_n_blocks;
525-
int threads_per_block = WARPS_PER_BLOCK * 32; // 256
526-
527-
// Auto split-K: split along K to fill the GPU when M/N tiles are sparse
528-
// Two-tier heuristic based on GPU occupancy:
529-
// - Very sparse (<1 block/SM): aggressive split to 4 blocks/SM
530-
// - Moderate (<2 blocks/SM): gentle split to 2 blocks/SM
531-
// - Sufficient (>=2 blocks/SM): no split
544+
// ============================================================================
545+
// Auto split-K heuristic (shared by all launchers)
546+
// ============================================================================
547+
static int compute_split_k(int base_blocks, int K) {
532548
int max_k_splits = K / 64;
533549
int split_k = 1;
534550
if (base_blocks < NUM_SMS && max_k_splits > 1) {
535-
// Very sparse: target 4 blocks/SM for full occupancy
536551
int target = NUM_SMS * 4;
537552
split_k = (target + base_blocks - 1) / base_blocks;
538553
if (split_k > max_k_splits)
539554
split_k = max_k_splits;
540555
if (split_k > 16)
541556
split_k = 16;
542557
} else if (base_blocks < NUM_SMS * 2 && max_k_splits > 1) {
543-
// Moderate: target 2 blocks/SM
544558
int target = NUM_SMS * 2;
545559
split_k = (target + base_blocks - 1) / base_blocks;
546560
if (split_k > max_k_splits)
547561
split_k = max_k_splits;
548562
if (split_k > 4)
549-
split_k = 4; // limit atomicAdd overhead for larger outputs
563+
split_k = 4;
550564
}
565+
return split_k;
566+
}
567+
568+
// ============================================================================
569+
// Generic typed launcher: works for float, __nv_bfloat16, half
570+
// ============================================================================
571+
template <typename OutT>
572+
static void launch_gemm_nvfp4(
573+
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, OutT* D,
574+
float* workspace, int M, int N, int K, int split_k, cudaStream_t stream
575+
) {
576+
int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM;
577+
int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM;
578+
int threads_per_block = WARPS_PER_BLOCK * 32;
551579

552-
// Zero output when using split-K (atomicAdd requires zeroed buffer)
553580
if (split_k > 1) {
554-
cudaMemsetAsync(D, 0, (size_t)M * N * sizeof(float), stream);
581+
// Split-K: accumulate in FP32 workspace, then convert to OutT
582+
cudaMemsetAsync(workspace, 0, (size_t)M * N * sizeof(float), stream);
583+
dim3 grid(num_n_blocks, num_m_blocks, split_k);
584+
kGemmNVFP4_smem<OutT><<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, workspace, M, N, K);
585+
586+
// Convert FP32 workspace → OutT output (skip for FP32 when workspace == (float*)D)
587+
if constexpr (!std::is_same_v<OutT, float>) {
588+
int n_elem = M * N;
589+
int conv_threads = 256;
590+
int conv_blocks = (n_elem + conv_threads - 1) / conv_threads;
591+
kConvertOutput<OutT><<<conv_blocks, conv_threads, 0, stream>>>(workspace, D, n_elem);
592+
}
593+
} else {
594+
// No split-K: direct typed output
595+
dim3 grid(num_n_blocks, num_m_blocks, 1);
596+
kGemmNVFP4_smem<OutT><<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, nullptr, M, N, K);
555597
}
598+
}
556599

557-
dim3 grid(num_n_blocks, num_m_blocks, split_k);
558-
kGemmNVFP4_smem<<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, M, N, K);
600+
// ============================================================================
601+
// C entry points — FP32 output (backward compatible)
602+
// ============================================================================
603+
extern "C" void cgemm_nvfp4(
604+
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M,
605+
int N, int K, cudaStream_t stream
606+
) {
607+
int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM;
608+
int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM;
609+
int base_blocks = num_m_blocks * num_n_blocks;
610+
int split_k = compute_split_k(base_blocks, K);
611+
612+
// FP32 output: D serves as both output and workspace for split-K
613+
launch_gemm_nvfp4<float>(A, B, SFA, SFB, D, D, M, N, K, split_k, stream);
559614
}
560615

561-
// Overload: caller specifies split-K explicitly (for benchmarking)
562616
extern "C" void cgemm_nvfp4_splitk(
563617
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M,
564618
int N, int K, int split_k, cudaStream_t stream
619+
) {
620+
if (split_k < 1)
621+
split_k = 1;
622+
int max_k_splits = K / 64;
623+
if (split_k > max_k_splits)
624+
split_k = max_k_splits;
625+
626+
// FP32 output: D serves as both output and workspace
627+
launch_gemm_nvfp4<float>(A, B, SFA, SFB, D, D, M, N, K, split_k, stream);
628+
}
629+
630+
// ============================================================================
631+
// C entry points — BF16 output
632+
// ============================================================================
633+
extern "C" void cgemm_nvfp4_bf16(
634+
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB,
635+
__nv_bfloat16* D, float* workspace, int M, int N, int K, cudaStream_t stream
565636
) {
566637
int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM;
567638
int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM;
568-
int threads_per_block = WARPS_PER_BLOCK * 32;
639+
int base_blocks = num_m_blocks * num_n_blocks;
640+
int split_k = compute_split_k(base_blocks, K);
641+
642+
launch_gemm_nvfp4<__nv_bfloat16>(A, B, SFA, SFB, D, workspace, M, N, K, split_k, stream);
643+
}
569644

645+
extern "C" void cgemm_nvfp4_bf16_splitk(
646+
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB,
647+
__nv_bfloat16* D, float* workspace, int M, int N, int K, int split_k, cudaStream_t stream
648+
) {
570649
if (split_k < 1)
571650
split_k = 1;
572651
int max_k_splits = K / 64;
573652
if (split_k > max_k_splits)
574653
split_k = max_k_splits;
575654

576-
if (split_k > 1) {
577-
cudaMemsetAsync(D, 0, (size_t)M * N * sizeof(float), stream);
578-
}
579-
580-
dim3 grid(num_n_blocks, num_m_blocks, split_k);
581-
kGemmNVFP4_smem<<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, M, N, K);
655+
launch_gemm_nvfp4<__nv_bfloat16>(A, B, SFA, SFB, D, workspace, M, N, K, split_k, stream);
582656
}

0 commit comments

Comments
 (0)