Skip to content

Commit 6fd6ccc

Browse files
committed
Update bytecode type.py
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent 24deef9 commit 6fd6ccc

2 files changed

Lines changed: 104 additions & 55 deletions

File tree

src/cuda/tile/_bytecode/type.py

Lines changed: 60 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,14 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import enum
6-
from dataclasses import dataclass
76
from typing import Sequence
87

9-
from .basic import encode_varint, encode_int_list, Table
8+
from .basic import encode_varint, encode_int_list
9+
from .type_base import TypeId, _TypeTableBase, encode_sized_typeid_seq, PaddingValue
10+
from .type_base import encode_typeid as encode_typeid # noqa: F401
1011
from .version import BytecodeVersion
1112

1213

13-
@dataclass(frozen=True)
14-
class TypeId:
15-
type_id: int
16-
17-
18-
def encode_typeid(type_id: TypeId, buf: bytearray):
19-
encode_varint(type_id.type_id, buf)
20-
21-
22-
def encode_sized_typeid_seq(type_ids: Sequence[TypeId], buf: bytearray):
23-
encode_varint(len(type_ids), buf)
24-
for i in type_ids:
25-
encode_varint(i.type_id, buf)
26-
27-
28-
# For simplicity, we always add these to the type table
29-
I1_TYPE_ID = TypeId(0)
30-
I32_TYPE_ID = TypeId(1)
31-
32-
3314
class SimpleType(enum.Enum):
3415
I1 = b"\x00"
3516
I8 = b"\x01"
@@ -46,28 +27,25 @@ class SimpleType(enum.Enum):
4627
Token = b"\x11"
4728
F8E8M0FNU = b"\x12" # since 13.2
4829
F4E2M1FN = b"\x13" # since 13.3
30+
I4 = b"\x16" # since 13.3
4931

5032

5133
class _CompositeType(enum.Enum):
5234
Pointer = b"\x0c"
5335
Tile = b"\x0d"
5436
TensorView = b"\x0e"
5537
PartitionView = b"\x0f"
56-
Func = b"\x10"
57-
38+
Function = b"\x10"
39+
GatherScatterView = b"\x14" # since 13.3
40+
StridedView = b"\x15" # since 13.3
5841

59-
class PaddingValue(enum.Enum):
60-
Missing = b""
61-
Zero = b"\x00"
62-
NegZero = b"\x01"
63-
Nan = b"\x02"
64-
PosInf = b"\x03"
65-
NegInf = b"\x04"
6642

43+
# Predefined type IDs
44+
I1_TYPE_ID = TypeId(0)
45+
I32_TYPE_ID = TypeId(1)
6746

68-
class TypeTable(Table[bytes, TypeId]):
69-
_wrapper_type = TypeId
7047

48+
class TypeTable(_TypeTableBase):
7149
def __init__(self, version: BytecodeVersion):
7250
super().__init__()
7351
self.version = version
@@ -97,23 +75,23 @@ def F32(self) -> TypeId:
9775
def Token(self) -> TypeId:
9876
return self.simple(SimpleType.Token)
9977

100-
def tile(self, dtype: TypeId, shape: Sequence[int]) -> TypeId:
101-
buf = bytearray(_CompositeType.Tile._value_)
102-
encode_varint(dtype.type_id, buf)
103-
encode_int_list(shape, 8, buf)
78+
def pointer(self, pointeeType: TypeId) -> TypeId:
79+
buf = bytearray(_CompositeType.Pointer._value_)
80+
encode_varint(pointeeType.type_id, buf)
10481
return self[bytes(buf)]
10582

106-
def pointer(self, pointee: TypeId) -> TypeId:
107-
buf = bytearray(_CompositeType.Pointer._value_)
108-
encode_varint(pointee.type_id, buf)
83+
def tile(self, elementType: TypeId, shape: Sequence[int]) -> TypeId:
84+
buf = bytearray(_CompositeType.Tile._value_)
85+
encode_varint(elementType.type_id, buf)
86+
encode_int_list(shape, 8, buf)
10987
return self[bytes(buf)]
11088

11189
def tensor_view(self,
112-
dtype: TypeId,
90+
elementType: TypeId,
11391
shape: Sequence[int],
11492
strides: Sequence[int]) -> TypeId:
11593
buf = bytearray(_CompositeType.TensorView._value_)
116-
encode_varint(dtype.type_id, buf)
94+
encode_varint(elementType.type_id, buf)
11795
encode_int_list(shape, 8, buf)
11896
encode_int_list(strides, 8, buf)
11997
return self[bytes(buf)]
@@ -124,32 +102,59 @@ def partition_view(self,
124102
dim_map: Sequence[int],
125103
padding_value: PaddingValue) -> TypeId:
126104
buf = bytearray(_CompositeType.PartitionView._value_)
127-
if self.version >= BytecodeVersion.V_13_3:
128-
# Unified bitfield: encode optional flags before parameters
105+
use_unified_bitfield = self.version >= BytecodeVersion.V_13_3
106+
if use_unified_bitfield:
129107
optional_flags = 0
130108
if padding_value != PaddingValue.Missing:
131109
optional_flags |= (1 << 0)
132110
encode_varint(optional_flags, buf)
133-
134111
encode_int_list(tile_shape, 4, buf)
135112
encode_varint(tensor_view.type_id, buf)
136113
encode_int_list(dim_map, 4, buf)
137-
if self.version >= BytecodeVersion.V_13_3:
138-
buf.extend(padding_value._value_)
114+
if use_unified_bitfield:
115+
if padding_value != PaddingValue.Missing:
116+
buf.extend(padding_value._value_)
139117
else:
140118
encode_varint(int(padding_value != PaddingValue.Missing), buf)
141119
buf.extend(padding_value._value_)
142120
return self[bytes(buf)]
143121

122+
def gather_scatter_view(self,
123+
tile_shape: Sequence[int],
124+
tensor_view: TypeId,
125+
sparse_dim: int,
126+
padding_value: PaddingValue) -> TypeId:
127+
buf = bytearray(_CompositeType.GatherScatterView._value_)
128+
optional_flags = 0
129+
if padding_value != PaddingValue.Missing:
130+
optional_flags |= (1 << 0)
131+
encode_varint(optional_flags, buf)
132+
encode_int_list(tile_shape, 4, buf)
133+
encode_varint(tensor_view.type_id, buf)
134+
encode_varint(sparse_dim, buf)
135+
buf.extend(padding_value._value_)
136+
return self[bytes(buf)]
137+
138+
def strided_view(self,
139+
tile_shape: Sequence[int],
140+
traversal_strides: Sequence[int],
141+
tensor_view: TypeId,
142+
dim_map: Sequence[int],
143+
padding_value: PaddingValue) -> TypeId:
144+
buf = bytearray(_CompositeType.StridedView._value_)
145+
optional_flags = 0
146+
if padding_value != PaddingValue.Missing:
147+
optional_flags |= (1 << 0)
148+
encode_varint(optional_flags, buf)
149+
encode_int_list(tile_shape, 4, buf)
150+
encode_int_list(traversal_strides, 4, buf)
151+
encode_varint(tensor_view.type_id, buf)
152+
encode_int_list(dim_map, 4, buf)
153+
buf.extend(padding_value._value_)
154+
return self[bytes(buf)]
155+
144156
def function(self, parameter_types: Sequence[TypeId], result_types: Sequence[TypeId]) -> TypeId:
145-
buf = bytearray(_CompositeType.Func._value_)
157+
buf = bytearray(_CompositeType.Function._value_)
146158
encode_sized_typeid_seq(parameter_types, buf)
147159
encode_sized_typeid_seq(result_types, buf)
148160
return self[bytes(buf)]
149-
150-
def _predefine(self, tag: bytes, expected_id: TypeId):
151-
if self[tag].type_id != expected_id.type_id:
152-
raise RuntimeError("Wrong type registration order")
153-
154-
def _unwrap_id(self, id: TypeId) -> int:
155-
return id.type_id
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import enum
6+
from dataclasses import dataclass
7+
from typing import Sequence
8+
9+
from .basic import Table, encode_varint
10+
11+
12+
@dataclass(frozen=True)
13+
class TypeId:
14+
type_id: int
15+
16+
17+
class PaddingValue(enum.Enum):
18+
Missing = b""
19+
Zero = b"\x00"
20+
NegZero = b"\x01"
21+
Nan = b"\x02"
22+
PosInf = b"\x03"
23+
NegInf = b"\x04"
24+
25+
26+
def encode_typeid(type_id: TypeId, buf: bytearray):
27+
encode_varint(type_id.type_id, buf)
28+
29+
30+
def encode_sized_typeid_seq(type_ids: Sequence[TypeId], buf: bytearray):
31+
encode_varint(len(type_ids), buf)
32+
for i in type_ids:
33+
encode_varint(i.type_id, buf)
34+
35+
36+
class _TypeTableBase(Table[bytes, TypeId]):
37+
_wrapper_type = TypeId
38+
39+
def _predefine(self, tag: bytes, expected_id: TypeId):
40+
if self[tag].type_id != expected_id.type_id:
41+
raise RuntimeError("Wrong type registration order")
42+
43+
def _unwrap_id(self, id: TypeId) -> int:
44+
return id.type_id

0 commit comments

Comments
 (0)