|
| 1 | +# NVFP4 GEMM Benchmark Results |
| 2 | + |
| 3 | +## Hardware |
| 4 | +- GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition (SM_120, 96GB GDDR7) |
| 5 | +- CUDA: 13.1 (nvcc), PyTorch 2.9.1+cu130 |
| 6 | +- Driver: 580.95.05 |
| 7 | + |
| 8 | +## Kernel Implementation |
| 9 | +- **NVFP4**: Correctness-first kernel (`kGemmNVFP4_simple`), one warp per m16n8 output tile, |
| 10 | + global memory loads, no shared memory, no software pipelining. |
| 11 | + Uses `mma.sync.aligned.block_scale` PTX instruction. |
| 12 | +- **FP16**: cuBLAS via `torch.matmul` (highly optimized baseline) |
| 13 | + |
| 14 | +## Results |
| 15 | + |
| 16 | +| Shape | NVFP4 (ms) | FP16 (ms) | Speedup | NVFP4 TFLOPS | FP16 TFLOPS | |
| 17 | +|-------|-----------|----------|---------|-------------|------------| |
| 18 | +| 128x128x128 | 0.012 | 0.005 | 0.43x | 0.4T | 0.8T | |
| 19 | +| 256x256x256 | 0.012 | 0.005 | 0.43x | 2.9T | 6.8T | |
| 20 | +| 512x512x512 | 0.023 | 0.005 | 0.22x | 11.9T | 53.1T | |
| 21 | +| 1024x1024x1024 | 0.124 | 0.010 | 0.08x | 17.4T | 208.4T | |
| 22 | +| 2048x2048x2048 | 0.965 | 0.053 | 0.06x | 17.8T | 322.7T | |
| 23 | +| 4096x4096x4096 | 7.571 | 0.347 | 0.05x | 18.2T | 396.5T | |
| 24 | +| 1x4096x4096 | 0.092 | 0.010 | 0.11x | 5.8T | 3.3T | |
| 25 | +| 8x4096x4096 | 0.090 | 0.010 | 0.11x | 6.0T | 25.9T | |
| 26 | +| 32x4096x4096 | 0.111 | 0.012 | 0.11x | 9.7T | 86.9T | |
| 27 | +| 128x4096x4096 | 0.267 | 0.019 | 0.07x | 16.1T | 231.6T | |
| 28 | +| 32x4096x11008 | 0.260 | 0.023 | 0.09x | 11.1T | 127.0T | |
| 29 | +| 128x4096x11008 | 0.621 | 0.041 | 0.07x | 18.6T | 280.6T | |
| 30 | + |
| 31 | +## Memory Savings |
| 32 | + |
| 33 | +| Weight Shape | FP16 | NVFP4 | Compression | |
| 34 | +|-------------|------|-------|-------------| |
| 35 | +| 4096x4096 | 32.0 MB | 9.0 MB | 3.6x | |
| 36 | +| 4096x11008 | 86.0 MB | 24.1 MB | 3.6x | |
| 37 | + |
| 38 | +## Analysis |
| 39 | + |
| 40 | +The NVFP4 GEMM kernel peaks at ~18 TFLOPS, while cuBLAS FP16 reaches ~400 TFLOPS on |
| 41 | +the RTX PRO 6000. The current kernel is **~20x slower** than cuBLAS at large matrix sizes. |
| 42 | + |
| 43 | +### Why the NVFP4 kernel is slow |
| 44 | + |
| 45 | +This is a **correctness-first implementation** with no performance optimization: |
| 46 | +1. **Global memory loads per-element**: Each thread loads individual nibbles from global memory |
| 47 | + with manual bit manipulation (shifts and masks). No coalesced loads. |
| 48 | +2. **No shared memory**: Data is loaded directly from global memory into registers. |
| 49 | + A tiled kernel would stage data in shared memory for reuse. |
| 50 | +3. **No software pipelining**: K-dimension loop has no overlap between compute and memory. |
| 51 | +4. **One warp per m16n8 tile**: Poor utilization of the SM's resources. A proper kernel |
| 52 | + would use multiple warps per threadblock with a larger tile (128x128x128). |
| 53 | +5. **Per-element packing**: The nibble extraction loop is serial (8 iterations per register). |
| 54 | + |
| 55 | +### Performance optimization path |
| 56 | + |
| 57 | +To close the gap with cuBLAS FP16, the kernel would need: |
| 58 | +1. Shared memory tiling (128x128x128 threadblock tile) |
| 59 | +2. Coalesced global → shared memory loads (cp.async or vectorized loads) |
| 60 | +3. 2-3 stage software pipelining for the K loop |
| 61 | +4. Multiple warps per threadblock (e.g., 4 warps computing 128x128 output) |
| 62 | +5. Vectorized nibble packing (load uint32/uint64 instead of byte-by-byte) |
| 63 | + |
| 64 | +The theoretical speedup of NVFP4 over FP16 on Blackwell is ~2x (double the FLOPs per |
| 65 | +cycle). Achieving this requires a kernel within ~50% of cuBLAS's FP16 efficiency. |
| 66 | + |
| 67 | +### Current value |
| 68 | + |
| 69 | +Despite the performance gap, the implementation provides: |
| 70 | +- **3.6x memory savings**: Enables larger models in GPU memory |
| 71 | +- **Correct GEMM output**: Verified against torch.matmul on dequantized inputs |
| 72 | + with 0.000000 relative error (same quantized data, different only in FP32 rounding) |
| 73 | +- **Full Python API**: quantize/dequantize/GEMM/LinearNVFP4 all working end-to-end |
| 74 | +- **NVFP4 output epilogue**: GEMM → quantize chain for layer chaining |
0 commit comments