Skip to content

Commit f4f0a36

Browse files
committed
float4_e2m1fn and float8_e8m0fnu
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent d012284 commit f4f0a36

20 files changed

+537
-114
lines changed

cext/tile_kernel.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ static PyObject* g_default_tile_context;
4646
X(float16, 16, 1, kDLFloat) \
4747
X(float32, 32, 1, kDLFloat) \
4848
X(float64, 64, 1, kDLFloat) \
49+
X(bfloat16, 16, 1, kDLBfloat) \
4950
X(float8_e4m3fn, 8, 1, kDLFloat8_e4m3fn) \
50-
X(float8_e5m2, 8, 1, kDLFloat8_e5m2)
51+
X(float8_e5m2, 8, 1, kDLFloat8_e5m2) \
52+
X(float8_e8m0fnu, 8, 1, kDLFloat8_e8m0fnu)
5153

5254

5355
#define DECLARE_TORCH_DTYPE_GLOBAL(name, bitwidth, lanes, typecode) \
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Add `ct.float8_e8m0fnu` dtype (8-bit, unsigned, 8 exponent bits, 0 mantissa bits). A restricted float type.
5+
- Add `ct.float4_e2m1fn` dtype (4-bit, 1 sign bit, 2 exponent bits, 1 mantissa bit). A restricted float type.
6+
- Compiling float8_e8m0fnu, float4_e2m1fn operations for SM80 family, or SM90 family will raise `TileUnsupportedFeatureError`

src/cuda/tile/_bytecode/float.py

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,60 +2,113 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from enum import Enum, auto
56
import math
67
import struct
78
from typing import NamedTuple
89

10+
from cuda.tile._exception import TileInternalError
11+
912
from .type import SimpleType
1013

1114

15+
class NonFiniteBehavior(Enum):
16+
IEEE = auto()
17+
NanOnlyAllOnes = auto()
18+
FiniteOnly = auto()
19+
20+
1221
class _FloatSpec(NamedTuple):
1322
bitwidth: int
1423
emin: int
1524
emax: int
1625
exp_bits: int
1726
precision: int
18-
finite_only: bool = False
27+
non_finite_behavior: NonFiniteBehavior = NonFiniteBehavior.IEEE
1928

2029

2130
_specs = {
2231
SimpleType.F16: _FloatSpec(16, -14, 15, 5, 10),
2332
SimpleType.BF16: _FloatSpec(16, -126, 127, 8, 7),
2433
SimpleType.F32: _FloatSpec(32, -126, 127, 8, 23),
2534
SimpleType.TF32: _FloatSpec(19, -126, 127, 8, 10),
26-
SimpleType.F8E4M3FN: _FloatSpec(8, -6, 8, 4, 3, finite_only=True),
35+
SimpleType.F8E4M3FN: _FloatSpec(8, -6, 8, 4, 3, NonFiniteBehavior.NanOnlyAllOnes),
2736
SimpleType.F8E5M2: _FloatSpec(8, -14, 15, 5, 2),
37+
SimpleType.F8E8M0FNU: _FloatSpec(8, -127, 127, 8, 0, NonFiniteBehavior.NanOnlyAllOnes),
38+
SimpleType.F4E2M1FN: _FloatSpec(4, 0, 2, 2, 1, NonFiniteBehavior.FiniteOnly),
2839
}
2940

3041

31-
def float_max_value(ty: SimpleType) -> float:
32-
return _specs[ty].max_value()
33-
34-
3542
def float_bit_size(ty: SimpleType) -> int:
3643
return 64 if ty == SimpleType.F64 else _specs[ty].bitwidth
3744

3845

3946
def float_to_bits(val: float, ty: SimpleType) -> int:
4047
if ty == SimpleType.F64:
4148
return struct.unpack("<Q", struct.pack("<d", val))[0]
49+
elif ty == SimpleType.F8E8M0FNU:
50+
spec = _specs[ty]
51+
return _convert_f8e8m0fnu(val, *spec)
4252
else:
4353
spec = _specs[ty]
4454
return _convert_float(val, *spec)
4555

4656

57+
def _convert_f8e8m0fnu(val: float,
58+
bitwidth: int,
59+
emin: int,
60+
emax: int,
61+
_exp_bits: int,
62+
_precision: int,
63+
_non_finite_behavior: NonFiniteBehavior) -> int:
64+
nan = (1 << bitwidth) - 1 # NaN is encoded as all ones
65+
smallest_representable = 0 # 2^(-127)
66+
bias = -emin # no subnormals
67+
68+
if math.copysign(1.0, val) < 0:
69+
raise TileInternalError("negative values cannot be represented in an unsigned float format")
70+
71+
if val == 0.0:
72+
return smallest_representable
73+
if not math.isfinite(val):
74+
return nan
75+
76+
m, e = math.frexp(val)
77+
m *= 2 # [1.0, 2.0)
78+
e -= 1 # val = m * 2^e
79+
80+
if e > emax:
81+
return nan
82+
if e < emin:
83+
return smallest_representable
84+
85+
m -= 1.0
86+
e += bias
87+
88+
# With 0 mantissa bits, the only implicit significand value is 1 (odd),
89+
# so round to nearest even (RNE) ties always round up.
90+
round_up = (m >= 0.5)
91+
if round_up:
92+
m = 0
93+
e += 1
94+
if e > emax + bias:
95+
return nan
96+
97+
return e
98+
99+
47100
def _convert_float(val: float,
48101
bitwidth: int,
49102
emin: int,
50103
emax: int,
51104
exp_bits: int,
52105
precision: int,
53-
finite_only: bool) -> int:
106+
non_finite_behavior: NonFiniteBehavior) -> int:
54107
if val == 0.0:
55108
sign = math.copysign(1.0, val) < 0.0
56109
return sign << (bitwidth - 1)
57110
elif not math.isfinite(val):
58-
return _convert_nonfinite(val, bitwidth, exp_bits, precision, finite_only)
111+
return _convert_nonfinite(val, bitwidth, exp_bits, precision, non_finite_behavior)
59112

60113
sign, val = (1, -val) if (val < 0) else (0, val)
61114
m, e = math.frexp(val)
@@ -64,7 +117,7 @@ def _convert_float(val: float,
64117

65118
if e > emax:
66119
return _convert_nonfinite(-math.inf if sign else math.inf,
67-
bitwidth, exp_bits, precision, finite_only)
120+
bitwidth, exp_bits, precision, non_finite_behavior)
68121

69122
if e < emin:
70123
m = math.ldexp(m, e - emin)
@@ -73,26 +126,37 @@ def _convert_float(val: float,
73126
m -= 1.0
74127
e += -emin + 1
75128

76-
# Round to nearest, ties to even
129+
# Round to nearest, ties to even (RNE)
130+
# The following RNE implementation breaks when precision is 0
131+
assert precision > 0
77132
m = round(m * (1 << precision))
78133
if m == (1 << precision):
79134
m = 0
80135
e += 1
81136
if e > emax - emin + 1:
82137
return _convert_nonfinite(-math.inf if sign else math.inf,
83-
bitwidth, exp_bits, precision, finite_only)
138+
bitwidth, exp_bits, precision, non_finite_behavior)
84139
bits = (sign << (bitwidth - 1)) | (e << precision) | m
85140
return bits
86141

87142

88-
def _convert_nonfinite(val, bitwidth, exp_bits, precision, finite_only) -> int:
89-
if finite_only:
90-
# NaN is encoded as all ones
91-
sign = math.copysign(1.0, val) < 0.0
92-
return (sign << (bitwidth - 1)) | ((1 << (bitwidth - 1)) - 1)
93-
else:
94-
# Exponent is all ones. Truncate the low bits, preserve the rest of the payload
95-
float64_bits, = struct.unpack("<Q", struct.pack("<d", val))
96-
payload = (float64_bits >> (52 - precision)) & ((1 << precision) - 1)
97-
hi_bits = (float64_bits >> (63 - exp_bits)) << precision
98-
return hi_bits | payload
143+
def _convert_nonfinite(val, bitwidth, exp_bits, precision, non_finite_behavior) -> int:
144+
match non_finite_behavior:
145+
case NonFiniteBehavior.NanOnlyAllOnes:
146+
# NaN is encoded as all ones
147+
sign = math.copysign(1.0, val) < 0.0
148+
return (sign << (bitwidth - 1)) | ((1 << (bitwidth - 1)) - 1)
149+
case NonFiniteBehavior.FiniteOnly:
150+
if math.isnan(val):
151+
raise TileInternalError("NaN cannot be represented in a finite-only float format")
152+
# Clamp to max representable magnitude, preserve sign
153+
sign = math.copysign(1.0, val) < 0.0
154+
return (sign << (bitwidth - 1)) | ((1 << (bitwidth - 1)) - 1)
155+
case NonFiniteBehavior.IEEE:
156+
# Exponent is all ones. Truncate the low bits, preserve the rest of the payload
157+
float64_bits, = struct.unpack("<Q", struct.pack("<d", val))
158+
payload = (float64_bits >> (52 - precision)) & ((1 << precision) - 1)
159+
hi_bits = (float64_bits >> (63 - exp_bits)) << precision
160+
return hi_bits | payload
161+
case _:
162+
assert False

src/cuda/tile/_bytecode/type.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class SimpleType(enum.Enum):
4444
F8E4M3FN = b"\x0a"
4545
F8E5M2 = b"\x0b"
4646
Token = b"\x11"
47-
Unknown = b"\x12"
47+
F8E8M0FNU = b"\x12" # since 13.2
48+
F4E2M1FN = b"\x13" # since 13.3
4849

4950

5051
class _CompositeType(enum.Enum):

src/cuda/tile/_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151

5252
from cuda.tile._passes.alias_analysis import alias_analysis_pass
53-
from cuda.tile._passes.check_ampere_fp8 import check_ampere_fp8
53+
from cuda.tile._passes.check_dtype_support import check_dtype_support
5454
from cuda.tile._passes.dce import dead_code_elimination_pass
5555
from cuda.tile._passes.token_order import token_order_pass
5656
from cuda.tile._cache import cache_key, cache_lookup, cache_store, evict_lru
@@ -219,7 +219,7 @@ def compile_tile(pyfunc,
219219
print(f'\n{code}', file=sys.stderr)
220220

221221
sm_arch = get_sm_arch()
222-
check_ampere_fp8(func_ir.body, sm_arch)
222+
check_dtype_support(func_ir.body, sm_arch, bytecode_version)
223223

224224
bytecode_generator = functools.partial(generate_bytecode_for_kernel,
225225
func_ir, compiler_options, sm_arch)

src/cuda/tile/_datatype.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"int8", "int16", "int32", "int64",
1818
"float16", "float32", "float64",
1919
"bfloat16", "tfloat32", "float8_e4m3fn", "float8_e5m2",
20+
"float8_e8m0fnu", "float4_e2m1fn",
2021
"DType", "NumericDType", "ArithmeticDType",
2122
"NumericDTypeCategories"]
2223

@@ -139,17 +140,25 @@ class ArithmeticDType(NumericDType):
139140
and 7 mantissa bits."""
140141

141142
tfloat32 = NumericDType("tfloat32", 32, float, bc.SimpleType.TF32)
142-
tfloat32.__doc__ = """A 32-bit tensor floating-point |arithmetic dtype| with 1 sign \
143+
tfloat32.__doc__ = """A 32-bit tensor floating-point |numeric dtype| with 1 sign \
143144
bit, 8 exponent bits, and 10 mantissa bits (19-bit representation stored in 32-bit container)."""
144145

145146
float8_e4m3fn = NumericDType("float8_e4m3fn", 8, float, bc.SimpleType.F8E4M3FN)
146-
float8_e4m3fn.__doc__ = """A 8-bit floating-point |arithmetic dtype| with 1 sign bit, \
147+
float8_e4m3fn.__doc__ = """An 8-bit floating-point |numeric dtype| with 1 sign bit, \
147148
4 exponent bits, and 3 mantissa bits."""
148149

149150
float8_e5m2 = NumericDType("float8_e5m2", 8, float, bc.SimpleType.F8E5M2)
150-
float8_e5m2.__doc__ = """A 8-bit floating-point |arithmetic dtype| with 1 sign bit, \
151+
float8_e5m2.__doc__ = """An 8-bit floating-point |numeric dtype| with 1 sign bit, \
151152
5 exponent bits, and 2 mantissa bits."""
152153

154+
float8_e8m0fnu = NumericDType("float8_e8m0fnu", 8, float, bc.SimpleType.F8E8M0FNU)
155+
float8_e8m0fnu.__doc__ = """An 8-bit floating-point |numeric dtype| with no sign bit, \
156+
8 exponent bits, and 0 mantissa bits."""
157+
158+
float4_e2m1fn = NumericDType("float4_e2m1fn", 4, float, bc.SimpleType.F4E2M1FN)
159+
float4_e2m1fn.__doc__ = """A 4-bit floating-point |numeric dtype| with 1 sign bit, \
160+
2 exponent bits, and 1 mantissa bit."""
161+
153162

154163
class DTypeEnum(IntEnum):
155164
B1 = 0
@@ -168,6 +177,8 @@ class DTypeEnum(IntEnum):
168177
TF32 = 13
169178
F8E4M3FN = 14
170179
F8E5M2 = 15
180+
F8E8M0FNU = 16
181+
F4E2M1FN = 17
171182

172183

173184
dtype_to_enum = {
@@ -187,6 +198,8 @@ class DTypeEnum(IntEnum):
187198
tfloat32: DTypeEnum.TF32,
188199
float8_e4m3fn: DTypeEnum.F8E4M3FN,
189200
float8_e5m2: DTypeEnum.F8E5M2,
201+
float8_e8m0fnu: DTypeEnum.F8E8M0FNU,
202+
float4_e2m1fn: DTypeEnum.F4E2M1FN,
190203
}
191204
_enum_to_dtype = dict((i, t) for t, i in dtype_to_enum.items())
192205

@@ -209,7 +222,7 @@ class NumericDTypeCategories:
209222
Boolean = [bool_]
210223
Integral = [uint8, uint16, uint32, uint64, int8, int16, int32, int64]
211224
Float = [float16, float32, float64, bfloat16]
212-
RestrictedFloat = [tfloat32, float8_e4m3fn, float8_e5m2]
225+
RestrictedFloat = [tfloat32, float8_e4m3fn, float8_e5m2, float8_e8m0fnu, float4_e2m1fn]
213226

214227
@classmethod
215228
def get_category(cls, t: DType) -> NumericDTypeCategory:
@@ -323,6 +336,8 @@ class _DTypePromotionImpl:
323336
tf = DTypeEnum.TF32
324337
f8e4m3fn = DTypeEnum.F8E4M3FN
325338
f8e5m2 = DTypeEnum.F8E5M2
339+
f8e8m0fnu = DTypeEnum.F8E8M0FNU
340+
f4e2m1fn = DTypeEnum.F4E2M1FN
326341
na = None
327342

328343
# Entries for restricted arithmetic dtypes will never be reached, but we need to keep them
@@ -337,23 +352,25 @@ class _DTypePromotionImpl:
337352
# Restricted floats requires explicit type cast
338353
# Float16 and BFloat 16 requires explicit type cast
339354
_common_dtype_table = [
340-
# b1, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64, bf, tf, f8e4m3fn, f8e5m2
341-
[b1, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64, bf, na, na, na], # b1
342-
[u8, u8, u16, u32, u64, na, na, na, na, f16, f32, f64, bf, na, na, na], # u8
343-
[u16, u16, u16, u32, u64, na, na, na, na, f16, f32, f64, bf, na, na, na], # u16
344-
[u32, u32, u32, u32, u64, na, na, na, na, f16, f32, f64, bf, na, na, na], # u32
345-
[u64, u64, u64, u64, u64, na, na, na, na, f16, f32, f64, bf, na, na, na], # u64
346-
[i8, na, na, na, na, i8, i16, i32, i64, f16, f32, f64, bf, na, na, na], # i8
347-
[i16, na, na, na, na, i16, i16, i32, i64, f16, f32, f64, bf, na, na, na], # i16
348-
[i32, na, na, na, na, i32, i32, i32, i64, f16, f32, f64, bf, na, na, na], # i32
349-
[i64, na, na, na, na, i64, i64, i64, i64, f16, f32, f64, bf, na, na, na], # i64
350-
[f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f32, f64, na, na, na, na], # f16
351-
[f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f64, f32, na, na, na], # f32
352-
[f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, na, na, na], # f64
353-
[bf, bf, bf, bf, bf, bf, bf, bf, bf, na, f32, f64, bf, na, na, na], # bf
354-
[na, na, na, na, na, na, na, na, na, na, na, na, na, tf, na, na], # tf
355-
[na, na, na, na, na, na, na, na, na, na, na, na, na, na, f8e4m3fn, na], # f8e4m3fn # noqa
356-
[na, na, na, na, na, na, na, na, na, na, na, na, na, na, na, f8e5m2], # f8e5m2 # noqa
355+
# b1, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64, bf, tf, f8e4m3fn, f8e5m2, f8e8m0fnu, f4e2m1fn # noqa
356+
[b1, u8, u16, u32, u64, i8, i16, i32, i64, f16, f32, f64, bf, na, na, na, na, na], # b1 # noqa
357+
[u8, u8, u16, u32, u64, na, na, na, na, f16, f32, f64, bf, na, na, na, na, na], # u8 # noqa
358+
[u16, u16, u16, u32, u64, na, na, na, na, f16, f32, f64, bf, na, na, na, na, na], # u16 # noqa
359+
[u32, u32, u32, u32, u64, na, na, na, na, f16, f32, f64, bf, na, na, na, na, na], # u32 # noqa
360+
[u64, u64, u64, u64, u64, na, na, na, na, f16, f32, f64, bf, na, na, na, na, na], # u64 # noqa
361+
[i8, na, na, na, na, i8, i16, i32, i64, f16, f32, f64, bf, na, na, na, na, na], # i8 # noqa
362+
[i16, na, na, na, na, i16, i16, i32, i64, f16, f32, f64, bf, na, na, na, na, na], # i16 # noqa
363+
[i32, na, na, na, na, i32, i32, i32, i64, f16, f32, f64, bf, na, na, na, na, na], # i32 # noqa
364+
[i64, na, na, na, na, i64, i64, i64, i64, f16, f32, f64, bf, na, na, na, na, na], # i64 # noqa
365+
[f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f32, f64, na, na, na, na, na, na], # f16 # noqa
366+
[f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f64, f32, na, na, na, na, na], # f32 # noqa
367+
[f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, na, na, na, na, na], # f64 # noqa
368+
[bf, bf, bf, bf, bf, bf, bf, bf, bf, na, f32, f64, bf, na, na, na, na, na], # bf # noqa
369+
[na, na, na, na, na, na, na, na, na, na, na, na, na, tf, na, na, na, na], # tf # noqa
370+
[na, na, na, na, na, na, na, na, na, na, na, na, na, na, f8e4m3fn, na, na, na], # f8e4m3fn # noqa
371+
[na, na, na, na, na, na, na, na, na, na, na, na, na, na, na, f8e5m2, na, na], # f8e5m2 # noqa
372+
[na, na, na, na, na, na, na, na, na, na, na, na, na, na, na, na, f8e8m0fnu, na], # f8e8m0fnu # noqa
373+
[na, na, na, na, na, na, na, na, na, na, na, na, na, na, na, na, na, f4e2m1fn], # f4e2m1fn # noqa
357374
]
358375

359376
@classmethod
@@ -423,14 +440,22 @@ def _resolve_mma_supported_dtype(x_dtype: DType,
423440

424441
def _generate_rst_dtype_promotion_table() -> str:
425442
"""Generate an RST table representation of the dtype promotion rules."""
426-
return _generate_rst_table(_DTypePromotionImpl._common_dtype_table)
443+
import cuda.tile
444+
# Skip dtypes not exposed in cuda.tile yet. Promomotion table is append only.
445+
n = sum(1 for dtype in _enum_to_dtype.values() if hasattr(cuda.tile, dtype.name))
446+
table = _DTypePromotionImpl._common_dtype_table
447+
return _generate_rst_table([row[:n] for row in table[:n]])
427448

428449

429450
def _generate_rst_numeric_dtypes() -> str:
430451
"""Generate RST documentation for numeric datatypes."""
452+
import cuda.tile
431453
content = []
432454

433455
for dtype in numeric_dtypes:
456+
# Skip dtypes not exposed in cuda.tile yet
457+
if not hasattr(cuda.tile, dtype.name):
458+
continue
434459
content.append(f".. autodata:: cuda.tile.{dtype.name}")
435460
content.append(" :annotation:")
436461
content.append("") # Empty line between types

src/cuda/tile/_ir/ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from typing import (
1717
List, Optional, Dict, Tuple, Any, TYPE_CHECKING, Sequence, Iterator
1818
)
19-
from .type import Type, InvalidType
19+
20+
from cuda.tile._ir.type import Type, InvalidType
2021
from cuda.tile._exception import (
2122
TileTypeError, Loc, TileInternalError
2223
)

src/cuda/tile/_ir/typing_support.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,9 @@ def get_constant_value(val: Any) -> Any:
249249
# =====CuTile native support ===========
250250
# register cuTile native dtype types
251251
for dtype in datatype.dtype_to_enum:
252-
register_dtypes({dtype: dtype}, usable_as_constructor=True)
252+
# only allow byte aligned dtypes as constructors
253+
usable_as_constructor = (dtype.bitwidth % 8 == 0)
254+
register_dtypes({dtype: dtype}, usable_as_constructor)
253255

254256

255257
# ========= Numpy support ===========
@@ -322,7 +324,8 @@ def get_constant_value(val: Any) -> Any:
322324
torch.bool: datatype.bool_,
323325
torch.bfloat16: datatype.bfloat16,
324326
torch.float8_e4m3fn: datatype.float8_e4m3fn,
325-
torch.float8_e5m2: datatype.float8_e5m2
327+
torch.float8_e5m2: datatype.float8_e5m2,
328+
torch.float8_e8m0fnu: datatype.float8_e8m0fnu,
326329
})
327330

328331
@register_type_handler(torch.Tensor, allow_subtypes=True)

0 commit comments

Comments
 (0)