3030#include < cstddef>
3131#include < cstdint>
3232#include < memory>
33- #include < pybind11/pybind11.h>
34- #include < pybind11/stl.h>
35- #include < sycl/sycl.hpp>
3633#include < type_traits>
3734#include < utility>
3835#include < vector>
3936
40- #include " choose_kernel.hpp"
37+ #include < sycl/sycl.hpp>
38+
4139#include " dpctl4pybind11.hpp"
40+ #include < pybind11/pybind11.h>
41+ #include < pybind11/stl.h>
4242
43- // utils extension header
4443#include " ext/common.hpp"
44+ #include " kernels/indexing/choose.hpp"
4545
4646// dpctl tensor headers
4747#include " utils/indexing_utils.hpp"
4848#include " utils/memory_overlap.hpp"
49+ #include " utils/offset_utils.hpp" //
4950#include " utils/output_validation.hpp"
5051#include " utils/sycl_alloc_utils.hpp"
5152#include " utils/type_dispatch.hpp"
5253
5354namespace dpnp ::extensions::indexing
5455{
55-
56+ namespace py = pybind11;
5657namespace td_ns = dpctl::tensor::type_dispatch;
5758
58- static kernels::choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
59- [td_ns::num_types];
60- static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
61- [td_ns::num_types];
59+ using dpctl::tensor::ssize_t ;
60+
61+ typedef sycl::event (*choose_fn_ptr_t )(sycl::queue &,
62+ size_t ,
63+ ssize_t ,
64+ int ,
65+ const ssize_t *,
66+ const char *,
67+ char *,
68+ char **,
69+ ssize_t ,
70+ ssize_t ,
71+ const ssize_t *,
72+ const std::vector<sycl::event> &);
73+
74+ static choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
75+ [td_ns::num_types];
76+ static choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
77+ [td_ns::num_types];
78+
79+ template <typename ProjectorT, typename indTy, typename Ty>
80+ sycl::event choose_impl (sycl::queue &q,
81+ size_t nelems,
82+ ssize_t n_chcs,
83+ int nd,
84+ const ssize_t *shape_and_strides,
85+ const char *ind_cp,
86+ char *dst_cp,
87+ char **chcs_cp,
88+ ssize_t ind_offset,
89+ ssize_t dst_offset,
90+ const ssize_t *chc_offsets,
91+ const std::vector<sycl::event> &depends)
92+ {
93+ dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
94+
95+ const indTy *ind_tp = reinterpret_cast <const indTy *>(ind_cp);
96+ Ty *dst_tp = reinterpret_cast <Ty *>(dst_cp);
6297
63- namespace py = pybind11;
98+ sycl::event choose_ev = q.submit ([&](sycl::handler &cgh) {
99+ cgh.depends_on (depends);
64100
65- namespace detail
101+ using InOutIndexerT =
102+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
103+ const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
104+ shape_and_strides};
105+
106+ using NthChoiceIndexerT =
107+ dpnp::kernels::choose::strides::NthStrideOffsetUnpacked;
108+ const NthChoiceIndexerT choices_indexer{
109+ nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};
110+
111+ using ChooseFunc =
112+ dpnp::kernels::choose::ChooseFunctor<ProjectorT, InOutIndexerT,
113+ NthChoiceIndexerT, indTy, Ty>;
114+
115+ cgh.parallel_for <ChooseFunc>(sycl::range<1 >(nelems),
116+ ChooseFunc (ind_tp, dst_tp, chcs_cp, n_chcs,
117+ ind_out_indexer,
118+ choices_indexer));
119+ });
120+
121+ return choose_ev;
122+ }
123+
124+ template <typename fnT, typename IndT, typename T, typename Index>
125+ struct ChooseFactory
66126{
127+ fnT get ()
128+ {
129+ if constexpr (std::is_integral<IndT>::value &&
130+ !std::is_same<IndT, bool >::value) {
131+ fnT fn = choose_impl<Index, IndT, T>;
132+ return fn;
133+ }
134+ else {
135+ fnT fn = nullptr ;
136+ return fn;
137+ }
138+ }
139+ };
67140
141+ namespace detail
142+ {
68143using host_ptrs_allocator_t =
69144 dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
70145using ptrs_t = std::vector<char *, host_ptrs_allocator_t >;
@@ -191,7 +266,6 @@ std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,
191266
192267 return res;
193268}
194-
195269} // namespace detail
196270
197271std::pair<sycl::event, sycl::event>
@@ -412,23 +486,6 @@ std::pair<sycl::event, sycl::event>
412486 return std::make_pair (arg_cleanup_ev, choose_generic_ev);
413487}
414488
415- template <typename fnT, typename IndT, typename T, typename Index>
416- struct ChooseFactory
417- {
418- fnT get ()
419- {
420- if constexpr (std::is_integral<IndT>::value &&
421- !std::is_same<IndT, bool >::value) {
422- fnT fn = kernels::choose_impl<Index, IndT, T>;
423- return fn;
424- }
425- else {
426- fnT fn = nullptr ;
427- return fn;
428- }
429- }
430- };
431-
432489using dpctl::tensor::indexing_utils::ClipIndex;
433490using dpctl::tensor::indexing_utils::WrapIndex;
434491
@@ -441,7 +498,6 @@ using ChooseClipFactory = ChooseFactory<fnT, IndT, T, ClipIndex<IndT>>;
441498void init_choose_dispatch_tables (void )
442499{
443500 using ext::common::init_dispatch_table;
444- using kernels::choose_fn_ptr_t ;
445501
446502 init_dispatch_table<choose_fn_ptr_t , ChooseClipFactory>(
447503 choose_clip_dispatch_table);
0 commit comments