Skip to content

Commit 817720d

Browse files
authored
Merge pull request #693 from mrava87/feat-rtol1
Feat: more callbacks
2 parents 2604cfd + 69efe67 commit 817720d

8 files changed

Lines changed: 311 additions & 203 deletions

File tree

docs/source/addingsolver.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ input and returns some of the most valuable properties of the class-based solver
189189
190190
def cg(Op, y, x0, niter=10, tol=1e-4, rtol=0.0,
191191
show=False, itershow=(10, 10, 10), callback=None):
192-
rcallback = ResidualNormCallback(rtol)
193-
cgsolve = CG(Op, callbacks=[rcallback, ])
192+
cgsolve = CG(Op, callbacks=[CostToInitialCallback(rtol), ])
194193
if callback is not None:
195194
cgsolve.callback = callback
196195
x, iiter, cost = cgsolve.solve(

docs/source/api/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ Callbacks
249249
:toctree: generated/
250250

251251
Callbacks
252+
CostNanInfCallback
253+
CostToDataCallback
254+
CostToInitialCallback
252255
MetricsCallback
253256

254257

pylops/optimization/basic.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import TYPE_CHECKING, Callable, Optional, Tuple
88

9-
from pylops.optimization.callback import ResidualNormCallback
9+
from pylops.optimization.callback import CostToDataCallback, CostToInitialCallback
1010
from pylops.optimization.cls_basic import CG, CGLS, LSQR
1111
from pylops.utils.decorators import add_ndarray_support_to_solver
1212
from pylops.utils.typing import NDArray
@@ -22,7 +22,8 @@ def cg(
2222
x0: Optional[NDArray] = None,
2323
niter: int = 10,
2424
tol: float = 1e-4,
25-
rtol: bool = 0.0,
25+
rtol: float = 0.0,
26+
rtol1: float = 0.0,
2627
show: bool = False,
2728
itershow: Tuple[int, int, int] = (10, 10, 10),
2829
callback: Optional[Callable] = None,
@@ -47,8 +48,12 @@ def cg(
4748
Absolute tolerance on residual norm. Stops the solver when the
4849
residual norm is below this value.
4950
rtol : :obj:`float`, optional
50-
Relative tolerance on residual norm. Stops the solver when the
51-
ratio of the current residual norm to the initial residual norm is
51+
Relative tolerance on residual norm wrt initial residual norm. Stops
52+
the solver when the ratio of the current residual norm to the initial
53+
residual norm is below this value.
54+
rtol1 : :obj:`float`, optional
55+
Relative tolerance on residual norm wrt to data. Stops the solver
56+
when the ratio of the current residual norm to the data norm is
5257
below this value.
5358
show : :obj:`bool`, optional
5459
Display iterations log
@@ -78,12 +83,15 @@ def cg(
7883
See :class:`pylops.optimization.cls_basic.CG`
7984
8085
"""
81-
rcallback = ResidualNormCallback(rtol)
86+
callbacks = []
87+
if rtol > 0.0:
88+
callbacks.append(CostToInitialCallback(rtol))
89+
if rtol1 > 0.0:
90+
callbacks.append(CostToDataCallback(rtol1))
91+
8292
cgsolve = CG(
8393
Op,
84-
callbacks=[
85-
rcallback,
86-
],
94+
callbacks=callbacks if len(callbacks) > 0 else None,
8795
)
8896
if callback is not None:
8997
cgsolve.callback = callback
@@ -108,6 +116,7 @@ def cgls(
108116
damp: float = 0.0,
109117
tol: float = 1e-4,
110118
rtol: float = 0.0,
119+
rtol1: float = 0.0,
111120
show: bool = False,
112121
itershow: Tuple[int, int, int] = (10, 10, 10),
113122
callback: Optional[Callable] = None,
@@ -134,8 +143,12 @@ def cgls(
134143
Absolute tolerance on residual norm. Stops the solver when the
135144
residual norm is below this value.
136145
rtol : :obj:`float`, optional
137-
Relative tolerance on residual norm. Stops the solver when the
138-
ratio of the current residual norm to the initial residual norm is
146+
Relative tolerance on residual norm wrt initial residual norm. Stops
147+
the solver when the ratio of the current residual norm to the initial
148+
residual norm is below this value.
149+
rtol1 : :obj:`float`, optional
150+
Relative tolerance on residual norm wrt to data. Stops the solver
151+
when the ratio of the current residual norm to the data norm is
139152
below this value.
140153
show : :obj:`bool`, optional
141154
Display iterations log
@@ -180,12 +193,14 @@ def cgls(
180193
See :class:`pylops.optimization.cls_basic.CGLS`
181194
182195
"""
183-
rcallback = ResidualNormCallback(rtol)
196+
callbacks = []
197+
if rtol > 0.0:
198+
callbacks.append(CostToInitialCallback(rtol))
199+
if rtol1 > 0.0:
200+
callbacks.append(CostToDataCallback(rtol1))
184201
cgsolve = CGLS(
185202
Op,
186-
callbacks=[
187-
rcallback,
188-
],
203+
callbacks=callbacks if len(callbacks) > 0 else None,
189204
)
190205
if callback is not None:
191206
cgsolve.callback = callback

pylops/optimization/callback.py

Lines changed: 82 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
__all__ = [
22
"Callbacks",
3+
"CostNanInfCallback",
4+
"CostToDataCallback",
5+
"CostToInitialCallback",
36
"MetricsCallback",
4-
"ResidualNormCallback",
57
]
68

79
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
810

11+
import numpy as np
12+
913
from pylops.utils.metrics import mae, mse, psnr, snr
1014
from pylops.utils.typing import NDArray
1115

@@ -141,6 +145,83 @@ def on_run_end(self, solver: "Solver", x: NDArray) -> None:
141145
pass
142146

143147

148+
class CostToDataCallback(Callbacks):
149+
"""Cost to data callback
150+
151+
This callback can be used to stop the solver when the ``cost`` parameter
152+
of the solver is below a certain threshold defined as a percentage of the
153+
Euclidean norm of the data.
154+
155+
Note that the meaning of ``cost`` can change from solver to solver - e.g.,
156+
it can represent the misfit of the data term or the total cost function.
157+
158+
Parameters
159+
----------
160+
rtol : :obj:`float`
161+
Percentage of the initial cost below which the solver
162+
will stop iterating. For example, if ``rtol`` is 0.1, the solver
163+
will stop when the cost is below 10% of the Euclidean norm of
164+
the data.
165+
166+
"""
167+
168+
def __init__(self, rtol: float) -> None:
169+
self.rtol = rtol
170+
self.stop = False
171+
172+
def on_setup_end(self, solver: "Solver", x: NDArray) -> None:
173+
self.ynorm = solver.ncp.linalg.norm(solver.y)
174+
175+
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
176+
if solver.cost[-1] < self.rtol * self.ynorm:
177+
self.stop = True
178+
179+
180+
class CostToInitialCallback(Callbacks):
181+
"""Cost to initial callback
182+
183+
This callback can be used to stop the solver when the ``cost``
184+
parameter of the solver is below a certain threshold defined as a
185+
percentage of the initial residual norm.
186+
187+
Note that the meaning of ``cost`` can change from solver to solver - e.g.,
188+
it can represent the misfit of the data term or the total cost function.
189+
190+
Parameters
191+
----------
192+
rtol : :obj:`float`
193+
Percentage of the initial cost below which the solver
194+
will stop iterating. For example, if ``rtol`` is 0.1, the solver
195+
will stop when the cost is below 10% of the initial
196+
cost.
197+
198+
"""
199+
200+
def __init__(self, rtol: float) -> None:
201+
self.rtol = rtol
202+
self.stop = False
203+
204+
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
205+
if solver.cost[-1] < self.rtol * solver.cost[0]:
206+
self.stop = True
207+
208+
209+
class CostNanInfCallback(Callbacks):
210+
"""Cost Nan/Inf callback
211+
212+
This callback can be used to stop the solver when the ``cost``
213+
becomes either ``np.nan`` or ``np.inf``
214+
215+
"""
216+
217+
def __init__(self) -> None:
218+
self.stop = False
219+
220+
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
221+
if np.isnan(solver.cost[-1]) or np.isinf(solver.cost[-1]):
222+
self.stop = True
223+
224+
144225
class MetricsCallback(Callbacks):
145226
r"""Metrics callback
146227
@@ -191,32 +272,6 @@ def on_step_end(self, solver: "Solver", x: NDArray) -> None:
191272
self.metrics["psnr"].append(psnr(self.xtrue, x))
192273

193274

194-
class ResidualNormCallback(Callbacks):
195-
"""Residual norm callback
196-
197-
This callback can be used to stop the solver when the residual norm
198-
is below a certain threshold defined as a percentage of the
199-
initial residual norm.
200-
201-
Parameters
202-
----------
203-
rtol : :obj:`float`
204-
Percentage of the initial residual norm below which the solver
205-
will stop iterating. For example, if `rtol` is 0.1, the solver
206-
will stop when the residual norm is below 10% of the initial
207-
residual norm.
208-
209-
"""
210-
211-
def __init__(self, rtol: float) -> None:
212-
self.rtol = rtol
213-
self.stop = False
214-
215-
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
216-
if solver.cost[-1] < self.rtol * solver.cost[0]:
217-
self.stop = True
218-
219-
220275
def _callback_stop(callbacks: Sequence[Callbacks]) -> bool:
221276
"""Check if any callback has raised a stop flag
222277

pylops/optimization/cls_sparsity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,7 @@ def setup(
10281028
# create variables to track the residual norm and iterations
10291029
self.res = self.y.copy()
10301030
self.cost = [
1031-
float(np.linalg.norm(self.y)),
1031+
float(np.linalg.norm(self.res)),
10321032
]
10331033
self.iiter = 0
10341034

0 commit comments

Comments
 (0)