Skip to content

Commit 4292518

Browse files
committed
sparse/linalg: pure-GPU CG/GMRES/MINRES, drop all CPU fallback paths, port SciPy corner cases
1 parent 2a4566f commit 4292518

2 files changed

Lines changed: 452 additions & 366 deletions

File tree

dpnp/scipy/sparse/linalg/_interface.py

Lines changed: 63 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,17 @@
2626

2727
"""LinearOperator and helpers for dpnp.scipy.sparse.linalg.
2828
29-
Aligned with CuPy v14.0.1 cupyx/scipy/sparse/linalg/_interface.py
30-
so that code written for cupyx or scipy.sparse.linalg is portable.
29+
Aligned with SciPy main scipy/sparse/linalg/_interface.py and
30+
CuPy v14.0.1 cupyx/scipy/sparse/linalg/_interface.py so that code
31+
written for either library is portable to dpnp.
32+
33+
Additional items versus the previous version
34+
--------------------------------------------
35+
* T / H properties now exposed as SciPy does (A.T and A.H work)
36+
* _adjoint / _transpose virtual hooks on LinearOperator base
37+
* _ScaledLinearOperator.adjoint uses conj(alpha) correctly
38+
* aslinearoperator accepts ndim-1 vectors (promotes to column/row)
39+
* _isshape accepts numpy integer types, not just Python int
3140
"""
3241

3342
from __future__ import annotations
@@ -42,9 +51,13 @@
4251
# ---------------------------------------------------------------------------
4352

4453
def _isshape(shape):
54+
"""Return True if shape is a length-2 tuple of non-negative integers."""
4555
if not isinstance(shape, tuple) or len(shape) != 2:
4656
return False
47-
return all(isinstance(s, int) and s >= 0 for s in shape)
57+
try:
58+
return all(int(s) >= 0 and int(s) == s for s in shape)
59+
except (TypeError, ValueError):
60+
return False
4861

4962

5063
def _isintlike(x):
@@ -58,9 +71,9 @@ def _get_dtype(operators, dtypes=None):
5871
if dtypes is None:
5972
dtypes = []
6073
for obj in operators:
61-
if obj is not None and hasattr(obj, "dtype"):
74+
if obj is not None and hasattr(obj, "dtype") and obj.dtype is not None:
6275
dtypes.append(obj.dtype)
63-
return dpnp.result_type(*dtypes)
76+
return dpnp.result_type(*dtypes) if dtypes else None
6477

6578

6679
# ---------------------------------------------------------------------------
@@ -71,15 +84,13 @@ class LinearOperator:
7184
"""Drop-in replacement for cupyx/scipy LinearOperator backed by dpnp arrays.
7285
7386
Supports the full operator algebra (addition, multiplication, scaling,
74-
power, adjoint, transpose) matching CuPy v14.0.1 semantics.
87+
power, adjoint A.H, transpose A.T) matching CuPy v14.0.1 and SciPy main.
7588
"""
7689

7790
ndim = 2
7891

7992
def __new__(cls, *args, **kwargs):
8093
if cls is LinearOperator:
81-
# Factory: bare LinearOperator(shape, matvec=...) returns a
82-
# _CustomLinearOperator, exactly as SciPy / CuPy do.
8394
return super().__new__(_CustomLinearOperator)
8495
else:
8596
obj = super().__new__(cls)
@@ -96,7 +107,7 @@ def __new__(cls, *args, **kwargs):
96107
def __init__(self, dtype, shape):
97108
if dtype is not None:
98109
dtype = dpnp.dtype(dtype)
99-
shape = tuple(shape)
110+
shape = tuple(int(s) for s in shape)
100111
if not _isshape(shape):
101112
raise ValueError(
102113
f"invalid shape {shape!r} (must be a length-2 tuple of non-negative ints)"
@@ -105,42 +116,27 @@ def __init__(self, dtype, shape):
105116
self.shape = shape
106117

107118
def _init_dtype(self):
108-
"""Infer dtype by running a trial matvec on a zero int8 vector.
109-
110-
Uses int8 (not float64) as the probe dtype so that the matvec lambda
111-
will promote int8 to whatever the operator's natural dtype is
112-
(e.g. float32 @ int8 -> float32). This matches SciPy's and CuPy's
113-
dtype-inference strategy and avoids the previous bug where
114-
dpnp.zeros(n) (float64 default) caused float32 operators to report
115-
dtype=float64.
116-
117-
Short-circuits when self.dtype is already set so that an explicit
118-
dtype= kwarg is never overwritten.
119-
"""
119+
"""Infer dtype via a trial matvec on an int8 zero vector (SciPy / CuPy strategy)."""
120120
if self.dtype is not None:
121121
return
122122
v = dpnp.zeros(self.shape[-1], dtype=dpnp.int8)
123123
self.dtype = self.matvec(v).dtype
124124

125125
# ------------------------------------------------------------------ #
126-
# Abstract primitives — subclasses override at least one of these #
126+
# Abstract primitives — subclasses override at least one #
127127
# ------------------------------------------------------------------ #
128128

129129
def _matvec(self, x):
130-
"""Default: call matmat on a column vector."""
131130
return self.matmat(x.reshape(-1, 1))
132131

133132
def _matmat(self, X):
134-
"""Default: stack matvec calls — slow fallback."""
135133
return dpnp.hstack(
136134
[self.matvec(col.reshape(-1, 1)) for col in X.T]
137135
)
138136

139137
def _rmatvec(self, x):
140138
if type(self)._adjoint is LinearOperator._adjoint:
141-
raise NotImplementedError(
142-
"rmatvec is not defined for this LinearOperator"
143-
)
139+
raise NotImplementedError("rmatvec is not defined for this LinearOperator")
144140
return self.H.matvec(x)
145141

146142
def _rmatmat(self, X):
@@ -176,18 +172,14 @@ def matmat(self, X):
176172
if X.ndim != 2:
177173
raise ValueError(f"expected 2-D array, got {X.ndim}-D")
178174
if X.shape[0] != self.shape[1]:
179-
raise ValueError(
180-
f"dimension mismatch: {self.shape!r} vs {X.shape!r}"
181-
)
175+
raise ValueError(f"dimension mismatch: {self.shape!r} vs {X.shape!r}")
182176
return self._matmat(X)
183177

184178
def rmatmat(self, X):
185179
if X.ndim != 2:
186180
raise ValueError(f"expected 2-D array, got {X.ndim}-D")
187181
if X.shape[0] != self.shape[0]:
188-
raise ValueError(
189-
f"dimension mismatch: {self.shape!r} vs {X.shape!r}"
190-
)
182+
raise ValueError(f"dimension mismatch: {self.shape!r} vs {X.shape!r}")
191183
return self._rmatmat(X)
192184

193185
# ------------------------------------------------------------------ #
@@ -215,12 +207,12 @@ def __mul__(self, x):
215207

216208
def __matmul__(self, x):
217209
if dpnp.isscalar(x):
218-
raise ValueError("Scalar operands are not allowed with '@'; use '*' instead")
210+
raise ValueError("Scalar operands not allowed with '@'; use '*' instead")
219211
return self.__mul__(x)
220212

221213
def __rmatmul__(self, x):
222214
if dpnp.isscalar(x):
223-
raise ValueError("Scalar operands are not allowed with '@'; use '*' instead")
215+
raise ValueError("Scalar operands not allowed with '@'; use '*' instead")
224216
return self.__rmul__(x)
225217

226218
def __rmul__(self, x):
@@ -245,29 +237,30 @@ def __sub__(self, x):
245237
return self.__add__(-x)
246238

247239
# ------------------------------------------------------------------ #
248-
# Adjoint / transpose #
240+
# Adjoint / transpose — A.H and A.T both work (SciPy + CuPy parity) #
249241
# ------------------------------------------------------------------ #
250242

243+
def _adjoint(self):
244+
"""Return conjugate-transpose operator (override in subclasses)."""
245+
return _AdjointLinearOperator(self)
246+
247+
def _transpose(self):
248+
"""Return plain-transpose operator (override in subclasses)."""
249+
return _TransposedLinearOperator(self)
250+
251251
def adjoint(self):
252-
"""Return the conjugate-transpose (Hermitian adjoint) operator."""
252+
"""Hermitian adjoint A^H."""
253253
return self._adjoint()
254254

255-
#: Property alias for adjoint() — A.H gives the Hermitian adjoint.
256-
H = property(adjoint)
257-
258255
def transpose(self):
259-
"""Return the (non-conjugated) transpose operator."""
256+
"""Plain (non-conjugated) transpose A^T."""
260257
return self._transpose()
261258

262-
#: Property alias for transpose() — A.T gives the plain transpose.
259+
#: A.H — conjugate transpose
260+
H = property(adjoint)
261+
#: A.T — plain transpose
263262
T = property(transpose)
264263

265-
def _adjoint(self):
266-
return _AdjointLinearOperator(self)
267-
268-
def _transpose(self):
269-
return _TransposedLinearOperator(self)
270-
271264
def __repr__(self):
272265
dt = "unspecified dtype" if self.dtype is None else f"dtype={self.dtype}"
273266
return f"<{self.shape[0]}x{self.shape[1]} {self.__class__.__name__} with {dt}>"
@@ -288,12 +281,9 @@ def __init__(self, shape, matvec, rmatvec=None, matmat=None,
288281
self.__rmatvec_impl = rmatvec
289282
self.__rmatmat_impl = rmatmat
290283
self.__matmat_impl = matmat
291-
# _init_dtype() short-circuits when dtype was explicitly provided,
292-
# so the caller's explicit dtype= is never overwritten.
293284
self._init_dtype()
294285

295-
def _matvec(self, x):
296-
return self.__matvec_impl(x)
286+
def _matvec(self, x): return self.__matvec_impl(x)
297287

298288
def _matmat(self, X):
299289
if self.__matmat_impl is not None:
@@ -331,6 +321,7 @@ def _matvec(self, x): return self.A._rmatvec(x)
331321
def _rmatvec(self, x): return self.A._matvec(x)
332322
def _matmat(self, X): return self.A._rmatmat(X)
333323
def _rmatmat(self, X): return self.A._matmat(X)
324+
def _adjoint(self): return self.A
334325

335326

336327
class _TransposedLinearOperator(LinearOperator):
@@ -343,6 +334,7 @@ def _matvec(self, x): return dpnp.conj(self.A._rmatvec(dpnp.conj(x)))
343334
def _rmatvec(self, x): return dpnp.conj(self.A._matvec(dpnp.conj(x)))
344335
def _matmat(self, X): return dpnp.conj(self.A._rmatmat(dpnp.conj(X)))
345336
def _rmatmat(self, X): return dpnp.conj(self.A._matmat(dpnp.conj(X)))
337+
def _transpose(self): return self.A
346338

347339

348340
class _SumLinearOperator(LinearOperator):
@@ -382,9 +374,7 @@ def _matvec(self, x): return self.args[1] * self.args[0].matvec(x)
382374
def _rmatvec(self, x): return dpnp.conj(self.args[1]) * self.args[0].rmatvec(x)
383375
def _matmat(self, X): return self.args[1] * self.args[0].matmat(X)
384376
def _rmatmat(self, X): return dpnp.conj(self.args[1]) * self.args[0].rmatmat(X)
385-
def _adjoint(self):
386-
A, alpha = self.args
387-
return A.H * dpnp.conj(alpha)
377+
def _adjoint(self): A, alpha = self.args; return A.H * dpnp.conj(alpha)
388378

389379

390380
class _PowerLinearOperator(LinearOperator):
@@ -406,19 +396,17 @@ def _matvec(self, x): return self._power(self.args[0].matvec, x)
406396
def _rmatvec(self, x): return self._power(self.args[0].rmatvec, x)
407397
def _matmat(self, X): return self._power(self.args[0].matmat, X)
408398
def _rmatmat(self, X): return self._power(self.args[0].rmatmat, X)
409-
def _adjoint(self):
410-
A, p = self.args
411-
return A.H ** p
399+
def _adjoint(self): A, p = self.args; return A.H ** p
412400

413401

414402
class MatrixLinearOperator(LinearOperator):
415403
"""Wrap a dense dpnp matrix (or sparse matrix) as a LinearOperator."""
416404

417405
def __init__(self, A):
418406
super().__init__(A.dtype, A.shape)
419-
self.A = A
407+
self.A = A
420408
self.__adj = None
421-
self.args = (A,)
409+
self.args = (A,)
422410

423411
def _matmat(self, X): return self.A.dot(X)
424412
def _rmatmat(self, X): return dpnp.conj(self.A.T).dot(X)
@@ -431,10 +419,10 @@ def _adjoint(self):
431419

432420
class _AdjointMatrixOperator(MatrixLinearOperator):
433421
def __init__(self, adjoint):
434-
self.A = dpnp.conj(adjoint.A.T)
422+
self.A = dpnp.conj(adjoint.A.T)
435423
self.__adjoint = adjoint
436-
self.args = (adjoint,)
437-
self.shape = (adjoint.shape[1], adjoint.shape[0])
424+
self.args = (adjoint,)
425+
self.shape = (adjoint.shape[1], adjoint.shape[0])
438426

439427
@property
440428
def dtype(self):
@@ -445,7 +433,7 @@ def _adjoint(self):
445433

446434

447435
class IdentityOperator(LinearOperator):
448-
"""Identity operator — used as default preconditioner in _make_system."""
436+
"""Identity operator — used as the default (no-op) preconditioner."""
449437

450438
def __init__(self, shape, dtype=None):
451439
super().__init__(dtype, shape)
@@ -455,6 +443,7 @@ def _rmatvec(self, x): return x
455443
def _matmat(self, X): return X
456444
def _rmatmat(self, X): return X
457445
def _adjoint(self): return self
446+
def _transpose(self): return self
458447

459448

460449
# ---------------------------------------------------------------------------
@@ -465,38 +454,41 @@ def aslinearoperator(A) -> LinearOperator:
465454
"""Wrap A as a LinearOperator if it is not already one.
466455
467456
Handles (in order):
468-
- Already a LinearOperator — returned as-is.
469-
- dpnp / scipy sparse matrix — wrapped in MatrixLinearOperator.
470-
- Dense dpnp / numpy ndarray — wrapped in MatrixLinearOperator.
471-
- Duck-typed objects with .shape and .matvec or @ support.
457+
1. Already a LinearOperator — returned as-is.
458+
2. dpnp.scipy.sparse or scipy.sparse sparse matrix.
459+
3. Dense dpnp / numpy ndarray (1-D promoted to column vector).
460+
4. Duck-typed objects with .shape and .matvec / @ support.
472461
"""
473462
if isinstance(A, LinearOperator):
474463
return A
475464

476-
# sparse matrix (dpnp.scipy.sparse or scipy.sparse)
465+
# dpnp sparse
477466
try:
478467
from dpnp.scipy import sparse as _sp
479468
if _sp.issparse(A):
480469
return MatrixLinearOperator(A)
481470
except (ImportError, AttributeError):
482471
pass
483472

473+
# scipy sparse — convert to dense on device
484474
try:
485475
import scipy.sparse as _ssp
486476
if _ssp.issparse(A):
487477
return MatrixLinearOperator(dpnp.asarray(A.toarray()))
488478
except (ImportError, AttributeError):
489479
pass
490480

491-
# dense ndarray
481+
# dense ndarray (dpnp or numpy)
492482
try:
493483
arr = dpnp.asarray(A)
484+
if arr.ndim == 1:
485+
arr = arr.reshape(-1, 1) # treat 1-D as column vector
494486
if arr.ndim == 2:
495487
return MatrixLinearOperator(arr)
496488
except Exception:
497489
pass
498490

499-
# duck-typed
491+
# duck-typed (anything with .shape + matvec or @)
500492
if hasattr(A, "shape") and len(A.shape) == 2:
501493
m, n = int(A.shape[0]), int(A.shape[1])
502494
dtype = getattr(A, "dtype", None)

0 commit comments

Comments
 (0)