2626// THE POSSIBILITY OF SUCH DAMAGE.
2727// *****************************************************************************
2828
29- #include < sstream>
3029#include < stdexcept>
3130
3231#include < pybind11/pybind11.h>
3332
34- // dpnp extension infrastructure
33+ // utils extension header
3534#include " ext/common.hpp"
3635
37- // dpctl tensor validation and utility headers
36+ // dpctl tensor headers
3837#include " utils/memory_overlap.hpp"
3938#include " utils/output_validation.hpp"
4039#include " utils/type_utils.hpp"
4140
4241#include " gemv.hpp"
4342#include " types_matrix.hpp"
4443
44+ namespace dpnp ::extensions::sparse
45+ {
4546namespace mkl_sparse = oneapi::mkl::sparse;
4647namespace py = pybind11;
4748namespace type_utils = dpctl::tensor::type_utils;
4849
4950using ext::common::init_dispatch_table;
5051
51- namespace dpnp ::extensions::sparse
52- {
53-
54- // ---------------------------------------------------------------------------
55- // Dispatch table: [value_type_id][index_type_id] -> impl function pointer
56- // Mirrors the 2-D table pattern of blas/gemm.cpp.
57- // ---------------------------------------------------------------------------
58-
5952typedef sycl::event (*gemv_impl_fn_ptr_t )(
6053 sycl::queue &,
6154 oneapi::mkl::transpose,
@@ -74,11 +67,6 @@ typedef sycl::event (*gemv_impl_fn_ptr_t)(
7467static gemv_impl_fn_ptr_t
7568 gemv_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
7669
77-
78- // ---------------------------------------------------------------------------
79- // Typed implementation — one instantiation per (Tv, Ti) pair
80- // ---------------------------------------------------------------------------
81-
8270template <typename Tv, typename Ti>
8371static sycl::event
8472gemv_impl (sycl::queue &exec_q,
@@ -114,8 +102,6 @@ gemv_impl(sycl::queue &exec_q,
114102 try {
115103 mkl_sparse::init_matrix_handle (&spmat);
116104
117- // oneMKL 2025-2 API: set_csr_data now requires explicit nnz and uses
118- // `spmat` nomenclature. The old form without nnz is deprecated.
119105 auto ev_set = mkl_sparse::set_csr_data (
120106 exec_q, spmat,
121107 num_rows, num_cols, nnz,
@@ -155,11 +141,6 @@ gemv_impl(sycl::queue &exec_q,
155141 return gemv_ev;
156142}
157143
158-
159- // ---------------------------------------------------------------------------
160- // Python-facing entry point
161- // ---------------------------------------------------------------------------
162-
163144std::pair<sycl::event, sycl::event>
164145sparse_gemv (sycl::queue &exec_q,
165146 const int trans,
@@ -175,33 +156,28 @@ sparse_gemv(sycl::queue &exec_q,
175156 const std::int64_t nnz,
176157 const std::vector<sycl::event> &depends)
177158{
178- // 1. ndim checks
179159 if (x.get_ndim () != 1 )
180160 throw py::value_error (" sparse_gemv: x must be a 1-D array." );
181161 if (y.get_ndim () != 1 )
182162 throw py::value_error (" sparse_gemv: y must be a 1-D array." );
183163
184- // 2. Queue compatibility
185164 if (!dpctl::utils::queues_are_compatible (
186165 exec_q, {row_ptr.get_queue (), col_ind.get_queue (),
187166 values.get_queue (), x.get_queue (), y.get_queue ()}))
188167 throw py::value_error (
189168 " sparse_gemv: USM allocations are not compatible with the "
190169 " execution queue." );
191170
192- // 3. Memory overlap: x and y must not alias
193171 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
194172 if (overlap (x, y))
195173 throw py::value_error (
196174 " sparse_gemv: input array x and output array y are overlapping "
197175 " segments of memory." );
198176
199- // 4. Output writability and size
200177 dpctl::tensor::validation::CheckWritable::throw_if_not_writable (y);
201178 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (
202179 y, static_cast <std::size_t >(num_rows));
203180
204- // 5. Map trans integer to oneMKL enum
205181 oneapi::mkl::transpose mkl_trans;
206182 switch (trans) {
207183 case 0 : mkl_trans = oneapi::mkl::transpose::nontrans; break ;
@@ -212,7 +188,6 @@ sparse_gemv(sycl::queue &exec_q,
212188 " sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)" );
213189 }
214190
215- // 6. Dispatch table lookup
216191 auto array_types = dpctl_td_ns::usm_ndarray_types ();
217192 const int val_id = array_types.typenum_to_lookup_id (values.get_typenum ());
218193 const int idx_id = array_types.typenum_to_lookup_id (row_ptr.get_typenum ());
@@ -236,11 +211,6 @@ sparse_gemv(sycl::queue &exec_q,
236211 return std::make_pair (args_ev, gemv_ev);
237212}
238213
239-
240- // ---------------------------------------------------------------------------
241- // Factory and dispatch table initialisation
242- // ---------------------------------------------------------------------------
243-
244214template <typename fnT, typename Tv, typename Ti>
245215struct GemvContigFactory
246216{
@@ -258,5 +228,4 @@ void init_sparse_gemv_dispatch_table(void)
258228 init_dispatch_table<gemv_impl_fn_ptr_t , GemvContigFactory>(
259229 gemv_dispatch_table);
260230}
261-
262231} // namespace dpnp::extensions::sparse
0 commit comments