|
| 1 | +// Minimal test: verify mma.sync.aligned.kind::mxf4nvf4 works on SM_120 |
| 2 | +// Compile: nvcc -arch=sm_120 -o test_mma_nvfp4 test_mma_nvfp4.cu |
| 3 | +// Run: ./test_mma_nvfp4 |
| 4 | + |
| 5 | +#include <cstdio> |
| 6 | +#include <cstdint> |
| 7 | +#include <cuda_runtime.h> |
| 8 | + |
| 9 | +// MMA instruction: m16n8k64, E2M1 x E2M1 -> F32, with UE4M3 block scales |
| 10 | +// One warp (32 threads) processes: |
| 11 | +// A: 16x64 E2M1 tile (4 regs per thread, 8 nibbles per reg) |
| 12 | +// B: 8x64 E2M1 tile (2 regs per thread, 8 nibbles per reg) |
| 13 | +// SFA: 4 UE4M3 scale factors for A (packed in 1 uint32) |
| 14 | +// SFB: 4 UE4M3 scale factors for B (packed in 1 uint32) |
| 15 | +// D/C: 16x8 F32 tile (4 floats per thread) |
| 16 | + |
| 17 | +__device__ void mma_nvfp4_16x8x64( |
| 18 | + float &d0, float &d1, float &d2, float &d3, |
| 19 | + uint32_t a0, uint32_t a1, uint32_t a2, uint32_t a3, |
| 20 | + uint32_t b0, uint32_t b1, |
| 21 | + float c0, float c1, float c2, float c3, |
| 22 | + uint32_t sfa, uint32_t sfb |
| 23 | +) { |
| 24 | + uint16_t bidA = 0, tidA = 0, bidB = 0, tidB = 0; |
| 25 | + |
| 26 | + asm volatile( |
| 27 | + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " |
| 28 | + "{%0, %1, %2, %3}," |
| 29 | + "{%4, %5, %6, %7}," |
| 30 | + "{%8, %9}," |
| 31 | + "{%10, %11, %12, %13}," |
| 32 | + "{%14}," |
| 33 | + "{%15, %16}," |
| 34 | + "{%17}," |
| 35 | + "{%18, %19};\n" |
| 36 | + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) |
| 37 | + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), |
| 38 | + "r"(b0), "r"(b1), |
| 39 | + "f"(c0), "f"(c1), "f"(c2), "f"(c3), |
| 40 | + "r"(sfa), "h"(bidA), "h"(tidA), |
| 41 | + "r"(sfb), "h"(bidB), "h"(tidB) |
| 42 | + ); |
| 43 | +} |
| 44 | + |
| 45 | +__global__ void test_mma_kernel(float* output) { |
| 46 | + // E2M1 code for 1.0: sign=0, exp=1, mant=0 -> 0b0010 = 0x2 |
| 47 | + // Pack 8 E2M1 values of 1.0 into one uint32: each nibble = 0x2 |
| 48 | + uint32_t a_val = 0x22222222u; // 8 x E2M1(1.0) |
| 49 | + uint32_t b_val = 0x22222222u; // 8 x E2M1(1.0) |
| 50 | + |
| 51 | + // UE4M3 code for 1.0: exp=7 (bias=7, so 2^0=1), mant=0 -> 0b01110000 = 0x38 |
| 52 | + // Wait - UE4M3 is unsigned, 4 exp bits, 3 mantissa bits |
| 53 | + // For value 1.0: 2^(e-7) * (1 + m/8) = 2^0 * 1.0 = 1.0 when e=7, m=0 |
| 54 | + // Binary: 0111 000 = 0x38 |
| 55 | + // Pack 4 UE4M3 values of 1.0: each byte = 0x38 |
| 56 | + uint32_t sfa_val = 0x38383838u; // 4 x UE4M3(1.0) |
| 57 | + uint32_t sfb_val = 0x38383838u; // 4 x UE4M3(1.0) |
| 58 | + |
| 59 | + // Accumulator starts at 0 |
| 60 | + float d0 = 0.0f, d1 = 0.0f, d2 = 0.0f, d3 = 0.0f; |
| 61 | + |
| 62 | + mma_nvfp4_16x8x64( |
| 63 | + d0, d1, d2, d3, |
| 64 | + a_val, a_val, a_val, a_val, // A: all 1.0 |
| 65 | + b_val, b_val, // B: all 1.0 |
| 66 | + 0.0f, 0.0f, 0.0f, 0.0f, // C: accumulator = 0 |
| 67 | + sfa_val, sfb_val |
| 68 | + ); |
| 69 | + |
| 70 | + // Each thread writes its 4 output values |
| 71 | + int tid = threadIdx.x; |
| 72 | + output[tid * 4 + 0] = d0; |
| 73 | + output[tid * 4 + 1] = d1; |
| 74 | + output[tid * 4 + 2] = d2; |
| 75 | + output[tid * 4 + 3] = d3; |
| 76 | +} |
| 77 | + |
| 78 | +int main() { |
| 79 | + float* d_output; |
| 80 | + float h_output[128]; // 32 threads * 4 values |
| 81 | + |
| 82 | + cudaMalloc(&d_output, 128 * sizeof(float)); |
| 83 | + cudaMemset(d_output, 0, 128 * sizeof(float)); |
| 84 | + |
| 85 | + // Launch 1 warp |
| 86 | + test_mma_kernel<<<1, 32>>>(d_output); |
| 87 | + |
| 88 | + cudaError_t err = cudaGetLastError(); |
| 89 | + if (err != cudaSuccess) { |
| 90 | + printf("Kernel launch error: %s\n", cudaGetErrorString(err)); |
| 91 | + return 1; |
| 92 | + } |
| 93 | + |
| 94 | + cudaDeviceSynchronize(); |
| 95 | + err = cudaGetLastError(); |
| 96 | + if (err != cudaSuccess) { |
| 97 | + printf("Kernel execution error: %s\n", cudaGetErrorString(err)); |
| 98 | + return 1; |
| 99 | + } |
| 100 | + |
| 101 | + cudaMemcpy(h_output, d_output, 128 * sizeof(float), cudaMemcpyDeviceToHost); |
| 102 | + |
| 103 | + // Expected: all A=1.0, all B=1.0, all scales=1.0 |
| 104 | + // D[i][j] = sum_k (A[i][k] * SFA[i][k/16]) * (B[j][k] * SFB[j][k/16]) |
| 105 | + // = sum_k=0..63 (1.0 * 1.0) * (1.0 * 1.0) = 64.0 |
| 106 | + printf("MMA NVFP4 m16n8k64 test (all ones, scales=1.0):\n"); |
| 107 | + printf("Expected: 64.0 for all outputs\n\n"); |
| 108 | + |
| 109 | + int pass = 1; |
| 110 | + for (int t = 0; t < 32; t++) { |
| 111 | + for (int v = 0; v < 4; v++) { |
| 112 | + float val = h_output[t * 4 + v]; |
| 113 | + if (val != 64.0f) pass = 0; |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + // Print first few threads |
| 118 | + for (int t = 0; t < 4; t++) { |
| 119 | + printf(" Thread %2d: d0=%.1f d1=%.1f d2=%.1f d3=%.1f\n", |
| 120 | + t, h_output[t*4], h_output[t*4+1], h_output[t*4+2], h_output[t*4+3]); |
| 121 | + } |
| 122 | + printf(" ...\n"); |
| 123 | + for (int t = 28; t < 32; t++) { |
| 124 | + printf(" Thread %2d: d0=%.1f d1=%.1f d2=%.1f d3=%.1f\n", |
| 125 | + t, h_output[t*4], h_output[t*4+1], h_output[t*4+2], h_output[t*4+3]); |
| 126 | + } |
| 127 | + |
| 128 | + printf("\n%s\n", pass ? "PASS: All outputs are 64.0" : "FAIL: Some outputs incorrect"); |
| 129 | + |
| 130 | + cudaFree(d_output); |
| 131 | + return pass ? 0 : 1; |
| 132 | +} |
0 commit comments