Skip to content

Commit c29b08e

Browse files
committed
Add dedicated windows functors
1 parent 5b8c4e2 commit c29b08e

File tree

12 files changed

+153
-154
lines changed

12 files changed

+153
-154
lines changed

dpnp/backend/extensions/indexing/choose.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@
5656

5757
namespace dpnp::extensions::indexing
5858
{
59+
namespace py = pybind11;
60+
5961
namespace impl
6062
{
61-
namespace py = pybind11;
6263
namespace td_ns = dpctl::tensor::type_dispatch;
6364

6465
using dpctl::tensor::ssize_t;

dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,28 +51,19 @@
5151
#include "ext/common.hpp"
5252
#include "ext/validation_utils.hpp"
5353

54-
namespace py = pybind11;
55-
namespace td_ns = dpctl::tensor::type_dispatch;
56-
namespace type_utils = dpctl::tensor::type_utils;
57-
58-
using ext::common::value_type_of;
59-
using ext::validation::array_names;
60-
61-
using ext::common::dtype_from_typenum;
62-
using ext::validation::check_has_dtype;
63-
using ext::validation::check_num_dims;
64-
using ext::validation::check_same_dtype;
65-
using ext::validation::check_same_size;
66-
using ext::validation::common_checks;
67-
6854
namespace dpnp::extensions::ufunc
6955
{
56+
namespace py = pybind11;
57+
7058
namespace impl
7159
{
72-
using ext::common::init_dispatch_vector;
60+
namespace td_ns = dpctl::tensor::type_dispatch;
61+
namespace type_utils = dpctl::tensor::type_utils;
7362

7463
template <typename T>
75-
using value_type_of_t = typename value_type_of<T>::type;
64+
using value_type_of_t = typename ext::common::value_type_of<T>::type;
65+
66+
using ext::common::dtype_from_typenum;
7667

7768
typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &,
7869
const void *, // x
@@ -163,6 +154,13 @@ struct InterpolateFactory
163154

164155
namespace detail
165156
{
157+
using ext::validation::array_names;
158+
using ext::validation::check_has_dtype;
159+
using ext::validation::check_num_dims;
160+
using ext::validation::check_same_dtype;
161+
using ext::validation::check_same_size;
162+
using ext::validation::common_checks;
163+
166164
void validate(const dpctl::tensor::usm_ndarray &x,
167165
const dpctl::tensor::usm_ndarray &idx,
168166
const dpctl::tensor::usm_ndarray &xp,
@@ -258,6 +256,7 @@ std::pair<sycl::event, sycl::event>
258256

259257
static void init_interpolate_dispatch_vectors()
260258
{
259+
using ext::common::init_dispatch_vector;
261260
init_dispatch_vector<interpolate_fn_ptr_t, impl::InterpolateFactory>(
262261
interpolate_dispatch_vector);
263262
}

dpnp/backend/extensions/window/common.hpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828

2929
#pragma once
3030

31-
#include <pybind11/pybind11.h>
32-
#include <pybind11/stl.h>
3331
#include <sycl/sycl.hpp>
3432

3533
#include "dpctl4pybind11.hpp"
34+
#include <pybind11/pybind11.h>
35+
#include <pybind11/stl.h>
3636

3737
// dpctl tensor headers
3838
#include "utils/output_validation.hpp"
@@ -41,10 +41,8 @@
4141

4242
namespace dpnp::extensions::window
4343
{
44-
45-
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
46-
4744
namespace py = pybind11;
45+
namespace td_ns = dpctl::tensor::type_dispatch;
4846

4947
typedef sycl::event (*window_fn_ptr_t)(sycl::queue &,
5048
char *,
@@ -72,6 +70,20 @@ sycl::event window_impl(sycl::queue &exec_q,
7270
return window_ev;
7371
}
7472

73+
template <typename fnT, typename T, template <typename> typename FunctorT>
74+
struct Factory
75+
{
76+
fnT get()
77+
{
78+
if constexpr (std::is_floating_point_v<T>) {
79+
return window_impl<T, FunctorT>;
80+
}
81+
else {
82+
return nullptr;
83+
}
84+
}
85+
};
86+
7587
template <typename funcPtrT>
7688
std::tuple<size_t, char *, funcPtrT>
7789
window_fn(sycl::queue &exec_q,
@@ -101,7 +113,7 @@ std::tuple<size_t, char *, funcPtrT>
101113
}
102114

103115
const int result_typenum = result.get_typenum();
104-
auto array_types = dpctl_td_ns::usm_ndarray_types();
116+
auto array_types = td_ns::usm_ndarray_types();
105117
const int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
106118
funcPtrT fn = window_dispatch_vector[result_type_id];
107119

dpnp/backend/extensions/window/kaiser.cpp

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "kaiser.hpp"
3030
#include "common.hpp"
3131

32+
#include "kernels/window/kaiser.hpp"
33+
3234
// utils extension header
3335
#include "ext/common.hpp"
3436

@@ -39,48 +41,21 @@
3941

4042
#include <sycl/sycl.hpp>
4143

42-
#include "kernels/elementwise_functions/i0.hpp"
43-
4444
namespace dpnp::extensions::window
4545
{
46-
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
47-
48-
using ext::common::init_dispatch_vector;
46+
namespace py = pybind11;
47+
namespace td_ns = dpctl::tensor::type_dispatch;
4948

5049
typedef sycl::event (*kaiser_fn_ptr_t)(sycl::queue &,
5150
char *,
5251
const std::size_t,
5352
const py::object &,
5453
const std::vector<sycl::event> &);
5554

56-
static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];
55+
static kaiser_fn_ptr_t kaiser_dispatch_vector[td_ns::num_types];
5756

58-
template <typename T>
59-
class KaiserFunctor
57+
namespace impl
6058
{
61-
private:
62-
T *res = nullptr;
63-
const std::size_t N;
64-
const T beta;
65-
66-
public:
67-
KaiserFunctor(T *res, const std::size_t N, const T beta)
68-
: res(res), N(N), beta(beta)
69-
{
70-
}
71-
72-
void operator()(sycl::id<1> id) const
73-
{
74-
using dpnp::kernels::i0::cyl_bessel_i0;
75-
76-
const auto i = id.get(0);
77-
const T alpha = (N - 1) / T(2);
78-
const T tmp = (i - alpha) / alpha;
79-
res[i] = cyl_bessel_i0(beta * sycl::sqrt(1 - tmp * tmp)) /
80-
cyl_bessel_i0(beta);
81-
}
82-
};
83-
8459
template <typename T>
8560
sycl::event kaiser_impl(sycl::queue &exec_q,
8661
char *result,
@@ -96,7 +71,7 @@ sycl::event kaiser_impl(sycl::queue &exec_q,
9671
sycl::event kaiser_ev = exec_q.submit([&](sycl::handler &cgh) {
9772
cgh.depends_on(depends);
9873

99-
using KaiserKernel = KaiserFunctor<T>;
74+
using KaiserKernel = dpnp::kernels::kaiser::KaiserFunctor<T>;
10075
cgh.parallel_for<KaiserKernel>(sycl::range<1>(nelems),
10176
KaiserKernel(res, nelems, beta));
10277
});
@@ -138,11 +113,12 @@ std::pair<sycl::event, sycl::event>
138113

139114
return std::make_pair(args_ev, kaiser_ev);
140115
}
116+
} // namespace impl
141117

142118
void init_kaiser_dispatch_vectors()
143119
{
144-
init_dispatch_vector<kaiser_fn_ptr_t, KaiserFactory>(
120+
using ext::common::init_dispatch_vector;
121+
init_dispatch_vector<kaiser_fn_ptr_t, impl::KaiserFactory>(
145122
kaiser_dispatch_vector);
146123
}
147-
148124
} // namespace dpnp::extensions::window

dpnp/backend/extensions/window/kaiser.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,20 @@
2828

2929
#pragma once
3030

31-
#include <dpctl4pybind11.hpp>
3231
#include <sycl/sycl.hpp>
3332

33+
#include <dpctl4pybind11.hpp>
34+
#include <pybind11/pybind11.h>
35+
3436
namespace dpnp::extensions::window
3537
{
38+
namespace py = pybind11;
39+
3640
extern std::pair<sycl::event, sycl::event>
3741
py_kaiser(sycl::queue &exec_q,
3842
const py::object &beta,
3943
const dpctl::tensor::usm_ndarray &result,
4044
const std::vector<sycl::event> &depends);
4145

4246
extern void init_kaiser_dispatch_vectors(void);
43-
4447
} // namespace dpnp::extensions::window

dpnp/backend/extensions/window/window_py.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@
3333
#include <pybind11/pybind11.h>
3434
#include <pybind11/stl.h>
3535

36-
#include "bartlett.hpp"
37-
#include "blackman.hpp"
36+
#include "kernels/window/bartlett.hpp"
37+
#include "kernels/window/blackman.hpp"
38+
#include "kernels/window/hamming.hpp"
39+
#include "kernels/window/hanning.hpp"
40+
3841
#include "common.hpp"
39-
#include "hamming.hpp"
40-
#include "hanning.hpp"
4142
#include "kaiser.hpp"
4243

4344
// utils extension header
@@ -51,6 +52,22 @@ using window_ns::window_fn_ptr_t;
5152

5253
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
5354

55+
template <typename fnT, typename T>
56+
using BartlettFactory =
57+
window_ns::Factory<fnT, T, dpnp::kernels::bartlett::BartlettFunctor>;
58+
59+
template <typename fnT, typename T>
60+
using BlackmanFactory =
61+
window_ns::Factory<fnT, T, dpnp::kernels::blackman::BlackmanFunctor>;
62+
63+
template <typename fnT, typename T>
64+
using HammingFactory =
65+
window_ns::Factory<fnT, T, dpnp::kernels::hamming::HammingFunctor>;
66+
67+
template <typename fnT, typename T>
68+
using HanningFactory =
69+
window_ns::Factory<fnT, T, dpnp::kernels::hanning::HanningFunctor>;
70+
5471
static window_fn_ptr_t bartlett_dispatch_vector[dpctl_td_ns::num_types];
5572
static window_fn_ptr_t blackman_dispatch_vector[dpctl_td_ns::num_types];
5673
static window_fn_ptr_t hamming_dispatch_vector[dpctl_td_ns::num_types];
@@ -62,8 +79,7 @@ PYBIND11_MODULE(_window_impl, m)
6279
using event_vecT = std::vector<sycl::event>;
6380

6481
{
65-
init_dispatch_vector<window_ns::window_fn_ptr_t,
66-
window_ns::kernels::BartlettFactory>(
82+
init_dispatch_vector<window_ns::window_fn_ptr_t, BartlettFactory>(
6783
bartlett_dispatch_vector);
6884

6985
auto bartlett_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
@@ -78,8 +94,7 @@ PYBIND11_MODULE(_window_impl, m)
7894
}
7995

8096
{
81-
init_dispatch_vector<window_ns::window_fn_ptr_t,
82-
window_ns::kernels::BlackmanFactory>(
97+
init_dispatch_vector<window_ns::window_fn_ptr_t, BlackmanFactory>(
8398
blackman_dispatch_vector);
8499

85100
auto blackman_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
@@ -94,8 +109,7 @@ PYBIND11_MODULE(_window_impl, m)
94109
}
95110

96111
{
97-
init_dispatch_vector<window_ns::window_fn_ptr_t,
98-
window_ns::kernels::HammingFactory>(
112+
init_dispatch_vector<window_ns::window_fn_ptr_t, HammingFactory>(
99113
hamming_dispatch_vector);
100114

101115
auto hamming_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
@@ -110,8 +124,7 @@ PYBIND11_MODULE(_window_impl, m)
110124
}
111125

112126
{
113-
init_dispatch_vector<window_ns::window_fn_ptr_t,
114-
window_ns::kernels::HanningFactory>(
127+
init_dispatch_vector<window_ns::window_fn_ptr_t, HanningFactory>(
115128
hanning_dispatch_vector);
116129

117130
auto hanning_pyapi = [&](sycl::queue &exec_q, const arrayT &result,

dpnp/backend/kernels/indexing/choose.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2024, Intel Corporation
2+
// Copyright (c) 2026, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without

dpnp/backend/extensions/window/bartlett.hpp renamed to dpnp/backend/kernels/window/bartlett.hpp

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2025, Intel Corporation
2+
// Copyright (c) 2026, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without
@@ -19,7 +19,7 @@
1919
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
2020
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
2121
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22-
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, RES, OR PROFITS; OR BUSINESS
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
2323
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
2424
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
2525
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
@@ -28,12 +28,10 @@
2828

2929
#pragma once
3030

31-
#include "common.hpp"
3231
#include <sycl/sycl.hpp>
3332

34-
namespace dpnp::extensions::window::kernels
33+
namespace dpnp::kernels::bartlett
3534
{
36-
3735
template <typename T>
3836
class BartlettFunctor
3937
{
@@ -52,19 +50,4 @@ class BartlettFunctor
5250
res[i] = T(1) - sycl::fabs(i - alpha) / alpha;
5351
}
5452
};
55-
56-
template <typename fnT, typename T>
57-
struct BartlettFactory
58-
{
59-
fnT get()
60-
{
61-
if constexpr (std::is_floating_point_v<T>) {
62-
return window_impl<T, BartlettFunctor>;
63-
}
64-
else {
65-
return nullptr;
66-
}
67-
}
68-
};
69-
70-
} // namespace dpnp::extensions::window::kernels
53+
} // namespace dpnp::kernels::bartlett

0 commit comments

Comments
 (0)