diff --git a/docs/source/addingsolver.rst b/docs/source/addingsolver.rst index 2002c6876..fbd321225 100755 --- a/docs/source/addingsolver.rst +++ b/docs/source/addingsolver.rst @@ -189,8 +189,7 @@ input and returns some of the most valuable properties of the class-based solver def cg(Op, y, x0, niter=10, tol=1e-4, rtol=0.0, show=False, itershow=(10, 10, 10), callback=None): - rcallback = ResidualNormCallback(rtol) - cgsolve = CG(Op, callbacks=[rcallback, ]) + cgsolve = CG(Op, callbacks=[CostToInitialCallback(rtol), ]) if callback is not None: cgsolve.callback = callback x, iiter, cost = cgsolve.solve( diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 0c23cf9aa..5d5c10a32 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -249,6 +249,9 @@ Callbacks :toctree: generated/ Callbacks + CostNanInfCallback + CostToDataCallback + CostToInitialCallback MetricsCallback diff --git a/pylops/optimization/basic.py b/pylops/optimization/basic.py index 32f3658c0..07493172c 100644 --- a/pylops/optimization/basic.py +++ b/pylops/optimization/basic.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple -from pylops.optimization.callback import ResidualNormCallback +from pylops.optimization.callback import CostToDataCallback, CostToInitialCallback from pylops.optimization.cls_basic import CG, CGLS, LSQR from pylops.utils.decorators import add_ndarray_support_to_solver from pylops.utils.typing import NDArray @@ -22,7 +22,8 @@ def cg( x0: Optional[NDArray] = None, niter: int = 10, tol: float = 1e-4, - rtol: bool = 0.0, + rtol: float = 0.0, + rtol1: float = 0.0, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, @@ -47,8 +48,12 @@ def cg( Absolute tolerance on residual norm. Stops the solver when the residual norm is below this value. rtol : :obj:`float`, optional - Relative tolerance on residual norm. Stops the solver when the - ratio of the current residual norm to the initial residual norm is + Relative tolerance on residual norm wrt initial residual norm. Stops + the solver when the ratio of the current residual norm to the initial + residual norm is below this value. + rtol1 : :obj:`float`, optional + Relative tolerance on residual norm wrt to data. Stops the solver + when the ratio of the current residual norm to the data norm is below this value. show : :obj:`bool`, optional Display iterations log @@ -78,12 +83,15 @@ def cg( See :class:`pylops.optimization.cls_basic.CG` """ - rcallback = ResidualNormCallback(rtol) + callbacks = [] + if rtol > 0.0: + callbacks.append(CostToInitialCallback(rtol)) + if rtol1 > 0.0: + callbacks.append(CostToDataCallback(rtol1)) + cgsolve = CG( Op, - callbacks=[ - rcallback, - ], + callbacks=callbacks if len(callbacks) > 0 else None, ) if callback is not None: cgsolve.callback = callback @@ -108,6 +116,7 @@ def cgls( damp: float = 0.0, tol: float = 1e-4, rtol: float = 0.0, + rtol1: float = 0.0, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, @@ -134,8 +143,12 @@ def cgls( Absolute tolerance on residual norm. Stops the solver when the residual norm is below this value. rtol : :obj:`float`, optional - Relative tolerance on residual norm. Stops the solver when the - ratio of the current residual norm to the initial residual norm is + Relative tolerance on residual norm wrt initial residual norm. Stops + the solver when the ratio of the current residual norm to the initial + residual norm is below this value. + rtol1 : :obj:`float`, optional + Relative tolerance on residual norm wrt to data. Stops the solver + when the ratio of the current residual norm to the data norm is below this value. show : :obj:`bool`, optional Display iterations log @@ -180,12 +193,14 @@ def cgls( See :class:`pylops.optimization.cls_basic.CGLS` """ - rcallback = ResidualNormCallback(rtol) + callbacks = [] + if rtol > 0.0: + callbacks.append(CostToInitialCallback(rtol)) + if rtol1 > 0.0: + callbacks.append(CostToDataCallback(rtol1)) cgsolve = CGLS( Op, - callbacks=[ - rcallback, - ], + callbacks=callbacks if len(callbacks) > 0 else None, ) if callback is not None: cgsolve.callback = callback diff --git a/pylops/optimization/callback.py b/pylops/optimization/callback.py index 095dd2e1e..422712ae9 100644 --- a/pylops/optimization/callback.py +++ b/pylops/optimization/callback.py @@ -1,11 +1,15 @@ __all__ = [ "Callbacks", + "CostNanInfCallback", + "CostToDataCallback", + "CostToInitialCallback", "MetricsCallback", - "ResidualNormCallback", ] from typing import TYPE_CHECKING, Dict, List, Optional, Sequence +import numpy as np + from pylops.utils.metrics import mae, mse, psnr, snr from pylops.utils.typing import NDArray @@ -141,6 +145,83 @@ def on_run_end(self, solver: "Solver", x: NDArray) -> None: pass +class CostToDataCallback(Callbacks): + """Cost to data callback + + This callback can be used to stop the solver when the ``cost`` parameter + of the solver is below a certain threshold defined as a percentage of the + Euclidean norm of the data. + + Note that the meaning of ``cost`` can change from solver to solver - e.g., + it can represent the misfit of the data term or the total cost function. + + Parameters + ---------- + rtol : :obj:`float` + Percentage of the initial cost below which the solver + will stop iterating. For example, if ``rtol`` is 0.1, the solver + will stop when the cost is below 10% of the Euclidean norm of + the data. + + """ + + def __init__(self, rtol: float) -> None: + self.rtol = rtol + self.stop = False + + def on_setup_end(self, solver: "Solver", x: NDArray) -> None: + self.ynorm = solver.ncp.linalg.norm(solver.y) + + def on_step_end(self, solver: "Solver", x: NDArray) -> None: + if solver.cost[-1] < self.rtol * self.ynorm: + self.stop = True + + +class CostToInitialCallback(Callbacks): + """Cost to initial callback + + This callback can be used to stop the solver when the ``cost`` + parameter of the solver is below a certain threshold defined as a + percentage of the initial residual norm. + + Note that the meaning of ``cost`` can change from solver to solver - e.g., + it can represent the misfit of the data term or the total cost function. + + Parameters + ---------- + rtol : :obj:`float` + Percentage of the initial cost below which the solver + will stop iterating. For example, if ``rtol`` is 0.1, the solver + will stop when the cost is below 10% of the initial + cost. + + """ + + def __init__(self, rtol: float) -> None: + self.rtol = rtol + self.stop = False + + def on_step_end(self, solver: "Solver", x: NDArray) -> None: + if solver.cost[-1] < self.rtol * solver.cost[0]: + self.stop = True + + +class CostNanInfCallback(Callbacks): + """Cost Nan/Inf callback + + This callback can be used to stop the solver when the ``cost`` + becomes either ``np.nan`` or ``np.inf`` + + """ + + def __init__(self) -> None: + self.stop = False + + def on_step_end(self, solver: "Solver", x: NDArray) -> None: + if np.isnan(solver.cost[-1]) or np.isinf(solver.cost[-1]): + self.stop = True + + class MetricsCallback(Callbacks): r"""Metrics callback @@ -191,32 +272,6 @@ def on_step_end(self, solver: "Solver", x: NDArray) -> None: self.metrics["psnr"].append(psnr(self.xtrue, x)) -class ResidualNormCallback(Callbacks): - """Residual norm callback - - This callback can be used to stop the solver when the residual norm - is below a certain threshold defined as a percentage of the - initial residual norm. - - Parameters - ---------- - rtol : :obj:`float` - Percentage of the initial residual norm below which the solver - will stop iterating. For example, if `rtol` is 0.1, the solver - will stop when the residual norm is below 10% of the initial - residual norm. - - """ - - def __init__(self, rtol: float) -> None: - self.rtol = rtol - self.stop = False - - def on_step_end(self, solver: "Solver", x: NDArray) -> None: - if solver.cost[-1] < self.rtol * solver.cost[0]: - self.stop = True - - def _callback_stop(callbacks: Sequence[Callbacks]) -> bool: """Check if any callback has raised a stop flag diff --git a/pylops/optimization/cls_sparsity.py b/pylops/optimization/cls_sparsity.py index f800a6138..8f615935e 100644 --- a/pylops/optimization/cls_sparsity.py +++ b/pylops/optimization/cls_sparsity.py @@ -1028,7 +1028,7 @@ def setup( # create variables to track the residual norm and iterations self.res = self.y.copy() self.cost = [ - float(np.linalg.norm(self.y)), + float(np.linalg.norm(self.res)), ] self.iiter = 0 diff --git a/pylops/optimization/sparsity.py b/pylops/optimization/sparsity.py index 54e155d72..70fb408e4 100644 --- a/pylops/optimization/sparsity.py +++ b/pylops/optimization/sparsity.py @@ -9,7 +9,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple -from pylops.optimization.callback import ResidualNormCallback +from pylops.optimization.callback import ( + CostNanInfCallback, + CostToDataCallback, + CostToInitialCallback, +) from pylops.optimization.cls_sparsity import FISTA, IRLS, ISTA, OMP, SPGL1, SplitBregman from pylops.utils.decorators import add_ndarray_support_to_solver from pylops.utils.typing import NDArray, SamplingLike @@ -142,6 +146,7 @@ def omp( niter_inner: int = 40, sigma: float = 1e-4, rtol: float = 0.0, + rtol1: float = 0.0, normalizecols: bool = False, Opbasis: Optional["LinearOperator"] = None, optimal_coeff: bool = False, @@ -172,9 +177,13 @@ def omp( sigma : :obj:`float`, optional Maximum :math:`L_2` norm of residual. When smaller stop iterations. rtol : :obj:`float`, optional - Relative tolerance on residual. Stops the solver when the - ratio of the current residual norm to the initial residual norm - is below this value. + Relative tolerance on residual norm wrt initial residual norm. Stops + the solver when the ratio of the current residual norm to the initial + residual norm is below this value. + rtol1 : :obj:`float`, optional + Relative tolerance on residual norm wrt to data. Stops the solver + when the ratio of the current residual norm to the data norm is + below this value. normalizecols : :obj:`list`, optional Normalize columns (``True``) or not (``False``). Note that this can be expensive as it requires applying the forward operator @@ -229,12 +238,17 @@ def omp( See :class:`pylops.optimization.cls_sparsity.OMP` """ - rcallback = ResidualNormCallback(rtol) + callbacks = [ + CostNanInfCallback(), + ] + if rtol > 0.0: + callbacks.append(CostToInitialCallback(rtol)) + if rtol1 > 0.0: + callbacks.append(CostToDataCallback(rtol1)) + ompsolve = OMP( Op, - callbacks=[ - rcallback, - ], + callbacks=callbacks, ) if callback is not None: ompsolve.callback = callback @@ -264,7 +278,8 @@ def ista( alpha: Optional[float] = None, eigsdict: Optional[Dict[str, Any]] = None, tol: float = 1e-10, - rtol: bool = 0.0, + rtol: float = 0.0, + rtol1: float = 0.0, threshkind: str = "soft", perc: Optional[float] = None, decay: Optional[NDArray] = None, @@ -309,9 +324,12 @@ def ista( Absolute tolerance on model update. Stop iterations if difference between inverted model at subsequent iterations is smaller than ``tol`` rtol : :obj:`float`, optional - Relative tolerance on total cost function. Stops the solver when the - ratio of the current cost function to the initial cost function - is below this value. + Relative tolerance on total cost function wrt initial total cost + function. Stops the solver when the ratio of the current total cost function + to the initial total cost function is below this value. + rtol1 : :obj:`float`, optional + Relative tolerance on total cost function wrt to data. Stops the solver when + the ratio of the current total cost function to the data norm is below this value. threshkind : :obj:`str`, optional Kind of thresholding ('hard', 'soft', 'half', 'hard-percentile', 'soft-percentile', or 'half-percentile' - 'soft' used as default) @@ -370,12 +388,17 @@ def ista( See :class:`pylops.optimization.cls_sparsity.ISTA` """ - rcallback = ResidualNormCallback(rtol) + callbacks = [ + CostNanInfCallback(), + ] + if rtol > 0.0: + callbacks.append(CostToInitialCallback(rtol)) + if rtol1 > 0.0: + callbacks.append(CostToDataCallback(rtol1)) + istasolve = ISTA( Op, - callbacks=[ - rcallback, - ], + callbacks=callbacks, ) if callback is not None: istasolve.callback = callback @@ -410,6 +433,7 @@ def fista( eigsdict: Optional[Dict[str, Any]] = None, tol: float = 1e-10, rtol: float = 0.0, + rtol1: float = 0.0, threshkind: str = "soft", perc: Optional[float] = None, decay: Optional[NDArray] = None, @@ -454,9 +478,12 @@ def fista( Absolute tolerance on model update. Stop iterations if difference between inverted model at subsequent iterations is smaller than ``tol`` rtol : :obj:`float`, optional - Relative tolerance on total cost function. Stops the solver when the - ratio of the current cost function to the initial cost function - is below this value. + Relative tolerance on total cost function wrt initial total cost + function. Stops the solver when the ratio of the current total cost function + to the initial total cost function is below this value. + rtol1 : :obj:`float`, optional + Relative tolerance on total cost function wrt to data. Stops the solver when + the ratio of the current total cost function to the data norm is below this value. threshkind : :obj:`str`, optional Kind of thresholding ('hard', 'soft', 'half', 'soft-percentile', or 'half-percentile' - 'soft' used as default) @@ -513,12 +540,17 @@ def fista( See :class:`pylops.optimization.cls_sparsity.FISTA` """ - rcallback = ResidualNormCallback(rtol) + callbacks = [ + CostNanInfCallback(), + ] + if rtol > 0.0: + callbacks.append(CostToInitialCallback(rtol)) + if rtol1 > 0.0: + callbacks.append(CostToDataCallback(rtol1)) + fistasolve = FISTA( Op, - callbacks=[ - rcallback, - ], + callbacks=callbacks, ) if callback is not None: fistasolve.callback = callback @@ -673,6 +705,7 @@ def splitbregman( epsRL2s: Optional[SamplingLike] = None, tol: float = 1e-10, rtol: float = 0.0, + rtol1: float = 0.0, tau: float = 1.0, restart: bool = False, engine: str = "scipy", @@ -726,9 +759,12 @@ def splitbregman( Tolerance. Stop the solver if difference between inverted model at subsequent iterations is smaller than ``tol`` rtol : :obj:`float`, optional - Relative tolerance on total cost function. Stops the solver when the - ratio of the current cost function to the initial cost function - is below this value. + Relative tolerance on total cost function wrt initial total cost + function. Stops the solver when the ratio of the current total cost function + to the initial total cost function is below this value. + rtol1 : :obj:`float`, optional + Relative tolerance on total cost function wrt to data. Stops the solver when + the ratio of the current total cost function to the data norm is below this value. tau : :obj:`float`, optional Scaling factor in the Bregman update (must be close to 1) restart : :obj:`bool`, optional @@ -765,19 +801,23 @@ def splitbregman( itn_out : :obj:`int` Iteration number of outer loop upon termination cost : :obj:`numpy.ndarray`, optional - History of cost function through iterations + History of the total cost function through iterations Notes ----- See :class:`pylops.optimization.cls_sparsity.SplitBregman` """ - rcallback = ResidualNormCallback(rtol) + callbacks = [ + CostNanInfCallback(), + ] + if rtol > 0.0: + callbacks.append(CostToInitialCallback(rtol)) + if rtol1 > 0.0: + callbacks.append(CostToDataCallback(rtol1)) sbsolve = SplitBregman( Op, - callbacks=[ - rcallback, - ], + callbacks=callbacks, ) if callback is not None: sbsolve.callback = callback diff --git a/pytests/test_solver.py b/pytests/test_solver.py index a59c08df1..13a2ca0cd 100644 --- a/pytests/test_solver.py +++ b/pytests/test_solver.py @@ -187,6 +187,7 @@ def test_cg_stopping(par): y = Aop * x + # test CostToInitialCallback callback for preallocate in [False, True]: rtol = 1e-2 _, _, cost = cg( @@ -195,6 +196,16 @@ def test_cg_stopping(par): assert cost[-2] / cost[0] >= rtol assert cost[-1] / cost[0] < rtol + # test CostToDataCallback callback + for preallocate in [False, True]: + ynorm = np.linalg.norm(y) + rtol = 1e-2 + _, _, cost = cg( + Aop, y, x0=x0, niter=par["nx"], tol=0, rtol1=rtol, preallocate=preallocate + ) + assert cost[-2] / ynorm >= rtol + assert cost[-1] / ynorm < rtol + @pytest.mark.parametrize( "par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)] @@ -307,6 +318,7 @@ def test_cgls_stopping(par): y = Aop * x + # test CostToInitialCallback callback for preallocate in [False, True]: rtol = 1e-2 cost = cgls( @@ -315,6 +327,16 @@ def test_cgls_stopping(par): assert cost[-2] / cost[0] >= rtol assert cost[-1] / cost[0] < rtol + # test CostToDataCallback callback + for preallocate in [False, True]: + ynorm = np.linalg.norm(y) + rtol = 1e-2 + cost = cgls( + Aop, y, x0=x0, niter=par["nx"], tol=0, rtol1=rtol, preallocate=preallocate + )[-1] + assert cost[-2] / ynorm >= rtol + assert cost[-1] / ynorm < rtol + @pytest.mark.skipif( int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled" diff --git a/pytests/test_sparsity.py b/pytests/test_sparsity.py index a996289ee..90eb59629 100644 --- a/pytests/test_sparsity.py +++ b/pytests/test_sparsity.py @@ -14,7 +14,7 @@ import pytest from pylops.basicoperators import FirstDerivative, Identity, MatrixMult -from pylops.optimization.callback import ResidualNormCallback +from pylops.optimization.callback import CostToInitialCallback from pylops.optimization.cls_sparsity import IRLS from pylops.optimization.sparsity import fista, irls, ista, omp, spgl1, splitbregman @@ -226,7 +226,7 @@ def test_IRLS_model_stopping(par): rtol = 6e-1 kwars_solver = dict(iter_lim=5) if backend == "numpy" else dict(niter=5) - rcallback = ResidualNormCallback(rtol) + rcallback = CostToInitialCallback(rtol) irlssolve = IRLS( Aop, callbacks=[ @@ -312,12 +312,22 @@ def test_OMP_stopping(par): y = Aop * x maxit = 100 + + # test CostToInitialCallback callback for preallocate in [False, True]: rtol = 1e-2 _, _, cost = omp(Aop, y, maxit, sigma=0.0, rtol=rtol, preallocate=preallocate) assert cost[-2] / cost[0] >= rtol assert cost[-1] / cost[0] < rtol + # test CostToDataCallback callback + for preallocate in [False, True]: + ynorm = np.linalg.norm(y) + rtol = 1e-2 + _, _, cost = omp(Aop, y, maxit, sigma=0.0, rtol1=rtol, preallocate=preallocate) + assert cost[-2] / ynorm >= rtol + assert cost[-1] / ynorm < rtol + def test_ISTA_FISTA_unknown_threshkind(): """Check error is raised if unknown threshkind is passed""" @@ -335,6 +345,45 @@ def test_ISTA_FISTA_missing_perc(): _ = fista(Identity(5), np.ones(5), 10, perc=None, threshkind="soft-percentile") +@pytest.mark.parametrize("par", [(par1), (par3), (par5), (par1j), (par3j), (par5j)]) +def test_ISTA_FISTA_alpha_too_high(par): + """Check error is raised or solver is stopped when alpha is chosen + too high""" + npp.random.seed(42) + A = npp.random.randn(par["ny"], par["nx"]) + par["imag"] * npp.random.randn( + par["ny"], par["nx"] + ) + Aop = MatrixMult(np.asarray(A), dtype=par["dtype"]) + + x = np.zeros(par["nx"]) + par["imag"] * np.zeros(par["nx"]) + x[par["nx"] // 2] = 1.0 + par["imag"] * 1.0 + y = Aop * x + + for solver in [ista, fista]: + # check that exception is raised + with pytest.raises(ValueError): + _, _, _ = solver( + Aop, + y, + niter=100, + eps=0.1, + alpha=1e5, + monitorres=True, + tol=0, + ) + + # check that CostNanInfCallback catches cost=np.inf + _, _, cost = solver( + Aop, + y, + niter=100, + eps=0.1, + alpha=1e5, + tol=0, + ) + assert np.isinf(cost[-1]) + + @pytest.mark.parametrize("par", [(par1), (par3), (par5), (par1j), (par3j), (par5j)]) def test_ISTA_FISTA(par): """Invert problem with ISTA/FISTA""" @@ -350,78 +399,42 @@ def test_ISTA_FISTA(par): x[par["nx"] - 4] = -1.0 - par["imag"] * 1.0 y = Aop * x - # some parameters need to be tuned differently for different problem sizes + # Some parameters need to be tuned differently for different problem sizes eps = 1.0 if par["ny"] >= par["nx"] else 2.0 perc = 50 if par["ny"] >= par["nx"] else 30 maxit = 500 - # ISTA with too high alpha (check that exception is raised) - with pytest.raises(ValueError): - xinv, _, _ = ista( - Aop, - y, - niter=maxit, - eps=eps, - alpha=1e5, - monitorres=True, - tol=0, - ) - # Regularization based ISTA and FISTA threshkinds = ["hard", "soft", "half"] if backend == "numpy" else ["soft", "half"] for threshkind in threshkinds: for preallocate in [False, True]: - # ISTA - xinv, _, _ = ista( - Aop, - y, - niter=maxit, - eps=eps, - threshkind=threshkind, - tol=0, - preallocate=preallocate, - ) - assert_array_almost_equal(x, xinv, decimal=1) - - # FISTA - xinv, _, _ = fista( - Aop, - y, - niter=maxit, - eps=eps, - threshkind=threshkind, - tol=0, - preallocate=preallocate, - ) - assert_array_almost_equal(x, xinv, decimal=1) - - # Percentile based ISTA and FISTA - if backend == "numpy": - for threshkind in ["hard-percentile", "soft-percentile", "half-percentile"]: - for preallocate in [False, True]: - # ISTA - xinv, _, _ = ista( + for solver in [ista, fista]: + xinv, _, _ = solver( Aop, y, niter=maxit, - perc=perc, + eps=eps, threshkind=threshkind, tol=0, preallocate=preallocate, ) assert_array_almost_equal(x, xinv, decimal=1) - # FISTA - xinv, _, _ = fista( - Aop, - y, - niter=maxit, - perc=perc, - threshkind=threshkind, - tol=0, - preallocate=preallocate, - ) - assert_array_almost_equal(x, xinv, decimal=1) + # Percentile based ISTA and FISTA + if backend == "numpy": + for threshkind in ["hard-percentile", "soft-percentile", "half-percentile"]: + for preallocate in [False, True]: + for solver in [ista, fista]: + xinv, _, _ = solver( + Aop, + y, + niter=maxit, + perc=perc, + threshkind=threshkind, + tol=0, + preallocate=preallocate, + ) + assert_array_almost_equal(x, xinv, decimal=1) @pytest.mark.parametrize("par", [(par1), (par3), (par5), (par1j), (par3j), (par5j)]) @@ -440,7 +453,7 @@ def test_ISTA_FISTA_multiplerhs(par): x = np.outer(x, np.ones(3)) y = Aop * x - # some parameters need to be tuned differently for different problem sizes + # Some parameters need to be tuned differently for different problem sizes eps = 1.0 if par["ny"] >= par["nx"] else 2.0 perc = 50 if par["ny"] >= par["nx"] else 30 maxit = 500 @@ -449,57 +462,33 @@ def test_ISTA_FISTA_multiplerhs(par): threshkinds = ["hard", "soft", "half"] if backend == "numpy" else ["soft", "half"] for threshkind in threshkinds: for preallocate in [False, True]: - # ISTA - xinv, _, _ = ista( - Aop, - y, - niter=maxit, - eps=eps, - threshkind=threshkind, - tol=0, - preallocate=preallocate, - ) - assert_array_almost_equal(x, xinv, decimal=1) - - # FISTA - xinv, _, _ = fista( - Aop, - y, - niter=maxit, - eps=eps, - threshkind=threshkind, - tol=0, - preallocate=preallocate, - ) - assert_array_almost_equal(x, xinv, decimal=1) - - # Percentile based ISTA and FISTA - if backend == "numpy": - for threshkind in ["hard-percentile", "soft-percentile", "half-percentile"]: - for preallocate in [False, True]: - # ISTA - xinv, _, _ = ista( + for solver in [ista, fista]: + xinv, _, _ = solver( Aop, y, niter=maxit, - perc=perc, + eps=eps, threshkind=threshkind, tol=0, preallocate=preallocate, ) assert_array_almost_equal(x, xinv, decimal=1) - # FISTA - xinv, _, _ = fista( - Aop, - y, - niter=maxit, - perc=perc, - threshkind=threshkind, - tol=0, - preallocate=preallocate, - ) - assert_array_almost_equal(x, xinv, decimal=1) + # Percentile based ISTA and FISTA + if backend == "numpy": + for threshkind in ["hard-percentile", "soft-percentile", "half-percentile"]: + for preallocate in [False, True]: + for solver in [ista, fista]: + xinv, _, _ = solver( + Aop, + y, + niter=maxit, + perc=perc, + threshkind=threshkind, + tol=0, + preallocate=preallocate, + ) + assert_array_almost_equal(x, xinv, decimal=1) @pytest.mark.parametrize("par", [(par1), (par3), (par5), (par1j), (par3j), (par5j)]) @@ -525,34 +514,19 @@ def test_ISTA_FISTA_stopping(par): threshkinds = ["hard", "soft", "half"] if backend == "numpy" else ["soft", "half"] for threshkind in threshkinds: for preallocate in [False, True]: - - # ISTA - _, _, cost = ista( - Aop, - y, - niter=maxit, - eps=eps, - threshkind=threshkind, - tol=0.0, - rtol=rtol, - preallocate=preallocate, - ) - assert cost[-2] / cost[0] >= rtol - assert cost[-1] / cost[0] < rtol - - # FISTA - _, _, cost = fista( - Aop, - y, - niter=maxit, - eps=eps, - threshkind=threshkind, - tol=0.0, - rtol=rtol, - preallocate=preallocate, - ) - assert cost[-2] / cost[0] >= rtol - assert cost[-1] / cost[0] < rtol + for solver in [ista, fista]: + _, _, cost = solver( + Aop, + y, + niter=maxit, + eps=eps, + threshkind=threshkind, + tol=0.0, + rtol=rtol, + preallocate=preallocate, + ) + assert cost[-2] / cost[0] >= rtol + assert cost[-1] / cost[0] < rtol @pytest.mark.skipif(