@@ -2786,7 +2786,7 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
27862786
27872787
27882788@impl (ct .mma )
2789- def mma (x : Var , y : Var , acc : Var ) -> Var :
2789+ def mma_impl (x : Var , y : Var , acc : Var ) -> Var :
27902790 x_tile_type = require_tile_type (x )
27912791 y_tile_type = require_tile_type (y )
27922792 acc_tile_type = require_tile_type (acc )
@@ -2808,7 +2808,7 @@ def mma(x: Var, y: Var, acc: Var) -> Var:
28082808
28092809@impl (ct .matmul )
28102810@impl (operator .matmul )
2811- def matmul (x : Var , y : Var ) -> Var :
2811+ def matmul_impl (x : Var , y : Var ) -> Var :
28122812 x_tile_type = require_tile_type (x )
28132813 y_tile_type = require_tile_type (y )
28142814 x_shape_orig = x_tile_type .shape
@@ -2833,6 +2833,89 @@ def matmul(x: Var, y: Var) -> Var:
28332833 return ret
28342834
28352835
2836+ @dataclass (eq = False )
2837+ class TileMmaScaled (Operation , opcode = "tile_mma_scaled" ):
2838+ x : Var = operand ()
2839+ x_scale : Var = operand ()
2840+ y : Var = operand ()
2841+ y_scale : Var = operand ()
2842+ acc : Var = operand ()
2843+
2844+ @override
2845+ def generate_bytecode (self , ctx : BytecodeContext ) -> bc .Value :
2846+ x_value = ctx .get_value (self .x )
2847+ x_scale_value = ctx .get_value (self .x_scale )
2848+ y_value = ctx .get_value (self .y )
2849+ y_scale_value = ctx .get_value (self .y_scale )
2850+ acc_value = ctx .get_value (self .acc )
2851+ res_typeid = ctx .typeid_of (self .result_var )
2852+ return bc .encode_MmaFScaledOp (ctx .builder , res_typeid , x_value , y_value ,
2853+ acc_value , x_scale_value , y_scale_value )
2854+
2855+
2856+ def _verify_scaling_block_size (ty : TileTy , scale_ty : TileTy , k_axis : int ,
2857+ name : str , scale_name : str ):
2858+ shape = ty .shape_value
2859+ dtype = ty .dtype
2860+ scale_shape = scale_ty .shape_value
2861+ scale_dtype = scale_ty .dtype
2862+ k_axis = normalize_axis (k_axis , len (shape ))
2863+ if any (x != y for i , (x , y ) in enumerate (zip (shape , scale_shape , strict = True )) if i != k_axis ):
2864+ raise TileTypeError (
2865+ f"{ scale_name } shape { scale_shape } is not compatible with { name } shape { shape } . "
2866+ f"All dimensions except K axis { k_axis } must match" )
2867+
2868+ allowed = datatype ._get_mma_scaled_scaling_block_sizes (ty .dtype , scale_ty .dtype )
2869+ scaling_block_size , rem = divmod (shape [k_axis ], scale_shape [k_axis ])
2870+ if rem != 0 or scaling_block_size not in allowed :
2871+ raise TileTypeError (
2872+ f"For mma_scaled with dtype={ dtype } , scale_dtype={ scale_dtype } : "
2873+ f"{ name } .shape[{ k_axis } ] must be an exact multiple of { scale_name } .shape[{ k_axis } ] "
2874+ f"with scaling block size B = K // K_s in { set (allowed )} , "
2875+ f"got { name } .shape[{ k_axis } ] = { shape [k_axis ]} and "
2876+ f"{ scale_name } .shape[{ k_axis } ] = { scale_shape [k_axis ]} " )
2877+
2878+
2879+ @impl (ct .mma_scaled , min_version = BytecodeVersion .V_13_3 )
2880+ def mma_scaled_impl (x : Var , x_scale : Var , y : Var , y_scale : Var , acc : Var ) -> Var :
2881+ x_ty = require_tile_type (x )
2882+ y_ty = require_tile_type (y )
2883+ acc_ty = require_tile_type (acc )
2884+ x_scale_ty = require_tile_type (x_scale )
2885+ y_scale_ty = require_tile_type (y_scale )
2886+
2887+ for name , shape in [("x" , x_ty .shape ), ("y" , y_ty .shape ),
2888+ ("acc" , acc_ty .shape ),
2889+ ("x_scale" , x_scale_ty .shape ),
2890+ ("y_scale" , y_scale_ty .shape )]:
2891+ if len (shape ) not in [2 , 3 ]:
2892+ raise TileTypeError (
2893+ f'Expect shape of `{ name } ` to be 2D or 3D, got { shape } ' )
2894+
2895+ datatype ._resolve_mma_scaled_supported_dtype (
2896+ x_ty .dtype , x_scale_ty .dtype ,
2897+ y_ty .dtype , y_scale_ty .dtype ,
2898+ acc_ty .dtype )
2899+ _verify_scaling_block_size (x_ty , x_scale_ty , k_axis = - 1 , name = "x" , scale_name = "x_scale" )
2900+ _verify_scaling_block_size (y_ty , y_scale_ty , k_axis = - 2 , name = "y" , scale_name = "y_scale" )
2901+
2902+ x_shape , y_shape , _ , output_shape = _matmul_broadcast_shape (x_ty .shape , y_ty .shape )
2903+ if acc_ty .shape != output_shape :
2904+ raise TileTypeError (f'Expect acc shape to be { output_shape } , got { acc_ty .shape } ' )
2905+
2906+ # Broadcast scale batch dims to match the broadcasted x/y batch dims
2907+ batch = x_shape .value_types [:- 2 ]
2908+ x_scale_shape = TupleTy (batch + x_scale_ty .shape .value_types [- 2 :])
2909+ y_scale_shape = TupleTy (batch + y_scale_ty .shape .value_types [- 2 :])
2910+
2911+ x = _promote_and_broadcast_to (x , TileTy (x_ty .dtype , x_shape ))
2912+ y = _promote_and_broadcast_to (y , TileTy (y_ty .dtype , y_shape ))
2913+ x_scale = _promote_and_broadcast_to (x_scale , TileTy (x_scale_ty .dtype , x_scale_shape ))
2914+ y_scale = _promote_and_broadcast_to (y_scale , TileTy (y_scale_ty .dtype , y_scale_shape ))
2915+ return add_operation (TileMmaScaled , acc_ty ,
2916+ x = x , x_scale = x_scale , y = y , y_scale = y_scale , acc = acc )
2917+
2918+
28362919@dataclass (eq = False )
28372920class TileReduce (Operation , opcode = "tile_reduce" ):
28382921 identities : tuple [bool | int | float , ...] = attribute ()
0 commit comments