Skip to content

Commit a51d34f

Browse files
Move ti.atan2()/bitwise_and() and reuse them
1 parent d8aab36 commit a51d34f

File tree

13 files changed

+1229
-13
lines changed

13 files changed

+1229
-13
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ set(_elementwise_sources
8080
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
8181
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
8282
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp
83-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp
83+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp
8484
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atanh.cpp
85-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_and.cpp
85+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_and.cpp
8686
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_invert.cpp
8787
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_left_shift.cpp
8888
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_or.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@
9898
asin,
9999
asinh,
100100
atan,
101+
atan2,
101102
atanh,
103+
bitwise_and,
102104
bitwise_invert,
103105
cbrt,
104106
ceil,
@@ -176,6 +178,8 @@
176178
"astype",
177179
"atan",
178180
"atanh",
181+
"atan2",
182+
"bitwise_and",
179183
"bitwise_invert",
180184
"broadcast_arrays",
181185
"broadcast_to",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,41 @@
246246
)
247247
del _atan_docstring
248248

249+
# B02: ===== ATAN2 (x1, x2)
250+
_atan2_docstring_ = r"""
251+
atan2(x1, x2, /, \*, out=None, order='K')
252+
253+
Calculates the inverse tangent of the quotient `x1_i/x2_i` for each element
254+
`x1_i` of the input array `x1` with the respective element `x2_i` of the
255+
input array `x2`. Each element-wise result is expressed in radians.
256+
257+
Args:
258+
x1 (usm_ndarray):
259+
First input array, expected to have a real-valued floating-point
260+
data type.
261+
x2 (usm_ndarray):
262+
Second input array, also expected to have a real-valued
263+
floating-point data type.
264+
out (Union[usm_ndarray, None], optional):
265+
Output array to populate.
266+
Array must have the correct shape and the expected data type.
267+
order ("C","F","A","K", optional):
268+
Memory layout of the new output array, if parameter
269+
`out` is ``None``.
270+
Default: "K".
271+
272+
Returns:
273+
usm_ndarray:
274+
An array containing the inverse tangent of the quotient `x1`/`x2`.
275+
The returned array must have a real-valued floating-point data type
276+
determined by Type Promotion Rules.
277+
"""
278+
279+
atan2 = BinaryElementwiseFunc(
280+
"atan2", ti._atan2_result_type, ti._atan2, _atan2_docstring_
281+
)
282+
del _atan2_docstring_
283+
249284
# U07: ===== ATANH (x)
250285
_atanh_docstring = r"""
251286
atanh(x, /, \*, out=None, order='K')
@@ -275,6 +310,43 @@
275310
)
276311
del _atanh_docstring
277312

313+
# B03: ===== BITWISE_AND (x1, x2)
314+
_bitwise_and_docstring_ = r"""
315+
bitwise_and(x1, x2, /, \*, out=None, order='K')
316+
317+
Computes the bitwise AND of the underlying binary representation of each
318+
element `x1_i` of the input array `x1` with the respective element `x2_i`
319+
of the input array `x2`.
320+
321+
Args:
322+
x1 (usm_ndarray):
323+
First input array, expected to have integer or boolean data type.
324+
x2 (usm_ndarray):
325+
Second input array, also expected to have integer or boolean data
326+
type.
327+
out (Union[usm_ndarray, None], optional):
328+
Output array to populate.
329+
Array must have the correct shape and the expected data type.
330+
order ("C","F","A","K", optional):
331+
Memory layout of the new output array, if parameter
332+
`out` is ``None``.
333+
Default: "K".
334+
335+
Returns:
336+
usm_ndarray:
337+
An array containing the element-wise results. The data type
338+
of the returned array is determined by the Type Promotion Rules.
339+
"""
340+
341+
bitwise_and = BinaryElementwiseFunc(
342+
"bitwise_and",
343+
ti._bitwise_and_result_type,
344+
ti._bitwise_and,
345+
_bitwise_and_docstring_,
346+
binary_inplace_fn=ti._bitwise_and_inplace,
347+
)
348+
del _bitwise_and_docstring_
349+
278350
# U08: ===== BITWISE_INVERT (x)
279351
_bitwise_invert_docstring = r"""
280352
bitwise_invert(x, /, \*, out=None, order='K')

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@
4545
#include "sycl_complex.hpp"
4646
#include "vec_size_util.hpp"
4747

48-
#include "utils/type_dispatch_building.hpp"
49-
#include "utils/type_utils.hpp"
50-
5148
#include "kernels/dpctl_tensor_types.hpp"
5249
#include "kernels/elementwise_functions/common.hpp"
5350
#include "kernels/elementwise_functions/common_inplace.hpp"
5451

52+
#include "utils/type_dispatch_building.hpp"
53+
#include "utils/type_utils.hpp"
54+
5555
namespace dpctl::tensor::kernels::add
5656
{
5757

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of ATAN2(x1, x2)
33+
/// function.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <cstddef>
38+
#include <cstdint>
39+
#include <type_traits>
40+
#include <vector>
41+
42+
#include <sycl/sycl.hpp>
43+
44+
#include "vec_size_util.hpp"
45+
46+
#include "kernels/dpctl_tensor_types.hpp"
47+
#include "kernels/elementwise_functions/common.hpp"
48+
49+
#include "utils/type_dispatch_building.hpp"
50+
51+
namespace dpctl::tensor::kernels::atan2
52+
{
53+
54+
using dpctl::tensor::ssize_t;
55+
namespace td_ns = dpctl::tensor::type_dispatch;
56+
57+
template <typename argT1, typename argT2, typename resT>
58+
struct Atan2Functor
59+
{
60+
61+
using supports_sg_loadstore = std::true_type;
62+
using supports_vec = std::false_type;
63+
64+
resT operator()(const argT1 &in1, const argT2 &in2) const
65+
{
66+
if (std::isinf(in2) && !sycl::signbit(in2)) {
67+
if (std::isfinite(in1)) {
68+
return sycl::copysign(resT(0), in1);
69+
}
70+
}
71+
return sycl::atan2(in1, in2);
72+
}
73+
};
74+
75+
template <typename argT1,
76+
typename argT2,
77+
typename resT,
78+
std::uint8_t vec_sz = 4u,
79+
std::uint8_t n_vecs = 2u,
80+
bool enable_sg_loadstore = true>
81+
using Atan2ContigFunctor =
82+
elementwise_common::BinaryContigFunctor<argT1,
83+
argT2,
84+
resT,
85+
Atan2Functor<argT1, argT2, resT>,
86+
vec_sz,
87+
n_vecs,
88+
enable_sg_loadstore>;
89+
90+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
91+
using Atan2StridedFunctor =
92+
elementwise_common::BinaryStridedFunctor<argT1,
93+
argT2,
94+
resT,
95+
IndexerT,
96+
Atan2Functor<argT1, argT2, resT>>;
97+
98+
template <typename T1, typename T2>
99+
struct Atan2OutputType
100+
{
101+
using value_type = typename std::disjunction<
102+
td_ns::BinaryTypeMapResultEntry<T1,
103+
sycl::half,
104+
T2,
105+
sycl::half,
106+
sycl::half>,
107+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
108+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
109+
td_ns::DefaultResultEntry<void>>::result_type;
110+
111+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
112+
};
113+
114+
namespace hyperparam_detail
115+
{
116+
117+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
118+
119+
using vsu_ns::BinaryContigHyperparameterSetEntry;
120+
using vsu_ns::ContigHyperparameterSetDefault;
121+
122+
template <typename argTy1, typename argTy2>
123+
struct Atan2ContigHyperparameterSet
124+
{
125+
using value_type =
126+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
127+
128+
constexpr static auto vec_sz = value_type::vec_sz;
129+
constexpr static auto n_vecs = value_type::n_vecs;
130+
};
131+
132+
} // end of namespace hyperparam_detail
133+
134+
template <typename argT1,
135+
typename argT2,
136+
typename resT,
137+
std::uint8_t vec_sz,
138+
std::uint8_t n_vecs>
139+
class atan2_contig_kernel;
140+
141+
template <typename argTy1, typename argTy2>
142+
sycl::event atan2_contig_impl(sycl::queue &exec_q,
143+
std::size_t nelems,
144+
const char *arg1_p,
145+
ssize_t arg1_offset,
146+
const char *arg2_p,
147+
ssize_t arg2_offset,
148+
char *res_p,
149+
ssize_t res_offset,
150+
const std::vector<sycl::event> &depends = {})
151+
{
152+
using Atan2HS =
153+
hyperparam_detail::Atan2ContigHyperparameterSet<argTy1, argTy2>;
154+
static constexpr std::uint8_t vec_sz = Atan2HS::vec_sz;
155+
static constexpr std::uint8_t n_vecs = Atan2HS::n_vecs;
156+
157+
return elementwise_common::binary_contig_impl<
158+
argTy1, argTy2, Atan2OutputType, Atan2ContigFunctor,
159+
atan2_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg1_p,
160+
arg1_offset, arg2_p, arg2_offset,
161+
res_p, res_offset, depends);
162+
}
163+
164+
template <typename fnT, typename T1, typename T2>
165+
struct Atan2ContigFactory
166+
{
167+
fnT get()
168+
{
169+
if constexpr (!Atan2OutputType<T1, T2>::is_defined) {
170+
fnT fn = nullptr;
171+
return fn;
172+
}
173+
else {
174+
fnT fn = atan2_contig_impl<T1, T2>;
175+
return fn;
176+
}
177+
}
178+
};
179+
180+
template <typename fnT, typename T1, typename T2>
181+
struct Atan2TypeMapFactory
182+
{
183+
/*! @brief get typeid for output type of sycl::atan2(T1 x, T2 y) */
184+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
185+
{
186+
using rT = typename Atan2OutputType<T1, T2>::value_type;
187+
return td_ns::GetTypeid<rT>{}.get();
188+
}
189+
};
190+
191+
template <typename T1, typename T2, typename resT, typename IndexerT>
192+
class atan2_strided_kernel;
193+
194+
template <typename argTy1, typename argTy2>
195+
sycl::event
196+
atan2_strided_impl(sycl::queue &exec_q,
197+
std::size_t nelems,
198+
int nd,
199+
const ssize_t *shape_and_strides,
200+
const char *arg1_p,
201+
ssize_t arg1_offset,
202+
const char *arg2_p,
203+
ssize_t arg2_offset,
204+
char *res_p,
205+
ssize_t res_offset,
206+
const std::vector<sycl::event> &depends,
207+
const std::vector<sycl::event> &additional_depends)
208+
{
209+
return elementwise_common::binary_strided_impl<
210+
argTy1, argTy2, Atan2OutputType, Atan2StridedFunctor,
211+
atan2_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
212+
arg1_offset, arg2_p, arg2_offset, res_p,
213+
res_offset, depends, additional_depends);
214+
}
215+
216+
template <typename fnT, typename T1, typename T2>
217+
struct Atan2StridedFactory
218+
{
219+
fnT get()
220+
{
221+
if constexpr (!Atan2OutputType<T1, T2>::is_defined) {
222+
fnT fn = nullptr;
223+
return fn;
224+
}
225+
else {
226+
fnT fn = atan2_strided_impl<T1, T2>;
227+
return fn;
228+
}
229+
}
230+
};
231+
232+
} // namespace dpctl::tensor::kernels::atan2

0 commit comments

Comments
 (0)