Skip to content

Commit ed58333

Browse files
committed
minor cleanup for sparse extensions
1 parent 6136da2 commit ed58333

4 files changed

Lines changed: 13 additions & 54 deletions

File tree

dpnp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ add_subdirectory(backend/extensions/statistics)
100100
add_subdirectory(backend/extensions/ufunc)
101101
add_subdirectory(backend/extensions/vm)
102102
add_subdirectory(backend/extensions/window)
103+
add_subdirectory(backend/extensions/sparse)
103104

104105
add_subdirectory(dpnp_algo)
105106
add_subdirectory(dpnp_utils)

dpnp/backend/extensions/sparse/CMakeLists.txt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pybind11_add_module(${python_module_name} MODULE ${_module_src})
3737
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src})
3838

3939
if(_dpnp_sycl_targets)
40+
# make fat binary
4041
target_compile_options(
4142
${python_module_name}
4243
PRIVATE ${_dpnp_sycl_target_compile_options}
@@ -45,7 +46,9 @@ if(_dpnp_sycl_targets)
4546
endif()
4647

4748
if(WIN32)
48-
if(${CMAKE_VERSION} VERSION_LESS "3.27")
49+
if(${CMAKE_VERSION} VERSION_LESS "3.27")
50+
# this is a work-around for target_link_options inserting option after -link option, cause
51+
# linker to ignore it.
4952
set(CMAKE_CXX_LINK_FLAGS
5053
"${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel"
5154
)
@@ -62,13 +65,11 @@ target_include_directories(
6265
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common
6366
)
6467

68+
# treat below headers as system to suppress the warnings there during the build
6569
target_include_directories(
6670
${python_module_name}
6771
SYSTEM
68-
PRIVATE
69-
${SYCL_INCLUDE_DIR}
70-
${Dpctl_INCLUDE_DIRS}
71-
${Dpctl_TENSOR_INCLUDE_DIR}
72+
PRIVATE ${SYCL_INCLUDE_DIR} ${Dpctl_INCLUDE_DIRS} ${Dpctl_TENSOR_INCLUDE_DIR}
7273
)
7374

7475
if(WIN32)

dpnp/backend/extensions/sparse/gemv.cpp

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,29 @@
2626
// THE POSSIBILITY OF SUCH DAMAGE.
2727
//*****************************************************************************
2828

29-
#include <sstream>
3029
#include <stdexcept>
3130

3231
#include <pybind11/pybind11.h>
3332

34-
// dpnp extension infrastructure
33+
// utils extension header
3534
#include "ext/common.hpp"
3635

37-
// dpctl tensor validation and utility headers
36+
// dpctl tensor headers
3837
#include "utils/memory_overlap.hpp"
3938
#include "utils/output_validation.hpp"
4039
#include "utils/type_utils.hpp"
4140

4241
#include "gemv.hpp"
4342
#include "types_matrix.hpp"
4443

44+
namespace dpnp::extensions::sparse
45+
{
4546
namespace mkl_sparse = oneapi::mkl::sparse;
4647
namespace py = pybind11;
4748
namespace type_utils = dpctl::tensor::type_utils;
4849

4950
using ext::common::init_dispatch_table;
5051

51-
namespace dpnp::extensions::sparse
52-
{
53-
54-
// ---------------------------------------------------------------------------
55-
// Dispatch table: [value_type_id][index_type_id] -> impl function pointer
56-
// Mirrors the 2-D table pattern of blas/gemm.cpp.
57-
// ---------------------------------------------------------------------------
58-
5952
typedef sycl::event (*gemv_impl_fn_ptr_t)(
6053
sycl::queue &,
6154
oneapi::mkl::transpose,
@@ -74,11 +67,6 @@ typedef sycl::event (*gemv_impl_fn_ptr_t)(
7467
static gemv_impl_fn_ptr_t
7568
gemv_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
7669

77-
78-
// ---------------------------------------------------------------------------
79-
// Typed implementation — one instantiation per (Tv, Ti) pair
80-
// ---------------------------------------------------------------------------
81-
8270
template <typename Tv, typename Ti>
8371
static sycl::event
8472
gemv_impl(sycl::queue &exec_q,
@@ -114,8 +102,6 @@ gemv_impl(sycl::queue &exec_q,
114102
try {
115103
mkl_sparse::init_matrix_handle(&spmat);
116104

117-
// oneMKL 2025-2 API: set_csr_data now requires explicit nnz and uses
118-
// `spmat` nomenclature. The old form without nnz is deprecated.
119105
auto ev_set = mkl_sparse::set_csr_data(
120106
exec_q, spmat,
121107
num_rows, num_cols, nnz,
@@ -155,11 +141,6 @@ gemv_impl(sycl::queue &exec_q,
155141
return gemv_ev;
156142
}
157143

158-
159-
// ---------------------------------------------------------------------------
160-
// Python-facing entry point
161-
// ---------------------------------------------------------------------------
162-
163144
std::pair<sycl::event, sycl::event>
164145
sparse_gemv(sycl::queue &exec_q,
165146
const int trans,
@@ -175,33 +156,28 @@ sparse_gemv(sycl::queue &exec_q,
175156
const std::int64_t nnz,
176157
const std::vector<sycl::event> &depends)
177158
{
178-
// 1. ndim checks
179159
if (x.get_ndim() != 1)
180160
throw py::value_error("sparse_gemv: x must be a 1-D array.");
181161
if (y.get_ndim() != 1)
182162
throw py::value_error("sparse_gemv: y must be a 1-D array.");
183163

184-
// 2. Queue compatibility
185164
if (!dpctl::utils::queues_are_compatible(
186165
exec_q, {row_ptr.get_queue(), col_ind.get_queue(),
187166
values.get_queue(), x.get_queue(), y.get_queue()}))
188167
throw py::value_error(
189168
"sparse_gemv: USM allocations are not compatible with the "
190169
"execution queue.");
191170

192-
// 3. Memory overlap: x and y must not alias
193171
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
194172
if (overlap(x, y))
195173
throw py::value_error(
196174
"sparse_gemv: input array x and output array y are overlapping "
197175
"segments of memory.");
198176

199-
// 4. Output writability and size
200177
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(y);
201178
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
202179
y, static_cast<std::size_t>(num_rows));
203180

204-
// 5. Map trans integer to oneMKL enum
205181
oneapi::mkl::transpose mkl_trans;
206182
switch (trans) {
207183
case 0: mkl_trans = oneapi::mkl::transpose::nontrans; break;
@@ -212,7 +188,6 @@ sparse_gemv(sycl::queue &exec_q,
212188
"sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)");
213189
}
214190

215-
// 6. Dispatch table lookup
216191
auto array_types = dpctl_td_ns::usm_ndarray_types();
217192
const int val_id = array_types.typenum_to_lookup_id(values.get_typenum());
218193
const int idx_id = array_types.typenum_to_lookup_id(row_ptr.get_typenum());
@@ -236,11 +211,6 @@ sparse_gemv(sycl::queue &exec_q,
236211
return std::make_pair(args_ev, gemv_ev);
237212
}
238213

239-
240-
// ---------------------------------------------------------------------------
241-
// Factory and dispatch table initialisation
242-
// ---------------------------------------------------------------------------
243-
244214
template <typename fnT, typename Tv, typename Ti>
245215
struct GemvContigFactory
246216
{
@@ -258,5 +228,4 @@ void init_sparse_gemv_dispatch_table(void)
258228
init_dispatch_table<gemv_impl_fn_ptr_t, GemvContigFactory>(
259229
gemv_dispatch_table);
260230
}
261-
262231
} // namespace dpnp::extensions::sparse

dpnp/backend/extensions/sparse/sparse_py.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
// THE POSSIBILITY OF SUCH DAMAGE.
2727
//*****************************************************************************
2828
//
29-
// Defines the dpnp.backend._sparse_impl pybind11 extension module.
30-
// Provides oneMKL sparse BLAS operations on CSR matrices over dpctl USM arrays.
31-
// Equivalent role to _cusparse for the SYCL/oneMKL backend.
29+
// This file defines functions of dpnp.backend._sparse_impl extensions
3230
//
3331
//*****************************************************************************
3432

@@ -42,7 +40,7 @@ namespace py = pybind11;
4240

4341
static void init_dispatch_vectors_tables(void)
4442
{
45-
sparse_ns::init_sparse_gemv_dispatch_vector();
43+
sparse_ns::init_sparse_gemv_dispatch_table();
4644
}
4745

4846
PYBIND11_MODULE(_sparse_impl, m)
@@ -52,13 +50,6 @@ PYBIND11_MODULE(_sparse_impl, m)
5250
using arrayT = dpctl::tensor::usm_ndarray;
5351
using event_vecT = std::vector<sycl::event>;
5452

55-
// ------------------------------------------------------------------
56-
// _sparse_gemv — CSR SpMV: y = alpha * op(A) * x + beta * y
57-
//
58-
// Equivalent to _cusparse.spMV_make_fast_matvec for the SYCL stack.
59-
// Backed by oneMKL sparse::gemv with set_csr_data + optimize_gemv so
60-
// matrix-handle analysis is amortised across repeated calls.
61-
// ------------------------------------------------------------------
6253
{
6354
m.def(
6455
"_sparse_gemv",
@@ -113,9 +104,6 @@ PYBIND11_MODULE(_sparse_impl, m)
113104
py::arg("depends") = py::list());
114105
}
115106

116-
// ------------------------------------------------------------------
117-
// Runtime query: which sparse library backend is active
118-
// ------------------------------------------------------------------
119107
{
120108
m.def(
121109
"_using_onemath",

0 commit comments

Comments
 (0)