3030#include < cstddef>
3131#include < cstdint>
3232#include < memory>
33- #include < pybind11/pybind11.h >
34- #include < pybind11/stl.h >
35- #include < sycl/sycl.hpp >
33+ #include < stdexcept >
34+ #include < string >
35+ #include < tuple >
3636#include < type_traits>
3737#include < utility>
3838#include < vector>
3939
40- #include " choose_kernel.hpp"
40+ #include < sycl/sycl.hpp>
41+
4142#include " dpctl4pybind11.hpp"
43+ #include < pybind11/pybind11.h>
44+ #include < pybind11/stl.h>
4245
43- // utils extension header
4446#include " ext/common.hpp"
47+ #include " kernels/indexing/choose.hpp"
4548
4649// dpctl tensor headers
4750#include " utils/indexing_utils.hpp"
4851#include " utils/memory_overlap.hpp"
52+ #include " utils/offset_utils.hpp"
4953#include " utils/output_validation.hpp"
5054#include " utils/sycl_alloc_utils.hpp"
5155#include " utils/type_dispatch.hpp"
56+ #include " utils/type_utils.hpp"
5257
5358namespace dpnp ::extensions::indexing
5459{
60+ namespace py = pybind11;
5561
62+ namespace impl
63+ {
5664namespace td_ns = dpctl::tensor::type_dispatch;
5765
58- static kernels::choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
59- [td_ns::num_types];
60- static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
61- [td_ns::num_types];
66+ using dpctl::tensor::ssize_t ;
67+
68+ typedef sycl::event (*choose_fn_ptr_t )(sycl::queue &,
69+ size_t ,
70+ ssize_t ,
71+ int ,
72+ const ssize_t *,
73+ const char *,
74+ char *,
75+ char **,
76+ ssize_t ,
77+ ssize_t ,
78+ const ssize_t *,
79+ const std::vector<sycl::event> &);
80+
81+ static choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
82+ [td_ns::num_types];
83+ static choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
84+ [td_ns::num_types];
85+
86+ template <typename ProjectorT, typename indTy, typename Ty>
87+ sycl::event choose_impl (sycl::queue &q,
88+ size_t nelems,
89+ ssize_t n_chcs,
90+ int nd,
91+ const ssize_t *shape_and_strides,
92+ const char *ind_cp,
93+ char *dst_cp,
94+ char **chcs_cp,
95+ ssize_t ind_offset,
96+ ssize_t dst_offset,
97+ const ssize_t *chc_offsets,
98+ const std::vector<sycl::event> &depends)
99+ {
100+ dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
62101
63- namespace py = pybind11;
102+ const indTy *ind_tp = reinterpret_cast <const indTy *>(ind_cp);
103+ Ty *dst_tp = reinterpret_cast <Ty *>(dst_cp);
64104
65- namespace detail
105+ sycl::event choose_ev = q.submit ([&](sycl::handler &cgh) {
106+ cgh.depends_on (depends);
107+
108+ using InOutIndexerT =
109+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
110+ const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
111+ shape_and_strides};
112+
113+ using NthChoiceIndexerT =
114+ dpnp::kernels::choose::strides::NthStrideOffsetUnpacked;
115+ const NthChoiceIndexerT choices_indexer{
116+ nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};
117+
118+ using ChooseFunc =
119+ dpnp::kernels::choose::ChooseFunctor<ProjectorT, InOutIndexerT,
120+ NthChoiceIndexerT, indTy, Ty>;
121+
122+ cgh.parallel_for <ChooseFunc>(sycl::range<1 >(nelems),
123+ ChooseFunc (ind_tp, dst_tp, chcs_cp, n_chcs,
124+ ind_out_indexer,
125+ choices_indexer));
126+ });
127+
128+ return choose_ev;
129+ }
130+
131+ template <typename fnT, typename IndT, typename T, typename Index>
132+ struct ChooseFactory
66133{
134+ fnT get ()
135+ {
136+ if constexpr (std::is_integral<IndT>::value &&
137+ !std::is_same<IndT, bool >::value) {
138+ fnT fn = choose_impl<Index, IndT, T>;
139+ return fn;
140+ }
141+ else {
142+ fnT fn = nullptr ;
143+ return fn;
144+ }
145+ }
146+ };
67147
148+ namespace detail
149+ {
68150using host_ptrs_allocator_t =
69151 dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
70152using ptrs_t = std::vector<char *, host_ptrs_allocator_t >;
@@ -191,7 +273,6 @@ std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,
191273
192274 return res;
193275}
194-
195276} // namespace detail
196277
197278std::pair<sycl::event, sycl::event>
@@ -412,23 +493,6 @@ std::pair<sycl::event, sycl::event>
412493 return std::make_pair (arg_cleanup_ev, choose_generic_ev);
413494}
414495
415- template <typename fnT, typename IndT, typename T, typename Index>
416- struct ChooseFactory
417- {
418- fnT get ()
419- {
420- if constexpr (std::is_integral<IndT>::value &&
421- !std::is_same<IndT, bool >::value) {
422- fnT fn = kernels::choose_impl<Index, IndT, T>;
423- return fn;
424- }
425- else {
426- fnT fn = nullptr ;
427- return fn;
428- }
429- }
430- };
431-
432496using dpctl::tensor::indexing_utils::ClipIndex;
433497using dpctl::tensor::indexing_utils::WrapIndex;
434498
@@ -441,23 +505,22 @@ using ChooseClipFactory = ChooseFactory<fnT, IndT, T, ClipIndex<IndT>>;
441505void init_choose_dispatch_tables (void )
442506{
443507 using ext::common::init_dispatch_table;
444- using kernels::choose_fn_ptr_t ;
445508
446509 init_dispatch_table<choose_fn_ptr_t , ChooseClipFactory>(
447510 choose_clip_dispatch_table);
448511 init_dispatch_table<choose_fn_ptr_t , ChooseWrapFactory>(
449512 choose_wrap_dispatch_table);
450513}
514+ } // namespace impl
451515
452516void init_choose (py::module_ m)
453517{
454- dpnp::extensions::indexing ::init_choose_dispatch_tables ();
518+ impl ::init_choose_dispatch_tables ();
455519
456- m.def (" _choose" , &py_choose, " " , py::arg (" src" ), py::arg (" chcs" ),
520+ m.def (" _choose" , &impl:: py_choose, " " , py::arg (" src" ), py::arg (" chcs" ),
457521 py::arg (" dst" ), py::arg (" mode" ), py::arg (" sycl_queue" ),
458522 py::arg (" depends" ) = py::list ());
459523
460524 return ;
461525}
462-
463526} // namespace dpnp::extensions::indexing
0 commit comments