@@ -1310,7 +1310,7 @@ carries the result dtype, which must be one tileiras allows for the input dtype
13101310`fast_acc` enables fast accumulation (lower accumulator precision for
13111311throughput); it is valid only for FP8 inputs and requires Tile IR v13.3+.
13121312"""
1313- @inline function Base. muladd (a:: Tile{T1, SA} , b:: Tile{T2, SB} , acc:: Tile{T3, SC} ;
1313+ function Base. muladd (a:: Tile{T1, SA} , b:: Tile{T2, SB} , acc:: Tile{T3, SC} ;
13141314 fast_acc:: Bool = false ) where {T1, T2, T3, SA, SB, SC}
13151315 # SA, SB, SC type parameters avoid ambiguity with the scalar `muladd`
13161316 # methods during codegen.
@@ -1319,18 +1319,18 @@ end
13191319
13201320# 2D × 2D: MmaFOp with swapped operands for row-major Tile IR
13211321# Julia (M,K)*(K,N) → TileIR (K,M)*(N,K) → mmaf(b,a,acc) → TileIR (N,M) → Julia (M,N)
1322- @inline function _muladd (a:: Tile , b:: Tile , acc:: Tile , :: Val{2} , :: Val{2} , fast_acc:: Bool )
1322+ function _muladd (a:: Tile , b:: Tile , acc:: Tile , :: Val{2} , :: Val{2} , fast_acc:: Bool )
13231323 Intrinsics. mma (b, a, acc, fast_acc)
13241324end
13251325
13261326# Vec-mat (1D × 2D): reshape (M,) → (M, 1), MmaFOp, acc is already (M, N)
1327- @inline function _muladd (a:: Tile , b:: Tile , acc:: Tile , :: Val{1} , :: Val{2} , fast_acc:: Bool )
1327+ function _muladd (a:: Tile , b:: Tile , acc:: Tile , :: Val{1} , :: Val{2} , fast_acc:: Bool )
13281328 a2d = reshape (a, (size (a, 1 ), 1 ))
13291329 _muladd (a2d, b, acc, Val (2 ), Val (2 ), fast_acc)
13301330end
13311331
13321332# Mat-vec (2D × 1D): reshape b (K,) → (K, 1), acc (M,) → (M, 1), MmaFOp, squeeze back
1333- @inline function _muladd (a:: Tile , b:: Tile , acc:: Tile , :: Val{2} , :: Val{1} , fast_acc:: Bool )
1333+ function _muladd (a:: Tile , b:: Tile , acc:: Tile , :: Val{2} , :: Val{1} , fast_acc:: Bool )
13341334 M, K = size (a, 1 ), size (b, 1 )
13351335 b2d = reshape (b, (K, 1 ))
13361336 acc2d = reshape (acc, (M, 1 ))
@@ -1339,18 +1339,18 @@ end
13391339end
13401340
13411341# Vec-vec (1D × 1D): not supported
1342- @generated function _muladd (:: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{1} , :: Bool )
1343- return :( throw (ArgumentError (" Vector-vector multiply-accumulate is not supported." ) ))
1342+ function _muladd (:: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{1} , :: Bool )
1343+ throw (ArgumentError (" Vector-vector multiply-accumulate is not supported." ))
13441344end
13451345
13461346# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported, unsqueeze manually
1347- @generated function _muladd (:: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{NB} , :: Bool ) where {NB}
1348- NB >= 3 || return :( throw (ArgumentError (" unreachable" ) ))
1349- return :( throw (ArgumentError (" Batched vec-mat is not supported. Reshape the 1D operand to 2D first." ) ))
1347+ function _muladd (:: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{NB} , :: Bool ) where {NB}
1348+ NB >= 3 || throw (ArgumentError (" unreachable" ))
1349+ throw (ArgumentError (" Batched vec-mat is not supported. Reshape the 1D operand to 2D first." ))
13501350end
1351- @generated function _muladd (:: Tile , :: Tile , :: Tile , :: Val{NA} , :: Val{1} , :: Bool ) where {NA}
1352- NA >= 3 || return :( throw (ArgumentError (" unreachable" ) ))
1353- return :( throw (ArgumentError (" Batched mat-vec is not supported. Reshape the 1D operand to 2D first." ) ))
1351+ function _muladd (:: Tile , :: Tile , :: Tile , :: Val{NA} , :: Val{1} , :: Bool ) where {NA}
1352+ NA >= 3 || throw (ArgumentError (" unreachable" ))
1353+ throw (ArgumentError (" Batched mat-vec is not supported. Reshape the 1D operand to 2D first." ))
13541354end
13551355
13561356# Batched matmul (≥3D × ≥3D): trailing batch dims with broadcast
@@ -1360,10 +1360,10 @@ end
13601360# 2. Flatten batch dims into one via reshape (no permute needed!)
13611361# 3. MmaFOp with swapped operands: mmaf(b, a, acc)
13621362# 4. Unflatten batch dims via reshape
1363- @generated function _muladd (a:: Tile{T1, SA} , b:: Tile{T2, SB} , acc:: Tile{T3, SC} ,
1364- :: Val{NA} , :: Val{NB} , fast_acc:: Bool ) where {T1, T2, T3, SA, SB, SC, NA, NB}
1365- sa = Tuple (SA . parameters )
1366- sb = Tuple (SB . parameters )
1363+ function _muladd (a:: Tile{T1, SA} , b:: Tile{T2, SB} , acc:: Tile{T3, SC} ,
1364+ :: Val{NA} , :: Val{NB} , fast_acc:: Bool ) where {T1, T2, T3, SA, SB, SC, NA, NB}
1365+ sa = size (a )
1366+ sb = size (b )
13671367
13681368 # Matrix dims are first two; batch dims are trailing
13691369 M = sa[1 ]; K = sa[2 ]; N = sb[2 ]
@@ -1374,24 +1374,22 @@ end
13741374 n_batch = max (length (a_batch), length (b_batch))
13751375 a_batch_padded = (a_batch... , ntuple (Returns (1 ), n_batch - length (a_batch))... )
13761376 b_batch_padded = (b_batch... , ntuple (Returns (1 ), n_batch - length (b_batch))... )
1377- batch_shape = map ( max, a_batch_padded, b_batch_padded)
1377+ batch_shape = max .( a_batch_padded, b_batch_padded)
13781378 B_flat = prod (batch_shape)
13791379
1380- quote
1381- # Reshape + broadcast to align batch dims (still trailing)
1382- a_bc = broadcast_to (reshape (a, $ ((M, K, a_batch_padded... ))), $ ((M, K, batch_shape... )))
1383- b_bc = broadcast_to (reshape (b, $ ((K, N, b_batch_padded... ))), $ ((K, N, batch_shape... )))
1384- acc_bc = broadcast_to (acc, $ ((M, N, batch_shape... )))
1385- # Flatten batch dims to one — no permute needed since row-major Tile IR
1386- # already has batch as the leading (slowest) dimension
1387- a_3d = reshape (a_bc, $ ((M, K, B_flat)))
1388- b_3d = reshape (b_bc, $ ((K, N, B_flat)))
1389- acc_3d = reshape (acc_bc, $ ((M, N, B_flat)))
1390- # MmaFOp with swapped operands for row-major convention
1391- result_3d = Intrinsics. mma (b_3d, a_3d, acc_3d, fast_acc)
1392- # Unflatten batch dims
1393- reshape (result_3d, $ ((M, N, batch_shape... )))
1394- end
1380+ # Reshape + broadcast to align batch dims (still trailing)
1381+ a_bc = broadcast_to (reshape (a, (M, K, a_batch_padded... )), (M, K, batch_shape... ))
1382+ b_bc = broadcast_to (reshape (b, (K, N, b_batch_padded... )), (K, N, batch_shape... ))
1383+ acc_bc = broadcast_to (acc, (M, N, batch_shape... ))
1384+ # Flatten batch dims to one — no permute needed since row-major Tile IR
1385+ # already has batch as the leading (slowest) dimension
1386+ a_3d = reshape (a_bc, (M, K, B_flat))
1387+ b_3d = reshape (b_bc, (K, N, B_flat))
1388+ acc_3d = reshape (acc_bc, (M, N, B_flat))
1389+ # MmaFOp with swapped operands for row-major convention
1390+ result_3d = Intrinsics. mma (b_3d, a_3d, acc_3d, fast_acc)
1391+ # Unflatten batch dims
1392+ reshape (result_3d, (M, N, batch_shape... ))
13951393end
13961394
13971395#= ============================================================================
@@ -1414,22 +1412,22 @@ dimension except K, where they have `K_s ≤ K` entries.
14141412follow [`muladd`](@ref): 2-D `(M, K)` × `(K, N)`, mat-vec, and trailing batch
14151413dims; vec-mat is unsupported (it would collapse K, leaving nothing to scale).
14161414"""
1417- @inline function muladd_scaled (a:: Tile{Ta, SA} , a_scale:: Tile , b:: Tile{Tb, SB} , b_scale:: Tile ,
1415+ function muladd_scaled (a:: Tile{Ta, SA} , a_scale:: Tile , b:: Tile{Tb, SB} , b_scale:: Tile ,
14181416 acc:: Tile ) where {Ta, Tb, SA, SB}
14191417 _muladd_scaled (a, a_scale, b, b_scale, acc, Val (ndims (a)), Val (ndims (b)))
14201418end
14211419
14221420# 2D × 2D: swap operands (and their scales) for row-major Tile IR, exactly as
14231421# `_muladd` swaps for `mma`.
1424- @inline function _muladd_scaled (a:: Tile , a_scale:: Tile , b:: Tile , b_scale:: Tile , acc:: Tile ,
1422+ function _muladd_scaled (a:: Tile , a_scale:: Tile , b:: Tile , b_scale:: Tile , acc:: Tile ,
14251423 :: Val{2} , :: Val{2} )
14261424 Intrinsics. mma_scaled (b, b_scale, a, a_scale, acc)
14271425end
14281426
14291427# Mat-vec (2D × 1D): the K-vector `b` (and its scale) gain a trailing N=1 dim;
14301428# `acc` becomes (M, 1); then squeeze back to (M,). K — the scaled dimension —
14311429# is preserved, so block scaling is well-defined.
1432- @inline function _muladd_scaled (a:: Tile , a_scale:: Tile , b:: Tile , b_scale:: Tile , acc:: Tile ,
1430+ function _muladd_scaled (a:: Tile , a_scale:: Tile , b:: Tile , b_scale:: Tile , acc:: Tile ,
14331431 :: Val{2} , :: Val{1} )
14341432 M, K, Ks = size (a, 1 ), size (b, 1 ), size (b_scale, 1 )
14351433 b2d = reshape (b, (K, 1 ))
@@ -1442,33 +1440,33 @@ end
14421440# Vec-mat (1D × 2D): promoting `a` to (M, 1) collapses K to 1, leaving no K
14431441# dimension to block-scale. Unsupported — reshape to 2D and supply a matching
14441442# K_s scale instead.
1445- @generated function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{2} )
1446- return :( throw (ArgumentError (" Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled)." ) ))
1443+ function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{2} )
1444+ throw (ArgumentError (" Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled)." ))
14471445end
14481446
14491447# Vec-vec (1D × 1D): not supported.
1450- @generated function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{1} )
1451- return :( throw (ArgumentError (" Scaled vector-vector multiply-accumulate is not supported." ) ))
1448+ function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{1} )
1449+ throw (ArgumentError (" Scaled vector-vector multiply-accumulate is not supported." ))
14521450end
14531451
14541452# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported.
1455- @generated function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{NB} ) where {NB}
1456- return :( throw (ArgumentError (" Batched scaled vec-mat is not supported." ) ))
1453+ function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{1} , :: Val{NB} ) where {NB}
1454+ throw (ArgumentError (" Batched scaled vec-mat is not supported." ))
14571455end
1458- @generated function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{NA} , :: Val{1} ) where {NA}
1459- return :( throw (ArgumentError (" Batched scaled mat-vec is not supported." ) ))
1456+ function _muladd_scaled (:: Tile , :: Tile , :: Tile , :: Tile , :: Tile , :: Val{NA} , :: Val{1} ) where {NA}
1457+ throw (ArgumentError (" Batched scaled mat-vec is not supported." ))
14601458end
14611459
14621460# Batched (≥3D × ≥3D): trailing batch dims with broadcast, mirroring `_muladd`.
14631461# Scales carry the same batch dims as their operands; a_scale's batch must match
14641462# a's batch (likewise b_scale/b), then both broadcast to the common batch shape.
1465- @generated function _muladd_scaled (a:: Tile{Ta, SA} , a_scale:: Tile{Tas, SAS} ,
1466- b:: Tile{Tb, SB} , b_scale:: Tile{Tbs, SBS} ,
1467- acc:: Tile{Tc, SC} ,
1468- :: Val{NA} , :: Val{NB} ) where {Ta, Tas, Tb, Tbs, Tc,
1469- SA, SAS, SB, SBS, SC, NA, NB}
1470- sa = Tuple (SA . parameters ); sas = Tuple (SAS . parameters )
1471- sb = Tuple (SB . parameters ); sbs = Tuple (SBS . parameters )
1463+ function _muladd_scaled (a:: Tile{Ta, SA} , a_scale:: Tile{Tas, SAS} ,
1464+ b:: Tile{Tb, SB} , b_scale:: Tile{Tbs, SBS} ,
1465+ acc:: Tile{Tc, SC} ,
1466+ :: Val{NA} , :: Val{NB} ) where {Ta, Tas, Tb, Tbs, Tc,
1467+ SA, SAS, SB, SBS, SC, NA, NB}
1468+ sa = size (a ); sas = size (a_scale )
1469+ sb = size (b ); sbs = size (b_scale )
14721470
14731471 # Matrix dims are first two; batch dims are trailing.
14741472 M = sa[1 ]; K = sa[2 ]; N = sb[2 ]
@@ -1483,28 +1481,26 @@ end
14831481 b_batch_padded = (b_batch... , ntuple (Returns (1 ), n_batch - length (b_batch))... )
14841482 as_batch_padded = (as_batch... , ntuple (Returns (1 ), n_batch - length (as_batch))... )
14851483 bs_batch_padded = (bs_batch... , ntuple (Returns (1 ), n_batch - length (bs_batch))... )
1486- batch_shape = map ( max, a_batch_padded, b_batch_padded)
1484+ batch_shape = max .( a_batch_padded, b_batch_padded)
14871485 B_flat = prod (batch_shape)
14881486
1489- quote
1490- # Reshape + broadcast to align batch dims (still trailing).
1491- a_bc = broadcast_to (reshape (a, $ ((M, K, a_batch_padded... ))), $ ((M, K, batch_shape... )))
1492- b_bc = broadcast_to (reshape (b, $ ((K, N, b_batch_padded... ))), $ ((K, N, batch_shape... )))
1493- as_bc = broadcast_to (reshape (a_scale, $ ((M, Ksa, as_batch_padded... ))), $ ((M, Ksa, batch_shape... )))
1494- bs_bc = broadcast_to (reshape (b_scale, $ ((Ksb, N, bs_batch_padded... ))), $ ((Ksb, N, batch_shape... )))
1495- acc_bc = broadcast_to (acc, $ ((M, N, batch_shape... )))
1496- # Flatten batch dims to one — no permute needed since row-major Tile IR
1497- # already has batch as the leading (slowest) dimension.
1498- a_3d = reshape (a_bc, $ ((M, K, B_flat)))
1499- b_3d = reshape (b_bc, $ ((K, N, B_flat)))
1500- as_3d = reshape (as_bc, $ ((M, Ksa, B_flat)))
1501- bs_3d = reshape (bs_bc, $ ((Ksb, N, B_flat)))
1502- acc_3d = reshape (acc_bc, $ ((M, N, B_flat)))
1503- # mmaf_scaled with swapped operands for row-major convention.
1504- result_3d = Intrinsics. mma_scaled (b_3d, bs_3d, a_3d, as_3d, acc_3d)
1505- # Unflatten batch dims.
1506- reshape (result_3d, $ ((M, N, batch_shape... )))
1507- end
1487+ # Reshape + broadcast to align batch dims (still trailing).
1488+ a_bc = broadcast_to (reshape (a, (M, K, a_batch_padded... )), (M, K, batch_shape... ))
1489+ b_bc = broadcast_to (reshape (b, (K, N, b_batch_padded... )), (K, N, batch_shape... ))
1490+ as_bc = broadcast_to (reshape (a_scale, (M, Ksa, as_batch_padded... )), (M, Ksa, batch_shape... ))
1491+ bs_bc = broadcast_to (reshape (b_scale, (Ksb, N, bs_batch_padded... )), (Ksb, N, batch_shape... ))
1492+ acc_bc = broadcast_to (acc, (M, N, batch_shape... ))
1493+ # Flatten batch dims to one — no permute needed since row-major Tile IR
1494+ # already has batch as the leading (slowest) dimension.
1495+ a_3d = reshape (a_bc, (M, K, B_flat))
1496+ b_3d = reshape (b_bc, (K, N, B_flat))
1497+ as_3d = reshape (as_bc, (M, Ksa, B_flat))
1498+ bs_3d = reshape (bs_bc, (Ksb, N, B_flat))
1499+ acc_3d = reshape (acc_bc, (M, N, B_flat))
1500+ # mmaf_scaled with swapped operands for row-major convention.
1501+ result_3d = Intrinsics. mma_scaled (b_3d, bs_3d, a_3d, as_3d, acc_3d)
1502+ # Unflatten batch dims.
1503+ reshape (result_3d, (M, N, batch_shape... ))
15081504end
15091505
15101506# Matrix multiplication: A * B = muladd(A, B, zeros)
0 commit comments