Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions docs/source/addingsolver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ Implementing new solvers
========================
Users are welcome to create new solvers and add them to the PyLops library.

In this tutorial, we will go through the key steps in the definition of a solver, using the
:py:class:`pylops.optimization.basic.CG` as an example.
In this tutorial, we will go through the key steps in the definition of a solver, using a
sligthly simplified version of :py:class:`pylops.optimization.basic.CG` as an example.

.. note::
In case the solver that you are planning to create falls within the category of proximal solvers,
Expand Down Expand Up @@ -83,7 +83,6 @@ note that the ``setup`` method returns the created starting guess ``x`` (does no
.. code-block:: python

def setup(self, y, x0=None, niter=None, tol=1e-4, show=False):

self.y = y
self.tol = tol
self.niter = niter
Expand Down Expand Up @@ -134,7 +133,26 @@ can add additional input parameters. For CG, the step is:


Similarly, we also implement a ``run`` method that is in charge of running a number of iterations by repeatedly
calling the ``step`` method. It is also usually convenient to implement a finalize method; this method can do any required post-processing that should
calling the ``step`` method.

.. code-block:: python

def run(self, x, niter, show, itershow):
while self.iiter < niter and self.kold > self.tol:
x = self.step(x, showstep)
self.callback(x)
# check if any callback has raised a stop flag
stop = _callback_stop(self.callbacks)
if stop:
break
return x

It is worth noting that any number of callbacks can be attached to the solver; some of these
callbacks can implement a stopping criterion and set the ``stop`` member to True when a given
condition is met. The ``_callback_stop`` method is in change of checking if any of the callbacks
has set ``stop`` to True and in the case break the iterations.

Finally, it is also usually convenient to implement a ``finalize`` method; this method can do any required post-processing that should
not be applied at the end of each step, rather at the end of the entire optimization process. For CG, this is as simple
as converting the ``cost`` variable from a list to a ``numpy`` array. For more details, see our implementations for CG.

Expand Down Expand Up @@ -169,8 +187,10 @@ input and returns some of the most valuable properties of the class-based solver

.. code-block:: python

def cg(Op, y, x0, niter=10, tol=1e-4, show=False, itershow=(10, 10, 10), callback=None):
cgsolve = CG(Op)
def cg(Op, y, x0, niter=10, tol=1e-4, rtol=0.0,
show=False, itershow=(10, 10, 10), callback=None):
rcallback = ResidualNormCallback(rtol)
cgsolve = CG(Op, callbacks=[rcallback, ])
if callback is not None:
cgsolve.callback = callback
x, iiter, cost = cgsolve.solve(
Expand Down
33 changes: 29 additions & 4 deletions pylops/optimization/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import TYPE_CHECKING, Callable, Optional, Tuple

from pylops.optimization.callback import ResidualNormCallback
from pylops.optimization.cls_basic import CG, CGLS, LSQR
from pylops.utils.decorators import add_ndarray_support_to_solver
from pylops.utils.typing import NDArray
Expand All @@ -21,6 +22,7 @@ def cg(
x0: Optional[NDArray] = None,
niter: int = 10,
tol: float = 1e-4,
rtol: bool = 0.0,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
Expand All @@ -42,7 +44,12 @@ def cg(
niter : :obj:`int`, optional
Number of iterations
tol : :obj:`float`, optional
Tolerance on residual norm
Absolute tolerance on residual norm. Stops the solver when the
residual norm is below this value.
rtol : :obj:`float`, optional
Relative tolerance on residual norm. Stops the solver when the
ratio of the current residual norm to the initial residual norm is
below this value.
show : :obj:`bool`, optional
Display iterations log
itershow : :obj:`tuple`, optional
Expand Down Expand Up @@ -71,7 +78,13 @@ def cg(
See :class:`pylops.optimization.cls_basic.CG`

"""
cgsolve = CG(Op)
rcallback = ResidualNormCallback(rtol)
cgsolve = CG(
Op,
callbacks=[
rcallback,
],
)
if callback is not None:
cgsolve.callback = callback
x, iiter, cost = cgsolve.solve(
Expand All @@ -94,6 +107,7 @@ def cgls(
niter: int = 10,
damp: float = 0.0,
tol: float = 1e-4,
rtol: float = 0.0,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
Expand All @@ -117,7 +131,12 @@ def cgls(
damp : :obj:`float`, optional
Damping coefficient
tol : :obj:`float`, optional
Tolerance on residual norm
Absolute tolerance on residual norm. Stops the solver when the
residual norm is below this value.
rtol : :obj:`float`, optional
Relative tolerance on residual norm. Stops the solver when the
ratio of the current residual norm to the initial residual norm is
below this value.
show : :obj:`bool`, optional
Display iterations log
itershow : :obj:`tuple`, optional
Expand Down Expand Up @@ -161,7 +180,13 @@ def cgls(
See :class:`pylops.optimization.cls_basic.CGLS`

"""
cgsolve = CGLS(Op)
rcallback = ResidualNormCallback(rtol)
cgsolve = CGLS(
Op,
callbacks=[
rcallback,
],
)
if callback is not None:
cgsolve.callback = callback
x, istop, iiter, r1norm, r2norm, cost = cgsolve.solve(
Expand Down
66 changes: 62 additions & 4 deletions pylops/optimization/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"Callbacks",
"MetricsCallback",
"ResidualNormCallback",
]

from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
Expand Down Expand Up @@ -28,25 +29,32 @@ class Callbacks:

All methods take two input parameters: the solver itself, and the vector ``x``.

Moreover, some callback may be used to implement custom stopping criteria for the solver.
This can be done by adding a boolean attribute ``stop`` to the callback object, which will
be initially set to ``False``. As soon as the callback sets this attribute to ``True``, the
``run`` method of the solver will stop iterating and return the current model vector.

Examples
--------
>>> import numpy as np
>>> from pylops.basicoperators import MatrixMult
>>> from pylops.optimization.solver import CG
>>> from pylops.optimization.basic import CG
>>> from pylops.optimization.callback import Callbacks
>>>
>>> class StoreIterCallback(Callbacks):
... def __init__(self):
... self.stored = []
... def on_step_end(self, solver, x):
... self.stored.append(solver.iiter)
>>> cb_sto = StoreIterCallback()
>>>
>>> Aop = MatrixMult(np.random.normal(0., 1., 36).reshape(6, 6))
>>> Aop = Aop.H @ Aop
>>> y = Aop @ np.ones(6)
>>> cb_sto = StoreIterCallback()
>>> cgsolve = CG(Aop, callbacks=[cb_sto, ])
>>> xest = cgsolve.solve(y=y, x0=np.zeros(6), tol=0, niter=6, show=False)[0]
>>> xest
array([1., 1., 1., 1., 1., 1.])
>>> xest, cb_sto.stored
(array([1., 1., 1., 1., 1., 1.]), [1, 2, 3, 4, 5, 6])

"""

Expand Down Expand Up @@ -181,3 +189,53 @@ def on_step_end(self, solver: "Solver", x: NDArray) -> None:
self.metrics["snr"].append(snr(self.xtrue, x))
if "psnr" in self.which:
self.metrics["psnr"].append(psnr(self.xtrue, x))


class ResidualNormCallback(Callbacks):
"""Residual norm callback

This callback can be used to stop the solver when the residual norm
is below a certain threshold defined as a percentage of the
initial residual norm.

Parameters
----------
rtol : :obj:`float`
Percentage of the initial residual norm below which the solver
will stop iterating. For example, if `rtol` is 0.1, the solver
will stop when the residual norm is below 10% of the initial
residual norm.

"""

def __init__(self, rtol: float) -> None:
self.rtol = rtol
self.stop = False

def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if solver.cost[-1] < self.rtol * solver.cost[0]:
self.stop = True


def _callback_stop(callbacks: Sequence[Callbacks]) -> bool:
"""Check if any callback has raised a stop flag

Parameters
----------
callbacks : :obj:`pylops.optimization.callback.Callbacks`
List of callbacks to evaluate

Returns
-------
stop : :obj:`bool`
Whether to stop the solver or not

"""
if callbacks is not None:
stop = [
False if not hasattr(callback, "stop") else callback.stop
for callback in callbacks
]
if any(stop):
return True
return False
44 changes: 33 additions & 11 deletions pylops/optimization/cls_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np

from pylops.optimization.basesolver import Solver, _units
from pylops.optimization.callback import _callback_stop
from pylops.utils.backend import (
get_array_module,
get_module_name,
Expand Down Expand Up @@ -121,7 +122,8 @@ def setup(
Number of iterations (default to ``None`` in case a user wants to
manually step over the solver)
tol : :obj:`float`, optional
Tolerance on residual norm
Absolute tolerance on residual norm. Stops the solver when the
residual norm is below this value.
preallocate : :obj:`bool`, optional
.. versionadded:: 2.6.0

Expand Down Expand Up @@ -260,6 +262,10 @@ def run(
)
x = self.step(x, showstep)
self.callback(x)
# check if any callback has raised a stop flag
stop = _callback_stop(self.callbacks)
if stop:
break
return x

def finalize(self, show: bool = False) -> None:
Expand Down Expand Up @@ -299,7 +305,8 @@ def solve(
niter : :obj:`int`, optional
Number of iterations
tol : :obj:`float`, optional
Tolerance on residual norm
Absolute tolerance on residual norm. Stops the solver when the
residual norm is below this value.
preallocate : :obj:`bool`, optional
.. versionadded:: 2.6.0

Expand Down Expand Up @@ -440,7 +447,8 @@ def setup(
damp : :obj:`float`, optional
Damping coefficient
tol : :obj:`float`, optional
Tolerance on residual norm
Absolute tolerance on residual norm. Stops the solver when the
residual norm is below this value.
preallocate : :obj:`bool`, optional
.. versionadded:: 2.6.0

Expand Down Expand Up @@ -592,22 +600,26 @@ def run(
Estimated model of size :math:`[M \times 1]`

"""
niter = self.niter if niter is None else niter
if niter is None:
self.niter = self.niter if niter is None else niter
if self.niter is None:
raise ValueError("niter must not be None")
while self.iiter < niter and self.kold > self.tol:
while self.iiter < self.niter and self.kold > self.tol:
showstep = (
True
if show
and (
self.iiter < itershow[0]
or niter - self.iiter < itershow[1]
or self.niter - self.iiter < itershow[1]
or self.iiter % itershow[2] == 0
)
else False
)
x = self.step(x, showstep)
self.callback(x)
# check if any callback has raised a stop flag
stop = _callback_stop(self.callbacks)
if stop:
break
return x

def finalize(self, show: bool = False) -> None:
Expand All @@ -622,7 +634,12 @@ def finalize(self, show: bool = False) -> None:
self.tend = time.time()
self.telapsed = self.tend - self.tstart
# reason for termination
self.istop = 1 if self.kold < self.tol else 2
if self.kold < self.tol:
self.istop = 1
elif self.iiter >= self.niter:
self.istop = 2
else:
self.istop = 3
self.r1norm = self.kold
self.r2norm = self.cost1[self.iiter]
if show:
Expand Down Expand Up @@ -655,7 +672,8 @@ def solve(
damp : :obj:`float`, optional
Damping coefficient
tol : :obj:`float`, optional
Tolerance on residual norm
Absolute tolerance on residual norm. Stops the solver when the
residual norm is below this value.
preallocate : :obj:`bool`, optional
.. versionadded:: 2.6.0

Expand All @@ -677,10 +695,14 @@ def solve(
Gives the reason for termination

``1`` means :math:`\mathbf{x}` is an approximate solution to
:math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}`
:math:`\mathbf{y} = \mathbf{Op}\,\mathbf{x}` with the provided
tolerance ``tol``

``2`` means :math:`\mathbf{x}` approximately solves the least-squares
problem
problem (reached the maximum number of iterations ``niter``)

``3`` means another stopping criterion implemented via a callback
was reached
iit : :obj:`int`
Iteration number upon termination
r1norm : :obj:`float`
Expand Down
6 changes: 3 additions & 3 deletions pylops/optimization/cls_leastsquares.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
from scipy.sparse.linalg import cg as sp_cg
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import lsqr as sp_lsqr

from pylops.basicoperators import Diagonal, VStack
from pylops.optimization.basesolver import Solver, _units
Expand Down Expand Up @@ -676,7 +676,7 @@ def run(
if engine == "scipy" and self.ncp == np:
if show:
kwargs_solver["show"] = 1
xinv, istop, itn, r1norm, r2norm = lsqr(
xinv, istop, itn, r1norm, r2norm = sp_lsqr(
self.RegOp, self.datatot, x0=x, **kwargs_solver
)[0:5]
elif engine == "pylops" or self.ncp != np:
Expand Down Expand Up @@ -938,7 +938,7 @@ def run(
if engine == "scipy" and self.ncp == np:
if show:
kwargs_solver["show"] = 1
pinv, istop, itn, r1norm, r2norm = lsqr(
pinv, istop, itn, r1norm, r2norm = sp_lsqr(
self.POp,
self.y,
x0=x,
Expand Down
Loading
Loading