|
15 | 15 | #include <cuda_bf16.h> |
16 | 16 | #include <cuda_fp16.h> |
17 | 17 | #include <cuda_runtime.h> |
| 18 | +#include <type_traits> |
18 | 19 |
|
19 | 20 | // ============================================================================ |
20 | 21 | // MMA wrapper: m16n8k64 E2M1 x E2M1 -> F32 with UE4M3 block scales |
@@ -96,13 +97,36 @@ __device__ __forceinline__ uint32_t |
96 | 97 | #define SMEM_SFB_BYTES (BLOCK_N_DIM * 4) // 512 |
97 | 98 | #define SMEM_TOTAL (SMEM_A_BYTES + SMEM_B_BYTES + SMEM_SFA_BYTES + SMEM_SFB_BYTES) |
98 | 99 |
|
| 100 | +// ============================================================================ |
| 101 | +// Output conversion helpers |
| 102 | +// ============================================================================ |
| 103 | +template <typename T> __device__ __forceinline__ T float_to_out(float v); |
| 104 | + |
| 105 | +template <> __device__ __forceinline__ float float_to_out<float>(float v) { return v; } |
| 106 | + |
| 107 | +template <> __device__ __forceinline__ __nv_bfloat16 float_to_out<__nv_bfloat16>(float v) { |
| 108 | + return __float2bfloat16(v); |
| 109 | +} |
| 110 | + |
| 111 | +template <> __device__ __forceinline__ half float_to_out<half>(float v) { return __float2half(v); } |
| 112 | + |
| 113 | +// Tiny kernel: convert FP32 workspace to OutT after split-K reduction |
| 114 | +template <typename OutT> __global__ void kConvertOutput(const float* __restrict__ src, OutT* __restrict__ dst, int n) { |
| 115 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 116 | + if (idx < n) { |
| 117 | + dst[idx] = float_to_out<OutT>(src[idx]); |
| 118 | + } |
| 119 | +} |
| 120 | + |
99 | 121 | // 256 threads, target 4 blocks/SM for occupancy |
| 122 | +template <typename OutT> |
100 | 123 | __global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 4) void kGemmNVFP4_smem( |
101 | 124 | const unsigned char* __restrict__ A, // M x K/2 packed FP4 (row-major) |
102 | 125 | const unsigned char* __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major) |
103 | 126 | const unsigned char* __restrict__ SFA, // M x K/16 UE4M3 scales |
104 | 127 | const unsigned char* __restrict__ SFB, // N x K/16 UE4M3 scales |
105 | | - float* __restrict__ D, // M x N output (F32) |
| 128 | + OutT* __restrict__ D, // M x N output |
| 129 | + float* __restrict__ D_splitk, // M x N FP32 workspace (only used when split-K > 1) |
106 | 130 | int M, int N, int K |
107 | 131 | ) { |
108 | 132 | // Split-K: compute this block's K-range from blockIdx.z / gridDim.z |
@@ -321,39 +345,41 @@ __global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 4) void kGemmNVFP4_smem( |
321 | 345 | #undef COMPUTE_STEP |
322 | 346 |
|
323 | 347 | // ---- Write output ---- |
324 | | - // Use atomicAdd when split-K is active (gridDim.z > 1) to accumulate |
325 | | - // partial results from different K-slices |
| 348 | + // split-K (gridDim.z > 1): atomicAdd to FP32 workspace, host converts later |
| 349 | + // no split-K: convert and store directly to typed output |
326 | 350 | int octet = lane_id / 4; |
327 | 351 | int quad = lane_id % 4; |
328 | 352 | int out_row0 = tile_m + octet * 2; |
329 | 353 | int out_row1 = out_row0 + 1; |
330 | 354 | int out_col_base = quad * 2; |
331 | | - const bool use_atomic = (gridDim.z > 1); |
| 355 | + const bool use_splitk = (gridDim.z > 1); |
332 | 356 |
|
333 | 357 | #pragma unroll |
334 | 358 | for (int nt = 0; nt < N_TILES_PER_WARP; nt++) { |
335 | 359 | int this_tile_n = warp_n_base + nt * 8; |
336 | 360 | int c0 = this_tile_n + out_col_base; |
337 | 361 | int c1 = c0 + 1; |
338 | 362 |
|
339 | | - if (use_atomic) { |
| 363 | + if (use_splitk) { |
| 364 | + // Accumulate partial sums in FP32 workspace via atomicAdd |
340 | 365 | if (out_row0 < M && c0 < N) |
341 | | - atomicAdd(&D[out_row0 * N + c0], acc[nt][0]); |
| 366 | + atomicAdd(&D_splitk[out_row0 * N + c0], acc[nt][0]); |
342 | 367 | if (out_row0 < M && c1 < N) |
343 | | - atomicAdd(&D[out_row0 * N + c1], acc[nt][1]); |
| 368 | + atomicAdd(&D_splitk[out_row0 * N + c1], acc[nt][1]); |
344 | 369 | if (out_row1 < M && c0 < N) |
345 | | - atomicAdd(&D[out_row1 * N + c0], acc[nt][2]); |
| 370 | + atomicAdd(&D_splitk[out_row1 * N + c0], acc[nt][2]); |
346 | 371 | if (out_row1 < M && c1 < N) |
347 | | - atomicAdd(&D[out_row1 * N + c1], acc[nt][3]); |
| 372 | + atomicAdd(&D_splitk[out_row1 * N + c1], acc[nt][3]); |
348 | 373 | } else { |
| 374 | + // Direct store with type conversion (no split-K) |
349 | 375 | if (out_row0 < M && c0 < N) |
350 | | - D[out_row0 * N + c0] = acc[nt][0]; |
| 376 | + D[out_row0 * N + c0] = float_to_out<OutT>(acc[nt][0]); |
351 | 377 | if (out_row0 < M && c1 < N) |
352 | | - D[out_row0 * N + c1] = acc[nt][1]; |
| 378 | + D[out_row0 * N + c1] = float_to_out<OutT>(acc[nt][1]); |
353 | 379 | if (out_row1 < M && c0 < N) |
354 | | - D[out_row1 * N + c0] = acc[nt][2]; |
| 380 | + D[out_row1 * N + c0] = float_to_out<OutT>(acc[nt][2]); |
355 | 381 | if (out_row1 < M && c1 < N) |
356 | | - D[out_row1 * N + c1] = acc[nt][3]; |
| 382 | + D[out_row1 * N + c1] = float_to_out<OutT>(acc[nt][3]); |
357 | 383 | } |
358 | 384 | } |
359 | 385 | } |
@@ -515,68 +541,116 @@ __global__ void kGemmNVFP4_simple( |
515 | 541 | // RTX PRO 6000: 84 SMs |
516 | 542 | static const int NUM_SMS = 84; |
517 | 543 |
|
518 | | -extern "C" void cgemm_nvfp4( |
519 | | - const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M, |
520 | | - int N, int K, cudaStream_t stream |
521 | | -) { |
522 | | - int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM; |
523 | | - int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM; |
524 | | - int base_blocks = num_m_blocks * num_n_blocks; |
525 | | - int threads_per_block = WARPS_PER_BLOCK * 32; // 256 |
526 | | - |
527 | | - // Auto split-K: split along K to fill the GPU when M/N tiles are sparse |
528 | | - // Two-tier heuristic based on GPU occupancy: |
529 | | - // - Very sparse (<1 block/SM): aggressive split to 4 blocks/SM |
530 | | - // - Moderate (<2 blocks/SM): gentle split to 2 blocks/SM |
531 | | - // - Sufficient (>=2 blocks/SM): no split |
| 544 | +// ============================================================================ |
| 545 | +// Auto split-K heuristic (shared by all launchers) |
| 546 | +// ============================================================================ |
| 547 | +static int compute_split_k(int base_blocks, int K) { |
532 | 548 | int max_k_splits = K / 64; |
533 | 549 | int split_k = 1; |
534 | 550 | if (base_blocks < NUM_SMS && max_k_splits > 1) { |
535 | | - // Very sparse: target 4 blocks/SM for full occupancy |
536 | 551 | int target = NUM_SMS * 4; |
537 | 552 | split_k = (target + base_blocks - 1) / base_blocks; |
538 | 553 | if (split_k > max_k_splits) |
539 | 554 | split_k = max_k_splits; |
540 | 555 | if (split_k > 16) |
541 | 556 | split_k = 16; |
542 | 557 | } else if (base_blocks < NUM_SMS * 2 && max_k_splits > 1) { |
543 | | - // Moderate: target 2 blocks/SM |
544 | 558 | int target = NUM_SMS * 2; |
545 | 559 | split_k = (target + base_blocks - 1) / base_blocks; |
546 | 560 | if (split_k > max_k_splits) |
547 | 561 | split_k = max_k_splits; |
548 | 562 | if (split_k > 4) |
549 | | - split_k = 4; // limit atomicAdd overhead for larger outputs |
| 563 | + split_k = 4; |
550 | 564 | } |
| 565 | + return split_k; |
| 566 | +} |
| 567 | + |
| 568 | +// ============================================================================ |
| 569 | +// Generic typed launcher: works for float, __nv_bfloat16, half |
| 570 | +// ============================================================================ |
| 571 | +template <typename OutT> |
| 572 | +static void launch_gemm_nvfp4( |
| 573 | + const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, OutT* D, |
| 574 | + float* workspace, int M, int N, int K, int split_k, cudaStream_t stream |
| 575 | +) { |
| 576 | + int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM; |
| 577 | + int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM; |
| 578 | + int threads_per_block = WARPS_PER_BLOCK * 32; |
551 | 579 |
|
552 | | - // Zero output when using split-K (atomicAdd requires zeroed buffer) |
553 | 580 | if (split_k > 1) { |
554 | | - cudaMemsetAsync(D, 0, (size_t)M * N * sizeof(float), stream); |
| 581 | + // Split-K: accumulate in FP32 workspace, then convert to OutT |
| 582 | + cudaMemsetAsync(workspace, 0, (size_t)M * N * sizeof(float), stream); |
| 583 | + dim3 grid(num_n_blocks, num_m_blocks, split_k); |
| 584 | + kGemmNVFP4_smem<OutT><<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, workspace, M, N, K); |
| 585 | + |
| 586 | + // Convert FP32 workspace → OutT output (skip for FP32 when workspace == (float*)D) |
| 587 | + if constexpr (!std::is_same_v<OutT, float>) { |
| 588 | + int n_elem = M * N; |
| 589 | + int conv_threads = 256; |
| 590 | + int conv_blocks = (n_elem + conv_threads - 1) / conv_threads; |
| 591 | + kConvertOutput<OutT><<<conv_blocks, conv_threads, 0, stream>>>(workspace, D, n_elem); |
| 592 | + } |
| 593 | + } else { |
| 594 | + // No split-K: direct typed output |
| 595 | + dim3 grid(num_n_blocks, num_m_blocks, 1); |
| 596 | + kGemmNVFP4_smem<OutT><<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, nullptr, M, N, K); |
555 | 597 | } |
| 598 | +} |
556 | 599 |
|
557 | | - dim3 grid(num_n_blocks, num_m_blocks, split_k); |
558 | | - kGemmNVFP4_smem<<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, M, N, K); |
| 600 | +// ============================================================================ |
| 601 | +// C entry points — FP32 output (backward compatible) |
| 602 | +// ============================================================================ |
| 603 | +extern "C" void cgemm_nvfp4( |
| 604 | + const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M, |
| 605 | + int N, int K, cudaStream_t stream |
| 606 | +) { |
| 607 | + int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM; |
| 608 | + int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM; |
| 609 | + int base_blocks = num_m_blocks * num_n_blocks; |
| 610 | + int split_k = compute_split_k(base_blocks, K); |
| 611 | + |
| 612 | + // FP32 output: D serves as both output and workspace for split-K |
| 613 | + launch_gemm_nvfp4<float>(A, B, SFA, SFB, D, D, M, N, K, split_k, stream); |
559 | 614 | } |
560 | 615 |
|
561 | | -// Overload: caller specifies split-K explicitly (for benchmarking) |
562 | 616 | extern "C" void cgemm_nvfp4_splitk( |
563 | 617 | const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M, |
564 | 618 | int N, int K, int split_k, cudaStream_t stream |
| 619 | +) { |
| 620 | + if (split_k < 1) |
| 621 | + split_k = 1; |
| 622 | + int max_k_splits = K / 64; |
| 623 | + if (split_k > max_k_splits) |
| 624 | + split_k = max_k_splits; |
| 625 | + |
| 626 | + // FP32 output: D serves as both output and workspace |
| 627 | + launch_gemm_nvfp4<float>(A, B, SFA, SFB, D, D, M, N, K, split_k, stream); |
| 628 | +} |
| 629 | + |
| 630 | +// ============================================================================ |
| 631 | +// C entry points — BF16 output |
| 632 | +// ============================================================================ |
| 633 | +extern "C" void cgemm_nvfp4_bf16( |
| 634 | + const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, |
| 635 | + __nv_bfloat16* D, float* workspace, int M, int N, int K, cudaStream_t stream |
565 | 636 | ) { |
566 | 637 | int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM; |
567 | 638 | int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM; |
568 | | - int threads_per_block = WARPS_PER_BLOCK * 32; |
| 639 | + int base_blocks = num_m_blocks * num_n_blocks; |
| 640 | + int split_k = compute_split_k(base_blocks, K); |
| 641 | + |
| 642 | + launch_gemm_nvfp4<__nv_bfloat16>(A, B, SFA, SFB, D, workspace, M, N, K, split_k, stream); |
| 643 | +} |
569 | 644 |
|
| 645 | +extern "C" void cgemm_nvfp4_bf16_splitk( |
| 646 | + const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, |
| 647 | + __nv_bfloat16* D, float* workspace, int M, int N, int K, int split_k, cudaStream_t stream |
| 648 | +) { |
570 | 649 | if (split_k < 1) |
571 | 650 | split_k = 1; |
572 | 651 | int max_k_splits = K / 64; |
573 | 652 | if (split_k > max_k_splits) |
574 | 653 | split_k = max_k_splits; |
575 | 654 |
|
576 | | - if (split_k > 1) { |
577 | | - cudaMemsetAsync(D, 0, (size_t)M * N * sizeof(float), stream); |
578 | | - } |
579 | | - |
580 | | - dim3 grid(num_n_blocks, num_m_blocks, split_k); |
581 | | - kGemmNVFP4_smem<<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, M, N, K); |
| 655 | + launch_gemm_nvfp4<__nv_bfloat16>(A, B, SFA, SFB, D, workspace, M, N, K, split_k, stream); |
582 | 656 | } |
0 commit comments