Skip to content

Commit ee8e49d

Browse files
committed
Add FP8 muladd with fast_acc argument; add i8/u8 mma tests
1 parent cae78e8 commit ee8e49d

9 files changed

Lines changed: 412 additions & 206 deletions

File tree

README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,23 @@ Tile IR operations.
196196
| Operation | Description |
197197
|-----------|-------------|
198198
| `a * b` | Matrix multiplication: `a @ b` |
199-
| `muladd(a, b, acc)` | Matrix multiply-accumulate: `a * b + acc` |
199+
| `muladd(a, b, acc; fast_acc=false)` | Matrix multiply-accumulate: `a * b + acc` |
200+
| `ct.muladd_scaled(a, a_scale, b, b_scale, acc)` | Block-scaled multiply-accumulate |
201+
202+
Each operation follows `Base.:*` / `Base.muladd`'s shape rules, with the addition of allowing trailing batch dimensions.
203+
204+
`fast_acc=true` enables fast accumulation for FP8 inputs, and has an effect only on Hopper (sm_90; silently ignored on other
205+
architectures), and requires Tile IR v13.3+.
206+
207+
`ct.muladd_scaled` multiplies each operand by a low-precision block scale before the matmul:
208+
each scale element covers a contiguous block of `B = K ÷ K_s` elements along the K dimension.
209+
Requires Blackwell. The supported operand/scale/accumulator dtypes and block sizes are:
210+
211+
| Input (`a`/`b`) | Scale | Acc/Output | B |
212+
|-----------------|-------|------------|--------|
213+
| `Float8_E4M3FN`, `Float8_E5M2` | `Float8_E8M0FNU` | `Float32` | 32 |
214+
| `Float4_E2M1FN` | `Float8_E8M0FNU` | `Float32` | 16, 32 |
215+
| `Float4_E2M1FN` | `Float8_E4M3FN` | `Float32` | 16 |
200216

201217
### Higher-Order Functions
202218
| Operation | Description |

ext/DLFP8TypesExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ end
1818
ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32)
1919
ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32)
2020

21+
# `fast_acc` (lower-precision MMA accumulation) is an FP8-only throughput hint.
22+
ct.mma_supports_fast_acc(::Type{Float8_E4M3FN}) = true
23+
ct.mma_supports_fast_acc(::Type{Float8_E5M2}) = true
24+
2125
# Float ↔ FP8 scalar constructor overlays (for map/convert dispatch)
2226
const FP8Types = (Float8_E4M3FN, Float8_E5M2)
2327
const StandardFloats = (Float16, ct.BFloat16, Float32, ct.TFloat32, Float64)

ext/MicrofloatsExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero
3535
ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32)
3636
ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32)
3737

38+
# `fast_acc` (lower-precision MMA accumulation) is an FP8-only throughput hint.
39+
ct.mma_supports_fast_acc(::Type{Float8_E4M3FN}) = true
40+
ct.mma_supports_fast_acc(::Type{Float8_E5M2}) = true
41+
3842
# Float ↔ microfloat scalar constructor overlays (for map/convert dispatch).
3943
# Mirrors DLFP8TypesExt: route to `Intrinsics.ftof` so kernel-side conversions
4044
# lower to the FToFOp Tile IR intrinsic instead of Microfloats' Float32-fallback

src/compiler/intrinsics/core.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,17 @@ mma_allowed_acc_dtypes(::Type{TFloat32}) = (Float32,)
415415
mma_allowed_acc_dtypes(::Type{Float64}) = (Float64,)
416416
mma_allowed_acc_dtypes(@nospecialize(::Type)) = nothing
417417

418+
"""
419+
mma_supports_fast_acc(input_T) -> Bool
420+
421+
Whether `mma`'s `fast_acc` hint (lower-precision accumulation for throughput)
422+
is valid for operand element type `input_T`. Only the FP8 dtypes qualify;
423+
extensions (Microfloats/DLFP8Types) register them by overloading this, like
424+
[`mma_allowed_acc_dtypes`](@ref). Mirrors cuTile Python's
425+
`use_fast_acc is only supported for fp8 input dtypes` check.
426+
"""
427+
mma_supports_fast_acc(@nospecialize(::Type)) = false
428+
418429
# First (preferred) accumulator dtype for a given input dtype. Used by
419430
# matmul to pick `acc = zeros(first_allowed_acc(T), …)` so that the input
420431
# dtype constraint and tileiras's acc-dtype constraint stay consistent
@@ -425,7 +436,7 @@ first_allowed_acc_dtype(::Type{T}) where {T} =
425436
end
426437

427438
"""
428-
Intrinsics.mma(a::Tile, b::Tile, acc::Tile) -> typeof(acc)
439+
Intrinsics.mma(a::Tile, b::Tile, acc::Tile, fast_acc::Bool=false) -> typeof(acc)
429440
430441
Matrix-multiply-accumulate computing `a*b + acc`. Dispatches at codegen
431442
based on element types:
@@ -436,15 +447,23 @@ based on element types:
436447
- `i8` `a`/`b` with `i32` `acc` lower to `cuda_tile.mmai`. Per-input
437448
signedness is derived from the Julia type (`Int8` → signed,
438449
`UInt8` → unsigned); `acc` and the result are always signed `i32`.
450+
451+
`fast_acc` enables fast accumulation (trading accumulator precision for
452+
throughput). It is only valid for FP8 inputs (see
453+
[`mma_supports_fast_acc`](@ref)) and requires Tile IR v13.3+.
439454
"""
440-
@intrinsic mma(a::Tile, b::Tile, acc::Tile)
441-
tfunc(𝕃, ::typeof(Intrinsics.mma), @nospecialize(a), @nospecialize(b), @nospecialize(acc)) = CC.widenconst(acc)
455+
@intrinsic mma(a::Tile, b::Tile, acc::Tile, fast_acc::Bool=false)
456+
tfunc(𝕃, ::typeof(Intrinsics.mma), @nospecialize(a), @nospecialize(b), @nospecialize(acc),
457+
@nospecialize(rest...)) = CC.widenconst(acc)
442458
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args)
443459
cb = ctx.cb
444460

445461
lhs = emit_value!(ctx, args[1])
446462
rhs = emit_value!(ctx, args[2])
447463
acc = emit_value!(ctx, args[3])
464+
fast_acc = length(args) >= 4 ?
465+
(@something get_constant(ctx, args[4]) throw(IRError("mma: fast_acc must be a compile-time constant"))) :
466+
false
448467

449468
(lhs === nothing || rhs === nothing || acc === nothing) && throw(IRError("Cannot resolve operands for mma()"))
450469

@@ -467,9 +486,15 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args)
467486
acc_elem in allowed_acc ||
468487
throw(IRError("mma: acc dtype $acc_elem is not allowed for input dtype " *
469488
"$lhs_elem; tileiras requires acc ∈ $allowed_acc"))
470-
encode_MmaFOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v)
489+
# fast_acc is an FP8-only throughput hint; reject it elsewhere with a
490+
# clear error rather than emitting IR tileiras would reject.
491+
fast_acc && !Base.invokelatest(mma_supports_fast_acc, lhs_elem) &&
492+
throw(IRError("mma: fast_acc is only supported for fp8 input dtypes " *
493+
"(f8e4m3fn, f8e5m2), got $lhs_elem"))
494+
encode_MmaFOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v; fast_acc)
471495
elseif lhs_elem <: Union{Int8, UInt8} && rhs_elem <: Union{Int8, UInt8} &&
472496
acc_elem === Int32
497+
fast_acc && throw(IRError("mma: fast_acc is not supported for integer (mmai) inputs"))
473498
s_lhs = lhs_elem <: Signed ? Signedness.Signed : Signedness.Unsigned
474499
s_rhs = rhs_elem <: Signed ? Signedness.Signed : Signedness.Unsigned
475500
encode_MmaIOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v;

src/language/operations.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,7 @@ end
11451145
=============================================================================#
11461146

11471147
"""
1148-
muladd(a::Tile, b::Tile, acc::Tile) -> Tile
1148+
muladd(a::Tile, b::Tile, acc::Tile; fast_acc::Bool=false) -> Tile
11491149
11501150
Matrix multiply-accumulate `a * b + acc` over tiles, lowering to
11511151
`cuda_tile.mmaf` (float) or `cuda_tile.mmai` (`i8 × i8 → i32`).
@@ -1155,45 +1155,49 @@ promoted (vec-mat / mat-vec) and any trailing dimensions (≥3-D) are treated as
11551155
broadcast batch dims, lifting `Base.muladd`'s shape rules to tiles. `acc`
11561156
carries the result dtype, which must be one tileiras allows for the input dtype
11571157
(f16/f32 for f16 and f8; f32 for bf16/tf32; f64 for f64; i32 for i8).
1158+
1159+
`fast_acc` enables fast accumulation (lower accumulator precision for
1160+
throughput); it is valid only for FP8 inputs and requires Tile IR v13.3+.
11581161
"""
1159-
@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC}
1162+
@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC};
1163+
fast_acc::Bool=false) where {T1, T2, T3, SA, SB, SC}
11601164
# SA, SB, SC type parameters avoid ambiguity with the scalar `muladd`
11611165
# methods during codegen.
1162-
_muladd(a, b, acc, Val(ndims(a)), Val(ndims(b)))
1166+
_muladd(a, b, acc, Val(ndims(a)), Val(ndims(b)), fast_acc)
11631167
end
11641168

11651169
# 2D × 2D: MmaFOp with swapped operands for row-major Tile IR
11661170
# Julia (M,K)*(K,N) → TileIR (K,M)*(N,K) → mmaf(b,a,acc) → TileIR (N,M) → Julia (M,N)
1167-
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2})
1168-
Intrinsics.mma(b, a, acc)
1171+
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}, fast_acc::Bool)
1172+
Intrinsics.mma(b, a, acc, fast_acc)
11691173
end
11701174

11711175
# Vec-mat (1D × 2D): reshape (M,) → (M, 1), MmaFOp, acc is already (M, N)
1172-
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2})
1176+
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}, fast_acc::Bool)
11731177
a2d = reshape(a, (size(a, 1), 1))
1174-
_muladd(a2d, b, acc, Val(2), Val(2))
1178+
_muladd(a2d, b, acc, Val(2), Val(2), fast_acc)
11751179
end
11761180

11771181
# Mat-vec (2D × 1D): reshape b (K,) → (K, 1), acc (M,) → (M, 1), MmaFOp, squeeze back
1178-
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1})
1182+
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}, fast_acc::Bool)
11791183
M, K = size(a, 1), size(b, 1)
11801184
b2d = reshape(b, (K, 1))
11811185
acc2d = reshape(acc, (M, 1))
1182-
result = _muladd(a, b2d, acc2d, Val(2), Val(2))
1186+
result = _muladd(a, b2d, acc2d, Val(2), Val(2), fast_acc)
11831187
reshape(result, (M,))
11841188
end
11851189

11861190
# Vec-vec (1D × 1D): not supported
1187-
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1})
1191+
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}, ::Bool)
11881192
return :(throw(ArgumentError("Vector-vector multiply-accumulate is not supported.")))
11891193
end
11901194

11911195
# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported, unsqueeze manually
1192-
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB}
1196+
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}, ::Bool) where {NB}
11931197
NB >= 3 || return :(throw(ArgumentError("unreachable")))
11941198
return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first.")))
11951199
end
1196-
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA}
1200+
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}, ::Bool) where {NA}
11971201
NA >= 3 || return :(throw(ArgumentError("unreachable")))
11981202
return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first.")))
11991203
end
@@ -1206,7 +1210,7 @@ end
12061210
# 3. MmaFOp with swapped operands: mmaf(b, a, acc)
12071211
# 4. Unflatten batch dims via reshape
12081212
@generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC},
1209-
::Val{NA}, ::Val{NB}) where {T1, T2, T3, SA, SB, SC, NA, NB}
1213+
::Val{NA}, ::Val{NB}, fast_acc::Bool) where {T1, T2, T3, SA, SB, SC, NA, NB}
12101214
sa = Tuple(SA.parameters)
12111215
sb = Tuple(SB.parameters)
12121216

@@ -1233,7 +1237,7 @@ end
12331237
b_3d = reshape(b_bc, $((K, N, B_flat)))
12341238
acc_3d = reshape(acc_bc, $((M, N, B_flat)))
12351239
# MmaFOp with swapped operands for row-major convention
1236-
result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d)
1240+
result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d, fast_acc)
12371241
# Unflatten batch dims
12381242
reshape(result_3d, $((M, N, batch_shape...)))
12391243
end

test/device/integration.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,40 @@ using CUDA
3333

3434
@test c_cpu c_ref
3535
end
36+
37+
@testset "i8/u8 matmul (mmai)" begin
38+
# i8/u8 × i8/u8 → i32 lowers to `cuda_tile.mmai`. The accumulator must be
39+
# Int32 (the `*` operator would pick the input dtype as acc and fail), so we
40+
# use `muladd` with an explicit i32 acc. Per-operand signedness is derived
41+
# from the Julia element type, so we feed values that differ under signed vs
42+
# unsigned interpretation (negatives, and magnitudes > 127) to confirm each
43+
# operand is interpreted correctly. K = 16, |product| ≤ 255² · 16 ≈ 1.0e6,
44+
# well within Int32, so the result is exact.
45+
function mmai(a::ct.TileArray{T1,2}, b::ct.TileArray{T2,2}, c::ct.TileArray{Int32,2}) where {T1,T2}
46+
ta = ct.load(a, (1, 1), (16, 16))
47+
tb = ct.load(b, (1, 1), (16, 16))
48+
tc = muladd(ta, tb, zeros(Int32, (16, 16)))
49+
ct.store(c, (1, 1), tc)
50+
return
51+
end
52+
53+
M = K = N = 16
54+
@testset "signed × signed" begin
55+
a = rand(Int8(-128):Int8(127), M, K); b = rand(Int8(-128):Int8(127), K, N)
56+
c = CUDA.zeros(Int32, M, N)
57+
@cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c)
58+
@test Array(c) == Int32.(a) * Int32.(b)
59+
end
60+
@testset "unsigned × unsigned" begin
61+
a = rand(UInt8(0):UInt8(255), M, K); b = rand(UInt8(0):UInt8(255), K, N)
62+
c = CUDA.zeros(Int32, M, N)
63+
@cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c)
64+
@test Array(c) == Int32.(a) * Int32.(b)
65+
end
66+
@testset "unsigned × signed" begin
67+
a = rand(UInt8(0):UInt8(255), M, K); b = rand(Int8(-128):Int8(127), K, N)
68+
c = CUDA.zeros(Int32, M, N)
69+
@cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c)
70+
@test Array(c) == Int32.(a) * Int32.(b)
71+
end
72+
end

test/extensions/DLFP8Types.jl

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,27 @@ end
4545
end
4646
end
4747

48+
# `fast_acc=true` is an FP8-only hint; it still lowers to `mmaf` (13.3+).
49+
@test @filecheck begin
50+
@check_label "entry"
51+
code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d},
52+
ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.3") do a, b, c
53+
ta = ct.load(a, (1, 1), (16, 16))
54+
tb = ct.load(b, (1, 1), (16, 16))
55+
@check "mmaf"
56+
ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true))
57+
return
58+
end
4859
end
4960

50-
# FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+).
51-
@testset "execution" begin
52-
if capability(device()) >= v"9"
61+
end
5362

54-
# Round-trip Float32 → FP8 → Float32 on values exactly representable in
55-
# the target FP8 type — result must match input bit-for-bit.
63+
# Execution kernels are plain top-level functions, each defined next to the
64+
# test that exercises it. Kernels parametric on accumulator dtype must stay at
65+
# top level — defining them inside a testset scope boxes them into closures.
66+
67+
# Round-trip Float32 → FP8 → Float32 on values exactly representable in the
68+
# target FP8 type — result must match input bit-for-bit.
5669
function rt_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1})
5770
pid = ct.bid(1)
5871
tile = ct.load(a, pid, (16,))
@@ -65,19 +78,8 @@ function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1})
6578
ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E5M2}, tile)))
6679
return
6780
end
68-
69-
representable = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0,
70-
16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5]
71-
let a = CuArray(representable), b = CUDA.zeros(Float32, length(representable))
72-
@cuda backend=cuTile blocks=1 rt_e4m3(a, b)
73-
@test Array(b) == representable
74-
@cuda backend=cuTile blocks=1 rt_e5m2(a, b)
75-
@test Array(b) == representable
76-
end
77-
7881
# FMA in FP8: load Float32, convert to FP8, multiply-add in FP8, convert back.
79-
# Uses inputs whose products and sums also stay representable, so the result
80-
# is exact.
82+
# Inputs whose products and sums also stay representable, so the result is exact.
8183
function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
8284
c::ct.TileArray{Float32,1}, d::ct.TileArray{Float32,1})
8385
pid = ct.bid(1)
@@ -87,6 +89,33 @@ function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
8789
ct.store(d, pid, convert(ct.Tile{Float32}, muladd.(ta, tb, tc)))
8890
return
8991
end
92+
# Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32).
93+
function mma_dl_fp8(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2},
94+
C::ct.TileArray{Tacc,2}, D::ct.TileArray{Float32,2}) where {Tacc<:Union{Float16,Float32}}
95+
a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16))
96+
ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c)))
97+
return
98+
end
99+
function mma_dl_fast(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2},
100+
C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2})
101+
a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16))
102+
ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c; fast_acc=true)))
103+
return
104+
end
105+
106+
# FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+).
107+
@testset "execution" begin
108+
if capability(device()) >= v"9"
109+
110+
representable = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0,
111+
16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5]
112+
let a = CuArray(representable), b = CUDA.zeros(Float32, length(representable))
113+
@cuda backend=cuTile blocks=1 rt_e4m3(a, b)
114+
@test Array(b) == representable
115+
@cuda backend=cuTile blocks=1 rt_e5m2(a, b)
116+
@test Array(b) == representable
117+
end
118+
90119
let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0],
91120
bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0],
92121
cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]
@@ -96,29 +125,30 @@ let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -
96125
@test Array(d) == av .* bv .+ cv
97126
end
98127

99-
# Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32).
100-
function mma_dl_f32(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2},
101-
C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2})
102-
a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16))
103-
ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c)))
104-
return
105-
end
106-
function mma_dl_f16(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2},
107-
C::ct.TileArray{Float16,2}, D::ct.TileArray{Float32,2})
108-
a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16))
109-
ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c)))
110-
return
111-
end
112-
@testset "mma → $Tacc acc" for (Tacc, kern) in ((Float32, mma_dl_f32), (Float16, mma_dl_f16))
128+
@testset "mma → $Tacc acc" for Tacc in (Float32, Float16)
113129
M = 16
114130
ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2)
115131
bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2)
116132
ch = Tacc.(Float32.(rand(0:2, M, M)))
117133
ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch)
118134
D = CUDA.zeros(Float32, M, M)
119-
@cuda backend=cuTile blocks=1 kern(CuArray(ah), CuArray(bh), CuArray(ch), D)
135+
@cuda backend=cuTile blocks=1 mma_dl_fp8(CuArray(ah), CuArray(bh), CuArray(ch), D)
120136
@test Array(D) == ref
121137
end
122138

139+
# fast_acc only has an effect on Hopper (sm_90); ignored elsewhere. So off
140+
# Hopper we assert the exact result (the flag must ride through without
141+
# perturbing the output); on Hopper we make no numeric claim.
142+
@testset "mma fast_acc (exact off Hopper)" begin
143+
M = 16
144+
ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2)
145+
bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2)
146+
ch = Float32.(rand(0:2, M, M))
147+
ref = Float32.(ah) * Float32.(bh) .+ ch
148+
D = CUDA.zeros(Float32, M, M)
149+
@cuda backend=cuTile blocks=1 mma_dl_fast(CuArray(ah), CuArray(bh), CuArray(ch), D)
150+
@test (Array(D) == ref) || (v"9" <= capability(device()) < v"10")
151+
end
152+
123153
end
124154
end

0 commit comments

Comments
 (0)