4747
4848import dpnp
4949
50-
5150# ---------------------------------------------------------------------------
5251# helpers
5352# ---------------------------------------------------------------------------
5453
54+
5555def _isshape (shape ):
5656 """Return True if shape is a length-2 tuple of non-negative integers."""
5757 if not isinstance (shape , tuple ) or len (shape ) != 2 :
@@ -77,6 +77,7 @@ def _get_dtype(operators, dtypes=None):
7777 dtypes .append (obj .dtype )
7878 return dpnp .result_type (* dtypes ) if dtypes else None
7979
80+
8081class LinearOperator :
8182 """Drop-in replacement for cupyx/scipy LinearOperator backed by dpnp arrays.
8283
@@ -91,8 +92,10 @@ def __new__(cls, *args, **kwargs):
9192 return super ().__new__ (_CustomLinearOperator )
9293 else :
9394 obj = super ().__new__ (cls )
94- if (type (obj )._matvec is LinearOperator ._matvec
95- and type (obj )._matmat is LinearOperator ._matmat ):
95+ if (
96+ type (obj )._matvec is LinearOperator ._matvec
97+ and type (obj )._matmat is LinearOperator ._matmat
98+ ):
9699 warnings .warn (
97100 "LinearOperator subclass should implement at least one of "
98101 "_matvec and _matmat." ,
@@ -125,13 +128,13 @@ def _matvec(self, x):
125128 return self .matmat (x .reshape (- 1 , 1 ))
126129
127130 def _matmat (self , X ):
128- return dpnp .hstack (
129- [self .matvec (col .reshape (- 1 , 1 )) for col in X .T ]
130- )
131+ return dpnp .hstack ([self .matvec (col .reshape (- 1 , 1 )) for col in X .T ])
131132
132133 def _rmatvec (self , x ):
133134 if type (self )._adjoint is LinearOperator ._adjoint :
134- raise NotImplementedError ("rmatvec is not defined for this LinearOperator" )
135+ raise NotImplementedError (
136+ "rmatvec is not defined for this LinearOperator"
137+ )
135138 return self .H .matvec (x )
136139
137140 def _rmatmat (self , X ):
@@ -163,14 +166,18 @@ def matmat(self, X):
163166 if X .ndim != 2 :
164167 raise ValueError (f"expected 2-D array, got { X .ndim } -D" )
165168 if X .shape [0 ] != self .shape [1 ]:
166- raise ValueError (f"dimension mismatch: { self .shape !r} vs { X .shape !r} " )
169+ raise ValueError (
170+ f"dimension mismatch: { self .shape !r} vs { X .shape !r} "
171+ )
167172 return self ._matmat (X )
168173
169174 def rmatmat (self , X ):
170175 if X .ndim != 2 :
171176 raise ValueError (f"expected 2-D array, got { X .ndim } -D" )
172177 if X .shape [0 ] != self .shape [0 ]:
173- raise ValueError (f"dimension mismatch: { self .shape !r} vs { X .shape !r} " )
178+ raise ValueError (
179+ f"dimension mismatch: { self .shape !r} vs { X .shape !r} "
180+ )
174181 return self ._rmatmat (X )
175182
176183 def dot (self , x ):
@@ -184,7 +191,9 @@ def dot(self, x):
184191 return self .matvec (x )
185192 elif x .ndim == 2 :
186193 return self .matmat (x )
187- raise ValueError (f"expected 1-D or 2-D array or LinearOperator, got { x !r} " )
194+ raise ValueError (
195+ f"expected 1-D or 2-D array or LinearOperator, got { x !r} "
196+ )
188197
189198 def __call__ (self , x ):
190199 return self * x
@@ -194,12 +203,16 @@ def __mul__(self, x):
194203
195204 def __matmul__ (self , x ):
196205 if dpnp .isscalar (x ):
197- raise ValueError ("Scalar operands not allowed with '@'; use '*' instead" )
206+ raise ValueError (
207+ "Scalar operands not allowed with '@'; use '*' instead"
208+ )
198209 return self .__mul__ (x )
199210
200211 def __rmatmul__ (self , x ):
201212 if dpnp .isscalar (x ):
202- raise ValueError ("Scalar operands not allowed with '@'; use '*' instead" )
213+ raise ValueError (
214+ "Scalar operands not allowed with '@'; use '*' instead"
215+ )
203216 return self .__rmul__ (x )
204217
205218 def __rmul__ (self , x ):
@@ -245,28 +258,33 @@ def transpose(self):
245258 T = property (transpose )
246259
247260 def __repr__ (self ):
248- dt = "unspecified dtype" if self .dtype is None else f"dtype={ self .dtype } "
261+ dt = (
262+ "unspecified dtype" if self .dtype is None else f"dtype={ self .dtype } "
263+ )
249264 return f"<{ self .shape [0 ]} x{ self .shape [1 ]} { self .__class__ .__name__ } with { dt } >"
250265
251266
252267# ---------------------------------------------------------------------------
253268# Concrete operator classes
254269# ---------------------------------------------------------------------------
255270
271+
256272class _CustomLinearOperator (LinearOperator ):
257273 """Created when the user calls LinearOperator(shape, matvec=...) directly."""
258274
259- def __init__ (self , shape , matvec , rmatvec = None , matmat = None ,
260- dtype = None , rmatmat = None ):
275+ def __init__ (
276+ self , shape , matvec , rmatvec = None , matmat = None , dtype = None , rmatmat = None
277+ ):
261278 super ().__init__ (dtype , shape )
262279 self .args = ()
263- self .__matvec_impl = matvec
280+ self .__matvec_impl = matvec
264281 self .__rmatvec_impl = rmatvec
265282 self .__rmatmat_impl = rmatmat
266- self .__matmat_impl = matmat
283+ self .__matmat_impl = matmat
267284 self ._init_dtype ()
268285
269- def _matvec (self , x ): return self .__matvec_impl (x )
286+ def _matvec (self , x ):
287+ return self .__matvec_impl (x )
270288
271289 def _matmat (self , X ):
272290 if self .__matmat_impl is not None :
@@ -275,7 +293,9 @@ def _matmat(self, X):
275293
276294 def _rmatvec (self , x ):
277295 if self .__rmatvec_impl is None :
278- raise NotImplementedError ("rmatvec is not defined for this operator" )
296+ raise NotImplementedError (
297+ "rmatvec is not defined for this operator"
298+ )
279299 return self .__rmatvec_impl (x )
280300
281301 def _rmatmat (self , X ):
@@ -300,11 +320,20 @@ def __init__(self, A):
300320 self .A = A
301321 self .args = (A ,)
302322
303- def _matvec (self , x ): return self .A ._rmatvec (x )
304- def _rmatvec (self , x ): return self .A ._matvec (x )
305- def _matmat (self , X ): return self .A ._rmatmat (X )
306- def _rmatmat (self , X ): return self .A ._matmat (X )
307- def _adjoint (self ): return self .A
323+ def _matvec (self , x ):
324+ return self .A ._rmatvec (x )
325+
326+ def _rmatvec (self , x ):
327+ return self .A ._matvec (x )
328+
329+ def _matmat (self , X ):
330+ return self .A ._rmatmat (X )
331+
332+ def _rmatmat (self , X ):
333+ return self .A ._matmat (X )
334+
335+ def _adjoint (self ):
336+ return self .A
308337
309338
310339class _TransposedLinearOperator (LinearOperator ):
@@ -313,11 +342,20 @@ def __init__(self, A):
313342 self .A = A
314343 self .args = (A ,)
315344
316- def _matvec (self , x ): return dpnp .conj (self .A ._rmatvec (dpnp .conj (x )))
317- def _rmatvec (self , x ): return dpnp .conj (self .A ._matvec (dpnp .conj (x )))
318- def _matmat (self , X ): return dpnp .conj (self .A ._rmatmat (dpnp .conj (X )))
319- def _rmatmat (self , X ): return dpnp .conj (self .A ._matmat (dpnp .conj (X )))
320- def _transpose (self ): return self .A
345+ def _matvec (self , x ):
346+ return dpnp .conj (self .A ._rmatvec (dpnp .conj (x )))
347+
348+ def _rmatvec (self , x ):
349+ return dpnp .conj (self .A ._matvec (dpnp .conj (x )))
350+
351+ def _matmat (self , X ):
352+ return dpnp .conj (self .A ._rmatmat (dpnp .conj (X )))
353+
354+ def _rmatmat (self , X ):
355+ return dpnp .conj (self .A ._matmat (dpnp .conj (X )))
356+
357+ def _transpose (self ):
358+ return self .A
321359
322360
323361class _SumLinearOperator (LinearOperator ):
@@ -327,11 +365,20 @@ def __init__(self, A, B):
327365 super ().__init__ (_get_dtype ([A , B ]), A .shape )
328366 self .args = (A , B )
329367
330- def _matvec (self , x ): return self .args [0 ].matvec (x ) + self .args [1 ].matvec (x )
331- def _rmatvec (self , x ): return self .args [0 ].rmatvec (x ) + self .args [1 ].rmatvec (x )
332- def _matmat (self , X ): return self .args [0 ].matmat (X ) + self .args [1 ].matmat (X )
333- def _rmatmat (self , X ): return self .args [0 ].rmatmat (X ) + self .args [1 ].rmatmat (X )
334- def _adjoint (self ): return self .args [0 ].H + self .args [1 ].H
368+ def _matvec (self , x ):
369+ return self .args [0 ].matvec (x ) + self .args [1 ].matvec (x )
370+
371+ def _rmatvec (self , x ):
372+ return self .args [0 ].rmatvec (x ) + self .args [1 ].rmatvec (x )
373+
374+ def _matmat (self , X ):
375+ return self .args [0 ].matmat (X ) + self .args [1 ].matmat (X )
376+
377+ def _rmatmat (self , X ):
378+ return self .args [0 ].rmatmat (X ) + self .args [1 ].rmatmat (X )
379+
380+ def _adjoint (self ):
381+ return self .args [0 ].H + self .args [1 ].H
335382
336383
337384class _ProductLinearOperator (LinearOperator ):
@@ -341,29 +388,53 @@ def __init__(self, A, B):
341388 super ().__init__ (_get_dtype ([A , B ]), (A .shape [0 ], B .shape [1 ]))
342389 self .args = (A , B )
343390
344- def _matvec (self , x ): return self .args [0 ].matvec (self .args [1 ].matvec (x ))
345- def _rmatvec (self , x ): return self .args [1 ].rmatvec (self .args [0 ].rmatvec (x ))
346- def _matmat (self , X ): return self .args [0 ].matmat (self .args [1 ].matmat (X ))
347- def _rmatmat (self , X ): return self .args [1 ].rmatmat (self .args [0 ].rmatmat (X ))
348- def _adjoint (self ): A , B = self .args ; return B .H * A .H
391+ def _matvec (self , x ):
392+ return self .args [0 ].matvec (self .args [1 ].matvec (x ))
393+
394+ def _rmatvec (self , x ):
395+ return self .args [1 ].rmatvec (self .args [0 ].rmatvec (x ))
396+
397+ def _matmat (self , X ):
398+ return self .args [0 ].matmat (self .args [1 ].matmat (X ))
399+
400+ def _rmatmat (self , X ):
401+ return self .args [1 ].rmatmat (self .args [0 ].rmatmat (X ))
402+
403+ def _adjoint (self ):
404+ A , B = self .args
405+ return B .H * A .H
406+
349407
350408class _ScaledLinearOperator (LinearOperator ):
351409 def __init__ (self , A , alpha ):
352410 super ().__init__ (_get_dtype ([A ], [type (alpha )]), A .shape )
353411 self .args = (A , alpha )
354412
355- def _matvec (self , x ): return self .args [1 ] * self .args [0 ].matvec (x )
356- def _rmatvec (self , x ): return dpnp .conj (self .args [1 ]) * self .args [0 ].rmatvec (x )
357- def _matmat (self , X ): return self .args [1 ] * self .args [0 ].matmat (X )
358- def _rmatmat (self , X ): return dpnp .conj (self .args [1 ]) * self .args [0 ].rmatmat (X )
359- def _adjoint (self ): A , alpha = self .args ; return A .H * dpnp .conj (alpha )
413+ def _matvec (self , x ):
414+ return self .args [1 ] * self .args [0 ].matvec (x )
415+
416+ def _rmatvec (self , x ):
417+ return dpnp .conj (self .args [1 ]) * self .args [0 ].rmatvec (x )
418+
419+ def _matmat (self , X ):
420+ return self .args [1 ] * self .args [0 ].matmat (X )
421+
422+ def _rmatmat (self , X ):
423+ return dpnp .conj (self .args [1 ]) * self .args [0 ].rmatmat (X )
424+
425+ def _adjoint (self ):
426+ A , alpha = self .args
427+ return A .H * dpnp .conj (alpha )
428+
360429
361430class _PowerLinearOperator (LinearOperator ):
362431 def __init__ (self , A , p ):
363432 if A .shape [0 ] != A .shape [1 ]:
364433 raise ValueError ("matrix power requires a square operator" )
365434 if not _isintlike (p ) or p < 0 :
366- raise ValueError ("matrix power requires a non-negative integer exponent" )
435+ raise ValueError (
436+ "matrix power requires a non-negative integer exponent"
437+ )
367438 super ().__init__ (_get_dtype ([A ]), A .shape )
368439 self .args = (A , int (p ))
369440
@@ -373,24 +444,37 @@ def _power(self, f, x):
373444 res = f (res )
374445 return res
375446
376- def _matvec (self , x ): return self ._power (self .args [0 ].matvec , x )
377- def _rmatvec (self , x ): return self ._power (self .args [0 ].rmatvec , x )
378- def _matmat (self , X ): return self ._power (self .args [0 ].matmat , X )
379- def _rmatmat (self , X ): return self ._power (self .args [0 ].rmatmat , X )
380- def _adjoint (self ): A , p = self .args ; return A .H ** p
447+ def _matvec (self , x ):
448+ return self ._power (self .args [0 ].matvec , x )
449+
450+ def _rmatvec (self , x ):
451+ return self ._power (self .args [0 ].rmatvec , x )
452+
453+ def _matmat (self , X ):
454+ return self ._power (self .args [0 ].matmat , X )
455+
456+ def _rmatmat (self , X ):
457+ return self ._power (self .args [0 ].rmatmat , X )
458+
459+ def _adjoint (self ):
460+ A , p = self .args
461+ return A .H ** p
381462
382463
383464class MatrixLinearOperator (LinearOperator ):
384465 """Wrap a dense dpnp matrix (or sparse matrix) as a LinearOperator."""
385466
386467 def __init__ (self , A ):
387468 super ().__init__ (A .dtype , A .shape )
388- self .A = A
469+ self .A = A
389470 self .__adj = None
390- self .args = (A ,)
471+ self .args = (A ,)
472+
473+ def _matmat (self , X ):
474+ return self .A .dot (X )
391475
392- def _matmat (self , X ): return self . A . dot ( X )
393- def _rmatmat ( self , X ): return dpnp .conj (self .A .T ).dot (X )
476+ def _rmatmat (self , X ):
477+ return dpnp .conj (self .A .T ).dot (X )
394478
395479 def _adjoint (self ):
396480 if self .__adj is None :
@@ -400,10 +484,10 @@ def _adjoint(self):
400484
401485class _AdjointMatrixOperator (MatrixLinearOperator ):
402486 def __init__ (self , adjoint ):
403- self .A = dpnp .conj (adjoint .A .T )
487+ self .A = dpnp .conj (adjoint .A .T )
404488 self .__adjoint = adjoint
405- self .args = (adjoint ,)
406- self .shape = (adjoint .shape [1 ], adjoint .shape [0 ])
489+ self .args = (adjoint ,)
490+ self .shape = (adjoint .shape [1 ], adjoint .shape [0 ])
407491
408492 @property
409493 def dtype (self ):
@@ -419,12 +503,24 @@ class IdentityOperator(LinearOperator):
419503 def __init__ (self , shape , dtype = None ):
420504 super ().__init__ (dtype , shape )
421505
422- def _matvec (self , x ): return x
423- def _rmatvec (self , x ): return x
424- def _matmat (self , X ): return X
425- def _rmatmat (self , X ): return X
426- def _adjoint (self ): return self
427- def _transpose (self ): return self
506+ def _matvec (self , x ):
507+ return x
508+
509+ def _rmatvec (self , x ):
510+ return x
511+
512+ def _matmat (self , X ):
513+ return X
514+
515+ def _rmatmat (self , X ):
516+ return X
517+
518+ def _adjoint (self ):
519+ return self
520+
521+ def _transpose (self ):
522+ return self
523+
428524
429525def aslinearoperator (A ) -> LinearOperator :
430526 """Wrap A as a LinearOperator if it is not already one.
@@ -440,6 +536,7 @@ def aslinearoperator(A) -> LinearOperator:
440536
441537 try :
442538 from dpnp .scipy import sparse as _sp
539+
443540 if _sp .issparse (A ):
444541 return MatrixLinearOperator (A )
445542 except (ImportError , AttributeError ):
0 commit comments