Skip to content

Commit f20e5cd

Browse files
committed
Properly catch restricted dtype in simple scan and reduce op
Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent 708a44b commit f20e5cd

File tree

11 files changed

+57
-23
lines changed

11 files changed

+57
-23
lines changed

changelog.d/restricted-dtype.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Fix a bug where restricted float dtype with simple reduce and scan did not raise proper TileTypeError

src/cuda/tile/_datatype.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ def get_signedness(t: DType) -> bc.Signedness:
278278

279279

280280
def is_float(t: DType) -> bool:
281+
return t in NumericDTypeCategories.Float or t in NumericDTypeCategories.RestrictedFloat
282+
283+
284+
def is_unrestricted_float(t: DType) -> bool:
281285
return t in NumericDTypeCategories.Float
282286

283287

src/cuda/tile/_ir/op_impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Optional, NamedTuple, Tuple, Sequence, Any, Union, Callable
1111

1212
from cuda.tile._datatype import (
13-
is_integral, is_float, is_restricted_float,
13+
is_integral, is_float,
1414
is_boolean, is_signed, DType)
1515
from cuda.tile._bytecode.version import BytecodeVersion
1616
from cuda.tile._exception import TileTypeError, TileUnsupportedFeatureError
@@ -430,7 +430,7 @@ class PrintfValidator:
430430
def infer_format(cls, dtype: DType) -> str:
431431
if is_boolean(dtype) or is_integral(dtype):
432432
return '%d'
433-
elif is_float(dtype) or is_restricted_float(dtype):
433+
elif is_float(dtype):
434434
return '%f'
435435
else:
436436
raise TileTypeError(f"print(): cannot infer format for dtype {dtype}")
@@ -439,7 +439,7 @@ def infer_format(cls, dtype: DType) -> str:
439439
def validate_dtype(cls, dtype: DType, specifier: str) -> bool:
440440
if is_boolean(dtype) or is_integral(dtype):
441441
return specifier in cls.int_specifiers
442-
elif is_float(dtype) or is_restricted_float(dtype):
442+
elif is_float(dtype):
443443
return specifier in cls.float_specifiers
444444
else:
445445
return False

src/cuda/tile/_ir/ops.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
StringFormat, FormattedPiece,
6464
)
6565
from cuda.tile._datatype import (
66-
DType, is_integral, is_float, is_signed, is_boolean, is_restricted_float,
66+
DType, is_integral, is_float, is_signed, is_boolean,
6767
)
6868
from cuda.tile._ir2bytecode import (
6969
BytecodeContext, typeid,
@@ -1026,16 +1026,20 @@ def binary_arithmetic(fn: str, x: Var, y: Var, rounding_mode: Optional[RoundingM
10261026
x_ty = require_tile_maybe_loose_type(x)
10271027
y_ty = require_tile_maybe_loose_type(y)
10281028

1029-
if get_dtype(x_ty) == get_dtype(y_ty) == datatype.bool_:
1030-
raise TileTypeError(f'Binary arithmetic op `{fn}` does not support bool, '
1031-
f'please cast bool to int')
1032-
10331029
if isinstance(x_ty, LooselyTypedScalar) and isinstance(y_ty, LooselyTypedScalar):
10341030
return _binop_propagate_constant(fn, x_ty.value, y_ty.value, None)
10351031

10361032
force_float = (fn == "truediv")
10371033
common_ty = promote_types(x_ty, y_ty, force_float=force_float)
10381034

1035+
common_dtype = get_dtype(common_ty)
1036+
if common_dtype == datatype.bool_:
1037+
raise TileTypeError(f'Binary arithmetic op `{fn}` does not support bool, '
1038+
f'please cast bool to int')
1039+
if datatype.is_restricted_float(common_dtype):
1040+
raise TileTypeError(
1041+
f'Binary arithmetic op `{fn}` does not support restricted float dtype {common_dtype}')
1042+
10391043
x = _promote_and_broadcast_to(x, common_ty)
10401044
y = _promote_and_broadcast_to(y, common_ty)
10411045

@@ -1386,7 +1390,7 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
13861390
flush_to_zero = self.flush_to_zero
13871391
input_type = ctx.typeof(self.operand)
13881392
input_dtype = get_dtype(input_type)
1389-
flt = is_float(input_dtype) or is_restricted_float(input_dtype)
1393+
flt = is_float(input_dtype)
13901394
res_type_id = ctx.typeid_of(self.result_var)
13911395

13921396
match self.fn, flt:
@@ -1467,7 +1471,7 @@ def unary(fn: str, behavior: _UnaryBehavior, x: Var,
14671471
if behavior.int_handler is None:
14681472
raise TileTypeError("Integer inputs are not supported")
14691473
x = behavior.int_handler(x)
1470-
elif is_float(input_dtype) or is_restricted_float(input_dtype):
1474+
elif is_float(input_dtype):
14711475
if behavior.float_handler is None:
14721476
raise TileTypeError("Float inputs are not supported")
14731477
x = behavior.float_handler(x)
@@ -1569,7 +1573,7 @@ def isnan_impl(x: Var) -> Var:
15691573
return loosely_typed_const(res)
15701574

15711575
ty = x.get_type()
1572-
if isinstance(x_type, TileTy) and (is_float(ty.dtype) or is_restricted_float(ty.dtype)):
1576+
if isinstance(x_type, TileTy) and is_float(ty.dtype):
15731577
if x.is_constant():
15741578
res = math.isnan(x.get_constant())
15751579
return strictly_typed_const(res, make_tile_ty(datatype.bool_, ty.shape))
@@ -3240,8 +3244,8 @@ async def reduce_simple(fn: str, x: Var, axis: int | None | tuple[int, ...], kee
32403244

32413245
async def body(lhs: tuple[Var], rhs: tuple[Var]) -> tuple[Var]:
32423246
[lhs], [rhs] = lhs, rhs
3243-
ret = raw_binary_arithmetic(fn, lhs, rhs,
3244-
rounding_mode=rounding_mode, flush_to_zero=flush_to_zero)
3247+
ret = binary_arithmetic(fn, lhs, rhs,
3248+
rounding_mode=rounding_mode, flush_to_zero=flush_to_zero)
32453249
return (ret,)
32463250

32473251
[ret] = await reduce((x,), (id_val,), axis, keepdims, body)
@@ -3453,8 +3457,8 @@ async def scan_simple(fn: str, x: Var, axis: int, reverse: bool,
34533457

34543458
async def body(lhs: tuple[Var], rhs: tuple[Var]) -> tuple[Var]:
34553459
[lhs], [rhs] = lhs, rhs
3456-
ret = raw_binary_arithmetic(fn, lhs, rhs,
3457-
rounding_mode=rounding_mode, flush_to_zero=flush_to_zero)
3460+
ret = binary_arithmetic(fn, lhs, rhs,
3461+
rounding_mode=rounding_mode, flush_to_zero=flush_to_zero)
34583462
return (ret,)
34593463

34603464
[ret] = await raw_scan((x,), (id_val,), axis, reverse, body)

src/cuda/tile/_ir/ops_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ def check_rd_and_ftz(fn: str, rounding_mode: Optional[RoundingMode], flush_to_ze
161161
f'{fn} rounding_mode={rounding_mode.value} requires tileiras '
162162
f'{min_version.major()}.{min_version.minor()} or later. '
163163
f'Current version is {cur_version.major()}.{cur_version.minor()}.')
164-
if not datatype.is_float(dtype):
164+
if not datatype.is_unrestricted_float(dtype):
165165
raise TileTypeError(
166-
f'Rounding mode can only be used for float types, '
166+
f'Rounding mode can only be used for unrestricted float types, '
167167
f'but got {dtype}')
168168
if rounding_mode in [RoundingMode.APPROX, RoundingMode.FULL]:
169169
if dtype != datatype.float32:

src/cuda/tile/_ir2bytecode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _constant_to_bytes(value: int | float, dtype: DType) -> bytes:
8181
return b"\xff" if value else b"\x00"
8282
elif datatype.is_integral(dtype):
8383
return int(value).to_bytes((dtype.bitwidth + 7) // 8, "little", signed=value < 0)
84-
elif datatype.is_float(dtype) or datatype.is_restricted_float(dtype):
84+
elif datatype.is_float(dtype):
8585
# Note that TF32 is stored as 3 bytes despite the "32" in its name.
8686
# Its float_bit_size() is 19 bits, which is rounded up to 24 bits.
8787
bits = bc.float_to_bits(value, dtype._bytecode_type)
@@ -94,7 +94,7 @@ def _constant_to_bytes(value: int | float, dtype: DType) -> bytes:
9494
def _get_type_conversion_encoder(from_dtype: Type, to_dtype: Type):
9595

9696
def kind(t):
97-
if datatype.is_float(t) or datatype.is_restricted_float(t):
97+
if datatype.is_float(t):
9898
return 'f'
9999
if datatype.is_integral(t) or datatype.is_boolean(t):
100100
return 'si' if datatype.is_signed(t) else 'ui'

src/cuda/tile/_passes/rewrite_patterns.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def match_float_mul(op: RawBinaryArithmeticOperation,
8484
ctx: MatchContext) -> RawBinaryArithmeticOperation:
8585
if op.fn != "mul":
8686
raise NoMatch("not a mul binop")
87-
if not datatype.is_float(get_dtype(ctx.typeof(op.result_var))):
88-
raise NoMatch("not a float mul")
87+
if not datatype.is_unrestricted_float(get_dtype(ctx.typeof(op.result_var))):
88+
raise NoMatch("not an unrestricted float mul")
8989
return op
9090

9191

test/test_binary_elementwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_array_core_arithmetic_rounding_mode(
208208
launch_binary(kernel, x, y, z, tile)
209209
elif should_raise_dtype:
210210
with pytest.raises(TileTypeError,
211-
match=r"Rounding mode can only be used for float types"):
211+
match=r"Rounding mode can only be used for unrestricted float types"):
212212
launch_binary(kernel, x, y, z, tile)
213213
else:
214214
bytecode = get_bytecode(kernel, (x, y, z, tile))

test/test_ir_types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
int64, int32, int16, int8,
1818
uint64, uint32, uint16, uint8, bfloat16,
1919
tfloat32, float8_e4m3fn, float8_e5m2,
20-
is_boolean, is_integral, is_float, is_restricted_float, is_signed,
20+
is_boolean, is_integral, is_float, is_unrestricted_float, is_restricted_float, is_signed,
2121
)
2222
from cuda.tile._ir.ops_utils import promote_dtypes, check_implicit_cast
2323
from cuda.tile._ir.typing_support import to_dtype, typeof_pyval
@@ -51,6 +51,9 @@ def test_builtin_types():
5151
assert is_boolean(bool_)
5252
assert is_float(bfloat16)
5353
assert not is_float(uint32)
54+
assert is_unrestricted_float(bfloat16)
55+
assert not is_unrestricted_float(tfloat32)
56+
assert is_float(tfloat32)
5457
assert is_restricted_float(tfloat32)
5558
assert is_restricted_float(float8_e4m3fn)
5659
assert is_restricted_float(float8_e5m2)

test/test_reduction.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,17 @@ def kernel(x):
190190
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
191191

192192

193+
def test_reduce_sum_restricted_dtype_error():
194+
@ct.kernel
195+
def kernel(x):
196+
tx = ct.load(x, (0,), (16,))
197+
ct.sum(tx, axis=0)
198+
199+
x = torch.rand((16,), dtype=torch.float32, device="cuda").to(torch.float8_e4m3fn)
200+
with pytest.raises(TileTypeError, match="does not support restricted float dtype"):
201+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
202+
203+
193204
sumprod_cases = [
194205
pytest.param(ct.sum, torch.sum, id="sum"),
195206
pytest.param(ct.prod, torch.prod, id="prod"),

0 commit comments

Comments
 (0)