Skip to content

Commit 3041d7d

Browse files
Move _logsumexp/hypot_over_axis to _tensor_reductions_impl
1 parent 6b27b1b commit 3041d7d

File tree

6 files changed

+612
-6
lines changed

6 files changed

+612
-6
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ set(_reduction_sources
7575
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/any.cpp
7676
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp
7777
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp
78-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp
7979
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/max.cpp
8080
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/min.cpp
8181
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/prod.cpp
82-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp
82+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp
8383
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
8484
)
8585
set(_sorting_sources
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_reductions_impl
33+
/// extension.
34+
//===---------------------------------------------------------------------===//
35+
36+
#include <cstdint>
37+
#include <type_traits>
38+
#include <vector>
39+
40+
#include <sycl/sycl.hpp>
41+
42+
#include "dpnp4pybind11.hpp"
43+
#include <pybind11/pybind11.h>
44+
#include <pybind11/stl.h>
45+
46+
#include "kernels/reductions.hpp"
47+
#include "reduction_over_axis.hpp"
48+
#include "utils/sycl_utils.hpp"
49+
#include "utils/type_dispatch_building.hpp"
50+
51+
namespace dpctl::tensor::py_internal
52+
{
53+
54+
namespace py = pybind11;
55+
namespace td_ns = dpctl::tensor::type_dispatch;
56+
namespace su_ns = dpctl::tensor::sycl_utils;
57+
58+
namespace impl
59+
{
60+
61+
using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
62+
static reduction_strided_impl_fn_ptr
63+
logsumexp_over_axis_strided_temps_dispatch_table[td_ns::num_types]
64+
[td_ns::num_types];
65+
66+
using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
67+
static reduction_contig_impl_fn_ptr
68+
logsumexp_over_axis1_contig_temps_dispatch_table[td_ns::num_types]
69+
[td_ns::num_types];
70+
static reduction_contig_impl_fn_ptr
71+
logsumexp_over_axis0_contig_temps_dispatch_table[td_ns::num_types]
72+
[td_ns::num_types];
73+
74+
template <typename argTy, typename outTy>
75+
struct TypePairSupportDataForLogSumExpReductionTemps
76+
{
77+
78+
static constexpr bool is_defined = std::disjunction<
79+
#if 1
80+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, sycl::half>,
81+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
82+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
83+
84+
// input int8_t
85+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, sycl::half>,
86+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
87+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
88+
89+
// input uint8_t
90+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, sycl::half>,
91+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
92+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
93+
94+
// input int16_t
95+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
96+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
97+
98+
// input uint16_t
99+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
100+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
101+
102+
// input int32_t
103+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
104+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
105+
106+
// input uint32_t
107+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
108+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
109+
110+
// input int64_t
111+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, float>,
112+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
113+
114+
// input uint64_t
115+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, float>,
116+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
117+
// input half
118+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
119+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
120+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, double>,
121+
122+
// input float
123+
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
124+
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
125+
126+
// input double
127+
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
128+
#endif
129+
130+
// fall-through
131+
td_ns::NotDefinedEntry>::is_defined;
132+
};
133+
134+
template <typename fnT, typename srcTy, typename dstTy>
135+
struct LogSumExpOverAxisTempsStridedFactory
136+
{
137+
fnT get() const
138+
{
139+
if constexpr (TypePairSupportDataForLogSumExpReductionTemps<
140+
srcTy, dstTy>::is_defined)
141+
{
142+
using ReductionOpT = su_ns::LogSumExp<dstTy>;
143+
return dpctl::tensor::kernels::
144+
reduction_over_group_temps_strided_impl<srcTy, dstTy,
145+
ReductionOpT>;
146+
}
147+
else {
148+
return nullptr;
149+
}
150+
}
151+
};
152+
153+
template <typename fnT, typename srcTy, typename dstTy>
154+
struct LogSumExpOverAxis1TempsContigFactory
155+
{
156+
fnT get() const
157+
{
158+
if constexpr (TypePairSupportDataForLogSumExpReductionTemps<
159+
srcTy, dstTy>::is_defined)
160+
{
161+
using ReductionOpT = su_ns::LogSumExp<dstTy>;
162+
return dpctl::tensor::kernels::
163+
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
164+
ReductionOpT>;
165+
}
166+
else {
167+
return nullptr;
168+
}
169+
}
170+
};
171+
172+
template <typename fnT, typename srcTy, typename dstTy>
173+
struct LogSumExpOverAxis0TempsContigFactory
174+
{
175+
fnT get() const
176+
{
177+
if constexpr (TypePairSupportDataForLogSumExpReductionTemps<
178+
srcTy, dstTy>::is_defined)
179+
{
180+
using ReductionOpT = su_ns::LogSumExp<dstTy>;
181+
return dpctl::tensor::kernels::
182+
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
183+
ReductionOpT>;
184+
}
185+
else {
186+
return nullptr;
187+
}
188+
}
189+
};
190+
191+
void populate_logsumexp_over_axis_dispatch_tables(void)
192+
{
193+
using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
194+
using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
195+
using namespace td_ns;
196+
197+
DispatchTableBuilder<reduction_strided_impl_fn_ptr,
198+
LogSumExpOverAxisTempsStridedFactory, num_types>
199+
dtb1;
200+
dtb1.populate_dispatch_table(
201+
logsumexp_over_axis_strided_temps_dispatch_table);
202+
203+
DispatchTableBuilder<reduction_contig_impl_fn_ptr,
204+
LogSumExpOverAxis1TempsContigFactory, td_ns::num_types>
205+
dtb2;
206+
dtb2.populate_dispatch_table(
207+
logsumexp_over_axis1_contig_temps_dispatch_table);
208+
209+
DispatchTableBuilder<reduction_contig_impl_fn_ptr,
210+
LogSumExpOverAxis0TempsContigFactory, td_ns::num_types>
211+
dtb3;
212+
dtb3.populate_dispatch_table(
213+
logsumexp_over_axis0_contig_temps_dispatch_table);
214+
}
215+
216+
} // namespace impl
217+
218+
void init_logsumexp(py::module_ m)
219+
{
220+
using arrayT = dpctl::tensor::usm_ndarray;
221+
using event_vecT = std::vector<sycl::event>;
222+
{
223+
using impl::populate_logsumexp_over_axis_dispatch_tables;
224+
populate_logsumexp_over_axis_dispatch_tables();
225+
using impl::logsumexp_over_axis0_contig_temps_dispatch_table;
226+
using impl::logsumexp_over_axis1_contig_temps_dispatch_table;
227+
using impl::logsumexp_over_axis_strided_temps_dispatch_table;
228+
229+
using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
230+
using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
231+
232+
auto logsumexp_pyapi = [&](const arrayT &src,
233+
int trailing_dims_to_reduce,
234+
const arrayT &dst, sycl::queue &exec_q,
235+
const event_vecT &depends = {}) {
236+
using dpctl::tensor::py_internal::py_tree_reduction_over_axis;
237+
return py_tree_reduction_over_axis(
238+
src, trailing_dims_to_reduce, dst, exec_q, depends,
239+
logsumexp_over_axis_strided_temps_dispatch_table,
240+
logsumexp_over_axis0_contig_temps_dispatch_table,
241+
logsumexp_over_axis1_contig_temps_dispatch_table);
242+
};
243+
m.def("_logsumexp_over_axis", logsumexp_pyapi, "", py::arg("src"),
244+
py::arg("trailing_dims_to_reduce"), py::arg("dst"),
245+
py::arg("sycl_queue"), py::arg("depends") = py::list());
246+
247+
auto logsumexp_dtype_supported = [&](const py::dtype &input_dtype,
248+
const py::dtype &output_dtype) {
249+
using dpctl::tensor::py_internal::py_tree_reduction_dtype_supported;
250+
return py_tree_reduction_dtype_supported(
251+
input_dtype, output_dtype,
252+
logsumexp_over_axis_strided_temps_dispatch_table);
253+
};
254+
m.def("_logsumexp_over_axis_dtype_supported", logsumexp_dtype_supported,
255+
"", py::arg("arg_dtype"), py::arg("out_dtype"));
256+
}
257+
}
258+
259+
} // namespace dpctl::tensor::py_internal
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_reductions_impl
33+
/// extension.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <pybind11/pybind11.h>
38+
39+
namespace py = pybind11;
40+
41+
namespace dpctl::tensor::py_internal
42+
{
43+
44+
extern void init_logsumexp(py::module_ m);
45+
46+
} // namespace dpctl::tensor::py_internal

0 commit comments

Comments
 (0)