Skip to content

Commit 7f444bd

Browse files
Move ti.logical_not()/negative()/positive() and reuse them
1 parent 6d9221c commit 7f444bd

File tree

16 files changed

+1289
-16
lines changed

16 files changed

+1289
-16
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,16 @@ set(_elementwise_sources
115115
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log10.cpp
116116
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logaddexp.cpp
117117
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_and.cpp
118-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_not.cpp
118+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_not.cpp
119119
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_or.cpp
120120
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_xor.cpp
121121
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/maximum.cpp
122122
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/minimum.cpp
123123
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/multiply.cpp
124-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/negative.cpp
124+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/negative.cpp
125125
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/nextafter.cpp
126126
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp
127-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/positive.cpp
127+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/positive.cpp
128128
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
129129
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
130130
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@
108108
log1p,
109109
log2,
110110
log10,
111+
logical_not,
112+
negative,
113+
positive,
111114
)
112115
from ._reduction import (
113116
argmax,
@@ -187,6 +190,7 @@
187190
"isnan",
188191
"linspace",
189192
"log",
193+
"logical_not",
190194
"logsumexp",
191195
"log1p",
192196
"log2",
@@ -196,10 +200,12 @@
196200
"min",
197201
"moveaxis",
198202
"permute_dims",
203+
"negative",
199204
"nonzero",
200205
"ones",
201206
"ones_like",
202207
"place",
208+
"positive",
203209
"prod",
204210
"put",
205211
"put_along_axis",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
import dpctl_ext.tensor._tensor_elementwise_impl as ti
3232

3333
from ._elementwise_common import UnaryElementwiseFunc
34+
from ._type_utils import (
35+
_acceptance_fn_negative,
36+
)
3437

3538
# U01: ==== ABS (x)
3639
_abs_docstring_ = r"""
@@ -694,6 +697,91 @@
694697
)
695698
del _log10_docstring_
696699

700+
# U24: ==== LOGICAL_NOT (x)
701+
_logical_not_docstring = r"""
702+
logical_not(x, /, \*, out=None, order='K')
703+
704+
Computes the logical NOT for each element `x_i` of input array `x`.
705+
706+
Args:
707+
x (usm_ndarray):
708+
Input array. May have any data type.
709+
out (usm_ndarray):
710+
Output array to populate. Array must have the correct
711+
shape and the expected data type.
712+
order ("C","F","A","K", optional): memory layout of the new
713+
output array, if parameter `out` is ``None``.
714+
Default: "K".
715+
716+
Returns:
717+
usm_ndarray:
718+
An array containing the element-wise logical NOT results.
719+
"""
720+
721+
logical_not = UnaryElementwiseFunc(
722+
"logical_not",
723+
ti._logical_not_result_type,
724+
ti._logical_not,
725+
_logical_not_docstring,
726+
)
727+
del _logical_not_docstring
728+
729+
# U25: ==== NEGATIVE (x)
730+
_negative_docstring_ = r"""
731+
negative(x, /, \*, out=None, order='K')
732+
733+
Computes the numerical negative for each element `x_i` of input array `x`.
734+
735+
Args:
736+
x (usm_ndarray):
737+
Input array, expected to have a numeric data type.
738+
out (usm_ndarray):
739+
Output array to populate. Array must have the correct
740+
shape and the expected data type.
741+
order ("C","F","A","K", optional): memory layout of the new
742+
output array, if parameter `out` is ``None``.
743+
Default: "K".
744+
745+
Returns:
746+
usm_ndarray:
747+
An array containing the negative of `x`.
748+
"""
749+
750+
negative = UnaryElementwiseFunc(
751+
"negative",
752+
ti._negative_result_type,
753+
ti._negative,
754+
_negative_docstring_,
755+
acceptance_fn=_acceptance_fn_negative,
756+
)
757+
del _negative_docstring_
758+
759+
# U26: ==== POSITIVE (x)
760+
_positive_docstring_ = r"""
761+
positive(x, /, \*, out=None, order='K')
762+
763+
Computes the numerical positive for each element `x_i` of input array `x`.
764+
765+
Args:
766+
x (usm_ndarray):
767+
Input array, expected to have a numeric data type.
768+
out (usm_ndarray):
769+
Output array to populate. Array must have the correct
770+
shape and the expected data type.
771+
order ("C","F","A","K", optional): memory layout of the new
772+
output array, if parameter `out` is ``None``.
773+
Default: "K".
774+
775+
Returns:
776+
usm_ndarray:
777+
An array containing the positive of `x`.
778+
"""
779+
780+
positive = UnaryElementwiseFunc(
781+
"positive", ti._positive_result_type, ti._positive, _positive_docstring_
782+
)
783+
del _positive_docstring_
784+
697785
# U43: ==== ANGLE (x)
698786
_angle_docstring = r"""
699787
angle(x, /, \*, out=None, order='K')

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
2626
// THE POSSIBILITY OF SUCH DAMAGE.
2727
//*****************************************************************************
28-
///
28+
//
29+
//===---------------------------------------------------------------------===//
2930
/// \file
3031
/// This file defines kernels for elementwise evaluation of LOGADDEXP(x1, x2)
3132
/// function.
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of ISNAN(x)
33+
/// function that tests whether a tensor element is a NaN.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <cstddef>
38+
#include <type_traits>
39+
#include <vector>
40+
41+
#include <sycl/sycl.hpp>
42+
43+
#include "vec_size_util.hpp"
44+
45+
#include "kernels/dpctl_tensor_types.hpp"
46+
#include "kernels/elementwise_functions/common.hpp"
47+
48+
#include "utils/type_dispatch_building.hpp"
49+
#include "utils/type_utils.hpp"
50+
51+
namespace dpctl::tensor::kernels::logical_not
52+
{
53+
54+
using dpctl::tensor::ssize_t;
55+
namespace td_ns = dpctl::tensor::type_dispatch;
56+
namespace tu_ns = dpctl::tensor::type_utils;
57+
58+
template <typename argT, typename resT>
59+
struct LogicalNotFunctor
60+
{
61+
static_assert(std::is_same_v<resT, bool>);
62+
63+
using is_constant = typename std::false_type;
64+
// constexpr resT constant_value = resT{};
65+
using supports_vec = typename std::false_type;
66+
using supports_sg_loadstore = typename std::negation<
67+
std::disjunction<tu_ns::is_complex<resT>, tu_ns::is_complex<argT>>>;
68+
69+
resT operator()(const argT &in) const
70+
{
71+
using tu_ns::convert_impl;
72+
return !convert_impl<bool, argT>(in);
73+
}
74+
};
75+
76+
template <typename argT,
77+
typename resT = bool,
78+
std::uint8_t vec_sz = 4u,
79+
std::uint8_t n_vecs = 2u,
80+
bool enable_sg_loadstore = true>
81+
using LogicalNotContigFunctor =
82+
elementwise_common::UnaryContigFunctor<argT,
83+
resT,
84+
LogicalNotFunctor<argT, resT>,
85+
vec_sz,
86+
n_vecs,
87+
enable_sg_loadstore>;
88+
89+
template <typename argTy, typename resTy, typename IndexerT>
90+
using LogicalNotStridedFunctor =
91+
elementwise_common::UnaryStridedFunctor<argTy,
92+
resTy,
93+
IndexerT,
94+
LogicalNotFunctor<argTy, resTy>>;
95+
96+
template <typename argTy>
97+
struct LogicalNotOutputType
98+
{
99+
using value_type = bool;
100+
};
101+
102+
namespace hyperparam_detail
103+
{
104+
105+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
106+
107+
using vsu_ns::ContigHyperparameterSetDefault;
108+
using vsu_ns::UnaryContigHyperparameterSetEntry;
109+
110+
template <typename argTy>
111+
struct LogicalNotContigHyperparameterSet
112+
{
113+
using value_type =
114+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
115+
116+
constexpr static auto vec_sz = value_type::vec_sz;
117+
constexpr static auto n_vecs = value_type::n_vecs;
118+
};
119+
120+
} // end of namespace hyperparam_detail
121+
122+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
123+
class logical_not_contig_kernel;
124+
125+
template <typename argTy>
126+
sycl::event
127+
logical_not_contig_impl(sycl::queue &exec_q,
128+
std::size_t nelems,
129+
const char *arg_p,
130+
char *res_p,
131+
const std::vector<sycl::event> &depends = {})
132+
{
133+
using LogicalNotHS =
134+
hyperparam_detail::LogicalNotContigHyperparameterSet<argTy>;
135+
static constexpr std::uint8_t vec_sz = LogicalNotHS::vec_sz;
136+
static constexpr std::uint8_t n_vecs = LogicalNotHS::n_vecs;
137+
138+
return elementwise_common::unary_contig_impl<
139+
argTy, LogicalNotOutputType, LogicalNotContigFunctor,
140+
logical_not_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg_p, res_p,
141+
depends);
142+
}
143+
144+
template <typename fnT, typename T>
145+
struct LogicalNotContigFactory
146+
{
147+
fnT get()
148+
{
149+
fnT fn = logical_not_contig_impl<T>;
150+
return fn;
151+
}
152+
};
153+
154+
template <typename fnT, typename T>
155+
struct LogicalNotTypeMapFactory
156+
{
157+
/*! @brief get typeid for output type of sycl::logical_not(T x) */
158+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
159+
{
160+
using rT = typename LogicalNotOutputType<T>::value_type;
161+
return td_ns::GetTypeid<rT>{}.get();
162+
}
163+
};
164+
165+
template <typename T1, typename T2, typename T3>
166+
class logical_not_strided_kernel;
167+
168+
template <typename argTy>
169+
sycl::event
170+
logical_not_strided_impl(sycl::queue &exec_q,
171+
std::size_t nelems,
172+
int nd,
173+
const ssize_t *shape_and_strides,
174+
const char *arg_p,
175+
ssize_t arg_offset,
176+
char *res_p,
177+
ssize_t res_offset,
178+
const std::vector<sycl::event> &depends,
179+
const std::vector<sycl::event> &additional_depends)
180+
{
181+
return elementwise_common::unary_strided_impl<argTy, LogicalNotOutputType,
182+
LogicalNotStridedFunctor,
183+
logical_not_strided_kernel>(
184+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
185+
res_offset, depends, additional_depends);
186+
}
187+
188+
template <typename fnT, typename T>
189+
struct LogicalNotStridedFactory
190+
{
191+
fnT get()
192+
{
193+
fnT fn = logical_not_strided_impl<T>;
194+
return fn;
195+
}
196+
};
197+
198+
} // namespace dpctl::tensor::kernels::logical_not

0 commit comments

Comments
 (0)