2828
2929#include < sstream>
3030#include < stdexcept>
31- #include < vector>
3231
3332#include < pybind11/pybind11.h>
3433
35- // ext/common.hpp — dpctl_td_ns; mirrors every other dpnp extension
34+ // dpnp extension infrastructure
3635#include " ext/common.hpp"
3736
38- // dpctl tensor validation and utility headers — same set as blas/gemm.cpp
37+ // dpctl tensor validation and utility headers
3938#include " utils/memory_overlap.hpp"
4039#include " utils/output_validation.hpp"
4140#include " utils/type_utils.hpp"
4241
4342#include " gemv.hpp"
43+ #include " types_matrix.hpp"
4444
45- // oneMKL sparse BLAS
4645namespace mkl_sparse = oneapi::mkl::sparse;
47- namespace py = pybind11;
46+ namespace py = pybind11;
4847namespace type_utils = dpctl::tensor::type_utils;
4948
49+ using ext::common::init_dispatch_table;
50+
5051namespace dpnp ::extensions::sparse
5152{
5253
5354// ---------------------------------------------------------------------------
54- // Type-dispatched implementation: y = alpha * op(A) * x + beta * y
55+ // Dispatch table: [value_type_id][index_type_id] -> impl function pointer
56+ // Mirrors the 2-D table pattern of blas/gemm.cpp.
57+ // ---------------------------------------------------------------------------
58+
59+ typedef sycl::event (*gemv_impl_fn_ptr_t )(
60+ sycl::queue &,
61+ oneapi::mkl::transpose,
62+ double , // alpha (always passed as double; cast inside)
63+ const char *, // row_ptr (typeless)
64+ const char *, // col_ind (typeless)
65+ const char *, // values (typeless)
66+ std::int64_t , // num_rows
67+ std::int64_t , // num_cols
68+ std::int64_t , // nnz
69+ const char *, // x (typeless)
70+ double , // beta (always passed as double; cast inside)
71+ char *, // y (typeless, writable)
72+ const std::vector<sycl::event> &);
73+
74+ static gemv_impl_fn_ptr_t
75+ gemv_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
76+
77+
78+ // ---------------------------------------------------------------------------
79+ // Typed implementation — one instantiation per (Tv, Ti) pair
5580// ---------------------------------------------------------------------------
5681
57- template <typename T , typename intType >
82+ template <typename Tv , typename Ti >
5883static sycl::event
59- sparse_gemv_impl (sycl::queue &exec_q,
60- oneapi::mkl::transpose mkl_trans,
61- T alpha ,
62- intType *row_ptr_ptr ,
63- intType *col_ind_ptr ,
64- T *values_ptr ,
65- std:: int64_t num_rows,
66- std:: int64_t num_cols,
67- std:: int64_t nnz,
68- T *x_ptr ,
69- T beta ,
70- T *y_ptr ,
71- const std::vector<sycl::event> &depends)
84+ gemv_impl (sycl::queue &exec_q,
85+ oneapi::mkl::transpose mkl_trans,
86+ double alpha_d ,
87+ const char *row_ptr_data ,
88+ const char *col_ind_data ,
89+ const char *values_data ,
90+ std:: int64_t num_rows,
91+ std:: int64_t num_cols,
92+ std:: int64_t nnz,
93+ const char *x_data ,
94+ double beta_d ,
95+ char *y_data ,
96+ const std::vector<sycl::event> &depends)
7297{
73- // Validate that T is supported on this device (mirrors gemm_impl pattern)
74- type_utils::validate_type_for_device<T>(exec_q);
98+ type_utils::validate_type_for_device<Tv>(exec_q);
99+
100+ const Tv alpha = static_cast <Tv>(alpha_d);
101+ const Tv beta = static_cast <Tv>(beta_d);
102+ const Ti *row_ptr = reinterpret_cast <const Ti *>(row_ptr_data);
103+ const Ti *col_ind = reinterpret_cast <const Ti *>(col_ind_data);
104+ const Tv *values = reinterpret_cast <const Tv *>(values_data);
105+ const Tv *x = reinterpret_cast <const Tv *>(x_data);
106+ Tv *y = reinterpret_cast <Tv *>(y_data);
75107
76108 std::stringstream error_msg;
77109 bool is_exception_caught = false ;
@@ -86,40 +118,35 @@ sparse_gemv_impl(sycl::queue &exec_q,
86118 exec_q, handle,
87119 num_rows, num_cols,
88120 oneapi::mkl::index_base::zero,
89- row_ptr_ptr, col_ind_ptr, values_ptr,
121+ const_cast <Ti *>(row_ptr),
122+ const_cast <Ti *>(col_ind),
123+ const_cast <Tv *>(values),
90124 depends);
91125
92- // optimize_gemv performs internal analysis — amortises over repeated SpMV
93126 auto ev_opt = mkl_sparse::optimize_gemv (
94127 exec_q, mkl_trans, handle, {ev_set});
95128
96129 gemv_ev = mkl_sparse::gemv (
97130 exec_q, mkl_trans,
98131 alpha, handle,
99- x_ptr , beta, y_ptr ,
132+ x , beta, y ,
100133 {ev_opt});
101134
102- // async release — waits for gemv_ev internally
103135 mkl_sparse::release_matrix_handle (exec_q, &handle, {gemv_ev});
104136
105137 } catch (oneapi::mkl::exception const &e) {
106- error_msg
107- << " Unexpected MKL exception caught during sparse_gemv() call:"
108- " \n reason: "
109- << e.what ();
138+ error_msg << " Unexpected MKL exception caught during sparse_gemv() "
139+ " call:\n reason: " << e.what ();
110140 is_exception_caught = true ;
111141 } catch (sycl::exception const &e) {
112- error_msg
113- << " Unexpected SYCL exception caught during sparse_gemv() call:\n "
114- << e.what ();
142+ error_msg << " Unexpected SYCL exception caught during sparse_gemv() "
143+ " call:\n " << e.what ();
115144 is_exception_caught = true ;
116145 }
117146
118147 if (is_exception_caught) {
119- // Best-effort handle cleanup before re-raising
120- if (handle != nullptr ) {
148+ if (handle != nullptr )
121149 mkl_sparse::release_matrix_handle (exec_q, &handle, {});
122- }
123150 throw std::runtime_error (error_msg.str ());
124151 }
125152
@@ -128,56 +155,51 @@ sparse_gemv_impl(sycl::queue &exec_q,
128155
129156
130157// ---------------------------------------------------------------------------
131- // Python-facing function
158+ // Python-facing entry point
132159// ---------------------------------------------------------------------------
133160
134161std::pair<sycl::event, sycl::event>
135- sparse_gemv (sycl::queue &exec_q,
136- const int trans,
137- const double alpha,
138- const dpctl::tensor::usm_ndarray &row_ptr,
139- const dpctl::tensor::usm_ndarray &col_ind,
140- const dpctl::tensor::usm_ndarray &values,
141- const dpctl::tensor::usm_ndarray &x,
142- const double beta,
143- const dpctl::tensor::usm_ndarray &y,
144- const std::int64_t num_rows,
145- const std::int64_t num_cols,
146- const std::int64_t nnz,
147- const std::vector<sycl::event> &depends)
162+ sparse_gemv (sycl::queue &exec_q,
163+ const int trans,
164+ const double alpha,
165+ const dpctl::tensor::usm_ndarray &row_ptr,
166+ const dpctl::tensor::usm_ndarray &col_ind,
167+ const dpctl::tensor::usm_ndarray &values,
168+ const dpctl::tensor::usm_ndarray &x,
169+ const double beta,
170+ const dpctl::tensor::usm_ndarray &y,
171+ const std::int64_t num_rows,
172+ const std::int64_t num_cols,
173+ const std::int64_t nnz,
174+ const std::vector<sycl::event> &depends)
148175{
149- // --- 1. ndim checks ---
150- if (x.get_ndim () != 1 ) {
176+ // 1. ndim checks
177+ if (x.get_ndim () != 1 )
151178 throw py::value_error (" sparse_gemv: x must be a 1-D array." );
152- }
153- if (y.get_ndim () != 1 ) {
179+ if (y.get_ndim () != 1 )
154180 throw py::value_error (" sparse_gemv: y must be a 1-D array." );
155- }
156181
157- // --- 2. Queue compatibility (all USM arrays must share the same queue) ---
182+ // 2. Queue compatibility
158183 if (!dpctl::utils::queues_are_compatible (
159- exec_q,
160- {row_ptr.get_queue (), col_ind.get_queue (),
161- values.get_queue (), x.get_queue (), y.get_queue ()})) {
184+ exec_q, {row_ptr.get_queue (), col_ind.get_queue (),
185+ values.get_queue (), x.get_queue (), y.get_queue ()}))
162186 throw py::value_error (
163187 " sparse_gemv: USM allocations are not compatible with the "
164188 " execution queue." );
165- }
166189
167- // --- 3. Memory overlap: x and y must not alias ---
190+ // 3. Memory overlap: x and y must not alias
168191 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
169- if (overlap (x, y)) {
192+ if (overlap (x, y))
170193 throw py::value_error (
171194 " sparse_gemv: input array x and output array y are overlapping "
172195 " segments of memory." );
173- }
174196
175- // --- 4. Output writability and size ---
197+ // 4. Output writability and size
176198 dpctl::tensor::validation::CheckWritable::throw_if_not_writable (y);
177199 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (
178200 y, static_cast <std::size_t >(num_rows));
179201
180- // --- 5. Map trans integer to oneMKL enum ---
202+ // 5. Map trans integer to oneMKL enum
181203 oneapi::mkl::transpose mkl_trans;
182204 switch (trans) {
183205 case 0 : mkl_trans = oneapi::mkl::transpose::nontrans; break ;
@@ -188,74 +210,24 @@ sparse_gemv(sycl::queue &exec_q,
188210 " sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)" );
189211 }
190212
191- // --- 6. Type dispatch (value type x index type) ---
192- // oneMKL sparse BLAS supports float32 and float64 (no complex yet)
193- int val_typenum = values.get_typenum ();
194- int idx_typenum = row_ptr.get_typenum ();
213+ // 6. Dispatch table lookup — replaces the explicit if/else chain
214+ auto array_types = dpctl_td_ns::usm_ndarray_types ();
215+ const int val_id = array_types. typenum_to_lookup_id ( values.get_typenum () );
216+ const int idx_id = array_types. typenum_to_lookup_id ( row_ptr.get_typenum () );
195217
196- sycl::event gemv_ev ;
197-
198- if (val_typenum == UAR_FLOAT) {
199- auto alpha_f = static_cast < float >(alpha);
200- auto beta_f = static_cast < float >(beta );
218+ gemv_impl_fn_ptr_t gemv_fn = gemv_dispatch_table[val_id][idx_id] ;
219+ if (gemv_fn == nullptr )
220+ throw py::value_error (
221+ " sparse_gemv: no implementation for the given value/index dtype "
222+ " combination. Supported: float32/float64 with int32/int64 indices. " );
201223
202- if (idx_typenum == UAR_INT32) {
203- gemv_ev = sparse_gemv_impl<float , std::int32_t >(
204- exec_q, mkl_trans, alpha_f,
205- row_ptr.get_data <std::int32_t >(),
206- col_ind.get_data <std::int32_t >(),
207- values.get_data <float >(),
208- num_rows, num_cols, nnz,
209- x.get_data <float >(), beta_f,
210- y.get_data <float >(), depends);
211- }
212- else if (idx_typenum == UAR_INT64) {
213- gemv_ev = sparse_gemv_impl<float , std::int64_t >(
214- exec_q, mkl_trans, alpha_f,
215- row_ptr.get_data <std::int64_t >(),
216- col_ind.get_data <std::int64_t >(),
217- values.get_data <float >(),
218- num_rows, num_cols, nnz,
219- x.get_data <float >(), beta_f,
220- y.get_data <float >(), depends);
221- }
222- else {
223- throw std::runtime_error (
224- " sparse_gemv: index dtype must be int32 or int64" );
225- }
226- }
227- else if (val_typenum == UAR_DOUBLE) {
228- if (idx_typenum == UAR_INT32) {
229- gemv_ev = sparse_gemv_impl<double , std::int32_t >(
230- exec_q, mkl_trans, alpha,
231- row_ptr.get_data <std::int32_t >(),
232- col_ind.get_data <std::int32_t >(),
233- values.get_data <double >(),
234- num_rows, num_cols, nnz,
235- x.get_data <double >(), beta,
236- y.get_data <double >(), depends);
237- }
238- else if (idx_typenum == UAR_INT64) {
239- gemv_ev = sparse_gemv_impl<double , std::int64_t >(
240- exec_q, mkl_trans, alpha,
241- row_ptr.get_data <std::int64_t >(),
242- col_ind.get_data <std::int64_t >(),
243- values.get_data <double >(),
224+ sycl::event gemv_ev =
225+ gemv_fn (exec_q, mkl_trans, alpha,
226+ row_ptr.get_data (), col_ind.get_data (), values.get_data (),
244227 num_rows, num_cols, nnz,
245- x.get_data <double >(), beta,
246- y.get_data <double >(), depends);
247- }
248- else {
249- throw std::runtime_error (
250- " sparse_gemv: index dtype must be int32 or int64" );
251- }
252- }
253- else {
254- throw std::runtime_error (
255- " sparse_gemv: value dtype must be float32 or float64" );
256- }
228+ x.get_data (), beta, y.get_data (),
229+ depends);
257230
258- // Keep all input/output USM arrays alive until gemv_ev completes
259231 sycl::event args_ev = dpctl::utils::keep_args_alive (
260232 exec_q, {row_ptr, col_ind, values, x, y}, {gemv_ev});
261233
@@ -264,15 +236,26 @@ sparse_gemv(sycl::queue &exec_q,
264236
265237
266238// ---------------------------------------------------------------------------
267- // Dispatch vector init (placeholder — matches blas convention)
239+ // Factory and dispatch table initialisation
240+ // Mirrors blas/gemm.cpp: GemmContigFactory -> GemvContigFactory
268241// ---------------------------------------------------------------------------
269242
270- void init_sparse_gemv_dispatch_vector (void )
243+ template <typename fnT, typename Tv, typename Ti>
244+ struct GemvContigFactory
245+ {
246+ fnT get ()
247+ {
248+ if constexpr (types::SparseGemvTypePairSupportFactory<Tv, Ti>::is_defined)
249+ return gemv_impl<Tv, Ti>;
250+ else
251+ return nullptr ;
252+ }
253+ };
254+
255+ void init_sparse_gemv_dispatch_table (void )
271256{
272- // No dispatch table needed for sparse_gemv since we do explicit
273- // type switching in the function body (oneMKL sparse API uses
274- // opaque handles, not templated dispatch tables).
275- // This function exists to match the dpnp extension convention.
257+ init_dispatch_table<gemv_impl_fn_ptr_t , GemvContigFactory>(
258+ gemv_dispatch_table);
276259}
277260
278261} // namespace dpnp::extensions::sparse
0 commit comments