Skip to content

Commit a1707b2

Browse files
Move ti.square()/sqrt()/tan()/tanh() and reuse them
1 parent 0bc8973 commit a1707b2

File tree

17 files changed

+1842
-20
lines changed

17 files changed

+1842
-20
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ set(_elementwise_sources
136136
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/signbit.cpp
137137
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sin.cpp
138138
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sinh.cpp
139-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sqrt.cpp
140-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/square.cpp
139+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sqrt.cpp
140+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/square.cpp
141141
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/subtract.cpp
142-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tan.cpp
143-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tanh.cpp
142+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tan.cpp
143+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tanh.cpp
144144
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp
145145
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp
146146
)

dpctl_ext/tensor/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@
118118
signbit,
119119
sin,
120120
sinh,
121+
sqrt,
122+
square,
123+
tan,
124+
tanh,
121125
)
122126
from ._reduction import (
123127
argmax,
@@ -230,12 +234,16 @@
230234
"sin",
231235
"sinh",
232236
"sort",
237+
"sqrt",
238+
"square",
233239
"squeeze",
234240
"stack",
235241
"sum",
236242
"swapaxes",
237243
"take",
238244
"take_along_axis",
245+
"tan",
246+
"tanh",
239247
"tile",
240248
"top_k",
241249
"to_numpy",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,117 @@
931931
)
932932
del _sinh_docstring
933933

934+
# U32: ==== SQUARE (x)
935+
_square_docstring_ = r"""
936+
square(x, /, \*, out=None, order='K')
937+
938+
Squares each element `x_i` of input array `x`.
939+
940+
Args:
941+
x (usm_ndarray):
942+
Input array. May have any data type.
943+
out (Union[usm_ndarray, None], optional):
944+
Output array to populate.
945+
Array must have the correct shape and the expected data type.
946+
order ("C","F","A","K", optional):
947+
Memory layout of the new output array, if parameter
948+
`out` is ``None``.
949+
Default: "K".
950+
951+
Returns:
952+
usm_ndarray:
953+
An array containing the element-wise squares of `x`. The data type of
954+
the returned array is determined by the Type Promotion Rules.
955+
"""
956+
957+
square = UnaryElementwiseFunc(
958+
"square", ti._square_result_type, ti._square, _square_docstring_
959+
)
960+
del _square_docstring_
961+
962+
# U33: ==== SQRT (x)
963+
_sqrt_docstring_ = r"""
964+
sqrt(x, /, \*, out=None, order='K')
965+
966+
Computes the positive square-root for each element `x_i` of input array `x`.
967+
968+
Args:
969+
x (usm_ndarray):
970+
Input array, expected to have a floating-point data type.
971+
out (Union[usm_ndarray, None], optional):
972+
Output array to populate.
973+
Array must have the correct shape and the expected data type.
974+
order ("C","F","A","K", optional):
975+
Memory layout of the new output array, if parameter
976+
`out` is ``None``.
977+
Default: "K".
978+
979+
Returns:
980+
usm_ndarray:
981+
An array containing the element-wise positive square-roots of `x`. The
982+
data type of the returned array is determined by the Type Promotion
983+
Rules.
984+
"""
985+
986+
sqrt = UnaryElementwiseFunc(
987+
"sqrt", ti._sqrt_result_type, ti._sqrt, _sqrt_docstring_
988+
)
989+
del _sqrt_docstring_
990+
991+
# U34: ==== TAN (x)
992+
_tan_docstring = r"""
993+
tan(x, /, \*, out=None, order='K')
994+
995+
Computes tangent for each element `x_i` for input array `x`.
996+
997+
Args:
998+
x (usm_ndarray):
999+
Input array, expected to have a floating-point data type.
1000+
out (Union[usm_ndarray, None], optional):
1001+
Output array to populate.
1002+
Array must have the correct shape and the expected data type.
1003+
order ("C","F","A","K", optional):
1004+
Memory layout of the new output array, if parameter
1005+
`out` is ``None``.
1006+
Default: "K".
1007+
1008+
Returns:
1009+
usm_ndarray:
1010+
An array containing the element-wise tangent. The data type
1011+
of the returned array is determined by the Type Promotion Rules.
1012+
"""
1013+
1014+
tan = UnaryElementwiseFunc("tan", ti._tan_result_type, ti._tan, _tan_docstring)
1015+
del _tan_docstring
1016+
1017+
# U35: ==== TANH (x)
1018+
_tanh_docstring = r"""
1019+
tanh(x, /, \*, out=None, order='K')
1020+
1021+
Computes hyperbolic tangent for each element `x_i` for input array `x`.
1022+
1023+
Args:
1024+
x (usm_ndarray):
1025+
Input array, expected to have a floating-point data type.
1026+
out (Union[usm_ndarray, None], optional):
1027+
Output array to populate.
1028+
Array must have the correct shape and the expected data type.
1029+
order ("C","F","A","K", optional):
1030+
Memory layout of the new output array, if parameter
1031+
`out` is ``None``.
1032+
Default: "K".
1033+
1034+
Returns:
1035+
usm_ndarray:
1036+
An array containing the element-wise hyperbolic tangent. The data type
1037+
of the returned array is determined by the Type Promotion Rules.
1038+
"""
1039+
1040+
tanh = UnaryElementwiseFunc(
1041+
"tanh", ti._tanh_result_type, ti._tanh, _tanh_docstring
1042+
)
1043+
del _tanh_docstring
1044+
9341045
# U40: ==== PROJ (x)
9351046
_proj_docstring = r"""
9361047
proj(x, /, \*, out=None, order='K')
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of SQRT(x)
33+
/// function that compute a square root.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <complex>
38+
#include <cstddef>
39+
#include <cstdint>
40+
#include <type_traits>
41+
#include <vector>
42+
43+
#include <sycl/sycl.hpp>
44+
45+
#include "sycl_complex.hpp"
46+
#include "vec_size_util.hpp"
47+
48+
#include "kernels/dpctl_tensor_types.hpp"
49+
#include "kernels/elementwise_functions/common.hpp"
50+
51+
#include "utils/type_dispatch_building.hpp"
52+
#include "utils/type_utils.hpp"
53+
54+
namespace dpctl::tensor::kernels::sqrt
55+
{
56+
57+
using dpctl::tensor::ssize_t;
58+
namespace td_ns = dpctl::tensor::type_dispatch;
59+
60+
using dpctl::tensor::type_utils::is_complex;
61+
62+
template <typename argT, typename resT>
63+
struct SqrtFunctor
64+
{
65+
66+
// is function constant for given argT
67+
using is_constant = typename std::false_type;
68+
// constant value, if constant
69+
// constexpr resT constant_value = resT{};
70+
// is function defined for sycl::vec
71+
using supports_vec = typename std::false_type;
72+
// do both argTy and resTy support sugroup store/load operation
73+
using supports_sg_loadstore = typename std::negation<
74+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
75+
76+
resT operator()(const argT &in) const
77+
{
78+
if constexpr (is_complex<argT>::value) {
79+
using realT = typename argT::value_type;
80+
return exprm_ns::sqrt(exprm_ns::complex<realT>(in));
81+
}
82+
else {
83+
return sycl::sqrt(in);
84+
}
85+
}
86+
};
87+
88+
template <typename argTy,
89+
typename resTy = argTy,
90+
std::uint8_t vec_sz = 4u,
91+
std::uint8_t n_vecs = 2u,
92+
bool enable_sg_loadstore = true>
93+
using SqrtContigFunctor =
94+
elementwise_common::UnaryContigFunctor<argTy,
95+
resTy,
96+
SqrtFunctor<argTy, resTy>,
97+
vec_sz,
98+
n_vecs,
99+
enable_sg_loadstore>;
100+
101+
template <typename argTy, typename resTy, typename IndexerT>
102+
using SqrtStridedFunctor = elementwise_common::
103+
UnaryStridedFunctor<argTy, resTy, IndexerT, SqrtFunctor<argTy, resTy>>;
104+
105+
template <typename T>
106+
struct SqrtOutputType
107+
{
108+
using value_type = typename std::disjunction<
109+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
110+
td_ns::TypeMapResultEntry<T, float, float>,
111+
td_ns::TypeMapResultEntry<T, double, double>,
112+
td_ns::TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
113+
td_ns::
114+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
115+
td_ns::DefaultResultEntry<void>>::result_type;
116+
117+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
118+
};
119+
120+
namespace hyperparam_detail
121+
{
122+
123+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
124+
125+
using vsu_ns::ContigHyperparameterSetDefault;
126+
using vsu_ns::UnaryContigHyperparameterSetEntry;
127+
128+
template <typename argTy>
129+
struct SqrtContigHyperparameterSet
130+
{
131+
using value_type =
132+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
133+
134+
constexpr static auto vec_sz = value_type::vec_sz;
135+
constexpr static auto n_vecs = value_type::n_vecs;
136+
};
137+
138+
} // end of namespace hyperparam_detail
139+
140+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
141+
class sqrt_contig_kernel;
142+
143+
template <typename argTy>
144+
sycl::event sqrt_contig_impl(sycl::queue &exec_q,
145+
std::size_t nelems,
146+
const char *arg_p,
147+
char *res_p,
148+
const std::vector<sycl::event> &depends = {})
149+
{
150+
using SqrtHS = hyperparam_detail::SqrtContigHyperparameterSet<argTy>;
151+
static constexpr std::uint8_t vec_sz = SqrtHS::vec_sz;
152+
static constexpr std::uint8_t n_vecs = SqrtHS::n_vecs;
153+
154+
return elementwise_common::unary_contig_impl<
155+
argTy, SqrtOutputType, SqrtContigFunctor, sqrt_contig_kernel, vec_sz,
156+
n_vecs>(exec_q, nelems, arg_p, res_p, depends);
157+
}
158+
159+
template <typename fnT, typename T>
160+
struct SqrtContigFactory
161+
{
162+
fnT get()
163+
{
164+
if constexpr (!SqrtOutputType<T>::is_defined) {
165+
fnT fn = nullptr;
166+
return fn;
167+
}
168+
else {
169+
fnT fn = sqrt_contig_impl<T>;
170+
return fn;
171+
}
172+
}
173+
};
174+
175+
template <typename fnT, typename T>
176+
struct SqrtTypeMapFactory
177+
{
178+
/*! @brief get typeid for output type of std::sqrt(T x) */
179+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
180+
{
181+
using rT = typename SqrtOutputType<T>::value_type;
182+
return td_ns::GetTypeid<rT>{}.get();
183+
}
184+
};
185+
186+
template <typename T1, typename T2, typename T3>
187+
class sqrt_strided_kernel;
188+
189+
template <typename argTy>
190+
sycl::event
191+
sqrt_strided_impl(sycl::queue &exec_q,
192+
std::size_t nelems,
193+
int nd,
194+
const ssize_t *shape_and_strides,
195+
const char *arg_p,
196+
ssize_t arg_offset,
197+
char *res_p,
198+
ssize_t res_offset,
199+
const std::vector<sycl::event> &depends,
200+
const std::vector<sycl::event> &additional_depends)
201+
{
202+
return elementwise_common::unary_strided_impl<
203+
argTy, SqrtOutputType, SqrtStridedFunctor, sqrt_strided_kernel>(
204+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
205+
res_offset, depends, additional_depends);
206+
}
207+
208+
template <typename fnT, typename T>
209+
struct SqrtStridedFactory
210+
{
211+
fnT get()
212+
{
213+
if constexpr (!SqrtOutputType<T>::is_defined) {
214+
fnT fn = nullptr;
215+
return fn;
216+
}
217+
else {
218+
fnT fn = sqrt_strided_impl<T>;
219+
return fn;
220+
}
221+
}
222+
};
223+
224+
} // namespace dpctl::tensor::kernels::sqrt

0 commit comments

Comments
 (0)