1717 "int8" , "int16" , "int32" , "int64" ,
1818 "float16" , "float32" , "float64" ,
1919 "bfloat16" , "tfloat32" , "float8_e4m3fn" , "float8_e5m2" ,
20+ "float8_e8m0fnu" , "float4_e2m1fn" ,
2021 "DType" , "NumericDType" , "ArithmeticDType" ,
2122 "NumericDTypeCategories" ]
2223
@@ -139,17 +140,25 @@ class ArithmeticDType(NumericDType):
139140 and 7 mantissa bits."""
140141
141142tfloat32 = NumericDType ("tfloat32" , 32 , float , bc .SimpleType .TF32 )
142- tfloat32 .__doc__ = """A 32-bit tensor floating-point |arithmetic dtype| with 1 sign \
143+ tfloat32 .__doc__ = """A 32-bit tensor floating-point |numeric dtype| with 1 sign \
143144 bit, 8 exponent bits, and 10 mantissa bits (19-bit representation stored in 32-bit container)."""
144145
145146float8_e4m3fn = NumericDType ("float8_e4m3fn" , 8 , float , bc .SimpleType .F8E4M3FN )
146- float8_e4m3fn .__doc__ = """A 8-bit floating-point |arithmetic dtype| with 1 sign bit, \
147+ float8_e4m3fn .__doc__ = """An 8-bit floating-point |numeric dtype| with 1 sign bit, \
147148 4 exponent bits, and 3 mantissa bits."""
148149
149150float8_e5m2 = NumericDType ("float8_e5m2" , 8 , float , bc .SimpleType .F8E5M2 )
150- float8_e5m2 .__doc__ = """A 8-bit floating-point |arithmetic dtype| with 1 sign bit, \
151+ float8_e5m2 .__doc__ = """An 8-bit floating-point |numeric dtype| with 1 sign bit, \
151152 5 exponent bits, and 2 mantissa bits."""
152153
154+ float8_e8m0fnu = NumericDType ("float8_e8m0fnu" , 8 , float , bc .SimpleType .F8E8M0FNU )
155+ float8_e8m0fnu .__doc__ = """An 8-bit floating-point |numeric dtype| with no sign bit, \
156+ 8 exponent bits, and 0 mantissa bits."""
157+
158+ float4_e2m1fn = NumericDType ("float4_e2m1fn" , 4 , float , bc .SimpleType .F4E2M1FN )
159+ float4_e2m1fn .__doc__ = """A 4-bit floating-point |numeric dtype| with 1 sign bit, \
160+ 2 exponent bits, and 1 mantissa bit."""
161+
153162
154163class DTypeEnum (IntEnum ):
155164 B1 = 0
@@ -168,6 +177,8 @@ class DTypeEnum(IntEnum):
168177 TF32 = 13
169178 F8E4M3FN = 14
170179 F8E5M2 = 15
180+ F8E8M0FNU = 16
181+ F4E2M1FN = 17
171182
172183
173184dtype_to_enum = {
@@ -187,6 +198,8 @@ class DTypeEnum(IntEnum):
187198 tfloat32 : DTypeEnum .TF32 ,
188199 float8_e4m3fn : DTypeEnum .F8E4M3FN ,
189200 float8_e5m2 : DTypeEnum .F8E5M2 ,
201+ float8_e8m0fnu : DTypeEnum .F8E8M0FNU ,
202+ float4_e2m1fn : DTypeEnum .F4E2M1FN ,
190203}
191204_enum_to_dtype = dict ((i , t ) for t , i in dtype_to_enum .items ())
192205
@@ -209,7 +222,7 @@ class NumericDTypeCategories:
209222 Boolean = [bool_ ]
210223 Integral = [uint8 , uint16 , uint32 , uint64 , int8 , int16 , int32 , int64 ]
211224 Float = [float16 , float32 , float64 , bfloat16 ]
212- RestrictedFloat = [tfloat32 , float8_e4m3fn , float8_e5m2 ]
225+ RestrictedFloat = [tfloat32 , float8_e4m3fn , float8_e5m2 , float8_e8m0fnu , float4_e2m1fn ]
213226
214227 @classmethod
215228 def get_category (cls , t : DType ) -> NumericDTypeCategory :
@@ -323,6 +336,8 @@ class _DTypePromotionImpl:
323336 tf = DTypeEnum .TF32
324337 f8e4m3fn = DTypeEnum .F8E4M3FN
325338 f8e5m2 = DTypeEnum .F8E5M2
339+ f8e8m0fnu = DTypeEnum .F8E8M0FNU
340+ f4e2m1fn = DTypeEnum .F4E2M1FN
326341 na = None
327342
328343 # Entries for restricted arithmetic dtypes will never be reached, but we need to keep them
@@ -337,23 +352,25 @@ class _DTypePromotionImpl:
337352 # Restricted floats requires explicit type cast
338353 # Float16 and BFloat 16 requires explicit type cast
339354 _common_dtype_table = [
340- # b1, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64, bf, tf, f8e4m3fn, f8e5m2
341- [b1 , u8 , u16 , u32 , u64 , i8 , i16 , i32 , i64 , f16 , f32 , f64 , bf , na , na , na ], # b1
342- [u8 , u8 , u16 , u32 , u64 , na , na , na , na , f16 , f32 , f64 , bf , na , na , na ], # u8
343- [u16 , u16 , u16 , u32 , u64 , na , na , na , na , f16 , f32 , f64 , bf , na , na , na ], # u16
344- [u32 , u32 , u32 , u32 , u64 , na , na , na , na , f16 , f32 , f64 , bf , na , na , na ], # u32
345- [u64 , u64 , u64 , u64 , u64 , na , na , na , na , f16 , f32 , f64 , bf , na , na , na ], # u64
346- [i8 , na , na , na , na , i8 , i16 , i32 , i64 , f16 , f32 , f64 , bf , na , na , na ], # i8
347- [i16 , na , na , na , na , i16 , i16 , i32 , i64 , f16 , f32 , f64 , bf , na , na , na ], # i16
348- [i32 , na , na , na , na , i32 , i32 , i32 , i64 , f16 , f32 , f64 , bf , na , na , na ], # i32
349- [i64 , na , na , na , na , i64 , i64 , i64 , i64 , f16 , f32 , f64 , bf , na , na , na ], # i64
350- [f16 , f16 , f16 , f16 , f16 , f16 , f16 , f16 , f16 , f16 , f32 , f64 , na , na , na , na ], # f16
351- [f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 , f64 , f32 , na , na , na ], # f32
352- [f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , na , na , na ], # f64
353- [bf , bf , bf , bf , bf , bf , bf , bf , bf , na , f32 , f64 , bf , na , na , na ], # bf
354- [na , na , na , na , na , na , na , na , na , na , na , na , na , tf , na , na ], # tf
355- [na , na , na , na , na , na , na , na , na , na , na , na , na , na , f8e4m3fn , na ], # f8e4m3fn # noqa
356- [na , na , na , na , na , na , na , na , na , na , na , na , na , na , na , f8e5m2 ], # f8e5m2 # noqa
355+ # b1, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64, bf, tf, f8e4m3fn, f8e5m2, f8e8m0fnu, f4e2m1fn # noqa
356+ [b1 , u8 , u16 , u32 , u64 , i8 , i16 , i32 , i64 , f16 , f32 , f64 , bf , na , na , na , na , na ], # b1 # noqa
357+ [u8 , u8 , u16 , u32 , u64 , na , na , na , na , f16 , f32 , f64 , bf , na , na , na , na , na ], # u8 # noqa
358+ [u16 , u16 , u16 , u32 , u64 , na , na , na , na , f16 , f32 , f64 , bf , na , na , na , na , na ], # u16 # noqa
359+ [u32 , u32 , u32 , u32 , u64 , na , na , na , na , f16 , f32 , f64 , bf , na , na , na , na , na ], # u32 # noqa
360+ [u64 , u64 , u64 , u64 , u64 , na , na , na , na , f16 , f32 , f64 , bf , na , na , na , na , na ], # u64 # noqa
361+ [i8 , na , na , na , na , i8 , i16 , i32 , i64 , f16 , f32 , f64 , bf , na , na , na , na , na ], # i8 # noqa
362+ [i16 , na , na , na , na , i16 , i16 , i32 , i64 , f16 , f32 , f64 , bf , na , na , na , na , na ], # i16 # noqa
363+ [i32 , na , na , na , na , i32 , i32 , i32 , i64 , f16 , f32 , f64 , bf , na , na , na , na , na ], # i32 # noqa
364+ [i64 , na , na , na , na , i64 , i64 , i64 , i64 , f16 , f32 , f64 , bf , na , na , na , na , na ], # i64 # noqa
365+ [f16 , f16 , f16 , f16 , f16 , f16 , f16 , f16 , f16 , f16 , f32 , f64 , na , na , na , na , na , na ], # f16 # noqa
366+ [f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 , f64 , f32 , na , na , na , na , na ], # f32 # noqa
367+ [f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , f64 , na , na , na , na , na ], # f64 # noqa
368+ [bf , bf , bf , bf , bf , bf , bf , bf , bf , na , f32 , f64 , bf , na , na , na , na , na ], # bf # noqa
369+ [na , na , na , na , na , na , na , na , na , na , na , na , na , tf , na , na , na , na ], # tf # noqa
370+ [na , na , na , na , na , na , na , na , na , na , na , na , na , na , f8e4m3fn , na , na , na ], # f8e4m3fn # noqa
371+ [na , na , na , na , na , na , na , na , na , na , na , na , na , na , na , f8e5m2 , na , na ], # f8e5m2 # noqa
372+ [na , na , na , na , na , na , na , na , na , na , na , na , na , na , na , na , f8e8m0fnu , na ], # f8e8m0fnu # noqa
373+ [na , na , na , na , na , na , na , na , na , na , na , na , na , na , na , na , na , f4e2m1fn ], # f4e2m1fn # noqa
357374 ]
358375
359376 @classmethod
@@ -423,14 +440,22 @@ def _resolve_mma_supported_dtype(x_dtype: DType,
423440
424441def _generate_rst_dtype_promotion_table () -> str :
425442 """Generate an RST table representation of the dtype promotion rules."""
426- return _generate_rst_table (_DTypePromotionImpl ._common_dtype_table )
443+ import cuda .tile
444+ # Skip dtypes not exposed in cuda.tile yet. Promomotion table is append only.
445+ n = sum (1 for dtype in _enum_to_dtype .values () if hasattr (cuda .tile , dtype .name ))
446+ table = _DTypePromotionImpl ._common_dtype_table
447+ return _generate_rst_table ([row [:n ] for row in table [:n ]])
427448
428449
429450def _generate_rst_numeric_dtypes () -> str :
430451 """Generate RST documentation for numeric datatypes."""
452+ import cuda .tile
431453 content = []
432454
433455 for dtype in numeric_dtypes :
456+ # Skip dtypes not exposed in cuda.tile yet
457+ if not hasattr (cuda .tile , dtype .name ):
458+ continue
434459 content .append (f".. autodata:: cuda.tile.{ dtype .name } " )
435460 content .append (" :annotation:" )
436461 content .append ("" ) # Empty line between types
0 commit comments