Skip to content

Commit 1afdde6

Browse files
TimDettmersclaude
andcommitted
test: Add minimal MMA mxf4nvf4 block_scale test for SM_120a
Standalone test verifies the PTX instruction works on RTX PRO 6000: mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X .m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 Key finding: requires compute_120a (not compute_120) arch target. All 128 outputs = 64.0 when A=B=1.0, scales=1.0. PASS. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d0da58d commit 1afdde6

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

csrc/test_mma_nvfp4.cu

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Minimal test: verify mma.sync.aligned.kind::mxf4nvf4 works on SM_120
2+
// Compile: nvcc -arch=sm_120 -o test_mma_nvfp4 test_mma_nvfp4.cu
3+
// Run: ./test_mma_nvfp4
4+
5+
#include <cstdio>
6+
#include <cstdint>
7+
#include <cuda_runtime.h>
8+
9+
// MMA instruction: m16n8k64, E2M1 x E2M1 -> F32, with UE4M3 block scales
10+
// One warp (32 threads) processes:
11+
// A: 16x64 E2M1 tile (4 regs per thread, 8 nibbles per reg)
12+
// B: 8x64 E2M1 tile (2 regs per thread, 8 nibbles per reg)
13+
// SFA: 4 UE4M3 scale factors for A (packed in 1 uint32)
14+
// SFB: 4 UE4M3 scale factors for B (packed in 1 uint32)
15+
// D/C: 16x8 F32 tile (4 floats per thread)
16+
17+
__device__ void mma_nvfp4_16x8x64(
18+
float &d0, float &d1, float &d2, float &d3,
19+
uint32_t a0, uint32_t a1, uint32_t a2, uint32_t a3,
20+
uint32_t b0, uint32_t b1,
21+
float c0, float c1, float c2, float c3,
22+
uint32_t sfa, uint32_t sfb
23+
) {
24+
uint16_t bidA = 0, tidA = 0, bidB = 0, tidB = 0;
25+
26+
asm volatile(
27+
"mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
28+
"{%0, %1, %2, %3},"
29+
"{%4, %5, %6, %7},"
30+
"{%8, %9},"
31+
"{%10, %11, %12, %13},"
32+
"{%14},"
33+
"{%15, %16},"
34+
"{%17},"
35+
"{%18, %19};\n"
36+
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
37+
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
38+
"r"(b0), "r"(b1),
39+
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
40+
"r"(sfa), "h"(bidA), "h"(tidA),
41+
"r"(sfb), "h"(bidB), "h"(tidB)
42+
);
43+
}
44+
45+
__global__ void test_mma_kernel(float* output) {
46+
// E2M1 code for 1.0: sign=0, exp=1, mant=0 -> 0b0010 = 0x2
47+
// Pack 8 E2M1 values of 1.0 into one uint32: each nibble = 0x2
48+
uint32_t a_val = 0x22222222u; // 8 x E2M1(1.0)
49+
uint32_t b_val = 0x22222222u; // 8 x E2M1(1.0)
50+
51+
// UE4M3 code for 1.0: exp=7 (bias=7, so 2^0=1), mant=0 -> 0b01110000 = 0x38
52+
// Wait - UE4M3 is unsigned, 4 exp bits, 3 mantissa bits
53+
// For value 1.0: 2^(e-7) * (1 + m/8) = 2^0 * 1.0 = 1.0 when e=7, m=0
54+
// Binary: 0111 000 = 0x38
55+
// Pack 4 UE4M3 values of 1.0: each byte = 0x38
56+
uint32_t sfa_val = 0x38383838u; // 4 x UE4M3(1.0)
57+
uint32_t sfb_val = 0x38383838u; // 4 x UE4M3(1.0)
58+
59+
// Accumulator starts at 0
60+
float d0 = 0.0f, d1 = 0.0f, d2 = 0.0f, d3 = 0.0f;
61+
62+
mma_nvfp4_16x8x64(
63+
d0, d1, d2, d3,
64+
a_val, a_val, a_val, a_val, // A: all 1.0
65+
b_val, b_val, // B: all 1.0
66+
0.0f, 0.0f, 0.0f, 0.0f, // C: accumulator = 0
67+
sfa_val, sfb_val
68+
);
69+
70+
// Each thread writes its 4 output values
71+
int tid = threadIdx.x;
72+
output[tid * 4 + 0] = d0;
73+
output[tid * 4 + 1] = d1;
74+
output[tid * 4 + 2] = d2;
75+
output[tid * 4 + 3] = d3;
76+
}
77+
78+
int main() {
79+
float* d_output;
80+
float h_output[128]; // 32 threads * 4 values
81+
82+
cudaMalloc(&d_output, 128 * sizeof(float));
83+
cudaMemset(d_output, 0, 128 * sizeof(float));
84+
85+
// Launch 1 warp
86+
test_mma_kernel<<<1, 32>>>(d_output);
87+
88+
cudaError_t err = cudaGetLastError();
89+
if (err != cudaSuccess) {
90+
printf("Kernel launch error: %s\n", cudaGetErrorString(err));
91+
return 1;
92+
}
93+
94+
cudaDeviceSynchronize();
95+
err = cudaGetLastError();
96+
if (err != cudaSuccess) {
97+
printf("Kernel execution error: %s\n", cudaGetErrorString(err));
98+
return 1;
99+
}
100+
101+
cudaMemcpy(h_output, d_output, 128 * sizeof(float), cudaMemcpyDeviceToHost);
102+
103+
// Expected: all A=1.0, all B=1.0, all scales=1.0
104+
// D[i][j] = sum_k (A[i][k] * SFA[i][k/16]) * (B[j][k] * SFB[j][k/16])
105+
// = sum_k=0..63 (1.0 * 1.0) * (1.0 * 1.0) = 64.0
106+
printf("MMA NVFP4 m16n8k64 test (all ones, scales=1.0):\n");
107+
printf("Expected: 64.0 for all outputs\n\n");
108+
109+
int pass = 1;
110+
for (int t = 0; t < 32; t++) {
111+
for (int v = 0; v < 4; v++) {
112+
float val = h_output[t * 4 + v];
113+
if (val != 64.0f) pass = 0;
114+
}
115+
}
116+
117+
// Print first few threads
118+
for (int t = 0; t < 4; t++) {
119+
printf(" Thread %2d: d0=%.1f d1=%.1f d2=%.1f d3=%.1f\n",
120+
t, h_output[t*4], h_output[t*4+1], h_output[t*4+2], h_output[t*4+3]);
121+
}
122+
printf(" ...\n");
123+
for (int t = 28; t < 32; t++) {
124+
printf(" Thread %2d: d0=%.1f d1=%.1f d2=%.1f d3=%.1f\n",
125+
t, h_output[t*4], h_output[t*4+1], h_output[t*4+2], h_output[t*4+3]);
126+
}
127+
128+
printf("\n%s\n", pass ? "PASS: All outputs are 64.0" : "FAIL: Some outputs incorrect");
129+
130+
cudaFree(d_output);
131+
return pass ? 0 : 1;
132+
}

0 commit comments

Comments
 (0)