Skip to content

Commit e54114f

Browse files
Move __bitwise_invert/_ceil/_conj to _tensor_elementwise_impl
1 parent d509f6d commit e54114f

File tree

11 files changed

+1230
-9
lines changed

11 files changed

+1230
-9
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ set(_elementwise_sources
8383
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp
8484
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atanh.cpp
8585
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_and.cpp
86-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_invert.cpp
86+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_invert.cpp
8787
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_left_shift.cpp
8888
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_or.cpp
8989
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_right_shift.cpp
9090
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_xor.cpp
9191
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cbrt.cpp
92-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/ceil.cpp
93-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/conj.cpp
92+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/ceil.cpp
93+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/conj.cpp
9494
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/copysign.cpp
9595
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
9696
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of bitwise_invert(x)
33+
/// function that inverts bits of binary representation of the argument.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <cstddef>
38+
#include <cstdint>
39+
#include <type_traits>
40+
#include <vector>
41+
42+
#include <sycl/sycl.hpp>
43+
44+
#include "vec_size_util.hpp"
45+
46+
#include "kernels/dpctl_tensor_types.hpp"
47+
#include "kernels/elementwise_functions/common.hpp"
48+
49+
#include "utils/offset_utils.hpp"
50+
#include "utils/type_dispatch_building.hpp"
51+
#include "utils/type_utils.hpp"
52+
53+
namespace dpctl::tensor::kernels::bitwise_invert
54+
{
55+
56+
using dpctl::tensor::ssize_t;
57+
namespace td_ns = dpctl::tensor::type_dispatch;
58+
namespace tu_ns = dpctl::tensor::type_utils;
59+
60+
using dpctl::tensor::type_utils::vec_cast;
61+
62+
template <typename argT, typename resT>
63+
struct BitwiseInvertFunctor
64+
{
65+
static_assert(std::is_same_v<argT, resT>);
66+
static_assert(std::is_integral_v<argT> || std::is_same_v<argT, bool>);
67+
68+
using is_constant = typename std::false_type;
69+
// constexpr resT constant_value = resT{};
70+
using supports_vec = typename std::negation<std::is_same<argT, bool>>;
71+
using supports_sg_loadstore = typename std::true_type;
72+
73+
resT operator()(const argT &in) const
74+
{
75+
if constexpr (std::is_same_v<argT, bool>) {
76+
return !in;
77+
}
78+
else {
79+
return ~in;
80+
}
81+
}
82+
83+
template <int vec_sz>
84+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in) const
85+
{
86+
return ~in;
87+
}
88+
};
89+
90+
template <typename argT,
91+
typename resT = argT,
92+
std::uint8_t vec_sz = 4u,
93+
std::uint8_t n_vecs = 2u,
94+
bool enable_sg_loadstore = true>
95+
using BitwiseInvertContigFunctor =
96+
elementwise_common::UnaryContigFunctor<argT,
97+
resT,
98+
BitwiseInvertFunctor<argT, resT>,
99+
vec_sz,
100+
n_vecs,
101+
enable_sg_loadstore>;
102+
103+
template <typename argTy, typename resTy, typename IndexerT>
104+
using BitwiseInvertStridedFunctor =
105+
elementwise_common::UnaryStridedFunctor<argTy,
106+
resTy,
107+
IndexerT,
108+
BitwiseInvertFunctor<argTy, resTy>>;
109+
110+
template <typename argTy>
111+
struct BitwiseInvertOutputType
112+
{
113+
using value_type = typename std::disjunction<
114+
td_ns::TypeMapResultEntry<argTy, bool>,
115+
td_ns::TypeMapResultEntry<argTy, std::uint8_t>,
116+
td_ns::TypeMapResultEntry<argTy, std::uint16_t>,
117+
td_ns::TypeMapResultEntry<argTy, std::uint32_t>,
118+
td_ns::TypeMapResultEntry<argTy, std::uint64_t>,
119+
td_ns::TypeMapResultEntry<argTy, std::int8_t>,
120+
td_ns::TypeMapResultEntry<argTy, std::int16_t>,
121+
td_ns::TypeMapResultEntry<argTy, std::int32_t>,
122+
td_ns::TypeMapResultEntry<argTy, std::int64_t>,
123+
td_ns::DefaultResultEntry<void>>::result_type;
124+
125+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
126+
};
127+
128+
namespace hyperparam_detail
129+
{
130+
131+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
132+
133+
using vsu_ns::ContigHyperparameterSetDefault;
134+
using vsu_ns::UnaryContigHyperparameterSetEntry;
135+
136+
template <typename argTy>
137+
struct BitwiseInvertContigHyperparameterSet
138+
{
139+
using value_type =
140+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
141+
142+
constexpr static auto vec_sz = value_type::vec_sz;
143+
constexpr static auto n_vecs = value_type::n_vecs;
144+
};
145+
146+
} // end of namespace hyperparam_detail
147+
148+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
149+
class bitwise_invert_contig_kernel;
150+
151+
template <typename argTy>
152+
sycl::event
153+
bitwise_invert_contig_impl(sycl::queue &exec_q,
154+
std::size_t nelems,
155+
const char *arg_p,
156+
char *res_p,
157+
const std::vector<sycl::event> &depends = {})
158+
{
159+
using BitwiseInvertHS =
160+
hyperparam_detail::BitwiseInvertContigHyperparameterSet<argTy>;
161+
static constexpr std::uint8_t vec_sz = BitwiseInvertHS::vec_sz;
162+
static constexpr std::uint8_t n_vec = BitwiseInvertHS::n_vecs;
163+
164+
return elementwise_common::unary_contig_impl<
165+
argTy, BitwiseInvertOutputType, BitwiseInvertContigFunctor,
166+
bitwise_invert_contig_kernel, vec_sz, n_vec>(exec_q, nelems, arg_p,
167+
res_p, depends);
168+
}
169+
170+
template <typename fnT, typename T>
171+
struct BitwiseInvertContigFactory
172+
{
173+
fnT get()
174+
{
175+
if constexpr (!BitwiseInvertOutputType<T>::is_defined) {
176+
fnT fn = nullptr;
177+
return fn;
178+
}
179+
else {
180+
fnT fn = bitwise_invert_contig_impl<T>;
181+
return fn;
182+
}
183+
}
184+
};
185+
186+
template <typename fnT, typename T>
187+
struct BitwiseInvertTypeMapFactory
188+
{
189+
/*! @brief get typeid for output type of sycl::logical_not(T x) */
190+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
191+
{
192+
using rT = typename BitwiseInvertOutputType<T>::value_type;
193+
return td_ns::GetTypeid<rT>{}.get();
194+
}
195+
};
196+
197+
template <typename T1, typename T2, typename T3>
198+
class bitwise_invert_strided_kernel;
199+
200+
template <typename argTy>
201+
sycl::event bitwise_invert_strided_impl(
202+
sycl::queue &exec_q,
203+
std::size_t nelems,
204+
int nd,
205+
const ssize_t *shape_and_strides,
206+
const char *arg_p,
207+
ssize_t arg_offset,
208+
char *res_p,
209+
ssize_t res_offset,
210+
const std::vector<sycl::event> &depends,
211+
const std::vector<sycl::event> &additional_depends)
212+
{
213+
return elementwise_common::unary_strided_impl<
214+
argTy, BitwiseInvertOutputType, BitwiseInvertStridedFunctor,
215+
bitwise_invert_strided_kernel>(exec_q, nelems, nd, shape_and_strides,
216+
arg_p, arg_offset, res_p, res_offset,
217+
depends, additional_depends);
218+
}
219+
220+
template <typename fnT, typename T>
221+
struct BitwiseInvertStridedFactory
222+
{
223+
fnT get()
224+
{
225+
if constexpr (!BitwiseInvertOutputType<T>::is_defined) {
226+
fnT fn = nullptr;
227+
return fn;
228+
}
229+
else {
230+
fnT fn = bitwise_invert_strided_impl<T>;
231+
return fn;
232+
}
233+
}
234+
};
235+
236+
} // namespace dpctl::tensor::kernels::bitwise_invert

0 commit comments

Comments
 (0)