2525 enter_nested_block , nested_block , PhiState , LoopVarState ,
2626 TupleValue , make_aggregate , RangeValue , BoundMethodValue , ArrayValue , ConstantState ,
2727 ListValue , TiledViewValue , ClosureValue , MemoryEffect , attribute , operand ,
28- BlockRestriction , FormattedStringValue ,
28+ BlockRestriction , FormattedStringValue , RawArrayMemoryValue
2929)
3030from .type import PointerTy
3131from . import hir
4242 require_optional_constant_str , PrintfValidator , require_tile_maybe_loose_type ,
4343 require_0d_tile_maybe_loose_type , require_bool , require_optional_range_type ,
4444 require_tile_or_tile_tuple_type , require_constant_scalar_tuple , require_constant_scalar ,
45- require_callable_type )
45+ require_callable_type , require_raw_array_memory_type )
4646from .ops_utils import (
4747 BINOP_REGISTRY , UNARYOP_REGISTRY ,
4848 check_rd_and_ftz , PaddingMode , get_default_order ,
6060 ListTy , make_tile_ty , SliceType , DTypeConstructor , RangeIterType , Type ,
6161 NONE , ModuleTy , TypeTy , LooselyTypedScalar , DTypeSpec , StringTy , InvalidType ,
6262 array_size_type , ClosureTy , LiveCapturedScope , TokenTy , TiledViewTy , FormattedStringTy ,
63- StringFormat , FormattedPiece ,
63+ StringFormat , FormattedPiece , RawArrayMemoryTy
6464)
6565from cuda .tile ._datatype import (
6666 DType , is_integral , is_float , is_signed , is_boolean ,
@@ -1593,6 +1593,7 @@ def getattr_impl(object: Var, name: Var) -> Var:
15931593 case ArrayTy (), "strides" : return build_tuple (object .get_aggregate ().strides )
15941594 case ArrayTy (), "slice" : return bind_method (object , ct ._m_array_slice )
15951595 case ArrayTy (), "tiled_view" : return bind_method (object , ct ._m_array_tiled_view )
1596+ case ArrayTy (), "get_raw_memory" : return bind_method (object , ct ._m_array_get_raw_memory )
15961597
15971598 case TileTy (), "dtype" : return loosely_typed_const (ty .dtype )
15981599 case TileTy (), "shape" : return loosely_typed_const (ty .shape )
@@ -1614,6 +1615,12 @@ def getattr_impl(object: Var, name: Var) -> Var:
16141615 case TiledViewTy (), "load" : return bind_method (object , ct ._m_tiled_view_load )
16151616 case TiledViewTy (), "store" : return bind_method (object , ct ._m_tiled_view_store )
16161617
1618+ case RawArrayMemoryTy (), "dtype" : return loosely_typed_const (ty .dtype )
1619+ case RawArrayMemoryTy (), "load_offset" : return bind_method (
1620+ object , ct ._m_raw_array_memory_load_offset )
1621+ case RawArrayMemoryTy (), "store_offset" : return bind_method (
1622+ object , ct ._m_raw_array_memory_store_offset )
1623+
16171624 case ModuleTy (), _:
16181625 try :
16191626 return loosely_typed_const (getattr (ty .py_mod , attr_name ))
@@ -2154,6 +2161,83 @@ def _tile_load_impl_inner(array: Var, index_items: tuple[Var, ...], shape: Seque
21542161 return reshape (result , shape )
21552162
21562163
2164+ @impl (ct ._m_array_get_raw_memory )
2165+ def get_raw_memory_impl (array : Var ) -> Var :
2166+ array_ty = require_array_type (array )
2167+ array_val = array .get_aggregate ()
2168+ assert isinstance (array_val , ArrayValue )
2169+ base_ptr = array_val .base_ptr
2170+ raw_mem_ty = RawArrayMemoryTy (array_ty .dtype )
2171+ [ret ] = unflatten_aggregates ((base_ptr ,), (raw_mem_ty ,), (raw_mem_ty ,))
2172+ return ret
2173+
2174+
2175+ @impl (ct ._m_raw_array_memory_load_offset )
2176+ def raw_array_memory_load_offset_impl (raw_array_memory : Var , offset : Var , mask : Var ,
2177+ padding_value : Var , latency : Var ) -> Var :
2178+ raw_mem_ty = require_raw_array_memory_type (raw_array_memory )
2179+ raw_mem_val = raw_array_memory .get_aggregate ()
2180+ assert isinstance (raw_mem_val , RawArrayMemoryValue )
2181+ base_ptr = raw_mem_val .base_ptr
2182+
2183+ offset = astype (offset , datatype .uint64 )
2184+ pointer = pointer_offset (base_ptr , offset )
2185+ pointer_ty = pointer .get_type ()
2186+ pointer_shape = pointer_ty .shape
2187+ array_dtype = raw_mem_ty .dtype
2188+
2189+ final_mask = _process_custom_mask (mask , None , pointer_shape )
2190+
2191+ if padding_value .is_constant () and padding_value .get_constant () is None :
2192+ padding_var : Optional [Var ] = None
2193+ else :
2194+ padding_ty = require_tile_type (padding_value )
2195+ padding_shape = padding_ty .shape
2196+ if not is_shape_broadcastable_to (padding_shape , pointer_shape ):
2197+ raise TileTypeError (f"Padding shape { padding_shape } is not broadcastable to the"
2198+ f" offset shape { pointer_shape } " )
2199+ padding_var = _implicit_cast (padding_value , array_dtype , "Invalid padding value" )
2200+ padding_var = broadcast_to (padding_var , pointer_shape )
2201+
2202+ latency_val = require_optional_constant_int (latency )
2203+ _check_load_store_hints (latency_val )
2204+ result , _token = load_pointer (pointer , final_mask , padding_var , latency_val )
2205+ return result
2206+
2207+
2208+ @impl (ct ._m_raw_array_memory_store_offset )
2209+ def raw_array_memory_store_offset_impl (raw_array_memory : Var , offset : Var , value : Var ,
2210+ mask : Var , latency : Var ) -> None :
2211+ raw_mem_ty = require_raw_array_memory_type (raw_array_memory )
2212+ raw_mem_val = raw_array_memory .get_aggregate ()
2213+ assert isinstance (raw_mem_val , RawArrayMemoryValue )
2214+ base_ptr = raw_mem_val .base_ptr
2215+
2216+ offset = astype (offset , datatype .uint64 )
2217+ pointer = pointer_offset (base_ptr , offset )
2218+ pointer_ty = pointer .get_type ()
2219+ pointer_shape = pointer_ty .shape
2220+ array_dtype = raw_mem_ty .dtype
2221+
2222+ final_mask = _process_custom_mask (mask , None , pointer_shape )
2223+ value = _get_scatter_value (value , pointer_shape , array_dtype , "Value" ,
2224+ array_name = "RawArrayMemory" )
2225+
2226+ latency_val = require_optional_constant_int (latency )
2227+ _check_load_store_hints (latency_val )
2228+ [_token ] = store_pointer (pointer , value , final_mask , latency_val )
2229+
2230+
2231+ def tile_load (array : Var , index : tuple [Var , ...], shape : Sequence [int ], order : Sequence [int ],
2232+ padding_mode : PaddingMode , latency : Optional [int ],
2233+ allow_tma : Optional [bool ]) -> tuple [Var , Var ]:
2234+ res_ty = make_tile_ty (array .get_type ().dtype , shape )
2235+ return add_operation (TileLoad , (res_ty , TokenTy ()),
2236+ array = array , index = index , order = tuple (order ),
2237+ padding_mode = padding_mode , latency = latency ,
2238+ allow_tma = allow_tma )
2239+
2240+
21572241@impl (ct .load )
21582242def tile_load_impl (array : Var , index : Var , shape : Var , order : Var ,
21592243 padding_mode : Var , latency : Var , allow_tma : Var ) -> Var :
@@ -2369,7 +2453,8 @@ def scatter_impl(array: Var, indices: Var, value: Var, mask: Var,
23692453
23702454
23712455def _get_scatter_value (value : Var , pointer_shape : Tuple [int , ...], array_dtype : DType ,
2372- value_name : str , cast_dtype : bool = True ) -> Var :
2456+ value_name : str , cast_dtype : bool = True ,
2457+ array_name : str = "array" ) -> Var :
23732458 value_ty = require_tile_type (value )
23742459 value_shape = value_ty .shape
23752460
@@ -2379,7 +2464,7 @@ def _get_scatter_value(value: Var, pointer_shape: Tuple[int, ...], array_dtype:
23792464
23802465 if cast_dtype :
23812466 value = _implicit_cast (value , array_dtype ,
2382- "Stored value is incompatible with array 's dtype" )
2467+ f "Stored value is incompatible with { array_name } 's dtype" )
23832468 return broadcast_to (value , pointer_shape )
23842469
23852470
0 commit comments