Skip to content

Commit 549e940

Browse files
Add device-aware output dtype for dpt.round() with boolean input (#2851)
This PR proposes device-aware output dtype resolution for `dpnp.tensor.round()` with `boolean` input to handle devices that do not support `float16` Boolean support for round() was originally added in #2817 [6f5a792](6f5a792) to match NumPy behavior where numpy.round(bool) returns float16 rather than an integral type like int8. However on devices without fp16 support, returning float16 is not viable. The bool type mapping was removed from the round kernel and an acceptance function `_acceptance_fn_round` was added to ensure the fallback in `_find_buf_dtype` prefers floating-point output over integral types for boolean input Result : fp16 devices: round(bool) -> float16 non-fp16 devices: round(bool) -> float32
1 parent 4b163bf commit 549e940

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

dpnp/tensor/_elementwise_funcs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_acceptance_fn_divide,
3434
_acceptance_fn_negative,
3535
_acceptance_fn_reciprocal,
36+
_acceptance_fn_round,
3637
_acceptance_fn_subtract,
3738
_resolve_weak_types_all_py_ints,
3839
)
@@ -1723,7 +1724,11 @@
17231724
"""
17241725

17251726
round = UnaryElementwiseFunc(
1726-
"round", ti._round_result_type, ti._round, _round_docstring
1727+
"round",
1728+
ti._round_result_type,
1729+
ti._round,
1730+
_round_docstring,
1731+
acceptance_fn=_acceptance_fn_round,
17271732
)
17281733
del _round_docstring
17291734

dpnp/tensor/_type_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
133133
return True
134134

135135

136+
def _acceptance_fn_round(arg_dtype, buf_dt, res_dt, sycl_dev):
137+
# for boolean input, prefer floating-point output over integral
138+
if arg_dtype.kind == "b" and res_dt.kind != "f":
139+
return False
140+
return True
141+
142+
136143
def _acceptance_fn_subtract(
137144
arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev
138145
):
@@ -970,6 +977,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
970977
"_find_buf_dtype2",
971978
"_to_device_supported_dtype",
972979
"_acceptance_fn_default_unary",
980+
"_acceptance_fn_round",
973981
"_acceptance_fn_reciprocal",
974982
"_acceptance_fn_default_binary",
975983
"_acceptance_fn_divide",

dpnp/tensor/libtensor/include/kernels/elementwise_functions/round.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ template <typename T>
116116
struct RoundOutputType
117117
{
118118
using value_type = typename std::disjunction<
119-
td_ns::TypeMapResultEntry<T, bool, sycl::half>,
120119
td_ns::TypeMapResultEntry<T, std::uint8_t>,
121120
td_ns::TypeMapResultEntry<T, std::uint16_t>,
122121
td_ns::TypeMapResultEntry<T, std::uint32_t>,

0 commit comments

Comments
 (0)