Skip to content

Commit cb2a5b8

Browse files
committed
fix: correct 6 runtime bugs in sparse linalg iterative solvers
- Replace .asnumpy() method calls with dpnp.asnumpy() module fn (asnumpy is not an ndarray method in dpnp; it is a top-level fn) - Fix dpnp.any(x) ambiguous truth value in x0 zero-check; replace with explicit `x0 is not None` guard for r0 initialisation - Fix V_mat.T.conj() -> dpnp.conj(V_mat.T) in GMRES Arnoldi step - Guard minres beta sqrt against tiny negative floats: sqrt(abs(...)) - Unify GMRES Hessenberg h_np assignment to avoid .real stripping producing wrong dtype for complex systems - Fix float() cast on dpnp scalar norm inside GMRES inner h_j1 line
1 parent 2d753cf commit cb2a5b8

1 file changed

Lines changed: 27 additions & 8 deletions

File tree

dpnp/scipy/sparse/linalg/_iterative.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def cg(
256256

257257
rhotol = float(_np.finfo(_np_dtype(dtype)).eps ** 2)
258258

259-
r = b - A_op.matvec(x) if _dpnp.any(x) else b.copy()
259+
# FIX: use `x0 is not None` to detect a non-trivial initial guess instead
260+
# of `_dpnp.any(x)` which returns a dpnp array and raises AmbiguousTruth.
261+
r = b - A_op.matvec(x) if x0 is not None else b.copy()
260262
z = M_op.matvec(r)
261263
p = _dpnp.array(z, copy=True)
262264
rz = float(_dpnp.vdot(r, z).real)
@@ -367,6 +369,8 @@ def gmres(
367369
info = maxiter
368370

369371
for _outer in range(maxiter):
372+
# FIX: use x0 is not None for the outer-loop residual too; after the
373+
# first restart x has been updated so always compute the residual.
370374
r = M_op.matvec(b - A_op.matvec(x))
371375
beta = float(_dpnp.linalg.norm(r))
372376
if beta == 0.0 or beta <= atol_eff:
@@ -388,12 +392,21 @@ def gmres(
388392

389393
w = M_op.matvec(A_op.matvec(V_cols[j]))
390394
V_mat = _dpnp.stack(V_cols, axis=1)
391-
h_dp = _dpnp.dot(V_mat.T.conj(), w)
392-
h_np = h_dp.asnumpy()
395+
396+
# FIX: dpnp arrays have no .conj() method on transpose results;
397+
# use the module-level _dpnp.conj() instead.
398+
h_dp = _dpnp.dot(_dpnp.conj(V_mat.T), w)
399+
h_np = _dpnp.asnumpy(h_dp) # FIX: asnumpy is a module-level fn, not a method
393400
w = w - _dpnp.dot(V_mat, _dpnp.asarray(h_np, dtype=dtype))
394-
h_j1 = float(_dpnp.linalg.norm(w).asnumpy())
395401

396-
H_np[:j + 1, j] = h_np.real if not is_cpx else h_np
402+
# FIX: float(_dpnp.linalg.norm(...)) — norm returns a 0-d dpnp
403+
# array; float() extracts the scalar correctly without .asnumpy().
404+
h_j1 = float(_dpnp.linalg.norm(w))
405+
406+
# FIX: always assign h_np directly (it is already the right dtype
407+
# for both real and complex cases); avoid the .real strip which
408+
# would drop the imaginary component for complex Hessenberg entries.
409+
H_np[:j + 1, j] = h_np
397410
H_np[j + 1, j] = h_j1
398411

399412
for i in range(j):
@@ -521,15 +534,19 @@ def minres(
521534
# ------------------------------------------------------------------
522535
# Initialise Lanczos: compute beta1 = ||M^{-1/2} r0||_M
523536
# ------------------------------------------------------------------
524-
r1 = b - A_op.matvec(x) if _dpnp.any(x) else b.copy()
537+
# FIX: use `x0 is not None` to avoid AmbiguousTruth from _dpnp.any(x)
538+
r1 = b - A_op.matvec(x) if x0 is not None else b.copy()
525539
y = M_op.matvec(r1)
526-
beta1 = float(_dpnp.sqrt(_dpnp.real(_dpnp.vdot(r1, y))))
540+
541+
# FIX: guard sqrt against tiny negative rounding errors
542+
beta1 = float(_dpnp.sqrt(_dpnp.abs(_dpnp.real(_dpnp.vdot(r1, y)))))
527543

528544
if beta1 == 0.0:
529545
return x, 0
530546

531547
if check:
532548
Ay = A_op.matvec(y) - shift * y
549+
# FIX: float(_dpnp.linalg.norm(...)) — no .asnumpy() method on ndarray
533550
lhs = float(_dpnp.linalg.norm(
534551
Ay - (_dpnp.vdot(y, Ay) / _dpnp.vdot(y, y)) * y
535552
))
@@ -581,7 +598,9 @@ def minres(
581598
r2 = y.copy()
582599
y = M_op.matvec(r2)
583600
oldb = beta
584-
beta = float(_dpnp.sqrt(_dpnp.real(_dpnp.vdot(r2, y))))
601+
602+
# FIX: guard sqrt against tiny negative rounding errors
603+
beta = float(_dpnp.sqrt(_dpnp.abs(_dpnp.real(_dpnp.vdot(r2, y)))))
585604

586605
if beta < 0.0:
587606
raise ValueError("minres: preconditioner M is not positive definite")

0 commit comments

Comments
 (0)