Skip to content

Commit 4a8c05d

Browse files
Move ti.isnan()/log()/log1p() and reuse them
1 parent eab5fec commit 4a8c05d

File tree

15 files changed

+1285
-16
lines changed

15 files changed

+1285
-16
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ set(_elementwise_sources
106106
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/imag.cpp
107107
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isfinite.cpp
108108
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isinf.cpp
109-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isnan.cpp
109+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isnan.cpp
110110
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less_equal.cpp
111111
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less.cpp
112-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log.cpp
113-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log1p.cpp
112+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log.cpp
113+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log1p.cpp
114114
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log2.cpp
115115
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log10.cpp
116116
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logaddexp.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@
103103
imag,
104104
isfinite,
105105
isinf,
106+
isnan,
107+
log,
108+
log1p,
106109
)
107110
from ._reduction import (
108111
argmax,
@@ -179,8 +182,11 @@
179182
"isinf",
180183
"isdtype",
181184
"isin",
185+
"isnan",
182186
"linspace",
187+
"log",
183188
"logsumexp",
189+
"log1p",
184190
"max",
185191
"meshgrid",
186192
"min",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,90 @@
552552
)
553553
del _isinf_docstring_
554554

555+
# U19: ==== ISNAN (x)
556+
_isnan_docstring_ = r"""
557+
isnan(x, /, \*, out=None, order='K')
558+
559+
Test if each element of an input array is a NaN.
560+
561+
Args:
562+
x (usm_ndarray):
563+
Input array. May have any data type.
564+
out (Union[usm_ndarray, None], optional):
565+
Output array to populate.
566+
Array must have the correct shape and the expected data type.
567+
order ("C","F","A","K", optional):
568+
Memory layout of the new output array, if parameter
569+
`out` is ``None``.
570+
Default: "K".
571+
572+
Returns:
573+
usm_ndarray:
574+
An array which is True where x is NaN, False otherwise.
575+
The data type of the returned array is `bool`.
576+
"""
577+
578+
isnan = UnaryElementwiseFunc(
579+
"isnan", ti._isnan_result_type, ti._isnan, _isnan_docstring_
580+
)
581+
del _isnan_docstring_
582+
583+
# U20: ==== LOG (x)
584+
_log_docstring = r"""
585+
log(x, /, \*, out=None, order='K')
586+
587+
Computes the natural logarithm for each element `x_i` of input array `x`.
588+
589+
Args:
590+
x (usm_ndarray):
591+
Input array, expected to have a floating-point data type.
592+
out (usm_ndarray):
593+
Output array to populate. Array must have the correct
594+
shape and the expected data type.
595+
order ("C","F","A","K", optional): memory layout of the new
596+
output array, if parameter `out` is ``None``.
597+
Default: "K".
598+
599+
Returns:
600+
usm_ndarray:
601+
An array containing the element-wise natural logarithm values.
602+
The data type of the returned array is determined by the Type
603+
Promotion Rules.
604+
"""
605+
606+
log = UnaryElementwiseFunc("log", ti._log_result_type, ti._log, _log_docstring)
607+
del _log_docstring
608+
609+
# U21: ==== LOG1P (x)
610+
_log1p_docstring = r"""
611+
log1p(x, /, \*, out=None, order='K')
612+
613+
Computes the natural logarithm of (1 + `x`) for each element `x_i` of input
614+
array `x`.
615+
616+
This function calculates `log(1 + x)` more accurately for small values of `x`.
617+
618+
Args:
619+
x (usm_ndarray):
620+
Input array, expected to have a floating-point data type.
621+
out (usm_ndarray):
622+
Output array to populate. Array must have the correct
623+
shape and the expected data type.
624+
order ("C","F","A","K", optional): memory layout of the new
625+
output array, if parameter `out` is ``None``.
626+
Default: "K".
627+
628+
Returns:
629+
usm_ndarray:
630+
An array containing the element-wise `log(1 + x)` results. The data type
631+
of the returned array is determined by the Type Promotion Rules.
632+
"""
633+
634+
log1p = UnaryElementwiseFunc(
635+
"log1p", ti._log1p_result_type, ti._log1p, _log1p_docstring
636+
)
637+
del _log1p_docstring
638+
555639
# U43: ==== ANGLE (x)
556640
_angle_docstring = r"""
557641
angle(x, /, \*, out=None, order='K')
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of ISNAN(x)
33+
/// function that tests whether a tensor element is a NaN.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <complex>
38+
#include <cstddef>
39+
#include <cstdint>
40+
#include <type_traits>
41+
#include <vector>
42+
43+
#include <sycl/sycl.hpp>
44+
45+
#include "vec_size_util.hpp"
46+
47+
#include "kernels/dpctl_tensor_types.hpp"
48+
#include "kernels/elementwise_functions/common.hpp"
49+
50+
#include "utils/offset_utils.hpp"
51+
#include "utils/type_dispatch_building.hpp"
52+
#include "utils/type_utils.hpp"
53+
54+
namespace dpctl::tensor::kernels::isnan
55+
{
56+
57+
using dpctl::tensor::ssize_t;
58+
namespace td_ns = dpctl::tensor::type_dispatch;
59+
60+
using dpctl::tensor::type_utils::is_complex;
61+
using dpctl::tensor::type_utils::vec_cast;
62+
63+
template <typename argT, typename resT>
64+
struct IsNanFunctor
65+
{
66+
static_assert(std::is_same_v<resT, bool>);
67+
68+
/*
69+
std::is_same<argT, bool>::value ||
70+
std::is_integral<argT>::value
71+
*/
72+
using is_constant = typename std::disjunction<std::is_same<argT, bool>,
73+
std::is_integral<argT>>;
74+
static constexpr resT constant_value = false;
75+
using supports_vec = typename std::true_type;
76+
using supports_sg_loadstore = typename std::negation<
77+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
78+
79+
resT operator()(const argT &in) const
80+
{
81+
if constexpr (is_complex<argT>::value) {
82+
const bool real_isnan = sycl::isnan(std::real(in));
83+
const bool imag_isnan = sycl::isnan(std::imag(in));
84+
return (real_isnan || imag_isnan);
85+
}
86+
else if constexpr (std::is_same<argT, bool>::value ||
87+
std::is_integral<argT>::value)
88+
{
89+
return constant_value;
90+
}
91+
else {
92+
return sycl::isnan(in);
93+
}
94+
}
95+
96+
template <int vec_sz>
97+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in) const
98+
{
99+
auto const &res_vec = sycl::isnan(in);
100+
101+
using deducedT = typename std::remove_cv_t<
102+
std::remove_reference_t<decltype(res_vec)>>::element_type;
103+
104+
return vec_cast<bool, deducedT, vec_sz>(res_vec);
105+
}
106+
};
107+
108+
template <typename argT,
109+
typename resT = bool,
110+
std::uint8_t vec_sz = 4u,
111+
std::uint8_t n_vecs = 2u,
112+
bool enable_sg_loadstore = true>
113+
using IsNanContigFunctor =
114+
elementwise_common::UnaryContigFunctor<argT,
115+
resT,
116+
IsNanFunctor<argT, resT>,
117+
vec_sz,
118+
n_vecs,
119+
enable_sg_loadstore>;
120+
121+
template <typename argTy, typename resTy, typename IndexerT>
122+
using IsNanStridedFunctor = elementwise_common::
123+
UnaryStridedFunctor<argTy, resTy, IndexerT, IsNanFunctor<argTy, resTy>>;
124+
125+
template <typename argTy>
126+
struct IsNanOutputType
127+
{
128+
using value_type = bool;
129+
};
130+
131+
namespace hyperparam_detail
132+
{
133+
134+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
135+
136+
using vsu_ns::ContigHyperparameterSetDefault;
137+
using vsu_ns::UnaryContigHyperparameterSetEntry;
138+
139+
template <typename argTy>
140+
struct IsNanContigHyperparameterSet
141+
{
142+
using value_type =
143+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
144+
145+
constexpr static auto vec_sz = value_type::vec_sz;
146+
constexpr static auto n_vecs = value_type::n_vecs;
147+
};
148+
149+
} // end of namespace hyperparam_detail
150+
151+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
152+
class isnan_contig_kernel;
153+
154+
template <typename argTy>
155+
sycl::event isnan_contig_impl(sycl::queue &exec_q,
156+
std::size_t nelems,
157+
const char *arg_p,
158+
char *res_p,
159+
const std::vector<sycl::event> &depends = {})
160+
{
161+
using IsNanHS = hyperparam_detail::IsNanContigHyperparameterSet<argTy>;
162+
static constexpr std::uint8_t vec_sz = IsNanHS::vec_sz;
163+
static constexpr std::uint8_t n_vecs = IsNanHS::n_vecs;
164+
165+
return elementwise_common::unary_contig_impl<
166+
argTy, IsNanOutputType, IsNanContigFunctor, isnan_contig_kernel, vec_sz,
167+
n_vecs>(exec_q, nelems, arg_p, res_p, depends);
168+
}
169+
170+
template <typename fnT, typename T>
171+
struct IsNanContigFactory
172+
{
173+
fnT get()
174+
{
175+
fnT fn = isnan_contig_impl<T>;
176+
return fn;
177+
}
178+
};
179+
180+
template <typename fnT, typename T>
181+
struct IsNanTypeMapFactory
182+
{
183+
/*! @brief get typeid for output type of sycl::isnan(T x) */
184+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
185+
{
186+
using rT = typename IsNanOutputType<T>::value_type;
187+
return td_ns::GetTypeid<rT>{}.get();
188+
}
189+
};
190+
191+
template <typename T1, typename T2, typename T3>
192+
class isnan_strided_kernel;
193+
194+
template <typename argTy>
195+
sycl::event
196+
isnan_strided_impl(sycl::queue &exec_q,
197+
std::size_t nelems,
198+
int nd,
199+
const ssize_t *shape_and_strides,
200+
const char *arg_p,
201+
ssize_t arg_offset,
202+
char *res_p,
203+
ssize_t res_offset,
204+
const std::vector<sycl::event> &depends,
205+
const std::vector<sycl::event> &additional_depends)
206+
{
207+
return elementwise_common::unary_strided_impl<
208+
argTy, IsNanOutputType, IsNanStridedFunctor, isnan_strided_kernel>(
209+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
210+
res_offset, depends, additional_depends);
211+
}
212+
213+
template <typename fnT, typename T>
214+
struct IsNanStridedFactory
215+
{
216+
fnT get()
217+
{
218+
fnT fn = isnan_strided_impl<T>;
219+
return fn;
220+
}
221+
};
222+
223+
} // namespace dpctl::tensor::kernels::isnan

0 commit comments

Comments
 (0)