Skip to content

Commit a548dbb

Browse files
Move _angle/_asin/_asinh to _tensor_elementwise_impl
1 parent d4293a2 commit a548dbb

File tree

11 files changed

+1319
-9
lines changed

11 files changed

+1319
-9
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ set(_elementwise_sources
7676
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
7777
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
7878
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp
79-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp
80-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
81-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
79+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp
80+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
81+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
8282
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp
8383
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp
8484
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atanh.cpp
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of ANGLE(x) function.
33+
//===---------------------------------------------------------------------===//
34+
35+
#include <cmath>
36+
#include <complex>
37+
#include <cstddef>
38+
#include <cstdint>
39+
#include <type_traits>
40+
#include <vector>
41+
42+
#include <sycl/sycl.hpp>
43+
44+
#include "sycl_complex.hpp"
45+
#include "vec_size_util.hpp"
46+
47+
#include "kernels/dpctl_tensor_types.hpp"
48+
#include "kernels/elementwise_functions/common.hpp"
49+
50+
#include "utils/offset_utils.hpp"
51+
#include "utils/type_dispatch_building.hpp"
52+
#include "utils/type_utils.hpp"
53+
54+
namespace dpctl::tensor::kernels::angle
55+
{
56+
57+
using dpctl::tensor::ssize_t;
58+
namespace td_ns = dpctl::tensor::type_dispatch;
59+
60+
using dpctl::tensor::type_utils::is_complex;
61+
62+
template <typename argT, typename resT>
63+
struct AngleFunctor
64+
{
65+
66+
// is function constant for given argT
67+
using is_constant = typename std::false_type;
68+
// constant value, if constant
69+
// constexpr resT constant_value = resT{};
70+
// is function defined for sycl::vec
71+
using supports_vec = typename std::false_type;
72+
// do both argTy and resTy support sugroup store/load operation
73+
using supports_sg_loadstore = typename std::negation<
74+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
75+
76+
resT operator()(const argT &in) const
77+
{
78+
using rT = typename argT::value_type;
79+
80+
return exprm_ns::arg(exprm_ns::complex<rT>(in)); // arg(in);
81+
}
82+
};
83+
84+
template <typename argTy,
85+
typename resTy = argTy,
86+
std::uint8_t vec_sz = 4u,
87+
std::uint8_t n_vecs = 2u,
88+
bool enable_sg_loadstore = true>
89+
using AngleContigFunctor =
90+
elementwise_common::UnaryContigFunctor<argTy,
91+
resTy,
92+
AngleFunctor<argTy, resTy>,
93+
vec_sz,
94+
n_vecs,
95+
enable_sg_loadstore>;
96+
97+
template <typename argTy, typename resTy, typename IndexerT>
98+
using AngleStridedFunctor = elementwise_common::
99+
UnaryStridedFunctor<argTy, resTy, IndexerT, AngleFunctor<argTy, resTy>>;
100+
101+
template <typename T>
102+
struct AngleOutputType
103+
{
104+
using value_type = typename std::disjunction<
105+
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
106+
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
107+
td_ns::DefaultResultEntry<void>>::result_type;
108+
109+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
110+
};
111+
112+
namespace hyperparam_detail
113+
{
114+
115+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
116+
117+
using vsu_ns::ContigHyperparameterSetDefault;
118+
using vsu_ns::UnaryContigHyperparameterSetEntry;
119+
120+
template <typename argTy>
121+
struct AngleContigHyperparameterSet
122+
{
123+
using value_type =
124+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
125+
126+
constexpr static auto vec_sz = value_type::vec_sz;
127+
constexpr static auto n_vecs = value_type::n_vecs;
128+
};
129+
130+
} // end of namespace hyperparam_detail
131+
132+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
133+
class angle_contig_kernel;
134+
135+
template <typename argTy>
136+
sycl::event angle_contig_impl(sycl::queue &exec_q,
137+
std::size_t nelems,
138+
const char *arg_p,
139+
char *res_p,
140+
const std::vector<sycl::event> &depends = {})
141+
{
142+
using AngleHS = hyperparam_detail::AngleContigHyperparameterSet<argTy>;
143+
static constexpr std::uint8_t vec_sz = AngleHS::vec_sz;
144+
static constexpr std::uint8_t n_vec = AngleHS::n_vecs;
145+
146+
return elementwise_common::unary_contig_impl<
147+
argTy, AngleOutputType, AngleContigFunctor, angle_contig_kernel, vec_sz,
148+
n_vec>(exec_q, nelems, arg_p, res_p, depends);
149+
}
150+
151+
template <typename fnT, typename T>
152+
struct AngleContigFactory
153+
{
154+
fnT get()
155+
{
156+
if constexpr (!AngleOutputType<T>::is_defined) {
157+
fnT fn = nullptr;
158+
return fn;
159+
}
160+
else {
161+
fnT fn = angle_contig_impl<T>;
162+
return fn;
163+
}
164+
}
165+
};
166+
167+
template <typename fnT, typename T>
168+
struct AngleTypeMapFactory
169+
{
170+
/*! @brief get typeid for output type of std::arg(T x) */
171+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
172+
{
173+
using rT = typename AngleOutputType<T>::value_type;
174+
return td_ns::GetTypeid<rT>{}.get();
175+
}
176+
};
177+
178+
template <typename T1, typename T2, typename T3>
179+
class angle_strided_kernel;
180+
181+
template <typename argTy>
182+
sycl::event
183+
angle_strided_impl(sycl::queue &exec_q,
184+
std::size_t nelems,
185+
int nd,
186+
const ssize_t *shape_and_strides,
187+
const char *arg_p,
188+
ssize_t arg_offset,
189+
char *res_p,
190+
ssize_t res_offset,
191+
const std::vector<sycl::event> &depends,
192+
const std::vector<sycl::event> &additional_depends)
193+
{
194+
return elementwise_common::unary_strided_impl<
195+
argTy, AngleOutputType, AngleStridedFunctor, angle_strided_kernel>(
196+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
197+
res_offset, depends, additional_depends);
198+
}
199+
200+
template <typename fnT, typename T>
201+
struct AngleStridedFactory
202+
{
203+
fnT get()
204+
{
205+
if constexpr (!AngleOutputType<T>::is_defined) {
206+
fnT fn = nullptr;
207+
return fn;
208+
}
209+
else {
210+
fnT fn = angle_strided_impl<T>;
211+
return fn;
212+
}
213+
}
214+
};
215+
216+
} // namespace dpctl::tensor::kernels::angle

0 commit comments

Comments
 (0)