Skip to content

Commit 8ce0189

Browse files
committed
mma_scaled
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent f4f0a36 commit 8ce0189

File tree

5 files changed

+492
-2
lines changed

5 files changed

+492
-2
lines changed

changelog.d/future/mma-scaled.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- New `ct.mma_scaled()` operation for block-scaled matrix multiply-accumulate.
5+
Supported input dtypes: `float8_e4m3fn`, `float8_e5m2`, `float4_e2m1fn`.
6+
Supported scale dtypes: `float8_e8m0fnu`, `float8_e4m3fn` (f4 inputs only).

src/cuda/tile/_datatype.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,53 @@ def _resolve_mma_supported_dtype(x_dtype: DType,
436436
return acc_dtype
437437

438438

439+
_mma_scaled_supported_dtypes = {
440+
# operand dtype -> {scale dtype: (result dtype, scaling block sizes)}
441+
float8_e4m3fn: {float8_e8m0fnu: (float32, (32,))},
442+
float8_e5m2: {float8_e8m0fnu: (float32, (32,))},
443+
float4_e2m1fn: {float8_e8m0fnu: (float32, (16, 32)),
444+
float8_e4m3fn: (float32, (16,))},
445+
}
446+
447+
448+
def _resolve_mma_scaled_supported_dtype(x_dtype: DType,
449+
x_scale_dtype: DType,
450+
y_dtype: DType,
451+
y_scale_dtype: DType,
452+
acc_dtype: DType):
453+
if x_dtype != y_dtype:
454+
raise TileTypeError(
455+
f"x and y must have the same dtype, got {x_dtype} and {y_dtype}")
456+
if x_scale_dtype != y_scale_dtype:
457+
raise TileTypeError(
458+
f"x_scale and y_scale must have the same dtype, "
459+
f"got {x_scale_dtype} and {y_scale_dtype}")
460+
if x_dtype not in _mma_scaled_supported_dtypes:
461+
candidates = ", ".join(str(d) for d in _mma_scaled_supported_dtypes.keys())
462+
raise TileTypeError(
463+
f"Unsupported input dtype {x_dtype} for mma_scaled, "
464+
f"supported input dtypes are {candidates}")
465+
scale_candidates = _mma_scaled_supported_dtypes[x_dtype]
466+
if x_scale_dtype not in scale_candidates:
467+
candidate_names = ", ".join(str(s) for s in scale_candidates.keys())
468+
raise TileTypeError(
469+
f"Unsupported scale dtype {x_scale_dtype} for input dtype {x_dtype}, "
470+
f"supported scale dtypes are {candidate_names}")
471+
expected_acc, _ = scale_candidates[x_scale_dtype]
472+
if acc_dtype != expected_acc:
473+
raise TileTypeError(
474+
f"Unsupported acc dtype {acc_dtype} for mma_scaled, "
475+
f"expected {expected_acc}")
476+
477+
478+
def _get_mma_scaled_scaling_block_sizes(data_dtype, scale_dtype) -> Tuple[int, ...]:
479+
assert data_dtype in _mma_scaled_supported_dtypes
480+
scale_candidates = _mma_scaled_supported_dtypes[data_dtype]
481+
assert scale_dtype in scale_candidates
482+
_, scaling_block_sizes = scale_candidates[scale_dtype]
483+
return scaling_block_sizes
484+
485+
439486
# =============== Documentation Generator ================
440487

441488
def _generate_rst_dtype_promotion_table() -> str:

src/cuda/tile/_ir/ops.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,7 +2786,7 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
27862786

27872787

27882788
@impl(ct.mma)
2789-
def mma(x: Var, y: Var, acc: Var) -> Var:
2789+
def mma_impl(x: Var, y: Var, acc: Var) -> Var:
27902790
x_tile_type = require_tile_type(x)
27912791
y_tile_type = require_tile_type(y)
27922792
acc_tile_type = require_tile_type(acc)
@@ -2808,7 +2808,7 @@ def mma(x: Var, y: Var, acc: Var) -> Var:
28082808

28092809
@impl(ct.matmul)
28102810
@impl(operator.matmul)
2811-
def matmul(x: Var, y: Var) -> Var:
2811+
def matmul_impl(x: Var, y: Var) -> Var:
28122812
x_tile_type = require_tile_type(x)
28132813
y_tile_type = require_tile_type(y)
28142814
x_shape_orig = x_tile_type.shape
@@ -2833,6 +2833,89 @@ def matmul(x: Var, y: Var) -> Var:
28332833
return ret
28342834

28352835

2836+
@dataclass(eq=False)
2837+
class TileMmaScaled(Operation, opcode="tile_mma_scaled"):
2838+
x: Var = operand()
2839+
x_scale: Var = operand()
2840+
y: Var = operand()
2841+
y_scale: Var = operand()
2842+
acc: Var = operand()
2843+
2844+
@override
2845+
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
2846+
x_value = ctx.get_value(self.x)
2847+
x_scale_value = ctx.get_value(self.x_scale)
2848+
y_value = ctx.get_value(self.y)
2849+
y_scale_value = ctx.get_value(self.y_scale)
2850+
acc_value = ctx.get_value(self.acc)
2851+
res_typeid = ctx.typeid_of(self.result_var)
2852+
return bc.encode_MmaFScaledOp(ctx.builder, res_typeid, x_value, y_value,
2853+
acc_value, x_scale_value, y_scale_value)
2854+
2855+
2856+
def _verify_scaling_block_size(ty: TileTy, scale_ty: TileTy, k_axis: int,
2857+
name: str, scale_name: str):
2858+
shape = ty.shape_value
2859+
dtype = ty.dtype
2860+
scale_shape = scale_ty.shape_value
2861+
scale_dtype = scale_ty.dtype
2862+
k_axis = normalize_axis(k_axis, len(shape))
2863+
if any(x != y for i, (x, y) in enumerate(zip(shape, scale_shape, strict=True)) if i != k_axis):
2864+
raise TileTypeError(
2865+
f"{scale_name} shape {scale_shape} is not compatible with {name} shape {shape}. "
2866+
f"All dimensions except K axis {k_axis} must match")
2867+
2868+
allowed = datatype._get_mma_scaled_scaling_block_sizes(ty.dtype, scale_ty.dtype)
2869+
scaling_block_size, rem = divmod(shape[k_axis], scale_shape[k_axis])
2870+
if rem != 0 or scaling_block_size not in allowed:
2871+
raise TileTypeError(
2872+
f"For mma_scaled with dtype={dtype}, scale_dtype={scale_dtype}: "
2873+
f"{name}.shape[{k_axis}] must be an exact multiple of {scale_name}.shape[{k_axis}] "
2874+
f"with scaling block size B = K // K_s in {set(allowed)}, "
2875+
f"got {name}.shape[{k_axis}] = {shape[k_axis]} and "
2876+
f"{scale_name}.shape[{k_axis}] = {scale_shape[k_axis]}")
2877+
2878+
2879+
@impl(ct.mma_scaled, min_version=BytecodeVersion.V_13_3)
2880+
def mma_scaled_impl(x: Var, x_scale: Var, y: Var, y_scale: Var, acc: Var) -> Var:
2881+
x_ty = require_tile_type(x)
2882+
y_ty = require_tile_type(y)
2883+
acc_ty = require_tile_type(acc)
2884+
x_scale_ty = require_tile_type(x_scale)
2885+
y_scale_ty = require_tile_type(y_scale)
2886+
2887+
for name, shape in [("x", x_ty.shape), ("y", y_ty.shape),
2888+
("acc", acc_ty.shape),
2889+
("x_scale", x_scale_ty.shape),
2890+
("y_scale", y_scale_ty.shape)]:
2891+
if len(shape) not in [2, 3]:
2892+
raise TileTypeError(
2893+
f'Expect shape of `{name}` to be 2D or 3D, got {shape}')
2894+
2895+
datatype._resolve_mma_scaled_supported_dtype(
2896+
x_ty.dtype, x_scale_ty.dtype,
2897+
y_ty.dtype, y_scale_ty.dtype,
2898+
acc_ty.dtype)
2899+
_verify_scaling_block_size(x_ty, x_scale_ty, k_axis=-1, name="x", scale_name="x_scale")
2900+
_verify_scaling_block_size(y_ty, y_scale_ty, k_axis=-2, name="y", scale_name="y_scale")
2901+
2902+
x_shape, y_shape, _, output_shape = _matmul_broadcast_shape(x_ty.shape, y_ty.shape)
2903+
if acc_ty.shape != output_shape:
2904+
raise TileTypeError(f'Expect acc shape to be {output_shape}, got {acc_ty.shape}')
2905+
2906+
# Broadcast scale batch dims to match the broadcasted x/y batch dims
2907+
batch = x_shape.value_types[:-2]
2908+
x_scale_shape = TupleTy(batch + x_scale_ty.shape.value_types[-2:])
2909+
y_scale_shape = TupleTy(batch + y_scale_ty.shape.value_types[-2:])
2910+
2911+
x = _promote_and_broadcast_to(x, TileTy(x_ty.dtype, x_shape))
2912+
y = _promote_and_broadcast_to(y, TileTy(y_ty.dtype, y_shape))
2913+
x_scale = _promote_and_broadcast_to(x_scale, TileTy(x_scale_ty.dtype, x_scale_shape))
2914+
y_scale = _promote_and_broadcast_to(y_scale, TileTy(y_scale_ty.dtype, y_scale_shape))
2915+
return add_operation(TileMmaScaled, acc_ty,
2916+
x=x, x_scale=x_scale, y=y, y_scale=y_scale, acc=acc)
2917+
2918+
28362919
@dataclass(eq=False)
28372920
class TileReduce(Operation, opcode="tile_reduce"):
28382921
identities: tuple[bool | int | float, ...] = attribute()

src/cuda/tile/_stub.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,61 @@ def mma(x, y, /, acc) -> Tile:
10051005
"""
10061006

10071007

1008+
@function
1009+
def mma_scaled(x, x_scale, y, y_scale, /, acc) -> Tile:
1010+
"""Block-scaled matrix multiply-accumulate.
1011+
1012+
Computes a matrix multiply-accumulate where inputs are scaled by block scales
1013+
along the K dimension before the mma::
1014+
1015+
result[i, j] = sum(x[i, k] * x_scale[i, k // B] * y[k, j] * y_scale[k // B, j]
1016+
for k in range(K)) + acc[i, j]
1017+
1018+
The scaling block size is ``B = K // K_s``, where ``K_s`` is the K dimension of the scale tile.
1019+
``K`` must be divisible by ``K_s``, and ``B`` must be one of the allowed values listed
1020+
in the table below.
1021+
1022+
Args:
1023+
x (Tile): LHS input, 2D or 3D ``[..., M, K]``.
1024+
x_scale (Tile): Scale factors for x, shape ``[..., M, K_s]``.
1025+
All dimensions except K_s must match x exactly.
1026+
y (Tile): RHS input, 2D or 3D ``[..., K, N]``.
1027+
y_scale (Tile): Scale factors for y, shape ``[..., K_s, N]``.
1028+
All dimensions except K_s must match y exactly.
1029+
acc (Tile): Accumulator ``[..., M, N]``.
1030+
1031+
Supported datatypes and scaling block sizes:
1032+
1033+
+----------------------------+------------+---------+--------+
1034+
| Input (x/y) | Scale | Acc/Out | B |
1035+
+============================+============+=========+========+
1036+
| f8e4m3fn, f8e5m2 | f8e8m0fnu | f32 | 32 |
1037+
+----------------------------+------------+---------+--------+
1038+
| f4e2m1fn | f8e8m0fnu | f32 | 16, 32 |
1039+
+----------------------------+------------+---------+--------+
1040+
| f4e2m1fn | f8e4m3fn | f32 | 16 |
1041+
+----------------------------+------------+---------+--------+
1042+
1043+
Batch dimensions of x and y are broadcast against each other (same as
1044+
:func:`mma`). x_scale's batch dimension must match x's batch exactly,
1045+
and y_scale's batch dimension must match y's batch exactly; both are
1046+
then broadcast to the output batch shape.
1047+
1048+
Returns:
1049+
Tile:
1050+
1051+
Example:
1052+
1053+
>>> # B = K // K_s = 64 // 2 = 32
1054+
>>> tx = ct.ones((16, 64), ct.float8_e4m3fn)
1055+
>>> sx = ct.ones((16, 2), ct.float8_e8m0fnu)
1056+
>>> ty = ct.ones((64, 16), ct.float8_e4m3fn)
1057+
>>> sy = ct.ones((2, 16), ct.float8_e8m0fnu)
1058+
>>> acc = ct.zeros((16, 16), ct.float32)
1059+
>>> tz = ct.mma_scaled(tx, sx, ty, sy, acc)
1060+
"""
1061+
1062+
10081063
@function
10091064
def matmul(x, y, /) -> Tile:
10101065
"""Performs matrix multiply on the given tiles.

0 commit comments

Comments
 (0)