4545 end
4646end
4747
48+ # `fast_acc=true` is an FP8-only hint; it still lowers to `mmaf` (13.3+).
49+ @test @filecheck begin
50+ @check_label " entry"
51+ code_tiled (Tuple{ct. TileArray{Float8_E4M3FN,2 ,spec2d}, ct. TileArray{Float8_E4M3FN,2 ,spec2d},
52+ ct. TileArray{Float32,2 ,spec2d}}; bytecode_version= v " 13.3" ) do a, b, c
53+ ta = ct. load (a, (1 , 1 ), (16 , 16 ))
54+ tb = ct. load (b, (1 , 1 ), (16 , 16 ))
55+ @check " mmaf"
56+ ct. store (c, (1 , 1 ), muladd (ta, tb, zeros (Float32, (16 , 16 )); fast_acc= true ))
57+ return
58+ end
4859end
4960
50- # FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+).
51- @testset " execution" begin
52- if capability (device ()) >= v " 9"
61+ end
5362
54- # Round-trip Float32 → FP8 → Float32 on values exactly representable in
55- # the target FP8 type — result must match input bit-for-bit.
63+ # Execution kernels are plain top-level functions, each defined next to the
64+ # test that exercises it. Kernels parametric on accumulator dtype must stay at
65+ # top level — defining them inside a testset scope boxes them into closures.
66+
67+ # Round-trip Float32 → FP8 → Float32 on values exactly representable in the
68+ # target FP8 type — result must match input bit-for-bit.
5669function rt_e4m3 (a:: ct.TileArray{Float32,1} , b:: ct.TileArray{Float32,1} )
5770 pid = ct. bid (1 )
5871 tile = ct. load (a, pid, (16 ,))
@@ -65,19 +78,8 @@ function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1})
6578 ct. store (b, pid, convert (ct. Tile{Float32}, convert (ct. Tile{Float8_E5M2}, tile)))
6679 return
6780end
68-
69- representable = Float32[0.0 , 0.5 , 1.0 , 1.5 , 2.0 , 3.0 , 4.0 , 8.0 ,
70- 16.0 , 32.0 , 64.0 , 128.0 , 256.0 , - 1.0 , - 2.0 , - 0.5 ]
71- let a = CuArray (representable), b = CUDA. zeros (Float32, length (representable))
72- @cuda backend= cuTile blocks= 1 rt_e4m3 (a, b)
73- @test Array (b) == representable
74- @cuda backend= cuTile blocks= 1 rt_e5m2 (a, b)
75- @test Array (b) == representable
76- end
77-
7881# FMA in FP8: load Float32, convert to FP8, multiply-add in FP8, convert back.
79- # Uses inputs whose products and sums also stay representable, so the result
80- # is exact.
82+ # Inputs whose products and sums also stay representable, so the result is exact.
8183function fma_e4m3 (a:: ct.TileArray{Float32,1} , b:: ct.TileArray{Float32,1} ,
8284 c:: ct.TileArray{Float32,1} , d:: ct.TileArray{Float32,1} )
8385 pid = ct. bid (1 )
@@ -87,6 +89,33 @@ function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
8789 ct. store (d, pid, convert (ct. Tile{Float32}, muladd .(ta, tb, tc)))
8890 return
8991end
92+ # Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32).
93+ function mma_dl_fp8 (A:: ct.TileArray{Float8_E4M3FN,2} , B:: ct.TileArray{Float8_E4M3FN,2} ,
94+ C:: ct.TileArray{Tacc,2} , D:: ct.TileArray{Float32,2} ) where {Tacc<: Union{Float16,Float32} }
95+ a = ct. load (A, (1 , 1 ), (16 , 16 )); b = ct. load (B, (1 , 1 ), (16 , 16 )); c = ct. load (C, (1 , 1 ), (16 , 16 ))
96+ ct. store (D, (1 , 1 ), convert (ct. Tile{Float32}, muladd (a, b, c)))
97+ return
98+ end
99+ function mma_dl_fast (A:: ct.TileArray{Float8_E4M3FN,2} , B:: ct.TileArray{Float8_E4M3FN,2} ,
100+ C:: ct.TileArray{Float32,2} , D:: ct.TileArray{Float32,2} )
101+ a = ct. load (A, (1 , 1 ), (16 , 16 )); b = ct. load (B, (1 , 1 ), (16 , 16 )); c = ct. load (C, (1 , 1 ), (16 , 16 ))
102+ ct. store (D, (1 , 1 ), convert (ct. Tile{Float32}, muladd (a, b, c; fast_acc= true )))
103+ return
104+ end
105+
106+ # FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+).
107+ @testset " execution" begin
108+ if capability (device ()) >= v " 9"
109+
110+ representable = Float32[0.0 , 0.5 , 1.0 , 1.5 , 2.0 , 3.0 , 4.0 , 8.0 ,
111+ 16.0 , 32.0 , 64.0 , 128.0 , 256.0 , - 1.0 , - 2.0 , - 0.5 ]
112+ let a = CuArray (representable), b = CUDA. zeros (Float32, length (representable))
113+ @cuda backend= cuTile blocks= 1 rt_e4m3 (a, b)
114+ @test Array (b) == representable
115+ @cuda backend= cuTile blocks= 1 rt_e5m2 (a, b)
116+ @test Array (b) == representable
117+ end
118+
90119let av = Float32[1.0 , 2.0 , 0.5 , 4.0 , 1.5 , 2.0 , - 1.0 , - 0.5 , 3.0 , 0.5 , 1.0 , 2.0 , - 2.0 , 1.0 , 0.5 , 4.0 ],
91120 bv = Float32[2.0 , 1.0 , 4.0 , 0.5 , 2.0 , 3.0 , 2.0 , 4.0 , 1.0 , 2.0 , 1.0 , 0.5 , 2.0 , 1.0 , 2.0 , 1.0 ],
92121 cv = Float32[0.0 , 1.0 , 0.0 , 0.0 , 1.0 , 1.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 1.0 , 0.0 ]
@@ -96,29 +125,30 @@ let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -
96125 @test Array (d) == av .* bv .+ cv
97126end
98127
99- # Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32).
100- function mma_dl_f32 (A:: ct.TileArray{Float8_E4M3FN,2} , B:: ct.TileArray{Float8_E4M3FN,2} ,
101- C:: ct.TileArray{Float32,2} , D:: ct.TileArray{Float32,2} )
102- a = ct. load (A, (1 , 1 ), (16 , 16 )); b = ct. load (B, (1 , 1 ), (16 , 16 )); c = ct. load (C, (1 , 1 ), (16 , 16 ))
103- ct. store (D, (1 , 1 ), convert (ct. Tile{Float32}, muladd (a, b, c)))
104- return
105- end
106- function mma_dl_f16 (A:: ct.TileArray{Float8_E4M3FN,2} , B:: ct.TileArray{Float8_E4M3FN,2} ,
107- C:: ct.TileArray{Float16,2} , D:: ct.TileArray{Float32,2} )
108- a = ct. load (A, (1 , 1 ), (16 , 16 )); b = ct. load (B, (1 , 1 ), (16 , 16 )); c = ct. load (C, (1 , 1 ), (16 , 16 ))
109- ct. store (D, (1 , 1 ), convert (ct. Tile{Float32}, muladd (a, b, c)))
110- return
111- end
112- @testset " mma → $Tacc acc" for (Tacc, kern) in ((Float32, mma_dl_f32), (Float16, mma_dl_f16))
128+ @testset " mma → $Tacc acc" for Tacc in (Float32, Float16)
113129 M = 16
114130 ah = Float8_E4M3FN .(Float32 .(rand (0 : 2 , M, M)) ./ 2 )
115131 bh = Float8_E4M3FN .(Float32 .(rand (0 : 2 , M, M)) ./ 2 )
116132 ch = Tacc .(Float32 .(rand (0 : 2 , M, M)))
117133 ref = Float32 .(ah) * Float32 .(bh) .+ Float32 .(ch)
118134 D = CUDA. zeros (Float32, M, M)
119- @cuda backend= cuTile blocks= 1 kern (CuArray (ah), CuArray (bh), CuArray (ch), D)
135+ @cuda backend= cuTile blocks= 1 mma_dl_fp8 (CuArray (ah), CuArray (bh), CuArray (ch), D)
120136 @test Array (D) == ref
121137end
122138
139+ # fast_acc only has an effect on Hopper (sm_90); ignored elsewhere. So off
140+ # Hopper we assert the exact result (the flag must ride through without
141+ # perturbing the output); on Hopper we make no numeric claim.
142+ @testset " mma fast_acc (exact off Hopper)" begin
143+ M = 16
144+ ah = Float8_E4M3FN .(Float32 .(rand (0 : 2 , M, M)) ./ 2 )
145+ bh = Float8_E4M3FN .(Float32 .(rand (0 : 2 , M, M)) ./ 2 )
146+ ch = Float32 .(rand (0 : 2 , M, M))
147+ ref = Float32 .(ah) * Float32 .(bh) .+ ch
148+ D = CUDA. zeros (Float32, M, M)
149+ @cuda backend= cuTile blocks= 1 mma_dl_fast (CuArray (ah), CuArray (bh), CuArray (ch), D)
150+ @test (Array (D) == ref) || (v " 9" <= capability (device ()) < v " 10" )
151+ end
152+
123153end
124154end
0 commit comments