@@ -1144,10 +1144,21 @@ end
11441144 Matrix multiplication
11451145=============================================================================#
11461146
1147- # Matrix multiply-accumulate: muladd(a, b, acc) = a * b + acc
1148- # Handles 1D promotion, type promotion, and batched dims (≥3D).
1149- # Note: SA, SB, SC type parameters required to avoid ambiguity with scalar methods during codegen
1147+ """
1148+ muladd(a::Tile, b::Tile, acc::Tile) -> Tile
1149+
1150+ Matrix multiply-accumulate `a * b + acc` over tiles, lowering to
1151+ `cuda_tile.mmaf` (float) or `cuda_tile.mmai` (`i8 × i8 → i32`).
1152+
1153+ `a`/`b` are 2-D matrices `(M, K)` × `(K, N)` → `(M, N)`; a 1-D operand is
1154+ promoted (vec-mat / mat-vec) and any trailing dimensions (≥3-D) are treated as
1155+ broadcast batch dims, lifting `Base.muladd`'s shape rules to tiles. `acc`
1156+ carries the result dtype, which must be one tileiras allows for the input dtype
1157+ (f16/f32 for f16 and f8; f32 for bf16/tf32; f64 for f64; i32 for i8).
1158+ """
11501159@inline function Base. muladd (a:: Tile{T1, SA} , b:: Tile{T2, SB} , acc:: Tile{T3, SC} ) where {T1, T2, T3, SA, SB, SC}
1160+ # SA, SB, SC type parameters avoid ambiguity with the scalar `muladd`
1161+ # methods during codegen.
11511162 _muladd (a, b, acc, Val (ndims (a)), Val (ndims (b)))
11521163end
11531164
@@ -1228,6 +1239,117 @@ end
12281239 end
12291240end
12301241
1242+ #= ============================================================================
1243+ Block-scaled matrix multiply-accumulate
1244+ =============================================================================#
1245+
1246+ """
1247+ muladd_scaled(a, a_scale, b, b_scale, acc) -> Tile
1248+
1249+ Block-scaled matrix multiply-accumulate `(a ⊙ a_scale) * (b ⊙ b_scale) + acc`,
1250+ lowering to `cuda_tile.mmaf_scaled` (Tile IR v13.3+, Blackwell). Each scale
1251+ element multiplies a contiguous block of `B = K ÷ K_s` elements along the K
1252+ dimension of its operand, so `a_scale`/`b_scale` match `a`/`b` in every
1253+ dimension except K, where they have `K_s ≤ K` entries.
1254+
1255+ `a`/`b` are low-precision floats (`f8e4m3fn`, `f8e5m2`, or `f4e2m1fn`),
1256+ `a_scale`/`b_scale` are `f8e8m0fnu` or `f8e4m3fn`, and `acc` is `f32`. Shapes
1257+ follow [`muladd`](@ref): 2-D `(M, K)` × `(K, N)`, mat-vec, and trailing batch
1258+ dims; vec-mat is unsupported (it would collapse K, leaving nothing to scale).
1259+ """
1260+ @inline function muladd_scaled (a:: Tile{Ta, SA} , a_scale:: Tile , b:: Tile{Tb, SB} , b_scale:: Tile ,
1261+ acc:: Tile ) where {Ta, Tb, SA, SB}
1262+ _muladd_scaled (a, a_scale, b, b_scale, acc, Val (ndims (a)), Val (ndims (b)))
1263+ end
1264+
1265+ # 2D × 2D: swap operands (and their scales) for row-major Tile IR, exactly as
1266+ # `_muladd` swaps for `mma`.
1267+ @inline function _muladd_scaled (a:: Tile , a_scale:: Tile , b:: Tile , b_scale:: Tile , acc:: Tile ,
1268+ :: Val{2} , :: Val{2} )
1269+ Intrinsics. mma_scaled (b, b_scale, a, a_scale, acc)
1270+ end
1271+
1272+ # Mat-vec (2D × 1D): the K-vector `b` (and its scale) gain a trailing N=1 dim;
1273+ # `acc` becomes (M, 1); then squeeze back to (M,). K — the scaled dimension —
1274+ # is preserved, so block scaling is well-defined.
1275+ @inline function _muladd_scaled (a:: Tile , a_scale:: Tile , b:: Tile , b_scale:: Tile , acc:: Tile ,
1276+ :: Val{2} , :: Val{1} )
1277+ M, K, Ks = size (a, 1 ), size (b, 1 ), size (b_scale, 1 )
1278+ b2d = reshape (b, (K, 1 ))
1279+ b_scale2d = reshape (b_scale, (Ks, 1 ))
1280+ acc2d = reshape (acc, (M, 1 ))
1281+ result = _muladd_scaled (a, a_scale, b2d, b_scale2d, acc2d, Val (2 ), Val (2 ))
1282+ reshape (result, (M,))
1283+ end
1284+
1285+ # Vec-mat (1D × 2D): promoting `a` to (M, 1) collapses K to 1, leaving no K
1286+ # dimension to block-scale. Unsupported — reshape to 2D and supply a matching
1287+ # K_s scale instead.
1288+ @generated function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{2} )
1289+ return :(throw (ArgumentError (" Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled)." )))
1290+ end
1291+
1292+ # Vec-vec (1D × 1D): not supported.
1293+ @generated function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{1} )
1294+ return :(throw (ArgumentError (" Scaled vector-vector multiply-accumulate is not supported." )))
1295+ end
1296+
1297+ # Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported.
1298+ @generated function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{NB} ) where {NB}
1299+ return :(throw (ArgumentError (" Batched scaled vec-mat is not supported." )))
1300+ end
1301+ @generated function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{NA} , :: Val{1} ) where {NA}
1302+ return :(throw (ArgumentError (" Batched scaled mat-vec is not supported." )))
1303+ end
1304+
1305+ # Batched (≥3D × ≥3D): trailing batch dims with broadcast, mirroring `_muladd`.
1306+ # Scales carry the same batch dims as their operands; a_scale's batch must match
1307+ # a's batch (likewise b_scale/b), then both broadcast to the common batch shape.
1308+ @generated function _muladd_scaled (a:: Tile{Ta, SA} , a_scale:: Tile{Tas, SAS} ,
1309+ b:: Tile{Tb, SB} , b_scale:: Tile{Tbs, SBS} ,
1310+ acc:: Tile{Tc, SC} ,
1311+ :: Val{NA} , :: Val{NB} ) where {Ta, Tas, Tb, Tbs, Tc,
1312+ SA, SAS, SB, SBS, SC, NA, NB}
1313+ sa = Tuple (SA. parameters); sas = Tuple (SAS. parameters)
1314+ sb = Tuple (SB. parameters); sbs = Tuple (SBS. parameters)
1315+
1316+ # Matrix dims are first two; batch dims are trailing.
1317+ M = sa[1 ]; K = sa[2 ]; N = sb[2 ]
1318+ Ksa = sas[2 ] # a_scale K_s (a_scale is (M, K_s, batch...))
1319+ Ksb = sbs[1 ] # b_scale K_s (b_scale is (K_s, N, batch...))
1320+ a_batch = sa[3 : end ]; b_batch = sb[3 : end ]
1321+ as_batch = sas[3 : end ]; bs_batch = sbs[3 : end ]
1322+
1323+ # Broadcast batch dims (pad shorter with trailing 1s, then broadcast).
1324+ n_batch = max (length (a_batch), length (b_batch))
1325+ a_batch_padded = (a_batch... , ntuple (Returns (1 ), n_batch - length (a_batch))... )
1326+ b_batch_padded = (b_batch... , ntuple (Returns (1 ), n_batch - length (b_batch))... )
1327+ as_batch_padded = (as_batch... , ntuple (Returns (1 ), n_batch - length (as_batch))... )
1328+ bs_batch_padded = (bs_batch... , ntuple (Returns (1 ), n_batch - length (bs_batch))... )
1329+ batch_shape = map (max, a_batch_padded, b_batch_padded)
1330+ B_flat = prod (batch_shape)
1331+
1332+ quote
1333+ # Reshape + broadcast to align batch dims (still trailing).
1334+ a_bc = broadcast_to (reshape (a, $ ((M, K, a_batch_padded... ))), $ ((M, K, batch_shape... )))
1335+ b_bc = broadcast_to (reshape (b, $ ((K, N, b_batch_padded... ))), $ ((K, N, batch_shape... )))
1336+ as_bc = broadcast_to (reshape (a_scale, $ ((M, Ksa, as_batch_padded... ))), $ ((M, Ksa, batch_shape... )))
1337+ bs_bc = broadcast_to (reshape (b_scale, $ ((Ksb, N, bs_batch_padded... ))), $ ((Ksb, N, batch_shape... )))
1338+ acc_bc = broadcast_to (acc, $ ((M, N, batch_shape... )))
1339+ # Flatten batch dims to one — no permute needed since row-major Tile IR
1340+ # already has batch as the leading (slowest) dimension.
1341+ a_3d = reshape (a_bc, $ ((M, K, B_flat)))
1342+ b_3d = reshape (b_bc, $ ((K, N, B_flat)))
1343+ as_3d = reshape (as_bc, $ ((M, Ksa, B_flat)))
1344+ bs_3d = reshape (bs_bc, $ ((Ksb, N, B_flat)))
1345+ acc_3d = reshape (acc_bc, $ ((M, N, B_flat)))
1346+ # mmaf_scaled with swapped operands for row-major convention.
1347+ result_3d = Intrinsics. mma_scaled (b_3d, bs_3d, a_3d, as_3d, acc_3d)
1348+ # Unflatten batch dims.
1349+ reshape (result_3d, $ ((M, N, batch_shape... )))
1350+ end
1351+ end
1352+
12311353# Matrix multiplication: A * B = muladd(A, B, zeros)
12321354# Note: SA, SB type parameters required to avoid ambiguity with scalar*tile methods during codegen
12331355#
0 commit comments