Skip to content

Commit f019aad

Browse files
committed
Move ChooseFunctor to dpnp/backend/kernels/indexing/choose.hpp
1 parent 7fae3a6 commit f019aad

File tree

3 files changed

+98
-53
lines changed

3 files changed

+98
-53
lines changed

dpnp/backend/extensions/indexing/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ set_target_properties(
6262

6363
target_include_directories(
6464
${python_module_name}
65-
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common
65+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../ ${CMAKE_CURRENT_SOURCE_DIR}/../common
6666
)
6767

6868
# treat below headers as system to suppress the warnings there during the build

dpnp/backend/extensions/indexing/choose_kernel.hpp

Lines changed: 10 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
#include "utils/strided_iters.hpp"
4343
#include "utils/type_utils.hpp"
4444

45-
namespace dpnp::extensions::indexing::strides_detail
45+
#include "kernels/indexing/choose.hpp"
46+
47+
namespace dpnp::extensions::indexing
48+
{
49+
namespace strides_detail
4650
{
4751

4852
struct NthStrideOffsetUnpacked
@@ -78,59 +82,12 @@ struct NthStrideOffsetUnpacked
7882

7983
static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
8084

81-
} // namespace dpnp::extensions::indexing::strides_detail
82-
83-
namespace dpnp::extensions::indexing::kernels
84-
{
85+
} // namespace strides_detail
8586

86-
template <typename ProjectorT,
87-
typename IndOutIndexerT,
88-
typename ChoicesIndexerT,
89-
typename IndT,
90-
typename T>
91-
class ChooseFunctor
87+
namespace kernels
9288
{
93-
private:
94-
const IndT *ind = nullptr;
95-
T *dst = nullptr;
96-
char **chcs = nullptr;
97-
dpctl::tensor::ssize_t n_chcs;
98-
const IndOutIndexerT ind_out_indexer;
99-
const ChoicesIndexerT chcs_indexer;
100-
101-
public:
102-
ChooseFunctor(const IndT *ind_,
103-
T *dst_,
104-
char **chcs_,
105-
dpctl::tensor::ssize_t n_chcs_,
106-
const IndOutIndexerT &ind_out_indexer_,
107-
const ChoicesIndexerT &chcs_indexer_)
108-
: ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
109-
ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
110-
{
111-
}
112-
113-
void operator()(sycl::id<1> id) const
114-
{
115-
const ProjectorT proj{};
116-
117-
dpctl::tensor::ssize_t i = id[0];
118-
119-
auto ind_dst_offsets = ind_out_indexer(i);
120-
dpctl::tensor::ssize_t ind_offset = ind_dst_offsets.get_first_offset();
121-
dpctl::tensor::ssize_t dst_offset = ind_dst_offsets.get_second_offset();
122-
123-
IndT chc_idx = ind[ind_offset];
124-
// proj produces an index in the range of n_chcs
125-
dpctl::tensor::ssize_t projected_idx = proj(n_chcs, chc_idx);
12689

127-
dpctl::tensor::ssize_t chc_offset = chcs_indexer(i, projected_idx);
128-
129-
T *chc = reinterpret_cast<T *>(chcs[projected_idx]);
130-
131-
dst[dst_offset] = chc[chc_offset];
132-
}
133-
};
90+
using dpnp::kernels::choose::ChooseFunctor;
13491

13592
typedef sycl::event (*choose_fn_ptr_t)(sycl::queue &,
13693
size_t,
@@ -188,4 +145,5 @@ sycl::event choose_impl(sycl::queue &q,
188145
return choose_ev;
189146
}
190147

191-
} // namespace dpnp::extensions::indexing::kernels
148+
} // namespace kernels
149+
} // namespace dpnp::extensions::indexing
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <sycl/sycl.hpp>
32+
33+
#include "kernels/dpctl_tensor_types.hpp"
34+
35+
namespace dpnp::kernels::choose
36+
{
37+
using dpctl::tensor::ssize_t;
38+
39+
template <typename ProjectorT,
40+
typename IndOutIndexerT,
41+
typename ChoicesIndexerT,
42+
typename IndT,
43+
typename T>
44+
class ChooseFunctor
45+
{
46+
private:
47+
const IndT *ind = nullptr;
48+
T *dst = nullptr;
49+
char **chcs = nullptr;
50+
ssize_t n_chcs;
51+
const IndOutIndexerT ind_out_indexer;
52+
const ChoicesIndexerT chcs_indexer;
53+
54+
public:
55+
ChooseFunctor(const IndT *ind_,
56+
T *dst_,
57+
char **chcs_,
58+
ssize_t n_chcs_,
59+
const IndOutIndexerT &ind_out_indexer_,
60+
const ChoicesIndexerT &chcs_indexer_)
61+
: ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
62+
ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
63+
{
64+
}
65+
66+
void operator()(sycl::id<1> id) const
67+
{
68+
const ProjectorT proj{};
69+
70+
ssize_t i = id[0];
71+
72+
auto ind_dst_offsets = ind_out_indexer(i);
73+
ssize_t ind_offset = ind_dst_offsets.get_first_offset();
74+
ssize_t dst_offset = ind_dst_offsets.get_second_offset();
75+
76+
IndT chc_idx = ind[ind_offset];
77+
// proj produces an index in the range of n_chcs
78+
ssize_t projected_idx = proj(n_chcs, chc_idx);
79+
80+
ssize_t chc_offset = chcs_indexer(i, projected_idx);
81+
82+
T *chc = reinterpret_cast<T *>(chcs[projected_idx]);
83+
84+
dst[dst_offset] = chc[chc_offset];
85+
}
86+
};
87+
} // namespace dpnp::kernels::choose

0 commit comments

Comments
 (0)