Skip to content

Commit 44ac844

Browse files
Move _exp/_expmp/_floor to dpctl_ext.tensor and reuse them
1 parent e6d0d6f commit 44ac844

File tree

15 files changed

+1397
-15
lines changed

15 files changed

+1397
-15
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ set(_elementwise_sources
9595
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
9696
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
9797
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/equal.cpp
98-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp.cpp
98+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp.cpp
9999
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp2.cpp
100-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/expm1.cpp
100+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/expm1.cpp
101101
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor_divide.cpp
102-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor.cpp
102+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor.cpp
103103
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/greater_equal.cpp
104104
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/greater.cpp
105105
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/hypot.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@
9797
conj,
9898
cos,
9999
cosh,
100+
exp,
101+
expm1,
102+
floor,
100103
)
101104
from ._reduction import (
102105
argmax,
@@ -159,8 +162,11 @@
159162
"extract",
160163
"expand_dims",
161164
"eye",
165+
"exp",
166+
"expm1",
162167
"finfo",
163168
"flip",
169+
"floor",
164170
"from_numpy",
165171
"full",
166172
"full_like",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,92 @@
378378
)
379379
del _cosh_docstring
380380

381+
# U13: ==== EXP (x)
382+
_exp_docstring = r"""
383+
exp(x, /, \*, out=None, order='K')
384+
385+
Computes the exponential for each element `x_i` of input array `x`.
386+
387+
Args:
388+
x (usm_ndarray):
389+
Input array, expected to have a floating-point data type.
390+
out (Union[usm_ndarray, None], optional):
391+
Output array to populate.
392+
Array must have the correct shape and the expected data type.
393+
order ("C","F","A","K", optional):
394+
Memory layout of the new output array, if parameter
395+
`out` is ``None``.
396+
Default: "K".
397+
398+
Returns:
399+
usm_ndarray:
400+
An array containing the element-wise exponential of `x`.
401+
The data type of the returned array is determined by
402+
the Type Promotion Rules.
403+
"""
404+
405+
exp = UnaryElementwiseFunc("exp", ti._exp_result_type, ti._exp, _exp_docstring)
406+
del _exp_docstring
407+
408+
# U14: ==== EXPM1 (x)
409+
_expm1_docstring = r"""
410+
expm1(x, /, \*, out=None, order='K')
411+
412+
Computes the exponential minus 1 for each element `x_i` of input array `x`.
413+
414+
This function calculates `exp(x) - 1.0` more accurately for small values of `x`.
415+
416+
Args:
417+
x (usm_ndarray):
418+
Input array, expected to have a floating-point data type.
419+
out (usm_ndarray):
420+
Output array to populate. Array must have the correct
421+
shape and the expected data type.
422+
order ("C","F","A","K", optional): memory layout of the new
423+
output array, if parameter `out` is ``None``.
424+
Default: "K".
425+
426+
Returns:
427+
usm_ndarray:
428+
An array containing the element-wise `exp(x) - 1` results.
429+
The data type of the returned array is determined by the Type
430+
Promotion Rules.
431+
"""
432+
433+
expm1 = UnaryElementwiseFunc(
434+
"expm1", ti._expm1_result_type, ti._expm1, _expm1_docstring
435+
)
436+
del _expm1_docstring
437+
438+
# U15: ==== FLOOR (x)
439+
_floor_docstring = r"""
440+
floor(x, /, \*, out=None, order='K')
441+
442+
Returns the floor for each element `x_i` for input array `x`.
443+
444+
The floor of `x_i` is the largest integer `n`, such that `n <= x_i`.
445+
446+
Args:
447+
x (usm_ndarray):
448+
Input array, expected to have a boolean or real-valued data type.
449+
out (Union[usm_ndarray, None], optional):
450+
Output array to populate.
451+
Array must have the correct shape and the expected data type.
452+
order ("C","F","A","K", optional):
453+
Memory layout of the new output array, if parameter
454+
`out` is ``None``.
455+
Default: "K".
456+
457+
Returns:
458+
usm_ndarray:
459+
An array containing the element-wise floor.
460+
"""
461+
462+
floor = UnaryElementwiseFunc(
463+
"floor", ti._floor_result_type, ti._floor, _floor_docstring
464+
)
465+
del _floor_docstring
466+
381467
# U43: ==== ANGLE (x)
382468
_angle_docstring = r"""
383469
angle(x, /, \*, out=None, order='K')
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
/// \file
30+
/// This file defines kernels for elementwise evaluation of EXP(x) function.
31+
//===---------------------------------------------------------------------===//
32+
33+
#pragma once
34+
#include <cmath>
35+
#include <cstddef>
36+
#include <cstdint>
37+
#include <limits>
38+
#include <type_traits>
39+
40+
#include <sycl/sycl.hpp>
41+
42+
#include "sycl_complex.hpp"
43+
#include "vec_size_util.hpp"
44+
45+
#include "kernels/dpctl_tensor_types.hpp"
46+
#include "kernels/elementwise_functions/common.hpp"
47+
48+
#include "utils/offset_utils.hpp"
49+
#include "utils/type_dispatch_building.hpp"
50+
#include "utils/type_utils.hpp"
51+
52+
namespace dpctl::tensor::kernels::exp
53+
{
54+
55+
using dpctl::tensor::ssize_t;
56+
namespace td_ns = dpctl::tensor::type_dispatch;
57+
58+
using dpctl::tensor::type_utils::is_complex;
59+
60+
template <typename argT, typename resT>
61+
struct ExpFunctor
62+
{
63+
// is function constant for given argT
64+
using is_constant = typename std::false_type;
65+
// constant value, if constant
66+
// constexpr resT constant_value = resT{};
67+
// is function defined for sycl::vec
68+
using supports_vec = typename std::false_type;
69+
// do both argTy and resTy support sugroup store/load operation
70+
using supports_sg_loadstore = typename std::negation<
71+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
72+
73+
resT operator()(const argT &in) const
74+
{
75+
if constexpr (is_complex<argT>::value) {
76+
using realT = typename argT::value_type;
77+
78+
static constexpr realT q_nan =
79+
std::numeric_limits<realT>::quiet_NaN();
80+
81+
const realT x = std::real(in);
82+
const realT y = std::imag(in);
83+
if (std::isfinite(x)) {
84+
if (std::isfinite(y)) {
85+
return exprm_ns::exp(
86+
exprm_ns::complex<realT>(in)); // exp(in);
87+
}
88+
else {
89+
return resT{q_nan, q_nan};
90+
}
91+
}
92+
else if (std::isnan(x)) {
93+
/* x is nan */
94+
if (y == realT(0)) {
95+
return resT{in};
96+
}
97+
else {
98+
return resT{x, q_nan};
99+
}
100+
}
101+
else {
102+
if (!sycl::signbit(x)) { /* x is +inf */
103+
if (y == realT(0)) {
104+
return resT{x, y};
105+
}
106+
else if (std::isfinite(y)) {
107+
return resT{x * sycl::cos(y), x * sycl::sin(y)};
108+
}
109+
else {
110+
/* x = +inf, y = +-inf || nan */
111+
return resT{x, q_nan};
112+
}
113+
}
114+
else { /* x is -inf */
115+
if (std::isfinite(y)) {
116+
realT exp_x = sycl::exp(x);
117+
return resT{exp_x * sycl::cos(y), exp_x * sycl::sin(y)};
118+
}
119+
else {
120+
/* x = -inf, y = +-inf || nan */
121+
return resT{0, 0};
122+
}
123+
}
124+
}
125+
}
126+
else {
127+
return sycl::exp(in);
128+
}
129+
}
130+
};
131+
132+
template <typename argTy,
133+
typename resTy = argTy,
134+
std::uint8_t vec_sz = 4u,
135+
std::uint8_t n_vecs = 2u,
136+
bool enable_sg_loadstore = true>
137+
using ExpContigFunctor =
138+
elementwise_common::UnaryContigFunctor<argTy,
139+
resTy,
140+
ExpFunctor<argTy, resTy>,
141+
vec_sz,
142+
n_vecs,
143+
enable_sg_loadstore>;
144+
145+
template <typename argTy, typename resTy, typename IndexerT>
146+
using ExpStridedFunctor = elementwise_common::
147+
UnaryStridedFunctor<argTy, resTy, IndexerT, ExpFunctor<argTy, resTy>>;
148+
149+
template <typename T>
150+
struct ExpOutputType
151+
{
152+
using value_type = typename std::disjunction<
153+
td_ns::TypeMapResultEntry<T, sycl::half>,
154+
td_ns::TypeMapResultEntry<T, float>,
155+
td_ns::TypeMapResultEntry<T, double>,
156+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
157+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
158+
td_ns::DefaultResultEntry<void>>::result_type;
159+
160+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
161+
};
162+
163+
namespace hyperparam_detail
164+
{
165+
166+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
167+
168+
using vsu_ns::ContigHyperparameterSetDefault;
169+
using vsu_ns::UnaryContigHyperparameterSetEntry;
170+
171+
template <typename argTy>
172+
struct ExpContigHyperparameterSet
173+
{
174+
using value_type =
175+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
176+
177+
constexpr static auto vec_sz = value_type::vec_sz;
178+
constexpr static auto n_vecs = value_type::n_vecs;
179+
};
180+
181+
} // end of namespace hyperparam_detail
182+
183+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
184+
class exp_contig_kernel;
185+
186+
template <typename argTy>
187+
sycl::event exp_contig_impl(sycl::queue &exec_q,
188+
std::size_t nelems,
189+
const char *arg_p,
190+
char *res_p,
191+
const std::vector<sycl::event> &depends = {})
192+
{
193+
using ExpHS = hyperparam_detail::ExpContigHyperparameterSet<argTy>;
194+
static constexpr std::uint8_t vec_sz = ExpHS::vec_sz;
195+
static constexpr std::uint8_t n_vecs = ExpHS::n_vecs;
196+
197+
return elementwise_common::unary_contig_impl<
198+
argTy, ExpOutputType, ExpContigFunctor, exp_contig_kernel, vec_sz,
199+
n_vecs>(exec_q, nelems, arg_p, res_p, depends);
200+
}
201+
202+
template <typename fnT, typename T>
203+
struct ExpContigFactory
204+
{
205+
fnT get()
206+
{
207+
if constexpr (!ExpOutputType<T>::is_defined) {
208+
fnT fn = nullptr;
209+
return fn;
210+
}
211+
else {
212+
fnT fn = exp_contig_impl<T>;
213+
return fn;
214+
}
215+
}
216+
};
217+
218+
template <typename fnT, typename T>
219+
struct ExpTypeMapFactory
220+
{
221+
/*! @brief get typeid for output type of sycl::exp(T x) */
222+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
223+
{
224+
using rT = typename ExpOutputType<T>::value_type;
225+
return td_ns::GetTypeid<rT>{}.get();
226+
}
227+
};
228+
229+
template <typename T1, typename T2, typename T3>
230+
class exp_strided_kernel;
231+
232+
template <typename argTy>
233+
sycl::event exp_strided_impl(sycl::queue &exec_q,
234+
std::size_t nelems,
235+
int nd,
236+
const ssize_t *shape_and_strides,
237+
const char *arg_p,
238+
ssize_t arg_offset,
239+
char *res_p,
240+
ssize_t res_offset,
241+
const std::vector<sycl::event> &depends,
242+
const std::vector<sycl::event> &additional_depends)
243+
{
244+
return elementwise_common::unary_strided_impl<
245+
argTy, ExpOutputType, ExpStridedFunctor, exp_strided_kernel>(
246+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
247+
res_offset, depends, additional_depends);
248+
}
249+
250+
template <typename fnT, typename T>
251+
struct ExpStridedFactory
252+
{
253+
fnT get()
254+
{
255+
if constexpr (!ExpOutputType<T>::is_defined) {
256+
fnT fn = nullptr;
257+
return fn;
258+
}
259+
else {
260+
fnT fn = exp_strided_impl<T>;
261+
return fn;
262+
}
263+
}
264+
};
265+
266+
} // namespace dpctl::tensor::kernels::exp

0 commit comments

Comments
 (0)