Skip to content

Commit b0647db

Browse files
authored
add tensor linalg extension (#2799)
This PR migrates the `_tensor_linalg_impl` extension to `dpctl_ext.tensor` and extends `dpctl_ext.tensor` Python API with `dpctl.tensor` functions `matmul`, `matrix_transpose`, `tensordot`, and `vecdot`
1 parent 3a0c2ff commit b0647db

File tree

12 files changed

+8077
-5
lines changed

12 files changed

+8077
-5
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ set(_sorting_sources
166166
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
167167
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
168168
)
169+
set(_linalg_sources
170+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
171+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
172+
)
169173
set(_tensor_accumulation_impl_sources
170174
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
171175
${_accumulator_sources}
@@ -182,6 +186,10 @@ set(_tensor_sorting_impl_sources
182186
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
183187
${_sorting_sources}
184188
)
189+
set(_tensor_linalg_impl_sources
190+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp
191+
${_linalg_sources}
192+
)
185193

186194
set(_static_lib_trgt simplify_iteration_space)
187195

@@ -228,6 +236,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s
228236
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
229237
list(APPEND _py_trgts ${python_module_name})
230238

239+
set(python_module_name _tensor_linalg_impl)
240+
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
241+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})
242+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
243+
list(APPEND _py_trgts ${python_module_name})
244+
231245
set(_clang_prefix "")
232246
if(WIN32)
233247
set(_clang_prefix "/clang:")
@@ -245,7 +259,7 @@ list(
245259
${_elementwise_sources}
246260
${_reduction_sources}
247261
${_sorting_sources}
248-
# ${_linalg_sources}
262+
${_linalg_sources}
249263
${_accumulator_sources}
250264
)
251265

dpctl_ext/tensor/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@
107107
take,
108108
take_along_axis,
109109
)
110+
from ._linear_algebra_functions import (
111+
matmul,
112+
matrix_transpose,
113+
tensordot,
114+
vecdot,
115+
)
110116
from ._manipulation_functions import (
111117
broadcast_arrays,
112118
broadcast_to,
@@ -216,6 +222,8 @@
216222
"min",
217223
"moveaxis",
218224
"permute_dims",
225+
"matmul",
226+
"matrix_transpose",
219227
"negative",
220228
"nonzero",
221229
"ones",
@@ -251,6 +259,7 @@
251259
"take_along_axis",
252260
"tan",
253261
"tanh",
262+
"tensordot",
254263
"tile",
255264
"top_k",
256265
"to_numpy",
@@ -262,6 +271,7 @@
262271
"unique_inverse",
263272
"unique_values",
264273
"unstack",
274+
"vecdot",
265275
"where",
266276
"zeros",
267277
"zeros_like",

0 commit comments

Comments
 (0)