Skip to content

Commit 4e10a53

Browse files
Move _atan/_atanh to _tensor_elementwise_impl
1 parent 0cec8c3 commit 4e10a53

File tree

8 files changed

+922
-6
lines changed

8 files changed

+922
-6
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ set(_elementwise_sources
7979
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp
8080
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
8181
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
82-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp
82+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp
8383
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp
84-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atanh.cpp
84+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atanh.cpp
8585
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_and.cpp
8686
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_invert.cpp
8787
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_left_shift.cpp
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of ATAN(x) function.
33+
//===---------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <cmath>
37+
#include <complex>
38+
#include <cstddef>
39+
#include <cstdint>
40+
#include <limits>
41+
#include <type_traits>
42+
#include <vector>
43+
44+
#include <sycl/sycl.hpp>
45+
46+
#include "sycl_complex.hpp"
47+
#include "vec_size_util.hpp"
48+
49+
#include "kernels/dpctl_tensor_types.hpp"
50+
#include "kernels/elementwise_functions/common.hpp"
51+
52+
#include "utils/offset_utils.hpp"
53+
#include "utils/type_dispatch_building.hpp"
54+
#include "utils/type_utils.hpp"
55+
56+
namespace dpctl::tensor::kernels::atan
57+
{
58+
59+
using dpctl::tensor::ssize_t;
60+
namespace td_ns = dpctl::tensor::type_dispatch;
61+
62+
using dpctl::tensor::kernels::vec_size_utils::ContigHyperparameterSetDefault;
63+
using dpctl::tensor::kernels::vec_size_utils::UnaryContigHyperparameterSetEntry;
64+
65+
using dpctl::tensor::type_utils::is_complex;
66+
67+
template <typename argT, typename resT>
68+
struct AtanFunctor
69+
{
70+
71+
// is function constant for given argT
72+
using is_constant = typename std::false_type;
73+
// constant value, if constant
74+
// constexpr resT constant_value = resT{};
75+
// is function defined for sycl::vec
76+
using supports_vec = typename std::false_type;
77+
// do both argTy and resTy support sugroup store/load operation
78+
using supports_sg_loadstore = typename std::negation<
79+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
80+
81+
resT operator()(const argT &in) const
82+
{
83+
if constexpr (is_complex<argT>::value) {
84+
using realT = typename argT::value_type;
85+
86+
static constexpr realT q_nan =
87+
std::numeric_limits<realT>::quiet_NaN();
88+
/*
89+
* atan(in) = I * conj( atanh(I * conj(in)) )
90+
* so we first calculate w = atanh(I * conj(in)) with
91+
* x = real(I * conj(in)) = imag(in)
92+
* y = imag(I * conj(in)) = real(in)
93+
* and then return {imag(w), real(w)} which is atan(in)
94+
*/
95+
const realT x = std::imag(in);
96+
const realT y = std::real(in);
97+
if (std::isnan(x)) {
98+
/* atanh(NaN + I*+-Inf) = sign(NaN)*0 + I*+-Pi/2 */
99+
if (std::isinf(y)) {
100+
const realT pi_half = sycl::atan(realT(1)) * 2;
101+
102+
const realT atanh_re = sycl::copysign(realT(0), x);
103+
const realT atanh_im = sycl::copysign(pi_half, y);
104+
return resT{atanh_im, atanh_re};
105+
}
106+
/*
107+
* All other cases involving NaN return NaN + I*NaN.
108+
*/
109+
return resT{q_nan, q_nan};
110+
}
111+
else if (std::isnan(y)) {
112+
/* atanh(+-Inf + I*NaN) = +-0 + I*NaN */
113+
if (std::isinf(x)) {
114+
const realT atanh_re = sycl::copysign(realT(0), x);
115+
const realT atanh_im = q_nan;
116+
return resT{atanh_im, atanh_re};
117+
}
118+
/* atanh(+-0 + I*NaN) = +-0 + I*NaN */
119+
if (x == realT(0)) {
120+
return resT{q_nan, x};
121+
}
122+
/*
123+
* All other cases involving NaN return NaN + I*NaN.
124+
*/
125+
return resT{q_nan, q_nan};
126+
}
127+
128+
/*
129+
* For large x or y including
130+
* atanh(+-Inf + I*+-Inf) = 0 + I*+-PI/2
131+
* The sign of pi/2 depends on the sign of imaginary part of the
132+
* input.
133+
*/
134+
static constexpr realT r_eps =
135+
realT(1) / std::numeric_limits<realT>::epsilon();
136+
if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) {
137+
const realT pi_half = sycl::atan(realT(1)) * 2;
138+
139+
const realT atanh_re = realT(0);
140+
const realT atanh_im = sycl::copysign(pi_half, y);
141+
return resT{atanh_im, atanh_re};
142+
}
143+
/* ordinary cases */
144+
return exprm_ns::atan(exprm_ns::complex<realT>(in)); // atan(in);
145+
}
146+
else {
147+
static_assert(std::is_floating_point_v<argT> ||
148+
std::is_same_v<argT, sycl::half>);
149+
return sycl::atan(in);
150+
}
151+
}
152+
};
153+
154+
template <typename argTy,
155+
typename resTy = argTy,
156+
std::uint8_t vec_sz = 4u,
157+
std::uint8_t n_vecs = 2u,
158+
bool enable_sg_loadstore = true>
159+
using AtanContigFunctor =
160+
elementwise_common::UnaryContigFunctor<argTy,
161+
resTy,
162+
AtanFunctor<argTy, resTy>,
163+
vec_sz,
164+
n_vecs,
165+
enable_sg_loadstore>;
166+
167+
template <typename argTy, typename resTy, typename IndexerT>
168+
using AtanStridedFunctor = elementwise_common::
169+
UnaryStridedFunctor<argTy, resTy, IndexerT, AtanFunctor<argTy, resTy>>;
170+
171+
template <typename T>
172+
struct AtanOutputType
173+
{
174+
using value_type = typename std::disjunction<
175+
td_ns::TypeMapResultEntry<T, sycl::half>,
176+
td_ns::TypeMapResultEntry<T, float>,
177+
td_ns::TypeMapResultEntry<T, double>,
178+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
179+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
180+
td_ns::DefaultResultEntry<void>>::result_type;
181+
182+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
183+
};
184+
185+
namespace hyperparam_detail
186+
{
187+
188+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
189+
190+
using vsu_ns::ContigHyperparameterSetDefault;
191+
using vsu_ns::UnaryContigHyperparameterSetEntry;
192+
193+
template <typename argTy>
194+
struct AtanContigHyperparameterSet
195+
{
196+
using value_type =
197+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
198+
199+
constexpr static auto vec_sz = value_type::vec_sz;
200+
constexpr static auto n_vecs = value_type::n_vecs;
201+
};
202+
203+
} // end of namespace hyperparam_detail
204+
205+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
206+
class atan_contig_kernel;
207+
208+
template <typename argTy>
209+
sycl::event atan_contig_impl(sycl::queue &exec_q,
210+
std::size_t nelems,
211+
const char *arg_p,
212+
char *res_p,
213+
const std::vector<sycl::event> &depends = {})
214+
{
215+
using AtanHS = hyperparam_detail::AtanContigHyperparameterSet<argTy>;
216+
static constexpr std::uint8_t vec_sz = AtanHS::vec_sz;
217+
static constexpr std::uint8_t n_vec = AtanHS::n_vecs;
218+
219+
return elementwise_common::unary_contig_impl<
220+
argTy, AtanOutputType, AtanContigFunctor, atan_contig_kernel, vec_sz,
221+
n_vec>(exec_q, nelems, arg_p, res_p, depends);
222+
}
223+
224+
template <typename fnT, typename T>
225+
struct AtanContigFactory
226+
{
227+
fnT get()
228+
{
229+
if constexpr (!AtanOutputType<T>::is_defined) {
230+
fnT fn = nullptr;
231+
return fn;
232+
}
233+
else {
234+
fnT fn = atan_contig_impl<T>;
235+
return fn;
236+
}
237+
}
238+
};
239+
240+
template <typename fnT, typename T>
241+
struct AtanTypeMapFactory
242+
{
243+
/*! @brief get typeid for output type of sycl::atan(T x) */
244+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
245+
{
246+
using rT = typename AtanOutputType<T>::value_type;
247+
return td_ns::GetTypeid<rT>{}.get();
248+
}
249+
};
250+
251+
template <typename T1, typename T2, typename T3>
252+
class atan_strided_kernel;
253+
254+
template <typename argTy>
255+
sycl::event
256+
atan_strided_impl(sycl::queue &exec_q,
257+
std::size_t nelems,
258+
int nd,
259+
const ssize_t *shape_and_strides,
260+
const char *arg_p,
261+
ssize_t arg_offset,
262+
char *res_p,
263+
ssize_t res_offset,
264+
const std::vector<sycl::event> &depends,
265+
const std::vector<sycl::event> &additional_depends)
266+
{
267+
return elementwise_common::unary_strided_impl<
268+
argTy, AtanOutputType, AtanStridedFunctor, atan_strided_kernel>(
269+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
270+
res_offset, depends, additional_depends);
271+
}
272+
273+
template <typename fnT, typename T>
274+
struct AtanStridedFactory
275+
{
276+
fnT get()
277+
{
278+
if constexpr (!AtanOutputType<T>::is_defined) {
279+
fnT fn = nullptr;
280+
return fn;
281+
}
282+
else {
283+
fnT fn = atan_strided_impl<T>;
284+
return fn;
285+
}
286+
}
287+
};
288+
289+
} // namespace dpctl::tensor::kernels::atan

0 commit comments

Comments
 (0)