Skip to content

Commit 14cb5c4

Browse files
committed
sparse: capture exec_q from CSR data at closure construction
1 parent a7ddc1c commit 14cb5c4

1 file changed

Lines changed: 10 additions & 6 deletions

File tree

dpnp/scipy/sparse/linalg/_iterative.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
# ---------------------------------------------------------------------------
5858
# Try to import the compiled _sparse_impl extension (oneMKL sparse::gemv).
5959
# If the extension has not been built yet the pure-Python / A.dot fallback
60-
# is used transparently no import error is raised at module load time.
60+
# is used transparently - no import error is raised at module load time.
6161
# ---------------------------------------------------------------------------
6262
try:
6363
from dpnp.backend.extensions.sparse import _sparse_impl as _si
@@ -128,17 +128,21 @@ def _make_fast_matvec(A):
128128
if _HAS_SPARSE_IMPL:
129129
# --- fast path: oneMKL sparse::gemv via pybind11 ---
130130
# Pull CSR arrays once; they are already in USM device memory.
131-
indptr = A.indptr # row_ptr int32 or int64 USM array
132-
indices = A.indices # col_ind int32 or int64 USM array
133-
data = A.data # values float32 or float64 USM array
131+
indptr = A.indptr # row_ptr - int32 or int64 USM array
132+
indices = A.indices # col_ind - int32 or int64 USM array
133+
data = A.data # values - float32 or float64 USM array
134134
nrows = int(A.shape[0])
135135
ncols = int(A.shape[1])
136136
nnz = int(data.shape[0])
137+
# Capture the SYCL queue from the matrix data array at closure-creation
138+
# time, not from x at call time. This avoids queue mismatch when x is
139+
# constructed on a different (e.g. default CPU) queue.
140+
exec_q = data.sycl_queue
137141

138142
def _csr_spmv(x: _dpnp.ndarray) -> _dpnp.ndarray:
139-
y = _dpnp.zeros(nrows, dtype=data.dtype, sycl_queue=x.sycl_queue)
143+
y = _dpnp.zeros(nrows, dtype=data.dtype, sycl_queue=exec_q)
140144
_, ev = _si._sparse_gemv(
141-
x.sycl_queue,
145+
exec_q,
142146
0, # trans = NoTrans
143147
1.0, # alpha
144148
indptr, indices, data,

0 commit comments

Comments
 (0)