Skip to content

Commit 4f00632

Browse files
Move _clip to dpctl_ext/tensor/libtensor
1 parent c58b531 commit 4f00632

File tree

5 files changed

+693
-12
lines changed

5 files changed

+693
-12
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ set(_tensor_impl_sources
6161
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
6262
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
6363
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
64-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
64+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
6565
)
6666

6767
set(_static_lib_trgt simplify_iteration_space)
@@ -94,7 +94,7 @@ set(_no_fast_math_sources
9494
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
9595
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
9696
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
97-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
97+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
9898
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
9999
)
100100
#list(
Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for dpctl.tensor.clip.
33+
//===---------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <algorithm>
37+
#include <cmath>
38+
#include <complex>
39+
#include <cstddef>
40+
#include <cstdint>
41+
#include <type_traits>
42+
43+
#include <sycl/sycl.hpp>
44+
45+
#include "dpctl_tensor_types.hpp"
46+
#include "kernels/alignment.hpp"
47+
#include "utils/math_utils.hpp"
48+
#include "utils/offset_utils.hpp"
49+
#include "utils/sycl_utils.hpp"
50+
#include "utils/type_utils.hpp"
51+
52+
namespace dpctl::tensor::kernels::clip
53+
{
54+
55+
using dpctl::tensor::ssize_t;
56+
using namespace dpctl::tensor::offset_utils;
57+
58+
using dpctl::tensor::kernels::alignment_utils::
59+
disabled_sg_loadstore_wrapper_krn;
60+
using dpctl::tensor::kernels::alignment_utils::is_aligned;
61+
using dpctl::tensor::kernels::alignment_utils::required_alignment;
62+
63+
using dpctl::tensor::sycl_utils::sub_group_load;
64+
using dpctl::tensor::sycl_utils::sub_group_store;
65+
66+
template <typename T>
67+
T clip(const T &x, const T &min, const T &max)
68+
{
69+
using dpctl::tensor::type_utils::is_complex;
70+
if constexpr (is_complex<T>::value) {
71+
using dpctl::tensor::math_utils::max_complex;
72+
using dpctl::tensor::math_utils::min_complex;
73+
return min_complex(max_complex(x, min), max);
74+
}
75+
else if constexpr (std::is_floating_point_v<T> ||
76+
std::is_same_v<T, sycl::half>) {
77+
auto tmp = (std::isnan(x) || x > min) ? x : min;
78+
return (std::isnan(tmp) || tmp < max) ? tmp : max;
79+
}
80+
else if constexpr (std::is_same_v<T, bool>) {
81+
return (x || min) && max;
82+
}
83+
else {
84+
auto tmp = (x > min) ? x : min;
85+
return (tmp < max) ? tmp : max;
86+
}
87+
}
88+
89+
template <typename T,
90+
std::uint8_t vec_sz = 4,
91+
std::uint8_t n_vecs = 2,
92+
bool enable_sg_loadstore = true>
93+
class ClipContigFunctor
94+
{
95+
private:
96+
std::size_t nelems = 0;
97+
const T *x_p = nullptr;
98+
const T *min_p = nullptr;
99+
const T *max_p = nullptr;
100+
T *dst_p = nullptr;
101+
102+
public:
103+
ClipContigFunctor(std::size_t nelems_,
104+
const T *x_p_,
105+
const T *min_p_,
106+
const T *max_p_,
107+
T *dst_p_)
108+
: nelems(nelems_), x_p(x_p_), min_p(min_p_), max_p(max_p_),
109+
dst_p(dst_p_)
110+
{
111+
}
112+
113+
void operator()(sycl::nd_item<1> ndit) const
114+
{
115+
static constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz;
116+
117+
using dpctl::tensor::type_utils::is_complex;
118+
if constexpr (is_complex<T>::value || !enable_sg_loadstore) {
119+
const std::uint16_t sgSize =
120+
ndit.get_sub_group().get_local_range()[0];
121+
const std::size_t gid = ndit.get_global_linear_id();
122+
const std::uint16_t nelems_per_sg = sgSize * nelems_per_wi;
123+
124+
const std::size_t start =
125+
(gid / sgSize) * (nelems_per_sg - sgSize) + gid;
126+
const std::size_t end = std::min(nelems, start + nelems_per_sg);
127+
128+
for (std::size_t offset = start; offset < end; offset += sgSize) {
129+
dst_p[offset] = clip(x_p[offset], min_p[offset], max_p[offset]);
130+
}
131+
}
132+
else {
133+
auto sg = ndit.get_sub_group();
134+
const std::uint16_t sgSize = sg.get_max_local_range()[0];
135+
136+
const std::size_t base =
137+
nelems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
138+
sg.get_group_id()[0] * sgSize);
139+
140+
if (base + nelems_per_wi * sgSize < nelems) {
141+
sycl::vec<T, vec_sz> dst_vec;
142+
#pragma unroll
143+
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
144+
const std::size_t idx = base + it * sgSize;
145+
auto x_multi_ptr = sycl::address_space_cast<
146+
sycl::access::address_space::global_space,
147+
sycl::access::decorated::yes>(&x_p[idx]);
148+
auto min_multi_ptr = sycl::address_space_cast<
149+
sycl::access::address_space::global_space,
150+
sycl::access::decorated::yes>(&min_p[idx]);
151+
auto max_multi_ptr = sycl::address_space_cast<
152+
sycl::access::address_space::global_space,
153+
sycl::access::decorated::yes>(&max_p[idx]);
154+
auto dst_multi_ptr = sycl::address_space_cast<
155+
sycl::access::address_space::global_space,
156+
sycl::access::decorated::yes>(&dst_p[idx]);
157+
158+
const sycl::vec<T, vec_sz> x_vec =
159+
sub_group_load<vec_sz>(sg, x_multi_ptr);
160+
const sycl::vec<T, vec_sz> min_vec =
161+
sub_group_load<vec_sz>(sg, min_multi_ptr);
162+
const sycl::vec<T, vec_sz> max_vec =
163+
sub_group_load<vec_sz>(sg, max_multi_ptr);
164+
#pragma unroll
165+
for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
166+
dst_vec[vec_id] = clip(x_vec[vec_id], min_vec[vec_id],
167+
max_vec[vec_id]);
168+
}
169+
sub_group_store<vec_sz>(sg, dst_vec, dst_multi_ptr);
170+
}
171+
}
172+
else {
173+
const std::size_t lane_id = sg.get_local_id()[0];
174+
for (std::size_t k = base + lane_id; k < nelems; k += sgSize) {
175+
dst_p[k] = clip(x_p[k], min_p[k], max_p[k]);
176+
}
177+
}
178+
}
179+
}
180+
};
181+
182+
template <typename T, int vec_sz, int n_vecs>
183+
class clip_contig_kernel;
184+
185+
typedef sycl::event (*clip_contig_impl_fn_ptr_t)(
186+
sycl::queue &,
187+
std::size_t,
188+
const char *,
189+
const char *,
190+
const char *,
191+
char *,
192+
const std::vector<sycl::event> &);
193+
194+
template <typename T>
195+
sycl::event clip_contig_impl(sycl::queue &q,
196+
std::size_t nelems,
197+
const char *x_cp,
198+
const char *min_cp,
199+
const char *max_cp,
200+
char *dst_cp,
201+
const std::vector<sycl::event> &depends)
202+
{
203+
const T *x_tp = reinterpret_cast<const T *>(x_cp);
204+
const T *min_tp = reinterpret_cast<const T *>(min_cp);
205+
const T *max_tp = reinterpret_cast<const T *>(max_cp);
206+
T *dst_tp = reinterpret_cast<T *>(dst_cp);
207+
208+
sycl::event clip_ev = q.submit([&](sycl::handler &cgh) {
209+
cgh.depends_on(depends);
210+
211+
std::size_t lws = 64;
212+
static constexpr std::uint8_t vec_sz = 4;
213+
static constexpr std::uint8_t n_vecs = 2;
214+
const std::size_t n_groups =
215+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
216+
const auto gws_range = sycl::range<1>(n_groups * lws);
217+
const auto lws_range = sycl::range<1>(lws);
218+
219+
if (is_aligned<required_alignment>(x_cp) &&
220+
is_aligned<required_alignment>(min_cp) &&
221+
is_aligned<required_alignment>(max_cp) &&
222+
is_aligned<required_alignment>(dst_cp))
223+
{
224+
static constexpr bool enable_sg_loadstore = true;
225+
using KernelName = clip_contig_kernel<T, vec_sz, n_vecs>;
226+
using Impl =
227+
ClipContigFunctor<T, vec_sz, n_vecs, enable_sg_loadstore>;
228+
229+
cgh.parallel_for<KernelName>(
230+
sycl::nd_range<1>(gws_range, lws_range),
231+
Impl(nelems, x_tp, min_tp, max_tp, dst_tp));
232+
}
233+
else {
234+
static constexpr bool disable_sg_loadstore = false;
235+
using InnerKernelName = clip_contig_kernel<T, vec_sz, n_vecs>;
236+
using KernelName =
237+
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
238+
using Impl =
239+
ClipContigFunctor<T, vec_sz, n_vecs, disable_sg_loadstore>;
240+
241+
cgh.parallel_for<KernelName>(
242+
sycl::nd_range<1>(gws_range, lws_range),
243+
Impl(nelems, x_tp, min_tp, max_tp, dst_tp));
244+
}
245+
});
246+
247+
return clip_ev;
248+
}
249+
250+
template <typename T, typename IndexerT>
251+
class ClipStridedFunctor
252+
{
253+
private:
254+
const T *x_p = nullptr;
255+
const T *min_p = nullptr;
256+
const T *max_p = nullptr;
257+
T *dst_p = nullptr;
258+
IndexerT indexer;
259+
260+
public:
261+
ClipStridedFunctor(const T *x_p_,
262+
const T *min_p_,
263+
const T *max_p_,
264+
T *dst_p_,
265+
const IndexerT &indexer_)
266+
: x_p(x_p_), min_p(min_p_), max_p(max_p_), dst_p(dst_p_),
267+
indexer(indexer_)
268+
{
269+
}
270+
271+
void operator()(sycl::id<1> id) const
272+
{
273+
std::size_t gid = id[0];
274+
auto offsets = indexer(static_cast<ssize_t>(gid));
275+
dst_p[offsets.get_fourth_offset()] = clip(
276+
x_p[offsets.get_first_offset()], min_p[offsets.get_second_offset()],
277+
max_p[offsets.get_third_offset()]);
278+
}
279+
};
280+
281+
template <typename T, typename IndexerT>
282+
class clip_strided_kernel;
283+
284+
typedef sycl::event (*clip_strided_impl_fn_ptr_t)(
285+
sycl::queue &,
286+
std::size_t,
287+
int,
288+
const char *,
289+
const char *,
290+
const char *,
291+
char *,
292+
const ssize_t *,
293+
ssize_t,
294+
ssize_t,
295+
ssize_t,
296+
ssize_t,
297+
const std::vector<sycl::event> &);
298+
299+
template <typename T>
300+
sycl::event clip_strided_impl(sycl::queue &q,
301+
std::size_t nelems,
302+
int nd,
303+
const char *x_cp,
304+
const char *min_cp,
305+
const char *max_cp,
306+
char *dst_cp,
307+
const ssize_t *shape_strides,
308+
ssize_t x_offset,
309+
ssize_t min_offset,
310+
ssize_t max_offset,
311+
ssize_t dst_offset,
312+
const std::vector<sycl::event> &depends)
313+
{
314+
const T *x_tp = reinterpret_cast<const T *>(x_cp);
315+
const T *min_tp = reinterpret_cast<const T *>(min_cp);
316+
const T *max_tp = reinterpret_cast<const T *>(max_cp);
317+
T *dst_tp = reinterpret_cast<T *>(dst_cp);
318+
319+
sycl::event clip_ev = q.submit([&](sycl::handler &cgh) {
320+
cgh.depends_on(depends);
321+
322+
const FourOffsets_StridedIndexer indexer{
323+
nd, x_offset, min_offset, max_offset, dst_offset, shape_strides};
324+
325+
using KernelName = clip_strided_kernel<T, FourOffsets_StridedIndexer>;
326+
using Impl = ClipStridedFunctor<T, FourOffsets_StridedIndexer>;
327+
328+
cgh.parallel_for<KernelName>(
329+
sycl::range<1>(nelems),
330+
Impl(x_tp, min_tp, max_tp, dst_tp, indexer));
331+
});
332+
333+
return clip_ev;
334+
}
335+
336+
template <typename fnT, typename T>
337+
struct ClipStridedFactory
338+
{
339+
fnT get()
340+
{
341+
fnT fn = clip_strided_impl<T>;
342+
return fn;
343+
}
344+
};
345+
346+
template <typename fnT, typename T>
347+
struct ClipContigFactory
348+
{
349+
fnT get()
350+
{
351+
352+
fnT fn = clip_contig_impl<T>;
353+
return fn;
354+
}
355+
};
356+
357+
} // namespace dpctl::tensor::kernels::clip

0 commit comments

Comments
 (0)