9090# Helpers
9191# ---------------------------------------------------------------------------
9292
93+ def _np_dtype (dp_dtype ) -> _np .dtype :
94+ """Convert a dpnp dtype (or any dtype-like) to a concrete numpy dtype.
95+
96+ dpnp dtype objects (e.g. dpnp.float64) are *type objects*, not
97+ numpy dtype instances, so they have no ``.char`` attribute.
98+ Wrapping them with ``_np.dtype(...)`` normalises everything to a
99+ proper numpy dtype regardless of whether the input is a dpnp type,
100+ a numpy type, a string, or already a numpy dtype.
101+ """
102+ return _np .dtype (dp_dtype )
103+
104+
93105def _check_dtype (dtype , name : str ) -> None :
94- if dtype .char not in _SUPPORTED_DTYPES :
106+ if _np_dtype ( dtype ) .char not in _SUPPORTED_DTYPES :
95107 raise TypeError (
96108 f"{ name } has unsupported dtype { dtype } ; "
97109 "only float32, float64, complex64, complex128 are accepted."
@@ -149,8 +161,8 @@ def _make_system(A, M, x0, b):
149161 dtype = _dpnp .complex128
150162 else :
151163 dtype = _dpnp .float64
152- if A_op .dtype is not None and A_op .dtype .char in "fF" :
153- dtype = _dpnp .complex64 if A_op .dtype .char == "F" else _dpnp .float32
164+ if A_op .dtype is not None and _np_dtype ( A_op .dtype ) .char in "fF" :
165+ dtype = _dpnp .complex64 if _np_dtype ( A_op .dtype ) .char == "F" else _dpnp .float32
154166
155167 b = b .astype (dtype , copy = False )
156168 _check_dtype (b .dtype , "b" )
@@ -240,7 +252,8 @@ def cg(
240252 maxiter = n * 10
241253
242254 # Machine-epsilon breakdown tolerance (mirrors SciPy bicg rhotol)
243- rhotol = float (_np .finfo (_np .dtype (dtype .char )).eps ** 2 )
255+ # Use _np_dtype() to safely convert dpnp dtype to numpy dtype.
256+ rhotol = float (_np .finfo (_np_dtype (dtype )).eps ** 2 )
244257
245258 r = b - A_op .matvec (x ) if _dpnp .any (x ) else b .copy ()
246259 z = M_op .matvec (r )
@@ -350,7 +363,8 @@ def gmres(
350363
351364 is_cpx = _dpnp .issubdtype (dtype , _dpnp .complexfloating )
352365 H_dtype = _np .complex128 if is_cpx else _np .float64
353- rhotol = float (_np .finfo (H_dtype ).eps ** 2 )
366+ # Use _np_dtype() so this works whether dtype is a dpnp type or numpy dtype.
367+ rhotol = float (_np .finfo (_np_dtype (dtype )).eps ** 2 )
354368
355369 total_iters = 0
356370 info = maxiter
@@ -520,7 +534,8 @@ def minres(
520534 A_op , M_op , x , b , dtype = _make_system (A , M , x0 , b )
521535 n = b .shape [0 ]
522536 is_cpx = _dpnp .issubdtype (dtype , _dpnp .complexfloating )
523- eps = float (_np .finfo (_np .dtype (dtype .char )).eps )
537+ # Use _np_dtype() to convert dpnp dtype to numpy dtype before finfo.
538+ eps = float (_np .finfo (_np_dtype (dtype )).eps )
524539
525540 if maxiter is None :
526541 maxiter = 5 * n
@@ -570,6 +585,10 @@ def minres(
570585 w2 = _dpnp .zeros_like (x )
571586 r2 = _dpnp .array (v , copy = True )
572587
588+ # Givens rotation scalars from the previous step
589+ cs_n = 0.0
590+ sn_n = 0.0
591+
573592 info = 1
574593 for itr in range (1 , maxiter + 1 ):
575594 # Lanczos step
@@ -596,53 +615,35 @@ def minres(
596615 info = 2
597616 break
598617
599- # QR update — Givens rotation plane
600- oldeps = epln
601- epln = dltan * (- dbar ) if itr > 1 else 0.0
602- dltan = gbar
603- delta = dltan * _np .cos (0.0 ) # cos(theta)=dltan/sqrt(dltan^2+beta^2)
604-
605- # ---- Symmetric QR on the Lanczos tridiagonal ---
606- # Simplified scalar recurrence (Paige-Saunders §6.4)
607- eps2 = alpha - shift
608- dbar = _np .hypot (dbar , beta ) # hypothetical: used below in full form
609-
610- # Givens rotation to zero out the sub-diagonal
611- eps2sq = float (eps2 )
612- betan = float (beta )
613- gabar = float (gbar )
614- rhs1 = float (phibar )
615-
616- # Full Paige-Saunders Givens step
617- cs_old = 0.0 if itr == 1 else cs_n
618- sn_old = 0.0 if itr == 1 else sn_n
619-
620- # Recurrence: eps, delta, gbar from previous Givens
621- eps_n = sn_old * betan
622- dbar = - cs_old * betan
623- delta_n = _np .hypot (gbar , betan )
618+ # Save previous Givens rotation scalars before overwriting
619+ cs_old = cs_n
620+ sn_old = sn_n
621+
622+ # Givens rotation to annihilate the sub-diagonal of the tridiagonal
623+ # Current diagonal entry in the shifted system
624+ eps_n = sn_old * beta
625+ dbar = - cs_old * beta
626+ delta_n = _np .hypot (gbar , beta )
624627 if delta_n == 0.0 :
625628 delta_n = eps
626- cs_n = gbar / delta_n
627- sn_n = betan / delta_n
628- phi = cs_n * phibar
629- phibar = sn_n * phibar
630-
631- denom = 1.0 / delta_n
632- w2old = w2 .copy ()
633- w2 = (v - eps_n * w - delta_n * w2 ) * denom # NOT right yet
634- # Correct: w update is w_{k} = (v_k - delta*w_{k-1} - eps*w_{k-2}) / gamma
635- # Redo with right symbols:
636- w_new = (v - oldeps * w - (delta_n * denom ) * w2old )
637- w = w2old
638- w2 = w_new
629+ cs_n = gbar / delta_n
630+ sn_n = beta / delta_n
631+ phi = cs_n * phibar
632+ phibar = sn_n * phibar
639633
640- x = x + phi * w2
634+ # Solution update using the Paige-Saunders w-vectors
635+ denom = 1.0 / delta_n
636+ w_new = (v - eps_n * w - dbar * w2 ) * denom
637+ x = x + phi * w_new
638+ w = w2 .copy ()
639+ w2 = w_new
641640
642- # Residual norm estimate
643- rnorm = abs (phibar )
641+ # Update gbar for next iteration
642+ gbar = sn_n * (alpha - shift ) - cs_n * dbar
643+ # rnorm estimate: |phibar|
644+ rnorm = abs (phibar )
644645
645- dnorm = _np .hypot (dnorm , phi / delta_n ) if delta_n != 0.0 else dnorm
646+ dnorm = _np .hypot (dnorm , phi * denom ) if delta_n != 0.0 else dnorm
646647
647648 if callback is not None :
648649 callback (x )
@@ -652,7 +653,7 @@ def minres(
652653 break
653654
654655 # Stagnation guard
655- if phi / delta_n < eps :
656+ if phi * denom < eps :
656657 info = 2
657658 break
658659 else :
0 commit comments