Skip to content

Commit 125dab5

Browse files
committed
Fix dtype.char AttributeError on dpnp dtype objects in CG/GMRES/MINRES
1 parent 4292518 commit 125dab5

1 file changed

Lines changed: 50 additions & 49 deletions

File tree

dpnp/scipy/sparse/linalg/_iterative.py

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,20 @@
9090
# Helpers
9191
# ---------------------------------------------------------------------------
9292

93+
def _np_dtype(dp_dtype) -> _np.dtype:
94+
"""Convert a dpnp dtype (or any dtype-like) to a concrete numpy dtype.
95+
96+
dpnp dtype objects (e.g. dpnp.float64) are *type objects*, not
97+
numpy dtype instances, so they have no ``.char`` attribute.
98+
Wrapping them with ``_np.dtype(...)`` normalises everything to a
99+
proper numpy dtype regardless of whether the input is a dpnp type,
100+
a numpy type, a string, or already a numpy dtype.
101+
"""
102+
return _np.dtype(dp_dtype)
103+
104+
93105
def _check_dtype(dtype, name: str) -> None:
94-
if dtype.char not in _SUPPORTED_DTYPES:
106+
if _np_dtype(dtype).char not in _SUPPORTED_DTYPES:
95107
raise TypeError(
96108
f"{name} has unsupported dtype {dtype}; "
97109
"only float32, float64, complex64, complex128 are accepted."
@@ -149,8 +161,8 @@ def _make_system(A, M, x0, b):
149161
dtype = _dpnp.complex128
150162
else:
151163
dtype = _dpnp.float64
152-
if A_op.dtype is not None and A_op.dtype.char in "fF":
153-
dtype = _dpnp.complex64 if A_op.dtype.char == "F" else _dpnp.float32
164+
if A_op.dtype is not None and _np_dtype(A_op.dtype).char in "fF":
165+
dtype = _dpnp.complex64 if _np_dtype(A_op.dtype).char == "F" else _dpnp.float32
154166

155167
b = b.astype(dtype, copy=False)
156168
_check_dtype(b.dtype, "b")
@@ -240,7 +252,8 @@ def cg(
240252
maxiter = n * 10
241253

242254
# Machine-epsilon breakdown tolerance (mirrors SciPy bicg rhotol)
243-
rhotol = float(_np.finfo(_np.dtype(dtype.char)).eps ** 2)
255+
# Use _np_dtype() to safely convert dpnp dtype to numpy dtype.
256+
rhotol = float(_np.finfo(_np_dtype(dtype)).eps ** 2)
244257

245258
r = b - A_op.matvec(x) if _dpnp.any(x) else b.copy()
246259
z = M_op.matvec(r)
@@ -350,7 +363,8 @@ def gmres(
350363

351364
is_cpx = _dpnp.issubdtype(dtype, _dpnp.complexfloating)
352365
H_dtype = _np.complex128 if is_cpx else _np.float64
353-
rhotol = float(_np.finfo(H_dtype).eps ** 2)
366+
# Use _np_dtype() so this works whether dtype is a dpnp type or numpy dtype.
367+
rhotol = float(_np.finfo(_np_dtype(dtype)).eps ** 2)
354368

355369
total_iters = 0
356370
info = maxiter
@@ -520,7 +534,8 @@ def minres(
520534
A_op, M_op, x, b, dtype = _make_system(A, M, x0, b)
521535
n = b.shape[0]
522536
is_cpx = _dpnp.issubdtype(dtype, _dpnp.complexfloating)
523-
eps = float(_np.finfo(_np.dtype(dtype.char)).eps)
537+
# Use _np_dtype() to convert dpnp dtype to numpy dtype before finfo.
538+
eps = float(_np.finfo(_np_dtype(dtype)).eps)
524539

525540
if maxiter is None:
526541
maxiter = 5 * n
@@ -570,6 +585,10 @@ def minres(
570585
w2 = _dpnp.zeros_like(x)
571586
r2 = _dpnp.array(v, copy=True)
572587

588+
# Givens rotation scalars from the previous step
589+
cs_n = 0.0
590+
sn_n = 0.0
591+
573592
info = 1
574593
for itr in range(1, maxiter + 1):
575594
# Lanczos step
@@ -596,53 +615,35 @@ def minres(
596615
info = 2
597616
break
598617

599-
# QR update — Givens rotation plane
600-
oldeps = epln
601-
epln = dltan * (-dbar) if itr > 1 else 0.0
602-
dltan = gbar
603-
delta = dltan * _np.cos(0.0) # cos(theta)=dltan/sqrt(dltan^2+beta^2)
604-
605-
# ---- Symmetric QR on the Lanczos tridiagonal ---
606-
# Simplified scalar recurrence (Paige-Saunders §6.4)
607-
eps2 = alpha - shift
608-
dbar = _np.hypot(dbar, beta) # hypothetical: used below in full form
609-
610-
# Givens rotation to zero out the sub-diagonal
611-
eps2sq = float(eps2)
612-
betan = float(beta)
613-
gabar = float(gbar)
614-
rhs1 = float(phibar)
615-
616-
# Full Paige-Saunders Givens step
617-
cs_old = 0.0 if itr == 1 else cs_n
618-
sn_old = 0.0 if itr == 1 else sn_n
619-
620-
# Recurrence: eps, delta, gbar from previous Givens
621-
eps_n = sn_old * betan
622-
dbar = -cs_old * betan
623-
delta_n = _np.hypot(gbar, betan)
618+
# Save previous Givens rotation scalars before overwriting
619+
cs_old = cs_n
620+
sn_old = sn_n
621+
622+
# Givens rotation to annihilate the sub-diagonal of the tridiagonal
623+
# Current diagonal entry in the shifted system
624+
eps_n = sn_old * beta
625+
dbar = -cs_old * beta
626+
delta_n = _np.hypot(gbar, beta)
624627
if delta_n == 0.0:
625628
delta_n = eps
626-
cs_n = gbar / delta_n
627-
sn_n = betan / delta_n
628-
phi = cs_n * phibar
629-
phibar = sn_n * phibar
630-
631-
denom = 1.0 / delta_n
632-
w2old = w2.copy()
633-
w2 = (v - eps_n * w - delta_n * w2) * denom # NOT right yet
634-
# Correct: w update is w_{k} = (v_k - delta*w_{k-1} - eps*w_{k-2}) / gamma
635-
# Redo with right symbols:
636-
w_new = (v - oldeps * w - (delta_n * denom) * w2old)
637-
w = w2old
638-
w2 = w_new
629+
cs_n = gbar / delta_n
630+
sn_n = beta / delta_n
631+
phi = cs_n * phibar
632+
phibar = sn_n * phibar
639633

640-
x = x + phi * w2
634+
# Solution update using the Paige-Saunders w-vectors
635+
denom = 1.0 / delta_n
636+
w_new = (v - eps_n * w - dbar * w2) * denom
637+
x = x + phi * w_new
638+
w = w2.copy()
639+
w2 = w_new
641640

642-
# Residual norm estimate
643-
rnorm = abs(phibar)
641+
# Update gbar for next iteration
642+
gbar = sn_n * (alpha - shift) - cs_n * dbar
643+
# rnorm estimate: |phibar|
644+
rnorm = abs(phibar)
644645

645-
dnorm = _np.hypot(dnorm, phi / delta_n) if delta_n != 0.0 else dnorm
646+
dnorm = _np.hypot(dnorm, phi * denom) if delta_n != 0.0 else dnorm
646647

647648
if callback is not None:
648649
callback(x)
@@ -652,7 +653,7 @@ def minres(
652653
break
653654

654655
# Stagnation guard
655-
if phi / delta_n < eps:
656+
if phi * denom < eps:
656657
info = 2
657658
break
658659
else:

0 commit comments

Comments
 (0)