2626// THE POSSIBILITY OF SUCH DAMAGE.
2727// *****************************************************************************
2828
29+ #include < sstream>
2930#include < stdexcept>
3031#include < vector>
3132
33+ #include < pybind11/pybind11.h>
34+
35+ // ext/common.hpp — dpctl_td_ns; mirrors every other dpnp extension
36+ #include " ext/common.hpp"
37+
38+ // dpctl tensor validation and utility headers — same set as blas/gemm.cpp
39+ #include " utils/memory_overlap.hpp"
40+ #include " utils/output_validation.hpp"
41+ #include " utils/type_utils.hpp"
42+
3243#include " gemv.hpp"
3344
3445// oneMKL sparse BLAS
3546namespace mkl_sparse = oneapi::mkl::sparse;
47+ namespace py = pybind11;
48+ namespace type_utils = dpctl::tensor::type_utils;
3649
3750namespace dpnp ::extensions::sparse
3851{
@@ -57,30 +70,60 @@ sparse_gemv_impl(sycl::queue &exec_q,
5770 T *y_ptr,
5871 const std::vector<sycl::event> &depends)
5972{
73+ // Validate that T is supported on this device (mirrors gemm_impl pattern)
74+ type_utils::validate_type_for_device<T>(exec_q);
75+
76+ std::stringstream error_msg;
77+ bool is_exception_caught = false ;
78+
6079 mkl_sparse::matrix_handle_t handle = nullptr ;
61- mkl_sparse::init_matrix_handle (&handle) ;
80+ sycl::event gemv_ev ;
6281
63- auto ev_set = mkl_sparse::set_csr_data (
64- exec_q, handle,
65- num_rows, num_cols,
66- oneapi::mkl::index_base::zero,
67- row_ptr_ptr, col_ind_ptr, values_ptr,
68- depends);
82+ try {
83+ mkl_sparse::init_matrix_handle (&handle);
6984
70- // optimize_gemv performs internal analysis — amortises over repeated SpMV
71- auto ev_opt = mkl_sparse::optimize_gemv (
72- exec_q, mkl_trans, handle, {ev_set});
85+ auto ev_set = mkl_sparse::set_csr_data (
86+ exec_q, handle,
87+ num_rows, num_cols,
88+ oneapi::mkl::index_base::zero,
89+ row_ptr_ptr, col_ind_ptr, values_ptr,
90+ depends);
7391
74- auto ev_gemv = mkl_sparse::gemv (
75- exec_q, mkl_trans,
76- alpha, handle,
77- x_ptr, beta, y_ptr,
78- {ev_opt});
92+ // optimize_gemv performs internal analysis — amortises over repeated SpMV
93+ auto ev_opt = mkl_sparse::optimize_gemv (
94+ exec_q, mkl_trans, handle, {ev_set});
7995
80- // async release — waits for ev_gemv internally
81- mkl_sparse::release_matrix_handle (exec_q, &handle, {ev_gemv});
96+ gemv_ev = mkl_sparse::gemv (
97+ exec_q, mkl_trans,
98+ alpha, handle,
99+ x_ptr, beta, y_ptr,
100+ {ev_opt});
82101
83- return ev_gemv;
102+ // async release — waits for gemv_ev internally
103+ mkl_sparse::release_matrix_handle (exec_q, &handle, {gemv_ev});
104+
105+ } catch (oneapi::mkl::exception const &e) {
106+ error_msg
107+ << " Unexpected MKL exception caught during sparse_gemv() call:"
108+ " \n reason: "
109+ << e.what ();
110+ is_exception_caught = true ;
111+ } catch (sycl::exception const &e) {
112+ error_msg
113+ << " Unexpected SYCL exception caught during sparse_gemv() call:\n "
114+ << e.what ();
115+ is_exception_caught = true ;
116+ }
117+
118+ if (is_exception_caught) {
119+ // Best-effort handle cleanup before re-raising
120+ if (handle != nullptr ) {
121+ mkl_sparse::release_matrix_handle (exec_q, &handle, {});
122+ }
123+ throw std::runtime_error (error_msg.str ());
124+ }
125+
126+ return gemv_ev;
84127}
85128
86129
@@ -103,24 +146,55 @@ sparse_gemv(sycl::queue &exec_q,
103146 const std::int64_t nnz,
104147 const std::vector<sycl::event> &depends)
105148{
106- // Map trans integer to oneMKL enum
149+ // --- 1. ndim checks ---
150+ if (x.get_ndim () != 1 ) {
151+ throw py::value_error (" sparse_gemv: x must be a 1-D array." );
152+ }
153+ if (y.get_ndim () != 1 ) {
154+ throw py::value_error (" sparse_gemv: y must be a 1-D array." );
155+ }
156+
157+ // --- 2. Queue compatibility (all USM arrays must share the same queue) ---
158+ if (!dpctl::utils::queues_are_compatible (
159+ exec_q,
160+ {row_ptr.get_queue (), col_ind.get_queue (),
161+ values.get_queue (), x.get_queue (), y.get_queue ()})) {
162+ throw py::value_error (
163+ " sparse_gemv: USM allocations are not compatible with the "
164+ " execution queue." );
165+ }
166+
167+ // --- 3. Memory overlap: x and y must not alias ---
168+ auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
169+ if (overlap (x, y)) {
170+ throw py::value_error (
171+ " sparse_gemv: input array x and output array y are overlapping "
172+ " segments of memory." );
173+ }
174+
175+ // --- 4. Output writability and size ---
176+ dpctl::tensor::validation::CheckWritable::throw_if_not_writable (y);
177+ dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (
178+ y, static_cast <std::size_t >(num_rows));
179+
180+ // --- 5. Map trans integer to oneMKL enum ---
107181 oneapi::mkl::transpose mkl_trans;
108182 switch (trans) {
109- case 0 : mkl_trans = oneapi::mkl::transpose::nontrans; break ;
110- case 1 : mkl_trans = oneapi::mkl::transpose::trans; break ;
183+ case 0 : mkl_trans = oneapi::mkl::transpose::nontrans; break ;
184+ case 1 : mkl_trans = oneapi::mkl::transpose::trans; break ;
111185 case 2 : mkl_trans = oneapi::mkl::transpose::conjtrans; break ;
112186 default :
113187 throw std::invalid_argument (
114188 " sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)" );
115189 }
116190
191+ // --- 6. Type dispatch (value type x index type) ---
192+ // oneMKL sparse BLAS supports float32 and float64 (no complex yet)
117193 int val_typenum = values.get_typenum ();
118194 int idx_typenum = row_ptr.get_typenum ();
119195
120196 sycl::event gemv_ev;
121197
122- // Dispatch on value type x index type
123- // oneMKL sparse BLAS supports float32, float64 (no complex yet)
124198 if (val_typenum == UAR_FLOAT) {
125199 auto alpha_f = static_cast <float >(alpha);
126200 auto beta_f = static_cast <float >(beta);
@@ -181,7 +255,11 @@ sparse_gemv(sycl::queue &exec_q,
181255 " sparse_gemv: value dtype must be float32 or float64" );
182256 }
183257
184- return std::make_pair (sycl::event{}, gemv_ev);
258+ // Keep all input/output USM arrays alive until gemv_ev completes
259+ sycl::event args_ev = dpctl::utils::keep_args_alive (
260+ exec_q, {row_ptr, col_ind, values, x, y}, {gemv_ev});
261+
262+ return std::make_pair (args_ev, gemv_ev);
185263}
186264
187265
0 commit comments