Skip to content

Commit 0108e8b

Browse files
committed
feat: added ResidualNormToInitialCallback and rtol1 to solvers
1 parent 2604cfd commit 0108e8b

5 files changed

Lines changed: 127 additions & 48 deletions

File tree

docs/source/addingsolver.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ input and returns some of the most valuable properties of the class-based solver
189189
190190
def cg(Op, y, x0, niter=10, tol=1e-4, rtol=0.0,
191191
show=False, itershow=(10, 10, 10), callback=None):
192-
rcallback = ResidualNormCallback(rtol)
193-
cgsolve = CG(Op, callbacks=[rcallback, ])
192+
cgsolve = CG(Op, callbacks=[ResidualNormToInitialCallback(rtol), ])
194193
if callback is not None:
195194
cgsolve.callback = callback
196195
x, iiter, cost = cgsolve.solve(

pylops/optimization/basic.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
from typing import TYPE_CHECKING, Callable, Optional, Tuple
88

9-
from pylops.optimization.callback import ResidualNormCallback
9+
from pylops.optimization.callback import (
10+
ResidualNormToDataCallback,
11+
ResidualNormToInitialCallback,
12+
)
1013
from pylops.optimization.cls_basic import CG, CGLS, LSQR
1114
from pylops.utils.decorators import add_ndarray_support_to_solver
1215
from pylops.utils.typing import NDArray
@@ -22,7 +25,8 @@ def cg(
2225
x0: Optional[NDArray] = None,
2326
niter: int = 10,
2427
tol: float = 1e-4,
25-
rtol: bool = 0.0,
28+
rtol: float = 0.0,
29+
rtol1: float = 0.0,
2630
show: bool = False,
2731
itershow: Tuple[int, int, int] = (10, 10, 10),
2832
callback: Optional[Callable] = None,
@@ -47,8 +51,12 @@ def cg(
4751
Absolute tolerance on residual norm. Stops the solver when the
4852
residual norm is below this value.
4953
rtol : :obj:`float`, optional
50-
Relative tolerance on residual norm. Stops the solver when the
51-
ratio of the current residual norm to the initial residual norm is
54+
Relative tolerance on residual norm wrt initial residual norm. Stops
55+
the solver when the ratio of the current residual norm to the initial
56+
residual norm is below this value.
57+
rtol1 : :obj:`float`, optional
58+
Relative tolerance on residual norm wrt to data. Stops the solver
59+
when the ratio of the current residual norm to the data norm is
5260
below this value.
5361
show : :obj:`bool`, optional
5462
Display iterations log
@@ -78,12 +86,15 @@ def cg(
7886
See :class:`pylops.optimization.cls_basic.CG`
7987
8088
"""
81-
rcallback = ResidualNormCallback(rtol)
89+
callbacks = []
90+
if rtol > 0.0:
91+
callbacks.append(ResidualNormToInitialCallback(rtol))
92+
if rtol1 > 0.0:
93+
callbacks.append(ResidualNormToDataCallback(rtol1))
94+
8295
cgsolve = CG(
8396
Op,
84-
callbacks=[
85-
rcallback,
86-
],
97+
callbacks=callbacks if len(callbacks) > 0 else None,
8798
)
8899
if callback is not None:
89100
cgsolve.callback = callback
@@ -108,6 +119,7 @@ def cgls(
108119
damp: float = 0.0,
109120
tol: float = 1e-4,
110121
rtol: float = 0.0,
122+
rtol1: float = 0.0,
111123
show: bool = False,
112124
itershow: Tuple[int, int, int] = (10, 10, 10),
113125
callback: Optional[Callable] = None,
@@ -134,8 +146,12 @@ def cgls(
134146
Absolute tolerance on residual norm. Stops the solver when the
135147
residual norm is below this value.
136148
rtol : :obj:`float`, optional
137-
Relative tolerance on residual norm. Stops the solver when the
138-
ratio of the current residual norm to the initial residual norm is
149+
Relative tolerance on residual norm wrt initial residual norm. Stops
150+
the solver when the ratio of the current residual norm to the initial
151+
residual norm is below this value.
152+
rtol1 : :obj:`float`, optional
153+
Relative tolerance on residual norm wrt to data. Stops the solver
154+
when the ratio of the current residual norm to the data norm is
139155
below this value.
140156
show : :obj:`bool`, optional
141157
Display iterations log
@@ -180,12 +196,14 @@ def cgls(
180196
See :class:`pylops.optimization.cls_basic.CGLS`
181197
182198
"""
183-
rcallback = ResidualNormCallback(rtol)
199+
callbacks = []
200+
if rtol > 0.0:
201+
callbacks.append(ResidualNormToInitialCallback(rtol))
202+
if rtol1 > 0.0:
203+
callbacks.append(ResidualNormToDataCallback(rtol1))
184204
cgsolve = CGLS(
185205
Op,
186-
callbacks=[
187-
rcallback,
188-
],
206+
callbacks=callbacks if len(callbacks) > 0 else None,
189207
)
190208
if callback is not None:
191209
cgsolve.callback = callback

pylops/optimization/callback.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
__all__ = [
22
"Callbacks",
33
"MetricsCallback",
4-
"ResidualNormCallback",
4+
"ResidualNormToDataCallback",
5+
"ResidualNormToInitialCallback",
56
]
67

78
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
89

10+
import numpy as np
11+
912
from pylops.utils.metrics import mae, mse, psnr, snr
1013
from pylops.utils.typing import NDArray
1114

@@ -191,8 +194,37 @@ def on_step_end(self, solver: "Solver", x: NDArray) -> None:
191194
self.metrics["psnr"].append(psnr(self.xtrue, x))
192195

193196

194-
class ResidualNormCallback(Callbacks):
195-
"""Residual norm callback
197+
class ResidualNormToDataCallback(Callbacks):
198+
"""Residual norm to data callback
199+
200+
This callback can be used to stop the solver when the residual norm
201+
is below a certain threshold defined as a percentage of the
202+
initial residual norm.
203+
204+
Parameters
205+
----------
206+
rtol : :obj:`float`
207+
Percentage of the initial residual norm below which the solver
208+
will stop iterating. For example, if `rtol` is 0.1, the solver
209+
will stop when the residual norm is below 10% of the initial
210+
residual norm.
211+
212+
"""
213+
214+
def __init__(self, rtol: float) -> None:
215+
self.rtol = rtol
216+
self.stop = False
217+
218+
def on_setup_end(self, solver: "Solver", x: NDArray) -> None:
219+
self.ynorm = self.ncp.linalg.norm(self.y)
220+
221+
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
222+
if solver.cost[-1] < self.rtol * self.ynorm:
223+
self.stop = True
224+
225+
226+
class ResidualNormToInitialCallback(Callbacks):
227+
"""Residual norm to initial callback
196228
197229
This callback can be used to stop the solver when the residual norm
198230
is below a certain threshold defined as a percentage of the

pylops/optimization/sparsity.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
1111

12-
from pylops.optimization.callback import ResidualNormCallback
12+
from pylops.optimization.callback import (
13+
ResidualNormToDataCallback,
14+
ResidualNormToInitialCallback,
15+
)
1316
from pylops.optimization.cls_sparsity import FISTA, IRLS, ISTA, OMP, SPGL1, SplitBregman
1417
from pylops.utils.decorators import add_ndarray_support_to_solver
1518
from pylops.utils.typing import NDArray, SamplingLike
@@ -142,6 +145,7 @@ def omp(
142145
niter_inner: int = 40,
143146
sigma: float = 1e-4,
144147
rtol: float = 0.0,
148+
rtol1: float = 0.0,
145149
normalizecols: bool = False,
146150
Opbasis: Optional["LinearOperator"] = None,
147151
optimal_coeff: bool = False,
@@ -172,9 +176,13 @@ def omp(
172176
sigma : :obj:`float`, optional
173177
Maximum :math:`L_2` norm of residual. When smaller stop iterations.
174178
rtol : :obj:`float`, optional
175-
Relative tolerance on residual. Stops the solver when the
176-
ratio of the current residual norm to the initial residual norm
177-
is below this value.
179+
Relative tolerance on residual norm wrt initial residual norm. Stops
180+
the solver when the ratio of the current residual norm to the initial
181+
residual norm is below this value.
182+
rtol1 : :obj:`float`, optional
183+
Relative tolerance on residual norm wrt to data. Stops the solver
184+
when the ratio of the current residual norm to the data norm is
185+
below this value.
178186
normalizecols : :obj:`list`, optional
179187
Normalize columns (``True``) or not (``False``). Note that this can be
180188
expensive as it requires applying the forward operator
@@ -229,12 +237,15 @@ def omp(
229237
See :class:`pylops.optimization.cls_sparsity.OMP`
230238
231239
"""
232-
rcallback = ResidualNormCallback(rtol)
240+
callbacks = []
241+
if rtol > 0.0:
242+
callbacks.append(ResidualNormToInitialCallback(rtol))
243+
if rtol1 > 0.0:
244+
callbacks.append(ResidualNormToDataCallback(rtol1))
245+
233246
ompsolve = OMP(
234247
Op,
235-
callbacks=[
236-
rcallback,
237-
],
248+
callbacks=callbacks if len(callbacks) > 0 else None,
238249
)
239250
if callback is not None:
240251
ompsolve.callback = callback
@@ -264,7 +275,8 @@ def ista(
264275
alpha: Optional[float] = None,
265276
eigsdict: Optional[Dict[str, Any]] = None,
266277
tol: float = 1e-10,
267-
rtol: bool = 0.0,
278+
rtol: float = 0.0,
279+
rtol1: float = 0.0,
268280
threshkind: str = "soft",
269281
perc: Optional[float] = None,
270282
decay: Optional[NDArray] = None,
@@ -309,9 +321,13 @@ def ista(
309321
Absolute tolerance on model update. Stop iterations if difference between inverted model
310322
at subsequent iterations is smaller than ``tol``
311323
rtol : :obj:`float`, optional
312-
Relative tolerance on total cost function. Stops the solver when the
313-
ratio of the current cost function to the initial cost function
314-
is below this value.
324+
Relative tolerance on residual norm wrt initial residual norm. Stops
325+
the solver when the ratio of the current residual norm to the initial
326+
residual norm is below this value.
327+
rtol1 : :obj:`float`, optional
328+
Relative tolerance on residual norm wrt to data. Stops the solver
329+
when the ratio of the current residual norm to the data norm is
330+
below this value.
315331
threshkind : :obj:`str`, optional
316332
Kind of thresholding ('hard', 'soft', 'half', 'hard-percentile',
317333
'soft-percentile', or 'half-percentile' - 'soft' used as default)
@@ -370,12 +386,15 @@ def ista(
370386
See :class:`pylops.optimization.cls_sparsity.ISTA`
371387
372388
"""
373-
rcallback = ResidualNormCallback(rtol)
389+
callbacks = []
390+
if rtol > 0.0:
391+
callbacks.append(ResidualNormToInitialCallback(rtol))
392+
if rtol1 > 0.0:
393+
callbacks.append(ResidualNormToDataCallback(rtol1))
394+
374395
istasolve = ISTA(
375396
Op,
376-
callbacks=[
377-
rcallback,
378-
],
397+
callbacks=callbacks if len(callbacks) > 0 else None,
379398
)
380399
if callback is not None:
381400
istasolve.callback = callback
@@ -410,6 +429,7 @@ def fista(
410429
eigsdict: Optional[Dict[str, Any]] = None,
411430
tol: float = 1e-10,
412431
rtol: float = 0.0,
432+
rtol1: float = 0.0,
413433
threshkind: str = "soft",
414434
perc: Optional[float] = None,
415435
decay: Optional[NDArray] = None,
@@ -513,12 +533,15 @@ def fista(
513533
See :class:`pylops.optimization.cls_sparsity.FISTA`
514534
515535
"""
516-
rcallback = ResidualNormCallback(rtol)
536+
callbacks = []
537+
if rtol > 0.0:
538+
callbacks.append(ResidualNormToInitialCallback(rtol))
539+
if rtol1 > 0.0:
540+
callbacks.append(ResidualNormToDataCallback(rtol1))
541+
517542
fistasolve = FISTA(
518543
Op,
519-
callbacks=[
520-
rcallback,
521-
],
544+
callbacks=callbacks if len(callbacks) > 0 else None,
522545
)
523546
if callback is not None:
524547
fistasolve.callback = callback
@@ -673,6 +696,7 @@ def splitbregman(
673696
epsRL2s: Optional[SamplingLike] = None,
674697
tol: float = 1e-10,
675698
rtol: float = 0.0,
699+
rtol1: float = 0.0,
676700
tau: float = 1.0,
677701
restart: bool = False,
678702
engine: str = "scipy",
@@ -726,9 +750,13 @@ def splitbregman(
726750
Tolerance. Stop the solver if difference between inverted model
727751
at subsequent iterations is smaller than ``tol``
728752
rtol : :obj:`float`, optional
729-
Relative tolerance on total cost function. Stops the solver when the
730-
ratio of the current cost function to the initial cost function
731-
is below this value.
753+
Relative tolerance on residual norm wrt initial residual norm. Stops
754+
the solver when the ratio of the current residual norm to the initial
755+
residual norm is below this value.
756+
rtol1 : :obj:`float`, optional
757+
Relative tolerance on residual norm wrt to data. Stops the solver
758+
when the ratio of the current residual norm to the data norm is
759+
below this value.
732760
tau : :obj:`float`, optional
733761
Scaling factor in the Bregman update (must be close to 1)
734762
restart : :obj:`bool`, optional
@@ -772,12 +800,14 @@ def splitbregman(
772800
See :class:`pylops.optimization.cls_sparsity.SplitBregman`
773801
774802
"""
775-
rcallback = ResidualNormCallback(rtol)
803+
callbacks = []
804+
if rtol > 0.0:
805+
callbacks.append(ResidualNormToInitialCallback(rtol))
806+
if rtol1 > 0.0:
807+
callbacks.append(ResidualNormToDataCallback(rtol1))
776808
sbsolve = SplitBregman(
777809
Op,
778-
callbacks=[
779-
rcallback,
780-
],
810+
callbacks=callbacks if len(callbacks) > 0 else None,
781811
)
782812
if callback is not None:
783813
sbsolve.callback = callback

pytests/test_sparsity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pytest
1515

1616
from pylops.basicoperators import FirstDerivative, Identity, MatrixMult
17-
from pylops.optimization.callback import ResidualNormCallback
17+
from pylops.optimization.callback import ResidualNormToInitialCallback
1818
from pylops.optimization.cls_sparsity import IRLS
1919
from pylops.optimization.sparsity import fista, irls, ista, omp, spgl1, splitbregman
2020

@@ -226,7 +226,7 @@ def test_IRLS_model_stopping(par):
226226
rtol = 6e-1
227227
kwars_solver = dict(iter_lim=5) if backend == "numpy" else dict(niter=5)
228228

229-
rcallback = ResidualNormCallback(rtol)
229+
rcallback = ResidualNormToInitialCallback(rtol)
230230
irlssolve = IRLS(
231231
Aop,
232232
callbacks=[

0 commit comments

Comments
 (0)