@@ -65,8 +65,9 @@ class CG(Solver):
6565
6666 Notes
6767 -----
68- Solve the :math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}` problem using conjugate gradient
69- iterations [1]_.
68+ Solve the :math:`\mathbf{y} = (\mathbf{Op} + \epsilon\mathbf{I})\,\mathbf{x}` problem
69+ using conjugate gradient iterations [1]_, where :math:`\epsilon` is the damping
70+ coefficient.
7071
7172 .. [1] Hestenes, M R., Stiefel, E., “Methods of Conjugate Gradients for Solving
7273 Linear Systems”, Journal of Research of the National Bureau of Standards.
@@ -134,6 +135,7 @@ def setup(
134135 y : NDArray ,
135136 x0 : NDArray | None = None ,
136137 niter : int | None = None ,
138+ damp : float = 0.0 ,
137139 tol : float = 1e-4 ,
138140 preallocate : bool = False ,
139141 show : bool = False ,
@@ -150,6 +152,8 @@ def setup(
150152 niter : :obj:`int`, optional
151153 Number of iterations (default to ``None`` in case a user wants to
152154 manually step over the solver)
155+ damp : :obj:`float`, optional
156+ Damping coefficient
153157 tol : :obj:`float`, optional
154158 Absolute tolerance on residual norm. Stops the solver when the
155159 residual norm is below this value.
@@ -170,6 +174,7 @@ def setup(
170174 """
171175 self .y = y
172176 self .niter = niter
177+ self .damp = damp
173178 self .tol = tol
174179
175180 self .ncp = get_array_module (y )
@@ -187,6 +192,12 @@ def setup(
187192 else :
188193 self .r = self .ncp .empty_like (self .y )
189194 self .ncp .subtract (self .y , self .Op .matvec (x ), out = self .r )
195+ # account for the damping term in the initial residual
196+ if self .damp != 0.0 :
197+ if not self .preallocate :
198+ self .r = self .r - self .damp * x
199+ else :
200+ self .ncp .subtract (self .r , self .damp * x , out = self .r )
190201 self .c = self .r .copy ()
191202 self .kold = self .ncp .abs (self .r .dot (self .r .conj ()))
192203
@@ -221,6 +232,9 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
221232
222233 """
223234 Opc = self .Op .matvec (self .c )
235+ # add damping contribution
236+ if self .damp != 0.0 :
237+ Opc = Opc + self .damp * self .c
224238 cOpc = self .ncp .abs (self .c .dot (Opc .conj ()))
225239 a = self .kold / cOpc
226240 if not self .preallocate :
@@ -317,6 +331,7 @@ def solve(
317331 y : NDArray ,
318332 x0 : NDArray | None = None ,
319333 niter : int = 10 ,
334+ damp : float = 0.0 ,
320335 tol : float = 1e-4 ,
321336 preallocate : bool = False ,
322337 show : bool = False ,
@@ -333,6 +348,8 @@ def solve(
333348 internally as zero vector
334349 niter : :obj:`int`, optional
335350 Number of iterations
351+ damp : :obj:`float`, optional
352+ Damping coefficient
336353 tol : :obj:`float`, optional
337354 Absolute tolerance on residual norm. Stops the solver when the
338355 residual norm is below this value.
@@ -360,7 +377,13 @@ def solve(
360377
361378 """
362379 x = self .setup (
363- y = y , x0 = x0 , niter = niter , tol = tol , preallocate = preallocate , show = show
380+ y = y ,
381+ x0 = x0 ,
382+ niter = niter ,
383+ damp = damp ,
384+ tol = tol ,
385+ preallocate = preallocate ,
386+ show = show ,
364387 )
365388 x = self .run (x , niter , show = show , itershow = itershow )
366389 self .finalize (show )
0 commit comments