Skip to content

Commit 6d9221c

Browse files
Move ti.log2()/log10() and reuse them
1 parent 4a8c05d commit 6d9221c

File tree

11 files changed

+895
-10
lines changed

11 files changed

+895
-10
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ set(_elementwise_sources
111111
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less.cpp
112112
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log.cpp
113113
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log1p.cpp
114-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log2.cpp
115-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log10.cpp
114+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log2.cpp
115+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log10.cpp
116116
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logaddexp.cpp
117117
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_and.cpp
118118
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_not.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
isnan,
107107
log,
108108
log1p,
109+
log2,
110+
log10,
109111
)
110112
from ._reduction import (
111113
argmax,
@@ -187,6 +189,8 @@
187189
"log",
188190
"logsumexp",
189191
"log1p",
192+
"log2",
193+
"log10",
190194
"max",
191195
"meshgrid",
192196
"min",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,64 @@
636636
)
637637
del _log1p_docstring
638638

639+
# U22: ==== LOG2 (x)
640+
_log2_docstring_ = r"""
641+
log2(x, /, \*, out=None, order='K')
642+
643+
Computes the base-2 logarithm for each element `x_i` of input array `x`.
644+
645+
Args:
646+
x (usm_ndarray):
647+
Input array, expected to have a floating-point data type.
648+
out (Union[usm_ndarray, None], optional):
649+
Output array to populate.
650+
Array must have the correct shape and the expected data type.
651+
order ("C","F","A","K", optional):
652+
Memory layout of the new output array, if parameter
653+
`out` is ``None``.
654+
Default: "K".
655+
656+
Returns:
657+
usm_ndarray:
658+
An array containing the element-wise base-2 logarithm of `x`.
659+
The data type of the returned array is determined by the
660+
Type Promotion Rules.
661+
"""
662+
663+
log2 = UnaryElementwiseFunc(
664+
"log2", ti._log2_result_type, ti._log2, _log2_docstring_
665+
)
666+
del _log2_docstring_
667+
668+
# U23: ==== LOG10 (x)
669+
_log10_docstring_ = r"""
670+
log10(x, /, \*, out=None, order='K')
671+
672+
Computes the base-10 logarithm for each element `x_i` of input array `x`.
673+
674+
Args:
675+
x (usm_ndarray):
676+
Input array, expected to have a floating-point data type.
677+
out (Union[usm_ndarray, None], optional):
678+
Output array to populate.
679+
Array must have the correct shape and the expected data type.
680+
order ("C","F","A","K", optional):
681+
Memory layout of the new output array, if parameter
682+
`out` is ``None``.
683+
Default: `"K"`.
684+
685+
Returns:
686+
usm_ndarray:
687+
An array containing the element-wise base-10 logarithm of `x`.
688+
The data type of the returned array is determined by the
689+
Type Promotion Rules.
690+
"""
691+
692+
log10 = UnaryElementwiseFunc(
693+
"log10", ti._log10_result_type, ti._log10, _log10_docstring_
694+
)
695+
del _log10_docstring_
696+
639697
# U43: ==== ANGLE (x)
640698
_angle_docstring = r"""
641699
angle(x, /, \*, out=None, order='K')
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of LOG10(x) function.
33+
//===---------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <complex>
37+
#include <cstddef>
38+
#include <cstdint>
39+
#include <sycl/sycl.hpp>
40+
#include <type_traits>
41+
42+
#include "sycl_complex.hpp"
43+
#include "vec_size_util.hpp"
44+
45+
#include "kernels/dpctl_tensor_types.hpp"
46+
#include "kernels/elementwise_functions/common.hpp"
47+
48+
#include "utils/offset_utils.hpp"
49+
#include "utils/type_dispatch_building.hpp"
50+
#include "utils/type_utils.hpp"
51+
52+
namespace dpctl::tensor::kernels::log10
53+
{
54+
55+
using dpctl::tensor::ssize_t;
56+
namespace td_ns = dpctl::tensor::type_dispatch;
57+
58+
using dpctl::tensor::type_utils::is_complex;
59+
using dpctl::tensor::type_utils::vec_cast;
60+
61+
template <typename argT, typename resT>
62+
struct Log10Functor
63+
{
64+
65+
// is function constant for given argT
66+
using is_constant = typename std::false_type;
67+
// constant value, if constant
68+
// constexpr resT constant_value = resT{};
69+
// is function defined for sycl::vec
70+
using supports_vec = typename std::negation<
71+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
72+
// do both argTy and resTy support sugroup store/load operation
73+
using supports_sg_loadstore = typename std::negation<
74+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
75+
76+
resT operator()(const argT &in) const
77+
{
78+
if constexpr (is_complex<argT>::value) {
79+
using realT = typename argT::value_type;
80+
// return (log(in) / log(realT{10}));
81+
return exprm_ns::log(exprm_ns::complex<realT>(in)) /
82+
sycl::log(realT{10});
83+
}
84+
else {
85+
return sycl::log10(in);
86+
}
87+
}
88+
89+
template <int vec_sz>
90+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in) const
91+
{
92+
auto const &res_vec = sycl::log10(in);
93+
using deducedT = typename std::remove_cv_t<
94+
std::remove_reference_t<decltype(res_vec)>>::element_type;
95+
if constexpr (std::is_same_v<resT, deducedT>) {
96+
return res_vec;
97+
}
98+
else {
99+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
100+
}
101+
}
102+
};
103+
104+
template <typename argTy,
105+
typename resTy = argTy,
106+
std::uint8_t vec_sz = 4u,
107+
std::uint8_t n_vecs = 2u,
108+
bool enable_sg_loadstore = true>
109+
using Log10ContigFunctor =
110+
elementwise_common::UnaryContigFunctor<argTy,
111+
resTy,
112+
Log10Functor<argTy, resTy>,
113+
vec_sz,
114+
n_vecs,
115+
enable_sg_loadstore>;
116+
117+
template <typename argTy, typename resTy, typename IndexerT>
118+
using Log10StridedFunctor = elementwise_common::
119+
UnaryStridedFunctor<argTy, resTy, IndexerT, Log10Functor<argTy, resTy>>;
120+
121+
template <typename T>
122+
struct Log10OutputType
123+
{
124+
using value_type = typename std::disjunction<
125+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
126+
td_ns::TypeMapResultEntry<T, float, float>,
127+
td_ns::TypeMapResultEntry<T, double, double>,
128+
td_ns::TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
129+
td_ns::
130+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
131+
td_ns::DefaultResultEntry<void>>::result_type;
132+
133+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
134+
};
135+
136+
namespace hyperparam_detail
137+
{
138+
139+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
140+
141+
using vsu_ns::ContigHyperparameterSetDefault;
142+
using vsu_ns::UnaryContigHyperparameterSetEntry;
143+
144+
template <typename argTy>
145+
struct Log10ContigHyperparameterSet
146+
{
147+
using value_type =
148+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
149+
150+
constexpr static auto vec_sz = value_type::vec_sz;
151+
constexpr static auto n_vecs = value_type::n_vecs;
152+
};
153+
154+
} // end of namespace hyperparam_detail
155+
156+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
157+
class log10_contig_kernel;
158+
159+
template <typename argTy>
160+
sycl::event log10_contig_impl(sycl::queue &exec_q,
161+
std::size_t nelems,
162+
const char *arg_p,
163+
char *res_p,
164+
const std::vector<sycl::event> &depends = {})
165+
{
166+
using Log10HS = hyperparam_detail::Log10ContigHyperparameterSet<argTy>;
167+
static constexpr std::uint8_t vec_sz = Log10HS::vec_sz;
168+
static constexpr std::uint8_t n_vecs = Log10HS::n_vecs;
169+
170+
return elementwise_common::unary_contig_impl<
171+
argTy, Log10OutputType, Log10ContigFunctor, log10_contig_kernel, vec_sz,
172+
n_vecs>(exec_q, nelems, arg_p, res_p, depends);
173+
}
174+
175+
template <typename fnT, typename T>
176+
struct Log10ContigFactory
177+
{
178+
fnT get()
179+
{
180+
if constexpr (!Log10OutputType<T>::is_defined) {
181+
fnT fn = nullptr;
182+
return fn;
183+
}
184+
else {
185+
fnT fn = log10_contig_impl<T>;
186+
return fn;
187+
}
188+
}
189+
};
190+
191+
template <typename fnT, typename T>
192+
struct Log10TypeMapFactory
193+
{
194+
/*! @brief get typeid for output type of sycl::log10(T x) */
195+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
196+
{
197+
using rT = typename Log10OutputType<T>::value_type;
198+
return td_ns::GetTypeid<rT>{}.get();
199+
}
200+
};
201+
202+
template <typename T1, typename T2, typename T3>
203+
class log10_strided_kernel;
204+
205+
template <typename argTy>
206+
sycl::event
207+
log10_strided_impl(sycl::queue &exec_q,
208+
std::size_t nelems,
209+
int nd,
210+
const ssize_t *shape_and_strides,
211+
const char *arg_p,
212+
ssize_t arg_offset,
213+
char *res_p,
214+
ssize_t res_offset,
215+
const std::vector<sycl::event> &depends,
216+
const std::vector<sycl::event> &additional_depends)
217+
{
218+
return elementwise_common::unary_strided_impl<
219+
argTy, Log10OutputType, Log10StridedFunctor, log10_strided_kernel>(
220+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
221+
res_offset, depends, additional_depends);
222+
}
223+
224+
template <typename fnT, typename T>
225+
struct Log10StridedFactory
226+
{
227+
fnT get()
228+
{
229+
if constexpr (!Log10OutputType<T>::is_defined) {
230+
fnT fn = nullptr;
231+
return fn;
232+
}
233+
else {
234+
fnT fn = log10_strided_impl<T>;
235+
return fn;
236+
}
237+
}
238+
};
239+
240+
} // namespace dpctl::tensor::kernels::log10

0 commit comments

Comments
 (0)