66import operator
77from contextlib import contextmanager
88from dataclasses import dataclass
9- from typing import Sequence , Tuple , Optional , Union , Any , List , Callable , Iterator
9+ from typing import Sequence , Tuple , Optional , Union , Any , List , Callable , Iterator , Iterable
1010
1111from typing_extensions import override
1212
6161 DType , is_integral , is_float , is_signed , is_boolean , is_restricted_float ,
6262)
6363from cuda .tile ._ir2bytecode import (
64- lower_reduce ,
65- lower_reduce_argmax_argmin , lower_scan ,
66- BytecodeContext , typeid ,
64+ lower_scan , BytecodeContext , typeid ,
6765 generate_bytecode_for_block , convert_dtype , get_list_item_repr_size_in_words ,
6866 get_list_partition_view_tile_size , tensor_view_typeid , tensor_view_typeid_for_list
6967)
@@ -2894,40 +2892,153 @@ def matmul(x: Var, y: Var) -> Var:
28942892
28952893
28962894class TileReduce (TypedOperation ):
2897- def __init__ (self , fn : str , x : Var , axis : int ,
2898- rounding_mode : Optional [RoundingMode ], flush_to_zero : bool ,
2899- result_var : Var , loc : Loc ):
2895+ def __init__ (self , xs : tuple [Var , ...], identities : tuple [bool | int | float , ...], axis : int ,
2896+ body : Block , result_vars : tuple [Var , ...], loc : Loc ):
29002897 super ().__init__ (
29012898 "tile_reduce" ,
2902- operands = {"x" : x },
2903- attributes = {
2904- "fn" : fn , "axis" : axis ,
2905- "rounding_mode" : rounding_mode ,
2906- "flush_to_zero" : flush_to_zero ,
2907- },
2908- result_vars = [result_var ],
2899+ operands = {"xs" : xs },
2900+ attributes = {"identities" : identities , "axis" : axis },
2901+ nested_blocks = [body ],
2902+ result_vars = result_vars ,
29092903 loc = loc ,
29102904 )
29112905
2906+ @property
2907+ def body (self ):
2908+ return self .nested_blocks [0 ]
2909+
29122910 @override
2913- def generate_bytecode (self , ctx : BytecodeContext ) -> bc .Value :
2914- x_type = ctx .typeof (self .x )
2915- x_value = ctx .get_value (self .x )
2916- res_type = ctx .typeof (self .result_var )
2917- return lower_reduce (
2918- ctx , x_value , x_type , self .axis , res_type , self .fn ,
2919- self .rounding_mode , self .flush_to_zero
2911+ def _to_string_block_prefixes (self ) -> List [str ]:
2912+ return ["do" ]
2913+
2914+ @override
2915+ def generate_bytecode (self , ctx : BytecodeContext ) -> tuple [bc .Value , ...]:
2916+ xs = tuple (ctx .get_value (x ) for x in self .xs )
2917+ res_typeids = tuple (ctx .typeid_of (v ) for v in self .result_vars )
2918+
2919+ identities = []
2920+ param_type_ids = []
2921+ for id_val , x in zip (self .identities , self .xs , strict = True ):
2922+ x_dtype = get_dtype (x .get_type ())
2923+ x_dtype_id = typeid (ctx .type_table , x_dtype , wrap_scalars = False )
2924+ if datatype .is_float (x_dtype ):
2925+ x_dtype_bc = x_dtype ._bytecode_type
2926+ attr = bc .Float (float (id_val ), x_dtype_bc , ctx .type_table )
2927+ elif datatype .is_boolean (x_dtype ):
2928+ attr = bc .Bool (bool (id_val ))
2929+ else :
2930+ assert datatype .is_integral (x_dtype )
2931+ attr = bc .Integer (x_dtype_id , x_dtype .bitwidth , int (id_val ))
2932+ identities .append (attr )
2933+
2934+ x_tile_typeid = ctx .type_table .tile (x_dtype_id , ())
2935+ param_type_ids .append (x_tile_typeid )
2936+ param_type_ids .append (x_tile_typeid )
2937+
2938+ nested_builder = bc .encode_ReduceOp (
2939+ ctx .builder ,
2940+ result_types = res_typeids ,
2941+ operands = xs ,
2942+ dim = self .axis ,
2943+ identities = identities
29202944 )
29212945
2946+ with nested_builder .new_block (param_type_ids ) as block_args :
2947+ for var , value in zip (self .body .params , block_args , strict = True ):
2948+ ctx .set_value (var , value )
2949+ generate_bytecode_for_block (ctx , self .body )
2950+
2951+ return nested_builder .done ()
2952+
2953+
2954+ def raw_reduce (xs : tuple [Var , ...], identities : tuple [bool | int | float ], axis : int ,
2955+ body : Callable [[tuple [Var , ...], tuple [Var , ...]], tuple [Var , ...]]
2956+ ) -> tuple [Var , ...]:
2957+ builder = Builder .get_current ()
2958+
2959+ block_params = []
2960+ lhs_vars = []
2961+ rhs_vars = []
2962+ input_shape = ()
2963+ for i , x in enumerate (xs ):
2964+ x_ty = x .get_type ()
2965+ assert isinstance (x_ty , TileTy )
2966+ if i == 0 :
2967+ input_shape = x_ty .shape_value
2968+ else :
2969+ assert input_shape == x_ty .shape_value
2970+ tile_0d_ty = make_tile_ty (x_ty .dtype , ())
2971+ for _ in range (2 ):
2972+ var = builder .ir_ctx .make_temp (builder .loc )
2973+ var .set_type (tile_0d_ty )
2974+ block_params .append (var )
2975+ lhs_vars .append (block_params [- 2 ])
2976+ rhs_vars .append (block_params [- 1 ])
2977+
2978+ assert 0 <= axis < len (input_shape )
2979+ result_shape = input_shape [:axis ] + input_shape [axis + 1 :]
2980+ result_types = tuple (make_tile_ty (x .get_type ().dtype , result_shape ) for x in xs )
2981+
2982+ assert len (xs ) == len (identities )
2983+
2984+ with nested_block (builder .loc ) as body_block :
2985+ body_block .params = tuple (block_params )
2986+ body_results = body (tuple (lhs_vars ), tuple (rhs_vars ))
2987+ for body_res , x in zip (body_results , xs , strict = True ):
2988+ body_res_ty = body_res .get_type ()
2989+ assert isinstance (body_res_ty , TileTy )
2990+ assert body_res_ty .shape_value == ()
2991+ assert body_res_ty .dtype == x .get_type ().dtype
2992+
2993+ add_operation (EndBranch , (), outputs = body_results )
2994+
2995+ return add_operation (TileReduce , result_types , xs = xs , identities = identities , axis = axis ,
2996+ body = body_block )
2997+
2998+
2999+ def reduce (xs : tuple [Var , ...], identities : tuple [bool | int | float , ...],
3000+ axis : int | None | Iterable [int ], keepdims : bool ,
3001+ body : Callable [[tuple [Var , ...], tuple [Var , ...]], tuple [Var , ...]]
3002+ ) -> tuple [Var , ...]:
3003+ if len (xs ) == 0 :
3004+ raise TileTypeError ("Need at least one input value to reduce" )
3005+
3006+ if len (xs ) != len (identities ):
3007+ raise TileTypeError (f"Number of input values ({ len (xs )} ) doesn't match the"
3008+ f" number of identities ({ len (identities )} )" )
3009+
3010+ common_input_shape = ()
3011+
3012+ x_types = tuple (require_tile_type (x ) for x in xs )
3013+ for x_ty in x_types :
3014+ try :
3015+ common_input_shape = broadcast_shapes2 (common_input_shape , x_ty .shape_value )
3016+ except BroadcastError :
3017+ all_shapes = ", " .join (str (ty .shape_value ) for ty in x_types )
3018+ raise TileTypeError (f"Input shapes { all_shapes } "
3019+ f" are not broadcastable to a common shape" )
3020+
3021+ if axis is None :
3022+ axis = tuple (range (len (common_input_shape )))
3023+ else :
3024+ if isinstance (axis , int ):
3025+ axis = (axis ,)
3026+ axis = sorted (normalize_axis (a , len (common_input_shape )) for a in axis )
3027+ for a1 , a2 in zip (axis , axis [1 :]):
3028+ if a1 == a2 :
3029+ raise TileTypeError (f"Repeated reduction axis { a1 } " )
3030+
3031+ xs = tuple (broadcast_to (x , common_input_shape ) for x in xs )
3032+ for i , a in enumerate (axis ):
3033+ xs = raw_reduce (xs , identities , a - i , body )
3034+
3035+ result_shape = _get_reduction_shape (common_input_shape , axis , keepdims )
3036+ return tuple (reshape (x , result_shape ) for x in xs )
3037+
29223038
29233039def _get_reduction_shape (shape : Tuple [int , ...],
2924- normalized_axis : int | Tuple [int , ...] | None ,
3040+ normalized_axis : Tuple [int , ...],
29253041 keepdims : bool ) -> Tuple [int , ...]:
2926- if normalized_axis is None :
2927- normalized_axis = tuple (range (len (shape )))
2928- if isinstance (normalized_axis , int ):
2929- normalized_axis = (normalized_axis ,)
2930- normalized_axis = set (normalized_axis )
29313042 ret = []
29323043 for i , size in enumerate (shape ):
29333044 if i in normalized_axis :
@@ -2938,29 +3049,46 @@ def _get_reduction_shape(shape: Tuple[int, ...],
29383049 return tuple (ret )
29393050
29403051
2941- def reduce (fn : str , x : Var , axis : Optional [ tuple [int , ...] ], keepdims : bool ,
2942- rounding_mode : Optional [RoundingMode ] = None ,
2943- flush_to_zero : bool = False ) -> Var :
3052+ def reduce_simple (fn : str , x : Var , axis : int | None | tuple [int , ...], keepdims : bool ,
3053+ rounding_mode : Optional [RoundingMode ] = None ,
3054+ flush_to_zero : bool = False ) -> Var :
29443055 x_type = require_tile_type (x )
29453056 check_rd_and_ftz (fn , rounding_mode , flush_to_zero , x_type .dtype )
2946- x_shape = x_type .shape
2947- rank = len (x_shape )
2948- if axis is None :
2949- axis = tuple (range (rank ))
2950- else :
2951- axis = tuple ([normalize_axis (axis_value , rank ) for axis_value in axis ])
29523057
2953- x_dtype = datatype .default_int_type if datatype .is_boolean (x_type .dtype ) else x_type .dtype
2954- x = _promote_and_broadcast_to (x , TileTy (x_dtype , x_shape ))
2955- for i , axis_value in enumerate (axis ):
2956- axis_value -= i
2957- x_shape = x_shape [:axis_value ] + x_shape [axis_value + 1 :]
2958- x = add_operation (
2959- TileReduce , TileTy (x_dtype , TupleTy (x_shape )),
2960- fn = fn , x = x , axis = axis_value ,
2961- rounding_mode = rounding_mode , flush_to_zero = flush_to_zero
2962- )
2963- return reshape (x , _get_reduction_shape (x_type .shape_value , axis , keepdims ))
3058+ if datatype .is_boolean (x_type .dtype ):
3059+ x = astype (x , datatype .default_int_type )
3060+
3061+ match fn :
3062+ case "add" : id_val = 0
3063+ case "mul" : id_val = 1
3064+ case "min" : id_val = _get_min_max (x_type .dtype )[1 ]
3065+ case "max" : id_val = _get_min_max (x_type .dtype )[0 ]
3066+ case _: assert False
3067+
3068+ def body (lhs : tuple [Var ], rhs : tuple [Var ]) -> tuple [Var ]:
3069+ [lhs ], [rhs ] = lhs , rhs
3070+ ret = raw_binary_arithmetic (fn , lhs , rhs ,
3071+ rounding_mode = rounding_mode , flush_to_zero = flush_to_zero )
3072+ return (ret ,)
3073+
3074+ [ret ] = reduce ((x ,), (id_val ,), axis , keepdims , body )
3075+ return ret
3076+
3077+
3078+ Limits = Tuple [float , float ] | Tuple [int , int ]
3079+
3080+
3081+ def _get_min_max (dtype : datatype .DType ) -> Limits :
3082+ use_float = datatype .is_float (dtype )
3083+ if use_float :
3084+ if dtype in [datatype .float16 , datatype .bfloat16 , datatype .float32 , datatype .float64 ]:
3085+ return - float ("inf" ), float ("inf" )
3086+ else :
3087+ raise NotImplementedError (f"Unsupported float dtype: { dtype } " )
3088+ elif datatype .is_signed (dtype ):
3089+ return - (1 << (dtype .bitwidth - 1 )), (1 << (dtype .bitwidth - 1 )) - 1
3090+ else :
3091+ return 0 , (1 << dtype .bitwidth ) - 1
29643092
29653093
29663094def _parse_reduce_axis (axis : Var ) -> Optional [tuple [int , ...]]:
@@ -2981,7 +3109,8 @@ def reduce_impl_with_rd_and_ftz(fn: str, x: Var, axis: Var, keepdims: Var, round
29813109 keepdims = require_constant_bool (keepdims )
29823110 rounding_mode = require_optional_constant_enum (rounding_mode , RoundingMode )
29833111 flush_to_zero = require_constant_bool (flush_to_zero )
2984- return reduce (fn , x , axis , keepdims , rounding_mode = rounding_mode , flush_to_zero = flush_to_zero )
3112+ return reduce_simple (fn , x , axis , keepdims ,
3113+ rounding_mode = rounding_mode , flush_to_zero = flush_to_zero )
29853114
29863115
29873116@impl (ct .max , fixed_args = ["max" ])
@@ -2990,53 +3119,63 @@ def reduce_impl_with_ftz(fn: str, x: Var, axis: Var, keepdims: Var, flush_to_zer
29903119 axis = _parse_reduce_axis (axis )
29913120 keepdims = require_constant_bool (keepdims )
29923121 flush_to_zero = require_constant_bool (flush_to_zero )
2993- return reduce (fn , x , axis , keepdims , flush_to_zero = flush_to_zero )
3122+ return reduce_simple (fn , x , axis , keepdims , flush_to_zero = flush_to_zero )
29943123
29953124
2996- class TileArgReduce (TypedOperation ):
2997- def __init__ (self , fn : str , x : Var , axis : Optional [int ],
2998- result_var : Var , loc : Loc ):
2999- super ().__init__ (
3000- "tile_arg_reduce" ,
3001- operands = {"x" : x },
3002- attributes = {"fn" : fn , "axis" : axis },
3003- result_vars = [result_var ],
3004- loc = loc ,
3005- )
3006-
3007- @override
3008- def generate_bytecode (self , ctx : BytecodeContext ) -> bc .Value :
3009- x_type = ctx .typeof (self .x )
3010- x_value = ctx .get_value (self .x )
3011- res_type = ctx .typeof (self .result_var )
3012- return lower_reduce_argmax_argmin (
3013- ctx , x_value , x_type , self .axis , res_type , self .fn
3014- )
3125+ def argmax_argmin (fn : str , x : Var , axis : Optional [int ], keepdims : bool ) -> Var :
3126+ require_tile_type (x )
3127+ final_shape = None
3128+ if axis is None :
3129+ if keepdims :
3130+ final_shape = (1 ,) * x .get_type ().ndim
3131+ keepdims = False
3132+ x = reshape (x , (- 1 ,))
3133+ axis = 0
3134+ else :
3135+ axis = normalize_axis (axis , x .get_type ().ndim )
30153136
3137+ if datatype .is_boolean (x .get_type ().dtype ):
3138+ x = astype (x , datatype .default_int_type )
30163139
3017- def argreduce (fn : str , x : Var , axis : Optional [int ], keepdims : bool ) -> Var :
3018- x_type = require_tile_type (x )
3019- x_shape = x_type .shape
3020- if axis is not None :
3021- axis = normalize_axis (axis , len (x_shape ))
3140+ x_type = x .get_type ()
3141+ indices = arange (x_type .shape_value [axis ], datatype .default_int_type )
3142+ indices = reshape (indices , tuple (- 1 if i == axis else 1 for i in range (x_type .ndim )))
3143+
3144+ match fn :
3145+ case "argmin" :
3146+ id_val = _get_min_max (x_type .dtype )[1 ]
3147+ cmp = "lt"
3148+ case "argmax" :
3149+ id_val = _get_min_max (x_type .dtype )[0 ]
3150+ cmp = "gt"
3151+ case _: assert False
3152+
3153+ def body (lhs : tuple [Var , Var ], rhs : tuple [Var , Var ]) -> tuple [Var , Var ]:
3154+ lhs_val , lhs_idx = lhs
3155+ rhs_val , rhs_idx = rhs
3156+ val_strict = raw_comparison (cmp , lhs_val , rhs_val )
3157+ val_equal = raw_comparison ("eq" , lhs_val , rhs_val )
3158+ index_lt = raw_comparison ("lt" , lhs_idx , rhs_idx )
3159+ val_equal_and_index_lt = raw_binary_bitwise ("and_" , val_equal , index_lt )
3160+ cond = raw_binary_bitwise ("or_" , val_strict , val_equal_and_index_lt )
3161+ res = raw_where (cond , lhs_val , rhs_val )
3162+ idx = raw_where (cond , lhs_idx , rhs_idx )
3163+ return res , idx
3164+
3165+ [_ , ret ] = reduce ((x , indices ), (id_val , 0 ), axis , keepdims , body )
3166+
3167+ if final_shape is not None :
3168+ ret = reshape (ret , final_shape )
30223169
3023- x_dtype = datatype .default_int_type if datatype .is_boolean (x_type .dtype ) else x_type .dtype
3024- x = _promote_and_broadcast_to (x , TileTy (x_dtype , x_shape ))
3025- output_dtype = datatype .default_int_type
3026- output_shape = TupleTy ([]) if axis is None else TupleTy (x_shape [:axis ] + x_shape [axis + 1 :])
3027- x = add_operation (
3028- TileArgReduce , TileTy (output_dtype , output_shape ),
3029- fn = fn , x = x , axis = axis
3030- )
3031- return reshape (x , _get_reduction_shape (x_type .shape_value , axis , keepdims ))
3170+ return ret
30323171
30333172
30343173@impl (ct .argmax , fixed_args = ["argmax" ])
30353174@impl (ct .argmin , fixed_args = ["argmin" ])
3036- def argreduce_impl (fn : str , x : Var , axis : Var , keepdims : Var ) -> Var :
3175+ def argmax_argmin_impl (fn : str , x : Var , axis : Var , keepdims : Var ) -> Var :
30373176 axis = require_optional_constant_int (axis )
30383177 keepdims = require_constant_bool (keepdims )
3039- return argreduce (fn , x , axis , keepdims )
3178+ return argmax_argmin (fn , x , axis , keepdims )
30403179
30413180
30423181class TileScan (TypedOperation ):
0 commit comments