Skip to content

Commit 3e14bf8

Browse files
committed
Remove muladd @inline and @generated
1 parent 674c5bb commit 3e14bf8

2 files changed

Lines changed: 67 additions & 74 deletions

File tree

ext/MicrofloatsExt.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero
2525

2626
# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2
2727
# operands with an f16 or f32 accumulator — f16 first/preferred, mirroring
28-
# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. The
29-
# other Microfloats formats (E8M0FNU, Float4_E2M1FN) are only valid as
30-
# scaled-mma operands/scales, so they stay absent and fall through to `mma`'s
31-
# unsupported-dtype error.
28+
# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`.
3229
ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32)
3330
ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32)
3431

src/language/operations.jl

Lines changed: 66 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,7 @@ carries the result dtype, which must be one tileiras allows for the input dtype
13101310
`fast_acc` enables fast accumulation (lower accumulator precision for
13111311
throughput); it is valid only for FP8 inputs and requires Tile IR v13.3+.
13121312
"""
1313-
@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC};
1313+
function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC};
13141314
fast_acc::Bool=false) where {T1, T2, T3, SA, SB, SC}
13151315
# SA, SB, SC type parameters avoid ambiguity with the scalar `muladd`
13161316
# methods during codegen.
@@ -1319,18 +1319,18 @@ end
13191319

13201320
# 2D × 2D: MmaFOp with swapped operands for row-major Tile IR
13211321
# Julia (M,K)*(K,N) → TileIR (K,M)*(N,K) → mmaf(b,a,acc) → TileIR (N,M) → Julia (M,N)
1322-
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}, fast_acc::Bool)
1322+
function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}, fast_acc::Bool)
13231323
Intrinsics.mma(b, a, acc, fast_acc)
13241324
end
13251325

13261326
# Vec-mat (1D × 2D): reshape (M,) → (M, 1), MmaFOp, acc is already (M, N)
1327-
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}, fast_acc::Bool)
1327+
function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}, fast_acc::Bool)
13281328
a2d = reshape(a, (size(a, 1), 1))
13291329
_muladd(a2d, b, acc, Val(2), Val(2), fast_acc)
13301330
end
13311331

13321332
# Mat-vec (2D × 1D): reshape b (K,) → (K, 1), acc (M,) → (M, 1), MmaFOp, squeeze back
1333-
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}, fast_acc::Bool)
1333+
function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}, fast_acc::Bool)
13341334
M, K = size(a, 1), size(b, 1)
13351335
b2d = reshape(b, (K, 1))
13361336
acc2d = reshape(acc, (M, 1))
@@ -1339,18 +1339,18 @@ end
13391339
end
13401340

13411341
# Vec-vec (1D × 1D): not supported
1342-
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}, ::Bool)
1343-
return :(throw(ArgumentError("Vector-vector multiply-accumulate is not supported.")))
1342+
function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}, ::Bool)
1343+
throw(ArgumentError("Vector-vector multiply-accumulate is not supported."))
13441344
end
13451345

13461346
# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported, unsqueeze manually
1347-
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}, ::Bool) where {NB}
1348-
NB >= 3 || return :(throw(ArgumentError("unreachable")))
1349-
return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first.")))
1347+
function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}, ::Bool) where {NB}
1348+
NB >= 3 || throw(ArgumentError("unreachable"))
1349+
throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first."))
13501350
end
1351-
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}, ::Bool) where {NA}
1352-
NA >= 3 || return :(throw(ArgumentError("unreachable")))
1353-
return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first.")))
1351+
function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}, ::Bool) where {NA}
1352+
NA >= 3 || throw(ArgumentError("unreachable"))
1353+
throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first."))
13541354
end
13551355

13561356
# Batched matmul (≥3D × ≥3D): trailing batch dims with broadcast
@@ -1360,10 +1360,10 @@ end
13601360
# 2. Flatten batch dims into one via reshape (no permute needed!)
13611361
# 3. MmaFOp with swapped operands: mmaf(b, a, acc)
13621362
# 4. Unflatten batch dims via reshape
1363-
@generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC},
1364-
::Val{NA}, ::Val{NB}, fast_acc::Bool) where {T1, T2, T3, SA, SB, SC, NA, NB}
1365-
sa = Tuple(SA.parameters)
1366-
sb = Tuple(SB.parameters)
1363+
function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC},
1364+
::Val{NA}, ::Val{NB}, fast_acc::Bool) where {T1, T2, T3, SA, SB, SC, NA, NB}
1365+
sa = size(a)
1366+
sb = size(b)
13671367

13681368
# Matrix dims are first two; batch dims are trailing
13691369
M = sa[1]; K = sa[2]; N = sb[2]
@@ -1374,24 +1374,22 @@ end
13741374
n_batch = max(length(a_batch), length(b_batch))
13751375
a_batch_padded = (a_batch..., ntuple(Returns(1), n_batch - length(a_batch))...)
13761376
b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...)
1377-
batch_shape = map(max, a_batch_padded, b_batch_padded)
1377+
batch_shape = max.(a_batch_padded, b_batch_padded)
13781378
B_flat = prod(batch_shape)
13791379

1380-
quote
1381-
# Reshape + broadcast to align batch dims (still trailing)
1382-
a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...)))
1383-
b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...)))
1384-
acc_bc = broadcast_to(acc, $((M, N, batch_shape...)))
1385-
# Flatten batch dims to one — no permute needed since row-major Tile IR
1386-
# already has batch as the leading (slowest) dimension
1387-
a_3d = reshape(a_bc, $((M, K, B_flat)))
1388-
b_3d = reshape(b_bc, $((K, N, B_flat)))
1389-
acc_3d = reshape(acc_bc, $((M, N, B_flat)))
1390-
# MmaFOp with swapped operands for row-major convention
1391-
result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d, fast_acc)
1392-
# Unflatten batch dims
1393-
reshape(result_3d, $((M, N, batch_shape...)))
1394-
end
1380+
# Reshape + broadcast to align batch dims (still trailing)
1381+
a_bc = broadcast_to(reshape(a, (M, K, a_batch_padded...)), (M, K, batch_shape...))
1382+
b_bc = broadcast_to(reshape(b, (K, N, b_batch_padded...)), (K, N, batch_shape...))
1383+
acc_bc = broadcast_to(acc, (M, N, batch_shape...))
1384+
# Flatten batch dims to one — no permute needed since row-major Tile IR
1385+
# already has batch as the leading (slowest) dimension
1386+
a_3d = reshape(a_bc, (M, K, B_flat))
1387+
b_3d = reshape(b_bc, (K, N, B_flat))
1388+
acc_3d = reshape(acc_bc, (M, N, B_flat))
1389+
# MmaFOp with swapped operands for row-major convention
1390+
result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d, fast_acc)
1391+
# Unflatten batch dims
1392+
reshape(result_3d, (M, N, batch_shape...))
13951393
end
13961394

13971395
#=============================================================================
@@ -1414,22 +1412,22 @@ dimension except K, where they have `K_s ≤ K` entries.
14141412
follow [`muladd`](@ref): 2-D `(M, K)` × `(K, N)`, mat-vec, and trailing batch
14151413
dims; vec-mat is unsupported (it would collapse K, leaving nothing to scale).
14161414
"""
1417-
@inline function muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile, b::Tile{Tb, SB}, b_scale::Tile,
1415+
function muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile, b::Tile{Tb, SB}, b_scale::Tile,
14181416
acc::Tile) where {Ta, Tb, SA, SB}
14191417
_muladd_scaled(a, a_scale, b, b_scale, acc, Val(ndims(a)), Val(ndims(b)))
14201418
end
14211419

14221420
# 2D × 2D: swap operands (and their scales) for row-major Tile IR, exactly as
14231421
# `_muladd` swaps for `mma`.
1424-
@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile,
1422+
function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile,
14251423
::Val{2}, ::Val{2})
14261424
Intrinsics.mma_scaled(b, b_scale, a, a_scale, acc)
14271425
end
14281426

14291427
# Mat-vec (2D × 1D): the K-vector `b` (and its scale) gain a trailing N=1 dim;
14301428
# `acc` becomes (M, 1); then squeeze back to (M,). K — the scaled dimension —
14311429
# is preserved, so block scaling is well-defined.
1432-
@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile,
1430+
function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile,
14331431
::Val{2}, ::Val{1})
14341432
M, K, Ks = size(a, 1), size(b, 1), size(b_scale, 1)
14351433
b2d = reshape(b, (K, 1))
@@ -1442,33 +1440,33 @@ end
14421440
# Vec-mat (1D × 2D): promoting `a` to (M, 1) collapses K to 1, leaving no K
14431441
# dimension to block-scale. Unsupported — reshape to 2D and supply a matching
14441442
# K_s scale instead.
1445-
@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{2})
1446-
return :(throw(ArgumentError("Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled).")))
1443+
function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{2})
1444+
throw(ArgumentError("Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled)."))
14471445
end
14481446

14491447
# Vec-vec (1D × 1D): not supported.
1450-
@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1})
1451-
return :(throw(ArgumentError("Scaled vector-vector multiply-accumulate is not supported.")))
1448+
function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1})
1449+
throw(ArgumentError("Scaled vector-vector multiply-accumulate is not supported."))
14521450
end
14531451

14541452
# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported.
1455-
@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB}
1456-
return :(throw(ArgumentError("Batched scaled vec-mat is not supported.")))
1453+
function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB}
1454+
throw(ArgumentError("Batched scaled vec-mat is not supported."))
14571455
end
1458-
@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA}
1459-
return :(throw(ArgumentError("Batched scaled mat-vec is not supported.")))
1456+
function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA}
1457+
throw(ArgumentError("Batched scaled mat-vec is not supported."))
14601458
end
14611459

14621460
# Batched (≥3D × ≥3D): trailing batch dims with broadcast, mirroring `_muladd`.
14631461
# Scales carry the same batch dims as their operands; a_scale's batch must match
14641462
# a's batch (likewise b_scale/b), then both broadcast to the common batch shape.
1465-
@generated function _muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile{Tas, SAS},
1466-
b::Tile{Tb, SB}, b_scale::Tile{Tbs, SBS},
1467-
acc::Tile{Tc, SC},
1468-
::Val{NA}, ::Val{NB}) where {Ta, Tas, Tb, Tbs, Tc,
1469-
SA, SAS, SB, SBS, SC, NA, NB}
1470-
sa = Tuple(SA.parameters); sas = Tuple(SAS.parameters)
1471-
sb = Tuple(SB.parameters); sbs = Tuple(SBS.parameters)
1463+
function _muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile{Tas, SAS},
1464+
b::Tile{Tb, SB}, b_scale::Tile{Tbs, SBS},
1465+
acc::Tile{Tc, SC},
1466+
::Val{NA}, ::Val{NB}) where {Ta, Tas, Tb, Tbs, Tc,
1467+
SA, SAS, SB, SBS, SC, NA, NB}
1468+
sa = size(a); sas = size(a_scale)
1469+
sb = size(b); sbs = size(b_scale)
14721470

14731471
# Matrix dims are first two; batch dims are trailing.
14741472
M = sa[1]; K = sa[2]; N = sb[2]
@@ -1483,28 +1481,26 @@ end
14831481
b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...)
14841482
as_batch_padded = (as_batch..., ntuple(Returns(1), n_batch - length(as_batch))...)
14851483
bs_batch_padded = (bs_batch..., ntuple(Returns(1), n_batch - length(bs_batch))...)
1486-
batch_shape = map(max, a_batch_padded, b_batch_padded)
1484+
batch_shape = max.(a_batch_padded, b_batch_padded)
14871485
B_flat = prod(batch_shape)
14881486

1489-
quote
1490-
# Reshape + broadcast to align batch dims (still trailing).
1491-
a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...)))
1492-
b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...)))
1493-
as_bc = broadcast_to(reshape(a_scale, $((M, Ksa, as_batch_padded...))), $((M, Ksa, batch_shape...)))
1494-
bs_bc = broadcast_to(reshape(b_scale, $((Ksb, N, bs_batch_padded...))), $((Ksb, N, batch_shape...)))
1495-
acc_bc = broadcast_to(acc, $((M, N, batch_shape...)))
1496-
# Flatten batch dims to one — no permute needed since row-major Tile IR
1497-
# already has batch as the leading (slowest) dimension.
1498-
a_3d = reshape(a_bc, $((M, K, B_flat)))
1499-
b_3d = reshape(b_bc, $((K, N, B_flat)))
1500-
as_3d = reshape(as_bc, $((M, Ksa, B_flat)))
1501-
bs_3d = reshape(bs_bc, $((Ksb, N, B_flat)))
1502-
acc_3d = reshape(acc_bc, $((M, N, B_flat)))
1503-
# mmaf_scaled with swapped operands for row-major convention.
1504-
result_3d = Intrinsics.mma_scaled(b_3d, bs_3d, a_3d, as_3d, acc_3d)
1505-
# Unflatten batch dims.
1506-
reshape(result_3d, $((M, N, batch_shape...)))
1507-
end
1487+
# Reshape + broadcast to align batch dims (still trailing).
1488+
a_bc = broadcast_to(reshape(a, (M, K, a_batch_padded...)), (M, K, batch_shape...))
1489+
b_bc = broadcast_to(reshape(b, (K, N, b_batch_padded...)), (K, N, batch_shape...))
1490+
as_bc = broadcast_to(reshape(a_scale, (M, Ksa, as_batch_padded...)), (M, Ksa, batch_shape...))
1491+
bs_bc = broadcast_to(reshape(b_scale, (Ksb, N, bs_batch_padded...)), (Ksb, N, batch_shape...))
1492+
acc_bc = broadcast_to(acc, (M, N, batch_shape...))
1493+
# Flatten batch dims to one — no permute needed since row-major Tile IR
1494+
# already has batch as the leading (slowest) dimension.
1495+
a_3d = reshape(a_bc, (M, K, B_flat))
1496+
b_3d = reshape(b_bc, (K, N, B_flat))
1497+
as_3d = reshape(as_bc, (M, Ksa, B_flat))
1498+
bs_3d = reshape(bs_bc, (Ksb, N, B_flat))
1499+
acc_3d = reshape(acc_bc, (M, N, B_flat))
1500+
# mmaf_scaled with swapped operands for row-major convention.
1501+
result_3d = Intrinsics.mma_scaled(b_3d, bs_3d, a_3d, as_3d, acc_3d)
1502+
# Unflatten batch dims.
1503+
reshape(result_3d, (M, N, batch_shape...))
15081504
end
15091505

15101506
# Matrix multiplication: A * B = muladd(A, B, zeros)

0 commit comments

Comments
 (0)