@@ -63,6 +63,18 @@ class MemoryScope(enum.Enum):
6363 SYS = b"\x02 "
6464
6565
66+ class ProgramIDDim (enum .Enum ):
67+ X = b"\x00 "
68+ Y = b"\x01 "
69+ Z = b"\x02 "
70+
71+
72+ class PtrAttr (enum .Enum ):
73+ NONE = b"\x00 "
74+ UNICAST = b"\x01 "
75+ MULTICAST = b"\x02 "
76+
77+
6678class RoundingMode (enum .Enum ):
6779 NEAREST_EVEN = b"\x00 "
6880 ZERO = b"\x01 "
@@ -313,8 +325,8 @@ def encode_AtomicRedViewTkoOp( # since 13.3
313325 code_builder : CodeBuilder ,
314326 result_token_type : TypeId , # since 13.3
315327 view : Value , # since 13.3
328+ index : Sequence [Value ], # since 13.3
316329 value : Value , # since 13.3
317- mask : Optional [Value ], # since 13.3
318330 token : Optional [Value ], # since 13.3
319331 memory_ordering_semantics : MemoryOrderingSemantics , # since 13.3
320332 memory_scope : MemoryScope , # since 13.3
@@ -323,19 +335,18 @@ def encode_AtomicRedViewTkoOp( # since 13.3
323335 _buf = code_builder .buf
324336 # Opcode
325337 encode_varint (117 , _buf )
326- # Result types
327- encode_typeid ( result_token_type , _buf )
338+ # Variadic result types
339+ encode_sized_typeid_seq (( result_token_type ,) , _buf )
328340 # Flags
329- encode_varint ((mask is not None )
330- | ((token is not None ) << 1 ), _buf )
341+ encode_varint ((token is not None ), _buf )
331342 # Attributes
332343 code_builder .encode_opattr_enum (MemoryOrderingSemantics , memory_ordering_semantics )
333344 code_builder .encode_opattr_enum (MemoryScope , memory_scope )
334345 code_builder .encode_opattr_enum (AtomicRMWMode , mode )
335346 # Operands
336347 encode_operand (view , _buf )
348+ encode_sized_variadic_operands (index , _buf )
337349 encode_operand (value , _buf )
338- encode_optional_operand (mask , _buf )
339350 encode_optional_operand (token , _buf )
340351 return code_builder .new_op ()
341352
@@ -1242,12 +1253,18 @@ def encode_MmaFOp(
12421253 lhs : Value ,
12431254 rhs : Value ,
12441255 acc : Value ,
1256+ fast_acc : bool , # since 13.3
12451257) -> Value :
12461258 _buf = code_builder .buf
12471259 # Opcode
12481260 encode_varint (73 , _buf )
12491261 # Result types
12501262 encode_typeid (result_type , _buf )
1263+ # Flags
1264+ _flag_bits = bool (fast_acc )
1265+ assert _flag_bits < 1 or code_builder .version >= BytecodeVersion .V_13_3
1266+ if code_builder .version >= BytecodeVersion .V_13_3 :
1267+ encode_varint (_flag_bits , _buf )
12511268 # Operands
12521269 encode_operand (lhs , _buf )
12531270 encode_operand (rhs , _buf )
@@ -2024,6 +2041,8 @@ def encode_YieldOp(
20242041 'IntegerOverflow' ,
20252042 'MemoryOrderingSemantics' ,
20262043 'MemoryScope' ,
2044+ 'ProgramIDDim' ,
2045+ 'PtrAttr' ,
20272046 'RoundingMode' ,
20282047 'Signedness' ,
20292048 'SymbolVisibility' ,
0 commit comments