7272from .type import (
7373 MemorySpace ,
7474 Type ,
75- make_tile_ty ,
7675 make_vector_ty ,
7776 is_vector_ty ,
7877 ArrayTy ,
@@ -219,8 +218,8 @@ def require_any_pointer_var(var: Var) -> TileTy:
219218
220219
221220def _array_base_pointer_type (array_ty : ArrayTy ) -> TileTy :
222- return make_tile_ty (
223- PointerTy (make_tile_ty (array_ty .dtype , ()), array_ty .memory_space ), ( )
221+ return TileTy (
222+ PointerTy (TileTy (array_ty .dtype , ()), array_ty .memory_space )
224223 )
225224
226225
@@ -240,7 +239,7 @@ def _get_array_base_pointer(array: Var) -> Var:
240239
241240def _array_linear_offset (array : Var , indices : tuple [Var , ...]) -> Var :
242241 array_val = array .get_aggregate ()
243- zero = strictly_typed_const (0 , make_tile_ty (datatype .uint64 , () ))
242+ zero = strictly_typed_const (0 , TileTy (datatype .uint64 ))
244243 offset = zero
245244 if len (indices ) != len (array_val .strides ):
246245 raise TileTypeError (
@@ -299,7 +298,7 @@ def array_getitem(object: Var, key: Var) -> Var:
299298 pointer = _array_get_element_pointer (object , indices )
300299 [result ] = add_operation (
301300 LoadPointer ,
302- (make_tile_ty (array_ty .dtype , () ),),
301+ (TileTy (array_ty .dtype ),),
303302 pointer = pointer ,
304303 alignment = None ,
305304 volatile = False ,
@@ -377,7 +376,7 @@ def atomic_rmw_dispatch_impl(kind: AtomicRMWKind, A: Var, idx: Var, val: Var) ->
377376 array_ty , _ = require_matching_array_value_type (A , val )
378377 indices = require_array_indices (A , idx )
379378 pointer = _array_get_element_pointer (A , indices )
380- result_ty = make_tile_ty (array_ty .dtype , () )
379+ result_ty = TileTy (array_ty .dtype )
381380 memory_order = mlir .llvm .AtomicOrdering .acq_rel
382381 return add_operation (
383382 AtomicRMW ,
@@ -394,7 +393,7 @@ def atomic_exch_impl(A: Var, idx: Var, val: Var) -> Var:
394393 array_ty , _ = require_matching_array_value_type (A , val )
395394 indices = require_array_indices (A , idx )
396395 pointer = _array_get_element_pointer (A , indices )
397- result_ty = make_tile_ty (array_ty .dtype , () )
396+ result_ty = TileTy (array_ty .dtype )
398397 memory_order = mlir .llvm .AtomicOrdering .acq_rel
399398 return add_operation (
400399 AtomicExchange ,
@@ -415,7 +414,7 @@ def atomic_cas_impl(A: Var, idx: Var, old: Var, val: Var) -> Var:
415414 )
416415 indices = require_array_indices (A , idx )
417416 pointer = _array_get_element_pointer (A , indices )
418- result_ty = make_tile_ty (array_ty .dtype , () )
417+ result_ty = TileTy (array_ty .dtype )
419418 success_memory_order = mlir .llvm .AtomicOrdering .acq_rel
420419 failure_memory_order = mlir .llvm .AtomicOrdering .monotonic
421420 return add_operation (
@@ -548,12 +547,11 @@ def address_space_cast_impl(value: Var, memory_space: Var) -> Var:
548547 memory_space = require_constant_enum (memory_space , MemorySpace )
549548 match pointer_tile_ty .dtype :
550549 case PointerTy ():
551- result_ty = make_tile_ty (
552- dataclasses .replace (pointer_tile_ty .dtype , memory_space = memory_space ),
553- (),
550+ result_ty = TileTy (
551+ dataclasses .replace (pointer_tile_ty .dtype , memory_space = memory_space )
554552 )
555553 case OpaquePointerTy ():
556- result_ty = make_tile_ty (OpaquePointerTy (memory_space ), ( ))
554+ result_ty = TileTy (OpaquePointerTy (memory_space ))
557555 case _:
558556 raise TileTypeError (f"Expected a pointer type, got { pointer_tile_ty } " )
559557 return add_operation (
@@ -574,7 +572,7 @@ def reinterpret_pointer_as_array_impl(pointer: Var, dtype: Var, shape: Var, stri
574572 shape = require_constant_int_tuple (shape , allow_single_int = True )
575573 dtype = require_dtype_spec (dtype )
576574 strides = _contiguous_strides (shape )
577- element_ty = make_tile_ty (dtype , () )
575+ element_ty = TileTy (dtype )
578576 memory_space = pointer_tile_ty .dtype .memory_space
579577
580578 typed_pointer_ty = TileTy (PointerTy (element_ty , memory_space = memory_space ), ())
@@ -1102,7 +1100,7 @@ def inline_ptx_impl(ptx_code: Var, constraint_pairs: tuple) -> tuple:
11021100 ptx_code = require_constant_str (ptx_code )
11031101 mlir_ptx_code , ro_args , rw_args , wo_args = require_inline_ptx_constraint_pairs (
11041102 ptx_code , constraint_pairs )
1105- result_types = tuple (make_tile_ty (dtype , () ) for dtype in wo_args )
1103+ result_types = tuple (TileTy (dtype ) for dtype in wo_args )
11061104 results = add_operation (
11071105 InlinePTX ,
11081106 result_types ,
@@ -1157,7 +1155,7 @@ def shfl_sync_impl(mode: str, mask: Var, value: Var, lane_mask: Var, width: Var)
11571155 clamp = 0 if mode == 'up' else 0x1F
11581156 mask_and_clamp = strictly_typed_const (
11591157 ((WARP_SIZE - width ) << 8 ) | clamp ,
1160- make_tile_ty (datatype .int32 , () ),
1158+ TileTy (datatype .int32 ),
11611159 )
11621160
11631161 suffix = "i32" if datatype .is_integral (value_ty .dtype ) else "f32"
@@ -1250,11 +1248,11 @@ def require_constant_result_dtype(dtype: Var) -> Type:
12501248 if const_dtype == datatype .any_opaque_ptr :
12511249 raise TileTypeError ("Result type cannot have no memory space" )
12521250 memory_space = datatype .MemorySpace (const_dtype .value )
1253- return make_tile_ty (OpaquePointerTy (memory_space = memory_space ), ( ))
1251+ return TileTy (OpaquePointerTy (memory_space = memory_space ))
12541252 elif is_vector_ty (const_dtype ):
12551253 return const_dtype
12561254 elif isinstance (const_dtype , datatype .DType ):
1257- return make_tile_ty (const_dtype , () )
1255+ return TileTy (const_dtype )
12581256 else :
12591257 raise TileTypeError (f"Expected a type spec but got { dtype } " )
12601258
0 commit comments