Skip to content

Commit 6cc6b6c

Browse files
Move _tensor_reductions_impl extension and use it for dpnp (#2794)
This PR completely moves `_tensor_reductions_impl` pybind11 extension into `dpctl_ext.tensor` and extends dpctl_ext.tensor Python API with the functions: `all, any, diff, argmax, argmin, count_nonzero, logsumexp, max. min, prod, reduce_hypot and sum` reusing them in dpnp
1 parent dd62b6e commit 6cc6b6c

38 files changed

+9970
-30
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,19 @@ set(_accumulator_sources
6969
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
7070
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
7171
)
72+
set(_reduction_sources
73+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp
74+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/all.cpp
75+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/any.cpp
76+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp
77+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp
79+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/max.cpp
80+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/min.cpp
81+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/prod.cpp
82+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp
83+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
84+
)
7285
set(_sorting_sources
7386
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
7487
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
@@ -82,6 +95,10 @@ set(_tensor_accumulation_impl_sources
8295
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
8396
${_accumulator_sources}
8497
)
98+
set(_tensor_reductions_impl_sources
99+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_reductions.cpp
100+
${_reduction_sources}
101+
)
85102
set(_tensor_sorting_impl_sources
86103
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
87104
${_sorting_sources}
@@ -114,6 +131,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_accumulation_i
114131
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
115132
list(APPEND _py_trgts ${python_module_name})
116133

134+
set(python_module_name _tensor_reductions_impl)
135+
pybind11_add_module(${python_module_name} MODULE ${_tensor_reductions_impl_sources})
136+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_reductions_impl_sources})
137+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
138+
list(APPEND _py_trgts ${python_module_name})
139+
117140
set(python_module_name _tensor_sorting_impl)
118141
pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources})
119142
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources})
@@ -135,7 +158,7 @@ set(_no_fast_math_sources
135158
list(
136159
APPEND _no_fast_math_sources
137160
# ${_elementwise_sources}
138-
# ${_reduction_sources}
161+
${_reduction_sources}
139162
${_sorting_sources}
140163
# ${_linalg_sources}
141164
${_accumulator_sources}

dpctl_ext/tensor/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@
7878
tile,
7979
unstack,
8080
)
81+
from ._reduction import (
82+
argmax,
83+
argmin,
84+
count_nonzero,
85+
logsumexp,
86+
max,
87+
min,
88+
prod,
89+
reduce_hypot,
90+
sum,
91+
)
8192
from ._reshape import reshape
8293
from ._search_functions import where
8394
from ._searchsorted import searchsorted
@@ -90,9 +101,14 @@
90101
)
91102
from ._sorting import argsort, sort, top_k
92103
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
104+
from ._utility_functions import all, any, diff
93105

94106
__all__ = [
107+
"all",
108+
"any",
95109
"arange",
110+
"argmax",
111+
"argmin",
96112
"argsort",
97113
"asarray",
98114
"asnumpy",
@@ -102,10 +118,12 @@
102118
"can_cast",
103119
"concat",
104120
"copy",
121+
"count_nonzero",
105122
"clip",
106123
"cumulative_logsumexp",
107124
"cumulative_prod",
108125
"cumulative_sum",
126+
"diff",
109127
"empty",
110128
"empty_like",
111129
"extract",
@@ -120,15 +138,20 @@
120138
"isdtype",
121139
"isin",
122140
"linspace",
141+
"logsumexp",
142+
"max",
123143
"meshgrid",
144+
"min",
124145
"moveaxis",
125146
"permute_dims",
126147
"nonzero",
127148
"ones",
128149
"ones_like",
129150
"place",
151+
"prod",
130152
"put",
131153
"put_along_axis",
154+
"reduce_hypot",
132155
"repeat",
133156
"reshape",
134157
"result_type",
@@ -137,6 +160,7 @@
137160
"sort",
138161
"squeeze",
139162
"stack",
163+
"sum",
140164
"swapaxes",
141165
"take",
142166
"take_along_axis",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def repeat(x, repeats, /, *, axis=None):
624624
"'repeats' array must be broadcastable to the size of "
625625
"the repeated axis"
626626
)
627-
if not dpt.all(repeats >= 0):
627+
if not dpt_ext.all(repeats >= 0):
628628
raise ValueError("'repeats' elements must be positive")
629629

630630
elif isinstance(repeats, (tuple, list, range)):
@@ -646,7 +646,7 @@ def repeat(x, repeats, /, *, axis=None):
646646
repeats = dpt_ext.asarray(
647647
repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
648648
)
649-
if not dpt.all(repeats >= 0):
649+
if not dpt_ext.all(repeats >= 0):
650650
raise ValueError("`repeats` elements must be positive")
651651
else:
652652
raise TypeError(

0 commit comments

Comments
 (0)