@@ -100,6 +100,17 @@ def from_numpy(cls, dtype: np.dtype) -> DataType:
100100 return DataType .FLOAT4E2M1
101101 raise TypeError (f"Unsupported numpy data type: { dtype } " )
102102
103+ @classmethod
104+ def from_short_name (cls , short_name : str ) -> DataType :
105+ """Returns the ONNX data type for the short name.
106+
107+ Raises:
108+ TypeError: If the short name is not available for the data type.
109+ """
110+ if short_name not in _SHORT_NAME_TO_DATA_TYPE :
111+ raise TypeError (f"Unknown short name: { short_name } " )
112+ return cls (_SHORT_NAME_TO_DATA_TYPE [short_name ])
113+
103114 @property
104115 def itemsize (self ) -> float :
105116 """Returns the size of the data type in bytes."""
@@ -115,6 +126,22 @@ def numpy(self) -> np.dtype:
115126 raise TypeError (f"Numpy does not support ONNX data type: { self } " )
116127 return _DATA_TYPE_TO_NP_TYPE [self ]
117128
129+ def short_name (self ) -> str :
130+ """Returns the short name of the data type.
131+
132+ The short name is a string that is used to represent the data type in a more
133+ compact form. For example, the short name for `DataType.FLOAT` is "f32".
134+ To get the corresponding data type back, call ``from_short_name`` on a string.
135+
136+ Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py
137+
138+ Raises:
139+ TypeError: If the short name is not available for the data type.
140+ """
141+ if self not in _DATA_TYPE_TO_SHORT_NAME :
142+ raise TypeError (f"Short name not available for ONNX data type: { self } " )
143+ return _DATA_TYPE_TO_SHORT_NAME [self ]
144+
118145 def __repr__ (self ) -> str :
119146 return self .name
120147
@@ -184,3 +211,32 @@ def __str__(self) -> str:
184211
185212# ONNX DataType to Numpy dtype.
186213_DATA_TYPE_TO_NP_TYPE = {v : k for k , v in _NP_TYPE_TO_DATA_TYPE .items ()}
214+
215+ _DATA_TYPE_TO_SHORT_NAME = {
216+ DataType .UNDEFINED : "undefined" ,
217+ DataType .BFLOAT16 : "bf16" ,
218+ DataType .DOUBLE : "f64" ,
219+ DataType .FLOAT : "f32" ,
220+ DataType .FLOAT16 : "f16" ,
221+ DataType .FLOAT8E4M3FN : "f8e4m3fn" ,
222+ DataType .FLOAT8E5M2 : "f8e5m2" ,
223+ DataType .FLOAT8E4M3FNUZ : "f8e4m3fnuz" ,
224+ DataType .FLOAT8E5M2FNUZ : "f8e5m2fnuz" ,
225+ DataType .FLOAT4E2M1 : "f4e2m1" ,
226+ DataType .COMPLEX64 : "c64" ,
227+ DataType .COMPLEX128 : "c128" ,
228+ DataType .INT4 : "i4" ,
229+ DataType .INT8 : "i8" ,
230+ DataType .INT16 : "i16" ,
231+ DataType .INT32 : "i32" ,
232+ DataType .INT64 : "i64" ,
233+ DataType .BOOL : "b8" ,
234+ DataType .UINT4 : "u4" ,
235+ DataType .UINT8 : "u8" ,
236+ DataType .UINT16 : "u16" ,
237+ DataType .UINT32 : "u32" ,
238+ DataType .UINT64 : "u64" ,
239+ DataType .STRING : "s" ,
240+ }
241+
242+ _SHORT_NAME_TO_DATA_TYPE = {v : k for k , v in _DATA_TYPE_TO_SHORT_NAME .items ()}
0 commit comments