Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/source/addingsolver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ input and returns some of the most valuable properties of the class-based solver

def cg(Op, y, x0, niter=10, tol=1e-4, rtol=0.0,
show=False, itershow=(10, 10, 10), callback=None):
rcallback = ResidualNormCallback(rtol)
cgsolve = CG(Op, callbacks=[rcallback, ])
cgsolve = CG(Op, callbacks=[CostToInitialCallback(rtol), ])
if callback is not None:
cgsolve.callback = callback
x, iiter, cost = cgsolve.solve(
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ Callbacks
:toctree: generated/

Callbacks
CostNanInfCallback
CostToDataCallback
CostToInitialCallback
MetricsCallback


Expand Down
43 changes: 29 additions & 14 deletions pylops/optimization/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import TYPE_CHECKING, Callable, Optional, Tuple

from pylops.optimization.callback import ResidualNormCallback
from pylops.optimization.callback import CostToDataCallback, CostToInitialCallback
from pylops.optimization.cls_basic import CG, CGLS, LSQR
from pylops.utils.decorators import add_ndarray_support_to_solver
from pylops.utils.typing import NDArray
Expand All @@ -22,7 +22,8 @@ def cg(
x0: Optional[NDArray] = None,
niter: int = 10,
tol: float = 1e-4,
rtol: bool = 0.0,
rtol: float = 0.0,
rtol1: float = 0.0,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
Expand All @@ -47,8 +48,12 @@ def cg(
Absolute tolerance on residual norm. Stops the solver when the
residual norm is below this value.
rtol : :obj:`float`, optional
Relative tolerance on residual norm. Stops the solver when the
ratio of the current residual norm to the initial residual norm is
Relative tolerance on residual norm wrt initial residual norm. Stops
the solver when the ratio of the current residual norm to the initial
residual norm is below this value.
rtol1 : :obj:`float`, optional
Relative tolerance on residual norm wrt to data. Stops the solver
when the ratio of the current residual norm to the data norm is
below this value.
show : :obj:`bool`, optional
Display iterations log
Expand Down Expand Up @@ -78,12 +83,15 @@ def cg(
See :class:`pylops.optimization.cls_basic.CG`

"""
rcallback = ResidualNormCallback(rtol)
callbacks = []
if rtol > 0.0:
callbacks.append(CostToInitialCallback(rtol))
if rtol1 > 0.0:
callbacks.append(CostToDataCallback(rtol1))

cgsolve = CG(
Op,
callbacks=[
rcallback,
],
callbacks=callbacks if len(callbacks) > 0 else None,
)
if callback is not None:
cgsolve.callback = callback
Expand All @@ -108,6 +116,7 @@ def cgls(
damp: float = 0.0,
tol: float = 1e-4,
rtol: float = 0.0,
rtol1: float = 0.0,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
Expand All @@ -134,8 +143,12 @@ def cgls(
Absolute tolerance on residual norm. Stops the solver when the
residual norm is below this value.
rtol : :obj:`float`, optional
Relative tolerance on residual norm. Stops the solver when the
ratio of the current residual norm to the initial residual norm is
Relative tolerance on residual norm wrt initial residual norm. Stops
the solver when the ratio of the current residual norm to the initial
residual norm is below this value.
rtol1 : :obj:`float`, optional
Relative tolerance on residual norm wrt to data. Stops the solver
when the ratio of the current residual norm to the data norm is
below this value.
show : :obj:`bool`, optional
Display iterations log
Expand Down Expand Up @@ -180,12 +193,14 @@ def cgls(
See :class:`pylops.optimization.cls_basic.CGLS`

"""
rcallback = ResidualNormCallback(rtol)
callbacks = []
if rtol > 0.0:
callbacks.append(CostToInitialCallback(rtol))
if rtol1 > 0.0:
callbacks.append(CostToDataCallback(rtol1))
cgsolve = CGLS(
Op,
callbacks=[
rcallback,
],
callbacks=callbacks if len(callbacks) > 0 else None,
)
if callback is not None:
cgsolve.callback = callback
Expand Down
109 changes: 82 additions & 27 deletions pylops/optimization/callback.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
__all__ = [
"Callbacks",
"CostNanInfCallback",
"CostToDataCallback",
"CostToInitialCallback",
"MetricsCallback",
"ResidualNormCallback",
]

from typing import TYPE_CHECKING, Dict, List, Optional, Sequence

import numpy as np

from pylops.utils.metrics import mae, mse, psnr, snr
from pylops.utils.typing import NDArray

Expand Down Expand Up @@ -141,6 +145,83 @@ def on_run_end(self, solver: "Solver", x: NDArray) -> None:
pass


class CostToDataCallback(Callbacks):
"""Cost to data callback

This callback can be used to stop the solver when the ``cost`` parameter
of the solver is below a certain threshold defined as a percentage of the
Euclidean norm of the data.

Note that the meaning of ``cost`` can change from solver to solver - e.g.,
it can represent the misfit of the data term or the total cost function.

Parameters
----------
rtol : :obj:`float`
Percentage of the initial cost below which the solver
will stop iterating. For example, if ``rtol`` is 0.1, the solver
will stop when the cost is below 10% of the Euclidean norm of
the data.

"""

def __init__(self, rtol: float) -> None:
self.rtol = rtol
self.stop = False

def on_setup_end(self, solver: "Solver", x: NDArray) -> None:
self.ynorm = solver.ncp.linalg.norm(solver.y)

def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if solver.cost[-1] < self.rtol * self.ynorm:
self.stop = True


class CostToInitialCallback(Callbacks):
"""Cost to initial callback

This callback can be used to stop the solver when the ``cost``
parameter of the solver is below a certain threshold defined as a
percentage of the initial residual norm.

Note that the meaning of ``cost`` can change from solver to solver - e.g.,
it can represent the misfit of the data term or the total cost function.

Parameters
----------
rtol : :obj:`float`
Percentage of the initial cost below which the solver
will stop iterating. For example, if ``rtol`` is 0.1, the solver
will stop when the cost is below 10% of the initial
cost.

"""

def __init__(self, rtol: float) -> None:
self.rtol = rtol
self.stop = False

def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if solver.cost[-1] < self.rtol * solver.cost[0]:
self.stop = True


class CostNanInfCallback(Callbacks):
"""Cost Nan/Inf callback

This callback can be used to stop the solver when the ``cost``
becomes either ``np.nan`` or ``np.inf``

"""

def __init__(self) -> None:
self.stop = False

def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if np.isnan(solver.cost[-1]) or np.isinf(solver.cost[-1]):
self.stop = True


class MetricsCallback(Callbacks):
r"""Metrics callback

Expand Down Expand Up @@ -191,32 +272,6 @@ def on_step_end(self, solver: "Solver", x: NDArray) -> None:
self.metrics["psnr"].append(psnr(self.xtrue, x))


class ResidualNormCallback(Callbacks):
"""Residual norm callback

This callback can be used to stop the solver when the residual norm
is below a certain threshold defined as a percentage of the
initial residual norm.

Parameters
----------
rtol : :obj:`float`
Percentage of the initial residual norm below which the solver
will stop iterating. For example, if `rtol` is 0.1, the solver
will stop when the residual norm is below 10% of the initial
residual norm.

"""

def __init__(self, rtol: float) -> None:
self.rtol = rtol
self.stop = False

def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if solver.cost[-1] < self.rtol * solver.cost[0]:
self.stop = True


def _callback_stop(callbacks: Sequence[Callbacks]) -> bool:
"""Check if any callback has raised a stop flag

Expand Down
2 changes: 1 addition & 1 deletion pylops/optimization/cls_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ def setup(
# create variables to track the residual norm and iterations
self.res = self.y.copy()
self.cost = [
float(np.linalg.norm(self.y)),
float(np.linalg.norm(self.res)),
]
self.iiter = 0

Expand Down
Loading