6262 Unary , implicit_cast , address_space_cast , reinterpret_pointer , AddrSpaceCast ,
6363 ReinterpretPointer ,
6464)
65- from cuda .tile ._ir .ir import MemoryEffect , make_aggregate
65+ from cuda .tile ._ir .ir import MemoryEffect , make_aggregate , add_operation_variadic
6666from cuda .lang ._exception import TileCompilerError , TileTypeError
6767import cuda .lang ._datatype as datatype
6868from cuda .tile ._datatype import (
@@ -324,14 +324,13 @@ def array_getitem(object: Var, key: Var) -> Var:
324324 array_ty = require_array_type (object )
325325 indices = require_array_indices (object , key )
326326 pointer = _array_get_element_pointer (object , indices )
327- [ result ] = add_operation (
327+ return add_operation (
328328 LoadPointer ,
329- ( TileTy (array_ty .dtype ), ),
329+ TileTy (array_ty .dtype ),
330330 pointer = pointer ,
331331 alignment = None ,
332332 volatile = False ,
333333 )
334- return result
335334
336335
337336@impl (operator .setitem , overload = (ArrayTy , WILDCARD , WILDCARD ))
@@ -341,7 +340,7 @@ def array_setitem(object: Var, key: Var, value: Var):
341340 value = astype (value , array_ty .dtype )
342341 indices = require_array_indices (object , key )
343342 pointer = _array_get_element_pointer (object , indices )
344- add_operation (
343+ add_operation_variadic (
345344 StorePointer ,
346345 (),
347346 pointer = pointer ,
@@ -485,15 +484,14 @@ def _pointer_load(
485484 result_ty = TileTy (pointee_dtype )
486485 else :
487486 result_ty = make_vector_ty (pointee_dtype , count )
488- [ result ] = add_operation (
487+ return add_operation (
489488 LoadPointer ,
490- ( result_ty ,) ,
489+ result_ty ,
491490 pointer = pointer ,
492491 volatile = volatile ,
493492 alignment = alignment ,
494493 ordering = ordering ,
495494 )
496- return result
497495
498496
499497def _pointer_store (
@@ -514,7 +512,7 @@ def _pointer_store(
514512 value = implicit_cast (value , pointee_dtype ,
515513 "Stored value type is incompatible with pointer type" )
516514
517- add_operation (
515+ add_operation_variadic (
518516 StorePointer ,
519517 (),
520518 pointer = pointer ,
@@ -669,8 +667,8 @@ def _to_string_rhs(self) -> str:
669667 return f"{ self .op } ^{ self .target ._name } ({ ', ' .join (format_var (arg ) for arg in self .args )} )"
670668
671669
672- def branch (target : Block , args : tuple [Var , ...]) -> Branch :
673- return add_operation (Branch , (), target = target , args = args )
670+ def branch (target : Block , args : tuple [Var , ...]) -> None :
671+ add_operation_variadic (Branch , (), target = target , args = args )
674672
675673
676674@dataclass (eq = False )
@@ -699,8 +697,8 @@ def cond_branch(
699697 false_args : tuple [Var , ...],
700698 true_target : Block ,
701699 false_target : Block ,
702- ) -> CondBranch :
703- return add_operation (
700+ ) -> None :
701+ add_operation_variadic (
704702 CondBranch ,
705703 (),
706704 cond = cond ,
@@ -795,7 +793,7 @@ def enter_context_local_array_impl(manager: Var):
795793 )
796794
797795 def exit_callback ():
798- add_operation (DeallocLocalMemory , (), ptr = base_ptr )
796+ add_operation_variadic (DeallocLocalMemory , (), ptr = base_ptr )
799797
800798 mgr_ty .state .exit_callback = exit_callback
801799
@@ -912,19 +910,19 @@ class SyncThreads(Operation, opcode="syncthreads", memory_effect=MemoryEffect.ST
912910
913911
914912@impl (stub .syncthreads )
915- def syncthreads_impl () -> Operation :
916- return add_operation (SyncThreads , None ,)
913+ def syncthreads_impl () -> None :
914+ add_operation_variadic (SyncThreads , () ,)
917915
918916
919917@impl (stub .elect_sync )
920918def elect_sync_impl (membermask ) -> Var :
921919 mask = require_constant_int (membermask )
922920 mask = strictly_typed_const (mask & 0xffffffff , I32_TY )
923921
924- _ , is_elected = add_operation (RawNVVMIntrinsic ,
925- (I32_TY , BOOL_TY ),
926- intrinsic = "llvm.nvvm.elect.sync" ,
927- operands_ = (mask ,))
922+ _ , is_elected = add_operation_variadic (RawNVVMIntrinsic ,
923+ (I32_TY , BOOL_TY ),
924+ intrinsic = "llvm.nvvm.elect.sync" ,
925+ operands_ = (mask ,))
928926 return is_elected
929927
930928
@@ -1110,12 +1108,12 @@ def rewrite(match: re.Match[str]) -> str:
11101108
11111109
11121110@impl (stub .inline_ptx )
1113- def inline_ptx_impl (ptx_code : Var , constraint_pairs : tuple ) -> tuple :
1111+ def inline_ptx_impl (ptx_code : Var , constraint_pairs : tuple ) -> Var [ TupleTy ] :
11141112 ptx_code = require_constant_str (ptx_code )
11151113 mlir_ptx_code , ro_args , rw_args , wo_args = require_inline_ptx_constraint_pairs (
11161114 ptx_code , constraint_pairs )
11171115 result_types = tuple (TileTy (dtype ) for dtype in wo_args )
1118- results = add_operation (
1116+ results = add_operation_variadic (
11191117 InlinePTX ,
11201118 result_types ,
11211119 ptx_code = mlir_ptx_code ,
@@ -1307,7 +1305,7 @@ def clusterlaunchcontrol_try_cancel_impl(addr: Var, mbar: Var, multicast: Var) -
13071305 intrinsic += ".multicast"
13081306 intrinsic += ".shared"
13091307
1310- add_operation (
1308+ add_operation_variadic (
13111309 RawNVVMIntrinsic ,
13121310 (),
13131311 intrinsic = intrinsic ,
@@ -1357,21 +1355,26 @@ class ForeignFunction(Operation, opcode="foreign_function", memory_effect=Memory
13571355
13581356
13591357@impl (stub .foreign_function ._call_foreign_function )
1360- def _call_foreign_function_impl (func : Var , return_type : Var , parameters : Var ) -> Operation :
1358+ def _call_foreign_function_impl (func : Var , return_type : Var , parameters : Var ):
13611359 function_name = require_constant_str (func )
1360+ require_tuple_type (parameters )
1361+ parameters = parameters .get_aggregate ().items
13621362 if return_type .is_constant () and return_type .get_constant () is None :
1363- result_type = tuple ()
1363+ add_operation_variadic (
1364+ ForeignFunction ,
1365+ (),
1366+ function_name = function_name ,
1367+ operands_ = parameters ,
1368+ )
1369+ return None
13641370 else :
13651371 result_type = require_constant_result_dtype (return_type )
1366- require_tuple_type (parameters )
1367- parameters = parameters .get_aggregate ().items
1368- result = add_operation (
1369- ForeignFunction ,
1370- result_type ,
1371- function_name = function_name ,
1372- operands_ = parameters ,
1373- )
1374- return result if result_type else None
1372+ return add_operation (
1373+ ForeignFunction ,
1374+ result_type ,
1375+ function_name = function_name ,
1376+ operands_ = parameters ,
1377+ )
13751378
13761379
13771380def require_mbarrier_ptr (
@@ -1397,7 +1400,7 @@ def require_mbarrier_ptr(
13971400def mbarrier_init_impl (mbar : Var , participants : Var ) -> Var :
13981401 require_mbarrier_ptr (mbar )
13991402 participants = astype (participants , datatype .int32 )
1400- add_operation (
1403+ add_operation_variadic (
14011404 RawNVVMIntrinsic ,
14021405 tuple (),
14031406 intrinsic = "llvm.nvvm.mbarrier.init.shared" ,
@@ -1408,7 +1411,7 @@ def mbarrier_init_impl(mbar: Var, participants: Var) -> Var:
14081411@impl (stub .mbarrier_invalidate )
14091412def mbarrier_invalidate_impl (mbar : Var ) -> Var :
14101413 require_mbarrier_ptr (mbar )
1411- add_operation (
1414+ add_operation_variadic (
14121415 RawNVVMIntrinsic ,
14131416 tuple (),
14141417 intrinsic = "llvm.nvvm.mbarrier.inval.shared" ,
@@ -1470,7 +1473,7 @@ def mbarrier_arrive_impl(
14701473 intrinsic += _mbar_space_scope_suffix (scope , space )
14711474
14721475 return_type = (TileTy (datatype .uint64 ),) if space is MemorySpace .SHARED else ()
1473- results = add_operation (
1476+ results = add_operation_variadic (
14741477 RawNVVMIntrinsic ,
14751478 return_type ,
14761479 intrinsic = intrinsic ,
@@ -1501,7 +1504,7 @@ def mbarrier_arrive_expect_tx_impl(
15011504 intrinsic += _mbar_space_scope_suffix (scope , space )
15021505
15031506 return_type = (TileTy (datatype .uint64 ),) if space is MemorySpace .SHARED else ()
1504- results = add_operation (
1507+ results = add_operation_variadic (
15051508 RawNVVMIntrinsic ,
15061509 return_type ,
15071510 intrinsic = intrinsic ,
@@ -1511,13 +1514,13 @@ def mbarrier_arrive_expect_tx_impl(
15111514
15121515
15131516@impl (stub .mbarrier_expect_tx )
1514- def mbarrier_expect_tx_impl (mbar : Var , bytes : Var , scope : Var ) -> Var :
1517+ def mbarrier_expect_tx_impl (mbar : Var , bytes : Var , scope : Var ):
15151518 space = require_mbarrier_ptr (mbar ).memory_space
15161519 bytes = astype (bytes , datatype .int32 )
15171520 scope = require_constant_enum (scope , MbarrierScope )
15181521 intrinsic = "llvm.nvvm.mbarrier.expect.tx"
15191522 intrinsic += _mbar_space_scope_suffix (scope , space )
1520- add_operation (
1523+ add_operation_variadic (
15211524 RawNVVMIntrinsic ,
15221525 (),
15231526 intrinsic = intrinsic ,
@@ -1532,7 +1535,7 @@ def mbarrier_complete_tx_impl(mbar: Var, bytes: Var, scope: Var) -> Var:
15321535 scope = require_constant_enum (scope , MbarrierScope )
15331536 intrinsic = "llvm.nvvm.mbarrier.complete.tx"
15341537 intrinsic += _mbar_space_scope_suffix (scope , space )
1535- add_operation (
1538+ add_operation_variadic (
15361539 RawNVVMIntrinsic ,
15371540 (),
15381541 intrinsic = intrinsic ,
@@ -1552,13 +1555,12 @@ def mbarrier_test_wait_impl(
15521555 if ordering is MemoryOrder .RELAXED :
15531556 intrinsic += ".relaxed"
15541557 intrinsic += _mbar_space_scope_suffix (scope , MemorySpace .SHARED )
1555- results = add_operation (
1558+ return add_operation (
15561559 RawNVVMIntrinsic ,
1557- ( TileTy (datatype .bool_ ), ),
1560+ TileTy (datatype .bool_ ),
15581561 intrinsic = intrinsic ,
15591562 operands_ = (mbar , state ),
15601563 )
1561- return results [0 ]
15621564
15631565
15641566@impl (stub .mbarrier_test_wait_parity )
@@ -1573,13 +1575,12 @@ def mbarrier_test_wait_parity_impl(
15731575 if ordering is MemoryOrder .RELAXED :
15741576 intrinsic += ".relaxed"
15751577 intrinsic += _mbar_space_scope_suffix (scope , MemorySpace .SHARED )
1576- results = add_operation (
1578+ return add_operation (
15771579 RawNVVMIntrinsic ,
1578- ( TileTy (datatype .bool_ ), ),
1580+ TileTy (datatype .bool_ ),
15791581 intrinsic = intrinsic ,
15801582 operands_ = (mbar , parity ),
15811583 )
1582- return results [0 ]
15831584
15841585
15851586def _is_none (var : Var ):
@@ -1607,13 +1608,12 @@ def mbarrier_try_wait_impl(
16071608 if ordering is MemoryOrder .RELAXED :
16081609 intrinsic += ".relaxed"
16091610 intrinsic += _mbar_space_scope_suffix (scope , MemorySpace .SHARED )
1610- results = add_operation (
1611+ return add_operation (
16111612 RawNVVMIntrinsic ,
1612- ( TileTy (datatype .bool_ ), ),
1613+ TileTy (datatype .bool_ ),
16131614 intrinsic = intrinsic ,
16141615 operands_ = args ,
16151616 )
1616- return results [0 ]
16171617
16181618
16191619@impl (stub .mbarrier_try_wait_parity )
@@ -1637,13 +1637,12 @@ def mbarrier_try_wait_parity_impl(
16371637 if ordering is MemoryOrder .RELAXED :
16381638 intrinsic += ".relaxed"
16391639 intrinsic += _mbar_space_scope_suffix (scope , MemorySpace .SHARED )
1640- results = add_operation (
1640+ return add_operation (
16411641 RawNVVMIntrinsic ,
1642- ( TileTy (datatype .bool_ ), ),
1642+ TileTy (datatype .bool_ ),
16431643 intrinsic = intrinsic ,
16441644 operands_ = args ,
16451645 )
1646- return results [0 ]
16471646
16481647
16491648@impl (stub .map_shared_to_cluster )
@@ -1660,13 +1659,12 @@ def map_shared_to_cluster_impl(ptr: Var, rank: Var):
16601659 else :
16611660 result_scalar_ty = pointer_dtype (info .pointee_dtype , MemorySpace .SHARED_CLUSTER )
16621661 result_ty = TileTy (result_scalar_ty )
1663- results = add_operation (
1662+ return add_operation (
16641663 RawNVVMIntrinsic ,
1665- ( result_ty ,) ,
1664+ result_ty ,
16661665 intrinsic = "llvm.nvvm.mapa.shared.cluster" ,
16671666 operands_ = (ptr , rank ),
16681667 )
1669- return results [0 ]
16701668
16711669
16721670__all__ = (
0 commit comments