Skip to content

Commit 838dfd8

Browse files
committed
sparse/gemv: replace explicit if/else type dispatch with 2-D dispatch table
Modeled after blas/gemm.cpp (2-D table: value type x index type) and blas/gemv.cpp (dispatch vector pattern with ContigFactory + init_dispatch_table). Changes: - Add sparse/types_matrix.hpp with SparseGemvTypePairSupportFactory<Tv, Ti> encoding the 4 supported combinations: {float32,float64} x {int32,int64} - Rewrite sparse_gemv_impl() to take typeless char* pointers (matching the blas gemv_impl signature style) — type info flows through template params only, no runtime branching inside the impl - Replace the 60-line if/else val_typenum/idx_typenum chain in sparse_gemv() with a 2-D dispatch table lookup (gemv_dispatch_table[val_id][idx_id]) - Rename init_sparse_gemv_dispatch_vector -> init_sparse_gemv_dispatch_table and implement it via init_dispatch_table<> from ext/common.hpp - All validation guards and exception handling from prior commit are preserved
1 parent 890238c commit 838dfd8

2 files changed

Lines changed: 189 additions & 135 deletions

File tree

dpnp/backend/extensions/sparse/gemv.cpp

Lines changed: 118 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -28,50 +28,82 @@
2828

2929
#include <sstream>
3030
#include <stdexcept>
31-
#include <vector>
3231

3332
#include <pybind11/pybind11.h>
3433

35-
// ext/common.hpp — dpctl_td_ns; mirrors every other dpnp extension
34+
// dpnp extension infrastructure
3635
#include "ext/common.hpp"
3736

38-
// dpctl tensor validation and utility headers — same set as blas/gemm.cpp
37+
// dpctl tensor validation and utility headers
3938
#include "utils/memory_overlap.hpp"
4039
#include "utils/output_validation.hpp"
4140
#include "utils/type_utils.hpp"
4241

4342
#include "gemv.hpp"
43+
#include "types_matrix.hpp"
4444

45-
// oneMKL sparse BLAS
4645
namespace mkl_sparse = oneapi::mkl::sparse;
47-
namespace py = pybind11;
46+
namespace py = pybind11;
4847
namespace type_utils = dpctl::tensor::type_utils;
4948

49+
using ext::common::init_dispatch_table;
50+
5051
namespace dpnp::extensions::sparse
5152
{
5253

5354
// ---------------------------------------------------------------------------
54-
// Type-dispatched implementation: y = alpha * op(A) * x + beta * y
55+
// Dispatch table: [value_type_id][index_type_id] -> impl function pointer
56+
// Mirrors the 2-D table pattern of blas/gemm.cpp.
57+
// ---------------------------------------------------------------------------
58+
59+
typedef sycl::event (*gemv_impl_fn_ptr_t)(
60+
sycl::queue &,
61+
oneapi::mkl::transpose,
62+
double, // alpha (always passed as double; cast inside)
63+
const char *, // row_ptr (typeless)
64+
const char *, // col_ind (typeless)
65+
const char *, // values (typeless)
66+
std::int64_t, // num_rows
67+
std::int64_t, // num_cols
68+
std::int64_t, // nnz
69+
const char *, // x (typeless)
70+
double, // beta (always passed as double; cast inside)
71+
char *, // y (typeless, writable)
72+
const std::vector<sycl::event> &);
73+
74+
static gemv_impl_fn_ptr_t
75+
gemv_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
76+
77+
78+
// ---------------------------------------------------------------------------
79+
// Typed implementation — one instantiation per (Tv, Ti) pair
5580
// ---------------------------------------------------------------------------
5681

57-
template <typename T, typename intType>
82+
template <typename Tv, typename Ti>
5883
static sycl::event
59-
sparse_gemv_impl(sycl::queue &exec_q,
60-
oneapi::mkl::transpose mkl_trans,
61-
T alpha,
62-
intType *row_ptr_ptr,
63-
intType *col_ind_ptr,
64-
T *values_ptr,
65-
std::int64_t num_rows,
66-
std::int64_t num_cols,
67-
std::int64_t nnz,
68-
T *x_ptr,
69-
T beta,
70-
T *y_ptr,
71-
const std::vector<sycl::event> &depends)
84+
gemv_impl(sycl::queue &exec_q,
85+
oneapi::mkl::transpose mkl_trans,
86+
double alpha_d,
87+
const char *row_ptr_data,
88+
const char *col_ind_data,
89+
const char *values_data,
90+
std::int64_t num_rows,
91+
std::int64_t num_cols,
92+
std::int64_t nnz,
93+
const char *x_data,
94+
double beta_d,
95+
char *y_data,
96+
const std::vector<sycl::event> &depends)
7297
{
73-
// Validate that T is supported on this device (mirrors gemm_impl pattern)
74-
type_utils::validate_type_for_device<T>(exec_q);
98+
type_utils::validate_type_for_device<Tv>(exec_q);
99+
100+
const Tv alpha = static_cast<Tv>(alpha_d);
101+
const Tv beta = static_cast<Tv>(beta_d);
102+
const Ti *row_ptr = reinterpret_cast<const Ti *>(row_ptr_data);
103+
const Ti *col_ind = reinterpret_cast<const Ti *>(col_ind_data);
104+
const Tv *values = reinterpret_cast<const Tv *>(values_data);
105+
const Tv *x = reinterpret_cast<const Tv *>(x_data);
106+
Tv *y = reinterpret_cast<Tv *>(y_data);
75107

76108
std::stringstream error_msg;
77109
bool is_exception_caught = false;
@@ -86,40 +118,35 @@ sparse_gemv_impl(sycl::queue &exec_q,
86118
exec_q, handle,
87119
num_rows, num_cols,
88120
oneapi::mkl::index_base::zero,
89-
row_ptr_ptr, col_ind_ptr, values_ptr,
121+
const_cast<Ti *>(row_ptr),
122+
const_cast<Ti *>(col_ind),
123+
const_cast<Tv *>(values),
90124
depends);
91125

92-
// optimize_gemv performs internal analysis — amortises over repeated SpMV
93126
auto ev_opt = mkl_sparse::optimize_gemv(
94127
exec_q, mkl_trans, handle, {ev_set});
95128

96129
gemv_ev = mkl_sparse::gemv(
97130
exec_q, mkl_trans,
98131
alpha, handle,
99-
x_ptr, beta, y_ptr,
132+
x, beta, y,
100133
{ev_opt});
101134

102-
// async release — waits for gemv_ev internally
103135
mkl_sparse::release_matrix_handle(exec_q, &handle, {gemv_ev});
104136

105137
} catch (oneapi::mkl::exception const &e) {
106-
error_msg
107-
<< "Unexpected MKL exception caught during sparse_gemv() call:"
108-
"\nreason: "
109-
<< e.what();
138+
error_msg << "Unexpected MKL exception caught during sparse_gemv() "
139+
"call:\nreason: " << e.what();
110140
is_exception_caught = true;
111141
} catch (sycl::exception const &e) {
112-
error_msg
113-
<< "Unexpected SYCL exception caught during sparse_gemv() call:\n"
114-
<< e.what();
142+
error_msg << "Unexpected SYCL exception caught during sparse_gemv() "
143+
"call:\n" << e.what();
115144
is_exception_caught = true;
116145
}
117146

118147
if (is_exception_caught) {
119-
// Best-effort handle cleanup before re-raising
120-
if (handle != nullptr) {
148+
if (handle != nullptr)
121149
mkl_sparse::release_matrix_handle(exec_q, &handle, {});
122-
}
123150
throw std::runtime_error(error_msg.str());
124151
}
125152

@@ -128,56 +155,51 @@ sparse_gemv_impl(sycl::queue &exec_q,
128155

129156

130157
// ---------------------------------------------------------------------------
131-
// Python-facing function
158+
// Python-facing entry point
132159
// ---------------------------------------------------------------------------
133160

134161
std::pair<sycl::event, sycl::event>
135-
sparse_gemv(sycl::queue &exec_q,
136-
const int trans,
137-
const double alpha,
138-
const dpctl::tensor::usm_ndarray &row_ptr,
139-
const dpctl::tensor::usm_ndarray &col_ind,
140-
const dpctl::tensor::usm_ndarray &values,
141-
const dpctl::tensor::usm_ndarray &x,
142-
const double beta,
143-
const dpctl::tensor::usm_ndarray &y,
144-
const std::int64_t num_rows,
145-
const std::int64_t num_cols,
146-
const std::int64_t nnz,
147-
const std::vector<sycl::event> &depends)
162+
sparse_gemv(sycl::queue &exec_q,
163+
const int trans,
164+
const double alpha,
165+
const dpctl::tensor::usm_ndarray &row_ptr,
166+
const dpctl::tensor::usm_ndarray &col_ind,
167+
const dpctl::tensor::usm_ndarray &values,
168+
const dpctl::tensor::usm_ndarray &x,
169+
const double beta,
170+
const dpctl::tensor::usm_ndarray &y,
171+
const std::int64_t num_rows,
172+
const std::int64_t num_cols,
173+
const std::int64_t nnz,
174+
const std::vector<sycl::event> &depends)
148175
{
149-
// --- 1. ndim checks ---
150-
if (x.get_ndim() != 1) {
176+
// 1. ndim checks
177+
if (x.get_ndim() != 1)
151178
throw py::value_error("sparse_gemv: x must be a 1-D array.");
152-
}
153-
if (y.get_ndim() != 1) {
179+
if (y.get_ndim() != 1)
154180
throw py::value_error("sparse_gemv: y must be a 1-D array.");
155-
}
156181

157-
// --- 2. Queue compatibility (all USM arrays must share the same queue) ---
182+
// 2. Queue compatibility
158183
if (!dpctl::utils::queues_are_compatible(
159-
exec_q,
160-
{row_ptr.get_queue(), col_ind.get_queue(),
161-
values.get_queue(), x.get_queue(), y.get_queue()})) {
184+
exec_q, {row_ptr.get_queue(), col_ind.get_queue(),
185+
values.get_queue(), x.get_queue(), y.get_queue()}))
162186
throw py::value_error(
163187
"sparse_gemv: USM allocations are not compatible with the "
164188
"execution queue.");
165-
}
166189

167-
// --- 3. Memory overlap: x and y must not alias ---
190+
// 3. Memory overlap: x and y must not alias
168191
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
169-
if (overlap(x, y)) {
192+
if (overlap(x, y))
170193
throw py::value_error(
171194
"sparse_gemv: input array x and output array y are overlapping "
172195
"segments of memory.");
173-
}
174196

175-
// --- 4. Output writability and size ---
197+
// 4. Output writability and size
176198
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(y);
177199
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
178200
y, static_cast<std::size_t>(num_rows));
179201

180-
// --- 5. Map trans integer to oneMKL enum ---
202+
// 5. Map trans integer to oneMKL enum
181203
oneapi::mkl::transpose mkl_trans;
182204
switch (trans) {
183205
case 0: mkl_trans = oneapi::mkl::transpose::nontrans; break;
@@ -188,74 +210,24 @@ sparse_gemv(sycl::queue &exec_q,
188210
"sparse_gemv: trans must be 0 (N), 1 (T), or 2 (C)");
189211
}
190212

191-
// --- 6. Type dispatch (value type x index type) ---
192-
// oneMKL sparse BLAS supports float32 and float64 (no complex yet)
193-
int val_typenum = values.get_typenum();
194-
int idx_typenum = row_ptr.get_typenum();
213+
// 6. Dispatch table lookup — replaces the explicit if/else chain
214+
auto array_types = dpctl_td_ns::usm_ndarray_types();
215+
const int val_id = array_types.typenum_to_lookup_id(values.get_typenum());
216+
const int idx_id = array_types.typenum_to_lookup_id(row_ptr.get_typenum());
195217

196-
sycl::event gemv_ev;
197-
198-
if (val_typenum == UAR_FLOAT) {
199-
auto alpha_f = static_cast<float>(alpha);
200-
auto beta_f = static_cast<float>(beta);
218+
gemv_impl_fn_ptr_t gemv_fn = gemv_dispatch_table[val_id][idx_id];
219+
if (gemv_fn == nullptr)
220+
throw py::value_error(
221+
"sparse_gemv: no implementation for the given value/index dtype "
222+
"combination. Supported: float32/float64 with int32/int64 indices.");
201223

202-
if (idx_typenum == UAR_INT32) {
203-
gemv_ev = sparse_gemv_impl<float, std::int32_t>(
204-
exec_q, mkl_trans, alpha_f,
205-
row_ptr.get_data<std::int32_t>(),
206-
col_ind.get_data<std::int32_t>(),
207-
values.get_data<float>(),
208-
num_rows, num_cols, nnz,
209-
x.get_data<float>(), beta_f,
210-
y.get_data<float>(), depends);
211-
}
212-
else if (idx_typenum == UAR_INT64) {
213-
gemv_ev = sparse_gemv_impl<float, std::int64_t>(
214-
exec_q, mkl_trans, alpha_f,
215-
row_ptr.get_data<std::int64_t>(),
216-
col_ind.get_data<std::int64_t>(),
217-
values.get_data<float>(),
218-
num_rows, num_cols, nnz,
219-
x.get_data<float>(), beta_f,
220-
y.get_data<float>(), depends);
221-
}
222-
else {
223-
throw std::runtime_error(
224-
"sparse_gemv: index dtype must be int32 or int64");
225-
}
226-
}
227-
else if (val_typenum == UAR_DOUBLE) {
228-
if (idx_typenum == UAR_INT32) {
229-
gemv_ev = sparse_gemv_impl<double, std::int32_t>(
230-
exec_q, mkl_trans, alpha,
231-
row_ptr.get_data<std::int32_t>(),
232-
col_ind.get_data<std::int32_t>(),
233-
values.get_data<double>(),
234-
num_rows, num_cols, nnz,
235-
x.get_data<double>(), beta,
236-
y.get_data<double>(), depends);
237-
}
238-
else if (idx_typenum == UAR_INT64) {
239-
gemv_ev = sparse_gemv_impl<double, std::int64_t>(
240-
exec_q, mkl_trans, alpha,
241-
row_ptr.get_data<std::int64_t>(),
242-
col_ind.get_data<std::int64_t>(),
243-
values.get_data<double>(),
224+
sycl::event gemv_ev =
225+
gemv_fn(exec_q, mkl_trans, alpha,
226+
row_ptr.get_data(), col_ind.get_data(), values.get_data(),
244227
num_rows, num_cols, nnz,
245-
x.get_data<double>(), beta,
246-
y.get_data<double>(), depends);
247-
}
248-
else {
249-
throw std::runtime_error(
250-
"sparse_gemv: index dtype must be int32 or int64");
251-
}
252-
}
253-
else {
254-
throw std::runtime_error(
255-
"sparse_gemv: value dtype must be float32 or float64");
256-
}
228+
x.get_data(), beta, y.get_data(),
229+
depends);
257230

258-
// Keep all input/output USM arrays alive until gemv_ev completes
259231
sycl::event args_ev = dpctl::utils::keep_args_alive(
260232
exec_q, {row_ptr, col_ind, values, x, y}, {gemv_ev});
261233

@@ -264,15 +236,26 @@ sparse_gemv(sycl::queue &exec_q,
264236

265237

266238
// ---------------------------------------------------------------------------
267-
// Dispatch vector init (placeholder — matches blas convention)
239+
// Factory and dispatch table initialisation
240+
// Mirrors blas/gemm.cpp: GemmContigFactory -> GemvContigFactory
268241
// ---------------------------------------------------------------------------
269242

270-
void init_sparse_gemv_dispatch_vector(void)
243+
template <typename fnT, typename Tv, typename Ti>
244+
struct GemvContigFactory
245+
{
246+
fnT get()
247+
{
248+
if constexpr (types::SparseGemvTypePairSupportFactory<Tv, Ti>::is_defined)
249+
return gemv_impl<Tv, Ti>;
250+
else
251+
return nullptr;
252+
}
253+
};
254+
255+
void init_sparse_gemv_dispatch_table(void)
271256
{
272-
// No dispatch table needed for sparse_gemv since we do explicit
273-
// type switching in the function body (oneMKL sparse API uses
274-
// opaque handles, not templated dispatch tables).
275-
// This function exists to match the dpnp extension convention.
257+
init_dispatch_table<gemv_impl_fn_ptr_t, GemvContigFactory>(
258+
gemv_dispatch_table);
276259
}
277260

278261
} // namespace dpnp::extensions::sparse

0 commit comments

Comments
 (0)