Skip to content

Commit f1aaa51

Browse files
committed
Add muladd_scaled
1 parent 48f29b0 commit f1aaa51

9 files changed

Lines changed: 787 additions & 235 deletions

File tree

ext/DLFP8TypesExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ function ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2})
1212
return ct.F8E5M2(table)
1313
end
1414

15+
# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2
16+
# operands with an f16 or f32 accumulator (f16 first/preferred), mirroring
17+
# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`.
18+
ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32)
19+
ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32)
20+
1521
# Float ↔ FP8 scalar constructor overlays (for map/convert dispatch)
1622
const FP8Types = (Float8_E4M3FN, Float8_E5M2)
1723
const StandardFloats = (Float16, ct.BFloat16, Float32, ct.TFloat32, Float64)

ext/MicrofloatsExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth(
2626
# nearest-even on f32→E8M0FNU (only `zero` and `positive_inf` are valid).
2727
ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero
2828

29+
# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2
30+
# operands with an f16 or f32 accumulator — f16 first/preferred, mirroring
31+
# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. The
32+
# other Microfloats formats (E8M0FNU, Float4_E2M1FN) are only valid as
33+
# scaled-mma operands/scales, so they stay absent and fall through to `mma`'s
34+
# unsupported-dtype error.
35+
ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32)
36+
ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32)
37+
2938
# Float ↔ microfloat scalar constructor overlays (for map/convert dispatch).
3039
# Mirrors DLFP8TypesExt: route to `Intrinsics.ftof` so kernel-side conversions
3140
# lower to the FToFOp Tile IR intrinsic instead of Microfloats' Float32-fallback

src/bytecode/encodings.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ module Opcode
9898
const Atan2Op = 110 # since 13.2
9999
const PackOp = 111 # since 13.3
100100
const UnpackOp = 112 # since 13.3
101+
# 113 (AllocaOp) not implemented
102+
const MmaFScaledOp = 114 # since 13.3
101103
end
102104

103105
# Enums for operation attributes
@@ -1310,6 +1312,30 @@ function encode_MmaIOp!(cb::CodeBuilder, result_type::TypeId,
13101312
return new_op!(cb)
13111313
end
13121314

1315+
"""
1316+
encode_MmaFScaledOp!(cb, result_type, lhs, rhs, acc, lhs_scale, rhs_scale) -> Value
1317+
1318+
Block-scaled float matrix multiply-accumulate (`acc + (lhs ⊙ lhs_scale) @
1319+
(rhs ⊙ rhs_scale)`) for low-precision (f8/f4) inputs. Each scale element scales
1320+
a contiguous block of K-dimension elements in its operand. Requires Tile IR
1321+
v13.3+.
1322+
Opcode: 114
1323+
"""
1324+
function encode_MmaFScaledOp!(cb::CodeBuilder, result_type::TypeId,
1325+
lhs::Value, rhs::Value, acc::Value,
1326+
lhs_scale::Value, rhs_scale::Value)
1327+
cb.version >= v"13.3" ||
1328+
throw(IRError("MmaFScaledOp requires Tile IR v13.3+, got v$(cb.version)"))
1329+
encode_varint!(cb.buf, Opcode.MmaFScaledOp)
1330+
encode_typeid!(cb.buf, result_type)
1331+
encode_operand!(cb.buf, lhs)
1332+
encode_operand!(cb.buf, rhs)
1333+
encode_operand!(cb.buf, acc)
1334+
encode_operand!(cb.buf, lhs_scale)
1335+
encode_operand!(cb.buf, rhs_scale)
1336+
return new_op!(cb)
1337+
end
1338+
13131339
#=============================================================================
13141340
Integer arithmetic operations
13151341
=============================================================================#

src/compiler/intrinsics/core.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,10 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args)
458458
# the table in cuTile Python's mma implementation.
459459
lhs_elem === rhs_elem ||
460460
throw(IRError("mma: float lhs and rhs must share dtype, got lhs=$lhs_elem, rhs=$rhs_elem"))
461-
allowed_acc = mma_allowed_acc_dtypes(lhs_elem)
461+
# `invokelatest` so extension-defined acc-dtype tables (e.g. FP8 via the
462+
# Microfloats/DLFP8Types exts) are visible from codegen's world age,
463+
# mirroring `lookup_bitwidth`.
464+
allowed_acc = Base.invokelatest(mma_allowed_acc_dtypes, lhs_elem)
462465
allowed_acc === nothing &&
463466
throw(IRError("mma: unsupported float input dtype $lhs_elem"))
464467
acc_elem in allowed_acc ||
@@ -480,6 +483,57 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args)
480483
CGVal(result, acc.type_id, acc.jltype, acc.shape)
481484
end
482485

486+
"""
487+
Intrinsics.mma_scaled(lhs, lhs_scale, rhs, rhs_scale, acc) -> typeof(acc)
488+
489+
Block-scaled matrix-multiply-accumulate computing `(lhs ⊙ lhs_scale) * (rhs ⊙
490+
rhs_scale) + acc`, where each scale element multiplies a contiguous block of
491+
`lhs`/`rhs` elements along the K dimension. Lowers to `cuda_tile.mmaf_scaled`
492+
(Tile IR v13.3+).
493+
494+
`lhs`/`rhs` are low-precision floats (`f8e4m3fn`, `f8e5m2`, or `f4e2m1fn`),
495+
`lhs_scale`/`rhs_scale` are `f8e8m0fnu` or `f8e4m3fn`, and `acc`/result are
496+
`f32`. The block size `K ÷ K_s` and the exact (operand, scale) dtype pairing are
497+
validated by tileiras (see its `mmaf_scaled` verifier for the supported table).
498+
"""
499+
@intrinsic mma_scaled(lhs, lhs_scale, rhs, rhs_scale, acc)
500+
tfunc(𝕃, ::typeof(Intrinsics.mma_scaled), @nospecialize(lhs), @nospecialize(lhs_scale),
501+
@nospecialize(rhs), @nospecialize(rhs_scale), @nospecialize(acc)) = CC.widenconst(acc)
502+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma_scaled), args)
503+
cb = ctx.cb
504+
505+
lhs = emit_value!(ctx, args[1])
506+
lhs_scale = emit_value!(ctx, args[2])
507+
rhs = emit_value!(ctx, args[3])
508+
rhs_scale = emit_value!(ctx, args[4])
509+
acc = emit_value!(ctx, args[5])
510+
511+
(lhs === nothing || lhs_scale === nothing || rhs === nothing ||
512+
rhs_scale === nothing || acc === nothing) &&
513+
throw(IRError("Cannot resolve operands for mma_scaled()"))
514+
515+
lhs_elem = eltype(CC.widenconst(lhs.jltype))
516+
rhs_elem = eltype(CC.widenconst(rhs.jltype))
517+
lhs_scale_elem = eltype(CC.widenconst(lhs_scale.jltype))
518+
rhs_scale_elem = eltype(CC.widenconst(rhs_scale.jltype))
519+
acc_elem = eltype(CC.widenconst(acc.jltype))
520+
521+
# Structural invariants (cheap, with clear errors); the operand/scale dtype
522+
# pairing and block-size constraints are checked by the tileiras verifier,
523+
# whose messages are already precise.
524+
lhs_elem === rhs_elem ||
525+
throw(IRError("mma_scaled: lhs and rhs must share dtype, got lhs=$lhs_elem, rhs=$rhs_elem"))
526+
lhs_scale_elem === rhs_scale_elem ||
527+
throw(IRError("mma_scaled: lhs_scale and rhs_scale must share dtype, " *
528+
"got $lhs_scale_elem and $rhs_scale_elem"))
529+
acc_elem === Float32 ||
530+
throw(IRError("mma_scaled: acc must be Float32, got $acc_elem"))
531+
532+
result = encode_MmaFScaledOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v,
533+
lhs_scale.v, rhs_scale.v)
534+
CGVal(result, acc.type_id, acc.jltype, acc.shape)
535+
end
536+
483537
# TODO: cuda_tile.module
484538

485539
"""

src/language/operations.jl

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,10 +1144,21 @@ end
11441144
Matrix multiplication
11451145
=============================================================================#
11461146

1147-
# Matrix multiply-accumulate: muladd(a, b, acc) = a * b + acc
1148-
# Handles 1D promotion, type promotion, and batched dims (≥3D).
1149-
# Note: SA, SB, SC type parameters required to avoid ambiguity with scalar methods during codegen
1147+
"""
1148+
muladd(a::Tile, b::Tile, acc::Tile) -> Tile
1149+
1150+
Matrix multiply-accumulate `a * b + acc` over tiles, lowering to
1151+
`cuda_tile.mmaf` (float) or `cuda_tile.mmai` (`i8 × i8 → i32`).
1152+
1153+
`a`/`b` are 2-D matrices `(M, K)` × `(K, N)` → `(M, N)`; a 1-D operand is
1154+
promoted (vec-mat / mat-vec) and any trailing dimensions (≥3-D) are treated as
1155+
broadcast batch dims, lifting `Base.muladd`'s shape rules to tiles. `acc`
1156+
carries the result dtype, which must be one tileiras allows for the input dtype
1157+
(f16/f32 for f16 and f8; f32 for bf16/tf32; f64 for f64; i32 for i8).
1158+
"""
11501159
@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC}
1160+
# SA, SB, SC type parameters avoid ambiguity with the scalar `muladd`
1161+
# methods during codegen.
11511162
_muladd(a, b, acc, Val(ndims(a)), Val(ndims(b)))
11521163
end
11531164

@@ -1228,6 +1239,117 @@ end
12281239
end
12291240
end
12301241

1242+
#=============================================================================
1243+
Block-scaled matrix multiply-accumulate
1244+
=============================================================================#
1245+
1246+
"""
1247+
muladd_scaled(a, a_scale, b, b_scale, acc) -> Tile
1248+
1249+
Block-scaled matrix multiply-accumulate `(a ⊙ a_scale) * (b ⊙ b_scale) + acc`,
1250+
lowering to `cuda_tile.mmaf_scaled` (Tile IR v13.3+, Blackwell). Each scale
1251+
element multiplies a contiguous block of `B = K ÷ K_s` elements along the K
1252+
dimension of its operand, so `a_scale`/`b_scale` match `a`/`b` in every
1253+
dimension except K, where they have `K_s ≤ K` entries.
1254+
1255+
`a`/`b` are low-precision floats (`f8e4m3fn`, `f8e5m2`, or `f4e2m1fn`),
1256+
`a_scale`/`b_scale` are `f8e8m0fnu` or `f8e4m3fn`, and `acc` is `f32`. Shapes
1257+
follow [`muladd`](@ref): 2-D `(M, K)` × `(K, N)`, mat-vec, and trailing batch
1258+
dims; vec-mat is unsupported (it would collapse K, leaving nothing to scale).
1259+
"""
1260+
@inline function muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile, b::Tile{Tb, SB}, b_scale::Tile,
1261+
acc::Tile) where {Ta, Tb, SA, SB}
1262+
_muladd_scaled(a, a_scale, b, b_scale, acc, Val(ndims(a)), Val(ndims(b)))
1263+
end
1264+
1265+
# 2D × 2D: swap operands (and their scales) for row-major Tile IR, exactly as
1266+
# `_muladd` swaps for `mma`.
1267+
@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile,
1268+
::Val{2}, ::Val{2})
1269+
Intrinsics.mma_scaled(b, b_scale, a, a_scale, acc)
1270+
end
1271+
1272+
# Mat-vec (2D × 1D): the K-vector `b` (and its scale) gain a trailing N=1 dim;
1273+
# `acc` becomes (M, 1); then squeeze back to (M,). K — the scaled dimension —
1274+
# is preserved, so block scaling is well-defined.
1275+
@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile,
1276+
::Val{2}, ::Val{1})
1277+
M, K, Ks = size(a, 1), size(b, 1), size(b_scale, 1)
1278+
b2d = reshape(b, (K, 1))
1279+
b_scale2d = reshape(b_scale, (Ks, 1))
1280+
acc2d = reshape(acc, (M, 1))
1281+
result = _muladd_scaled(a, a_scale, b2d, b_scale2d, acc2d, Val(2), Val(2))
1282+
reshape(result, (M,))
1283+
end
1284+
1285+
# Vec-mat (1D × 2D): promoting `a` to (M, 1) collapses K to 1, leaving no K
1286+
# dimension to block-scale. Unsupported — reshape to 2D and supply a matching
1287+
# K_s scale instead.
1288+
@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{2})
1289+
return :(throw(ArgumentError("Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled).")))
1290+
end
1291+
1292+
# Vec-vec (1D × 1D): not supported.
1293+
@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1})
1294+
return :(throw(ArgumentError("Scaled vector-vector multiply-accumulate is not supported.")))
1295+
end
1296+
1297+
# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported.
1298+
@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB}
1299+
return :(throw(ArgumentError("Batched scaled vec-mat is not supported.")))
1300+
end
1301+
@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA}
1302+
return :(throw(ArgumentError("Batched scaled mat-vec is not supported.")))
1303+
end
1304+
1305+
# Batched (≥3D × ≥3D): trailing batch dims with broadcast, mirroring `_muladd`.
1306+
# Scales carry the same batch dims as their operands; a_scale's batch must match
1307+
# a's batch (likewise b_scale/b), then both broadcast to the common batch shape.
1308+
@generated function _muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile{Tas, SAS},
1309+
b::Tile{Tb, SB}, b_scale::Tile{Tbs, SBS},
1310+
acc::Tile{Tc, SC},
1311+
::Val{NA}, ::Val{NB}) where {Ta, Tas, Tb, Tbs, Tc,
1312+
SA, SAS, SB, SBS, SC, NA, NB}
1313+
sa = Tuple(SA.parameters); sas = Tuple(SAS.parameters)
1314+
sb = Tuple(SB.parameters); sbs = Tuple(SBS.parameters)
1315+
1316+
# Matrix dims are first two; batch dims are trailing.
1317+
M = sa[1]; K = sa[2]; N = sb[2]
1318+
Ksa = sas[2] # a_scale K_s (a_scale is (M, K_s, batch...))
1319+
Ksb = sbs[1] # b_scale K_s (b_scale is (K_s, N, batch...))
1320+
a_batch = sa[3:end]; b_batch = sb[3:end]
1321+
as_batch = sas[3:end]; bs_batch = sbs[3:end]
1322+
1323+
# Broadcast batch dims (pad shorter with trailing 1s, then broadcast).
1324+
n_batch = max(length(a_batch), length(b_batch))
1325+
a_batch_padded = (a_batch..., ntuple(Returns(1), n_batch - length(a_batch))...)
1326+
b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...)
1327+
as_batch_padded = (as_batch..., ntuple(Returns(1), n_batch - length(as_batch))...)
1328+
bs_batch_padded = (bs_batch..., ntuple(Returns(1), n_batch - length(bs_batch))...)
1329+
batch_shape = map(max, a_batch_padded, b_batch_padded)
1330+
B_flat = prod(batch_shape)
1331+
1332+
quote
1333+
# Reshape + broadcast to align batch dims (still trailing).
1334+
a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...)))
1335+
b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...)))
1336+
as_bc = broadcast_to(reshape(a_scale, $((M, Ksa, as_batch_padded...))), $((M, Ksa, batch_shape...)))
1337+
bs_bc = broadcast_to(reshape(b_scale, $((Ksb, N, bs_batch_padded...))), $((Ksb, N, batch_shape...)))
1338+
acc_bc = broadcast_to(acc, $((M, N, batch_shape...)))
1339+
# Flatten batch dims to one — no permute needed since row-major Tile IR
1340+
# already has batch as the leading (slowest) dimension.
1341+
a_3d = reshape(a_bc, $((M, K, B_flat)))
1342+
b_3d = reshape(b_bc, $((K, N, B_flat)))
1343+
as_3d = reshape(as_bc, $((M, Ksa, B_flat)))
1344+
bs_3d = reshape(bs_bc, $((Ksb, N, B_flat)))
1345+
acc_3d = reshape(acc_bc, $((M, N, B_flat)))
1346+
# mmaf_scaled with swapped operands for row-major convention.
1347+
result_3d = Intrinsics.mma_scaled(b_3d, bs_3d, a_3d, as_3d, acc_3d)
1348+
# Unflatten batch dims.
1349+
reshape(result_3d, $((M, N, batch_shape...)))
1350+
end
1351+
end
1352+
12311353
# Matrix multiplication: A * B = muladd(A, B, zeros)
12321354
# Note: SA, SB type parameters required to avoid ambiguity with scalar*tile methods during codegen
12331355
#

test/extensions/DLFP8Types.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using CUDA
22
using DLFP8Types: Float8_E4M3FN, Float8_E5M2
33

44
spec1d = ct.ArraySpec{1}(16, true)
5+
spec2d = ct.ArraySpec{2}(16, true)
56

67
@testset "codegen" begin
78

@@ -31,11 +32,24 @@ end
3132
end
3233
end
3334

35+
# Non-scaled f8 matmul lowers to `mmaf` (f8 operands, f32 accumulator).
36+
@test @filecheck begin
37+
@check_label "entry"
38+
code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d},
39+
ct.TileArray{Float32,2,spec2d}}) do a, b, c
40+
ta = ct.load(a, (1, 1), (16, 16))
41+
tb = ct.load(b, (1, 1), (16, 16))
42+
@check "mmaf"
43+
ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16))))
44+
return
45+
end
46+
end
47+
3448
end
3549

36-
# FP8 types are Blackwell-only
50+
# FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+).
3751
@testset "execution" begin
38-
if capability(device()) >= v"10"
52+
if capability(device()) >= v"9"
3953

4054
# Round-trip Float32 → FP8 → Float32 on values exactly representable in
4155
# the target FP8 type — result must match input bit-for-bit.
@@ -82,5 +96,29 @@ let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -
8296
@test Array(d) == av .* bv .+ cv
8397
end
8498

99+
# Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32).
100+
function mma_dl_f32(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2},
101+
C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2})
102+
a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16))
103+
ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c)))
104+
return
105+
end
106+
function mma_dl_f16(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2},
107+
C::ct.TileArray{Float16,2}, D::ct.TileArray{Float32,2})
108+
a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16))
109+
ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c)))
110+
return
111+
end
112+
@testset "mma → $Tacc acc" for (Tacc, kern) in ((Float32, mma_dl_f32), (Float16, mma_dl_f16))
113+
M = 16
114+
ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2)
115+
bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2)
116+
ch = Tacc.(Float32.(rand(0:2, M, M)))
117+
ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch)
118+
D = CUDA.zeros(Float32, M, M)
119+
@cuda backend=cuTile blocks=1 kern(CuArray(ah), CuArray(bh), CuArray(ch), D)
120+
@test Array(D) == ref
121+
end
122+
85123
end
86124
end

0 commit comments

Comments
 (0)