2323 add_operation , Builder ,
2424 enter_nested_block , nested_block , PhiState , LoopVarState ,
2525 TupleValue , make_aggregate , RangeValue , BoundMethodValue , ArrayValue , ConstantState ,
26- ListValue , ClosureValue , MemoryEffect , attribute , operand , BlockRestriction
26+ ListValue , TiledViewValue , ClosureValue , MemoryEffect , attribute , operand , BlockRestriction
2727)
2828from .type import PointerTy
2929from . import hir
3333 require_signed_integer_0d_tile_type ,
3434 require_tile_type , normalize_axis , require_dtype_spec ,
3535 require_constant_bool , require_optional_constant_enum ,
36- require_constant_str , require_array_type , require_tuple_type , require_constant_slice ,
37- require_list_type , require_0d_tile_type ,
36+ require_constant_str , require_array_type , require_tiled_view_type , require_tuple_type ,
37+ require_constant_slice , require_list_type , require_0d_tile_type ,
3838 require_index_or_index_tuple_type , require_constant_shape , require_constant_axis_order ,
3939 require_constant_enum , require_optional_constant_int , require_optional_constant_bool ,
4040 require_optional_constant_str , PrintfValidator , require_tile_maybe_loose_type ,
4343 require_callable_type )
4444from .ops_utils import (
4545 BINOP_REGISTRY , UNARYOP_REGISTRY ,
46- check_rd_and_ftz , PaddingMode ,
46+ check_rd_and_ftz , PaddingMode , get_default_order ,
4747 rounding_mode_to_bytecode , get_default_rounding_mode , get_dtype ,
4848 change_dtype , memory_order_to_bytecode ,
4949 memory_scope_to_bytecode , broadcast_shapes2 , is_shape_broadcastable_to , BroadcastError ,
5757 PartitionViewTy , TupleTy , TileTy , NoneType , BoundMethodTy , ArrayTy ,
5858 ListTy , make_tile_ty , SliceType , DTypeConstructor , RangeIterType , Type ,
5959 NONE , ModuleTy , TypeTy , LooselyTypedScalar , DTypeSpec , StringTy , InvalidType ,
60- array_size_type , ClosureTy , LiveCapturedScope , TokenTy ,
60+ array_size_type , ClosureTy , LiveCapturedScope , TokenTy , TiledViewTy
6161)
6262from cuda .tile ._datatype import (
6363 DType , is_integral , is_float , is_signed , is_boolean , is_restricted_float ,
@@ -1552,6 +1552,7 @@ def getattr_impl(object: Var, name: Var) -> Var:
15521552 case ArrayTy (), "shape" : return build_tuple (object .get_aggregate ().shape )
15531553 case ArrayTy (), "strides" : return build_tuple (object .get_aggregate ().strides )
15541554 case ArrayTy (), "slice" : return bind_method (object , ct ._m_array_slice )
1555+ case ArrayTy (), "tiled_view" : return bind_method (object , ct ._m_array_tiled_view )
15551556
15561557 case TileTy (), "dtype" : return loosely_typed_const (ty .dtype )
15571558 case TileTy (), "shape" : return loosely_typed_const (ty .shape )
@@ -1564,6 +1565,15 @@ def getattr_impl(object: Var, name: Var) -> Var:
15641565 case TileTy (), "transpose" : return bind_method (object , ct .transpose )
15651566 case TileTy (), "item" : return bind_method (object , ct ._m_tile_item )
15661567
1568+ case TiledViewTy (), "dtype" : return loosely_typed_const (ty .dtype )
1569+ case TiledViewTy (), "tile_shape" : return loosely_typed_const (ty .tile_shape )
1570+ case TiledViewTy (), "num_tiles" :
1571+ [array ] = object .get_aggregate ().as_tuple ()
1572+ return build_tuple (num_tiles (array , ty .tile_shape , get_default_order (ty .ndim )))
1573+
1574+ case TiledViewTy (), "load" : return bind_method (object , ct ._m_tiled_view_load )
1575+ case TiledViewTy (), "store" : return bind_method (object , ct ._m_tiled_view_store )
1576+
15671577 case ModuleTy (), _:
15681578 try :
15691579 return loosely_typed_const (getattr (ty .py_mod , attr_name ))
@@ -2087,32 +2097,38 @@ def generate_bytecode(self, ctx: BytecodeContext) -> tuple[bc.Value, bc.Value]:
20872097 return res , res_token
20882098
20892099
2100+ def _tile_load_impl_inner (array : Var , index_items : tuple [Var , ...], shape : Sequence [int ],
2101+ order : Sequence [int ], padding_mode : PaddingMode ,
2102+ latency : Var , allow_tma : Var ) -> Var :
2103+ array_ty = require_array_type (array )
2104+ broadcasted_shape = (1 ,) * array_ty .ndim if len (shape ) == 0 else shape
2105+ latency = require_optional_constant_int (latency )
2106+ allow_tma = require_optional_constant_bool (allow_tma )
2107+ _check_load_store_hints (latency , allow_tma )
2108+
2109+ view = make_partition_view (array , broadcasted_shape , order , padding_mode )
2110+ res_ty = make_tile_ty (array_ty .dtype , broadcasted_shape )
2111+ result , _token = add_operation (TileLoad , (res_ty , TokenTy ()),
2112+ view = view , index = index_items , latency = latency ,
2113+ allow_tma = allow_tma )
2114+ return reshape (result , shape )
2115+
2116+
20902117@impl (ct .load )
20912118def tile_load_impl (array : Var , index : Var , shape : Var , order : Var ,
20922119 padding_mode : Var , latency : Var , allow_tma : Var ) -> Var :
20932120 array_ty = require_array_type (array )
20942121 index_ty = require_index_or_index_tuple_type (index )
20952122 index_items = index .get_aggregate ().items if isinstance (index_ty , TupleTy ) else (index ,)
2096-
20972123 if array_ty .ndim != len (index_items ):
20982124 raise TileTypeError (f"Index size { len (index_items )} "
20992125 f" does not match the array rank { array_ty .ndim } " )
21002126
21012127 shape = require_constant_shape (shape , allow_single_int = True , expected_rank = array_ty .ndim ,
21022128 allow_0d_shape = True )
2103- broadcasted_shape = (1 ,) * array_ty .ndim if len (shape ) == 0 else shape
21042129 order = require_constant_axis_order (order , array_ty .ndim )
21052130 padding_mode = require_constant_enum (padding_mode , PaddingMode )
2106- latency = require_optional_constant_int (latency )
2107- allow_tma = require_optional_constant_bool (allow_tma )
2108- _check_load_store_hints (latency , allow_tma )
2109-
2110- view = make_partition_view (array , broadcasted_shape , order , padding_mode )
2111- res_ty = make_tile_ty (array_ty .dtype , broadcasted_shape )
2112- result , _token = add_operation (TileLoad , (res_ty , TokenTy ()),
2113- view = view , index = index_items , latency = latency ,
2114- allow_tma = allow_tma )
2115- return reshape (result , shape )
2131+ return _tile_load_impl_inner (array , index_items , shape , order , padding_mode , latency , allow_tma )
21162132
21172133
21182134@dataclass (eq = False )
@@ -2150,29 +2166,34 @@ def _implicit_cast(src: Var, target_dtype: DType, error_context: str) -> Var:
21502166 return astype (src , target_dtype )
21512167
21522168
2169+ def _tile_store_impl_inner (array : Var , index_items : tuple [Var , ...], tile : Var ,
2170+ order : Sequence [int ], latency : Var , allow_tma : Var ):
2171+ array_ty = require_array_type (array )
2172+ tile_ty = require_tile_type (tile )
2173+ broadcasted_shape = (1 ,) * array_ty .ndim if len (tile_ty .shape ) == 0 else tile_ty .shape
2174+ latency = require_optional_constant_int (latency )
2175+ allow_tma = require_optional_constant_bool (allow_tma )
2176+ _check_load_store_hints (latency , allow_tma )
2177+
2178+ tile = reshape (tile , broadcasted_shape )
2179+ view = make_partition_view (array , broadcasted_shape , order , PaddingMode .UNDETERMINED )
2180+ [_token ] = add_operation (TileStore , (TokenTy (),), view = view , index = index_items , tile = tile ,
2181+ latency = latency , allow_tma = allow_tma )
2182+
2183+
21532184@impl (ct .store )
21542185def tile_store_impl (array : Var , index : Var , tile : Var , order : Var ,
21552186 latency : Var , allow_tma : Var ):
21562187 array_ty = require_array_type (array )
2157- tile_ty = require_tile_type (tile )
21582188 index_ty = require_index_or_index_tuple_type (index )
21592189 index_items = index .get_aggregate ().items if isinstance (index_ty , TupleTy ) else (index ,)
21602190 if array_ty .ndim != len (index_items ):
21612191 raise TileTypeError (f"Index size { len (index_items )} "
21622192 f" does not match the array rank { array_ty .ndim } " )
21632193
2164- shape = tile_ty .shape
2165- broadcasted_shape = (1 ,) * array_ty .ndim if len (shape ) == 0 else shape
2166- order = require_constant_axis_order (order , array_ty .ndim )
2167- latency = require_optional_constant_int (latency )
2168- allow_tma = require_optional_constant_bool (allow_tma )
2169- _check_load_store_hints (latency , allow_tma )
2170-
21712194 tile = _implicit_cast (tile , array_ty .dtype , "Stored tile is incompatible with array's dtype" )
2172- tile = reshape (tile , broadcasted_shape )
2173- view = make_partition_view (array , broadcasted_shape , order , PaddingMode .UNDETERMINED )
2174- [_token ] = add_operation (TileStore , (TokenTy (),), view = view , index = index_items , tile = tile ,
2175- latency = latency , allow_tma = allow_tma )
2195+ order = require_constant_axis_order (order , array_ty .ndim )
2196+ _tile_store_impl_inner (array , index_items , tile , order , latency , allow_tma )
21762197
21772198
21782199@dataclass (eq = False )
@@ -2646,16 +2667,22 @@ def join_tokens(tokens: Tuple[Var, ...], *, block: Block, res: Var, loc: Loc) ->
26462667
26472668@dataclass (eq = False )
26482669class NumTiles (Operation , opcode = "num_tiles" ):
2649- axis : int = attribute ()
26502670 view : Var = operand ()
26512671
26522672 @override
26532673 def generate_bytecode (self , ctx : BytecodeContext ):
26542674 view_ty : PartitionViewTy = self .view .get_type ()
26552675 result_types = [ctx .type_table .tile (ctx .type_table .I32 , ())] * len (view_ty .tile_shape )
2656- values = bc .encode_GetIndexSpaceShapeOp (ctx .builder , result_types ,
2657- src = ctx .get_value (self .view ))
2658- return values [self .axis ]
2676+ values = bc .encode_GetIndexSpaceShapeOp (ctx .builder , result_types , ctx .get_value (self .view ))
2677+ return values
2678+
2679+
2680+ def num_tiles (array : Var , shape : Sequence [int ], order : Sequence [int ]) -> Tuple [Var , ...]:
2681+ array_ty = require_array_type (array )
2682+ broadcasted_shape = (1 ,) * array_ty .ndim if len (shape ) == 0 else shape
2683+ view = make_partition_view (array , broadcasted_shape , order , PaddingMode .UNDETERMINED )
2684+ result_tys = tuple (make_tile_ty (datatype .default_int_type , ()) for _s in broadcasted_shape )
2685+ return add_operation (NumTiles , result_tys , view = view )
26592686
26602687
26612688@impl (ct .num_tiles )
@@ -2665,12 +2692,9 @@ def num_tiles_impl(array: Var, axis: Var, shape: Var, order: Var) -> Var:
26652692 axis = normalize_axis (axis , array_ty .ndim )
26662693 shape = require_constant_shape (shape , allow_single_int = True , expected_rank = array_ty .ndim ,
26672694 allow_0d_shape = True )
2668- broadcasted_shape = (1 ,) * array_ty .ndim if len (shape ) == 0 else shape
26692695 order = require_constant_axis_order (order , array_ty .ndim )
2670-
2671- view = make_partition_view (array , broadcasted_shape , order , PaddingMode .UNDETERMINED )
2672- return add_operation (NumTiles , make_tile_ty (datatype .default_int_type , ()), view = view ,
2673- axis = axis )
2696+ space_shape = num_tiles (array , shape , order )
2697+ return space_shape [axis ]
26742698
26752699
26762700def full_const (shape : Sequence [int ], fill_value : int | float , dtype : DType ) -> Var :
@@ -4009,6 +4033,54 @@ def tile_item(tile: Var) -> Var:
40094033 return reshape (tile , ())
40104034
40114035
4036+ @impl (ct ._m_array_tiled_view )
4037+ def array_tiled_view_impl (array : Var , tile_shape : Var , padding_mode : Var ) -> Var :
4038+ array_ty = require_array_type (array )
4039+ shape_val = require_constant_shape (tile_shape , allow_single_int = True ,
4040+ expected_rank = array_ty .ndim ,
4041+ allow_0d_shape = True )
4042+ padding_mode_val = require_constant_enum (padding_mode , PaddingMode )
4043+ view_ty = TiledViewTy (array_ty , shape_val , padding_mode_val )
4044+ return make_aggregate (TiledViewValue (array ), view_ty )
4045+
4046+
4047+ @impl (ct ._m_tiled_view_load )
4048+ def tiled_view_load_impl (tiled_view : Var , index : Var , latency : Var , allow_tma : Var ) -> Var :
4049+ view_ty = require_tiled_view_type (tiled_view )
4050+ index_ty = require_index_or_index_tuple_type (index )
4051+ index_items = index .get_aggregate ().items if isinstance (index_ty , TupleTy ) else (index ,)
4052+ if view_ty .ndim != len (index_items ):
4053+ raise TileTypeError (f"Index size { len (index_items )} "
4054+ f" does not match the tiled view rank { view_ty .ndim } " )
4055+
4056+ [array ] = tiled_view .get_aggregate ().as_tuple ()
4057+ order = get_default_order (view_ty .ndim )
4058+ return _tile_load_impl_inner (array , index_items , view_ty .tile_shape , order ,
4059+ view_ty .padding_mode , latency , allow_tma )
4060+
4061+
4062+ @impl (ct ._m_tiled_view_store )
4063+ def tiled_view_store_impl (tiled_view : Var , index : Var , tile : Var , latency : Var , allow_tma : Var ):
4064+ view_ty = require_tiled_view_type (tiled_view )
4065+ index_ty = require_index_or_index_tuple_type (index )
4066+ index_items = index .get_aggregate ().items if isinstance (index_ty , TupleTy ) else (index ,)
4067+ if view_ty .ndim != len (index_items ):
4068+ raise TileTypeError (f"Index size { len (index_items )} "
4069+ f" does not match the tiled view rank { view_ty .ndim } " )
4070+
4071+ tile_ty = require_tile_type (tile )
4072+ if not is_shape_broadcastable_to (tile_ty .shape , view_ty .tile_shape ):
4073+ raise TileTypeError (f"Tile shape { tile_ty .shape } is not broadcastable"
4074+ f" to the tiled view's tile shape { view_ty .tile_shape } " )
4075+
4076+ tile = broadcast_to (tile , view_ty .tile_shape )
4077+ tile = _implicit_cast (tile , view_ty .dtype ,
4078+ "Stored tile is incompatible with tiled view's dtype" )
4079+ [array ] = tiled_view .get_aggregate ().as_tuple ()
4080+ order = get_default_order (view_ty .ndim )
4081+ _tile_store_impl_inner (array , index_items , tile , order , latency , allow_tma )
4082+
4083+
40124084def store_var (local_idx : int , value : Var , loc : Loc | None = None ):
40134085 scope = Scope .get_current ()
40144086 new_var = scope .local .redefine (local_idx , loc or Builder .get_current ().loc )
0 commit comments