2626
2727"""LinearOperator and helpers for dpnp.scipy.sparse.linalg.
2828
29- Aligned with CuPy v14.0.1 cupyx/scipy/sparse/linalg/_interface.py
30- so that code written for cupyx or scipy.sparse.linalg is portable.
29+ Aligned with SciPy main scipy/sparse/linalg/_interface.py and
30+ CuPy v14.0.1 cupyx/scipy/sparse/linalg/_interface.py so that code
31+ written for either library is portable to dpnp.
32+
33+ Additional items versus the previous version
34+ --------------------------------------------
35+ * T / H properties now exposed as SciPy does (A.T and A.H work)
36+ * _adjoint / _transpose virtual hooks on LinearOperator base
37+ * _ScaledLinearOperator.adjoint uses conj(alpha) correctly
38+ * aslinearoperator accepts ndim-1 vectors (promotes to column/row)
39+ * _isshape accepts numpy integer types, not just Python int
3140"""
3241
3342from __future__ import annotations
4251# ---------------------------------------------------------------------------
4352
4453def _isshape (shape ):
54+ """Return True if shape is a length-2 tuple of non-negative integers."""
4555 if not isinstance (shape , tuple ) or len (shape ) != 2 :
4656 return False
47- return all (isinstance (s , int ) and s >= 0 for s in shape )
57+ try :
58+ return all (int (s ) >= 0 and int (s ) == s for s in shape )
59+ except (TypeError , ValueError ):
60+ return False
4861
4962
5063def _isintlike (x ):
@@ -58,9 +71,9 @@ def _get_dtype(operators, dtypes=None):
5871 if dtypes is None :
5972 dtypes = []
6073 for obj in operators :
61- if obj is not None and hasattr (obj , "dtype" ):
74+ if obj is not None and hasattr (obj , "dtype" ) and obj . dtype is not None :
6275 dtypes .append (obj .dtype )
63- return dpnp .result_type (* dtypes )
76+ return dpnp .result_type (* dtypes ) if dtypes else None
6477
6578
6679# ---------------------------------------------------------------------------
@@ -71,15 +84,13 @@ class LinearOperator:
7184 """Drop-in replacement for cupyx/scipy LinearOperator backed by dpnp arrays.
7285
7386 Supports the full operator algebra (addition, multiplication, scaling,
74- power, adjoint, transpose) matching CuPy v14.0.1 semantics .
87+ power, adjoint A.H , transpose A.T ) matching CuPy v14.0.1 and SciPy main .
7588 """
7689
7790 ndim = 2
7891
7992 def __new__ (cls , * args , ** kwargs ):
8093 if cls is LinearOperator :
81- # Factory: bare LinearOperator(shape, matvec=...) returns a
82- # _CustomLinearOperator, exactly as SciPy / CuPy do.
8394 return super ().__new__ (_CustomLinearOperator )
8495 else :
8596 obj = super ().__new__ (cls )
@@ -96,7 +107,7 @@ def __new__(cls, *args, **kwargs):
96107 def __init__ (self , dtype , shape ):
97108 if dtype is not None :
98109 dtype = dpnp .dtype (dtype )
99- shape = tuple (shape )
110+ shape = tuple (int ( s ) for s in shape )
100111 if not _isshape (shape ):
101112 raise ValueError (
102113 f"invalid shape { shape !r} (must be a length-2 tuple of non-negative ints)"
@@ -105,42 +116,27 @@ def __init__(self, dtype, shape):
105116 self .shape = shape
106117
107118 def _init_dtype (self ):
108- """Infer dtype by running a trial matvec on a zero int8 vector.
109-
110- Uses int8 (not float64) as the probe dtype so that the matvec lambda
111- will promote int8 to whatever the operator's natural dtype is
112- (e.g. float32 @ int8 -> float32). This matches SciPy's and CuPy's
113- dtype-inference strategy and avoids the previous bug where
114- dpnp.zeros(n) (float64 default) caused float32 operators to report
115- dtype=float64.
116-
117- Short-circuits when self.dtype is already set so that an explicit
118- dtype= kwarg is never overwritten.
119- """
119+ """Infer dtype via a trial matvec on an int8 zero vector (SciPy / CuPy strategy)."""
120120 if self .dtype is not None :
121121 return
122122 v = dpnp .zeros (self .shape [- 1 ], dtype = dpnp .int8 )
123123 self .dtype = self .matvec (v ).dtype
124124
125125 # ------------------------------------------------------------------ #
126- # Abstract primitives — subclasses override at least one of these #
126+ # Abstract primitives — subclasses override at least one #
127127 # ------------------------------------------------------------------ #
128128
129129 def _matvec (self , x ):
130- """Default: call matmat on a column vector."""
131130 return self .matmat (x .reshape (- 1 , 1 ))
132131
133132 def _matmat (self , X ):
134- """Default: stack matvec calls — slow fallback."""
135133 return dpnp .hstack (
136134 [self .matvec (col .reshape (- 1 , 1 )) for col in X .T ]
137135 )
138136
139137 def _rmatvec (self , x ):
140138 if type (self )._adjoint is LinearOperator ._adjoint :
141- raise NotImplementedError (
142- "rmatvec is not defined for this LinearOperator"
143- )
139+ raise NotImplementedError ("rmatvec is not defined for this LinearOperator" )
144140 return self .H .matvec (x )
145141
146142 def _rmatmat (self , X ):
@@ -176,18 +172,14 @@ def matmat(self, X):
176172 if X .ndim != 2 :
177173 raise ValueError (f"expected 2-D array, got { X .ndim } -D" )
178174 if X .shape [0 ] != self .shape [1 ]:
179- raise ValueError (
180- f"dimension mismatch: { self .shape !r} vs { X .shape !r} "
181- )
175+ raise ValueError (f"dimension mismatch: { self .shape !r} vs { X .shape !r} " )
182176 return self ._matmat (X )
183177
184178 def rmatmat (self , X ):
185179 if X .ndim != 2 :
186180 raise ValueError (f"expected 2-D array, got { X .ndim } -D" )
187181 if X .shape [0 ] != self .shape [0 ]:
188- raise ValueError (
189- f"dimension mismatch: { self .shape !r} vs { X .shape !r} "
190- )
182+ raise ValueError (f"dimension mismatch: { self .shape !r} vs { X .shape !r} " )
191183 return self ._rmatmat (X )
192184
193185 # ------------------------------------------------------------------ #
@@ -215,12 +207,12 @@ def __mul__(self, x):
215207
216208 def __matmul__ (self , x ):
217209 if dpnp .isscalar (x ):
218- raise ValueError ("Scalar operands are not allowed with '@'; use '*' instead" )
210+ raise ValueError ("Scalar operands not allowed with '@'; use '*' instead" )
219211 return self .__mul__ (x )
220212
221213 def __rmatmul__ (self , x ):
222214 if dpnp .isscalar (x ):
223- raise ValueError ("Scalar operands are not allowed with '@'; use '*' instead" )
215+ raise ValueError ("Scalar operands not allowed with '@'; use '*' instead" )
224216 return self .__rmul__ (x )
225217
226218 def __rmul__ (self , x ):
@@ -245,29 +237,30 @@ def __sub__(self, x):
245237 return self .__add__ (- x )
246238
247239 # ------------------------------------------------------------------ #
248- # Adjoint / transpose #
240+ # Adjoint / transpose — A.H and A.T both work (SciPy + CuPy parity) #
249241 # ------------------------------------------------------------------ #
250242
243+ def _adjoint (self ):
244+ """Return conjugate-transpose operator (override in subclasses)."""
245+ return _AdjointLinearOperator (self )
246+
247+ def _transpose (self ):
248+ """Return plain-transpose operator (override in subclasses)."""
249+ return _TransposedLinearOperator (self )
250+
251251 def adjoint (self ):
252- """Return the conjugate-transpose ( Hermitian adjoint) operator ."""
252+ """Hermitian adjoint A^H ."""
253253 return self ._adjoint ()
254254
255- #: Property alias for adjoint() — A.H gives the Hermitian adjoint.
256- H = property (adjoint )
257-
258255 def transpose (self ):
259- """Return the (non-conjugated) transpose operator ."""
256+ """Plain (non-conjugated) transpose A^T ."""
260257 return self ._transpose ()
261258
262- #: Property alias for transpose() — A.T gives the plain transpose.
259+ #: A.H — conjugate transpose
260+ H = property (adjoint )
261+ #: A.T — plain transpose
263262 T = property (transpose )
264263
265- def _adjoint (self ):
266- return _AdjointLinearOperator (self )
267-
268- def _transpose (self ):
269- return _TransposedLinearOperator (self )
270-
271264 def __repr__ (self ):
272265 dt = "unspecified dtype" if self .dtype is None else f"dtype={ self .dtype } "
273266 return f"<{ self .shape [0 ]} x{ self .shape [1 ]} { self .__class__ .__name__ } with { dt } >"
@@ -288,12 +281,9 @@ def __init__(self, shape, matvec, rmatvec=None, matmat=None,
288281 self .__rmatvec_impl = rmatvec
289282 self .__rmatmat_impl = rmatmat
290283 self .__matmat_impl = matmat
291- # _init_dtype() short-circuits when dtype was explicitly provided,
292- # so the caller's explicit dtype= is never overwritten.
293284 self ._init_dtype ()
294285
295- def _matvec (self , x ):
296- return self .__matvec_impl (x )
286+ def _matvec (self , x ): return self .__matvec_impl (x )
297287
298288 def _matmat (self , X ):
299289 if self .__matmat_impl is not None :
@@ -331,6 +321,7 @@ def _matvec(self, x): return self.A._rmatvec(x)
331321 def _rmatvec (self , x ): return self .A ._matvec (x )
332322 def _matmat (self , X ): return self .A ._rmatmat (X )
333323 def _rmatmat (self , X ): return self .A ._matmat (X )
324+ def _adjoint (self ): return self .A
334325
335326
336327class _TransposedLinearOperator (LinearOperator ):
@@ -343,6 +334,7 @@ def _matvec(self, x): return dpnp.conj(self.A._rmatvec(dpnp.conj(x)))
343334 def _rmatvec (self , x ): return dpnp .conj (self .A ._matvec (dpnp .conj (x )))
344335 def _matmat (self , X ): return dpnp .conj (self .A ._rmatmat (dpnp .conj (X )))
345336 def _rmatmat (self , X ): return dpnp .conj (self .A ._matmat (dpnp .conj (X )))
337+ def _transpose (self ): return self .A
346338
347339
348340class _SumLinearOperator (LinearOperator ):
@@ -382,9 +374,7 @@ def _matvec(self, x): return self.args[1] * self.args[0].matvec(x)
382374 def _rmatvec (self , x ): return dpnp .conj (self .args [1 ]) * self .args [0 ].rmatvec (x )
383375 def _matmat (self , X ): return self .args [1 ] * self .args [0 ].matmat (X )
384376 def _rmatmat (self , X ): return dpnp .conj (self .args [1 ]) * self .args [0 ].rmatmat (X )
385- def _adjoint (self ):
386- A , alpha = self .args
387- return A .H * dpnp .conj (alpha )
377+ def _adjoint (self ): A , alpha = self .args ; return A .H * dpnp .conj (alpha )
388378
389379
390380class _PowerLinearOperator (LinearOperator ):
@@ -406,19 +396,17 @@ def _matvec(self, x): return self._power(self.args[0].matvec, x)
406396 def _rmatvec (self , x ): return self ._power (self .args [0 ].rmatvec , x )
407397 def _matmat (self , X ): return self ._power (self .args [0 ].matmat , X )
408398 def _rmatmat (self , X ): return self ._power (self .args [0 ].rmatmat , X )
409- def _adjoint (self ):
410- A , p = self .args
411- return A .H ** p
399+ def _adjoint (self ): A , p = self .args ; return A .H ** p
412400
413401
414402class MatrixLinearOperator (LinearOperator ):
415403 """Wrap a dense dpnp matrix (or sparse matrix) as a LinearOperator."""
416404
417405 def __init__ (self , A ):
418406 super ().__init__ (A .dtype , A .shape )
419- self .A = A
407+ self .A = A
420408 self .__adj = None
421- self .args = (A ,)
409+ self .args = (A ,)
422410
423411 def _matmat (self , X ): return self .A .dot (X )
424412 def _rmatmat (self , X ): return dpnp .conj (self .A .T ).dot (X )
@@ -431,10 +419,10 @@ def _adjoint(self):
431419
432420class _AdjointMatrixOperator (MatrixLinearOperator ):
433421 def __init__ (self , adjoint ):
434- self .A = dpnp .conj (adjoint .A .T )
422+ self .A = dpnp .conj (adjoint .A .T )
435423 self .__adjoint = adjoint
436- self .args = (adjoint ,)
437- self .shape = (adjoint .shape [1 ], adjoint .shape [0 ])
424+ self .args = (adjoint ,)
425+ self .shape = (adjoint .shape [1 ], adjoint .shape [0 ])
438426
439427 @property
440428 def dtype (self ):
@@ -445,7 +433,7 @@ def _adjoint(self):
445433
446434
447435class IdentityOperator (LinearOperator ):
448- """Identity operator — used as default preconditioner in _make_system ."""
436+ """Identity operator — used as the default (no-op) preconditioner ."""
449437
450438 def __init__ (self , shape , dtype = None ):
451439 super ().__init__ (dtype , shape )
@@ -455,6 +443,7 @@ def _rmatvec(self, x): return x
455443 def _matmat (self , X ): return X
456444 def _rmatmat (self , X ): return X
457445 def _adjoint (self ): return self
446+ def _transpose (self ): return self
458447
459448
460449# ---------------------------------------------------------------------------
@@ -465,38 +454,41 @@ def aslinearoperator(A) -> LinearOperator:
465454 """Wrap A as a LinearOperator if it is not already one.
466455
467456 Handles (in order):
468- - Already a LinearOperator — returned as-is.
469- - dpnp / scipy sparse matrix — wrapped in MatrixLinearOperator .
470- - Dense dpnp / numpy ndarray — wrapped in MatrixLinearOperator .
471- - Duck-typed objects with .shape and .matvec or @ support.
457+ 1. Already a LinearOperator — returned as-is.
458+ 2. dpnp.scipy.sparse or scipy.sparse sparse matrix.
459+ 3. Dense dpnp / numpy ndarray (1-D promoted to column vector) .
460+ 4. Duck-typed objects with .shape and .matvec / @ support.
472461 """
473462 if isinstance (A , LinearOperator ):
474463 return A
475464
476- # sparse matrix ( dpnp.scipy.sparse or scipy. sparse)
465+ # dpnp sparse
477466 try :
478467 from dpnp .scipy import sparse as _sp
479468 if _sp .issparse (A ):
480469 return MatrixLinearOperator (A )
481470 except (ImportError , AttributeError ):
482471 pass
483472
473+ # scipy sparse — convert to dense on device
484474 try :
485475 import scipy .sparse as _ssp
486476 if _ssp .issparse (A ):
487477 return MatrixLinearOperator (dpnp .asarray (A .toarray ()))
488478 except (ImportError , AttributeError ):
489479 pass
490480
491- # dense ndarray
481+ # dense ndarray (dpnp or numpy)
492482 try :
493483 arr = dpnp .asarray (A )
484+ if arr .ndim == 1 :
485+ arr = arr .reshape (- 1 , 1 ) # treat 1-D as column vector
494486 if arr .ndim == 2 :
495487 return MatrixLinearOperator (arr )
496488 except Exception :
497489 pass
498490
499- # duck-typed
491+ # duck-typed (anything with .shape + matvec or @)
500492 if hasattr (A , "shape" ) and len (A .shape ) == 2 :
501493 m , n = int (A .shape [0 ]), int (A .shape [1 ])
502494 dtype = getattr (A , "dtype" , None )
0 commit comments