33# SPDX-License-Identifier: Apache-2.0
44
55import enum
6- from dataclasses import dataclass
76from typing import Sequence
87
9- from .basic import encode_varint , encode_int_list , Table
8+ from .basic import encode_varint , encode_int_list
9+ from .type_base import TypeId , _TypeTableBase , encode_sized_typeid_seq , PaddingValue
10+ from .type_base import encode_typeid as encode_typeid # noqa: F401
1011from .version import BytecodeVersion
1112
1213
13- @dataclass (frozen = True )
14- class TypeId :
15- type_id : int
16-
17-
18- def encode_typeid (type_id : TypeId , buf : bytearray ):
19- encode_varint (type_id .type_id , buf )
20-
21-
22- def encode_sized_typeid_seq (type_ids : Sequence [TypeId ], buf : bytearray ):
23- encode_varint (len (type_ids ), buf )
24- for i in type_ids :
25- encode_varint (i .type_id , buf )
26-
27-
28- # For simplicity, we always add these to the type table
29- I1_TYPE_ID = TypeId (0 )
30- I32_TYPE_ID = TypeId (1 )
31-
32-
3314class SimpleType (enum .Enum ):
3415 I1 = b"\x00 "
3516 I8 = b"\x01 "
@@ -46,28 +27,25 @@ class SimpleType(enum.Enum):
4627 Token = b"\x11 "
4728 F8E8M0FNU = b"\x12 " # since 13.2
4829 F4E2M1FN = b"\x13 " # since 13.3
30+ I4 = b"\x16 " # since 13.3
4931
5032
5133class _CompositeType (enum .Enum ):
5234 Pointer = b"\x0c "
5335 Tile = b"\x0d "
5436 TensorView = b"\x0e "
5537 PartitionView = b"\x0f "
56- Func = b"\x10 "
57-
38+ Function = b"\x10 "
39+ GatherScatterView = b"\x14 " # since 13.3
40+ StridedView = b"\x15 " # since 13.3
5841
59- class PaddingValue (enum .Enum ):
60- Missing = b""
61- Zero = b"\x00 "
62- NegZero = b"\x01 "
63- Nan = b"\x02 "
64- PosInf = b"\x03 "
65- NegInf = b"\x04 "
6642
43+ # Predefined type IDs
44+ I1_TYPE_ID = TypeId (0 )
45+ I32_TYPE_ID = TypeId (1 )
6746
68- class TypeTable (Table [bytes , TypeId ]):
69- _wrapper_type = TypeId
7047
48+ class TypeTable (_TypeTableBase ):
7149 def __init__ (self , version : BytecodeVersion ):
7250 super ().__init__ ()
7351 self .version = version
@@ -97,23 +75,23 @@ def F32(self) -> TypeId:
9775 def Token (self ) -> TypeId :
9876 return self .simple (SimpleType .Token )
9977
100- def tile (self , dtype : TypeId , shape : Sequence [int ]) -> TypeId :
101- buf = bytearray (_CompositeType .Tile ._value_ )
102- encode_varint (dtype .type_id , buf )
103- encode_int_list (shape , 8 , buf )
78+ def pointer (self , pointeeType : TypeId ) -> TypeId :
79+ buf = bytearray (_CompositeType .Pointer ._value_ )
80+ encode_varint (pointeeType .type_id , buf )
10481 return self [bytes (buf )]
10582
106- def pointer (self , pointee : TypeId ) -> TypeId :
107- buf = bytearray (_CompositeType .Pointer ._value_ )
108- encode_varint (pointee .type_id , buf )
83+ def tile (self , elementType : TypeId , shape : Sequence [int ]) -> TypeId :
84+ buf = bytearray (_CompositeType .Tile ._value_ )
85+ encode_varint (elementType .type_id , buf )
86+ encode_int_list (shape , 8 , buf )
10987 return self [bytes (buf )]
11088
11189 def tensor_view (self ,
112- dtype : TypeId ,
90+ elementType : TypeId ,
11391 shape : Sequence [int ],
11492 strides : Sequence [int ]) -> TypeId :
11593 buf = bytearray (_CompositeType .TensorView ._value_ )
116- encode_varint (dtype .type_id , buf )
94+ encode_varint (elementType .type_id , buf )
11795 encode_int_list (shape , 8 , buf )
11896 encode_int_list (strides , 8 , buf )
11997 return self [bytes (buf )]
@@ -124,32 +102,59 @@ def partition_view(self,
124102 dim_map : Sequence [int ],
125103 padding_value : PaddingValue ) -> TypeId :
126104 buf = bytearray (_CompositeType .PartitionView ._value_ )
127- if self .version >= BytecodeVersion .V_13_3 :
128- # Unified bitfield: encode optional flags before parameters
105+ use_unified_bitfield = self .version >= BytecodeVersion .V_13_3
106+ if use_unified_bitfield :
129107 optional_flags = 0
130108 if padding_value != PaddingValue .Missing :
131109 optional_flags |= (1 << 0 )
132110 encode_varint (optional_flags , buf )
133-
134111 encode_int_list (tile_shape , 4 , buf )
135112 encode_varint (tensor_view .type_id , buf )
136113 encode_int_list (dim_map , 4 , buf )
137- if self .version >= BytecodeVersion .V_13_3 :
138- buf .extend (padding_value ._value_ )
114+ if use_unified_bitfield :
115+ if padding_value != PaddingValue .Missing :
116+ buf .extend (padding_value ._value_ )
139117 else :
140118 encode_varint (int (padding_value != PaddingValue .Missing ), buf )
141119 buf .extend (padding_value ._value_ )
142120 return self [bytes (buf )]
143121
122+ def gather_scatter_view (self ,
123+ tile_shape : Sequence [int ],
124+ tensor_view : TypeId ,
125+ sparse_dim : int ,
126+ padding_value : PaddingValue ) -> TypeId :
127+ buf = bytearray (_CompositeType .GatherScatterView ._value_ )
128+ optional_flags = 0
129+ if padding_value != PaddingValue .Missing :
130+ optional_flags |= (1 << 0 )
131+ encode_varint (optional_flags , buf )
132+ encode_int_list (tile_shape , 4 , buf )
133+ encode_varint (tensor_view .type_id , buf )
134+ encode_varint (sparse_dim , buf )
135+ buf .extend (padding_value ._value_ )
136+ return self [bytes (buf )]
137+
138+ def strided_view (self ,
139+ tile_shape : Sequence [int ],
140+ traversal_strides : Sequence [int ],
141+ tensor_view : TypeId ,
142+ dim_map : Sequence [int ],
143+ padding_value : PaddingValue ) -> TypeId :
144+ buf = bytearray (_CompositeType .StridedView ._value_ )
145+ optional_flags = 0
146+ if padding_value != PaddingValue .Missing :
147+ optional_flags |= (1 << 0 )
148+ encode_varint (optional_flags , buf )
149+ encode_int_list (tile_shape , 4 , buf )
150+ encode_int_list (traversal_strides , 4 , buf )
151+ encode_varint (tensor_view .type_id , buf )
152+ encode_int_list (dim_map , 4 , buf )
153+ buf .extend (padding_value ._value_ )
154+ return self [bytes (buf )]
155+
144156 def function (self , parameter_types : Sequence [TypeId ], result_types : Sequence [TypeId ]) -> TypeId :
145- buf = bytearray (_CompositeType .Func ._value_ )
157+ buf = bytearray (_CompositeType .Function ._value_ )
146158 encode_sized_typeid_seq (parameter_types , buf )
147159 encode_sized_typeid_seq (result_types , buf )
148160 return self [bytes (buf )]
149-
150- def _predefine (self , tag : bytes , expected_id : TypeId ):
151- if self [tag ].type_id != expected_id .type_id :
152- raise RuntimeError ("Wrong type registration order" )
153-
154- def _unwrap_id (self , id : TypeId ) -> int :
155- return id .type_id
0 commit comments