Skip to content

Commit ec74213

Browse files
Move _acos/_acosh to _tensor_elementwise_impl
1 parent d4cda7c commit ec74213

File tree

8 files changed

+928
-6
lines changed

8 files changed

+928
-6
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ set(_elementwise_sources
7373
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_common.cpp
7474
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
7575
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/abs.cpp
76-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
77-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
76+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
77+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
7878
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp
7979
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp
8080
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of ACOS(x) function.
33+
//===---------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <cmath>
37+
#include <complex>
38+
#include <cstddef>
39+
#include <cstdint>
40+
#include <limits>
41+
#include <type_traits>
42+
#include <vector>
43+
44+
#include "sycl_complex.hpp"
45+
#include "vec_size_util.hpp"
46+
47+
#include "kernels/dpctl_tensor_types.hpp"
48+
#include "kernels/elementwise_functions/common.hpp"
49+
50+
#include "utils/offset_utils.hpp"
51+
#include "utils/type_dispatch_building.hpp"
52+
#include "utils/type_utils.hpp"
53+
54+
namespace dpctl::tensor::kernels::acos
55+
{
56+
57+
using dpctl::tensor::ssize_t;
58+
namespace td_ns = dpctl::tensor::type_dispatch;
59+
60+
using dpctl::tensor::type_utils::is_complex;
61+
62+
template <typename argT, typename resT>
63+
struct AcosFunctor
64+
{
65+
66+
// is function constant for given argT
67+
using is_constant = typename std::false_type;
68+
// constant value, if constant
69+
// constexpr resT constant_value = resT{};
70+
// is function defined for sycl::vec
71+
using supports_vec = typename std::false_type;
72+
// do both argTy and resTy support sugroup store/load operation
73+
using supports_sg_loadstore = typename std::negation<
74+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
75+
76+
resT operator()(const argT &in) const
77+
{
78+
if constexpr (is_complex<argT>::value) {
79+
using realT = typename argT::value_type;
80+
81+
static constexpr realT q_nan =
82+
std::numeric_limits<realT>::quiet_NaN();
83+
84+
const realT x = std::real(in);
85+
const realT y = std::imag(in);
86+
87+
if (std::isnan(x)) {
88+
/* acos(NaN + I*+-Inf) = NaN + I*-+Inf */
89+
if (std::isinf(y)) {
90+
return resT{q_nan, -y};
91+
}
92+
93+
/* all other cases involving NaN return NaN + I*NaN. */
94+
return resT{q_nan, q_nan};
95+
}
96+
if (std::isnan(y)) {
97+
/* acos(+-Inf + I*NaN) = NaN + I*opt(-)Inf */
98+
if (std::isinf(x)) {
99+
return resT{q_nan, -std::numeric_limits<realT>::infinity()};
100+
}
101+
/* acos(0 + I*NaN) = PI/2 + I*NaN with inexact */
102+
if (x == realT(0)) {
103+
const realT res_re = sycl::atan(realT(1)) * 2; // PI/2
104+
return resT{res_re, q_nan};
105+
}
106+
107+
/* all other cases involving NaN return NaN + I*NaN. */
108+
return resT{q_nan, q_nan};
109+
}
110+
111+
/*
112+
* For large x or y including acos(+-Inf + I*+-Inf)
113+
*/
114+
static constexpr realT r_eps =
115+
realT(1) / std::numeric_limits<realT>::epsilon();
116+
if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) {
117+
using sycl_complexT = exprm_ns::complex<realT>;
118+
sycl_complexT log_in =
119+
exprm_ns::log(exprm_ns::complex<realT>(in));
120+
121+
const realT wx = log_in.real();
122+
const realT wy = log_in.imag();
123+
const realT rx = sycl::fabs(wy);
124+
125+
realT ry = wx + sycl::log(realT(2));
126+
return resT{rx, (sycl::signbit(y)) ? ry : -ry};
127+
}
128+
129+
/* ordinary cases */
130+
return exprm_ns::acos(exprm_ns::complex<realT>(in)); // acos(in);
131+
}
132+
else {
133+
static_assert(std::is_floating_point_v<argT> ||
134+
std::is_same_v<argT, sycl::half>);
135+
return sycl::acos(in);
136+
}
137+
}
138+
};
139+
140+
template <typename argTy,
141+
typename resTy = argTy,
142+
std::uint8_t vec_sz = 4u,
143+
std::uint8_t n_vecs = 2u,
144+
bool enable_sg_loadstore = true>
145+
using AcosContigFunctor =
146+
elementwise_common::UnaryContigFunctor<argTy,
147+
resTy,
148+
AcosFunctor<argTy, resTy>,
149+
vec_sz,
150+
n_vecs,
151+
enable_sg_loadstore>;
152+
153+
template <typename argTy, typename resTy, typename IndexerT>
154+
using AcosStridedFunctor = elementwise_common::
155+
UnaryStridedFunctor<argTy, resTy, IndexerT, AcosFunctor<argTy, resTy>>;
156+
157+
template <typename T>
158+
struct AcosOutputType
159+
{
160+
using value_type = typename std::disjunction<
161+
td_ns::TypeMapResultEntry<T, sycl::half>,
162+
td_ns::TypeMapResultEntry<T, float>,
163+
td_ns::TypeMapResultEntry<T, double>,
164+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
165+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
166+
td_ns::DefaultResultEntry<void>>::result_type;
167+
168+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
169+
};
170+
171+
namespace hyperparam_detail
172+
{
173+
174+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
175+
176+
using vsu_ns::ContigHyperparameterSetDefault;
177+
178+
template <typename argTy>
179+
struct AcosContigHyperparameterSet
180+
{
181+
using value_type =
182+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
183+
184+
constexpr static auto vec_sz = value_type::vec_sz;
185+
constexpr static auto n_vecs = value_type::n_vecs;
186+
};
187+
188+
} // end of namespace hyperparam_detail
189+
190+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
191+
class acos_contig_kernel;
192+
193+
template <typename argTy>
194+
sycl::event acos_contig_impl(sycl::queue &exec_q,
195+
std::size_t nelems,
196+
const char *arg_p,
197+
char *res_p,
198+
const std::vector<sycl::event> &depends = {})
199+
{
200+
using AcosHS = hyperparam_detail::AcosContigHyperparameterSet<argTy>;
201+
static constexpr std::uint8_t vec_sz = AcosHS::vec_sz;
202+
static constexpr std::uint8_t n_vec = AcosHS::n_vecs;
203+
204+
return elementwise_common::unary_contig_impl<
205+
argTy, AcosOutputType, AcosContigFunctor, acos_contig_kernel, vec_sz,
206+
n_vec>(exec_q, nelems, arg_p, res_p, depends);
207+
}
208+
209+
template <typename fnT, typename T>
210+
struct AcosContigFactory
211+
{
212+
fnT get()
213+
{
214+
if constexpr (!AcosOutputType<T>::is_defined) {
215+
fnT fn = nullptr;
216+
return fn;
217+
}
218+
else {
219+
fnT fn = acos_contig_impl<T>;
220+
return fn;
221+
}
222+
}
223+
};
224+
225+
template <typename fnT, typename T>
226+
struct AcosTypeMapFactory
227+
{
228+
/*! @brief get typeid for output type of sycl::acos(T x) */
229+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
230+
{
231+
using rT = typename AcosOutputType<T>::value_type;
232+
return td_ns::GetTypeid<rT>{}.get();
233+
}
234+
};
235+
236+
template <typename T1, typename T2, typename T3>
237+
class acos_strided_kernel;
238+
239+
template <typename argTy>
240+
sycl::event
241+
acos_strided_impl(sycl::queue &exec_q,
242+
std::size_t nelems,
243+
int nd,
244+
const ssize_t *shape_and_strides,
245+
const char *arg_p,
246+
ssize_t arg_offset,
247+
char *res_p,
248+
ssize_t res_offset,
249+
const std::vector<sycl::event> &depends,
250+
const std::vector<sycl::event> &additional_depends)
251+
{
252+
return elementwise_common::unary_strided_impl<
253+
argTy, AcosOutputType, AcosStridedFunctor, acos_strided_kernel>(
254+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
255+
res_offset, depends, additional_depends);
256+
}
257+
258+
template <typename fnT, typename T>
259+
struct AcosStridedFactory
260+
{
261+
fnT get()
262+
{
263+
if constexpr (!AcosOutputType<T>::is_defined) {
264+
fnT fn = nullptr;
265+
return fn;
266+
}
267+
else {
268+
fnT fn = acos_strided_impl<T>;
269+
return fn;
270+
}
271+
}
272+
};
273+
274+
} // namespace dpctl::tensor::kernels::acos

0 commit comments

Comments
 (0)