Skip to content

Commit 2ac37c2

Browse files
committed
Add support for 13.2 bytecode
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent b6c61c5 commit 2ac37c2

9 files changed

Lines changed: 179 additions & 58 deletions

File tree

src/cuda/tile/_bytecode/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
from .code_builder import CodeBuilder, Value # noqa: F401
1313
from .float import float_to_bits, float_bit_size # noqa: F401
1414
from .encodings import * # noqa: F401 F403
15+
from .version import BytecodeVersion # noqa: F401
1516

1617
DYNAMIC_SHAPE = -1 << 63 # INT64_MIN

src/cuda/tile/_bytecode/code_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .constant import ConstantTable
1212
from .debug_info import DebugAttrId
1313
from .type import TypeId, encode_typeid
14+
from .version import BytecodeVersion
1415

1516

1617
@dataclass
@@ -57,6 +58,7 @@ def done(self) -> Tuple[Value, ...]:
5758
@dataclass
5859
class CodeBuilder:
5960
buf: bytearray
61+
version: BytecodeVersion
6062
string_table: StringTable
6163
constant_table: ConstantTable
6264
debug_attr_per_op: List[DebugAttrId]

src/cuda/tile/_bytecode/encodings.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

@@ -12,6 +12,7 @@
1212
encode_unsized_variadic_operands, encode_sized_variadic_operands, encode_operand
1313
)
1414
from .type import encode_typeid, encode_sized_typeid_seq, TypeId
15+
from .version import BytecodeVersion
1516

1617

1718
class AtomicRMWMode(enum.Enum):
@@ -200,6 +201,23 @@ def encode_AssumeOp(
200201
return code_builder.new_op()
201202

202203

204+
def encode_Atan2Op( # since 13.2
205+
code_builder: CodeBuilder,
206+
result_type: TypeId, # since 13.2
207+
x: Value, # since 13.2
208+
y: Value, # since 13.2
209+
) -> Value:
210+
_buf = code_builder.buf
211+
# Opcode
212+
encode_varint(110, _buf)
213+
# Result types
214+
encode_typeid(result_type, _buf)
215+
# Operands
216+
encode_operand(x, _buf)
217+
encode_operand(y, _buf)
218+
return code_builder.new_op()
219+
220+
203221
def encode_AtomicCASTkoOp(
204222
code_builder: CodeBuilder,
205223
result_type: TypeId,
@@ -676,12 +694,18 @@ def encode_ForOp(
676694
upperBound: Value,
677695
step: Value,
678696
initValues: Sequence[Value],
697+
unsignedCmp: bool, # since 13.2
679698
) -> NestedBlockBuilder:
680699
_buf = code_builder.buf
681700
# Opcode
682701
encode_varint(41, _buf)
683702
# Variadic result types
684703
encode_sized_typeid_seq(result_types, _buf)
704+
# Flags
705+
_flag_bits = bool(unsignedCmp)
706+
assert _flag_bits < 1 or code_builder.version >= BytecodeVersion.V_13_2
707+
if code_builder.version >= BytecodeVersion.V_13_2:
708+
encode_varint(_flag_bits, _buf)
685709
# Operands
686710
encode_varint(3 + len(initValues), _buf)
687711
encode_operand(lowerBound, _buf)
@@ -1243,12 +1267,18 @@ def encode_NegIOp(
12431267
code_builder: CodeBuilder,
12441268
result_type: TypeId,
12451269
source: Value,
1270+
overflow: IntegerOverflow, # since 13.2
12461271
) -> Value:
12471272
_buf = code_builder.buf
12481273
# Opcode
12491274
encode_varint(80, _buf)
12501275
# Result types
12511276
encode_typeid(result_type, _buf)
1277+
# Attributes
1278+
if code_builder.version >= BytecodeVersion.V_13_2:
1279+
code_builder.encode_opattr_enum(IntegerOverflow, overflow)
1280+
else:
1281+
assert overflow == IntegerOverflow.NONE
12521282
# Operands
12531283
encode_operand(source, _buf)
12541284
return code_builder.new_op()
@@ -1323,22 +1353,37 @@ def encode_PowOp(
13231353
return code_builder.new_op()
13241354

13251355

1326-
def encode_PrintOp(
1356+
def encode_PrintTkoOp(
13271357
code_builder: CodeBuilder,
1358+
result_token_type: Optional[TypeId], # since 13.2
13281359
args: Sequence[Value],
1360+
token: Optional[Value], # since 13.2
13291361
str: str,
1330-
) -> None:
1362+
) -> Optional[Value]:
13311363
_buf = code_builder.buf
13321364
# Opcode
13331365
encode_varint(85, _buf)
13341366
# Variadic result types
1335-
encode_sized_typeid_seq((), _buf)
1367+
result_types = []
1368+
if code_builder.version >= BytecodeVersion.V_13_2:
1369+
result_token_idx = len(result_types)
1370+
result_types.append(result_token_type)
1371+
else:
1372+
assert result_token_type is None
1373+
result_token_idx = None
1374+
encode_sized_typeid_seq(result_types, _buf)
1375+
# Flags
1376+
_flag_bits = (token is not None)
1377+
assert _flag_bits < 1 or code_builder.version >= BytecodeVersion.V_13_2
1378+
if code_builder.version >= BytecodeVersion.V_13_2:
1379+
encode_varint(_flag_bits, _buf)
13361380
# Attributes
13371381
code_builder.encode_opattr_str(str)
13381382
# Operands
1339-
encode_varint(len(args), _buf)
1340-
encode_unsized_variadic_operands(args, _buf)
1341-
return code_builder.new_op(0)
1383+
encode_sized_variadic_operands(args, _buf)
1384+
encode_optional_operand(token, _buf)
1385+
results = code_builder.new_op(len(result_types))
1386+
return None if result_token_idx is None else results[result_token_idx]
13421387

13431388

13441389
def encode_PtrToIntOp(
@@ -1726,12 +1771,18 @@ def encode_TanHOp(
17261771
code_builder: CodeBuilder,
17271772
result_type: TypeId,
17281773
source: Value,
1774+
rounding_mode: RoundingMode, # since 13.2
17291775
) -> Value:
17301776
_buf = code_builder.buf
17311777
# Opcode
17321778
encode_varint(106, _buf)
17331779
# Result types
17341780
encode_typeid(result_type, _buf)
1781+
# Attributes
1782+
if code_builder.version >= BytecodeVersion.V_13_2:
1783+
code_builder.encode_opattr_enum(RoundingMode, rounding_mode)
1784+
else:
1785+
assert rounding_mode == RoundingMode.FULL
17351786
# Operands
17361787
encode_operand(source, _buf)
17371788
return code_builder.new_op()
@@ -1818,6 +1869,7 @@ def encode_YieldOp(
18181869
'encode_AndIOp',
18191870
'encode_AssertOp',
18201871
'encode_AssumeOp',
1872+
'encode_Atan2Op',
18211873
'encode_AtomicCASTkoOp',
18221874
'encode_AtomicRMWTkoOp',
18231875
'encode_BitcastOp',
@@ -1878,7 +1930,7 @@ def encode_YieldOp(
18781930
'encode_OrIOp',
18791931
'encode_PermuteOp',
18801932
'encode_PowOp',
1881-
'encode_PrintOp',
1933+
'encode_PrintTkoOp',
18821934
'encode_PtrToIntOp',
18831935
'encode_PtrToPtrOp',
18841936
'encode_ReduceOp',

src/cuda/tile/_bytecode/version.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import enum
6+
7+
8+
class BytecodeVersion(enum.IntEnum):
9+
V_13_1 = 130100
10+
V_13_2 = 130200
11+
12+
def major(self) -> int:
13+
return self._value_ // 10000
14+
15+
def minor(self) -> int:
16+
return (self._value_ // 100) % 100
17+
18+
def tag(self) -> int:
19+
return self._value_ % 100

src/cuda/tile/_bytecode/writer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from .constant import ConstantTable
1313
from .debug_info import DebugAttrId, DebugAttrTable
1414
from .type import TypeTable, TypeId, encode_typeid
15-
16-
17-
_BYTECODE_VERSION = (13, 1, 0)
15+
from .version import BytecodeVersion
1816

1917

2018
class FunctionBuilder(NamedTuple):
@@ -44,7 +42,7 @@ def is_defined(self, global_name: str):
4442

4543

4644
class BytecodeWriter:
47-
def __init__(self, buf: bytearray):
45+
def __init__(self, buf: bytearray, version: BytecodeVersion):
4846
self._num_functions = 0
4947
self.debug_info = []
5048
self._buf = buf
@@ -53,6 +51,7 @@ def __init__(self, buf: bytearray):
5351
self._constant_table = ConstantTable()
5452
self._type_table = TypeTable()
5553
self._global_section = GlobalSection(self._string_table, self._constant_table)
54+
self.version = version
5655

5756
@property
5857
def debug_attr_table(self) -> DebugAttrTable:
@@ -90,6 +89,7 @@ def function(self,
9089
make_entry_hints(hints).encode_tagged(self._string_table, self._buf)
9190

9291
builder = CodeBuilder(buf=bytearray(),
92+
version=self.version,
9393
string_table=self._string_table,
9494
constant_table=self._constant_table,
9595
debug_attr_per_op=self.debug_info[-1])
@@ -99,12 +99,13 @@ def function(self,
9999

100100

101101
@contextmanager
102-
def write_bytecode(num_functions: int, buf: bytearray) -> Iterator[BytecodeWriter]:
103-
_write_header(buf)
102+
def write_bytecode(num_functions: int, buf: bytearray,
103+
version: BytecodeVersion) -> Iterator[BytecodeWriter]:
104+
_write_header(buf, version)
104105

105106
with _section(_Section.Func, 8, buf) as section_buf:
106107
encode_varint(num_functions, section_buf)
107-
w = BytecodeWriter(section_buf)
108+
w = BytecodeWriter(section_buf, version)
108109
yield w
109110
assert w._num_functions == num_functions
110111

@@ -124,12 +125,11 @@ def write_bytecode(num_functions: int, buf: bytearray) -> Iterator[BytecodeWrite
124125
buf.append(_Section.EndOfBytecode._value_)
125126

126127

127-
def _write_header(buf: bytearray):
128+
def _write_header(buf: bytearray, version: BytecodeVersion):
128129
buf.extend(b"\x7fTileIR\x00") # magic number
129-
major, minor, tag = _BYTECODE_VERSION
130-
buf.append(major)
131-
buf.append(minor)
132-
buf.extend(tag.to_bytes(2, "little"))
130+
buf.append(version.major())
131+
buf.append(version.minor())
132+
buf.extend(version.tag().to_bytes(2, "little"))
133133

134134

135135
class _Section(enum.IntEnum):

0 commit comments

Comments
 (0)