3131#include < pybind11/stl.h>
3232
3333// dpctl tensor headers
34- #include " utils/output_validation.hpp"
3534#include " utils/type_dispatch.hpp"
35+ #include " utils/type_utils.hpp"
3636
3737#include " kernels/elementwise_functions/interpolate.hpp"
3838
4141
4242namespace py = pybind11;
4343namespace td_ns = dpctl::tensor::type_dispatch;
44+ namespace type_utils = dpctl::tensor::type_utils;
4445
4546using ext::common::value_type_of;
4647using ext::validation::array_names;
@@ -57,18 +58,18 @@ template <typename T>
5758using value_type_of_t = typename value_type_of<T>::type;
5859
5960typedef sycl::event (*interpolate_fn_ptr_t )(sycl::queue &,
60- const void *, // x
61- const void *, // idx
62- const void *, // xp
63- const void *, // fp
64- const void *, // left
65- const void *, // right
66- void *, // out
67- std::size_t , // n
68- std::size_t , // xp_size
61+ const void *, // x
62+ const void *, // idx
63+ const void *, // xp
64+ const void *, // fp
65+ const void *, // left
66+ const void *, // right
67+ void *, // out
68+ const std::size_t , // n
69+ const std::size_t , // xp_size
6970 const std::vector<sycl::event> &);
7071
71- template <typename T>
72+ template <typename T, typename TIdx = std:: int64_t >
7273sycl::event interpolate_call (sycl::queue &exec_q,
7374 const void *vx,
7475 const void *vidx,
@@ -77,15 +78,15 @@ sycl::event interpolate_call(sycl::queue &exec_q,
7778 const void *vleft,
7879 const void *vright,
7980 void *vout,
80- std::size_t n,
81- std::size_t xp_size,
81+ const std::size_t n,
82+ const std::size_t xp_size,
8283 const std::vector<sycl::event> &depends)
8384{
84- using dpctl::tensor:: type_utils::is_complex_v;
85+ using type_utils::is_complex_v;
8586 using TCoord = std::conditional_t <is_complex_v<T>, value_type_of_t <T>, T>;
8687
8788 const TCoord *x = static_cast <const TCoord *>(vx);
88- const std:: int64_t *idx = static_cast <const std:: int64_t *>(vidx);
89+ const TIdx *idx = static_cast <const TIdx *>(vidx);
8990 const TCoord *xp = static_cast <const TCoord *>(vxp);
9091 const T *fp = static_cast <const T *>(vfp);
9192 const T *left = static_cast <const T *>(vleft);
@@ -114,6 +115,7 @@ void common_interpolate_checks(
114115
115116 auto array_types = td_ns::usm_ndarray_types ();
116117 int x_type_id = array_types.typenum_to_lookup_id (x.get_typenum ());
118+ int idx_type_id = array_types.typenum_to_lookup_id (idx.get_typenum ());
117119 int xp_type_id = array_types.typenum_to_lookup_id (xp.get_typenum ());
118120 int fp_type_id = array_types.typenum_to_lookup_id (fp.get_typenum ());
119121 int out_type_id = array_types.typenum_to_lookup_id (out.get_typenum ());
@@ -124,38 +126,41 @@ void common_interpolate_checks(
124126 if (fp_type_id != out_type_id) {
125127 throw py::value_error (" fp and out must have the same dtype" );
126128 }
129+ if (idx_type_id != static_cast <int >(td_ns::typenum_t ::INT64)) {
130+ throw py::value_error (" The type of idx must be int64" );
131+ }
127132
128- if ( left) {
129- const auto &l = left. value ();
130- names.insert ({&l , " left" });
131- if (l. get_ndim () != 0 ) {
133+ auto left_v = left ? &left. value () : nullptr ;
134+ if (left_v) {
135+ names.insert ({left_v , " left" });
136+ if (left_v-> get_ndim () != 0 ) {
132137 throw py::value_error (" left must be a zero-dimensional array" );
133138 }
134139
135- int left_type_id = array_types.typenum_to_lookup_id (l.get_typenum ());
140+ int left_type_id =
141+ array_types.typenum_to_lookup_id (left_v->get_typenum ());
136142 if (left_type_id != fp_type_id) {
137143 throw py::value_error (
138144 " left must have the same dtype as fp and out" );
139145 }
140146 }
141147
142- if ( right) {
143- const auto &r = right. value ();
144- names.insert ({&r , " right" });
145- if (r. get_ndim () != 0 ) {
148+ auto right_v = right ? &right. value () : nullptr ;
149+ if (right_v) {
150+ names.insert ({right_v , " right" });
151+ if (right_v-> get_ndim () != 0 ) {
146152 throw py::value_error (" right must be a zero-dimensional array" );
147153 }
148154
149- int right_type_id = array_types.typenum_to_lookup_id (r.get_typenum ());
155+ int right_type_id =
156+ array_types.typenum_to_lookup_id (right_v->get_typenum ());
150157 if (right_type_id != fp_type_id) {
151158 throw py::value_error (
152159 " right must have the same dtype as fp and out" );
153160 }
154161 }
155162
156- common_checks ({&x, &xp, &fp, left ? &left.value () : nullptr ,
157- right ? &right.value () : nullptr },
158- {&out}, names);
163+ common_checks ({&x, &xp, &fp, left_v, right_v}, {&out}, names);
159164
160165 if (x.get_ndim () != 1 || xp.get_ndim () != 1 || fp.get_ndim () != 1 ||
161166 idx.get_ndim () != 1 || out.get_ndim () != 1 )
@@ -167,6 +172,10 @@ void common_interpolate_checks(
167172 throw py::value_error (" xp and fp must have the same size" );
168173 }
169174
175+ if (xp.get_size () == 0 ) {
176+ throw py::value_error (" array of sample points is empty" );
177+ }
178+
170179 if (x.get_size () != out.get_size () || x.get_size () != idx.get_size ()) {
171180 throw py::value_error (" x, idx, and out must have the same size" );
172181 }
@@ -183,12 +192,12 @@ std::pair<sycl::event, sycl::event>
183192 sycl::queue &exec_q,
184193 const std::vector<sycl::event> &depends)
185194{
195+ common_interpolate_checks (x, idx, xp, fp, out, left, right);
196+
186197 if (x.get_size () == 0 ) {
187198 return {sycl::event (), sycl::event ()};
188199 }
189200
190- common_interpolate_checks (x, idx, xp, fp, out, left, right);
191-
192201 int out_typenum = out.get_typenum ();
193202
194203 auto array_types = td_ns::usm_ndarray_types ();
@@ -215,13 +224,10 @@ std::pair<sycl::event, sycl::event>
215224 args_ev = dpctl::utils::keep_args_alive (
216225 exec_q, {x, idx, xp, fp, out, left.value (), right.value ()}, {ev});
217226 }
218- else if (left) {
219- args_ev = dpctl::utils::keep_args_alive (
220- exec_q, {x, idx, xp, fp, out, left.value ()}, {ev});
221- }
222- else if (right) {
227+ else if (left || right) {
223228 args_ev = dpctl::utils::keep_args_alive (
224- exec_q, {x, idx, xp, fp, out, right.value ()}, {ev});
229+ exec_q, {x, idx, xp, fp, out, left ? left.value () : right.value ()},
230+ {ev});
225231 }
226232 else {
227233 args_ev =
0 commit comments