Skip to content

Commit 28288ae

Browse files
Move ti.logaddexp() and reuse it
1 parent fefaa17 commit 28288ae

File tree

7 files changed

+284
-63
lines changed

7 files changed

+284
-63
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ set(_elementwise_sources
113113
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log1p.cpp
114114
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log2.cpp
115115
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log10.cpp
116-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logaddexp.cpp
116+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logaddexp.cpp
117117
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_and.cpp
118118
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_not.cpp
119119
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_or.cpp

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,43 @@
12061206
)
12071207
del _log10_docstring_
12081208

1209+
# B15: ==== LOGADDEXP (x1, x2)
1210+
_logaddexp_docstring_ = r"""
1211+
logaddexp(x1, x2, /, \*, out=None, order='K')
1212+
1213+
Calculates the natural logarithm of the sum of exponentials for each element
1214+
`x1_i` of the input array `x1` with the respective element `x2_i` of the input
1215+
array `x2`.
1216+
1217+
This function calculates `log(exp(x1) + exp(x2))` more accurately for small
1218+
values of `x`.
1219+
1220+
Args:
1221+
x1 (usm_ndarray):
1222+
First input array, expected to have a real-valued floating-point data
1223+
type.
1224+
x2 (usm_ndarray):
1225+
Second input array, also expected to have a real-valued floating-point
1226+
data type.
1227+
out (Union[usm_ndarray, None], optional):
1228+
Output array to populate.
1229+
Array must have the correct shape and the expected data type.
1230+
order ("C","F","A","K", optional):
1231+
Memory layout of the new output array, if parameter
1232+
`out` is ``None``.
1233+
Default: "K".
1234+
1235+
Returns:
1236+
usm_ndarray:
1237+
An array containing the element-wise results. The data type
1238+
of the returned array is determined by the Type Promotion Rules.
1239+
"""
1240+
1241+
logaddexp = BinaryElementwiseFunc(
1242+
"logaddexp", ti._logaddexp_result_type, ti._logaddexp, _logaddexp_docstring_
1243+
)
1244+
del _logaddexp_docstring_
1245+
12091246
# U24: ==== LOGICAL_NOT (x)
12101247
_logical_not_docstring = r"""
12111248
logical_not(x, /, \*, out=None, order='K')

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,13 @@
4949
#include "utils/math_utils.hpp"
5050
#include "utils/offset_utils.hpp"
5151
#include "utils/type_dispatch_building.hpp"
52-
#include "utils/type_utils.hpp"
5352

5453
#include "kernels/dpctl_tensor_types.hpp"
5554

5655
namespace dpctl::tensor::kernels::logaddexp
5756
{
5857
using dpctl::tensor::ssize_t;
5958
namespace td_ns = dpctl::tensor::type_dispatch;
60-
namespace tu_ns = dpctl::tensor::type_utils;
61-
62-
using dpctl::tensor::type_utils::is_complex;
63-
using dpctl::tensor::type_utils::vec_cast;
6459

6560
template <typename argT1, typename argT2, typename resT>
6661
struct LogAddExpFunctor

dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
#include "log10.hpp"
7777
#include "log1p.hpp"
7878
#include "log2.hpp"
79-
// #include "logaddexp.hpp"
79+
#include "logaddexp.hpp"
8080
// #include "logical_and.hpp"
8181
#include "logical_not.hpp"
8282
// #include "logical_or.hpp"
@@ -157,7 +157,7 @@ void init_elementwise_functions(py::module_ m)
157157
init_log10(m);
158158
init_log1p(m);
159159
init_log2(m);
160-
// init_logaddexp(m);
160+
init_logaddexp(m);
161161
// init_logical_and(m);
162162
init_logical_not(m);
163163
// init_logical_or(m);
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_elementwise_impl
33+
/// extension, specifically functions for elementwise operations.
34+
//===---------------------------------------------------------------------===//
35+
36+
#include <vector>
37+
38+
#include <sycl/sycl.hpp>
39+
40+
#include "dpnp4pybind11.hpp"
41+
#include <pybind11/numpy.h>
42+
#include <pybind11/pybind11.h>
43+
#include <pybind11/stl.h>
44+
45+
#include "elementwise_functions.hpp"
46+
#include "logaddexp.hpp"
47+
#include "utils/type_dispatch.hpp"
48+
49+
#include "kernels/elementwise_functions/common.hpp"
50+
#include "kernels/elementwise_functions/logaddexp.hpp"
51+
52+
namespace dpctl::tensor::py_internal
53+
{
54+
55+
namespace py = pybind11;
56+
namespace td_ns = dpctl::tensor::type_dispatch;
57+
58+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
59+
using ew_cmn_ns::binary_contig_impl_fn_ptr_t;
60+
using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t;
61+
using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t;
62+
using ew_cmn_ns::binary_strided_impl_fn_ptr_t;
63+
64+
// B15: ===== LOGADDEXP (x1, x2)
65+
namespace impl
66+
{
67+
namespace logaddexp_fn_ns = dpctl::tensor::kernels::logaddexp;
68+
69+
static binary_contig_impl_fn_ptr_t
70+
logaddexp_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
71+
static int logaddexp_output_id_table[td_ns::num_types][td_ns::num_types];
72+
73+
static binary_strided_impl_fn_ptr_t
74+
logaddexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types];
75+
76+
void populate_logaddexp_dispatch_tables(void)
77+
{
78+
using namespace td_ns;
79+
namespace fn_ns = logaddexp_fn_ns;
80+
81+
// which input types are supported, and what is the type of the result
82+
using fn_ns::LogAddExpTypeMapFactory;
83+
DispatchTableBuilder<int, LogAddExpTypeMapFactory, num_types> dtb1;
84+
dtb1.populate_dispatch_table(logaddexp_output_id_table);
85+
86+
// function pointers for operation on general strided arrays
87+
using fn_ns::LogAddExpStridedFactory;
88+
DispatchTableBuilder<binary_strided_impl_fn_ptr_t, LogAddExpStridedFactory,
89+
num_types>
90+
dtb2;
91+
dtb2.populate_dispatch_table(logaddexp_strided_dispatch_table);
92+
93+
// function pointers for operation on contiguous inputs and output
94+
using fn_ns::LogAddExpContigFactory;
95+
DispatchTableBuilder<binary_contig_impl_fn_ptr_t, LogAddExpContigFactory,
96+
num_types>
97+
dtb3;
98+
dtb3.populate_dispatch_table(logaddexp_contig_dispatch_table);
99+
};
100+
101+
} // namespace impl
102+
103+
void init_logaddexp(py::module_ m)
104+
{
105+
using arrayT = dpctl::tensor::usm_ndarray;
106+
using event_vecT = std::vector<sycl::event>;
107+
{
108+
impl::populate_logaddexp_dispatch_tables();
109+
using impl::logaddexp_contig_dispatch_table;
110+
using impl::logaddexp_output_id_table;
111+
using impl::logaddexp_strided_dispatch_table;
112+
113+
auto logaddexp_pyapi = [&](const arrayT &src1, const arrayT &src2,
114+
const arrayT &dst, sycl::queue &exec_q,
115+
const event_vecT &depends = {}) {
116+
return py_binary_ufunc(
117+
src1, src2, dst, exec_q, depends, logaddexp_output_id_table,
118+
// function pointers to handle operation on contiguous arrays
119+
// (pointers may be nullptr)
120+
logaddexp_contig_dispatch_table,
121+
// function pointers to handle operation on strided arrays (most
122+
// general case)
123+
logaddexp_strided_dispatch_table,
124+
// function pointers to handle operation of c-contig matrix and
125+
// c-contig row with broadcasting (may be nullptr)
126+
td_ns::NullPtrTable<
127+
binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{},
128+
// function pointers to handle operation of c-contig matrix and
129+
// c-contig row with broadcasting (may be nullptr)
130+
td_ns::NullPtrTable<
131+
binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{});
132+
};
133+
auto logaddexp_result_type_pyapi = [&](const py::dtype &dtype1,
134+
const py::dtype &dtype2) {
135+
return py_binary_ufunc_result_type(dtype1, dtype2,
136+
logaddexp_output_id_table);
137+
};
138+
m.def("_logaddexp", logaddexp_pyapi, "", py::arg("src1"),
139+
py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"),
140+
py::arg("depends") = py::list());
141+
m.def("_logaddexp_result_type", logaddexp_result_type_pyapi, "");
142+
}
143+
}
144+
145+
} // namespace dpctl::tensor::py_internal
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_elementwise_impl
33+
/// extension, specifically functions for elementwise operations.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <pybind11/pybind11.h>
38+
39+
namespace py = pybind11;
40+
41+
namespace dpctl::tensor::py_internal
42+
{
43+
44+
extern void init_logaddexp(py::module_ m);
45+
46+
} // namespace dpctl::tensor::py_internal

0 commit comments

Comments
 (0)