Skip to content

Commit cba914c

Browse files
committed
Make Var generic over Type to improve type hints
This is (hopefully) a non-functional change that would help with refactoring a bunch of operation implementations etc. For example, it enables annotations like this: def astype(x: Var[TileTy], dtype: DType) -> Var[TileTy]: ... Which was previously just def astype(x: Var, dtype: DType) -> Var: ... Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 7bf1722 commit cba914c

4 files changed

Lines changed: 156 additions & 145 deletions

File tree

experimental/cuda-lang/src/cuda/lang/_ir/ops.py

Lines changed: 54 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
Unary, implicit_cast, address_space_cast, reinterpret_pointer, AddrSpaceCast,
6363
ReinterpretPointer,
6464
)
65-
from cuda.tile._ir.ir import MemoryEffect, make_aggregate
65+
from cuda.tile._ir.ir import MemoryEffect, make_aggregate, add_operation_variadic
6666
from cuda.lang._exception import TileCompilerError, TileTypeError
6767
import cuda.lang._datatype as datatype
6868
from cuda.tile._datatype import (
@@ -324,14 +324,13 @@ def array_getitem(object: Var, key: Var) -> Var:
324324
array_ty = require_array_type(object)
325325
indices = require_array_indices(object, key)
326326
pointer = _array_get_element_pointer(object, indices)
327-
[result] = add_operation(
327+
return add_operation(
328328
LoadPointer,
329-
(TileTy(array_ty.dtype),),
329+
TileTy(array_ty.dtype),
330330
pointer=pointer,
331331
alignment=None,
332332
volatile=False,
333333
)
334-
return result
335334

336335

337336
@impl(operator.setitem, overload=(ArrayTy, WILDCARD, WILDCARD))
@@ -341,7 +340,7 @@ def array_setitem(object: Var, key: Var, value: Var):
341340
value = astype(value, array_ty.dtype)
342341
indices = require_array_indices(object, key)
343342
pointer = _array_get_element_pointer(object, indices)
344-
add_operation(
343+
add_operation_variadic(
345344
StorePointer,
346345
(),
347346
pointer=pointer,
@@ -485,15 +484,14 @@ def _pointer_load(
485484
result_ty = TileTy(pointee_dtype)
486485
else:
487486
result_ty = make_vector_ty(pointee_dtype, count)
488-
[result] = add_operation(
487+
return add_operation(
489488
LoadPointer,
490-
(result_ty,),
489+
result_ty,
491490
pointer=pointer,
492491
volatile=volatile,
493492
alignment=alignment,
494493
ordering=ordering,
495494
)
496-
return result
497495

498496

499497
def _pointer_store(
@@ -514,7 +512,7 @@ def _pointer_store(
514512
value = implicit_cast(value, pointee_dtype,
515513
"Stored value type is incompatible with pointer type")
516514

517-
add_operation(
515+
add_operation_variadic(
518516
StorePointer,
519517
(),
520518
pointer=pointer,
@@ -669,8 +667,8 @@ def _to_string_rhs(self) -> str:
669667
return f"{self.op} ^{self.target._name}({', '.join(format_var(arg) for arg in self.args)})"
670668

671669

672-
def branch(target: Block, args: tuple[Var, ...]) -> Branch:
673-
return add_operation(Branch, (), target=target, args=args)
670+
def branch(target: Block, args: tuple[Var, ...]) -> None:
671+
add_operation_variadic(Branch, (), target=target, args=args)
674672

675673

676674
@dataclass(eq=False)
@@ -699,8 +697,8 @@ def cond_branch(
699697
false_args: tuple[Var, ...],
700698
true_target: Block,
701699
false_target: Block,
702-
) -> CondBranch:
703-
return add_operation(
700+
) -> None:
701+
add_operation_variadic(
704702
CondBranch,
705703
(),
706704
cond=cond,
@@ -795,7 +793,7 @@ def enter_context_local_array_impl(manager: Var):
795793
)
796794

797795
def exit_callback():
798-
add_operation(DeallocLocalMemory, (), ptr=base_ptr)
796+
add_operation_variadic(DeallocLocalMemory, (), ptr=base_ptr)
799797

800798
mgr_ty.state.exit_callback = exit_callback
801799

@@ -912,19 +910,19 @@ class SyncThreads(Operation, opcode="syncthreads", memory_effect=MemoryEffect.ST
912910

913911

914912
@impl(stub.syncthreads)
915-
def syncthreads_impl() -> Operation:
916-
return add_operation(SyncThreads, None,)
913+
def syncthreads_impl() -> None:
914+
add_operation_variadic(SyncThreads, (),)
917915

918916

919917
@impl(stub.elect_sync)
920918
def elect_sync_impl(membermask) -> Var:
921919
mask = require_constant_int(membermask)
922920
mask = strictly_typed_const(mask & 0xffffffff, I32_TY)
923921

924-
_, is_elected = add_operation(RawNVVMIntrinsic,
925-
(I32_TY, BOOL_TY),
926-
intrinsic="llvm.nvvm.elect.sync",
927-
operands_=(mask,))
922+
_, is_elected = add_operation_variadic(RawNVVMIntrinsic,
923+
(I32_TY, BOOL_TY),
924+
intrinsic="llvm.nvvm.elect.sync",
925+
operands_=(mask,))
928926
return is_elected
929927

930928

@@ -1110,12 +1108,12 @@ def rewrite(match: re.Match[str]) -> str:
11101108

11111109

11121110
@impl(stub.inline_ptx)
1113-
def inline_ptx_impl(ptx_code: Var, constraint_pairs: tuple) -> tuple:
1111+
def inline_ptx_impl(ptx_code: Var, constraint_pairs: tuple) -> Var[TupleTy]:
11141112
ptx_code = require_constant_str(ptx_code)
11151113
mlir_ptx_code, ro_args, rw_args, wo_args = require_inline_ptx_constraint_pairs(
11161114
ptx_code, constraint_pairs)
11171115
result_types = tuple(TileTy(dtype) for dtype in wo_args)
1118-
results = add_operation(
1116+
results = add_operation_variadic(
11191117
InlinePTX,
11201118
result_types,
11211119
ptx_code=mlir_ptx_code,
@@ -1307,7 +1305,7 @@ def clusterlaunchcontrol_try_cancel_impl(addr: Var, mbar: Var, multicast: Var) -
13071305
intrinsic += ".multicast"
13081306
intrinsic += ".shared"
13091307

1310-
add_operation(
1308+
add_operation_variadic(
13111309
RawNVVMIntrinsic,
13121310
(),
13131311
intrinsic=intrinsic,
@@ -1357,21 +1355,26 @@ class ForeignFunction(Operation, opcode="foreign_function", memory_effect=Memory
13571355

13581356

13591357
@impl(stub.foreign_function._call_foreign_function)
1360-
def _call_foreign_function_impl(func: Var, return_type: Var, parameters: Var) -> Operation:
1358+
def _call_foreign_function_impl(func: Var, return_type: Var, parameters: Var):
13611359
function_name = require_constant_str(func)
1360+
require_tuple_type(parameters)
1361+
parameters = parameters.get_aggregate().items
13621362
if return_type.is_constant() and return_type.get_constant() is None:
1363-
result_type = tuple()
1363+
add_operation_variadic(
1364+
ForeignFunction,
1365+
(),
1366+
function_name=function_name,
1367+
operands_=parameters,
1368+
)
1369+
return None
13641370
else:
13651371
result_type = require_constant_result_dtype(return_type)
1366-
require_tuple_type(parameters)
1367-
parameters = parameters.get_aggregate().items
1368-
result = add_operation(
1369-
ForeignFunction,
1370-
result_type,
1371-
function_name=function_name,
1372-
operands_=parameters,
1373-
)
1374-
return result if result_type else None
1372+
return add_operation(
1373+
ForeignFunction,
1374+
result_type,
1375+
function_name=function_name,
1376+
operands_=parameters,
1377+
)
13751378

13761379

13771380
def require_mbarrier_ptr(
@@ -1397,7 +1400,7 @@ def require_mbarrier_ptr(
13971400
def mbarrier_init_impl(mbar: Var, participants: Var) -> Var:
13981401
require_mbarrier_ptr(mbar)
13991402
participants = astype(participants, datatype.int32)
1400-
add_operation(
1403+
add_operation_variadic(
14011404
RawNVVMIntrinsic,
14021405
tuple(),
14031406
intrinsic="llvm.nvvm.mbarrier.init.shared",
@@ -1408,7 +1411,7 @@ def mbarrier_init_impl(mbar: Var, participants: Var) -> Var:
14081411
@impl(stub.mbarrier_invalidate)
14091412
def mbarrier_invalidate_impl(mbar: Var) -> Var:
14101413
require_mbarrier_ptr(mbar)
1411-
add_operation(
1414+
add_operation_variadic(
14121415
RawNVVMIntrinsic,
14131416
tuple(),
14141417
intrinsic="llvm.nvvm.mbarrier.inval.shared",
@@ -1470,7 +1473,7 @@ def mbarrier_arrive_impl(
14701473
intrinsic += _mbar_space_scope_suffix(scope, space)
14711474

14721475
return_type = (TileTy(datatype.uint64),) if space is MemorySpace.SHARED else ()
1473-
results = add_operation(
1476+
results = add_operation_variadic(
14741477
RawNVVMIntrinsic,
14751478
return_type,
14761479
intrinsic=intrinsic,
@@ -1501,7 +1504,7 @@ def mbarrier_arrive_expect_tx_impl(
15011504
intrinsic += _mbar_space_scope_suffix(scope, space)
15021505

15031506
return_type = (TileTy(datatype.uint64),) if space is MemorySpace.SHARED else ()
1504-
results = add_operation(
1507+
results = add_operation_variadic(
15051508
RawNVVMIntrinsic,
15061509
return_type,
15071510
intrinsic=intrinsic,
@@ -1511,13 +1514,13 @@ def mbarrier_arrive_expect_tx_impl(
15111514

15121515

15131516
@impl(stub.mbarrier_expect_tx)
1514-
def mbarrier_expect_tx_impl(mbar: Var, bytes: Var, scope: Var) -> Var:
1517+
def mbarrier_expect_tx_impl(mbar: Var, bytes: Var, scope: Var):
15151518
space = require_mbarrier_ptr(mbar).memory_space
15161519
bytes = astype(bytes, datatype.int32)
15171520
scope = require_constant_enum(scope, MbarrierScope)
15181521
intrinsic = "llvm.nvvm.mbarrier.expect.tx"
15191522
intrinsic += _mbar_space_scope_suffix(scope, space)
1520-
add_operation(
1523+
add_operation_variadic(
15211524
RawNVVMIntrinsic,
15221525
(),
15231526
intrinsic=intrinsic,
@@ -1532,7 +1535,7 @@ def mbarrier_complete_tx_impl(mbar: Var, bytes: Var, scope: Var) -> Var:
15321535
scope = require_constant_enum(scope, MbarrierScope)
15331536
intrinsic = "llvm.nvvm.mbarrier.complete.tx"
15341537
intrinsic += _mbar_space_scope_suffix(scope, space)
1535-
add_operation(
1538+
add_operation_variadic(
15361539
RawNVVMIntrinsic,
15371540
(),
15381541
intrinsic=intrinsic,
@@ -1552,13 +1555,12 @@ def mbarrier_test_wait_impl(
15521555
if ordering is MemoryOrder.RELAXED:
15531556
intrinsic += ".relaxed"
15541557
intrinsic += _mbar_space_scope_suffix(scope, MemorySpace.SHARED)
1555-
results = add_operation(
1558+
return add_operation(
15561559
RawNVVMIntrinsic,
1557-
(TileTy(datatype.bool_),),
1560+
TileTy(datatype.bool_),
15581561
intrinsic=intrinsic,
15591562
operands_=(mbar, state),
15601563
)
1561-
return results[0]
15621564

15631565

15641566
@impl(stub.mbarrier_test_wait_parity)
@@ -1573,13 +1575,12 @@ def mbarrier_test_wait_parity_impl(
15731575
if ordering is MemoryOrder.RELAXED:
15741576
intrinsic += ".relaxed"
15751577
intrinsic += _mbar_space_scope_suffix(scope, MemorySpace.SHARED)
1576-
results = add_operation(
1578+
return add_operation(
15771579
RawNVVMIntrinsic,
1578-
(TileTy(datatype.bool_),),
1580+
TileTy(datatype.bool_),
15791581
intrinsic=intrinsic,
15801582
operands_=(mbar, parity),
15811583
)
1582-
return results[0]
15831584

15841585

15851586
def _is_none(var: Var):
@@ -1607,13 +1608,12 @@ def mbarrier_try_wait_impl(
16071608
if ordering is MemoryOrder.RELAXED:
16081609
intrinsic += ".relaxed"
16091610
intrinsic += _mbar_space_scope_suffix(scope, MemorySpace.SHARED)
1610-
results = add_operation(
1611+
return add_operation(
16111612
RawNVVMIntrinsic,
1612-
(TileTy(datatype.bool_),),
1613+
TileTy(datatype.bool_),
16131614
intrinsic=intrinsic,
16141615
operands_=args,
16151616
)
1616-
return results[0]
16171617

16181618

16191619
@impl(stub.mbarrier_try_wait_parity)
@@ -1637,13 +1637,12 @@ def mbarrier_try_wait_parity_impl(
16371637
if ordering is MemoryOrder.RELAXED:
16381638
intrinsic += ".relaxed"
16391639
intrinsic += _mbar_space_scope_suffix(scope, MemorySpace.SHARED)
1640-
results = add_operation(
1640+
return add_operation(
16411641
RawNVVMIntrinsic,
1642-
(TileTy(datatype.bool_),),
1642+
TileTy(datatype.bool_),
16431643
intrinsic=intrinsic,
16441644
operands_=args,
16451645
)
1646-
return results[0]
16471646

16481647

16491648
@impl(stub.map_shared_to_cluster)
@@ -1660,13 +1659,12 @@ def map_shared_to_cluster_impl(ptr: Var, rank: Var):
16601659
else:
16611660
result_scalar_ty = pointer_dtype(info.pointee_dtype, MemorySpace.SHARED_CLUSTER)
16621661
result_ty = TileTy(result_scalar_ty)
1663-
results = add_operation(
1662+
return add_operation(
16641663
RawNVVMIntrinsic,
1665-
(result_ty,),
1664+
result_ty,
16661665
intrinsic="llvm.nvvm.mapa.shared.cluster",
16671666
operands_=(ptr, rank),
16681667
)
1669-
return results[0]
16701668

16711669

16721670
__all__ = (

experimental/cuda-lang/src/cuda/lang/_stub/_nvvm_support.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from cuda.tile._ir.op_impl import require_integer_0d_tile_type, require_scalar_pointer_type, \
1313
require_scalar_type, require_vector_type, require_any_vector_type, \
1414
require_any_scalar_or_vector_type
15-
from cuda.tile._ir.ir import Var, add_operation
15+
from cuda.tile._ir.ir import Var, add_operation_variadic
1616
from cuda.tile._ir.ops import implicit_cast, build_tuple
1717
from cuda.tile import _datatype as datatype
1818
from cuda.tile._memory_model import MemorySpace
@@ -47,15 +47,12 @@ def _raw_nvvm_intrinsic_impl(stub, *args: Var):
4747
if stub_sig.return_annotation is None:
4848
ret_type_hints = []
4949
def make_retval(_): return None
50-
result_type_tuple = True
5150
elif typing.get_origin(stub_sig.return_annotation) is tuple:
5251
ret_type_hints = typing.get_args(stub_sig.return_annotation)
5352
def make_retval(result_vars): return build_tuple(result_vars)
54-
result_type_tuple = True
5553
else:
5654
ret_type_hints = [stub_sig.return_annotation]
57-
def make_retval(result_var): return result_var
58-
result_type_tuple = False
55+
def make_retval(result_vars): return result_vars[0]
5956

6057
result_types = []
6158
for h in ret_type_hints:
@@ -64,9 +61,9 @@ def make_retval(result_var): return result_var
6461
shape = () if ann.vector_length is None else (ann.vector_length,)
6562
result_types.append(TileTy(ann.dtype, shape))
6663

67-
return make_retval(add_operation(
64+
return make_retval(add_operation_variadic(
6865
RawNVVMIntrinsic,
69-
tuple(result_types) if result_type_tuple else result_types[0],
66+
tuple(result_types),
7067
intrinsic="llvm.nvvm." + name,
7168
operands_=tuple(prepared_operands),
7269
))

0 commit comments

Comments
 (0)