@@ -135,7 +135,7 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
135135
136136def _acceptance_fn_round (arg_dtype , buf_dt , res_dt , sycl_dev ):
137137 # for boolean input, prefer floating-point output over integral
138- if arg_dtype .char == "? " and res_dt .kind in "biu " :
138+ if arg_dtype .kind == "b " and res_dt .kind != "f " :
139139 return False
140140 return True
141141
@@ -195,19 +195,17 @@ def _dtype_supported_by_device_impl(
195195
196196
197197def _find_buf_dtype (arg_dtype , query_fn , sycl_dev , acceptance_fn ):
198- _fp16 = sycl_dev .has_aspect_fp16
199- _fp64 = sycl_dev .has_aspect_fp64
200-
201198 res_dt = query_fn (arg_dtype )
202199 if res_dt :
203- if _dtype_supported_by_device_impl (res_dt , _fp16 , _fp64 ):
204- return None , res_dt
200+ return None , res_dt
205201
202+ _fp16 = sycl_dev .has_aspect_fp16
203+ _fp64 = sycl_dev .has_aspect_fp64
206204 all_dts = _all_data_types (_fp16 , _fp64 )
207205 for buf_dt in all_dts :
208206 if _can_cast (arg_dtype , buf_dt , _fp16 , _fp64 ):
209207 res_dt = query_fn (buf_dt )
210- if res_dt and _dtype_supported_by_device_impl ( res_dt , _fp16 , _fp64 ) :
208+ if res_dt :
211209 acceptable = acceptance_fn (arg_dtype , buf_dt , res_dt , sycl_dev )
212210 if acceptable :
213211 return buf_dt , res_dt
0 commit comments