Skip to content

Commit e96405c

Browse files
Move _tensor_accumulation_impl extension and use it for dpnp (#2791)
This PR completely moves `_tensor_accumulation_impl` pybind11 extension into `dpctl_ext.tensor` and extends `dpctl_ext.tensor` Python API with the functions `cumulative_logsumexp, cumulative_prod and cumulative_sum` reusing them in dpnp
1 parent 585f2e5 commit e96405c

15 files changed

+2305
-13
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ set(_tensor_impl_sources
6363
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
6464
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
6565
)
66+
set(_accumulator_sources
67+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
68+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
69+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
70+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
71+
)
72+
set(_tensor_accumulation_impl_sources
73+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
74+
${_accumulator_sources}
75+
)
6676

6777
set(_static_lib_trgt simplify_iteration_space)
6878

@@ -85,6 +95,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_impl_sources})
8595
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
8696
list(APPEND _py_trgts ${python_module_name})
8797

98+
set(python_module_name _tensor_accumulation_impl)
99+
pybind11_add_module(${python_module_name} MODULE ${_tensor_accumulation_impl_sources})
100+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_accumulation_impl_sources})
101+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
102+
list(APPEND _py_trgts ${python_module_name})
103+
88104
set(_clang_prefix "")
89105
if(WIN32)
90106
set(_clang_prefix "/clang:")
@@ -97,14 +113,14 @@ set(_no_fast_math_sources
97113
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
98114
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
99115
)
100-
#list(
101-
#APPEND _no_fast_math_sources
102-
# ${_elementwise_sources}
103-
# ${_reduction_sources}
104-
# ${_sorting_sources}
105-
# ${_linalg_sources}
106-
# ${_accumulator_sources}
107-
#)
116+
list(
117+
APPEND _no_fast_math_sources
118+
# ${_elementwise_sources}
119+
# ${_reduction_sources}
120+
# ${_sorting_sources}
121+
# ${_linalg_sources}
122+
${_accumulator_sources}
123+
)
108124

109125
foreach(_src_fn ${_no_fast_math_sources})
110126
get_source_file_property(_cmpl_options_prop ${_src_fn} COMPILE_OPTIONS)

dpctl_ext/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# *****************************************************************************
2828

2929

30+
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
3031
from ._clip import clip
3132
from ._copy_utils import (
3233
asnumpy,
@@ -92,6 +93,9 @@
9293
"concat",
9394
"copy",
9495
"clip",
96+
"cumulative_logsumexp",
97+
"cumulative_prod",
98+
"cumulative_sum",
9599
"empty",
96100
"empty_like",
97101
"extract",

0 commit comments

Comments
 (0)