|
63 | 63 | StringFormat, FormattedPiece, |
64 | 64 | ) |
65 | 65 | from cuda.tile._datatype import ( |
66 | | - DType, is_integral, is_float, is_signed, is_boolean, is_restricted_float, |
| 66 | + DType, is_integral, is_float, is_signed, is_boolean, |
67 | 67 | ) |
68 | 68 | from cuda.tile._ir2bytecode import ( |
69 | 69 | BytecodeContext, typeid, |
@@ -1026,16 +1026,20 @@ def binary_arithmetic(fn: str, x: Var, y: Var, rounding_mode: Optional[RoundingM |
1026 | 1026 | x_ty = require_tile_maybe_loose_type(x) |
1027 | 1027 | y_ty = require_tile_maybe_loose_type(y) |
1028 | 1028 |
|
1029 | | - if get_dtype(x_ty) == get_dtype(y_ty) == datatype.bool_: |
1030 | | - raise TileTypeError(f'Binary arithmetic op `{fn}` does not support bool, ' |
1031 | | - f'please cast bool to int') |
1032 | | - |
1033 | 1029 | if isinstance(x_ty, LooselyTypedScalar) and isinstance(y_ty, LooselyTypedScalar): |
1034 | 1030 | return _binop_propagate_constant(fn, x_ty.value, y_ty.value, None) |
1035 | 1031 |
|
1036 | 1032 | force_float = (fn == "truediv") |
1037 | 1033 | common_ty = promote_types(x_ty, y_ty, force_float=force_float) |
1038 | 1034 |
|
| 1035 | + common_dtype = get_dtype(common_ty) |
| 1036 | + if common_dtype == datatype.bool_: |
| 1037 | + raise TileTypeError(f'Binary arithmetic op `{fn}` does not support bool, ' |
| 1038 | + f'please cast bool to int') |
| 1039 | + if datatype.is_restricted_float(common_dtype): |
| 1040 | + raise TileTypeError( |
| 1041 | + f'Binary arithmetic op `{fn}` does not support restricted float dtype {common_dtype}') |
| 1042 | + |
1039 | 1043 | x = _promote_and_broadcast_to(x, common_ty) |
1040 | 1044 | y = _promote_and_broadcast_to(y, common_ty) |
1041 | 1045 |
|
@@ -1386,7 +1390,7 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value: |
1386 | 1390 | flush_to_zero = self.flush_to_zero |
1387 | 1391 | input_type = ctx.typeof(self.operand) |
1388 | 1392 | input_dtype = get_dtype(input_type) |
1389 | | - flt = is_float(input_dtype) or is_restricted_float(input_dtype) |
| 1393 | + flt = is_float(input_dtype) |
1390 | 1394 | res_type_id = ctx.typeid_of(self.result_var) |
1391 | 1395 |
|
1392 | 1396 | match self.fn, flt: |
@@ -1467,7 +1471,7 @@ def unary(fn: str, behavior: _UnaryBehavior, x: Var, |
1467 | 1471 | if behavior.int_handler is None: |
1468 | 1472 | raise TileTypeError("Integer inputs are not supported") |
1469 | 1473 | x = behavior.int_handler(x) |
1470 | | - elif is_float(input_dtype) or is_restricted_float(input_dtype): |
| 1474 | + elif is_float(input_dtype): |
1471 | 1475 | if behavior.float_handler is None: |
1472 | 1476 | raise TileTypeError("Float inputs are not supported") |
1473 | 1477 | x = behavior.float_handler(x) |
@@ -1569,7 +1573,7 @@ def isnan_impl(x: Var) -> Var: |
1569 | 1573 | return loosely_typed_const(res) |
1570 | 1574 |
|
1571 | 1575 | ty = x.get_type() |
1572 | | - if isinstance(x_type, TileTy) and (is_float(ty.dtype) or is_restricted_float(ty.dtype)): |
| 1576 | + if isinstance(x_type, TileTy) and is_float(ty.dtype): |
1573 | 1577 | if x.is_constant(): |
1574 | 1578 | res = math.isnan(x.get_constant()) |
1575 | 1579 | return strictly_typed_const(res, make_tile_ty(datatype.bool_, ty.shape)) |
@@ -3240,8 +3244,8 @@ async def reduce_simple(fn: str, x: Var, axis: int | None | tuple[int, ...], kee |
3240 | 3244 |
|
3241 | 3245 | async def body(lhs: tuple[Var], rhs: tuple[Var]) -> tuple[Var]: |
3242 | 3246 | [lhs], [rhs] = lhs, rhs |
3243 | | - ret = raw_binary_arithmetic(fn, lhs, rhs, |
3244 | | - rounding_mode=rounding_mode, flush_to_zero=flush_to_zero) |
| 3247 | + ret = binary_arithmetic(fn, lhs, rhs, |
| 3248 | + rounding_mode=rounding_mode, flush_to_zero=flush_to_zero) |
3245 | 3249 | return (ret,) |
3246 | 3250 |
|
3247 | 3251 | [ret] = await reduce((x,), (id_val,), axis, keepdims, body) |
@@ -3453,8 +3457,8 @@ async def scan_simple(fn: str, x: Var, axis: int, reverse: bool, |
3453 | 3457 |
|
3454 | 3458 | async def body(lhs: tuple[Var], rhs: tuple[Var]) -> tuple[Var]: |
3455 | 3459 | [lhs], [rhs] = lhs, rhs |
3456 | | - ret = raw_binary_arithmetic(fn, lhs, rhs, |
3457 | | - rounding_mode=rounding_mode, flush_to_zero=flush_to_zero) |
| 3460 | + ret = binary_arithmetic(fn, lhs, rhs, |
| 3461 | + rounding_mode=rounding_mode, flush_to_zero=flush_to_zero) |
3458 | 3462 | return (ret,) |
3459 | 3463 |
|
3460 | 3464 | [ret] = await raw_scan((x,), (id_val,), axis, reverse, body) |
|
0 commit comments