|
9 | 9 |
|
10 | 10 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple |
11 | 11 |
|
12 | | -from pylops.optimization.callback import ResidualNormCallback |
| 12 | +from pylops.optimization.callback import ( |
| 13 | + ResidualNormToDataCallback, |
| 14 | + ResidualNormToInitialCallback, |
| 15 | +) |
13 | 16 | from pylops.optimization.cls_sparsity import FISTA, IRLS, ISTA, OMP, SPGL1, SplitBregman |
14 | 17 | from pylops.utils.decorators import add_ndarray_support_to_solver |
15 | 18 | from pylops.utils.typing import NDArray, SamplingLike |
@@ -142,6 +145,7 @@ def omp( |
142 | 145 | niter_inner: int = 40, |
143 | 146 | sigma: float = 1e-4, |
144 | 147 | rtol: float = 0.0, |
| 148 | + rtol1: float = 0.0, |
145 | 149 | normalizecols: bool = False, |
146 | 150 | Opbasis: Optional["LinearOperator"] = None, |
147 | 151 | optimal_coeff: bool = False, |
@@ -172,9 +176,13 @@ def omp( |
172 | 176 | sigma : :obj:`float`, optional |
173 | 177 | Maximum :math:`L_2` norm of residual. When smaller stop iterations. |
174 | 178 | rtol : :obj:`float`, optional |
175 | | - Relative tolerance on residual. Stops the solver when the |
176 | | - ratio of the current residual norm to the initial residual norm |
177 | | - is below this value. |
| 179 | + Relative tolerance on residual norm wrt initial residual norm. Stops |
| 180 | + the solver when the ratio of the current residual norm to the initial |
| 181 | + residual norm is below this value. |
| 182 | + rtol1 : :obj:`float`, optional |
| 183 | + Relative tolerance on residual norm wrt to data. Stops the solver |
| 184 | + when the ratio of the current residual norm to the data norm is |
| 185 | + below this value. |
178 | 186 | normalizecols : :obj:`list`, optional |
179 | 187 | Normalize columns (``True``) or not (``False``). Note that this can be |
180 | 188 | expensive as it requires applying the forward operator |
@@ -229,12 +237,15 @@ def omp( |
229 | 237 | See :class:`pylops.optimization.cls_sparsity.OMP` |
230 | 238 |
|
231 | 239 | """ |
232 | | - rcallback = ResidualNormCallback(rtol) |
| 240 | + callbacks = [] |
| 241 | + if rtol > 0.0: |
| 242 | + callbacks.append(ResidualNormToInitialCallback(rtol)) |
| 243 | + if rtol1 > 0.0: |
| 244 | + callbacks.append(ResidualNormToDataCallback(rtol1)) |
| 245 | + |
233 | 246 | ompsolve = OMP( |
234 | 247 | Op, |
235 | | - callbacks=[ |
236 | | - rcallback, |
237 | | - ], |
| 248 | + callbacks=callbacks if len(callbacks) > 0 else None, |
238 | 249 | ) |
239 | 250 | if callback is not None: |
240 | 251 | ompsolve.callback = callback |
@@ -264,7 +275,8 @@ def ista( |
264 | 275 | alpha: Optional[float] = None, |
265 | 276 | eigsdict: Optional[Dict[str, Any]] = None, |
266 | 277 | tol: float = 1e-10, |
267 | | - rtol: bool = 0.0, |
| 278 | + rtol: float = 0.0, |
| 279 | + rtol1: float = 0.0, |
268 | 280 | threshkind: str = "soft", |
269 | 281 | perc: Optional[float] = None, |
270 | 282 | decay: Optional[NDArray] = None, |
@@ -309,9 +321,13 @@ def ista( |
309 | 321 | Absolute tolerance on model update. Stop iterations if difference between inverted model |
310 | 322 | at subsequent iterations is smaller than ``tol`` |
311 | 323 | rtol : :obj:`float`, optional |
312 | | - Relative tolerance on total cost function. Stops the solver when the |
313 | | - ratio of the current cost function to the initial cost function |
314 | | - is below this value. |
| 324 | + Relative tolerance on residual norm wrt initial residual norm. Stops |
| 325 | + the solver when the ratio of the current residual norm to the initial |
| 326 | + residual norm is below this value. |
| 327 | + rtol1 : :obj:`float`, optional |
| 328 | + Relative tolerance on residual norm wrt to data. Stops the solver |
| 329 | + when the ratio of the current residual norm to the data norm is |
| 330 | + below this value. |
315 | 331 | threshkind : :obj:`str`, optional |
316 | 332 | Kind of thresholding ('hard', 'soft', 'half', 'hard-percentile', |
317 | 333 | 'soft-percentile', or 'half-percentile' - 'soft' used as default) |
@@ -370,12 +386,15 @@ def ista( |
370 | 386 | See :class:`pylops.optimization.cls_sparsity.ISTA` |
371 | 387 |
|
372 | 388 | """ |
373 | | - rcallback = ResidualNormCallback(rtol) |
| 389 | + callbacks = [] |
| 390 | + if rtol > 0.0: |
| 391 | + callbacks.append(ResidualNormToInitialCallback(rtol)) |
| 392 | + if rtol1 > 0.0: |
| 393 | + callbacks.append(ResidualNormToDataCallback(rtol1)) |
| 394 | + |
374 | 395 | istasolve = ISTA( |
375 | 396 | Op, |
376 | | - callbacks=[ |
377 | | - rcallback, |
378 | | - ], |
| 397 | + callbacks=callbacks if len(callbacks) > 0 else None, |
379 | 398 | ) |
380 | 399 | if callback is not None: |
381 | 400 | istasolve.callback = callback |
@@ -410,6 +429,7 @@ def fista( |
410 | 429 | eigsdict: Optional[Dict[str, Any]] = None, |
411 | 430 | tol: float = 1e-10, |
412 | 431 | rtol: float = 0.0, |
| 432 | + rtol1: float = 0.0, |
413 | 433 | threshkind: str = "soft", |
414 | 434 | perc: Optional[float] = None, |
415 | 435 | decay: Optional[NDArray] = None, |
@@ -513,12 +533,15 @@ def fista( |
513 | 533 | See :class:`pylops.optimization.cls_sparsity.FISTA` |
514 | 534 |
|
515 | 535 | """ |
516 | | - rcallback = ResidualNormCallback(rtol) |
| 536 | + callbacks = [] |
| 537 | + if rtol > 0.0: |
| 538 | + callbacks.append(ResidualNormToInitialCallback(rtol)) |
| 539 | + if rtol1 > 0.0: |
| 540 | + callbacks.append(ResidualNormToDataCallback(rtol1)) |
| 541 | + |
517 | 542 | fistasolve = FISTA( |
518 | 543 | Op, |
519 | | - callbacks=[ |
520 | | - rcallback, |
521 | | - ], |
| 544 | + callbacks=callbacks if len(callbacks) > 0 else None, |
522 | 545 | ) |
523 | 546 | if callback is not None: |
524 | 547 | fistasolve.callback = callback |
@@ -673,6 +696,7 @@ def splitbregman( |
673 | 696 | epsRL2s: Optional[SamplingLike] = None, |
674 | 697 | tol: float = 1e-10, |
675 | 698 | rtol: float = 0.0, |
| 699 | + rtol1: float = 0.0, |
676 | 700 | tau: float = 1.0, |
677 | 701 | restart: bool = False, |
678 | 702 | engine: str = "scipy", |
@@ -726,9 +750,13 @@ def splitbregman( |
726 | 750 | Tolerance. Stop the solver if difference between inverted model |
727 | 751 | at subsequent iterations is smaller than ``tol`` |
728 | 752 | rtol : :obj:`float`, optional |
729 | | - Relative tolerance on total cost function. Stops the solver when the |
730 | | - ratio of the current cost function to the initial cost function |
731 | | - is below this value. |
| 753 | + Relative tolerance on residual norm wrt initial residual norm. Stops |
| 754 | + the solver when the ratio of the current residual norm to the initial |
| 755 | + residual norm is below this value. |
| 756 | + rtol1 : :obj:`float`, optional |
| 757 | + Relative tolerance on residual norm wrt to data. Stops the solver |
| 758 | + when the ratio of the current residual norm to the data norm is |
| 759 | + below this value. |
732 | 760 | tau : :obj:`float`, optional |
733 | 761 | Scaling factor in the Bregman update (must be close to 1) |
734 | 762 | restart : :obj:`bool`, optional |
@@ -772,12 +800,14 @@ def splitbregman( |
772 | 800 | See :class:`pylops.optimization.cls_sparsity.SplitBregman` |
773 | 801 |
|
774 | 802 | """ |
775 | | - rcallback = ResidualNormCallback(rtol) |
| 803 | + callbacks = [] |
| 804 | + if rtol > 0.0: |
| 805 | + callbacks.append(ResidualNormToInitialCallback(rtol)) |
| 806 | + if rtol1 > 0.0: |
| 807 | + callbacks.append(ResidualNormToDataCallback(rtol1)) |
776 | 808 | sbsolve = SplitBregman( |
777 | 809 | Op, |
778 | | - callbacks=[ |
779 | | - rcallback, |
780 | | - ], |
| 810 | + callbacks=callbacks if len(callbacks) > 0 else None, |
781 | 811 | ) |
782 | 812 | if callback is not None: |
783 | 813 | sbsolve.callback = callback |
|
0 commit comments