Skip to content

Commit 88b8eed

Browse files
committed
Remove make_tile_ty() helper
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent ee6bda7 commit 88b8eed

11 files changed

Lines changed: 93 additions & 92 deletions

File tree

experimental/cuda-lang/src/cuda/lang/_ir/ops.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
from .type import (
7373
MemorySpace,
7474
Type,
75-
make_tile_ty,
7675
make_vector_ty,
7776
is_vector_ty,
7877
ArrayTy,
@@ -219,8 +218,8 @@ def require_any_pointer_var(var: Var) -> TileTy:
219218

220219

221220
def _array_base_pointer_type(array_ty: ArrayTy) -> TileTy:
222-
return make_tile_ty(
223-
PointerTy(make_tile_ty(array_ty.dtype, ()), array_ty.memory_space), ()
221+
return TileTy(
222+
PointerTy(TileTy(array_ty.dtype, ()), array_ty.memory_space)
224223
)
225224

226225

@@ -240,7 +239,7 @@ def _get_array_base_pointer(array: Var) -> Var:
240239

241240
def _array_linear_offset(array: Var, indices: tuple[Var, ...]) -> Var:
242241
array_val = array.get_aggregate()
243-
zero = strictly_typed_const(0, make_tile_ty(datatype.uint64, ()))
242+
zero = strictly_typed_const(0, TileTy(datatype.uint64))
244243
offset = zero
245244
if len(indices) != len(array_val.strides):
246245
raise TileTypeError(
@@ -299,7 +298,7 @@ def array_getitem(object: Var, key: Var) -> Var:
299298
pointer = _array_get_element_pointer(object, indices)
300299
[result] = add_operation(
301300
LoadPointer,
302-
(make_tile_ty(array_ty.dtype, ()),),
301+
(TileTy(array_ty.dtype),),
303302
pointer=pointer,
304303
alignment=None,
305304
volatile=False,
@@ -377,7 +376,7 @@ def atomic_rmw_dispatch_impl(kind: AtomicRMWKind, A: Var, idx: Var, val: Var) ->
377376
array_ty, _ = require_matching_array_value_type(A, val)
378377
indices = require_array_indices(A, idx)
379378
pointer = _array_get_element_pointer(A, indices)
380-
result_ty = make_tile_ty(array_ty.dtype, ())
379+
result_ty = TileTy(array_ty.dtype)
381380
memory_order = mlir.llvm.AtomicOrdering.acq_rel
382381
return add_operation(
383382
AtomicRMW,
@@ -394,7 +393,7 @@ def atomic_exch_impl(A: Var, idx: Var, val: Var) -> Var:
394393
array_ty, _ = require_matching_array_value_type(A, val)
395394
indices = require_array_indices(A, idx)
396395
pointer = _array_get_element_pointer(A, indices)
397-
result_ty = make_tile_ty(array_ty.dtype, ())
396+
result_ty = TileTy(array_ty.dtype)
398397
memory_order = mlir.llvm.AtomicOrdering.acq_rel
399398
return add_operation(
400399
AtomicExchange,
@@ -415,7 +414,7 @@ def atomic_cas_impl(A: Var, idx: Var, old: Var, val: Var) -> Var:
415414
)
416415
indices = require_array_indices(A, idx)
417416
pointer = _array_get_element_pointer(A, indices)
418-
result_ty = make_tile_ty(array_ty.dtype, ())
417+
result_ty = TileTy(array_ty.dtype)
419418
success_memory_order = mlir.llvm.AtomicOrdering.acq_rel
420419
failure_memory_order = mlir.llvm.AtomicOrdering.monotonic
421420
return add_operation(
@@ -548,12 +547,11 @@ def address_space_cast_impl(value: Var, memory_space: Var) -> Var:
548547
memory_space = require_constant_enum(memory_space, MemorySpace)
549548
match pointer_tile_ty.dtype:
550549
case PointerTy():
551-
result_ty = make_tile_ty(
552-
dataclasses.replace(pointer_tile_ty.dtype, memory_space=memory_space),
553-
(),
550+
result_ty = TileTy(
551+
dataclasses.replace(pointer_tile_ty.dtype, memory_space=memory_space)
554552
)
555553
case OpaquePointerTy():
556-
result_ty = make_tile_ty(OpaquePointerTy(memory_space), ())
554+
result_ty = TileTy(OpaquePointerTy(memory_space))
557555
case _:
558556
raise TileTypeError(f"Expected a pointer type, got {pointer_tile_ty}")
559557
return add_operation(
@@ -574,7 +572,7 @@ def reinterpret_pointer_as_array_impl(pointer: Var, dtype: Var, shape: Var, stri
574572
shape = require_constant_int_tuple(shape, allow_single_int=True)
575573
dtype = require_dtype_spec(dtype)
576574
strides = _contiguous_strides(shape)
577-
element_ty = make_tile_ty(dtype, ())
575+
element_ty = TileTy(dtype)
578576
memory_space = pointer_tile_ty.dtype.memory_space
579577

580578
typed_pointer_ty = TileTy(PointerTy(element_ty, memory_space=memory_space), ())
@@ -1102,7 +1100,7 @@ def inline_ptx_impl(ptx_code: Var, constraint_pairs: tuple) -> tuple:
11021100
ptx_code = require_constant_str(ptx_code)
11031101
mlir_ptx_code, ro_args, rw_args, wo_args = require_inline_ptx_constraint_pairs(
11041102
ptx_code, constraint_pairs)
1105-
result_types = tuple(make_tile_ty(dtype, ()) for dtype in wo_args)
1103+
result_types = tuple(TileTy(dtype) for dtype in wo_args)
11061104
results = add_operation(
11071105
InlinePTX,
11081106
result_types,
@@ -1157,7 +1155,7 @@ def shfl_sync_impl(mode: str, mask: Var, value: Var, lane_mask: Var, width: Var)
11571155
clamp = 0 if mode == 'up' else 0x1F
11581156
mask_and_clamp = strictly_typed_const(
11591157
((WARP_SIZE - width) << 8) | clamp,
1160-
make_tile_ty(datatype.int32, ()),
1158+
TileTy(datatype.int32),
11611159
)
11621160

11631161
suffix = "i32" if datatype.is_integral(value_ty.dtype) else "f32"
@@ -1250,11 +1248,11 @@ def require_constant_result_dtype(dtype: Var) -> Type:
12501248
if const_dtype == datatype.any_opaque_ptr:
12511249
raise TileTypeError("Result type cannot have no memory space")
12521250
memory_space = datatype.MemorySpace(const_dtype.value)
1253-
return make_tile_ty(OpaquePointerTy(memory_space=memory_space), ())
1251+
return TileTy(OpaquePointerTy(memory_space=memory_space))
12541252
elif is_vector_ty(const_dtype):
12551253
return const_dtype
12561254
elif isinstance(const_dtype, datatype.DType):
1257-
return make_tile_ty(const_dtype, ())
1255+
return TileTy(const_dtype)
12581256
else:
12591257
raise TileTypeError(f"Expected a type spec but got {dtype}")
12601258

experimental/cuda-lang/src/cuda/lang/_ir/type.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
TokenTy,
2020
TypeTy,
2121
EnumTy,
22-
make_tile_ty,
2322
ContextManagerTy,
2423
ContextManagerState,
2524
MemorySpace,
@@ -54,7 +53,7 @@ def make_vector_ty(dtype: DType, length: int) -> TileTy:
5453
raise TileTypeError(
5554
f"Expected vector length to be a positive power of two, got {length}"
5655
)
57-
return make_tile_ty(dtype, (length,))
56+
return TileTy(dtype, (length,))
5857

5958

6059
@dataclass(frozen=True, eq=True)
@@ -127,7 +126,6 @@ class TensorMapTy(Type):
127126
"TokenTy",
128127
"TypeTy",
129128
"EnumTy",
130-
"make_tile_ty",
131129
"make_vector_ty",
132130
"is_vector_ty",
133131
"MemorySpace",

experimental/cuda-lang/src/cuda/lang/_passes/canonicalize_parameters.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Iterable
66

7-
from cuda.lang._ir.type import ArrayTy, MemorySpace, PointerTy, make_tile_ty, ArrayValue
7+
from cuda.lang._ir.type import ArrayTy, MemorySpace, PointerTy, ArrayValue, TileTy
88
from cuda.lang._ir.ops import MakeTensorView, ReinterpretPointer
99

1010

@@ -24,12 +24,11 @@ def _rewrite_make_tensor_view(builder, op, array_parameter_names) -> Iterable:
2424
strides=array_ty.strides,
2525
memory_space=MemorySpace.GENERIC,
2626
)
27-
base_ptr_ty = make_tile_ty(
27+
base_ptr_ty = TileTy(
2828
PointerTy(
29-
pointee=make_tile_ty(array_ty.dtype, ()),
29+
pointee=TileTy(array_ty.dtype),
3030
memory_space=MemorySpace.GENERIC,
31-
),
32-
(),
31+
)
3332
)
3433
base_ptr = op.base_ptr
3534
if base_ptr.get_type() != base_ptr_ty:

experimental/cuda-lang/src/cuda/lang/_passes/flatten_cfg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from cuda.lang._ir.ops import IfElse, EndBranch, Loop, Continue, Break
1010
from cuda.lang._ir.ir import IRContext, Region, Block, TileBlock, Operation, Var
11-
from cuda.lang._ir.type import make_tile_ty
11+
from cuda.lang._ir.type import TileTy
1212
from cuda.lang._datatype import bool_
1313
from cuda.lang._ir.ops import (
1414
Branch,
@@ -129,7 +129,7 @@ def flatten_for_loop(self, op: Loop, current: Block) -> Block:
129129
)
130130

131131
cv = self.ctx.make_temp(op.loc)
132-
cv.set_type(make_tile_ty(bool_, ()))
132+
cv.set_type(TileTy(bool_))
133133

134134
# Tile's range object requires the step to be positive so
135135
# we can always use "lt" here.

experimental/cuda-lang/src/cuda/lang/_passes/ir2mlir/pass_definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ def _lower_intrinsic_result_type(
916916
self, result_types: Sequence[ir_type.Type]
917917
) -> Sequence[mlir.Type]:
918918
for result_type in result_types:
919-
if result_type == ir_type.make_tile_ty(datatype.bool_, ()):
919+
if result_type == ir_type.TileTy(datatype.bool_):
920920
yield T.i1()
921921
else:
922922
yield ir_type_to_mlir_type(result_type)

src/cuda/tile/_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535
from cuda.tile._ir import ir, hir
3636
from cuda.tile._ir.ops import loosely_typed_const, flatten_block_parameters, tile_impl_registry
37-
from cuda.tile._ir.type import TileTy, ArrayTy, ListTy, make_tile_ty
37+
from cuda.tile._ir.type import TileTy, ArrayTy, ListTy
3838
from cuda.tile._passes.ast2hir import get_function_hir
3939
from cuda.tile._passes.code_motion import hoist_loop_invariants
4040
from cuda.tile._passes.unhoist_partition_views import unhoist_partition_views
@@ -165,7 +165,7 @@ def _get_array_ty(param: ArrayConstraint):
165165
raise NotImplementedError("Negative strides are currently not supported:"
166166
" please specify stride_lower_bound_incl=0")
167167

168-
return ArrayTy(make_tile_ty(param.dtype, ()),
168+
return ArrayTy(TileTy(param.dtype),
169169
shape=(None,) * param.ndim,
170170
strides=param.stride_constant,
171171
index_dtype=param.index_dtype)

0 commit comments

Comments
 (0)