Skip to content

Commit 00c9874

Browse files
Address remarks
1 parent 4bf2b03 commit 00c9874

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ using ext::common::value_type_of;
4747
using ext::validation::array_names;
4848
using ext::validation::array_ptr;
4949

50+
using ext::common::dtype_from_typenum;
5051
using ext::validation::check_has_dtype;
5152
using ext::validation::check_num_dims;
5253
using ext::validation::check_same_dtype;
@@ -165,7 +166,10 @@ std::pair<sycl::event, sycl::event>
165166

166167
auto fn = interpolate_dispatch_vector[out_type_id];
167168
if (!fn) {
168-
throw py::type_error("Unsupported dtype");
169+
py::dtype out_dtype_py = dtype_from_typenum(out_type_id);
170+
std::string msg = "Unsupported dtype for interpolation: " +
171+
std::string(py::str(out_dtype_py));
172+
throw py::type_error(msg);
169173
}
170174

171175
std::size_t n = x.get_size();
@@ -207,7 +211,6 @@ template <typename T>
207211
struct InterpolateOutputType
208212
{
209213
using value_type = typename std::disjunction<
210-
td_ns::TypeMapResultEntry<T, sycl::half>,
211214
td_ns::TypeMapResultEntry<T, float>,
212215
td_ns::TypeMapResultEntry<T, double>,
213216
td_ns::TypeMapResultEntry<T, std::complex<float>>,

dpnp/tests/test_mathematical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ class TestInterp:
11501150
@pytest.mark.parametrize(
11511151
"dtype_xp", get_all_dtypes(no_complex=True, no_none=True)
11521152
)
1153-
@pytest.mark.parametrize("dtype_y", get_all_dtypes())
1153+
@pytest.mark.parametrize("dtype_y", get_all_dtypes(no_none=True))
11541154
def test_all_dtypes(self, dtype_x, dtype_xp, dtype_y):
11551155
x = numpy.linspace(0.1, 9.9, 20).astype(dtype_x)
11561156
xp = numpy.linspace(0.0, 10.0, 5).astype(dtype_xp)
@@ -1281,7 +1281,7 @@ def test_errors(self):
12811281
assert_raises(TypeError, dpnp.interp, x, xp, fp, period=[180])
12821282

12831283
# left is not scalar or 0-dim
1284-
left = dpnp.array([1.0])
1284+
left = [1]
12851285
assert_raises(ValueError, dpnp.interp, x, xp, fp, left=left)
12861286

12871287
# left is 1-d array

0 commit comments

Comments
 (0)