Skip to content

Commit ac3bed5

Browse files
committed
black formatting
1 parent 4442530 commit ac3bed5

File tree

2 files changed

+259
-112
lines changed

2 files changed

+259
-112
lines changed

dpnp/scipy/sparse/linalg/_interface.py

Lines changed: 160 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@
4747

4848
import dpnp
4949

50-
5150
# ---------------------------------------------------------------------------
5251
# helpers
5352
# ---------------------------------------------------------------------------
5453

54+
5555
def _isshape(shape):
5656
"""Return True if shape is a length-2 tuple of non-negative integers."""
5757
if not isinstance(shape, tuple) or len(shape) != 2:
@@ -77,6 +77,7 @@ def _get_dtype(operators, dtypes=None):
7777
dtypes.append(obj.dtype)
7878
return dpnp.result_type(*dtypes) if dtypes else None
7979

80+
8081
class LinearOperator:
8182
"""Drop-in replacement for cupyx/scipy LinearOperator backed by dpnp arrays.
8283
@@ -91,8 +92,10 @@ def __new__(cls, *args, **kwargs):
9192
return super().__new__(_CustomLinearOperator)
9293
else:
9394
obj = super().__new__(cls)
94-
if (type(obj)._matvec is LinearOperator._matvec
95-
and type(obj)._matmat is LinearOperator._matmat):
95+
if (
96+
type(obj)._matvec is LinearOperator._matvec
97+
and type(obj)._matmat is LinearOperator._matmat
98+
):
9699
warnings.warn(
97100
"LinearOperator subclass should implement at least one of "
98101
"_matvec and _matmat.",
@@ -125,13 +128,13 @@ def _matvec(self, x):
125128
return self.matmat(x.reshape(-1, 1))
126129

127130
def _matmat(self, X):
128-
return dpnp.hstack(
129-
[self.matvec(col.reshape(-1, 1)) for col in X.T]
130-
)
131+
return dpnp.hstack([self.matvec(col.reshape(-1, 1)) for col in X.T])
131132

132133
def _rmatvec(self, x):
133134
if type(self)._adjoint is LinearOperator._adjoint:
134-
raise NotImplementedError("rmatvec is not defined for this LinearOperator")
135+
raise NotImplementedError(
136+
"rmatvec is not defined for this LinearOperator"
137+
)
135138
return self.H.matvec(x)
136139

137140
def _rmatmat(self, X):
@@ -163,14 +166,18 @@ def matmat(self, X):
163166
if X.ndim != 2:
164167
raise ValueError(f"expected 2-D array, got {X.ndim}-D")
165168
if X.shape[0] != self.shape[1]:
166-
raise ValueError(f"dimension mismatch: {self.shape!r} vs {X.shape!r}")
169+
raise ValueError(
170+
f"dimension mismatch: {self.shape!r} vs {X.shape!r}"
171+
)
167172
return self._matmat(X)
168173

169174
def rmatmat(self, X):
170175
if X.ndim != 2:
171176
raise ValueError(f"expected 2-D array, got {X.ndim}-D")
172177
if X.shape[0] != self.shape[0]:
173-
raise ValueError(f"dimension mismatch: {self.shape!r} vs {X.shape!r}")
178+
raise ValueError(
179+
f"dimension mismatch: {self.shape!r} vs {X.shape!r}"
180+
)
174181
return self._rmatmat(X)
175182

176183
def dot(self, x):
@@ -184,7 +191,9 @@ def dot(self, x):
184191
return self.matvec(x)
185192
elif x.ndim == 2:
186193
return self.matmat(x)
187-
raise ValueError(f"expected 1-D or 2-D array or LinearOperator, got {x!r}")
194+
raise ValueError(
195+
f"expected 1-D or 2-D array or LinearOperator, got {x!r}"
196+
)
188197

189198
def __call__(self, x):
190199
return self * x
@@ -194,12 +203,16 @@ def __mul__(self, x):
194203

195204
def __matmul__(self, x):
196205
if dpnp.isscalar(x):
197-
raise ValueError("Scalar operands not allowed with '@'; use '*' instead")
206+
raise ValueError(
207+
"Scalar operands not allowed with '@'; use '*' instead"
208+
)
198209
return self.__mul__(x)
199210

200211
def __rmatmul__(self, x):
201212
if dpnp.isscalar(x):
202-
raise ValueError("Scalar operands not allowed with '@'; use '*' instead")
213+
raise ValueError(
214+
"Scalar operands not allowed with '@'; use '*' instead"
215+
)
203216
return self.__rmul__(x)
204217

205218
def __rmul__(self, x):
@@ -245,28 +258,33 @@ def transpose(self):
245258
T = property(transpose)
246259

247260
def __repr__(self):
248-
dt = "unspecified dtype" if self.dtype is None else f"dtype={self.dtype}"
261+
dt = (
262+
"unspecified dtype" if self.dtype is None else f"dtype={self.dtype}"
263+
)
249264
return f"<{self.shape[0]}x{self.shape[1]} {self.__class__.__name__} with {dt}>"
250265

251266

252267
# ---------------------------------------------------------------------------
253268
# Concrete operator classes
254269
# ---------------------------------------------------------------------------
255270

271+
256272
class _CustomLinearOperator(LinearOperator):
257273
"""Created when the user calls LinearOperator(shape, matvec=...) directly."""
258274

259-
def __init__(self, shape, matvec, rmatvec=None, matmat=None,
260-
dtype=None, rmatmat=None):
275+
def __init__(
276+
self, shape, matvec, rmatvec=None, matmat=None, dtype=None, rmatmat=None
277+
):
261278
super().__init__(dtype, shape)
262279
self.args = ()
263-
self.__matvec_impl = matvec
280+
self.__matvec_impl = matvec
264281
self.__rmatvec_impl = rmatvec
265282
self.__rmatmat_impl = rmatmat
266-
self.__matmat_impl = matmat
283+
self.__matmat_impl = matmat
267284
self._init_dtype()
268285

269-
def _matvec(self, x): return self.__matvec_impl(x)
286+
def _matvec(self, x):
287+
return self.__matvec_impl(x)
270288

271289
def _matmat(self, X):
272290
if self.__matmat_impl is not None:
@@ -275,7 +293,9 @@ def _matmat(self, X):
275293

276294
def _rmatvec(self, x):
277295
if self.__rmatvec_impl is None:
278-
raise NotImplementedError("rmatvec is not defined for this operator")
296+
raise NotImplementedError(
297+
"rmatvec is not defined for this operator"
298+
)
279299
return self.__rmatvec_impl(x)
280300

281301
def _rmatmat(self, X):
@@ -300,11 +320,20 @@ def __init__(self, A):
300320
self.A = A
301321
self.args = (A,)
302322

303-
def _matvec(self, x): return self.A._rmatvec(x)
304-
def _rmatvec(self, x): return self.A._matvec(x)
305-
def _matmat(self, X): return self.A._rmatmat(X)
306-
def _rmatmat(self, X): return self.A._matmat(X)
307-
def _adjoint(self): return self.A
323+
def _matvec(self, x):
324+
return self.A._rmatvec(x)
325+
326+
def _rmatvec(self, x):
327+
return self.A._matvec(x)
328+
329+
def _matmat(self, X):
330+
return self.A._rmatmat(X)
331+
332+
def _rmatmat(self, X):
333+
return self.A._matmat(X)
334+
335+
def _adjoint(self):
336+
return self.A
308337

309338

310339
class _TransposedLinearOperator(LinearOperator):
@@ -313,11 +342,20 @@ def __init__(self, A):
313342
self.A = A
314343
self.args = (A,)
315344

316-
def _matvec(self, x): return dpnp.conj(self.A._rmatvec(dpnp.conj(x)))
317-
def _rmatvec(self, x): return dpnp.conj(self.A._matvec(dpnp.conj(x)))
318-
def _matmat(self, X): return dpnp.conj(self.A._rmatmat(dpnp.conj(X)))
319-
def _rmatmat(self, X): return dpnp.conj(self.A._matmat(dpnp.conj(X)))
320-
def _transpose(self): return self.A
345+
def _matvec(self, x):
346+
return dpnp.conj(self.A._rmatvec(dpnp.conj(x)))
347+
348+
def _rmatvec(self, x):
349+
return dpnp.conj(self.A._matvec(dpnp.conj(x)))
350+
351+
def _matmat(self, X):
352+
return dpnp.conj(self.A._rmatmat(dpnp.conj(X)))
353+
354+
def _rmatmat(self, X):
355+
return dpnp.conj(self.A._matmat(dpnp.conj(X)))
356+
357+
def _transpose(self):
358+
return self.A
321359

322360

323361
class _SumLinearOperator(LinearOperator):
@@ -327,11 +365,20 @@ def __init__(self, A, B):
327365
super().__init__(_get_dtype([A, B]), A.shape)
328366
self.args = (A, B)
329367

330-
def _matvec(self, x): return self.args[0].matvec(x) + self.args[1].matvec(x)
331-
def _rmatvec(self, x): return self.args[0].rmatvec(x) + self.args[1].rmatvec(x)
332-
def _matmat(self, X): return self.args[0].matmat(X) + self.args[1].matmat(X)
333-
def _rmatmat(self, X): return self.args[0].rmatmat(X) + self.args[1].rmatmat(X)
334-
def _adjoint(self): return self.args[0].H + self.args[1].H
368+
def _matvec(self, x):
369+
return self.args[0].matvec(x) + self.args[1].matvec(x)
370+
371+
def _rmatvec(self, x):
372+
return self.args[0].rmatvec(x) + self.args[1].rmatvec(x)
373+
374+
def _matmat(self, X):
375+
return self.args[0].matmat(X) + self.args[1].matmat(X)
376+
377+
def _rmatmat(self, X):
378+
return self.args[0].rmatmat(X) + self.args[1].rmatmat(X)
379+
380+
def _adjoint(self):
381+
return self.args[0].H + self.args[1].H
335382

336383

337384
class _ProductLinearOperator(LinearOperator):
@@ -341,29 +388,53 @@ def __init__(self, A, B):
341388
super().__init__(_get_dtype([A, B]), (A.shape[0], B.shape[1]))
342389
self.args = (A, B)
343390

344-
def _matvec(self, x): return self.args[0].matvec(self.args[1].matvec(x))
345-
def _rmatvec(self, x): return self.args[1].rmatvec(self.args[0].rmatvec(x))
346-
def _matmat(self, X): return self.args[0].matmat(self.args[1].matmat(X))
347-
def _rmatmat(self, X): return self.args[1].rmatmat(self.args[0].rmatmat(X))
348-
def _adjoint(self): A, B = self.args; return B.H * A.H
391+
def _matvec(self, x):
392+
return self.args[0].matvec(self.args[1].matvec(x))
393+
394+
def _rmatvec(self, x):
395+
return self.args[1].rmatvec(self.args[0].rmatvec(x))
396+
397+
def _matmat(self, X):
398+
return self.args[0].matmat(self.args[1].matmat(X))
399+
400+
def _rmatmat(self, X):
401+
return self.args[1].rmatmat(self.args[0].rmatmat(X))
402+
403+
def _adjoint(self):
404+
A, B = self.args
405+
return B.H * A.H
406+
349407

350408
class _ScaledLinearOperator(LinearOperator):
351409
def __init__(self, A, alpha):
352410
super().__init__(_get_dtype([A], [type(alpha)]), A.shape)
353411
self.args = (A, alpha)
354412

355-
def _matvec(self, x): return self.args[1] * self.args[0].matvec(x)
356-
def _rmatvec(self, x): return dpnp.conj(self.args[1]) * self.args[0].rmatvec(x)
357-
def _matmat(self, X): return self.args[1] * self.args[0].matmat(X)
358-
def _rmatmat(self, X): return dpnp.conj(self.args[1]) * self.args[0].rmatmat(X)
359-
def _adjoint(self): A, alpha = self.args; return A.H * dpnp.conj(alpha)
413+
def _matvec(self, x):
414+
return self.args[1] * self.args[0].matvec(x)
415+
416+
def _rmatvec(self, x):
417+
return dpnp.conj(self.args[1]) * self.args[0].rmatvec(x)
418+
419+
def _matmat(self, X):
420+
return self.args[1] * self.args[0].matmat(X)
421+
422+
def _rmatmat(self, X):
423+
return dpnp.conj(self.args[1]) * self.args[0].rmatmat(X)
424+
425+
def _adjoint(self):
426+
A, alpha = self.args
427+
return A.H * dpnp.conj(alpha)
428+
360429

361430
class _PowerLinearOperator(LinearOperator):
362431
def __init__(self, A, p):
363432
if A.shape[0] != A.shape[1]:
364433
raise ValueError("matrix power requires a square operator")
365434
if not _isintlike(p) or p < 0:
366-
raise ValueError("matrix power requires a non-negative integer exponent")
435+
raise ValueError(
436+
"matrix power requires a non-negative integer exponent"
437+
)
367438
super().__init__(_get_dtype([A]), A.shape)
368439
self.args = (A, int(p))
369440

@@ -373,24 +444,37 @@ def _power(self, f, x):
373444
res = f(res)
374445
return res
375446

376-
def _matvec(self, x): return self._power(self.args[0].matvec, x)
377-
def _rmatvec(self, x): return self._power(self.args[0].rmatvec, x)
378-
def _matmat(self, X): return self._power(self.args[0].matmat, X)
379-
def _rmatmat(self, X): return self._power(self.args[0].rmatmat, X)
380-
def _adjoint(self): A, p = self.args; return A.H ** p
447+
def _matvec(self, x):
448+
return self._power(self.args[0].matvec, x)
449+
450+
def _rmatvec(self, x):
451+
return self._power(self.args[0].rmatvec, x)
452+
453+
def _matmat(self, X):
454+
return self._power(self.args[0].matmat, X)
455+
456+
def _rmatmat(self, X):
457+
return self._power(self.args[0].rmatmat, X)
458+
459+
def _adjoint(self):
460+
A, p = self.args
461+
return A.H**p
381462

382463

383464
class MatrixLinearOperator(LinearOperator):
384465
"""Wrap a dense dpnp matrix (or sparse matrix) as a LinearOperator."""
385466

386467
def __init__(self, A):
387468
super().__init__(A.dtype, A.shape)
388-
self.A = A
469+
self.A = A
389470
self.__adj = None
390-
self.args = (A,)
471+
self.args = (A,)
472+
473+
def _matmat(self, X):
474+
return self.A.dot(X)
391475

392-
def _matmat(self, X): return self.A.dot(X)
393-
def _rmatmat(self, X): return dpnp.conj(self.A.T).dot(X)
476+
def _rmatmat(self, X):
477+
return dpnp.conj(self.A.T).dot(X)
394478

395479
def _adjoint(self):
396480
if self.__adj is None:
@@ -400,10 +484,10 @@ def _adjoint(self):
400484

401485
class _AdjointMatrixOperator(MatrixLinearOperator):
402486
def __init__(self, adjoint):
403-
self.A = dpnp.conj(adjoint.A.T)
487+
self.A = dpnp.conj(adjoint.A.T)
404488
self.__adjoint = adjoint
405-
self.args = (adjoint,)
406-
self.shape = (adjoint.shape[1], adjoint.shape[0])
489+
self.args = (adjoint,)
490+
self.shape = (adjoint.shape[1], adjoint.shape[0])
407491

408492
@property
409493
def dtype(self):
@@ -419,12 +503,24 @@ class IdentityOperator(LinearOperator):
419503
def __init__(self, shape, dtype=None):
420504
super().__init__(dtype, shape)
421505

422-
def _matvec(self, x): return x
423-
def _rmatvec(self, x): return x
424-
def _matmat(self, X): return X
425-
def _rmatmat(self, X): return X
426-
def _adjoint(self): return self
427-
def _transpose(self): return self
506+
def _matvec(self, x):
507+
return x
508+
509+
def _rmatvec(self, x):
510+
return x
511+
512+
def _matmat(self, X):
513+
return X
514+
515+
def _rmatmat(self, X):
516+
return X
517+
518+
def _adjoint(self):
519+
return self
520+
521+
def _transpose(self):
522+
return self
523+
428524

429525
def aslinearoperator(A) -> LinearOperator:
430526
"""Wrap A as a LinearOperator if it is not already one.
@@ -440,6 +536,7 @@ def aslinearoperator(A) -> LinearOperator:
440536

441537
try:
442538
from dpnp.scipy import sparse as _sp
539+
443540
if _sp.issparse(A):
444541
return MatrixLinearOperator(A)
445542
except (ImportError, AttributeError):

0 commit comments

Comments
 (0)