Skip to content

Commit 18bd2c3

Browse files
committed
fix: 3 bugs in _iterative.py (asnumpy, GMRES V alloc, MINRES atol)
Bug 1 — GMRES crash: `_dpnp.asnumpy(h_dp)` does not exist as a module-level function in dpnp. Changed to the correct array-method form `h_dp.asnumpy()`. Bug 2 — GMRES performance: `_dpnp.stack(V_cols, axis=1)` was called on every inner Arnoldi iteration, reallocating a growing (n x j) device matrix at each step (O(j^2*n) memory traffic per restart). Replaced with a pre-allocated V matrix `(n, restart+1)` filled column-by-column; back-substitution and the solution update now index directly into V rather than stacking V_cols. Bug 3 — MINRES silent ignore of atol: `_get_atol(\"minres\", bnrm, atol=None, rtol=tol)` hard-coded `atol=None`, discarding the caller's `atol` argument entirely. Changed to `atol=atol` so the caller's absolute tolerance is respected.
1 parent 969b1e9 commit 18bd2c3

1 file changed

Lines changed: 38 additions & 29 deletions

File tree

dpnp/scipy/sparse/linalg/_iterative.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
* GMRES: Givens-rotation Hessenberg QR, allocation-free scalar CPU side;
5151
all matvec + inner-product work stays on device.
5252
* GMRES: happy breakdown via h_{j+1,j} == 0
53+
* GMRES: V basis pre-allocated as (n, restart+1); no per-iteration stack().
5354
* MINRES: native Paige-Saunders (1975) recurrence — no scipy host round-trip.
5455
QR step uses the exact two-rotation recurrence from SciPy minres.py:
5556
oldeps = epsln
@@ -264,7 +265,7 @@ def cg(
264265

265266
rhotol = float(_np.finfo(_np_dtype(dtype)).eps ** 2)
266267

267-
# FIX: use `x0 is not None` to detect a non-trivial initial guess instead
268+
# use `x0 is not None` to detect a non-trivial initial guess instead
268269
# of `_dpnp.any(x)` which returns a dpnp array and raises AmbiguousTruth.
269270
r = b - A_op.matvec(x) if x0 is not None else b.copy()
270271
z = M_op.matvec(r)
@@ -377,19 +378,23 @@ def gmres(
377378
info = maxiter
378379

379380
for _outer in range(maxiter):
380-
# FIX: use x0 is not None for the outer-loop residual too; after the
381-
# first restart x has been updated so always compute the residual.
382381
r = M_op.matvec(b - A_op.matvec(x))
383382
beta = float(_dpnp.linalg.norm(r))
384383
if beta == 0.0 or beta <= atol_eff:
385384
info = 0
386385
break
387386

388-
V_cols = [r / beta]
389-
H_np = _np.zeros((restart + 1, restart), dtype=H_dtype)
390-
cs_np = _np.zeros(restart, dtype=H_dtype)
391-
sn_np = _np.zeros(restart, dtype=H_dtype)
392-
g_np = _np.zeros(restart + 1, dtype=H_dtype)
387+
# FIX (Bug 2): Pre-allocate V as (n, restart+1) and fill
388+
# column-by-column. The previous code called
389+
# `_dpnp.stack(V_cols, axis=1)` on every inner iteration,
390+
# reallocating a growing device matrix at O(j^2*n) total cost.
391+
V = _dpnp.zeros((n, restart + 1), dtype=dtype)
392+
V[:, 0] = r / beta
393+
394+
H_np = _np.zeros((restart + 1, restart), dtype=H_dtype)
395+
cs_np = _np.zeros(restart, dtype=H_dtype)
396+
sn_np = _np.zeros(restart, dtype=H_dtype)
397+
g_np = _np.zeros(restart + 1, dtype=H_dtype)
393398
g_np[0] = beta
394399

395400
j_final = 0
@@ -398,25 +403,23 @@ def gmres(
398403
for j in range(restart):
399404
total_iters += 1
400405

401-
w = M_op.matvec(A_op.matvec(V_cols[j]))
402-
V_mat = _dpnp.stack(V_cols, axis=1)
406+
w = M_op.matvec(A_op.matvec(V[:, j]))
403407

404-
# FIX: dpnp arrays have no .conj() method on transpose results;
405-
# use the module-level _dpnp.conj() instead.
406-
h_dp = _dpnp.dot(_dpnp.conj(V_mat.T), w)
407-
h_np = _dpnp.asnumpy(h_dp) # FIX: asnumpy is a module-level fn, not a method
408-
w = w - _dpnp.dot(V_mat, _dpnp.asarray(h_np, dtype=dtype))
408+
# Modified Gram-Schmidt orthogonalisation against V[:, :j+1].
409+
# h_dp is a (j+1,) device vector; pull to host with .asnumpy().
410+
# FIX (Bug 1): use the array method `.asnumpy()` — there is no
411+
# module-level `_dpnp.asnumpy()` function in dpnp.
412+
h_dp = _dpnp.dot(_dpnp.conj(V[:, :j + 1].T), w)
413+
h_np = h_dp.asnumpy() # (j+1,) numpy array
414+
w = w - _dpnp.dot(V[:, :j + 1],
415+
_dpnp.asarray(h_np, dtype=dtype))
409416

410-
# FIX: float(_dpnp.linalg.norm(...)) — norm returns a 0-d dpnp
411-
# array; float() extracts the scalar correctly without .asnumpy().
412-
h_j1 = float(_dpnp.linalg.norm(w))
417+
h_j1 = float(_dpnp.linalg.norm(w))
413418

414-
# FIX: always assign h_np directly (it is already the right dtype
415-
# for both real and complex cases); avoid the .real strip which
416-
# would drop the imaginary component for complex Hessenberg entries.
417419
H_np[:j + 1, j] = h_np
418420
H_np[j + 1, j] = h_j1
419421

422+
# Apply previous Givens rotations to column j of H
420423
for i in range(j):
421424
tmp = cs_np[i] * H_np[i, j] + sn_np[i] * H_np[i + 1, j]
422425
H_np[i + 1, j] = -_np.conj(sn_np[i]) * H_np[i, j] + cs_np[i] * H_np[i + 1, j]
@@ -452,9 +455,11 @@ def gmres(
452455
happy = True
453456
break
454457

455-
V_cols.append(w / h_j1)
458+
if j + 1 < restart:
459+
V[:, j + 1] = w / h_j1
456460
j_final = j
457461

462+
# Back-substitution: solve upper-triangular H[:k,:k] y = g[:k]
458463
k = j_final + 1
459464
y_np = _np.zeros(k, dtype=H_dtype)
460465
for i in range(k - 1, -1, -1):
@@ -466,8 +471,8 @@ def gmres(
466471
else:
467472
y_np[i] /= H_np[i, i]
468473

469-
V_k = _dpnp.stack(V_cols[:k], axis=1)
470-
x = x + _dpnp.dot(V_k, _dpnp.asarray(y_np, dtype=dtype))
474+
# Solution update: x += V[:, :k] @ y
475+
x = x + _dpnp.dot(V[:, :k], _dpnp.asarray(y_np, dtype=dtype))
471476

472477
res_norm = float(_dpnp.linalg.norm(M_op.matvec(b - A_op.matvec(x))))
473478

@@ -501,6 +506,7 @@ def minres(
501506
M=None,
502507
callback: Optional[Callable] = None,
503508
check: bool = False,
509+
atol=None,
504510
) -> Tuple[_dpnp.ndarray, int]:
505511
"""MINRES for symmetric (possibly indefinite) A — pure dpnp/oneMKL.
506512
@@ -536,6 +542,7 @@ def minres(
536542
M : LinearOperator, optional — SPD preconditioner
537543
callback: callable, optional — callback(xk) after each step
538544
check : bool — verify A symmetry before iterating
545+
atol : float, optional — absolute tolerance
539546
540547
Returns
541548
-------
@@ -554,7 +561,9 @@ def minres(
554561
if bnrm == 0.0:
555562
return _dpnp.zeros_like(b), 0
556563

557-
atol_eff = _get_atol("minres", bnrm, atol=None, rtol=tol)
564+
# FIX (Bug 3): pass the caller's `atol` argument instead of hard-coded
565+
# `atol=None`, so the absolute tolerance is actually respected.
566+
atol_eff = _get_atol("minres", bnrm, atol=atol, rtol=tol)
558567

559568
# ------------------------------------------------------------------
560569
# Initialise Lanczos: compute beta1 = ||M^{-1/2} r0||_M
@@ -635,7 +644,7 @@ def minres(
635644
# QR step: correct Paige-Saunders (1975) two-rotation recurrence.
636645
#
637646
# Apply the PREVIOUS Givens rotation Q_{k-1} to the current
638-
# tridiagonal column. The column is [dbar, (alpha-shift), beta].
647+
# tridiagonal column. The column is [dbar, alpha, beta].
639648
# (alpha already incorporates the shift via the Lanczos matvec above
640649
# so the column below uses plain `alpha`.)
641650
#
@@ -654,7 +663,7 @@ def minres(
654663
delta = cs * dbar + sn * alpha # apply previous rotation — diagonal
655664
gbar_k = sn * dbar - cs * alpha # remaining entry -> new rotation
656665
epsln = sn * beta # sub-sub-diagonal for next step
657-
dbar = -cs * beta # carry forward for next step
666+
dbar = -cs * beta # carry forward for next step
658667

659668
gamma = _np.hypot(gbar_k, beta)
660669
if gamma == 0.0:
@@ -681,14 +690,14 @@ def minres(
681690
if callback is not None:
682691
callback(x)
683692

684-
# FIX: convergence check MUST come before stagnation check so that
693+
# Convergence check MUST come before stagnation check so that
685694
# a boundary iteration that satisfies both conditions is correctly
686695
# reported as converged (info=0) rather than stagnated (info=2).
687696
if rnorm <= atol_eff:
688697
info = 0
689698
break
690699

691-
# FIX: use stag_eps (10*eps) instead of bare eps to prevent
700+
# Use stag_eps (10*eps) instead of bare eps to prevent
692701
# float32 runs with tol near machine-epsilon from false-positive
693702
# stagnation before the residual norm has had a chance to converge.
694703
if phi * denom < stag_eps:

0 commit comments

Comments
 (0)