Skip to content

Commit 87cde88

Browse files
committed
Get rid of dpnp/backend/extensions/indexing/choose_kernel.hpp
1 parent f019aad commit 87cde88

File tree

3 files changed

+126
-180
lines changed

3 files changed

+126
-180
lines changed

dpnp/backend/extensions/indexing/choose.cpp

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,41 +30,116 @@
3030
#include <cstddef>
3131
#include <cstdint>
3232
#include <memory>
33-
#include <pybind11/pybind11.h>
34-
#include <pybind11/stl.h>
35-
#include <sycl/sycl.hpp>
3633
#include <type_traits>
3734
#include <utility>
3835
#include <vector>
3936

40-
#include "choose_kernel.hpp"
37+
#include <sycl/sycl.hpp>
38+
4139
#include "dpctl4pybind11.hpp"
40+
#include <pybind11/pybind11.h>
41+
#include <pybind11/stl.h>
4242

43-
// utils extension header
4443
#include "ext/common.hpp"
44+
#include "kernels/indexing/choose.hpp"
4545

4646
// dpctl tensor headers
4747
#include "utils/indexing_utils.hpp"
4848
#include "utils/memory_overlap.hpp"
49+
#include "utils/offset_utils.hpp" //
4950
#include "utils/output_validation.hpp"
5051
#include "utils/sycl_alloc_utils.hpp"
5152
#include "utils/type_dispatch.hpp"
5253

5354
namespace dpnp::extensions::indexing
5455
{
55-
56+
namespace py = pybind11;
5657
namespace td_ns = dpctl::tensor::type_dispatch;
5758

58-
static kernels::choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
59-
[td_ns::num_types];
60-
static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
61-
[td_ns::num_types];
59+
using dpctl::tensor::ssize_t;
60+
61+
typedef sycl::event (*choose_fn_ptr_t)(sycl::queue &,
62+
size_t,
63+
ssize_t,
64+
int,
65+
const ssize_t *,
66+
const char *,
67+
char *,
68+
char **,
69+
ssize_t,
70+
ssize_t,
71+
const ssize_t *,
72+
const std::vector<sycl::event> &);
73+
74+
static choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
75+
[td_ns::num_types];
76+
static choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
77+
[td_ns::num_types];
78+
79+
template <typename ProjectorT, typename indTy, typename Ty>
80+
sycl::event choose_impl(sycl::queue &q,
81+
size_t nelems,
82+
ssize_t n_chcs,
83+
int nd,
84+
const ssize_t *shape_and_strides,
85+
const char *ind_cp,
86+
char *dst_cp,
87+
char **chcs_cp,
88+
ssize_t ind_offset,
89+
ssize_t dst_offset,
90+
const ssize_t *chc_offsets,
91+
const std::vector<sycl::event> &depends)
92+
{
93+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
94+
95+
const indTy *ind_tp = reinterpret_cast<const indTy *>(ind_cp);
96+
Ty *dst_tp = reinterpret_cast<Ty *>(dst_cp);
6297

63-
namespace py = pybind11;
98+
sycl::event choose_ev = q.submit([&](sycl::handler &cgh) {
99+
cgh.depends_on(depends);
64100

65-
namespace detail
101+
using InOutIndexerT =
102+
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
103+
const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
104+
shape_and_strides};
105+
106+
using NthChoiceIndexerT =
107+
dpnp::kernels::choose::strides::NthStrideOffsetUnpacked;
108+
const NthChoiceIndexerT choices_indexer{
109+
nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};
110+
111+
using ChooseFunc =
112+
dpnp::kernels::choose::ChooseFunctor<ProjectorT, InOutIndexerT,
113+
NthChoiceIndexerT, indTy, Ty>;
114+
115+
cgh.parallel_for<ChooseFunc>(sycl::range<1>(nelems),
116+
ChooseFunc(ind_tp, dst_tp, chcs_cp, n_chcs,
117+
ind_out_indexer,
118+
choices_indexer));
119+
});
120+
121+
return choose_ev;
122+
}
123+
124+
template <typename fnT, typename IndT, typename T, typename Index>
125+
struct ChooseFactory
66126
{
127+
fnT get()
128+
{
129+
if constexpr (std::is_integral<IndT>::value &&
130+
!std::is_same<IndT, bool>::value) {
131+
fnT fn = choose_impl<Index, IndT, T>;
132+
return fn;
133+
}
134+
else {
135+
fnT fn = nullptr;
136+
return fn;
137+
}
138+
}
139+
};
67140

141+
namespace detail
142+
{
68143
using host_ptrs_allocator_t =
69144
dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
70145
using ptrs_t = std::vector<char *, host_ptrs_allocator_t>;
@@ -191,7 +266,6 @@ std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,
191266

192267
return res;
193268
}
194-
195269
} // namespace detail
196270

197271
std::pair<sycl::event, sycl::event>
@@ -412,23 +486,6 @@ std::pair<sycl::event, sycl::event>
412486
return std::make_pair(arg_cleanup_ev, choose_generic_ev);
413487
}
414488

415-
template <typename fnT, typename IndT, typename T, typename Index>
416-
struct ChooseFactory
417-
{
418-
fnT get()
419-
{
420-
if constexpr (std::is_integral<IndT>::value &&
421-
!std::is_same<IndT, bool>::value) {
422-
fnT fn = kernels::choose_impl<Index, IndT, T>;
423-
return fn;
424-
}
425-
else {
426-
fnT fn = nullptr;
427-
return fn;
428-
}
429-
}
430-
};
431-
432489
using dpctl::tensor::indexing_utils::ClipIndex;
433490
using dpctl::tensor::indexing_utils::WrapIndex;
434491

@@ -441,7 +498,6 @@ using ChooseClipFactory = ChooseFactory<fnT, IndT, T, ClipIndex<IndT>>;
441498
void init_choose_dispatch_tables(void)
442499
{
443500
using ext::common::init_dispatch_table;
444-
using kernels::choose_fn_ptr_t;
445501

446502
init_dispatch_table<choose_fn_ptr_t, ChooseClipFactory>(
447503
choose_clip_dispatch_table);

dpnp/backend/extensions/indexing/choose_kernel.hpp

Lines changed: 0 additions & 149 deletions
This file was deleted.

dpnp/backend/kernels/indexing/choose.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <sycl/sycl.hpp>
3232

3333
#include "kernels/dpctl_tensor_types.hpp"
34+
#include "utils/strided_iters.hpp"
3435

3536
namespace dpnp::kernels::choose
3637
{
@@ -84,4 +85,42 @@ class ChooseFunctor
8485
dst[dst_offset] = chc[chc_offset];
8586
}
8687
};
88+
89+
namespace strides
90+
{
91+
using dpctl::tensor::strides::CIndexer_vector;
92+
93+
struct NthStrideOffsetUnpacked
94+
{
95+
NthStrideOffsetUnpacked(int common_nd,
96+
ssize_t const *_offsets,
97+
ssize_t const *_shape,
98+
ssize_t const *_strides)
99+
: _ind(common_nd), nd(common_nd), offsets(_offsets), shape(_shape),
100+
strides(_strides)
101+
{
102+
}
103+
104+
template <typename nT>
105+
size_t operator()(ssize_t gid, nT n) const
106+
{
107+
ssize_t relative_offset(0);
108+
_ind.get_displacement<const ssize_t *, const ssize_t *>(
109+
gid, shape, strides + (n * nd), relative_offset);
110+
111+
return relative_offset + offsets[n];
112+
}
113+
114+
private:
115+
CIndexer_vector<ssize_t> _ind;
116+
117+
int nd;
118+
ssize_t const *offsets;
119+
ssize_t const *shape;
120+
ssize_t const *strides;
121+
};
122+
123+
static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
124+
125+
} // namespace strides
87126
} // namespace dpnp::kernels::choose

0 commit comments

Comments
 (0)