Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dpnp/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_acceptance_fn_divide,
_acceptance_fn_negative,
_acceptance_fn_reciprocal,
_acceptance_fn_round,
_acceptance_fn_subtract,
_resolve_weak_types_all_py_ints,
)
Expand Down Expand Up @@ -1723,7 +1724,11 @@
"""

round = UnaryElementwiseFunc(
"round", ti._round_result_type, ti._round, _round_docstring
"round",
ti._round_result_type,
ti._round,
_round_docstring,
acceptance_fn=_acceptance_fn_round,
)
del _round_docstring

Expand Down
8 changes: 8 additions & 0 deletions dpnp/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
return True


def _acceptance_fn_round(arg_dtype, buf_dt, res_dt, sycl_dev):
# for boolean input, prefer floating-point output over integral
if arg_dtype.kind == "b" and res_dt.kind != "f":
return False
return True


def _acceptance_fn_subtract(
arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev
):
Expand Down Expand Up @@ -970,6 +977,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"_find_buf_dtype2",
"_to_device_supported_dtype",
"_acceptance_fn_default_unary",
"_acceptance_fn_round",
"_acceptance_fn_reciprocal",
"_acceptance_fn_default_binary",
"_acceptance_fn_divide",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ template <typename T>
struct RoundOutputType
{
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, bool, sycl::half>,
td_ns::TypeMapResultEntry<T, std::uint8_t>,
td_ns::TypeMapResultEntry<T, std::uint16_t>,
td_ns::TypeMapResultEntry<T, std::uint32_t>,
Expand Down
1 change: 1 addition & 0 deletions dpnp/tests/tensor/elementwise/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_exp_real_contig(dtype):
assert_allclose(dpt.asnumpy(Z), np.repeat(Ynp, n_rep), atol=tol, rtol=tol)


@pytest.mark.filterwarnings("ignore:overflow encountered:RuntimeWarning")
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_exp_complex_contig(dtype):
q = get_queue_or_skip()
Expand Down
4 changes: 3 additions & 1 deletion dpnp/tests/tensor/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,9 @@ def test_copy_via_host_gh_1789():
get_queue_or_skip()
x_np = np.ones((10, 10), dtype="i4")
# strides are no longer multiple of itemsize
x_np.strides = (x_np.strides[0] - 1, x_np.strides[1])
x_np = np.lib.stride_tricks.as_strided(
x_np, shape=x_np.shape, strides=(x_np.strides[0] - 1, x_np.strides[1])
)
with pytest.raises(BufferError):
dpt.from_dlpack(x_np)
with pytest.raises(BufferError):
Expand Down
Loading