4141#include < pybind11/pybind11.h>
4242#include < pybind11/stl.h>
4343
44+ #include " kernels/elementwise_functions/interpolate.hpp"
45+
4446// dpctl tensor headers
4547#include " utils/type_dispatch.hpp"
4648#include " utils/type_utils.hpp"
4749
48- #include " kernels/elementwise_functions/interpolate.hpp"
49-
5050// utils extension headers
5151#include " ext/common.hpp"
5252#include " ext/validation_utils.hpp"
@@ -57,7 +57,6 @@ namespace type_utils = dpctl::tensor::type_utils;
5757
5858using ext::common::value_type_of;
5959using ext::validation::array_names;
60- using ext::validation::array_ptr;
6160
6261using ext::common::dtype_from_typenum;
6362using ext::validation::check_has_dtype;
@@ -68,7 +67,6 @@ using ext::validation::common_checks;
6867
6968namespace dpnp ::extensions::ufunc
7069{
71-
7270namespace impl
7371{
7472using ext::common::init_dispatch_vector;
@@ -88,8 +86,10 @@ typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &,
8886 const std::size_t , // xp_size
8987 const std::vector<sycl::event> &);
9088
89+ interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types];
90+
9191template <typename T, typename TIdx = std::int64_t >
92- sycl::event interpolate_call (sycl::queue &exec_q ,
92+ sycl::event interpolate_impl (sycl::queue &q ,
9393 const void *vx,
9494 const void *vidx,
9595 const void *vxp,
@@ -101,6 +101,8 @@ sycl::event interpolate_call(sycl::queue &exec_q,
101101 const std::size_t xp_size,
102102 const std::vector<sycl::event> &depends)
103103{
104+ dpctl::tensor::type_utils::validate_type_for_device<T>(q);
105+
104106 using type_utils::is_complex_v;
105107 using TCoord = std::conditional_t <is_complex_v<T>, value_type_of_t <T>, T>;
106108
@@ -112,23 +114,62 @@ sycl::event interpolate_call(sycl::queue &exec_q,
112114 const T *right = static_cast <const T *>(vright);
113115 T *out = static_cast <T *>(vout);
114116
115- using dpnp::kernels::interpolate::interpolate_impl;
116- sycl::event interpolate_ev = interpolate_impl<TCoord, T>(
117- exec_q, x, idx, xp, fp, left, right, out, n, xp_size, depends);
117+ sycl::event interpolate_ev = q.submit ([&](sycl::handler &cgh) {
118+ cgh.depends_on (depends);
119+
120+ using InterpolateFunc =
121+ dpnp::kernels::interpolate::InterpolateFunctor<TCoord, T>;
122+
123+ cgh.parallel_for <InterpolateFunc>(
124+ sycl::range<1 >(n),
125+ InterpolateFunc (x, idx, xp, fp, left, right, out, xp_size));
126+ });
118127
119128 return interpolate_ev;
120129}
121130
122- interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types];
131+ /* *
132+ * @brief A factory to define pairs of supported types for which
133+ * interpolate function is available.
134+ *
135+ * @tparam T Type of input vector `a` and of result vector `y`.
136+ */
137+ template <typename T>
138+ struct InterpolateOutputType
139+ {
140+ using value_type = typename std::disjunction<
141+ td_ns::TypeMapResultEntry<T, float >,
142+ td_ns::TypeMapResultEntry<T, double >,
143+ td_ns::TypeMapResultEntry<T, std::complex <float >>,
144+ td_ns::TypeMapResultEntry<T, std::complex <double >>,
145+ td_ns::DefaultResultEntry<void >>::result_type;
146+ };
123147
124- void common_interpolate_checks (
125- const dpctl::tensor::usm_ndarray &x,
126- const dpctl::tensor::usm_ndarray &idx,
127- const dpctl::tensor::usm_ndarray &xp,
128- const dpctl::tensor::usm_ndarray &fp,
129- const dpctl::tensor::usm_ndarray &out,
130- const std::optional<const dpctl::tensor::usm_ndarray> &left,
131- const std::optional<const dpctl::tensor::usm_ndarray> &right)
148+ template <typename fnT, typename T>
149+ struct InterpolateFactory
150+ {
151+ fnT get ()
152+ {
153+ if constexpr (std::is_same_v<
154+ typename InterpolateOutputType<T>::value_type, void >)
155+ {
156+ return nullptr ;
157+ }
158+ else {
159+ return interpolate_impl<T>;
160+ }
161+ }
162+ };
163+
164+ namespace detail
165+ {
166+ void validate (const dpctl::tensor::usm_ndarray &x,
167+ const dpctl::tensor::usm_ndarray &idx,
168+ const dpctl::tensor::usm_ndarray &xp,
169+ const dpctl::tensor::usm_ndarray &fp,
170+ const dpctl::tensor::usm_ndarray &out,
171+ const std::optional<const dpctl::tensor::usm_ndarray> &left,
172+ const std::optional<const dpctl::tensor::usm_ndarray> &right)
132173{
133174 array_names names = {{&x, " x" }, {&xp, " xp" }, {&fp, " fp" }, {&out, " out" }};
134175
@@ -158,6 +199,7 @@ void common_interpolate_checks(
158199 throw py::value_error (" array of sample points is empty" );
159200 }
160201}
202+ } // namespace detail
161203
162204std::pair<sycl::event, sycl::event>
163205 py_interpolate (const dpctl::tensor::usm_ndarray &x,
@@ -170,7 +212,7 @@ std::pair<sycl::event, sycl::event>
170212 sycl::queue &exec_q,
171213 const std::vector<sycl::event> &depends)
172214{
173- common_interpolate_checks (x, idx, xp, fp, out, left, right);
215+ detail::validate (x, idx, xp, fp, out, left, right);
174216
175217 int out_typenum = out.get_typenum ();
176218
@@ -214,56 +256,20 @@ std::pair<sycl::event, sycl::event>
214256 return std::make_pair (args_ev, ev);
215257}
216258
217- /* *
218- * @brief A factory to define pairs of supported types for which
219- * interpolate function is available.
220- *
221- * @tparam T Type of input vector `a` and of result vector `y`.
222- */
223- template <typename T>
224- struct InterpolateOutputType
225- {
226- using value_type = typename std::disjunction<
227- td_ns::TypeMapResultEntry<T, float >,
228- td_ns::TypeMapResultEntry<T, double >,
229- td_ns::TypeMapResultEntry<T, std::complex <float >>,
230- td_ns::TypeMapResultEntry<T, std::complex <double >>,
231- td_ns::DefaultResultEntry<void >>::result_type;
232- };
233-
234- template <typename fnT, typename T>
235- struct InterpolateFactory
236- {
237- fnT get ()
238- {
239- if constexpr (std::is_same_v<
240- typename InterpolateOutputType<T>::value_type,
241- void >) {
242- return nullptr ;
243- }
244- else {
245- return interpolate_call<T>;
246- }
247- }
248- };
249-
250259static void init_interpolate_dispatch_vectors ()
251260{
252- init_dispatch_vector<interpolate_fn_ptr_t , InterpolateFactory>(
261+ init_dispatch_vector<interpolate_fn_ptr_t , impl:: InterpolateFactory>(
253262 interpolate_dispatch_vector);
254263}
255-
256264} // namespace impl
257265
258266void init_interpolate (py::module_ m)
259267{
260268 impl::init_interpolate_dispatch_vectors ();
261269
262- using impl::py_interpolate;
263- m.def (" _interpolate" , &py_interpolate, " " , py::arg (" x" ), py::arg (" idx" ),
264- py::arg (" xp" ), py::arg (" fp" ), py::arg (" left" ), py::arg (" right" ),
265- py::arg (" out" ), py::arg (" sycl_queue" ),
270+ m.def (" _interpolate" , &impl::py_interpolate, " " , py::arg (" x" ),
271+ py::arg (" idx" ), py::arg (" xp" ), py::arg (" fp" ), py::arg (" left" ),
272+ py::arg (" right" ), py::arg (" out" ), py::arg (" sycl_queue" ),
266273 py::arg (" depends" ) = py::list ());
267274}
268-
269275} // namespace dpnp::extensions::ufunc
0 commit comments