5353 BYTE_BITWIDTH , typeof_pyval , dtype_registry , loose_type_of_pyval , get_constant_value
5454)
5555from .type import (
56- TupleTy , TileTy , NoneType , BoundMethodTy , ArrayTy ,
56+ PartitionViewTy , TupleTy , TileTy , NoneType , BoundMethodTy , ArrayTy ,
5757 ListTy , make_tile_ty , SliceType , DTypeConstructor , RangeIterType , Type ,
5858 NONE , ModuleTy , TypeTy , LooselyTypedScalar , DTypeSpec , StringTy , InvalidType ,
5959 array_size_type , ClosureTy , LiveCapturedScope , TokenTy ,
@@ -2012,26 +2012,43 @@ def _check_load_store_hints(latency_value: int | None, allow_tma_value: bool | N
20122012 raise TileTypeError (f"Allow TMA must be a boolean, got { allow_tma_value } " )
20132013
20142014
2015+ @dataclass (eq = False )
2016+ class MakePartitionView (Operation , opcode = "make_partition_view" ):
2017+ array : Var = operand ()
2018+
2019+ @override
2020+ def generate_bytecode (self , ctx : BytecodeContext ) -> bc .Value :
2021+ partition_view_ty = self .result_var .get_type ()
2022+ return bc .encode_MakePartitionViewOp (ctx .builder ,
2023+ typeid (ctx .type_table , partition_view_ty ),
2024+ ctx .get_value (self .array ))
2025+
2026+
2027+ def make_partition_view (array : Var , tile_shape : Sequence [int ],
2028+ order : Sequence [int ],
2029+ padding_mode : PaddingMode ) -> Var :
2030+ array_ty = array .get_type ()
2031+ assert isinstance (array_ty , ArrayTy )
2032+ view_ty = PartitionViewTy (array_ty , tuple (tile_shape ), tuple (order ), padding_mode )
2033+ return add_operation (MakePartitionView , view_ty , array = array )
2034+
2035+
20152036@dataclass (eq = False )
20162037class TileLoad (Operation , opcode = "tile_load" , memory_effect = MemoryEffect .LOAD ):
2017- order : tuple [int , ...] = attribute ()
2018- padding_mode : PaddingMode = attribute ()
20192038 latency : Optional [int ] = attribute ()
20202039 allow_tma : Optional [bool ] = attribute ()
2021- array : Var = operand ()
2040+ view : Var = operand ()
20222041 index : tuple [Var , ...] = operand ()
20232042 token : Optional [Var ] = operand (default = None )
20242043
20252044 @override
20262045 def generate_bytecode (self , ctx : BytecodeContext ) -> tuple [bc .Value , bc .Value ]:
20272046 tile_type : TileTy = self .result_vars [0 ].get_type ()
2028- shape = tile_type .shape
2029- partition = ctx .make_partition_view (self .array , self .order , shape , self .padding_mode )
20302047 res , res_token = bc .encode_LoadViewTkoOp (
20312048 ctx .builder ,
20322049 tile_type = typeid (ctx .type_table , tile_type ),
20332050 result_token_type = ctx .type_table .Token ,
2034- view = partition ,
2051+ view = ctx . get_value ( self . view ) ,
20352052 index = ctx .index_tuple (self .index ),
20362053 token = None if self .token is None else ctx .get_value (self .token ),
20372054 memory_ordering_semantics = bc .MemoryOrderingSemantics .WEAK ,
@@ -2041,16 +2058,6 @@ def generate_bytecode(self, ctx: BytecodeContext) -> tuple[bc.Value, bc.Value]:
20412058 return res , res_token
20422059
20432060
2044- def tile_load (array : Var , index : tuple [Var , ...], shape : Sequence [int ], order : Sequence [int ],
2045- padding_mode : PaddingMode , latency : Optional [int ],
2046- allow_tma : Optional [bool ]) -> tuple [Var , Var ]:
2047- res_ty = make_tile_ty (array .get_type ().dtype , shape )
2048- return add_operation (TileLoad , (res_ty , TokenTy ()),
2049- array = array , index = index , order = tuple (order ),
2050- padding_mode = padding_mode , latency = latency ,
2051- allow_tma = allow_tma )
2052-
2053-
20542061@impl (ct .load )
20552062def tile_load_impl (array : Var , index : Var , shape : Var , order : Var ,
20562063 padding_mode : Var , latency : Var , allow_tma : Var ) -> Var :
@@ -2070,32 +2077,31 @@ def tile_load_impl(array: Var, index: Var, shape: Var, order: Var,
20702077 latency = require_optional_constant_int (latency )
20712078 allow_tma = require_optional_constant_bool (allow_tma )
20722079 _check_load_store_hints (latency , allow_tma )
2073- result , _token = tile_load (array , index_items , broadcasted_shape , order , padding_mode , latency ,
2074- allow_tma )
2080+
2081+ view = make_partition_view (array , broadcasted_shape , order , padding_mode )
2082+ res_ty = make_tile_ty (array_ty .dtype , broadcasted_shape )
2083+ result , _token = add_operation (TileLoad , (res_ty , TokenTy ()),
2084+ view = view , index = index_items , latency = latency ,
2085+ allow_tma = allow_tma )
20752086 return reshape (result , shape )
20762087
20772088
20782089@dataclass (eq = False )
20792090class TileStore (Operation , opcode = "tile_store" , memory_effect = MemoryEffect .STORE ):
2080- order : tuple [int , ...] = attribute ()
20812091 latency : Optional [int ] = attribute ()
20822092 allow_tma : Optional [bool ] = attribute ()
2083- array : Var = operand ()
2093+ view : Var = operand ()
20842094 index : tuple [Var , ...] = operand ()
20852095 tile : Var = operand ()
20862096 token : Optional [Var ] = operand (default = None )
20872097
20882098 @override
20892099 def generate_bytecode (self , ctx : BytecodeContext ) -> bc .Value :
2090- tile_ty = self .tile .get_type ()
2091- tile_shape = tile_ty .shape
2092- partition = ctx .make_partition_view (self .array , self .order , tile_shape ,
2093- padding_mode = PaddingMode .UNDETERMINED )
20942100 return bc .encode_StoreViewTkoOp (
20952101 ctx .builder ,
20962102 result_token_type = ctx .type_table .Token ,
20972103 tile = ctx .get_value (self .tile ),
2098- view = partition ,
2104+ view = ctx . get_value ( self . view ) ,
20992105 index = ctx .index_tuple (self .index ),
21002106 token = None if self .token is None else ctx .get_value (self .token ),
21012107 memory_ordering_semantics = bc .MemoryOrderingSemantics .WEAK ,
@@ -2104,17 +2110,6 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
21042110 )
21052111
21062112
2107- def tile_store (array : Var , index : tuple [Var , ...], tile : Var , order : Sequence [int ],
2108- latency : Optional [int ], allow_tma : Optional [bool ]) -> Var :
2109- array_ty = array .get_type ()
2110- ty = require_tile_type (tile )
2111- tile = astype (tile , array_ty .dtype )
2112- if ty .ndim == 0 :
2113- tile = reshape (tile , (1 ,) * array_ty .ndim )
2114- return add_operation (TileStore , (TokenTy (),), array = array , index = index , tile = tile ,
2115- order = tuple (order ), latency = latency , allow_tma = allow_tma )
2116-
2117-
21182113def _implicit_cast (src : Var , target_dtype : DType , error_context : str ) -> Var :
21192114 ty = require_tile_maybe_loose_type (src )
21202115 try :
@@ -2130,20 +2125,25 @@ def _implicit_cast(src: Var, target_dtype: DType, error_context: str) -> Var:
21302125def tile_store_impl (array : Var , index : Var , tile : Var , order : Var ,
21312126 latency : Var , allow_tma : Var ):
21322127 array_ty = require_array_type (array )
2128+ tile_ty = require_tile_type (tile )
21332129 index_ty = require_index_or_index_tuple_type (index )
21342130 index_items = index .get_aggregate ().items if isinstance (index_ty , TupleTy ) else (index ,)
21352131 if array_ty .ndim != len (index_items ):
21362132 raise TileTypeError (f"Index size { len (index_items )} "
21372133 f" does not match the array rank { array_ty .ndim } " )
21382134
2139- tile = _implicit_cast ( tile , array_ty . dtype , "Stored tile is incompatible with array's dtype" )
2140-
2135+ shape = tile_ty . shape
2136+ broadcasted_shape = ( 1 ,) * array_ty . ndim if len ( shape ) == 0 else shape
21412137 order = require_constant_axis_order (order , array_ty .ndim )
21422138 latency = require_optional_constant_int (latency )
21432139 allow_tma = require_optional_constant_bool (allow_tma )
21442140 _check_load_store_hints (latency , allow_tma )
21452141
2146- [_token ] = tile_store (array , index_items , tile , order , latency , allow_tma )
2142+ tile = _implicit_cast (tile , array_ty .dtype , "Stored tile is incompatible with array's dtype" )
2143+ tile = reshape (tile , broadcasted_shape )
2144+ view = make_partition_view (array , broadcasted_shape , order , PaddingMode .UNDETERMINED )
2145+ [_token ] = add_operation (TileStore , (TokenTy (),), view = view , index = index_items , tile = tile ,
2146+ latency = latency , allow_tma = allow_tma )
21472147
21482148
21492149@dataclass (eq = False )
@@ -2618,35 +2618,30 @@ def join_tokens(tokens: Tuple[Var, ...], *, block: Block, res: Var, loc: Loc) ->
26182618@dataclass (eq = False )
26192619class NumTiles (Operation , opcode = "num_tiles" ):
26202620 axis : int = attribute ()
2621- shape : tuple [int , ...] = attribute ()
2622- order : tuple [int , ...] = attribute ()
2623- array : Var = operand ()
2621+ view : Var = operand ()
26242622
26252623 @override
26262624 def generate_bytecode (self , ctx : BytecodeContext ):
2627- pv = ctx . make_partition_view ( self .array , self . order , self . shape ,
2628- padding_mode = PaddingMode . UNDETERMINED )
2629- result_types = [ ctx . type_table . tile (ctx .type_table . I32 , ())] * len ( self . shape )
2630- values = bc . encode_GetIndexSpaceShapeOp ( ctx . builder , result_types , pv )
2625+ view_ty : PartitionViewTy = self .view . get_type ()
2626+ result_types = [ ctx . type_table . tile ( ctx . type_table . I32 , ())] * len ( view_ty . tile_shape )
2627+ values = bc . encode_GetIndexSpaceShapeOp (ctx .builder , result_types ,
2628+ src = ctx . get_value ( self . view ) )
26312629 return values [self .axis ]
26322630
26332631
2634- def num_tiles (array : Var , axis : int , shape : Sequence [int ], order : Sequence [int ]) -> Var :
2635- return add_operation (NumTiles , make_tile_ty (datatype .default_int_type , ()), array = array ,
2636- axis = axis , shape = tuple (shape ), order = tuple (order ))
2637-
2638-
26392632@impl (ct .num_tiles )
26402633def num_tiles_impl (array : Var , axis : Var , shape : Var , order : Var ) -> Var :
26412634 array_ty = require_array_type (array )
26422635 axis = require_constant_int (axis )
26432636 axis = normalize_axis (axis , array_ty .ndim )
26442637 shape = require_constant_shape (shape , allow_single_int = True , expected_rank = array_ty .ndim ,
26452638 allow_0d_shape = True )
2646- if len (shape ) == 0 :
2647- shape = (1 ,) * array_ty .ndim
2639+ broadcasted_shape = (1 ,) * array_ty .ndim if len (shape ) == 0 else shape
26482640 order = require_constant_axis_order (order , array_ty .ndim )
2649- return num_tiles (array , axis , shape , order )
2641+
2642+ view = make_partition_view (array , broadcasted_shape , order , PaddingMode .UNDETERMINED )
2643+ return add_operation (NumTiles , make_tile_ty (datatype .default_int_type , ()), view = view ,
2644+ axis = axis )
26502645
26512646
26522647def full_const (shape : Sequence [int ], fill_value : int | float , dtype : DType ) -> Var :
0 commit comments