Skip to content

Commit dbfe2e5

Browse files
committed
feat: added rtol to solvers
1 parent cca7184 commit dbfe2e5

6 files changed

Lines changed: 205 additions & 34 deletions

File tree

pylops/optimization/basic.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from typing import TYPE_CHECKING, Callable, Optional, Tuple
88

9+
from pylops.optimization.callback import ResidualNormCallback
910
from pylops.optimization.cls_basic import CG, CGLS, LSQR
1011
from pylops.utils.decorators import add_ndarray_support_to_solver
1112
from pylops.utils.typing import NDArray
@@ -21,6 +22,7 @@ def cg(
2122
x0: Optional[NDArray] = None,
2223
niter: int = 10,
2324
tol: float = 1e-4,
25+
rtol: bool = 0.0,
2426
show: bool = False,
2527
itershow: Tuple[int, int, int] = (10, 10, 10),
2628
callback: Optional[Callable] = None,
@@ -42,7 +44,12 @@ def cg(
4244
niter : :obj:`int`, optional
4345
Number of iterations
4446
tol : :obj:`float`, optional
45-
Tolerance on residual norm
47+
Absolute tolerance on residual norm. Stops the solver when the
48+
residual norm is below this value.
49+
rtol : :obj:`float`, optional
50+
Relative tolerance on residual norm. Stops the solver when the
51+
ratio of the current residual norm to the initial residual norm is
52+
below this value.
4653
show : :obj:`bool`, optional
4754
Display iterations log
4855
itershow : :obj:`tuple`, optional
@@ -71,7 +78,13 @@ def cg(
7178
See :class:`pylops.optimization.cls_basic.CG`
7279
7380
"""
74-
cgsolve = CG(Op)
81+
rcallback = ResidualNormCallback(rtol)
82+
cgsolve = CG(
83+
Op,
84+
callbacks=[
85+
rcallback,
86+
],
87+
)
7588
if callback is not None:
7689
cgsolve.callback = callback
7790
x, iiter, cost = cgsolve.solve(
@@ -94,6 +107,7 @@ def cgls(
94107
niter: int = 10,
95108
damp: float = 0.0,
96109
tol: float = 1e-4,
110+
rtol: float = 0.0,
97111
show: bool = False,
98112
itershow: Tuple[int, int, int] = (10, 10, 10),
99113
callback: Optional[Callable] = None,
@@ -117,7 +131,12 @@ def cgls(
117131
damp : :obj:`float`, optional
118132
Damping coefficient
119133
tol : :obj:`float`, optional
120-
Tolerance on residual norm
134+
Absolute tolerance on residual norm. Stops the solver when the
135+
residual norm is below this value.
136+
rtol : :obj:`float`, optional
137+
Relative tolerance on residual norm. Stops the solver when the
138+
ratio of the current residual norm to the initial residual norm is
139+
below this value.
121140
show : :obj:`bool`, optional
122141
Display iterations log
123142
itershow : :obj:`tuple`, optional
@@ -161,7 +180,13 @@ def cgls(
161180
See :class:`pylops.optimization.cls_basic.CGLS`
162181
163182
"""
164-
cgsolve = CGLS(Op)
183+
rcallback = ResidualNormCallback(rtol)
184+
cgsolve = CGLS(
185+
Op,
186+
callbacks=[
187+
rcallback,
188+
],
189+
)
165190
if callback is not None:
166191
cgsolve.callback = callback
167192
x, istop, iiter, r1norm, r2norm, cost = cgsolve.solve(

pylops/optimization/callback.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = [
22
"Callbacks",
33
"MetricsCallback",
4+
"ResidualNormCallback",
45
]
56

67
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
@@ -28,6 +29,11 @@ class Callbacks:
2829
2930
All methods take two input parameters: the solver itself, and the vector ``x``.
3031
32+
Moreover, some callback may be used to implement custom stopping criteria for the solvers.
33+
This can be done by adding a boolean attribute `stop` to the callback object, which will
34+
be initially set to `False`. As soon as the callback sets this attribute to `True`, the
35+
``run`` method of the solver will stop iterating and return the current model vector.
36+
3137
Examples
3238
--------
3339
>>> import numpy as np
@@ -181,3 +187,53 @@ def on_step_end(self, solver: "Solver", x: NDArray) -> None:
181187
self.metrics["snr"].append(snr(self.xtrue, x))
182188
if "psnr" in self.which:
183189
self.metrics["psnr"].append(psnr(self.xtrue, x))
190+
191+
192+
class ResidualNormCallback(Callbacks):
193+
"""Residual norm callback
194+
195+
This callback can be used to stop the solver when the residual norm
196+
is below a certain threshold defined as a percentage of the
197+
initial residual norm.
198+
199+
Parameters
200+
----------
201+
rtol : :obj:`float`
202+
Percentage of the initial residual norm below which the solver
203+
will stop iterating. For example, if `rtol` is 0.1, the solver
204+
will stop when the residual norm is below 10% of the initial
205+
residual norm.
206+
207+
"""
208+
209+
def __init__(self, rtol: float) -> None:
210+
self.rtol = rtol
211+
self.stop = False
212+
213+
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
214+
if solver.cost[-1] < self.rtol * solver.cost[0]:
215+
self.stop = True
216+
217+
218+
def _callback_stop(callbacks: Sequence[Callbacks]) -> bool:
219+
"""Check if any callback has raised a stop flag
220+
221+
Parameters
222+
----------
223+
callbacks : :obj:`pylops.optimization.callback.Callbacks`
224+
List of callbacks to evaluate
225+
226+
Returns
227+
-------
228+
stop : :obj:`bool`
229+
Whether to stop the solver or not
230+
231+
"""
232+
if callbacks is not None:
233+
stop = [
234+
False if not hasattr(callback, "stop") else callback.stop
235+
for callback in callbacks
236+
]
237+
if any(stop):
238+
return True
239+
return False

pylops/optimization/cls_basic.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212
from pylops.optimization.basesolver import Solver, _units
13+
from pylops.optimization.callback import _callback_stop
1314
from pylops.utils.backend import (
1415
get_array_module,
1516
get_module_name,
@@ -121,7 +122,8 @@ def setup(
121122
Number of iterations (default to ``None`` in case a user wants to
122123
manually step over the solver)
123124
tol : :obj:`float`, optional
124-
Tolerance on residual norm
125+
Absolute tolerance on residual norm. Stops the solver when the
126+
residual norm is below this value.
125127
preallocate : :obj:`bool`, optional
126128
.. versionadded:: 2.6.0
127129
@@ -260,6 +262,10 @@ def run(
260262
)
261263
x = self.step(x, showstep)
262264
self.callback(x)
265+
# check if any callback has raised a stop flag
266+
stop = _callback_stop(self.callbacks)
267+
if stop:
268+
break
263269
return x
264270

265271
def finalize(self, show: bool = False) -> None:
@@ -592,22 +598,26 @@ def run(
592598
Estimated model of size :math:`[M \times 1]`
593599
594600
"""
595-
niter = self.niter if niter is None else niter
596-
if niter is None:
601+
self.niter = self.niter if niter is None else niter
602+
if self.niter is None:
597603
raise ValueError("niter must not be None")
598-
while self.iiter < niter and self.kold > self.tol:
604+
while self.iiter < self.niter and self.kold > self.tol:
599605
showstep = (
600606
True
601607
if show
602608
and (
603609
self.iiter < itershow[0]
604-
or niter - self.iiter < itershow[1]
610+
or self.niter - self.iiter < itershow[1]
605611
or self.iiter % itershow[2] == 0
606612
)
607613
else False
608614
)
609615
x = self.step(x, showstep)
610616
self.callback(x)
617+
# check if any callback has raised a stop flag
618+
stop = _callback_stop(self.callbacks)
619+
if stop:
620+
break
611621
return x
612622

613623
def finalize(self, show: bool = False) -> None:
@@ -622,7 +632,12 @@ def finalize(self, show: bool = False) -> None:
622632
self.tend = time.time()
623633
self.telapsed = self.tend - self.tstart
624634
# reason for termination
625-
self.istop = 1 if self.kold < self.tol else 2
635+
if self.kold < self.tol:
636+
self.istop = 1
637+
elif self.iiter >= self.niter:
638+
self.istop = 2
639+
else:
640+
self.istop = 3
626641
self.r1norm = self.kold
627642
self.r2norm = self.cost1[self.iiter]
628643
if show:
@@ -677,10 +692,14 @@ def solve(
677692
Gives the reason for termination
678693
679694
``1`` means :math:`\mathbf{x}` is an approximate solution to
680-
:math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}`
695+
:math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}` with the provided
696+
tolerance ``tol``
681697
682698
``2`` means :math:`\mathbf{x}` approximately solves the least-squares
683-
problem
699+
problem (reached the maximum number of iterations ``niter``)
700+
701+
``3`` means another stopping criterion implemented via a callback
702+
was reached
684703
iit : :obj:`int`
685704
Iteration number upon termination
686705
r1norm : :obj:`float`

pylops/optimization/cls_leastsquares.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111
from scipy.sparse.linalg import cg as sp_cg
12-
from scipy.sparse.linalg import lsqr
12+
from scipy.sparse.linalg import lsqr as sp_lsqr
1313

1414
from pylops.basicoperators import Diagonal, VStack
1515
from pylops.optimization.basesolver import Solver, _units
@@ -676,7 +676,7 @@ def run(
676676
if engine == "scipy" and self.ncp == np:
677677
if show:
678678
kwargs_solver["show"] = 1
679-
xinv, istop, itn, r1norm, r2norm = lsqr(
679+
xinv, istop, itn, r1norm, r2norm = sp_lsqr(
680680
self.RegOp, self.datatot, x0=x, **kwargs_solver
681681
)[0:5]
682682
elif engine == "pylops" or self.ncp != np:
@@ -938,7 +938,7 @@ def run(
938938
if engine == "scipy" and self.ncp == np:
939939
if show:
940940
kwargs_solver["show"] = 1
941-
pinv, istop, itn, r1norm, r2norm = lsqr(
941+
pinv, istop, itn, r1norm, r2norm = sp_lsqr(
942942
self.POp,
943943
self.y,
944944
x0=x,

pylops/optimization/cls_sparsity.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pylops.basicoperators import Diagonal, Identity, VStack
2020
from pylops.optimization.basesolver import Solver, _units
2121
from pylops.optimization.basic import cgls
22+
from pylops.optimization.callback import _callback_stop
2223
from pylops.optimization.eigs import power_iteration
2324
from pylops.optimization.leastsquares import regularized_inversion
2425
from pylops.utils import deps
@@ -399,14 +400,14 @@ def setup(
399400
epsR : :obj:`float`, optional
400401
Damping to be applied to residuals for weighting term
401402
epsI : :obj:`float`, optional
402-
Tikhonov damping (for ``kind="data"``) or L1 model damping
403-
(for ``kind="datamodel"``)
403+
Tikhonov damping
404404
tolIRLS : :obj:`float`, optional
405405
Tolerance. Stop outer iterations if difference between inverted model
406406
at subsequent iterations is smaller than ``tolIRLS``
407407
warm : :obj:`bool`, optional
408-
Warm start each inversion inner step with previous estimate (``True``) or not (``False``).
409-
This only applies to ``kind="data"`` and ``kind="datamodel"``
408+
Warm start each inversion inner step with previous estimate (``True``)
409+
or not (``False``). This only applies to ``kind="data"`` and
410+
``kind="datamodel"``
410411
kind : :obj:`str`, optional
411412
Kind of solver (``model``, ``data`` or ``datamodel``)
412413
preallocate : :obj:`bool`, optional
@@ -432,9 +433,6 @@ def setup(
432433
self.isjax = get_module_name(self.ncp) == "jax"
433434
self._setpreallocate(preallocate)
434435

435-
# initiate outer iteration counter
436-
self.iiter = 0
437-
438436
# choose step to use
439437
if self.kind == "data":
440438
self._step = self._step_data
@@ -456,6 +454,13 @@ def setup(
456454
self.rw = self.ncp.empty_like(self.y)
457455
else:
458456
self.rw = self.ncp.empty(self.Op.shape[1], dtype=self.Op.dtype)
457+
458+
# create variables to track the residual norm and iterations
459+
self.cost = [
460+
float(np.linalg.norm(self.y)),
461+
]
462+
self.iiter = 0
463+
459464
# print setup
460465
if show:
461466
self._print_setup()
@@ -619,6 +624,7 @@ def step(
619624
self.rnorm = self.ncp.linalg.norm(self.r)
620625

621626
self.iiter += 1
627+
self.cost.append(float(self.rnorm))
622628
if show:
623629
self._print_step(x)
624630
return x
@@ -687,6 +693,10 @@ def run(
687693
xold = x.copy()
688694
x = self.step(x, engine, showstep, **kwargs_solver)
689695
self.callback(x)
696+
# check if any callback has raised a stop flag
697+
stop = _callback_stop(self.callbacks)
698+
if stop:
699+
break
690700

691701
# adding initial guess
692702
if hasattr(self, "x0"):
@@ -1134,7 +1144,7 @@ def step(
11341144
self.ncp.subtract(self.res, self.y, out=self.res)
11351145

11361146
self.iiter += 1
1137-
self.cost.append(float(np.linalg.norm(self.res)))
1147+
self.cost.append(float(self.ncp.linalg.norm(self.res)))
11381148
if show:
11391149
self._print_step(x)
11401150
return x, cols
@@ -1187,6 +1197,10 @@ def run(
11871197
)
11881198
x, cols = self.step(x, cols, engine, showstep)
11891199
self.callback(x, cols)
1200+
# check if any callback has raised a stop flag
1201+
stop = _callback_stop(self.callbacks)
1202+
if stop:
1203+
break
11901204
return x, cols
11911205

11921206
def finalize(
@@ -1824,6 +1838,10 @@ def run(
18241838
)
18251839
x, xupdate = self.step(x, showstep)
18261840
self.callback(x)
1841+
# check if any callback has raised a stop flag
1842+
stop = _callback_stop(self.callbacks)
1843+
if stop:
1844+
break
18271845
if xupdate <= self.tol:
18281846
logger.info("Update smaller that tolerance for iteration %d", self.iiter)
18291847
return x
@@ -2205,6 +2223,10 @@ def run(
22052223
)
22062224
x, z, xupdate = self.step(x, z, showstep)
22072225
self.callback(x)
2226+
# check if any callback has raised a stop flag
2227+
stop = _callback_stop(self.callbacks)
2228+
if stop:
2229+
break
22082230
if xupdate <= self.tol:
22092231
logger.warning(
22102232
"Update smaller that tolerance for " "iteration %d", self.iiter
@@ -2943,7 +2965,10 @@ def run(
29432965
)
29442966
x = self.step(x, engine, showstep, show_inner, **kwargs_lsqr)
29452967
self.callback(x)
2946-
2968+
# check if any callback has raised a stop flag
2969+
stop = _callback_stop(self.callbacks)
2970+
if stop:
2971+
break
29472972
return x
29482973

29492974
def finalize(self, show: bool = False) -> NDArray:

0 commit comments

Comments
 (0)