@@ -51,7 +51,7 @@ namespace dpnp::extensions::sparse
5151{
5252
5353namespace mkl_sparse = oneapi::mkl::sparse;
54- namespace py = pybind11;
54+ namespace py = pybind11;
5555namespace type_utils = dpctl::tensor::type_utils;
5656
5757using ext::common::init_dispatch_table;
@@ -68,12 +68,12 @@ using ext::common::init_dispatch_table;
6868typedef std::pair<std::uintptr_t , sycl::event> (*gemv_init_fn_ptr_t )(
6969 sycl::queue &,
7070 oneapi::mkl::transpose,
71- const char *, // row_ptr (typeless)
72- const char *, // col_ind (typeless)
73- const char *, // values (typeless)
74- std::int64_t , // num_rows
75- std::int64_t , // num_cols
76- std::int64_t , // nnz
71+ const char *, // row_ptr (typeless)
72+ const char *, // col_ind (typeless)
73+ const char *, // values (typeless)
74+ std::int64_t , // num_rows
75+ std::int64_t , // num_cols
76+ std::int64_t , // nnz
7777 const std::vector<sycl::event> &);
7878
7979/* *
@@ -84,15 +84,15 @@ typedef sycl::event (*gemv_compute_fn_ptr_t)(
8484 sycl::queue &,
8585 oneapi::mkl::sparse::matrix_handle_t ,
8686 oneapi::mkl::transpose,
87- double , // alpha (cast to Tv inside)
88- const char *, // x (typeless)
89- double , // beta (cast to Tv inside)
90- char *, // y (typeless, writable)
87+ double , // alpha (cast to Tv inside)
88+ const char *, // x (typeless)
89+ double , // beta (cast to Tv inside)
90+ char *, // y (typeless, writable)
9191 const std::vector<sycl::event> &);
9292
9393// Init dispatch: 2-D on (Tv, Ti).
94- static gemv_init_fn_ptr_t
95- gemv_init_dispatch_table[dpctl_td_ns::num_types] [dpctl_td_ns::num_types];
94+ static gemv_init_fn_ptr_t gemv_init_dispatch_table[dpctl_td_ns::num_types]
95+ [dpctl_td_ns::num_types];
9696
9797// Compute dispatch: 1-D on Tv. The index type is baked into the handle,
9898// so compute doesn't need it.
@@ -105,48 +105,43 @@ static gemv_compute_fn_ptr_t
105105
106106template <typename Tv, typename Ti>
107107static std::pair<std::uintptr_t , sycl::event>
108- gemv_init_impl (sycl::queue &exec_q,
109- oneapi::mkl::transpose mkl_trans,
110- const char *row_ptr_data,
111- const char *col_ind_data,
112- const char *values_data,
113- std::int64_t num_rows,
114- std::int64_t num_cols,
115- std::int64_t nnz,
116- const std::vector<sycl::event> &depends)
108+ gemv_init_impl (sycl::queue &exec_q,
109+ oneapi::mkl::transpose mkl_trans,
110+ const char *row_ptr_data,
111+ const char *col_ind_data,
112+ const char *values_data,
113+ std::int64_t num_rows,
114+ std::int64_t num_cols,
115+ std::int64_t nnz,
116+ const std::vector<sycl::event> &depends)
117117{
118118 type_utils::validate_type_for_device<Tv>(exec_q);
119119
120120 const Ti *row_ptr = reinterpret_cast <const Ti *>(row_ptr_data);
121121 const Ti *col_ind = reinterpret_cast <const Ti *>(col_ind_data);
122- const Tv *values = reinterpret_cast <const Tv *>(values_data);
122+ const Tv *values = reinterpret_cast <const Tv *>(values_data);
123123
124124 mkl_sparse::matrix_handle_t spmat = nullptr ;
125125 mkl_sparse::init_matrix_handle (&spmat);
126126
127127 auto ev_set = mkl_sparse::set_csr_data (
128- exec_q, spmat,
129- num_rows, num_cols, nnz,
130- oneapi::mkl::index_base::zero,
131- const_cast <Ti *>(row_ptr),
132- const_cast <Ti *>(col_ind),
133- const_cast <Tv *>(values),
134- depends);
128+ exec_q, spmat, num_rows, num_cols, nnz, oneapi::mkl::index_base::zero,
129+ const_cast <Ti *>(row_ptr), const_cast <Ti *>(col_ind),
130+ const_cast <Tv *>(values), depends);
135131
136132 sycl::event ev_opt;
137133 try {
138- ev_opt = mkl_sparse::optimize_gemv (
139- exec_q, mkl_trans, spmat, {ev_set});
134+ ev_opt = mkl_sparse::optimize_gemv (exec_q, mkl_trans, spmat, {ev_set});
140135 } catch (oneapi::mkl::exception const &e) {
141136 mkl_sparse::release_matrix_handle (exec_q, &spmat, {});
142137 throw std::runtime_error (
143- std::string (" sparse_gemv_init: MKL exception in optimize_gemv: " )
144- + e.what ());
138+ std::string (" sparse_gemv_init: MKL exception in optimize_gemv: " ) +
139+ e.what ());
145140 } catch (sycl::exception const &e) {
146141 mkl_sparse::release_matrix_handle (exec_q, &spmat, {});
147142 throw std::runtime_error (
148- std::string (" sparse_gemv_init: SYCL exception in optimize_gemv: " )
149- + e.what ());
143+ std::string (" sparse_gemv_init: SYCL exception in optimize_gemv: " ) +
144+ e.what ());
150145 }
151146
152147 auto handle_ptr = reinterpret_cast <std::uintptr_t >(spmat);
@@ -158,32 +153,28 @@ gemv_init_impl(sycl::queue &exec_q,
158153// ---------------------------------------------------------------------------
159154
160155template <typename Tv>
161- static sycl::event
162- gemv_compute_impl (sycl::queue &exec_q,
163- mkl_sparse::matrix_handle_t spmat,
164- oneapi::mkl::transpose mkl_trans,
165- double alpha_d,
166- const char *x_data,
167- double beta_d,
168- char *y_data,
169- const std::vector<sycl::event> &depends)
156+ static sycl::event gemv_compute_impl (sycl::queue &exec_q,
157+ mkl_sparse::matrix_handle_t spmat,
158+ oneapi::mkl::transpose mkl_trans,
159+ double alpha_d,
160+ const char *x_data,
161+ double beta_d,
162+ char *y_data,
163+ const std::vector<sycl::event> &depends)
170164{
171165 // For complex Tv the single-arg constructor sets imag to zero.
172166 // Solvers use alpha=1, beta=0 so this is exact; other callers
173167 // passing complex scalars via this path will lose the imag
174168 // component silently.
175169 const Tv alpha = static_cast <Tv>(alpha_d);
176- const Tv beta = static_cast <Tv>(beta_d);
170+ const Tv beta = static_cast <Tv>(beta_d);
177171
178172 const Tv *x = reinterpret_cast <const Tv *>(x_data);
179- Tv *y = reinterpret_cast <Tv *>(y_data);
173+ Tv *y = reinterpret_cast <Tv *>(y_data);
180174
181175 try {
182- return mkl_sparse::gemv (
183- exec_q, mkl_trans,
184- alpha, spmat,
185- x, beta, y,
186- depends);
176+ return mkl_sparse::gemv (exec_q, mkl_trans, alpha, spmat, x, beta, y,
177+ depends);
187178 } catch (oneapi::mkl::exception const &e) {
188179 throw std::runtime_error (
189180 std::string (" sparse_gemv_compute: MKL exception: " ) + e.what ());
@@ -197,33 +188,35 @@ gemv_compute_impl(sycl::queue &exec_q,
197188// Public entry points
198189// ---------------------------------------------------------------------------
199190
200- static oneapi::mkl::transpose
201- decode_trans (const int trans)
191+ static oneapi::mkl::transpose decode_trans (const int trans)
202192{
203193 switch (trans) {
204- case 0 : return oneapi::mkl::transpose::nontrans;
205- case 1 : return oneapi::mkl::transpose::trans;
206- case 2 : return oneapi::mkl::transpose::conjtrans;
207- default :
208- throw std::invalid_argument (
209- " sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)" );
194+ case 0 :
195+ return oneapi::mkl::transpose::nontrans;
196+ case 1 :
197+ return oneapi::mkl::transpose::trans;
198+ case 2 :
199+ return oneapi::mkl::transpose::conjtrans;
200+ default :
201+ throw std::invalid_argument (
202+ " sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)" );
210203 }
211204}
212205
213206std::tuple<std::uintptr_t , int , sycl::event>
214- sparse_gemv_init (sycl::queue &exec_q,
215- const int trans,
216- const dpctl::tensor::usm_ndarray &row_ptr,
217- const dpctl::tensor::usm_ndarray &col_ind,
218- const dpctl::tensor::usm_ndarray &values,
219- const std::int64_t num_rows,
220- const std::int64_t num_cols,
221- const std::int64_t nnz,
222- const std::vector<sycl::event> &depends)
207+ sparse_gemv_init (sycl::queue &exec_q,
208+ const int trans,
209+ const dpctl::tensor::usm_ndarray &row_ptr,
210+ const dpctl::tensor::usm_ndarray &col_ind,
211+ const dpctl::tensor::usm_ndarray &values,
212+ const std::int64_t num_rows,
213+ const std::int64_t num_cols,
214+ const std::int64_t nnz,
215+ const std::vector<sycl::event> &depends)
223216{
224217 if (!dpctl::utils::queues_are_compatible (
225- exec_q, {row_ptr. get_queue (), col_ind. get_queue (),
226- values.get_queue ()}))
218+ exec_q,
219+ {row_ptr. get_queue (), col_ind. get_queue (), values.get_queue ()}))
227220 throw py::value_error (
228221 " sparse_gemv_init: USM allocations are not compatible with the "
229222 " execution queue." );
@@ -260,34 +253,32 @@ sparse_gemv_init(sycl::queue &exec_q,
260253 " dtype combination. Supported: {float32,float64,complex64,"
261254 " complex128} x {int32,int64}." );
262255
263- auto [handle_ptr, ev_opt] = init_fn (
264- exec_q, mkl_trans,
265- row_ptr.get_data (), col_ind.get_data (), values.get_data (),
266- num_rows, num_cols, nnz, depends);
256+ auto [handle_ptr, ev_opt] =
257+ init_fn (exec_q, mkl_trans, row_ptr.get_data (), col_ind.get_data (),
258+ values.get_data (), num_rows, num_cols, nnz, depends);
267259
268260 return {handle_ptr, val_id, ev_opt};
269261}
270262
271- sycl::event
272- sparse_gemv_compute (sycl::queue &exec_q,
273- const std::uintptr_t handle_ptr,
274- const int val_type_id,
275- const int trans,
276- const double alpha,
277- const dpctl::tensor::usm_ndarray &x,
278- const double beta,
279- const dpctl::tensor::usm_ndarray &y,
280- const std::int64_t num_rows,
281- const std::int64_t num_cols,
282- const std::vector<sycl::event> &depends)
263+ sycl::event sparse_gemv_compute (sycl::queue &exec_q,
264+ const std::uintptr_t handle_ptr,
265+ const int val_type_id,
266+ const int trans,
267+ const double alpha,
268+ const dpctl::tensor::usm_ndarray &x,
269+ const double beta,
270+ const dpctl::tensor::usm_ndarray &y,
271+ const std::int64_t num_rows,
272+ const std::int64_t num_cols,
273+ const std::vector<sycl::event> &depends)
283274{
284275 if (x.get_ndim () != 1 )
285276 throw py::value_error (" sparse_gemv_compute: x must be a 1-D array." );
286277 if (y.get_ndim () != 1 )
287278 throw py::value_error (" sparse_gemv_compute: y must be a 1-D array." );
288279
289- if (!dpctl::utils::queues_are_compatible (
290- exec_q, {x.get_queue (), y.get_queue ()}))
280+ if (!dpctl::utils::queues_are_compatible (exec_q,
281+ {x.get_queue (), y.get_queue ()}))
291282 throw py::value_error (
292283 " sparse_gemv_compute: USM allocations are not compatible with the "
293284 " execution queue." );
@@ -302,8 +293,7 @@ sparse_gemv_compute(sycl::queue &exec_q,
302293 // Shape validation: op(A) is (num_rows, num_cols) for trans=N,
303294 // (num_cols, num_rows) for trans={T,C}.
304295 auto mkl_trans = decode_trans (trans);
305- const bool is_non_trans =
306- (mkl_trans == oneapi::mkl::transpose::nontrans);
296+ const bool is_non_trans = (mkl_trans == oneapi::mkl::transpose::nontrans);
307297 const std::int64_t op_rows = is_non_trans ? num_rows : num_cols;
308298 const std::int64_t op_cols = is_non_trans ? num_cols : num_rows;
309299
@@ -328,28 +318,22 @@ sparse_gemv_compute(sycl::queue &exec_q,
328318 " of the sparse matrix used to build the handle." );
329319
330320 if (val_type_id < 0 || val_type_id >= dpctl_td_ns::num_types)
331- throw py::value_error (
332- " sparse_gemv_compute: val_type_id out of range." );
321+ throw py::value_error (" sparse_gemv_compute: val_type_id out of range." );
333322
334- gemv_compute_fn_ptr_t compute_fn =
335- gemv_compute_dispatch_table[val_type_id];
323+ gemv_compute_fn_ptr_t compute_fn = gemv_compute_dispatch_table[val_type_id];
336324
337325 if (compute_fn == nullptr )
338- throw py::value_error (
339- " sparse_gemv_compute: unsupported value dtype." );
326+ throw py::value_error (" sparse_gemv_compute: unsupported value dtype." );
340327
341328 auto spmat = reinterpret_cast <mkl_sparse::matrix_handle_t >(handle_ptr);
342329
343- return compute_fn (exec_q, spmat, mkl_trans, alpha,
344- x.get_data (), beta,
345- const_cast <char *>(y.get_data ()),
346- depends);
330+ return compute_fn (exec_q, spmat, mkl_trans, alpha, x.get_data (), beta,
331+ const_cast <char *>(y.get_data ()), depends);
347332}
348333
349- sycl::event
350- sparse_gemv_release (sycl::queue &exec_q,
351- const std::uintptr_t handle_ptr,
352- const std::vector<sycl::event> &depends)
334+ sycl::event sparse_gemv_release (sycl::queue &exec_q,
335+ const std::uintptr_t handle_ptr,
336+ const std::vector<sycl::event> &depends)
353337{
354338 auto spmat = reinterpret_cast <mkl_sparse::matrix_handle_t >(handle_ptr);
355339
@@ -378,7 +362,8 @@ struct GemvInitContigFactory
378362{
379363 fnT get ()
380364 {
381- if constexpr (types::SparseGemvInitTypePairSupportFactory<Tv, Ti>::is_defined)
365+ if constexpr (types::SparseGemvInitTypePairSupportFactory<
366+ Tv, Ti>::is_defined)
382367 return gemv_init_impl<Tv, Ti>;
383368 else
384369 return nullptr ;
@@ -390,7 +375,8 @@ struct GemvComputeContigFactory
390375{
391376 fnT get ()
392377 {
393- if constexpr (types::SparseGemvComputeTypeSupportFactory<Tv>::is_defined)
378+ if constexpr (types::SparseGemvComputeTypeSupportFactory<
379+ Tv>::is_defined)
394380 return gemv_compute_impl<Tv>;
395381 else
396382 return nullptr ;
@@ -406,9 +392,7 @@ void init_sparse_gemv_dispatch_tables(void)
406392 // 1-D table on Tv for compute. dpctl's type dispatch headers expose
407393 // DispatchVectorBuilder as the 1-D analogue of DispatchTableBuilder.
408394 dpctl_td_ns::DispatchVectorBuilder<
409- gemv_compute_fn_ptr_t ,
410- GemvComputeContigFactory,
411- dpctl_td_ns::num_types>
395+ gemv_compute_fn_ptr_t , GemvComputeContigFactory, dpctl_td_ns::num_types>
412396 builder;
413397 builder.populate_dispatch_vector (gemv_compute_dispatch_table);
414398}
0 commit comments