Skip to content

Commit cc2c154

Browse files
Add C-contig implementation of putmask
1 parent 0086e37 commit cc2c154

6 files changed

Lines changed: 592 additions & 20 deletions

File tree

dpnp/backend/extensions/indexing/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
set(python_module_name _indexing_impl)
3131
set(_module_src
3232
${CMAKE_CURRENT_SOURCE_DIR}/choose.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/putmask.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/indexing_py.cpp
3435
)
3536

dpnp/backend/extensions/indexing/indexing_py.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
#include <pybind11/pybind11.h>
3434

3535
#include "choose.hpp"
36+
#include "putmask.hpp"
3637

3738
PYBIND11_MODULE(_indexing_impl, m)
3839
{
3940
dpnp::extensions::indexing::init_choose(m);
41+
dpnp::extensions::indexing::init_putmask(m);
4042
}
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#include <stdexcept>
30+
#include <type_traits>
31+
#include <vector>
32+
33+
#include <sycl/sycl.hpp>
34+
35+
#include "dpctl4pybind11.hpp"
36+
#include <pybind11/pybind11.h>
37+
#include <pybind11/stl.h>
38+
39+
#include "putmask_kernel.hpp"
40+
41+
#include "../elementwise_functions/simplify_iteration_space.hpp"
42+
43+
// dpctl tensor headers
44+
#include "utils/offset_utils.hpp"
45+
#include "utils/output_validation.hpp"
46+
#include "utils/type_dispatch.hpp"
47+
#include "utils/type_utils.hpp"
48+
49+
// utils extension headers
50+
#include "ext/common.hpp"
51+
#include "ext/validation_utils.hpp"
52+
53+
namespace py = pybind11;
54+
namespace td_ns = dpctl::tensor::type_dispatch;
55+
56+
using dpctl::tensor::usm_ndarray;
57+
58+
using ext::common::dtype_from_typenum;
59+
using ext::validation::array_names;
60+
using ext::validation::check_has_dtype;
61+
using ext::validation::check_no_overlap;
62+
using ext::validation::check_num_dims;
63+
using ext::validation::check_queue;
64+
using ext::validation::check_same_dtype;
65+
using ext::validation::check_same_size;
66+
using ext::validation::check_writable;
67+
68+
namespace dpnp::extensions::indexing
69+
{
70+
using ext::common::init_dispatch_vector;
71+
72+
typedef sycl::event (*putmask_contig_fn_ptr_t)(
73+
sycl::queue &,
74+
const std::size_t, // nelems
75+
char *, // dst
76+
const char *, // mask
77+
const char *, // values
78+
const std::size_t, // values_size
79+
const std::vector<sycl::event> &);
80+
81+
static putmask_contig_fn_ptr_t putmask_contig_dispatch_vector[td_ns::num_types];
82+
83+
std::pair<sycl::event, sycl::event>
84+
py_putmask(const usm_ndarray &dst,
85+
const usm_ndarray &mask,
86+
const usm_ndarray &values,
87+
sycl::queue &exec_q,
88+
const std::vector<sycl::event> &depends = {})
89+
{
90+
array_names names = {{&dst, "dst"}, {&mask, "mask"}, {&values, "values"}};
91+
92+
check_same_dtype(&dst, &values, names);
93+
check_has_dtype(&mask, td_ns::typenum_t::BOOL, names);
94+
95+
check_same_size({&dst, &mask}, names);
96+
const int nd = dst.get_ndim();
97+
check_num_dims(&mask, nd, names);
98+
99+
check_queue({&dst, &mask, &values}, names, exec_q);
100+
check_no_overlap({&mask, &values}, {&dst}, names);
101+
check_writable({&dst}, names);
102+
103+
// values must be 1D
104+
check_num_dims(&values, 1, names);
105+
106+
auto types = td_ns::usm_ndarray_types();
107+
// dst_typeid == values_typeid (check_same_dtype(&dst, &values, names))
108+
int dst_values_typeid = types.typenum_to_lookup_id(dst.get_typenum());
109+
110+
const py::ssize_t *dst_shape = dst.get_shape_raw();
111+
const py::ssize_t *mask_shape = mask.get_shape_raw();
112+
bool shapes_equal(true);
113+
std::size_t nelems(1);
114+
115+
for (int i = 0; i < std::max(nd, 1); ++i) {
116+
const py::ssize_t d = (nd == 0 ? 1 : dst_shape[i]);
117+
const py::ssize_t m = (nd == 0 ? 1 : mask_shape[i]);
118+
nelems *= static_cast<std::size_t>(d);
119+
shapes_equal = shapes_equal && (d == m);
120+
}
121+
if (!shapes_equal) {
122+
throw py::value_error("`mask` and `dst` shapes must match");
123+
}
124+
125+
// if nelems is zero, return
126+
if (nelems == 0) {
127+
return {sycl::event(), sycl::event()};
128+
}
129+
130+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems);
131+
132+
char *dst_p = dst.get_data();
133+
const char *mask_p = mask.get_data();
134+
const char *values_p = values.get_data();
135+
const std::size_t values_size = values.get_size();
136+
137+
// handle C contiguous inputs
138+
const bool is_dst_c_contig = dst.is_c_contiguous();
139+
const bool is_mask_c_contig = mask.is_c_contiguous();
140+
const bool is_values_c_contig = values.is_c_contiguous();
141+
142+
const bool all_c_contig =
143+
(is_dst_c_contig && is_mask_c_contig && is_values_c_contig);
144+
145+
if (all_c_contig) {
146+
auto contig_fn = putmask_contig_dispatch_vector[dst_values_typeid];
147+
148+
if (contig_fn == nullptr) {
149+
py::dtype dst_values_dtype_py =
150+
dtype_from_typenum(dst_values_typeid);
151+
throw std::runtime_error(
152+
"Contiguous implementation is missing for " +
153+
std::string(py::str(dst_values_dtype_py)) + "data type");
154+
}
155+
156+
auto comp_ev = contig_fn(exec_q, nelems, dst_p, mask_p, values_p,
157+
values_size, depends);
158+
sycl::event ht_ev = dpctl::utils::keep_args_alive(
159+
exec_q, {dst, mask, values}, {comp_ev});
160+
161+
return std::make_pair(ht_ev, comp_ev);
162+
}
163+
164+
throw py::value_error("Stride implementation is not implemented yet");
165+
}
166+
167+
/**
168+
* @brief A factory to define pairs of supported types for which
169+
* putmask function is available.
170+
*
171+
* @tparam T Type of input vector `dst` and `values` and of result vector `dst`.
172+
*/
173+
template <typename T>
174+
struct PutMaskOutputType
175+
{
176+
using value_type = typename std::disjunction<
177+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
178+
td_ns::TypeMapResultEntry<T, std::int8_t>,
179+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
180+
td_ns::TypeMapResultEntry<T, std::int16_t>,
181+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
182+
td_ns::TypeMapResultEntry<T, std::int32_t>,
183+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
184+
td_ns::TypeMapResultEntry<T, std::int64_t>,
185+
td_ns::TypeMapResultEntry<T, sycl::half>,
186+
td_ns::TypeMapResultEntry<T, float>,
187+
td_ns::TypeMapResultEntry<T, double>,
188+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
189+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
190+
td_ns::DefaultResultEntry<void>>::result_type;
191+
};
192+
193+
template <typename fnT, typename T>
194+
struct PutMaskContigFactory
195+
{
196+
fnT get()
197+
{
198+
if constexpr (std::is_same_v<typename PutMaskOutputType<T>::value_type,
199+
void>) {
200+
return nullptr;
201+
}
202+
else {
203+
return kernels::putmask_contig_impl<T>;
204+
}
205+
}
206+
};
207+
208+
static void populate_putmask_dispatch_vectors()
209+
{
210+
init_dispatch_vector<putmask_contig_fn_ptr_t, PutMaskContigFactory>(
211+
putmask_contig_dispatch_vector);
212+
}
213+
214+
void init_putmask(py::module_ m)
215+
{
216+
populate_putmask_dispatch_vectors();
217+
218+
m.def("_putmask", &py_putmask, "", py::arg("dst"), py::arg("mask"),
219+
py::arg("values"), py::arg("sycl_queue"),
220+
py::arg("depends") = py::list());
221+
}
222+
223+
} // namespace dpnp::extensions::indexing
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <pybind11/pybind11.h>
32+
33+
namespace py = pybind11;
34+
35+
namespace dpnp::extensions::indexing
36+
{
37+
void init_putmask(py::module_ m);
38+
} // namespace dpnp::extensions::indexing

0 commit comments

Comments
 (0)