Skip to content

Commit 890238c

Browse files
committed
sparse/gemv: add missing headers, input validation, and MKL/SYCL exception handling
Align gemv.cpp with the conventions established in blas/gemm.cpp: Headers added: - ext/common.hpp (dpctl_td_ns, consistent with other extensions) - utils/memory_overlap.hpp (MemoryOverlap guard on x vs y) - utils/output_validation.hpp (CheckWritable + AmpleMemory on y) - utils/type_utils.hpp (validate_type_for_device<T> in impl) - <sstream> (needed for stringstream error_msg) Exception handling added in sparse_gemv_impl(): - try/catch(oneapi::mkl::exception) around all oneMKL sparse calls - try/catch(sycl::exception) around all oneMKL sparse calls - release_matrix_handle cleanup in the exception error path - throw std::runtime_error with descriptive message on catch Input validation added in sparse_gemv(): - ndim checks: x and y must be 1-D - queues_are_compatible() across all 5 USM arrays - MemoryOverlap()(x, y) aliasing guard - CheckWritable::throw_if_not_writable(y) - AmpleMemory::throw_if_not_ample(y, num_rows) - keep_args_alive() at function exit (was missing, returning empty event)
1 parent 14cb5c4 commit 890238c

1 file changed

Lines changed: 102 additions & 24 deletions

File tree

dpnp/backend/extensions/sparse/gemv.cpp

Lines changed: 102 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,26 @@
2626
// THE POSSIBILITY OF SUCH DAMAGE.
2727
//*****************************************************************************
2828

29+
#include <sstream>
2930
#include <stdexcept>
3031
#include <vector>
3132

33+
#include <pybind11/pybind11.h>
34+
35+
// ext/common.hpp — dpctl_td_ns; mirrors every other dpnp extension
36+
#include "ext/common.hpp"
37+
38+
// dpctl tensor validation and utility headers — same set as blas/gemm.cpp
39+
#include "utils/memory_overlap.hpp"
40+
#include "utils/output_validation.hpp"
41+
#include "utils/type_utils.hpp"
42+
3243
#include "gemv.hpp"
3344

3445
// oneMKL sparse BLAS
3546
namespace mkl_sparse = oneapi::mkl::sparse;
47+
namespace py = pybind11;
48+
namespace type_utils = dpctl::tensor::type_utils;
3649

3750
namespace dpnp::extensions::sparse
3851
{
@@ -57,30 +70,60 @@ sparse_gemv_impl(sycl::queue &exec_q,
5770
T *y_ptr,
5871
const std::vector<sycl::event> &depends)
5972
{
73+
// Validate that T is supported on this device (mirrors gemm_impl pattern)
74+
type_utils::validate_type_for_device<T>(exec_q);
75+
76+
std::stringstream error_msg;
77+
bool is_exception_caught = false;
78+
6079
mkl_sparse::matrix_handle_t handle = nullptr;
61-
mkl_sparse::init_matrix_handle(&handle);
80+
sycl::event gemv_ev;
6281

63-
auto ev_set = mkl_sparse::set_csr_data(
64-
exec_q, handle,
65-
num_rows, num_cols,
66-
oneapi::mkl::index_base::zero,
67-
row_ptr_ptr, col_ind_ptr, values_ptr,
68-
depends);
82+
try {
83+
mkl_sparse::init_matrix_handle(&handle);
6984

70-
// optimize_gemv performs internal analysis — amortises over repeated SpMV
71-
auto ev_opt = mkl_sparse::optimize_gemv(
72-
exec_q, mkl_trans, handle, {ev_set});
85+
auto ev_set = mkl_sparse::set_csr_data(
86+
exec_q, handle,
87+
num_rows, num_cols,
88+
oneapi::mkl::index_base::zero,
89+
row_ptr_ptr, col_ind_ptr, values_ptr,
90+
depends);
7391

74-
auto ev_gemv = mkl_sparse::gemv(
75-
exec_q, mkl_trans,
76-
alpha, handle,
77-
x_ptr, beta, y_ptr,
78-
{ev_opt});
92+
// optimize_gemv performs internal analysis — amortises over repeated SpMV
93+
auto ev_opt = mkl_sparse::optimize_gemv(
94+
exec_q, mkl_trans, handle, {ev_set});
7995

80-
// async release — waits for ev_gemv internally
81-
mkl_sparse::release_matrix_handle(exec_q, &handle, {ev_gemv});
96+
gemv_ev = mkl_sparse::gemv(
97+
exec_q, mkl_trans,
98+
alpha, handle,
99+
x_ptr, beta, y_ptr,
100+
{ev_opt});
82101

83-
return ev_gemv;
102+
// async release — waits for gemv_ev internally
103+
mkl_sparse::release_matrix_handle(exec_q, &handle, {gemv_ev});
104+
105+
} catch (oneapi::mkl::exception const &e) {
106+
error_msg
107+
<< "Unexpected MKL exception caught during sparse_gemv() call:"
108+
"\nreason: "
109+
<< e.what();
110+
is_exception_caught = true;
111+
} catch (sycl::exception const &e) {
112+
error_msg
113+
<< "Unexpected SYCL exception caught during sparse_gemv() call:\n"
114+
<< e.what();
115+
is_exception_caught = true;
116+
}
117+
118+
if (is_exception_caught) {
119+
// Best-effort handle cleanup before re-raising
120+
if (handle != nullptr) {
121+
mkl_sparse::release_matrix_handle(exec_q, &handle, {});
122+
}
123+
throw std::runtime_error(error_msg.str());
124+
}
125+
126+
return gemv_ev;
84127
}
85128

86129

@@ -103,24 +146,55 @@ sparse_gemv(sycl::queue &exec_q,
103146
const std::int64_t nnz,
104147
const std::vector<sycl::event> &depends)
105148
{
106-
// Map trans integer to oneMKL enum
149+
// --- 1. ndim checks ---
150+
if (x.get_ndim() != 1) {
151+
throw py::value_error("sparse_gemv: x must be a 1-D array.");
152+
}
153+
if (y.get_ndim() != 1) {
154+
throw py::value_error("sparse_gemv: y must be a 1-D array.");
155+
}
156+
157+
// --- 2. Queue compatibility (all USM arrays must share the same queue) ---
158+
if (!dpctl::utils::queues_are_compatible(
159+
exec_q,
160+
{row_ptr.get_queue(), col_ind.get_queue(),
161+
values.get_queue(), x.get_queue(), y.get_queue()})) {
162+
throw py::value_error(
163+
"sparse_gemv: USM allocations are not compatible with the "
164+
"execution queue.");
165+
}
166+
167+
// --- 3. Memory overlap: x and y must not alias ---
168+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
169+
if (overlap(x, y)) {
170+
throw py::value_error(
171+
"sparse_gemv: input array x and output array y are overlapping "
172+
"segments of memory.");
173+
}
174+
175+
// --- 4. Output writability and size ---
176+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(y);
177+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
178+
y, static_cast<std::size_t>(num_rows));
179+
180+
// --- 5. Map trans integer to oneMKL enum ---
107181
oneapi::mkl::transpose mkl_trans;
108182
switch (trans) {
109-
case 0: mkl_trans = oneapi::mkl::transpose::nontrans; break;
110-
case 1: mkl_trans = oneapi::mkl::transpose::trans; break;
183+
case 0: mkl_trans = oneapi::mkl::transpose::nontrans; break;
184+
case 1: mkl_trans = oneapi::mkl::transpose::trans; break;
111185
case 2: mkl_trans = oneapi::mkl::transpose::conjtrans; break;
112186
default:
113187
throw std::invalid_argument(
114188
"sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)");
115189
}
116190

191+
// --- 6. Type dispatch (value type x index type) ---
192+
// oneMKL sparse BLAS supports float32 and float64 (no complex yet)
117193
int val_typenum = values.get_typenum();
118194
int idx_typenum = row_ptr.get_typenum();
119195

120196
sycl::event gemv_ev;
121197

122-
// Dispatch on value type x index type
123-
// oneMKL sparse BLAS supports float32, float64 (no complex yet)
124198
if (val_typenum == UAR_FLOAT) {
125199
auto alpha_f = static_cast<float>(alpha);
126200
auto beta_f = static_cast<float>(beta);
@@ -181,7 +255,11 @@ sparse_gemv(sycl::queue &exec_q,
181255
"sparse_gemv: value dtype must be float32 or float64");
182256
}
183257

184-
return std::make_pair(sycl::event{}, gemv_ev);
258+
// Keep all input/output USM arrays alive until gemv_ev completes
259+
sycl::event args_ev = dpctl::utils::keep_args_alive(
260+
exec_q, {row_ptr, col_ind, values, x, y}, {gemv_ev});
261+
262+
return std::make_pair(args_ev, gemv_ev);
185263
}
186264

187265

0 commit comments

Comments
 (0)