Skip to content

Commit 72d2109

Browse files
Move _cumlogsumexp_over_axis to dpctl_ext.tensor._tensor_accumulation_impl
1 parent 668f3fb commit 72d2109

File tree

4 files changed

+400
-3
lines changed

4 files changed

+400
-3
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ set(_tensor_impl_sources
6565
)
6666
set(_accumulator_sources
6767
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
68-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
68+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
6969
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
7070
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
7171
)

dpctl_ext/tensor/libtensor/source/accumulators/accumulators_common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
#include <pybind11/pybind11.h>
3737

38-
// #include "cumulative_logsumexp.hpp"
38+
#include "cumulative_logsumexp.hpp"
3939
#include "cumulative_prod.hpp"
4040
#include "cumulative_sum.hpp"
4141

@@ -47,7 +47,7 @@ namespace dpctl::tensor::py_internal
4747
/*! @brief Add accumulators to Python module */
4848
void init_accumulator_functions(py::module_ m)
4949
{
50-
// init_cumulative_logsumexp(m);
50+
init_cumulative_logsumexp(m);
5151
init_cumulative_prod(m);
5252
init_cumulative_sum(m);
5353
}
Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_accumulation_impl
33+
// extensions
34+
//===----------------------------------------------------------------------===//
35+
36+
#include <cstdint>
37+
#include <type_traits>
38+
#include <vector>
39+
40+
#include <sycl/sycl.hpp>
41+
42+
#include "dpnp4pybind11.hpp"
43+
#include <pybind11/numpy.h>
44+
#include <pybind11/pybind11.h>
45+
#include <pybind11/stl.h>
46+
47+
#include "accumulate_over_axis.hpp"
48+
#include "kernels/accumulators.hpp"
49+
#include "utils/sycl_utils.hpp"
50+
#include "utils/type_dispatch_building.hpp"
51+
52+
namespace py = pybind11;
53+
54+
namespace dpctl::tensor::py_internal
55+
{
56+
57+
namespace su_ns = dpctl::tensor::sycl_utils;
58+
namespace td_ns = dpctl::tensor::type_dispatch;
59+
60+
namespace impl
61+
{
62+
63+
using dpctl::tensor::kernels::accumulators::accumulate_1d_contig_impl_fn_ptr_t;
64+
static accumulate_1d_contig_impl_fn_ptr_t
65+
cumlogsumexp_1d_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
66+
67+
using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t;
68+
static accumulate_strided_impl_fn_ptr_t
69+
cumlogsumexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types];
70+
71+
static accumulate_1d_contig_impl_fn_ptr_t
72+
cumlogsumexp_1d_include_initial_contig_dispatch_table[td_ns::num_types]
73+
[td_ns::num_types];
74+
75+
static accumulate_strided_impl_fn_ptr_t
76+
cumlogsumexp_include_initial_strided_dispatch_table[td_ns::num_types]
77+
[td_ns::num_types];
78+
79+
template <typename argTy, typename outTy>
80+
struct TypePairSupportDataForLogSumExpAccumulation
81+
{
82+
static constexpr bool is_defined = std::disjunction<
83+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, sycl::half>,
84+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
85+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
86+
87+
// input int8_t
88+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, sycl::half>,
89+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
90+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
91+
92+
// input uint8_t
93+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, sycl::half>,
94+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
95+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
96+
97+
// input int16_t
98+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
99+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
100+
101+
// input uint16_t
102+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
103+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
104+
105+
// input int32_t
106+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
107+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
108+
109+
// input uint32_t
110+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
111+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
112+
113+
// input int64_t
114+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, float>,
115+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
116+
117+
// input uint64_t
118+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, float>,
119+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
120+
121+
// input half
122+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
123+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
124+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, double>,
125+
126+
// input float
127+
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
128+
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
129+
130+
// input double
131+
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
132+
133+
// fall-through
134+
td_ns::NotDefinedEntry>::is_defined;
135+
};
136+
137+
template <typename fnT, typename srcTy, typename dstTy>
138+
struct CumLogSumExp1DContigFactory
139+
{
140+
fnT get()
141+
{
142+
if constexpr (TypePairSupportDataForLogSumExpAccumulation<
143+
srcTy, dstTy>::is_defined)
144+
{
145+
using ScanOpT = su_ns::LogSumExp<dstTy>;
146+
static constexpr bool include_initial = false;
147+
if constexpr (std::is_same_v<srcTy, dstTy>) {
148+
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
149+
fnT fn = dpctl::tensor::kernels::accumulators::
150+
accumulate_1d_contig_impl<srcTy, dstTy,
151+
NoOpTransformer<dstTy>, ScanOpT,
152+
include_initial>;
153+
return fn;
154+
}
155+
else {
156+
using dpctl::tensor::kernels::accumulators::CastTransformer;
157+
fnT fn = dpctl::tensor::kernels::accumulators::
158+
accumulate_1d_contig_impl<srcTy, dstTy,
159+
CastTransformer<srcTy, dstTy>,
160+
ScanOpT, include_initial>;
161+
return fn;
162+
}
163+
}
164+
else {
165+
return nullptr;
166+
}
167+
}
168+
};
169+
170+
template <typename fnT, typename srcTy, typename dstTy>
171+
struct CumLogSumExp1DIncludeInitialContigFactory
172+
{
173+
fnT get()
174+
{
175+
if constexpr (TypePairSupportDataForLogSumExpAccumulation<
176+
srcTy, dstTy>::is_defined)
177+
{
178+
using ScanOpT = su_ns::LogSumExp<dstTy>;
179+
static constexpr bool include_initial = true;
180+
if constexpr (std::is_same_v<srcTy, dstTy>) {
181+
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
182+
fnT fn = dpctl::tensor::kernels::accumulators::
183+
accumulate_1d_contig_impl<srcTy, dstTy,
184+
NoOpTransformer<dstTy>, ScanOpT,
185+
include_initial>;
186+
return fn;
187+
}
188+
else {
189+
using dpctl::tensor::kernels::accumulators::CastTransformer;
190+
fnT fn = dpctl::tensor::kernels::accumulators::
191+
accumulate_1d_contig_impl<srcTy, dstTy,
192+
CastTransformer<srcTy, dstTy>,
193+
ScanOpT, include_initial>;
194+
return fn;
195+
}
196+
}
197+
else {
198+
return nullptr;
199+
}
200+
}
201+
};
202+
203+
template <typename fnT, typename srcTy, typename dstTy>
204+
struct CumLogSumExpStridedFactory
205+
{
206+
fnT get()
207+
{
208+
if constexpr (TypePairSupportDataForLogSumExpAccumulation<
209+
srcTy, dstTy>::is_defined)
210+
{
211+
using ScanOpT = su_ns::LogSumExp<dstTy>;
212+
static constexpr bool include_initial = false;
213+
if constexpr (std::is_same_v<srcTy, dstTy>) {
214+
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
215+
fnT fn = dpctl::tensor::kernels::accumulators::
216+
accumulate_strided_impl<srcTy, dstTy,
217+
NoOpTransformer<dstTy>, ScanOpT,
218+
include_initial>;
219+
return fn;
220+
}
221+
else {
222+
using dpctl::tensor::kernels::accumulators::CastTransformer;
223+
fnT fn = dpctl::tensor::kernels::accumulators::
224+
accumulate_strided_impl<srcTy, dstTy,
225+
CastTransformer<srcTy, dstTy>,
226+
ScanOpT, include_initial>;
227+
return fn;
228+
}
229+
}
230+
else {
231+
return nullptr;
232+
}
233+
}
234+
};
235+
236+
template <typename fnT, typename srcTy, typename dstTy>
237+
struct CumLogSumExpIncludeInitialStridedFactory
238+
{
239+
fnT get()
240+
{
241+
if constexpr (TypePairSupportDataForLogSumExpAccumulation<
242+
srcTy, dstTy>::is_defined)
243+
{
244+
using ScanOpT = su_ns::LogSumExp<dstTy>;
245+
static constexpr bool include_initial = true;
246+
if constexpr (std::is_same_v<srcTy, dstTy>) {
247+
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
248+
fnT fn = dpctl::tensor::kernels::accumulators::
249+
accumulate_strided_impl<srcTy, dstTy,
250+
NoOpTransformer<dstTy>, ScanOpT,
251+
include_initial>;
252+
return fn;
253+
}
254+
else {
255+
using dpctl::tensor::kernels::accumulators::CastTransformer;
256+
fnT fn = dpctl::tensor::kernels::accumulators::
257+
accumulate_strided_impl<srcTy, dstTy,
258+
CastTransformer<srcTy, dstTy>,
259+
ScanOpT, include_initial>;
260+
return fn;
261+
}
262+
}
263+
else {
264+
return nullptr;
265+
}
266+
}
267+
};
268+
269+
void populate_cumlogsumexp_dispatch_tables(void)
270+
{
271+
td_ns::DispatchTableBuilder<accumulate_1d_contig_impl_fn_ptr_t,
272+
CumLogSumExp1DContigFactory, td_ns::num_types>
273+
dtb1;
274+
dtb1.populate_dispatch_table(cumlogsumexp_1d_contig_dispatch_table);
275+
276+
td_ns::DispatchTableBuilder<accumulate_strided_impl_fn_ptr_t,
277+
CumLogSumExpStridedFactory, td_ns::num_types>
278+
dtb2;
279+
dtb2.populate_dispatch_table(cumlogsumexp_strided_dispatch_table);
280+
281+
td_ns::DispatchTableBuilder<accumulate_1d_contig_impl_fn_ptr_t,
282+
CumLogSumExp1DIncludeInitialContigFactory,
283+
td_ns::num_types>
284+
dtb3;
285+
dtb3.populate_dispatch_table(
286+
cumlogsumexp_1d_include_initial_contig_dispatch_table);
287+
288+
td_ns::DispatchTableBuilder<accumulate_strided_impl_fn_ptr_t,
289+
CumLogSumExpIncludeInitialStridedFactory,
290+
td_ns::num_types>
291+
dtb4;
292+
dtb4.populate_dispatch_table(
293+
cumlogsumexp_include_initial_strided_dispatch_table);
294+
295+
return;
296+
}
297+
298+
} // namespace impl
299+
300+
void init_cumulative_logsumexp(py::module_ m)
301+
{
302+
using arrayT = dpctl::tensor::usm_ndarray;
303+
using event_vecT = std::vector<sycl::event>;
304+
305+
using impl::populate_cumlogsumexp_dispatch_tables;
306+
populate_cumlogsumexp_dispatch_tables();
307+
308+
using impl::cumlogsumexp_1d_contig_dispatch_table;
309+
using impl::cumlogsumexp_strided_dispatch_table;
310+
auto cumlogsumexp_pyapi = [&](const arrayT &src,
311+
int trailing_dims_to_accumulate,
312+
const arrayT &dst, sycl::queue &exec_q,
313+
const event_vecT &depends = {}) {
314+
using dpctl::tensor::py_internal::py_accumulate_over_axis;
315+
return py_accumulate_over_axis(src, trailing_dims_to_accumulate, dst,
316+
exec_q, depends,
317+
cumlogsumexp_strided_dispatch_table,
318+
cumlogsumexp_1d_contig_dispatch_table);
319+
};
320+
m.def("_cumlogsumexp_over_axis", cumlogsumexp_pyapi, "", py::arg("src"),
321+
py::arg("trailing_dims_to_accumulate"), py::arg("dst"),
322+
py::arg("sycl_queue"), py::arg("depends") = py::list());
323+
324+
using impl::cumlogsumexp_1d_include_initial_contig_dispatch_table;
325+
using impl::cumlogsumexp_include_initial_strided_dispatch_table;
326+
auto cumlogsumexp_include_initial_pyapi =
327+
[&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q,
328+
const event_vecT &depends = {}) {
329+
using dpctl::tensor::py_internal::
330+
py_accumulate_final_axis_include_initial;
331+
return py_accumulate_final_axis_include_initial(
332+
src, dst, exec_q, depends,
333+
cumlogsumexp_include_initial_strided_dispatch_table,
334+
cumlogsumexp_1d_include_initial_contig_dispatch_table);
335+
};
336+
m.def("_cumlogsumexp_final_axis_include_initial",
337+
cumlogsumexp_include_initial_pyapi, "", py::arg("src"),
338+
py::arg("dst"), py::arg("sycl_queue"),
339+
py::arg("depends") = py::list());
340+
341+
auto cumlogsumexp_dtype_supported = [&](const py::dtype &input_dtype,
342+
const py::dtype &output_dtype) {
343+
using dpctl::tensor::py_internal::py_accumulate_dtype_supported;
344+
return py_accumulate_dtype_supported(
345+
input_dtype, output_dtype, cumlogsumexp_strided_dispatch_table);
346+
};
347+
m.def("_cumlogsumexp_dtype_supported", cumlogsumexp_dtype_supported, "",
348+
py::arg("arg_dtype"), py::arg("out_dtype"));
349+
}
350+
351+
} // namespace dpctl::tensor::py_internal

0 commit comments

Comments
 (0)