Skip to content

Commit 6828c57

Browse files
committed
Address review: align CG damp with other solvers
- shorten damp docstrings to 'Damping coefficient' (cls_basic CG.setup/solve, basic.cg) and simplify CG Notes to a single damped-system sentence - follow preallocate vs non-preallocate residual-update pattern for the initial damping correction in setup - shorten the damping comment in step to '# add damping contribution' - test_cg_damp: build y from a known x (y = Aop * x), rename loop var to 'preallocate', drop issue ref from docstring
1 parent f39187c commit 6828c57

3 files changed

Lines changed: 17 additions & 27 deletions

File tree

pylops/optimization/basic.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def cg(
4747
niter : :obj:`int`, optional
4848
Number of iterations
4949
damp : :obj:`float`, optional
50-
Damping coefficient. When non-zero, the damped system
51-
:math:`(\mathbf{Op} + \epsilon\mathbf{I})\,\mathbf{x} = \mathbf{y}` is
52-
solved instead of :math:`\mathbf{Op}\,\mathbf{x} = \mathbf{y}`. ``Op``
53-
must be square and symmetric positive-definite.
50+
Damping coefficient
5451
tol : :obj:`float`, optional
5552
Absolute tolerance on residual norm. Stops the solver when the
5653
residual norm is below this value.

pylops/optimization/cls_basic.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,9 @@ class CG(Solver):
6565
6666
Notes
6767
-----
68-
Solve the :math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}` problem using conjugate gradient
69-
iterations [1]_.
70-
71-
When a non-zero damping coefficient :math:`\epsilon` is provided to ``setup``/``solve``,
72-
the damped system :math:`(\mathbf{Op} + \epsilon\mathbf{I})\,\mathbf{x} = \mathbf{y}` is
73-
solved instead; this is achieved by adding :math:`\epsilon\mathbf{c}` to every operator
74-
application :math:`\mathbf{Op}\,\mathbf{c}` in the iterations. ``Op`` is still required to
75-
be square and symmetric positive-definite (with :math:`\epsilon \geq 0` the damped
76-
operator remains so).
68+
Solve the :math:`\mathbf{y} = (\mathbf{Op} + \epsilon\mathbf{I})\,\mathbf{x}` problem
69+
using conjugate gradient iterations [1]_, where :math:`\epsilon` is the damping
70+
coefficient.
7771
7872
.. [1] Hestenes, M R., Stiefel, E., “Methods of Conjugate Gradients for Solving
7973
Linear Systems”, Journal of Research of the National Bureau of Standards.
@@ -159,9 +153,7 @@ def setup(
159153
Number of iterations (default to ``None`` in case a user wants to
160154
manually step over the solver)
161155
damp : :obj:`float`, optional
162-
Damping coefficient. When non-zero, the damped system
163-
:math:`(\mathbf{Op} + \epsilon\mathbf{I})\,\mathbf{x} = \mathbf{y}`
164-
is solved instead of :math:`\mathbf{Op}\,\mathbf{x} = \mathbf{y}`.
156+
Damping coefficient
165157
tol : :obj:`float`, optional
166158
Absolute tolerance on residual norm. Stops the solver when the
167159
residual norm is below this value.
@@ -202,7 +194,10 @@ def setup(
202194
self.ncp.subtract(self.y, self.Op.matvec(x), out=self.r)
203195
# account for the damping term in the initial residual
204196
if self.damp != 0.0:
205-
self.r -= self.damp * x
197+
if not self.preallocate:
198+
self.r = self.r - self.damp * x
199+
else:
200+
self.ncp.subtract(self.r, self.damp * x, out=self.r)
206201
self.c = self.r.copy()
207202
self.kold = self.ncp.abs(self.r.dot(self.r.conj()))
208203

@@ -237,8 +232,7 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
237232
238233
"""
239234
Opc = self.Op.matvec(self.c)
240-
# solve the damped system (Op + damp*I) x = y by adding the damping
241-
# contribution to every application of the operator
235+
# add damping contribution
242236
if self.damp != 0.0:
243237
Opc = Opc + self.damp * self.c
244238
cOpc = self.ncp.abs(self.c.dot(Opc.conj()))
@@ -355,9 +349,7 @@ def solve(
355349
niter : :obj:`int`, optional
356350
Number of iterations
357351
damp : :obj:`float`, optional
358-
Damping coefficient. When non-zero, the damped system
359-
:math:`(\mathbf{Op} + \epsilon\mathbf{I})\,\mathbf{x} = \mathbf{y}`
360-
is solved instead of :math:`\mathbf{Op}\,\mathbf{x} = \mathbf{y}`.
352+
Damping coefficient
361353
tol : :obj:`float`, optional
362354
Absolute tolerance on residual norm. Stops the solver when the
363355
residual norm is below this value.

pytests/test_solver.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_cg(par):
104104

105105
@pytest.mark.parametrize("par", [(par1), (par2)])
106106
def test_cg_damp(par):
107-
"""CG with damping solves the damped system (Op + damp*I) x = y (issue #406).
107+
"""CG with damping solves the damped system (Op + damp*I) x = y.
108108
109109
Verify equivalence between internal damping, adding ``damp*I`` to the
110110
operator (with ``damp=0`` in the solver), and the dense solution; and that
@@ -116,16 +116,17 @@ def test_cg_damp(par):
116116
A = np.random.normal(0, 1, (n, n)) + par["imag"] * np.random.normal(0, 1, (n, n))
117117
A = np.conj(A).T @ A + np.eye(n) # symmetric positive-definite
118118
Aop = MatrixMult(A, dtype=par["dtype"])
119-
y = np.random.normal(0, 1, n) + par["imag"] * np.random.normal(0, 1, n)
119+
x = np.ones(n) + par["imag"] * np.ones(n)
120+
y = Aop * x
120121
damp = 0.8
121122

122123
x_dense = np.linalg.solve(A + damp * np.eye(n), y)
123124
# adding damp*I to the operator and using damp=0 must match internal damping
124125
Aop_damped = MatrixMult(A + damp * np.eye(n), dtype=par["dtype"])
125126

126-
for prealloc in [False, True]:
127-
x_int = cg(Aop, y, niter=4 * n, tol=1e-10, damp=damp, preallocate=prealloc)[0]
128-
x_ext = cg(Aop_damped, y, niter=4 * n, tol=1e-10, preallocate=prealloc)[0]
127+
for preallocate in [False, True]:
128+
x_int = cg(Aop, y, niter=4 * n, tol=1e-10, damp=damp, preallocate=preallocate)[0]
129+
x_ext = cg(Aop_damped, y, niter=4 * n, tol=1e-10, preallocate=preallocate)[0]
129130
assert_array_almost_equal(x_int, x_dense, decimal=5)
130131
assert_array_almost_equal(x_int, x_ext, decimal=5)
131132

0 commit comments

Comments
 (0)