Skip to content

Commit e7bf84c

Browse files
committed
Add dedicated InterpolateFunctor to the kernels
1 parent 5056633 commit e7bf84c

File tree

2 files changed

+124
-106
lines changed

2 files changed

+124
-106
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp

Lines changed: 64 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@
4141
#include <pybind11/pybind11.h>
4242
#include <pybind11/stl.h>
4343

44+
#include "kernels/elementwise_functions/interpolate.hpp"
45+
4446
// dpctl tensor headers
4547
#include "utils/type_dispatch.hpp"
4648
#include "utils/type_utils.hpp"
4749

48-
#include "kernels/elementwise_functions/interpolate.hpp"
49-
5050
// utils extension headers
5151
#include "ext/common.hpp"
5252
#include "ext/validation_utils.hpp"
@@ -57,7 +57,6 @@ namespace type_utils = dpctl::tensor::type_utils;
5757

5858
using ext::common::value_type_of;
5959
using ext::validation::array_names;
60-
using ext::validation::array_ptr;
6160

6261
using ext::common::dtype_from_typenum;
6362
using ext::validation::check_has_dtype;
@@ -68,7 +67,6 @@ using ext::validation::common_checks;
6867

6968
namespace dpnp::extensions::ufunc
7069
{
71-
7270
namespace impl
7371
{
7472
using ext::common::init_dispatch_vector;
@@ -88,8 +86,10 @@ typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &,
8886
const std::size_t, // xp_size
8987
const std::vector<sycl::event> &);
9088

89+
interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types];
90+
9191
template <typename T, typename TIdx = std::int64_t>
92-
sycl::event interpolate_call(sycl::queue &exec_q,
92+
sycl::event interpolate_impl(sycl::queue &q,
9393
const void *vx,
9494
const void *vidx,
9595
const void *vxp,
@@ -101,6 +101,8 @@ sycl::event interpolate_call(sycl::queue &exec_q,
101101
const std::size_t xp_size,
102102
const std::vector<sycl::event> &depends)
103103
{
104+
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
105+
104106
using type_utils::is_complex_v;
105107
using TCoord = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>;
106108

@@ -112,23 +114,62 @@ sycl::event interpolate_call(sycl::queue &exec_q,
112114
const T *right = static_cast<const T *>(vright);
113115
T *out = static_cast<T *>(vout);
114116

115-
using dpnp::kernels::interpolate::interpolate_impl;
116-
sycl::event interpolate_ev = interpolate_impl<TCoord, T>(
117-
exec_q, x, idx, xp, fp, left, right, out, n, xp_size, depends);
117+
sycl::event interpolate_ev = q.submit([&](sycl::handler &cgh) {
118+
cgh.depends_on(depends);
119+
120+
using InterpolateFunc =
121+
dpnp::kernels::interpolate::InterpolateFunctor<TCoord, T>;
122+
123+
cgh.parallel_for<InterpolateFunc>(
124+
sycl::range<1>(n),
125+
InterpolateFunc(x, idx, xp, fp, left, right, out, xp_size));
126+
});
118127

119128
return interpolate_ev;
120129
}
121130

122-
interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types];
131+
/**
132+
* @brief A factory to define pairs of supported types for which
133+
* interpolate function is available.
134+
*
135+
* @tparam T Type of input vector `a` and of result vector `y`.
136+
*/
137+
template <typename T>
138+
struct InterpolateOutputType
139+
{
140+
using value_type = typename std::disjunction<
141+
td_ns::TypeMapResultEntry<T, float>,
142+
td_ns::TypeMapResultEntry<T, double>,
143+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
144+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
145+
td_ns::DefaultResultEntry<void>>::result_type;
146+
};
123147

124-
void common_interpolate_checks(
125-
const dpctl::tensor::usm_ndarray &x,
126-
const dpctl::tensor::usm_ndarray &idx,
127-
const dpctl::tensor::usm_ndarray &xp,
128-
const dpctl::tensor::usm_ndarray &fp,
129-
const dpctl::tensor::usm_ndarray &out,
130-
const std::optional<const dpctl::tensor::usm_ndarray> &left,
131-
const std::optional<const dpctl::tensor::usm_ndarray> &right)
148+
template <typename fnT, typename T>
149+
struct InterpolateFactory
150+
{
151+
fnT get()
152+
{
153+
if constexpr (std::is_same_v<
154+
typename InterpolateOutputType<T>::value_type, void>)
155+
{
156+
return nullptr;
157+
}
158+
else {
159+
return interpolate_impl<T>;
160+
}
161+
}
162+
};
163+
164+
namespace detail
165+
{
166+
void validate(const dpctl::tensor::usm_ndarray &x,
167+
const dpctl::tensor::usm_ndarray &idx,
168+
const dpctl::tensor::usm_ndarray &xp,
169+
const dpctl::tensor::usm_ndarray &fp,
170+
const dpctl::tensor::usm_ndarray &out,
171+
const std::optional<const dpctl::tensor::usm_ndarray> &left,
172+
const std::optional<const dpctl::tensor::usm_ndarray> &right)
132173
{
133174
array_names names = {{&x, "x"}, {&xp, "xp"}, {&fp, "fp"}, {&out, "out"}};
134175

@@ -158,6 +199,7 @@ void common_interpolate_checks(
158199
throw py::value_error("array of sample points is empty");
159200
}
160201
}
202+
} // namespace detail
161203

162204
std::pair<sycl::event, sycl::event>
163205
py_interpolate(const dpctl::tensor::usm_ndarray &x,
@@ -170,7 +212,7 @@ std::pair<sycl::event, sycl::event>
170212
sycl::queue &exec_q,
171213
const std::vector<sycl::event> &depends)
172214
{
173-
common_interpolate_checks(x, idx, xp, fp, out, left, right);
215+
detail::validate(x, idx, xp, fp, out, left, right);
174216

175217
int out_typenum = out.get_typenum();
176218

@@ -214,56 +256,20 @@ std::pair<sycl::event, sycl::event>
214256
return std::make_pair(args_ev, ev);
215257
}
216258

217-
/**
218-
* @brief A factory to define pairs of supported types for which
219-
* interpolate function is available.
220-
*
221-
* @tparam T Type of input vector `a` and of result vector `y`.
222-
*/
223-
template <typename T>
224-
struct InterpolateOutputType
225-
{
226-
using value_type = typename std::disjunction<
227-
td_ns::TypeMapResultEntry<T, float>,
228-
td_ns::TypeMapResultEntry<T, double>,
229-
td_ns::TypeMapResultEntry<T, std::complex<float>>,
230-
td_ns::TypeMapResultEntry<T, std::complex<double>>,
231-
td_ns::DefaultResultEntry<void>>::result_type;
232-
};
233-
234-
template <typename fnT, typename T>
235-
struct InterpolateFactory
236-
{
237-
fnT get()
238-
{
239-
if constexpr (std::is_same_v<
240-
typename InterpolateOutputType<T>::value_type, void>)
241-
{
242-
return nullptr;
243-
}
244-
else {
245-
return interpolate_call<T>;
246-
}
247-
}
248-
};
249-
250259
static void init_interpolate_dispatch_vectors()
251260
{
252-
init_dispatch_vector<interpolate_fn_ptr_t, InterpolateFactory>(
261+
init_dispatch_vector<interpolate_fn_ptr_t, impl::InterpolateFactory>(
253262
interpolate_dispatch_vector);
254263
}
255-
256264
} // namespace impl
257265

258266
void init_interpolate(py::module_ m)
259267
{
260268
impl::init_interpolate_dispatch_vectors();
261269

262-
using impl::py_interpolate;
263-
m.def("_interpolate", &py_interpolate, "", py::arg("x"), py::arg("idx"),
264-
py::arg("xp"), py::arg("fp"), py::arg("left"), py::arg("right"),
265-
py::arg("out"), py::arg("sycl_queue"),
270+
m.def("_interpolate", &impl::py_interpolate, "", py::arg("x"),
271+
py::arg("idx"), py::arg("xp"), py::arg("fp"), py::arg("left"),
272+
py::arg("right"), py::arg("out"), py::arg("sycl_queue"),
266273
py::arg("depends") = py::list());
267274
}
268-
269275
} // namespace dpnp::extensions::ufunc

dpnp/backend/kernels/elementwise_functions/interpolate.hpp

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,67 +28,79 @@
2828

2929
#pragma once
3030

31+
#include <cstddef>
32+
#include <cstdint>
33+
3134
#include <sycl/sycl.hpp>
32-
#include <vector>
3335

3436
#include "ext/common.hpp"
3537

36-
using ext::common::IsNan;
37-
3838
namespace dpnp::kernels::interpolate
3939
{
40+
using ext::common::IsNan;
41+
4042
template <typename TCoord, typename TValue, typename TIdx = std::int64_t>
41-
sycl::event interpolate_impl(sycl::queue &q,
42-
const TCoord *x,
43-
const TIdx *idx,
44-
const TCoord *xp,
45-
const TValue *fp,
46-
const TValue *left,
47-
const TValue *right,
48-
TValue *out,
49-
const std::size_t n,
50-
const std::size_t xp_size,
51-
const std::vector<sycl::event> &depends)
43+
class InterpolateFunctor
5244
{
45+
private:
46+
const TCoord *x = nullptr;
47+
const TIdx *idx = nullptr;
48+
const TCoord *xp = nullptr;
49+
const TValue *fp = nullptr;
50+
const TValue *left = nullptr;
51+
const TValue *right = nullptr;
52+
TValue *out = nullptr;
53+
const std::size_t xp_size;
54+
55+
public:
56+
InterpolateFunctor(const TCoord *x_,
57+
const TIdx *idx_,
58+
const TCoord *xp_,
59+
const TValue *fp_,
60+
const TValue *left_,
61+
const TValue *right_,
62+
TValue *out_,
63+
const std::size_t xp_size_)
64+
: x(x_), idx(idx_), xp(xp_), fp(fp_), left(left_), right(right_),
65+
out(out_), xp_size(xp_size_)
66+
{
67+
}
68+
5369
// Selected over the work-group version
5470
// due to simpler execution and slightly better performance.
55-
return q.submit([&](sycl::handler &h) {
56-
h.depends_on(depends);
57-
h.parallel_for(sycl::range<1>(n), [=](sycl::id<1> i) {
58-
TValue left_val = left ? *left : fp[0];
59-
TValue right_val = right ? *right : fp[xp_size - 1];
71+
void operator()(sycl::id<1> id) const
72+
{
73+
TValue left_val = left ? *left : fp[0];
74+
TValue right_val = right ? *right : fp[xp_size - 1];
6075

61-
TCoord x_val = x[i];
62-
TIdx x_idx = idx[i] - 1;
76+
TCoord x_val = x[id];
77+
TIdx x_idx = idx[id] - 1;
6378

64-
if (IsNan<TCoord>::isnan(x_val)) {
65-
out[i] = x_val;
66-
}
67-
else if (x_idx < 0) {
68-
out[i] = left_val;
69-
}
70-
else if (x_val == xp[xp_size - 1]) {
71-
out[i] = fp[xp_size - 1];
72-
}
73-
else if (x_idx >= static_cast<TIdx>(xp_size - 1)) {
74-
out[i] = right_val;
75-
}
76-
else {
77-
TValue slope =
78-
(fp[x_idx + 1] - fp[x_idx]) / (xp[x_idx + 1] - xp[x_idx]);
79-
TValue res = slope * (x_val - xp[x_idx]) + fp[x_idx];
79+
if (IsNan<TCoord>::isnan(x_val)) {
80+
out[id] = x_val;
81+
}
82+
else if (x_idx < 0) {
83+
out[id] = left_val;
84+
}
85+
else if (x_val == xp[xp_size - 1]) {
86+
out[id] = fp[xp_size - 1];
87+
}
88+
else if (x_idx >= static_cast<TIdx>(xp_size - 1)) {
89+
out[id] = right_val;
90+
}
91+
else {
92+
TValue slope =
93+
(fp[x_idx + 1] - fp[x_idx]) / (xp[x_idx + 1] - xp[x_idx]);
94+
TValue res = slope * (x_val - xp[x_idx]) + fp[x_idx];
8095

81-
if (IsNan<TValue>::isnan(res)) {
82-
res = slope * (x_val - xp[x_idx + 1]) + fp[x_idx + 1];
83-
if (IsNan<TValue>::isnan(res) &&
84-
(fp[x_idx] == fp[x_idx + 1])) {
85-
res = fp[x_idx];
86-
}
96+
if (IsNan<TValue>::isnan(res)) {
97+
res = slope * (x_val - xp[x_idx + 1]) + fp[x_idx + 1];
98+
if (IsNan<TValue>::isnan(res) && (fp[x_idx] == fp[x_idx + 1])) {
99+
res = fp[x_idx];
87100
}
88-
out[i] = res;
89101
}
90-
});
91-
});
92-
}
93-
102+
out[id] = res;
103+
}
104+
}
105+
};
94106
} // namespace dpnp::kernels::interpolate

0 commit comments

Comments
 (0)