Skip to content

Commit 3333fc1

Browse files
committed
add tensor linalg extension
1 parent 0cdc3e4 commit 3333fc1

File tree

12 files changed

+8135
-5
lines changed

12 files changed

+8135
-5
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ set(_sorting_sources
166166
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
167167
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
168168
)
169+
set(_linalg_sources
170+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
171+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
172+
)
169173
set(_tensor_accumulation_impl_sources
170174
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
171175
${_accumulator_sources}
@@ -182,6 +186,10 @@ set(_tensor_sorting_impl_sources
182186
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
183187
${_sorting_sources}
184188
)
189+
set(_tensor_linalg_impl_sources
190+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp
191+
${_linalg_sources}
192+
)
185193

186194
set(_static_lib_trgt simplify_iteration_space)
187195

@@ -228,6 +236,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s
228236
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
229237
list(APPEND _py_trgts ${python_module_name})
230238

239+
set(python_module_name _tensor_linalg_impl)
240+
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
241+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})
242+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
243+
list(APPEND _py_trgts ${python_module_name})
244+
231245
set(_clang_prefix "")
232246
if(WIN32)
233247
set(_clang_prefix "/clang:")
@@ -245,7 +259,7 @@ list(
245259
${_elementwise_sources}
246260
${_reduction_sources}
247261
${_sorting_sources}
248-
# ${_linalg_sources}
262+
${_linalg_sources}
249263
${_accumulator_sources}
250264
)
251265

dpctl_ext/tensor/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@
6262
take,
6363
take_along_axis,
6464
)
65+
from dpctl_ext.tensor._linear_algebra_functions import (
66+
matmul,
67+
matrix_transpose,
68+
tensordot,
69+
vecdot,
70+
)
6571
from dpctl_ext.tensor._manipulation_functions import (
6672
broadcast_arrays,
6773
broadcast_to,
@@ -218,6 +224,8 @@
218224
"min",
219225
"moveaxis",
220226
"permute_dims",
227+
"matmul",
228+
"matrix_transpose",
221229
"negative",
222230
"nonzero",
223231
"ones",
@@ -253,6 +261,7 @@
253261
"take_along_axis",
254262
"tan",
255263
"tanh",
264+
"tensordot",
256265
"tile",
257266
"top_k",
258267
"to_numpy",
@@ -264,6 +273,7 @@
264273
"unique_inverse",
265274
"unique_values",
266275
"unstack",
276+
"vecdot",
267277
"where",
268278
"zeros",
269279
"zeros_like",

0 commit comments

Comments
 (0)