Skip to content

Commit 0ae0af3

Browse files
authored
Merge pull request #690 from mrava87/feat-solversstopping
Feat: stopping criterial for solvers
2 parents cca7184 + 1d9c4ba commit 0ae0af3

10 files changed

Lines changed: 568 additions & 56 deletions

File tree

docs/source/addingsolver.rst

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ Implementing new solvers
44
========================
55
Users are welcome to create new solvers and add them to the PyLops library.
66

7-
In this tutorial, we will go through the key steps in the definition of a solver, using the
8-
:py:class:`pylops.optimization.basic.CG` as an example.
7+
In this tutorial, we will go through the key steps in the definition of a solver, using a
8+
sligthly simplified version of :py:class:`pylops.optimization.basic.CG` as an example.
99

1010
.. note::
1111
In case the solver that you are planning to create falls within the category of proximal solvers,
@@ -83,7 +83,6 @@ note that the ``setup`` method returns the created starting guess ``x`` (does no
8383
.. code-block:: python
8484
8585
def setup(self, y, x0=None, niter=None, tol=1e-4, show=False):
86-
8786
self.y = y
8887
self.tol = tol
8988
self.niter = niter
@@ -134,7 +133,26 @@ can add additional input parameters. For CG, the step is:
134133
135134
136135
Similarly, we also implement a ``run`` method that is in charge of running a number of iterations by repeatedly
137-
calling the ``step`` method. It is also usually convenient to implement a finalize method; this method can do any required post-processing that should
136+
calling the ``step`` method.
137+
138+
.. code-block:: python
139+
140+
def run(self, x, niter, show, itershow):
141+
while self.iiter < niter and self.kold > self.tol:
142+
x = self.step(x, showstep)
143+
self.callback(x)
144+
# check if any callback has raised a stop flag
145+
stop = _callback_stop(self.callbacks)
146+
if stop:
147+
break
148+
return x
149+
150+
It is worth noting that any number of callbacks can be attached to the solver; some of these
151+
callbacks can implement a stopping criterion and set the ``stop`` member to True when a given
152+
condition is met. The ``_callback_stop`` method is in change of checking if any of the callbacks
153+
has set ``stop`` to True and in the case break the iterations.
154+
155+
Finally, it is also usually convenient to implement a ``finalize`` method; this method can do any required post-processing that should
138156
not be applied at the end of each step, rather at the end of the entire optimization process. For CG, this is as simple
139157
as converting the ``cost`` variable from a list to a ``numpy`` array. For more details, see our implementations for CG.
140158

@@ -169,8 +187,10 @@ input and returns some of the most valuable properties of the class-based solver
169187

170188
.. code-block:: python
171189
172-
def cg(Op, y, x0, niter=10, tol=1e-4, show=False, itershow=(10, 10, 10), callback=None):
173-
cgsolve = CG(Op)
190+
def cg(Op, y, x0, niter=10, tol=1e-4, rtol=0.0,
191+
show=False, itershow=(10, 10, 10), callback=None):
192+
rcallback = ResidualNormCallback(rtol)
193+
cgsolve = CG(Op, callbacks=[rcallback, ])
174194
if callback is not None:
175195
cgsolve.callback = callback
176196
x, iiter, cost = cgsolve.solve(

pylops/optimization/basic.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from typing import TYPE_CHECKING, Callable, Optional, Tuple
88

9+
from pylops.optimization.callback import ResidualNormCallback
910
from pylops.optimization.cls_basic import CG, CGLS, LSQR
1011
from pylops.utils.decorators import add_ndarray_support_to_solver
1112
from pylops.utils.typing import NDArray
@@ -21,6 +22,7 @@ def cg(
2122
x0: Optional[NDArray] = None,
2223
niter: int = 10,
2324
tol: float = 1e-4,
25+
rtol: bool = 0.0,
2426
show: bool = False,
2527
itershow: Tuple[int, int, int] = (10, 10, 10),
2628
callback: Optional[Callable] = None,
@@ -42,7 +44,12 @@ def cg(
4244
niter : :obj:`int`, optional
4345
Number of iterations
4446
tol : :obj:`float`, optional
45-
Tolerance on residual norm
47+
Absolute tolerance on residual norm. Stops the solver when the
48+
residual norm is below this value.
49+
rtol : :obj:`float`, optional
50+
Relative tolerance on residual norm. Stops the solver when the
51+
ratio of the current residual norm to the initial residual norm is
52+
below this value.
4653
show : :obj:`bool`, optional
4754
Display iterations log
4855
itershow : :obj:`tuple`, optional
@@ -71,7 +78,13 @@ def cg(
7178
See :class:`pylops.optimization.cls_basic.CG`
7279
7380
"""
74-
cgsolve = CG(Op)
81+
rcallback = ResidualNormCallback(rtol)
82+
cgsolve = CG(
83+
Op,
84+
callbacks=[
85+
rcallback,
86+
],
87+
)
7588
if callback is not None:
7689
cgsolve.callback = callback
7790
x, iiter, cost = cgsolve.solve(
@@ -94,6 +107,7 @@ def cgls(
94107
niter: int = 10,
95108
damp: float = 0.0,
96109
tol: float = 1e-4,
110+
rtol: float = 0.0,
97111
show: bool = False,
98112
itershow: Tuple[int, int, int] = (10, 10, 10),
99113
callback: Optional[Callable] = None,
@@ -117,7 +131,12 @@ def cgls(
117131
damp : :obj:`float`, optional
118132
Damping coefficient
119133
tol : :obj:`float`, optional
120-
Tolerance on residual norm
134+
Absolute tolerance on residual norm. Stops the solver when the
135+
residual norm is below this value.
136+
rtol : :obj:`float`, optional
137+
Relative tolerance on residual norm. Stops the solver when the
138+
ratio of the current residual norm to the initial residual norm is
139+
below this value.
121140
show : :obj:`bool`, optional
122141
Display iterations log
123142
itershow : :obj:`tuple`, optional
@@ -161,7 +180,13 @@ def cgls(
161180
See :class:`pylops.optimization.cls_basic.CGLS`
162181
163182
"""
164-
cgsolve = CGLS(Op)
183+
rcallback = ResidualNormCallback(rtol)
184+
cgsolve = CGLS(
185+
Op,
186+
callbacks=[
187+
rcallback,
188+
],
189+
)
165190
if callback is not None:
166191
cgsolve.callback = callback
167192
x, istop, iiter, r1norm, r2norm, cost = cgsolve.solve(

pylops/optimization/callback.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = [
22
"Callbacks",
33
"MetricsCallback",
4+
"ResidualNormCallback",
45
]
56

67
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
@@ -28,25 +29,32 @@ class Callbacks:
2829
2930
All methods take two input parameters: the solver itself, and the vector ``x``.
3031
32+
Moreover, some callback may be used to implement custom stopping criteria for the solver.
33+
This can be done by adding a boolean attribute ``stop`` to the callback object, which will
34+
be initially set to ``False``. As soon as the callback sets this attribute to ``True``, the
35+
``run`` method of the solver will stop iterating and return the current model vector.
36+
3137
Examples
3238
--------
3339
>>> import numpy as np
3440
>>> from pylops.basicoperators import MatrixMult
35-
>>> from pylops.optimization.solver import CG
41+
>>> from pylops.optimization.basic import CG
3642
>>> from pylops.optimization.callback import Callbacks
43+
>>>
3744
>>> class StoreIterCallback(Callbacks):
3845
... def __init__(self):
3946
... self.stored = []
4047
... def on_step_end(self, solver, x):
4148
... self.stored.append(solver.iiter)
42-
>>> cb_sto = StoreIterCallback()
49+
>>>
4350
>>> Aop = MatrixMult(np.random.normal(0., 1., 36).reshape(6, 6))
4451
>>> Aop = Aop.H @ Aop
4552
>>> y = Aop @ np.ones(6)
53+
>>> cb_sto = StoreIterCallback()
4654
>>> cgsolve = CG(Aop, callbacks=[cb_sto, ])
4755
>>> xest = cgsolve.solve(y=y, x0=np.zeros(6), tol=0, niter=6, show=False)[0]
48-
>>> xest
49-
array([1., 1., 1., 1., 1., 1.])
56+
>>> xest, cb_sto.stored
57+
(array([1., 1., 1., 1., 1., 1.]), [1, 2, 3, 4, 5, 6])
5058
5159
"""
5260

@@ -181,3 +189,53 @@ def on_step_end(self, solver: "Solver", x: NDArray) -> None:
181189
self.metrics["snr"].append(snr(self.xtrue, x))
182190
if "psnr" in self.which:
183191
self.metrics["psnr"].append(psnr(self.xtrue, x))
192+
193+
194+
class ResidualNormCallback(Callbacks):
195+
"""Residual norm callback
196+
197+
This callback can be used to stop the solver when the residual norm
198+
is below a certain threshold defined as a percentage of the
199+
initial residual norm.
200+
201+
Parameters
202+
----------
203+
rtol : :obj:`float`
204+
Percentage of the initial residual norm below which the solver
205+
will stop iterating. For example, if `rtol` is 0.1, the solver
206+
will stop when the residual norm is below 10% of the initial
207+
residual norm.
208+
209+
"""
210+
211+
def __init__(self, rtol: float) -> None:
212+
self.rtol = rtol
213+
self.stop = False
214+
215+
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
216+
if solver.cost[-1] < self.rtol * solver.cost[0]:
217+
self.stop = True
218+
219+
220+
def _callback_stop(callbacks: Sequence[Callbacks]) -> bool:
221+
"""Check if any callback has raised a stop flag
222+
223+
Parameters
224+
----------
225+
callbacks : :obj:`pylops.optimization.callback.Callbacks`
226+
List of callbacks to evaluate
227+
228+
Returns
229+
-------
230+
stop : :obj:`bool`
231+
Whether to stop the solver or not
232+
233+
"""
234+
if callbacks is not None:
235+
stop = [
236+
False if not hasattr(callback, "stop") else callback.stop
237+
for callback in callbacks
238+
]
239+
if any(stop):
240+
return True
241+
return False

pylops/optimization/cls_basic.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212
from pylops.optimization.basesolver import Solver, _units
13+
from pylops.optimization.callback import _callback_stop
1314
from pylops.utils.backend import (
1415
get_array_module,
1516
get_module_name,
@@ -121,7 +122,8 @@ def setup(
121122
Number of iterations (default to ``None`` in case a user wants to
122123
manually step over the solver)
123124
tol : :obj:`float`, optional
124-
Tolerance on residual norm
125+
Absolute tolerance on residual norm. Stops the solver when the
126+
residual norm is below this value.
125127
preallocate : :obj:`bool`, optional
126128
.. versionadded:: 2.6.0
127129
@@ -260,6 +262,10 @@ def run(
260262
)
261263
x = self.step(x, showstep)
262264
self.callback(x)
265+
# check if any callback has raised a stop flag
266+
stop = _callback_stop(self.callbacks)
267+
if stop:
268+
break
263269
return x
264270

265271
def finalize(self, show: bool = False) -> None:
@@ -299,7 +305,8 @@ def solve(
299305
niter : :obj:`int`, optional
300306
Number of iterations
301307
tol : :obj:`float`, optional
302-
Tolerance on residual norm
308+
Absolute tolerance on residual norm. Stops the solver when the
309+
residual norm is below this value.
303310
preallocate : :obj:`bool`, optional
304311
.. versionadded:: 2.6.0
305312
@@ -440,7 +447,8 @@ def setup(
440447
damp : :obj:`float`, optional
441448
Damping coefficient
442449
tol : :obj:`float`, optional
443-
Tolerance on residual norm
450+
Absolute tolerance on residual norm. Stops the solver when the
451+
residual norm is below this value.
444452
preallocate : :obj:`bool`, optional
445453
.. versionadded:: 2.6.0
446454
@@ -592,22 +600,26 @@ def run(
592600
Estimated model of size :math:`[M \times 1]`
593601
594602
"""
595-
niter = self.niter if niter is None else niter
596-
if niter is None:
603+
self.niter = self.niter if niter is None else niter
604+
if self.niter is None:
597605
raise ValueError("niter must not be None")
598-
while self.iiter < niter and self.kold > self.tol:
606+
while self.iiter < self.niter and self.kold > self.tol:
599607
showstep = (
600608
True
601609
if show
602610
and (
603611
self.iiter < itershow[0]
604-
or niter - self.iiter < itershow[1]
612+
or self.niter - self.iiter < itershow[1]
605613
or self.iiter % itershow[2] == 0
606614
)
607615
else False
608616
)
609617
x = self.step(x, showstep)
610618
self.callback(x)
619+
# check if any callback has raised a stop flag
620+
stop = _callback_stop(self.callbacks)
621+
if stop:
622+
break
611623
return x
612624

613625
def finalize(self, show: bool = False) -> None:
@@ -622,7 +634,12 @@ def finalize(self, show: bool = False) -> None:
622634
self.tend = time.time()
623635
self.telapsed = self.tend - self.tstart
624636
# reason for termination
625-
self.istop = 1 if self.kold < self.tol else 2
637+
if self.kold < self.tol:
638+
self.istop = 1
639+
elif self.iiter >= self.niter:
640+
self.istop = 2
641+
else:
642+
self.istop = 3
626643
self.r1norm = self.kold
627644
self.r2norm = self.cost1[self.iiter]
628645
if show:
@@ -655,7 +672,8 @@ def solve(
655672
damp : :obj:`float`, optional
656673
Damping coefficient
657674
tol : :obj:`float`, optional
658-
Tolerance on residual norm
675+
Absolute tolerance on residual norm. Stops the solver when the
676+
residual norm is below this value.
659677
preallocate : :obj:`bool`, optional
660678
.. versionadded:: 2.6.0
661679
@@ -677,10 +695,14 @@ def solve(
677695
Gives the reason for termination
678696
679697
``1`` means :math:`\mathbf{x}` is an approximate solution to
680-
:math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}`
698+
:math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}` with the provided
699+
tolerance ``tol``
681700
682701
``2`` means :math:`\mathbf{x}` approximately solves the least-squares
683-
problem
702+
problem (reached the maximum number of iterations ``niter``)
703+
704+
``3`` means another stopping criterion implemented via a callback
705+
was reached
684706
iit : :obj:`int`
685707
Iteration number upon termination
686708
r1norm : :obj:`float`

pylops/optimization/cls_leastsquares.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111
from scipy.sparse.linalg import cg as sp_cg
12-
from scipy.sparse.linalg import lsqr
12+
from scipy.sparse.linalg import lsqr as sp_lsqr
1313

1414
from pylops.basicoperators import Diagonal, VStack
1515
from pylops.optimization.basesolver import Solver, _units
@@ -676,7 +676,7 @@ def run(
676676
if engine == "scipy" and self.ncp == np:
677677
if show:
678678
kwargs_solver["show"] = 1
679-
xinv, istop, itn, r1norm, r2norm = lsqr(
679+
xinv, istop, itn, r1norm, r2norm = sp_lsqr(
680680
self.RegOp, self.datatot, x0=x, **kwargs_solver
681681
)[0:5]
682682
elif engine == "pylops" or self.ncp != np:
@@ -938,7 +938,7 @@ def run(
938938
if engine == "scipy" and self.ncp == np:
939939
if show:
940940
kwargs_solver["show"] = 1
941-
pinv, istop, itn, r1norm, r2norm = lsqr(
941+
pinv, istop, itn, r1norm, r2norm = sp_lsqr(
942942
self.POp,
943943
self.y,
944944
x0=x,

0 commit comments

Comments
 (0)