Skip to content

Commit b63ba43

Browse files
authored
[IR] Introduce short name for dtypes (#2249)
Introduce short name for dtypes as a more compact way of describing the data types in strings. Users can already access the enums by name with e.g. `ir.DataType["DOUBLE"]`.
1 parent e55a1c6 commit b63ba43

2 files changed

Lines changed: 87 additions & 0 deletions

File tree

onnxscript/ir/_enums.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ def from_numpy(cls, dtype: np.dtype) -> DataType:
100100
return DataType.FLOAT4E2M1
101101
raise TypeError(f"Unsupported numpy data type: {dtype}")
102102

103+
@classmethod
104+
def from_short_name(cls, short_name: str) -> DataType:
105+
"""Returns the ONNX data type for the short name.
106+
107+
Raises:
108+
TypeError: If the short name is not available for the data type.
109+
"""
110+
if short_name not in _SHORT_NAME_TO_DATA_TYPE:
111+
raise TypeError(f"Unknown short name: {short_name}")
112+
return cls(_SHORT_NAME_TO_DATA_TYPE[short_name])
113+
103114
@property
104115
def itemsize(self) -> float:
105116
"""Returns the size of the data type in bytes."""
@@ -115,6 +126,22 @@ def numpy(self) -> np.dtype:
115126
raise TypeError(f"Numpy does not support ONNX data type: {self}")
116127
return _DATA_TYPE_TO_NP_TYPE[self]
117128

129+
def short_name(self) -> str:
130+
"""Returns the short name of the data type.
131+
132+
The short name is a string that is used to represent the data type in a more
133+
compact form. For example, the short name for `DataType.FLOAT` is "f32".
134+
To get the corresponding data type back, call ``from_short_name`` on a string.
135+
136+
Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py
137+
138+
Raises:
139+
TypeError: If the short name is not available for the data type.
140+
"""
141+
if self not in _DATA_TYPE_TO_SHORT_NAME:
142+
raise TypeError(f"Short name not available for ONNX data type: {self}")
143+
return _DATA_TYPE_TO_SHORT_NAME[self]
144+
118145
def __repr__(self) -> str:
119146
return self.name
120147

@@ -184,3 +211,32 @@ def __str__(self) -> str:
184211

185212
# ONNX DataType to Numpy dtype.
186213
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
214+
215+
_DATA_TYPE_TO_SHORT_NAME = {
216+
DataType.UNDEFINED: "undefined",
217+
DataType.BFLOAT16: "bf16",
218+
DataType.DOUBLE: "f64",
219+
DataType.FLOAT: "f32",
220+
DataType.FLOAT16: "f16",
221+
DataType.FLOAT8E4M3FN: "f8e4m3fn",
222+
DataType.FLOAT8E5M2: "f8e5m2",
223+
DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
224+
DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
225+
DataType.FLOAT4E2M1: "f4e2m1",
226+
DataType.COMPLEX64: "c64",
227+
DataType.COMPLEX128: "c128",
228+
DataType.INT4: "i4",
229+
DataType.INT8: "i8",
230+
DataType.INT16: "i16",
231+
DataType.INT32: "i32",
232+
DataType.INT64: "i64",
233+
DataType.BOOL: "b8",
234+
DataType.UINT4: "u4",
235+
DataType.UINT8: "u8",
236+
DataType.UINT16: "u16",
237+
DataType.UINT32: "u32",
238+
DataType.UINT64: "u64",
239+
DataType.STRING: "s",
240+
}
241+
242+
_SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()}

onnxscript/ir/_enums_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,37 @@ def test_repr_and_str_return_name(self):
122122
self.assertEqual(str(_enums.DataType.DOUBLE), "DOUBLE")
123123
self.assertEqual(repr(_enums.DataType.DOUBLE), "DOUBLE")
124124

125+
def test_short_name_conversion(self):
126+
for dtype in _enums.DataType:
127+
short_name = dtype.short_name()
128+
self.assertEqual(_enums.DataType.from_short_name(short_name), dtype)
129+
130+
def test_access_by_name(self):
131+
self.assertEqual(_enums.DataType["FLOAT"], _enums.DataType.FLOAT)
132+
self.assertEqual(_enums.DataType["UINT8"], _enums.DataType.UINT8)
133+
self.assertEqual(_enums.DataType["INT8"], _enums.DataType.INT8)
134+
self.assertEqual(_enums.DataType["UINT16"], _enums.DataType.UINT16)
135+
self.assertEqual(_enums.DataType["INT16"], _enums.DataType.INT16)
136+
self.assertEqual(_enums.DataType["INT32"], _enums.DataType.INT32)
137+
self.assertEqual(_enums.DataType["INT64"], _enums.DataType.INT64)
138+
self.assertEqual(_enums.DataType["STRING"], _enums.DataType.STRING)
139+
self.assertEqual(_enums.DataType["BOOL"], _enums.DataType.BOOL)
140+
self.assertEqual(_enums.DataType["FLOAT16"], _enums.DataType.FLOAT16)
141+
self.assertEqual(_enums.DataType["DOUBLE"], _enums.DataType.DOUBLE)
142+
self.assertEqual(_enums.DataType["UINT32"], _enums.DataType.UINT32)
143+
self.assertEqual(_enums.DataType["UINT64"], _enums.DataType.UINT64)
144+
self.assertEqual(_enums.DataType["COMPLEX64"], _enums.DataType.COMPLEX64)
145+
self.assertEqual(_enums.DataType["COMPLEX128"], _enums.DataType.COMPLEX128)
146+
self.assertEqual(_enums.DataType["BFLOAT16"], _enums.DataType.BFLOAT16)
147+
self.assertEqual(_enums.DataType["FLOAT8E4M3FN"], _enums.DataType.FLOAT8E4M3FN)
148+
self.assertEqual(_enums.DataType["FLOAT8E4M3FNUZ"], _enums.DataType.FLOAT8E4M3FNUZ)
149+
self.assertEqual(_enums.DataType["FLOAT8E5M2"], _enums.DataType.FLOAT8E5M2)
150+
self.assertEqual(_enums.DataType["FLOAT8E5M2FNUZ"], _enums.DataType.FLOAT8E5M2FNUZ)
151+
self.assertEqual(_enums.DataType["UINT4"], _enums.DataType.UINT4)
152+
self.assertEqual(_enums.DataType["INT4"], _enums.DataType.INT4)
153+
self.assertEqual(_enums.DataType["FLOAT4E2M1"], _enums.DataType.FLOAT4E2M1)
154+
self.assertEqual(_enums.DataType["UNDEFINED"], _enums.DataType.UNDEFINED)
155+
125156

126157
class AttributeTypeTest(unittest.TestCase):
127158
def test_enums_are_the_same_as_spec(self):

0 commit comments

Comments
 (0)