Skip to content

Commit dbf021a

Browse files
Move ti.proj()/real()/round() and reuse them
1 parent 7f444bd commit dbf021a

File tree

14 files changed

+1332
-15
lines changed

14 files changed

+1332
-15
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ set(_elementwise_sources
126126
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp
127127
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/positive.cpp
128128
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
129-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
130-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
129+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
130+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
131131
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp
132132
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
133-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
133+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
134134
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp
135135
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sign.cpp
136136
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/signbit.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@
111111
logical_not,
112112
negative,
113113
positive,
114+
proj,
115+
real,
116+
round,
114117
)
115118
from ._reduction import (
116119
argmax,
@@ -207,13 +210,16 @@
207210
"place",
208211
"positive",
209212
"prod",
213+
"proj",
210214
"put",
211215
"put_along_axis",
216+
"real",
212217
"reduce_hypot",
213218
"repeat",
214219
"reshape",
215220
"result_type",
216221
"roll",
222+
"round",
217223
"searchsorted",
218224
"sort",
219225
"squeeze",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,95 @@
782782
)
783783
del _positive_docstring_
784784

785+
# U27: ==== REAL (x)
786+
_real_docstring = r"""
787+
real(x, /, \*, out=None, order='K')
788+
789+
Computes real part of each element `x_i` for input array `x`.
790+
791+
Args:
792+
x (usm_ndarray):
793+
Input array. May have any data type.
794+
out (Union[usm_ndarray, None], optional):
795+
Output array to populate.
796+
Array must have the correct shape and the expected data type.
797+
order ("C","F","A","K", optional):
798+
Memory layout of the new output array, if parameter
799+
`out` is ``None``.
800+
Default: "K".
801+
802+
Returns:
803+
usm_ndarray:
804+
An array containing the element-wise real component of input.
805+
If the input is a real-valued data type, the returned array has
806+
the same data type. If the input is a complex floating-point
807+
data type, the returned array has a floating-point data type
808+
with the same floating-point precision as complex input.
809+
"""
810+
811+
real = UnaryElementwiseFunc(
812+
"real", ti._real_result_type, ti._real, _real_docstring
813+
)
814+
del _real_docstring
815+
816+
# U28: ==== ROUND (x)
817+
_round_docstring = r"""
818+
round(x, /, \*, out=None, order='K')
819+
820+
Rounds each element `x_i` of the input array `x` to
821+
the nearest integer-valued number.
822+
823+
When two integers are equally close to `x_i`, the result is the nearest even
824+
integer to `x_i`.
825+
826+
Args:
827+
x (usm_ndarray):
828+
Input array, expected to have a numeric data type.
829+
out (Union[usm_ndarray, None], optional):
830+
Output array to populate.
831+
Array must have the correct shape and the expected data type.
832+
order ("C","F","A","K", optional):
833+
Memory layout of the new output array, if parameter
834+
`out` is ``None``.
835+
Default: "K".
836+
837+
Returns:
838+
usm_ndarray:
839+
An array containing the element-wise rounded values.
840+
"""
841+
842+
round = UnaryElementwiseFunc(
843+
"round", ti._round_result_type, ti._round, _round_docstring
844+
)
845+
del _round_docstring
846+
847+
# U40: ==== PROJ (x)
848+
_proj_docstring = r"""
849+
proj(x, /, \*, out=None, order='K')
850+
851+
Computes projection of each element `x_i` for input array `x`.
852+
853+
Args:
854+
x (usm_ndarray):
855+
Input array, expected to have a complex data type.
856+
out (Union[usm_ndarray, None], optional):
857+
Output array to populate.
858+
Array must have the correct shape and the expected data type.
859+
order ("C","F","A","K", optional):
860+
Memory layout of the new output array, if parameter
861+
`out` is ``None``.
862+
Default: "K".
863+
864+
Returns:
865+
usm_ndarray:
866+
An array containing the element-wise projection.
867+
"""
868+
869+
proj = UnaryElementwiseFunc(
870+
"proj", ti._proj_result_type, ti._proj, _proj_docstring
871+
)
872+
del _proj_docstring
873+
785874
# U43: ==== ANGLE (x)
786875
_angle_docstring = r"""
787876
angle(x, /, \*, out=None, order='K')
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of PROJ(x) function.
33+
//===---------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <cmath>
37+
#include <complex>
38+
#include <cstddef>
39+
#include <cstdint>
40+
#include <limits>
41+
#include <type_traits>
42+
#include <vector>
43+
44+
#include <sycl/sycl.hpp>
45+
46+
#include "vec_size_util.hpp"
47+
48+
#include "kernels/dpctl_tensor_types.hpp"
49+
#include "kernels/elementwise_functions/common.hpp"
50+
51+
#include "utils/type_dispatch_building.hpp"
52+
#include "utils/type_utils.hpp"
53+
54+
namespace dpctl::tensor::kernels::proj
55+
{
56+
57+
using dpctl::tensor::ssize_t;
58+
namespace td_ns = dpctl::tensor::type_dispatch;
59+
60+
using dpctl::tensor::type_utils::is_complex;
61+
62+
template <typename argT, typename resT>
63+
struct ProjFunctor
64+
{
65+
66+
// is function constant for given argT
67+
using is_constant = typename std::false_type;
68+
// constant value, if constant
69+
// constexpr resT constant_value = resT{};
70+
// is function defined for sycl::vec
71+
using supports_vec = typename std::false_type;
72+
// do both argTy and resTy support sugroup store/load operation
73+
using supports_sg_loadstore = typename std::false_type;
74+
75+
resT operator()(const argT &in) const
76+
{
77+
using realT = typename argT::value_type;
78+
const realT x = std::real(in);
79+
const realT y = std::imag(in);
80+
81+
if (std::isinf(x)) {
82+
return value_at_infinity(y);
83+
}
84+
else if (std::isinf(y)) {
85+
return value_at_infinity(y);
86+
}
87+
else {
88+
return in;
89+
}
90+
}
91+
92+
private:
93+
template <typename T>
94+
std::complex<T> value_at_infinity(const T &y) const
95+
{
96+
const T res_im = sycl::copysign(T(0), y);
97+
return std::complex<T>{std::numeric_limits<T>::infinity(), res_im};
98+
}
99+
};
100+
101+
template <typename argTy,
102+
typename resTy = argTy,
103+
std::uint8_t vec_sz = 4u,
104+
std::uint8_t n_vecs = 2u,
105+
bool enable_sg_loadstore = true>
106+
using ProjContigFunctor =
107+
elementwise_common::UnaryContigFunctor<argTy,
108+
resTy,
109+
ProjFunctor<argTy, resTy>,
110+
vec_sz,
111+
n_vecs,
112+
enable_sg_loadstore>;
113+
114+
template <typename argTy, typename resTy, typename IndexerT>
115+
using ProjStridedFunctor = elementwise_common::
116+
UnaryStridedFunctor<argTy, resTy, IndexerT, ProjFunctor<argTy, resTy>>;
117+
118+
template <typename T>
119+
struct ProjOutputType
120+
{
121+
using value_type = typename std::disjunction<
122+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
123+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
124+
td_ns::DefaultResultEntry<void>>::result_type;
125+
126+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
127+
};
128+
129+
namespace hyperparam_detail
130+
{
131+
132+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
133+
134+
using vsu_ns::ContigHyperparameterSetDefault;
135+
using vsu_ns::UnaryContigHyperparameterSetEntry;
136+
137+
template <typename argTy>
138+
struct ProjContigHyperparameterSet
139+
{
140+
using value_type =
141+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
142+
143+
constexpr static auto vec_sz = value_type::vec_sz;
144+
constexpr static auto n_vecs = value_type::n_vecs;
145+
};
146+
147+
} // end of namespace hyperparam_detail
148+
149+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
150+
class proj_contig_kernel;
151+
152+
template <typename argTy>
153+
sycl::event proj_contig_impl(sycl::queue &exec_q,
154+
std::size_t nelems,
155+
const char *arg_p,
156+
char *res_p,
157+
const std::vector<sycl::event> &depends = {})
158+
{
159+
using ProjHS = hyperparam_detail::ProjContigHyperparameterSet<argTy>;
160+
static constexpr std::uint8_t vec_sz = ProjHS::vec_sz;
161+
static constexpr std::uint8_t n_vecs = ProjHS::n_vecs;
162+
163+
return elementwise_common::unary_contig_impl<
164+
argTy, ProjOutputType, ProjContigFunctor, proj_contig_kernel, vec_sz,
165+
n_vecs>(exec_q, nelems, arg_p, res_p, depends);
166+
}
167+
168+
template <typename fnT, typename T>
169+
struct ProjContigFactory
170+
{
171+
fnT get()
172+
{
173+
if constexpr (!ProjOutputType<T>::is_defined) {
174+
fnT fn = nullptr;
175+
return fn;
176+
}
177+
else {
178+
if constexpr (std::is_same_v<T, std::complex<double>>) {
179+
fnT fn = proj_contig_impl<T>;
180+
return fn;
181+
}
182+
else {
183+
fnT fn = proj_contig_impl<T>;
184+
return fn;
185+
}
186+
}
187+
}
188+
};
189+
190+
template <typename fnT, typename T>
191+
struct ProjTypeMapFactory
192+
{
193+
/*! @brief get typeid for output type of std::proj(T x) */
194+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
195+
{
196+
using rT = typename ProjOutputType<T>::value_type;
197+
return td_ns::GetTypeid<rT>{}.get();
198+
}
199+
};
200+
201+
template <typename T1, typename T2, typename T3>
202+
class proj_strided_kernel;
203+
204+
template <typename argTy>
205+
sycl::event
206+
proj_strided_impl(sycl::queue &exec_q,
207+
std::size_t nelems,
208+
int nd,
209+
const ssize_t *shape_and_strides,
210+
const char *arg_p,
211+
ssize_t arg_offset,
212+
char *res_p,
213+
ssize_t res_offset,
214+
const std::vector<sycl::event> &depends,
215+
const std::vector<sycl::event> &additional_depends)
216+
{
217+
return elementwise_common::unary_strided_impl<
218+
argTy, ProjOutputType, ProjStridedFunctor, proj_strided_kernel>(
219+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
220+
res_offset, depends, additional_depends);
221+
}
222+
223+
template <typename fnT, typename T>
224+
struct ProjStridedFactory
225+
{
226+
fnT get()
227+
{
228+
if constexpr (!ProjOutputType<T>::is_defined) {
229+
fnT fn = nullptr;
230+
return fn;
231+
}
232+
else {
233+
fnT fn = proj_strided_impl<T>;
234+
return fn;
235+
}
236+
}
237+
};
238+
239+
} // namespace dpctl::tensor::kernels::proj

0 commit comments

Comments
 (0)