Skip to content

Commit 4a23090

Browse files
Update _acceptance_fn_round and revert _find_buf_dtype changes
1 parent 772e981 commit 4a23090

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

dpnp/tensor/_type_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
135135

136136
def _acceptance_fn_round(arg_dtype, buf_dt, res_dt, sycl_dev):
137137
# for boolean input, prefer floating-point output over integral
138-
if arg_dtype.char == "?" and res_dt.kind in "biu":
138+
if arg_dtype.kind == "b" and res_dt.kind != "f":
139139
return False
140140
return True
141141

@@ -195,19 +195,17 @@ def _dtype_supported_by_device_impl(
195195

196196

197197
def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn):
198-
_fp16 = sycl_dev.has_aspect_fp16
199-
_fp64 = sycl_dev.has_aspect_fp64
200-
201198
res_dt = query_fn(arg_dtype)
202199
if res_dt:
203-
if _dtype_supported_by_device_impl(res_dt, _fp16, _fp64):
204-
return None, res_dt
200+
return None, res_dt
205201

202+
_fp16 = sycl_dev.has_aspect_fp16
203+
_fp64 = sycl_dev.has_aspect_fp64
206204
all_dts = _all_data_types(_fp16, _fp64)
207205
for buf_dt in all_dts:
208206
if _can_cast(arg_dtype, buf_dt, _fp16, _fp64):
209207
res_dt = query_fn(buf_dt)
210-
if res_dt and _dtype_supported_by_device_impl(res_dt, _fp16, _fp64):
208+
if res_dt:
211209
acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev)
212210
if acceptable:
213211
return buf_dt, res_dt

0 commit comments

Comments
 (0)