Skip to content

Commit b197a63

Browse files
committed
add tensor linalg extension
1 parent 0cdc3e4 commit b197a63

File tree

12 files changed

+8131
-5
lines changed

12 files changed

+8131
-5
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ set(_sorting_sources
166166
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
167167
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
168168
)
169+
set(_linalg_sources
170+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
171+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
172+
)
169173
set(_tensor_accumulation_impl_sources
170174
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
171175
${_accumulator_sources}
@@ -182,6 +186,10 @@ set(_tensor_sorting_impl_sources
182186
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
183187
${_sorting_sources}
184188
)
189+
set(_tensor_linalg_impl_sources
190+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp
191+
${_linalg_sources}
192+
)
185193

186194
set(_static_lib_trgt simplify_iteration_space)
187195

@@ -228,6 +236,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s
228236
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
229237
list(APPEND _py_trgts ${python_module_name})
230238

239+
set(python_module_name _tensor_linalg_impl)
240+
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
241+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})
242+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
243+
list(APPEND _py_trgts ${python_module_name})
244+
231245
set(_clang_prefix "")
232246
if(WIN32)
233247
set(_clang_prefix "/clang:")
@@ -245,7 +259,7 @@ list(
245259
${_elementwise_sources}
246260
${_reduction_sources}
247261
${_sorting_sources}
248-
# ${_linalg_sources}
262+
${_linalg_sources}
249263
${_accumulator_sources}
250264
)
251265

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@
6262
take,
6363
take_along_axis,
6464
)
65+
from dpctl_ext.tensor._linalg_functions import (
66+
matmul,
67+
matrix_transpose,
68+
tensordot,
69+
vecdot,
70+
)
6571
from dpctl_ext.tensor._manipulation_functions import (
6672
broadcast_arrays,
6773
broadcast_to,

0 commit comments

Comments
 (0)