Skip to content

Commit 4be3d9a

Browse files
committed
fix pre-commit issues
1 parent 0badee4 commit 4be3d9a

10 files changed

Lines changed: 220 additions & 285 deletions

File tree

dpnp/backend/extensions/sparse/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ if(_dpnp_sycl_targets)
4646
endif()
4747

4848
if(WIN32)
49-
if(${CMAKE_VERSION} VERSION_LESS "3.27")
49+
if(${CMAKE_VERSION} VERSION_LESS "3.27")
5050
# this is a work-around for target_link_options inserting option after -link option, cause
5151
# linker to ignore it.
5252
set(CMAKE_CXX_LINK_FLAGS

dpnp/backend/extensions/sparse/gemv.cpp

Lines changed: 94 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace dpnp::extensions::sparse
5151
{
5252

5353
namespace mkl_sparse = oneapi::mkl::sparse;
54-
namespace py = pybind11;
54+
namespace py = pybind11;
5555
namespace type_utils = dpctl::tensor::type_utils;
5656

5757
using ext::common::init_dispatch_table;
@@ -68,12 +68,12 @@ using ext::common::init_dispatch_table;
6868
typedef std::pair<std::uintptr_t, sycl::event> (*gemv_init_fn_ptr_t)(
6969
sycl::queue &,
7070
oneapi::mkl::transpose,
71-
const char *, // row_ptr (typeless)
72-
const char *, // col_ind (typeless)
73-
const char *, // values (typeless)
74-
std::int64_t, // num_rows
75-
std::int64_t, // num_cols
76-
std::int64_t, // nnz
71+
const char *, // row_ptr (typeless)
72+
const char *, // col_ind (typeless)
73+
const char *, // values (typeless)
74+
std::int64_t, // num_rows
75+
std::int64_t, // num_cols
76+
std::int64_t, // nnz
7777
const std::vector<sycl::event> &);
7878

7979
/**
@@ -84,15 +84,15 @@ typedef sycl::event (*gemv_compute_fn_ptr_t)(
8484
sycl::queue &,
8585
oneapi::mkl::sparse::matrix_handle_t,
8686
oneapi::mkl::transpose,
87-
double, // alpha (cast to Tv inside)
88-
const char *, // x (typeless)
89-
double, // beta (cast to Tv inside)
90-
char *, // y (typeless, writable)
87+
double, // alpha (cast to Tv inside)
88+
const char *, // x (typeless)
89+
double, // beta (cast to Tv inside)
90+
char *, // y (typeless, writable)
9191
const std::vector<sycl::event> &);
9292

9393
// Init dispatch: 2-D on (Tv, Ti).
94-
static gemv_init_fn_ptr_t
95-
gemv_init_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
94+
static gemv_init_fn_ptr_t gemv_init_dispatch_table[dpctl_td_ns::num_types]
95+
[dpctl_td_ns::num_types];
9696

9797
// Compute dispatch: 1-D on Tv. The index type is baked into the handle,
9898
// so compute doesn't need it.
@@ -105,48 +105,43 @@ static gemv_compute_fn_ptr_t
105105

106106
template <typename Tv, typename Ti>
107107
static std::pair<std::uintptr_t, sycl::event>
108-
gemv_init_impl(sycl::queue &exec_q,
109-
oneapi::mkl::transpose mkl_trans,
110-
const char *row_ptr_data,
111-
const char *col_ind_data,
112-
const char *values_data,
113-
std::int64_t num_rows,
114-
std::int64_t num_cols,
115-
std::int64_t nnz,
116-
const std::vector<sycl::event> &depends)
108+
gemv_init_impl(sycl::queue &exec_q,
109+
oneapi::mkl::transpose mkl_trans,
110+
const char *row_ptr_data,
111+
const char *col_ind_data,
112+
const char *values_data,
113+
std::int64_t num_rows,
114+
std::int64_t num_cols,
115+
std::int64_t nnz,
116+
const std::vector<sycl::event> &depends)
117117
{
118118
type_utils::validate_type_for_device<Tv>(exec_q);
119119

120120
const Ti *row_ptr = reinterpret_cast<const Ti *>(row_ptr_data);
121121
const Ti *col_ind = reinterpret_cast<const Ti *>(col_ind_data);
122-
const Tv *values = reinterpret_cast<const Tv *>(values_data);
122+
const Tv *values = reinterpret_cast<const Tv *>(values_data);
123123

124124
mkl_sparse::matrix_handle_t spmat = nullptr;
125125
mkl_sparse::init_matrix_handle(&spmat);
126126

127127
auto ev_set = mkl_sparse::set_csr_data(
128-
exec_q, spmat,
129-
num_rows, num_cols, nnz,
130-
oneapi::mkl::index_base::zero,
131-
const_cast<Ti *>(row_ptr),
132-
const_cast<Ti *>(col_ind),
133-
const_cast<Tv *>(values),
134-
depends);
128+
exec_q, spmat, num_rows, num_cols, nnz, oneapi::mkl::index_base::zero,
129+
const_cast<Ti *>(row_ptr), const_cast<Ti *>(col_ind),
130+
const_cast<Tv *>(values), depends);
135131

136132
sycl::event ev_opt;
137133
try {
138-
ev_opt = mkl_sparse::optimize_gemv(
139-
exec_q, mkl_trans, spmat, {ev_set});
134+
ev_opt = mkl_sparse::optimize_gemv(exec_q, mkl_trans, spmat, {ev_set});
140135
} catch (oneapi::mkl::exception const &e) {
141136
mkl_sparse::release_matrix_handle(exec_q, &spmat, {});
142137
throw std::runtime_error(
143-
std::string("sparse_gemv_init: MKL exception in optimize_gemv: ")
144-
+ e.what());
138+
std::string("sparse_gemv_init: MKL exception in optimize_gemv: ") +
139+
e.what());
145140
} catch (sycl::exception const &e) {
146141
mkl_sparse::release_matrix_handle(exec_q, &spmat, {});
147142
throw std::runtime_error(
148-
std::string("sparse_gemv_init: SYCL exception in optimize_gemv: ")
149-
+ e.what());
143+
std::string("sparse_gemv_init: SYCL exception in optimize_gemv: ") +
144+
e.what());
150145
}
151146

152147
auto handle_ptr = reinterpret_cast<std::uintptr_t>(spmat);
@@ -158,32 +153,28 @@ gemv_init_impl(sycl::queue &exec_q,
158153
// ---------------------------------------------------------------------------
159154

160155
template <typename Tv>
161-
static sycl::event
162-
gemv_compute_impl(sycl::queue &exec_q,
163-
mkl_sparse::matrix_handle_t spmat,
164-
oneapi::mkl::transpose mkl_trans,
165-
double alpha_d,
166-
const char *x_data,
167-
double beta_d,
168-
char *y_data,
169-
const std::vector<sycl::event> &depends)
156+
static sycl::event gemv_compute_impl(sycl::queue &exec_q,
157+
mkl_sparse::matrix_handle_t spmat,
158+
oneapi::mkl::transpose mkl_trans,
159+
double alpha_d,
160+
const char *x_data,
161+
double beta_d,
162+
char *y_data,
163+
const std::vector<sycl::event> &depends)
170164
{
171165
// For complex Tv the single-arg constructor sets imag to zero.
172166
// Solvers use alpha=1, beta=0 so this is exact; other callers
173167
// passing complex scalars via this path will lose the imag
174168
// component silently.
175169
const Tv alpha = static_cast<Tv>(alpha_d);
176-
const Tv beta = static_cast<Tv>(beta_d);
170+
const Tv beta = static_cast<Tv>(beta_d);
177171

178172
const Tv *x = reinterpret_cast<const Tv *>(x_data);
179-
Tv *y = reinterpret_cast<Tv *>(y_data);
173+
Tv *y = reinterpret_cast<Tv *>(y_data);
180174

181175
try {
182-
return mkl_sparse::gemv(
183-
exec_q, mkl_trans,
184-
alpha, spmat,
185-
x, beta, y,
186-
depends);
176+
return mkl_sparse::gemv(exec_q, mkl_trans, alpha, spmat, x, beta, y,
177+
depends);
187178
} catch (oneapi::mkl::exception const &e) {
188179
throw std::runtime_error(
189180
std::string("sparse_gemv_compute: MKL exception: ") + e.what());
@@ -197,33 +188,35 @@ gemv_compute_impl(sycl::queue &exec_q,
197188
// Public entry points
198189
// ---------------------------------------------------------------------------
199190

200-
static oneapi::mkl::transpose
201-
decode_trans(const int trans)
191+
static oneapi::mkl::transpose decode_trans(const int trans)
202192
{
203193
switch (trans) {
204-
case 0: return oneapi::mkl::transpose::nontrans;
205-
case 1: return oneapi::mkl::transpose::trans;
206-
case 2: return oneapi::mkl::transpose::conjtrans;
207-
default:
208-
throw std::invalid_argument(
209-
"sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)");
194+
case 0:
195+
return oneapi::mkl::transpose::nontrans;
196+
case 1:
197+
return oneapi::mkl::transpose::trans;
198+
case 2:
199+
return oneapi::mkl::transpose::conjtrans;
200+
default:
201+
throw std::invalid_argument(
202+
"sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)");
210203
}
211204
}
212205

213206
std::tuple<std::uintptr_t, int, sycl::event>
214-
sparse_gemv_init(sycl::queue &exec_q,
215-
const int trans,
216-
const dpctl::tensor::usm_ndarray &row_ptr,
217-
const dpctl::tensor::usm_ndarray &col_ind,
218-
const dpctl::tensor::usm_ndarray &values,
219-
const std::int64_t num_rows,
220-
const std::int64_t num_cols,
221-
const std::int64_t nnz,
222-
const std::vector<sycl::event> &depends)
207+
sparse_gemv_init(sycl::queue &exec_q,
208+
const int trans,
209+
const dpctl::tensor::usm_ndarray &row_ptr,
210+
const dpctl::tensor::usm_ndarray &col_ind,
211+
const dpctl::tensor::usm_ndarray &values,
212+
const std::int64_t num_rows,
213+
const std::int64_t num_cols,
214+
const std::int64_t nnz,
215+
const std::vector<sycl::event> &depends)
223216
{
224217
if (!dpctl::utils::queues_are_compatible(
225-
exec_q, {row_ptr.get_queue(), col_ind.get_queue(),
226-
values.get_queue()}))
218+
exec_q,
219+
{row_ptr.get_queue(), col_ind.get_queue(), values.get_queue()}))
227220
throw py::value_error(
228221
"sparse_gemv_init: USM allocations are not compatible with the "
229222
"execution queue.");
@@ -260,34 +253,32 @@ sparse_gemv_init(sycl::queue &exec_q,
260253
"dtype combination. Supported: {float32,float64,complex64,"
261254
"complex128} x {int32,int64}.");
262255

263-
auto [handle_ptr, ev_opt] = init_fn(
264-
exec_q, mkl_trans,
265-
row_ptr.get_data(), col_ind.get_data(), values.get_data(),
266-
num_rows, num_cols, nnz, depends);
256+
auto [handle_ptr, ev_opt] =
257+
init_fn(exec_q, mkl_trans, row_ptr.get_data(), col_ind.get_data(),
258+
values.get_data(), num_rows, num_cols, nnz, depends);
267259

268260
return {handle_ptr, val_id, ev_opt};
269261
}
270262

271-
sycl::event
272-
sparse_gemv_compute(sycl::queue &exec_q,
273-
const std::uintptr_t handle_ptr,
274-
const int val_type_id,
275-
const int trans,
276-
const double alpha,
277-
const dpctl::tensor::usm_ndarray &x,
278-
const double beta,
279-
const dpctl::tensor::usm_ndarray &y,
280-
const std::int64_t num_rows,
281-
const std::int64_t num_cols,
282-
const std::vector<sycl::event> &depends)
263+
sycl::event sparse_gemv_compute(sycl::queue &exec_q,
264+
const std::uintptr_t handle_ptr,
265+
const int val_type_id,
266+
const int trans,
267+
const double alpha,
268+
const dpctl::tensor::usm_ndarray &x,
269+
const double beta,
270+
const dpctl::tensor::usm_ndarray &y,
271+
const std::int64_t num_rows,
272+
const std::int64_t num_cols,
273+
const std::vector<sycl::event> &depends)
283274
{
284275
if (x.get_ndim() != 1)
285276
throw py::value_error("sparse_gemv_compute: x must be a 1-D array.");
286277
if (y.get_ndim() != 1)
287278
throw py::value_error("sparse_gemv_compute: y must be a 1-D array.");
288279

289-
if (!dpctl::utils::queues_are_compatible(
290-
exec_q, {x.get_queue(), y.get_queue()}))
280+
if (!dpctl::utils::queues_are_compatible(exec_q,
281+
{x.get_queue(), y.get_queue()}))
291282
throw py::value_error(
292283
"sparse_gemv_compute: USM allocations are not compatible with the "
293284
"execution queue.");
@@ -302,8 +293,7 @@ sparse_gemv_compute(sycl::queue &exec_q,
302293
// Shape validation: op(A) is (num_rows, num_cols) for trans=N,
303294
// (num_cols, num_rows) for trans={T,C}.
304295
auto mkl_trans = decode_trans(trans);
305-
const bool is_non_trans =
306-
(mkl_trans == oneapi::mkl::transpose::nontrans);
296+
const bool is_non_trans = (mkl_trans == oneapi::mkl::transpose::nontrans);
307297
const std::int64_t op_rows = is_non_trans ? num_rows : num_cols;
308298
const std::int64_t op_cols = is_non_trans ? num_cols : num_rows;
309299

@@ -328,28 +318,22 @@ sparse_gemv_compute(sycl::queue &exec_q,
328318
"of the sparse matrix used to build the handle.");
329319

330320
if (val_type_id < 0 || val_type_id >= dpctl_td_ns::num_types)
331-
throw py::value_error(
332-
"sparse_gemv_compute: val_type_id out of range.");
321+
throw py::value_error("sparse_gemv_compute: val_type_id out of range.");
333322

334-
gemv_compute_fn_ptr_t compute_fn =
335-
gemv_compute_dispatch_table[val_type_id];
323+
gemv_compute_fn_ptr_t compute_fn = gemv_compute_dispatch_table[val_type_id];
336324

337325
if (compute_fn == nullptr)
338-
throw py::value_error(
339-
"sparse_gemv_compute: unsupported value dtype.");
326+
throw py::value_error("sparse_gemv_compute: unsupported value dtype.");
340327

341328
auto spmat = reinterpret_cast<mkl_sparse::matrix_handle_t>(handle_ptr);
342329

343-
return compute_fn(exec_q, spmat, mkl_trans, alpha,
344-
x.get_data(), beta,
345-
const_cast<char *>(y.get_data()),
346-
depends);
330+
return compute_fn(exec_q, spmat, mkl_trans, alpha, x.get_data(), beta,
331+
const_cast<char *>(y.get_data()), depends);
347332
}
348333

349-
sycl::event
350-
sparse_gemv_release(sycl::queue &exec_q,
351-
const std::uintptr_t handle_ptr,
352-
const std::vector<sycl::event> &depends)
334+
sycl::event sparse_gemv_release(sycl::queue &exec_q,
335+
const std::uintptr_t handle_ptr,
336+
const std::vector<sycl::event> &depends)
353337
{
354338
auto spmat = reinterpret_cast<mkl_sparse::matrix_handle_t>(handle_ptr);
355339

@@ -378,7 +362,8 @@ struct GemvInitContigFactory
378362
{
379363
fnT get()
380364
{
381-
if constexpr (types::SparseGemvInitTypePairSupportFactory<Tv, Ti>::is_defined)
365+
if constexpr (types::SparseGemvInitTypePairSupportFactory<
366+
Tv, Ti>::is_defined)
382367
return gemv_init_impl<Tv, Ti>;
383368
else
384369
return nullptr;
@@ -390,7 +375,8 @@ struct GemvComputeContigFactory
390375
{
391376
fnT get()
392377
{
393-
if constexpr (types::SparseGemvComputeTypeSupportFactory<Tv>::is_defined)
378+
if constexpr (types::SparseGemvComputeTypeSupportFactory<
379+
Tv>::is_defined)
394380
return gemv_compute_impl<Tv>;
395381
else
396382
return nullptr;
@@ -406,9 +392,7 @@ void init_sparse_gemv_dispatch_tables(void)
406392
// 1-D table on Tv for compute. dpctl's type dispatch headers expose
407393
// DispatchVectorBuilder as the 1-D analogue of DispatchTableBuilder.
408394
dpctl_td_ns::DispatchVectorBuilder<
409-
gemv_compute_fn_ptr_t,
410-
GemvComputeContigFactory,
411-
dpctl_td_ns::num_types>
395+
gemv_compute_fn_ptr_t, GemvComputeContigFactory, dpctl_td_ns::num_types>
412396
builder;
413397
builder.populate_dispatch_vector(gemv_compute_dispatch_table);
414398
}

0 commit comments

Comments
 (0)