diff --git a/docs/source/sg_execution_times.rst b/docs/source/sg_execution_times.rst deleted file mode 100644 index 9641132d2..000000000 --- a/docs/source/sg_execution_times.rst +++ /dev/null @@ -1,271 +0,0 @@ - -:orphan: - -.. _sphx_glr_sg_execution_times: - - -Computation times -================= -**00:00.703** total execution time for 79 files **from all galleries**: - -.. container:: - - .. raw:: html - - - - - - - - .. list-table:: - :header-rows: 1 - :class: table table-striped sg-datatable - - * - Example - - Time - - Mem (MB) - * - :ref:`sphx_glr_tutorials_jaxop.py` (``../../tutorials/jaxop.py``) - - 00:00.703 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_avo.py` (``../../examples/plot_avo.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_bayeslinearregr.py` (``../../examples/plot_bayeslinearregr.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_bilinear.py` (``../../examples/plot_bilinear.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_blending.py` (``../../examples/plot_blending.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_causalintegration.py` (``../../examples/plot_causalintegration.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_cgls.py` (``../../examples/plot_cgls.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_chirpradon.py` (``../../examples/plot_chirpradon.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_conj.py` (``../../examples/plot_conj.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_convolve.py` (``../../examples/plot_convolve.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_dct.py` (``../../examples/plot_dct.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_derivative.py` (``../../examples/plot_derivative.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_describe.py` (``../../examples/plot_describe.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_diagonal.py` (``../../examples/plot_diagonal.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_dtcwt.py` (``../../examples/plot_dtcwt.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_fft.py` (``../../examples/plot_fft.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_flip.py` (``../../examples/plot_flip.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_fourierradon.py` (``../../examples/plot_fourierradon.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_identity.py` (``../../examples/plot_identity.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_imag.py` (``../../examples/plot_imag.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_ista.py` (``../../examples/plot_ista.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_l1l1.py` (``../../examples/plot_l1l1.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_linearregr.py` (``../../examples/plot_linearregr.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_matrixmult.py` (``../../examples/plot_matrixmult.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_mdc.py` (``../../examples/plot_mdc.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_multiproc.py` (``../../examples/plot_multiproc.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_nmo.py` (``../../examples/plot_nmo.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_nonstatconvolve.py` (``../../examples/plot_nonstatconvolve.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_nonstatfilter.py` (``../../examples/plot_nonstatfilter.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_pad.py` (``../../examples/plot_pad.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_patching.py` (``../../examples/plot_patching.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_phaseshift.py` (``../../examples/plot_phaseshift.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_prestack.py` (``../../examples/plot_prestack.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_radon.py` (``../../examples/plot_radon.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_real.py` (``../../examples/plot_real.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_regr.py` (``../../examples/plot_regr.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_restriction.py` (``../../examples/plot_restriction.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_roll.py` (``../../examples/plot_roll.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_seislet.py` (``../../examples/plot_seislet.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_seismicevents.py` (``../../examples/plot_seismicevents.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_shift.py` (``../../examples/plot_shift.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_sliding.py` (``../../examples/plot_sliding.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_slopeest.py` (``../../examples/plot_slopeest.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_smoothing1d.py` (``../../examples/plot_smoothing1d.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_smoothing2d.py` (``../../examples/plot_smoothing2d.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_spread.py` (``../../examples/plot_spread.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_stacking.py` (``../../examples/plot_stacking.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_sum.py` (``../../examples/plot_sum.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_symmetrize.py` (``../../examples/plot_symmetrize.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_tapers.py` (``../../examples/plot_tapers.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_tndarray.py` (``../../examples/plot_tndarray.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_transpose.py` (``../../examples/plot_transpose.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_tvreg.py` (``../../examples/plot_tvreg.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_twoway.py` (``../../examples/plot_twoway.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_wavelet.py` (``../../examples/plot_wavelet.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_wavest.py` (``../../examples/plot_wavest.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_wavs.py` (``../../examples/plot_wavs.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_gallery_plot_zero.py` (``../../examples/plot_zero.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_bayesian.py` (``../../tutorials/bayesian.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_classsolvers.py` (``../../tutorials/classsolvers.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_ctscan.py` (``../../tutorials/ctscan.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_deblending.py` (``../../tutorials/deblending.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_deblurring.py` (``../../tutorials/deblurring.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_deghosting.py` (``../../tutorials/deghosting.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_dottest.py` (``../../tutorials/dottest.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_ilsm.py` (``../../tutorials/ilsm.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_interpolation.py` (``../../tutorials/interpolation.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_linearoperator.py` (``../../tutorials/linearoperator.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_lsm.py` (``../../tutorials/lsm.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_marchenko.py` (``../../tutorials/marchenko.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_mdd.py` (``../../tutorials/mdd.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_poststack.py` (``../../tutorials/poststack.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_prestack.py` (``../../tutorials/prestack.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_radonfiltering.py` (``../../tutorials/radonfiltering.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_realcomplex.py` (``../../tutorials/realcomplex.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_seismicinterpolation.py` (``../../tutorials/seismicinterpolation.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_solvers.py` (``../../tutorials/solvers.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_torchop.py` (``../../tutorials/torchop.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_tutorials_wavefielddecomposition.py` (``../../tutorials/wavefielddecomposition.py``) - - 00:00.000 - - 0.0 diff --git a/pylops/optimization/basesolver.py b/pylops/optimization/basesolver.py index ded819424..490abcf3b 100644 --- a/pylops/optimization/basesolver.py +++ b/pylops/optimization/basesolver.py @@ -1,6 +1,7 @@ __all__ = ["Solver"] import functools +import logging import time from abc import ABCMeta, abstractmethod from typing import TYPE_CHECKING, Any @@ -11,6 +12,8 @@ if TYPE_CHECKING: from pylops.linearoperator import LinearOperator +_units = {"B": 1, "KB": 1024, "MB": 1024**2, "GB": 1024**3} + class Solver(metaclass=ABCMeta): r"""Solver @@ -19,6 +22,8 @@ class Solver(metaclass=ABCMeta): This class comprises of the following mandatory methods: - ``__init__``: initialization method to which the operator `Op` must be passed + - ``memory_usage``: a method to compute upfront the memory used by each + step of the solver - ``setup``: a method that is invoked to setup the solver, basically it will create anything required prior to applying a step of the solver - ``step``: a method applying a single step of the solver @@ -121,11 +126,55 @@ def wrapper(*args, **kwargs): ), ) + def _setpreallocate(self, preallocate: bool) -> None: + # Check if the solver can work in preallocate mode + # (basically all the time except when JAX arrays are + # used) and force it to be False otherwise. + self.preallocate = preallocate if not self.isjax else False + + if preallocate and self.isjax: + logging.warning( + "Preallocation is not supported for JAX arrays. " + "Setting preallocate to False." + ) + + @abstractmethod + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + This method computes an estimate of the memory required by the solver given + the shape of the operator. This is useful to assess upfront if the solver + will run out of memory. + + Note, that the memory usage of the operator itself is not taken into account + in this estimate. + + Parameters + ---------- + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in bytes + + """ + pass + @abstractmethod def setup( self, y: NDArray, *args, + preallocate: bool = False, show: bool = False, **kwargs, ) -> None: @@ -138,6 +187,10 @@ def setup( ---------- y : :obj:`np.ndarray` Data of size :math:`[N \times 1]` + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. show : :obj:`bool`, optional Display setup log diff --git a/pylops/optimization/basic.py b/pylops/optimization/basic.py index 34b03bb71..f8b060d2d 100644 --- a/pylops/optimization/basic.py +++ b/pylops/optimization/basic.py @@ -24,6 +24,7 @@ def cg( show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, + preallocate: bool = False, ) -> Tuple[NDArray, int, NDArray]: r"""Conjugate gradient @@ -51,6 +52,10 @@ def cg( callback : :obj:`callable`, optional Function with signature (``callback(x)``) to call after each iteration where ``x`` is the current model vector + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver Returns ------- @@ -70,7 +75,13 @@ def cg( if callback is not None: cgsolve.callback = callback x, iiter, cost = cgsolve.solve( - y=y, x0=x0, tol=tol, niter=niter, show=show, itershow=itershow + y=y, + x0=x0, + tol=tol, + niter=niter, + show=show, + itershow=itershow, + preallocate=preallocate, ) return x, iiter, cost @@ -86,6 +97,7 @@ def cgls( show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, + preallocate: bool = False, ) -> Tuple[NDArray, int, int, float, float, NDArray]: r"""Conjugate gradient least squares @@ -115,6 +127,10 @@ def cgls( callback : :obj:`callable`, optional Function with signature (``callback(x)``) to call after each iteration where ``x`` is the current model vector + preallocate : :obj:`bool`, optional + .. versionadded:: 2.5.0 + + Pre-allocate all variables used by the solver Returns ------- @@ -149,7 +165,14 @@ def cgls( if callback is not None: cgsolve.callback = callback x, istop, iiter, r1norm, r2norm, cost = cgsolve.solve( - y=y, x0=x0, tol=tol, niter=niter, damp=damp, show=show, itershow=itershow + y=y, + x0=x0, + tol=tol, + niter=niter, + damp=damp, + show=show, + itershow=itershow, + preallocate=preallocate, ) return x, istop, iiter, r1norm, r2norm, cost @@ -168,6 +191,7 @@ def lsqr( show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, + preallocate: bool = False, ) -> Tuple[NDArray, int, int, float, float, float, float, float, float, float, NDArray]: r"""LSQR @@ -213,6 +237,10 @@ def lsqr( callback : :obj:`callable`, optional Function with signature (``callback(x)``) to call after each iteration where ``x`` is the current model vector + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver Returns ------- @@ -299,5 +327,6 @@ def lsqr( calc_var=calc_var, show=show, itershow=itershow, + preallocate=preallocate, ) return x, istop, iiter, r1norm, r2norm, anorm, acond, arnorm, xnorm, var, cost diff --git a/pylops/optimization/cls_basic.py b/pylops/optimization/cls_basic.py index 1fb1ee7da..0713773ea 100644 --- a/pylops/optimization/cls_basic.py +++ b/pylops/optimization/cls_basic.py @@ -9,10 +9,10 @@ import numpy as np -from pylops.optimization.basesolver import Solver +from pylops.optimization.basesolver import Solver, _units from pylops.utils.backend import ( get_array_module, - to_cupy_conditional, + get_module_name, to_numpy, to_numpy_conditional, ) @@ -64,12 +64,48 @@ def _print_step(self, x: NDArray) -> None: msg = f"{self.iiter:6g} " + strx + f"{self.cost[self.iiter]:11.4e}" print(msg) + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: x0 - y, self.r, self.c + memuse = (self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes + + # Step (additional variables to those in setup): c1 - Opc + memuse += (self.Op.shape[1] + self.Op.shape[0]) * nbytes + + if show: + print(f"CG predicted memory usage: {memuse / _units[unit]:.2f} {unit}") + + return memuse + def setup( self, y: NDArray, x0: Optional[NDArray] = None, niter: Optional[int] = None, tol: float = 1e-4, + preallocate: bool = False, show: bool = False, ) -> NDArray: r"""Setup solver @@ -86,6 +122,13 @@ def setup( manually step over the solver) tol : :obj:`float`, optional Tolerance on residual norm + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. + show : :obj:`bool`, optional Display setup log @@ -98,7 +141,10 @@ def setup( self.y = y self.niter = niter self.tol = tol + self.ncp = get_array_module(y) + self.isjax = get_module_name(self.ncp) == "jax" + self._setpreallocate(preallocate) # initialize solver if x0 is None: @@ -106,10 +152,18 @@ def setup( self.r = self.y.copy() else: x = x0.copy() - self.r = self.y - self.Op.matvec(x) + if not self.preallocate: + self.r = self.y - self.Op.matvec(x) + else: + self.r = self.ncp.empty_like(self.y) + self.ncp.subtract(self.y, self.Op.matvec(x), out=self.r) self.c = self.r.copy() self.kold = self.ncp.abs(self.r.dot(self.r.conj())) + # initialize other internal variabled + if self.preallocate: + self.c1 = self.ncp.empty_like(x) + # create variables to track the residual norm and iterations self.cost: List = [] self.cost.append(float(np.sqrt(self.kold))) @@ -136,14 +190,24 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: Updated model vector """ - Opc = self.Op.matvec(to_cupy_conditional(x, self.c)) + Opc = self.Op.matvec(self.c) cOpc = self.ncp.abs(self.c.dot(Opc.conj())) a = self.kold / cOpc - x += to_cupy_conditional(x, a) * to_cupy_conditional(x, self.c) - self.r -= a * Opc + if not self.preallocate: + x += a * self.c + self.r -= a * Opc + else: + self.ncp.multiply(self.c, a, out=self.c1) + self.ncp.add(x, self.c1, out=x) + self.ncp.multiply(Opc, a, out=Opc) + self.ncp.subtract(self.r, Opc, out=self.r) k = self.ncp.abs(self.r.dot(self.r.conj())) b = k / self.kold - self.c = self.r + b * self.c + if not self.preallocate: + self.c = self.r + b * self.c + else: + self.ncp.multiply(self.c, b, out=self.c) + self.ncp.add(self.c, self.r, out=self.c) self.kold = k self.iiter += 1 self.cost.append(float(np.sqrt(self.kold))) @@ -219,6 +283,7 @@ def solve( x0: Optional[NDArray] = None, niter: int = 10, tol: float = 1e-4, + preallocate: bool = False, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), ) -> Tuple[NDArray, int, NDArray]: @@ -235,6 +300,12 @@ def solve( Number of iterations tol : :obj:`float`, optional Tolerance on residual norm + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -252,7 +323,9 @@ def solve( History of the L2 norm of the residual """ - x = self.setup(y=y, x0=x0, niter=niter, tol=tol, show=show) + x = self.setup( + y=y, x0=x0, niter=niter, tol=tol, preallocate=preallocate, show=show + ) x = self.run(x, niter, show=show, itershow=itershow) self.finalize(show) return x, self.iiter, self.cost @@ -307,6 +380,41 @@ def _print_step(self, x: NDArray) -> None: ) print(msg) + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: x0, self.c - y, self.s, self.q + memuse = (2 * self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes + + # Step (additional variables to those in setup): r, x1, c1 + memuse += (3 * self.Op.shape[1]) * nbytes + + if show: + print(f"CGLS predicted memory usage: {memuse / _units[unit]:.2f} {unit}") + + return memuse + def setup( self, y: NDArray, @@ -314,6 +422,7 @@ def setup( niter: Optional[int] = None, damp: float = 0.0, tol: float = 1e-4, + preallocate: bool = False, show: bool = False, ) -> NDArray: r"""Setup solver @@ -323,7 +432,7 @@ def setup( y : :obj:`np.ndarray` Data of size :math:`[N \times 1]` x0 : :obj:`np.ndarray`, optional - Initial guess of size :math:`[M \times 1]`. If ``None``, initialize + Initial guess of size :math:`[M \times 1]`. If ``None``, initialize internally as zero vector niter : :obj:`int`, optional Number of iterations (default to ``None`` in case a user wants to @@ -332,6 +441,12 @@ def setup( Damping coefficient tol : :obj:`float`, optional Tolerance on residual norm + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display setup log @@ -345,20 +460,36 @@ def setup( self.damp = damp**2 self.tol = tol self.niter = niter + self.ncp = get_array_module(y) + self.isjax = get_module_name(self.ncp) == "jax" + self._setpreallocate(preallocate) # initialize solver if x0 is None: x = self.ncp.zeros(self.Op.shape[1], dtype=y.dtype) self.s = self.y.copy() - r = self.Op.rmatvec(self.s) + self.c = self.Op.rmatvec(self.s) else: x = x0.copy() - self.s = self.y - self.Op.matvec(x) - r = self.Op.rmatvec(self.s) - damp * x - self.c = r.copy() + if not self.preallocate: + self.s = self.y - self.Op.matvec(x) + self.c = self.Op.rmatvec(self.s) - damp * x + else: + self.s = self.ncp.empty_like(self.y) + self.ncp.subtract(self.y, self.Op.matvec(x), out=self.s) + x1 = self.ncp.empty_like(x) + self.c = self.ncp.empty_like(x) + self.ncp.multiply(x, damp, out=x1) + self.ncp.subtract(self.Op.rmatvec(self.s), x1, out=self.c) self.q = self.Op.matvec(self.c) - self.kold = self.ncp.abs(r.dot(r.conj())) + self.kold = self.ncp.abs(self.c.dot(self.c.conj())) + + # initialize other internal variables + if self.preallocate: + self.c1 = self.ncp.empty_like(self.c) + self.x1 = self.ncp.empty_like(x) + self.r = self.ncp.empty_like(x) # create variables to track the residual norm and iterations self.cost = [] @@ -390,12 +521,32 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: a = self.kold / ( self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj()) ) - x = x + a * self.c - self.s = self.s - to_numpy_conditional(self.q, a) * self.q - r = self.Op.rmatvec(self.s) - self.damp * x - k = self.ncp.abs(r.dot(r.conj())) + if not self.preallocate: + x = x + a * self.c + self.s = self.s - a * self.q + r = self.Op.rmatvec(self.s) - self.damp * x + else: + self.ncp.multiply(self.c, a, out=self.c1) + self.ncp.add(x, self.c1, out=x) + + self.ncp.multiply(self.q, a, out=self.q) + self.ncp.subtract(self.s, self.q, out=self.s) + + self.ncp.multiply(x, self.damp, out=self.x1) + self.ncp.subtract( + self.Op.rmatvec(self.s), + self.x1, + out=self.r, + ) + k = self.ncp.abs( + self.r.dot(self.r.conj()) if self.preallocate else r.dot(r.conj()) + ) b = k / self.kold - self.c = r + b * self.c + if not self.preallocate: + self.c = r + b * self.c + else: + self.ncp.multiply(self.c, b, out=self.c) + self.ncp.add(self.c, self.r, out=self.c) self.q = self.Op.matvec(self.c) self.kold = k self.iiter += 1 @@ -485,6 +636,7 @@ def solve( niter: int = 10, damp: float = 0.0, tol: float = 1e-4, + preallocate: bool = False, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), ) -> Tuple[NDArray, int, int, float, float, NDArray]: @@ -504,6 +656,12 @@ def solve( Damping coefficient tol : :obj:`float`, optional Tolerance on residual norm + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -536,7 +694,15 @@ def solve( History of r1norm through iterations """ - x = self.setup(y=y, x0=x0, niter=niter, damp=damp, tol=tol, show=show) + x = self.setup( + y=y, + x0=x0, + niter=niter, + damp=damp, + tol=tol, + preallocate=preallocate, + show=show, + ) x = self.run(x, niter, show=show, itershow=itershow) self.finalize(show) return x, self.istop, self.iiter, self.r1norm, self.r2norm, self.cost @@ -631,6 +797,41 @@ def _print_finalize(self) -> None: print(str5) print("-" * 90 + "\n") + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: x0, self.v, self.w, self.dk - y, self.u + memuse = (4 * self.Op.shape[1] + 2 * self.Op.shape[0]) * nbytes + + # Step (additional variables to those in setup): w1 + memuse += self.Op.shape[1] * nbytes + + if show: + print(f"LSQR predicted memory usage: {memuse / _units[unit]:.2f} {unit}") + + return memuse + def setup( self, y: NDArray, @@ -641,6 +842,7 @@ def setup( conlim: float = 100000000.0, niter: int = 10, calc_var: bool = True, + preallocate: bool = False, show: bool = False, ) -> NDArray: r"""Setup solver @@ -670,7 +872,13 @@ def setup( Number of iterations calc_var : :obj:`bool`, optional Estimate diagonals of :math:`(\mathbf{Op}^H\mathbf{Op} + - \epsilon^2\mathbf{I})^{-1}`. + \epsilon^2\mathbf{I})^{-1}` + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display setup log @@ -687,7 +895,10 @@ def setup( self.conlim = conlim self.niter = niter self.calc_var = calc_var + self.ncp = get_array_module(y) + self.isjax = get_module_name(self.ncp) == "jax" + self._setpreallocate(preallocate) m, n = self.Op.shape @@ -719,15 +930,25 @@ def setup( self.u = y.copy() else: x = x0.copy() - self.u = self.y - self.Op.matvec(x0) + if not self.preallocate: + self.u = self.y - self.Op.matvec(x0) + else: + self.u = self.ncp.empty_like(self.y) + self.ncp.subtract(self.y, self.Op.matvec(x0), out=self.u) self.alfa = 0.0 self.beta = self.ncp.linalg.norm(self.u) if self.beta > 0.0: - self.u = self.u / self.beta + if not self.preallocate: + self.u = self.u / self.beta + else: + self.ncp.divide(self.u, self.beta, out=self.u) self.v = self.Op.rmatvec(self.u) self.alfa = self.ncp.linalg.norm(self.v) if self.alfa > 0: - self.v = self.v / self.alfa + if not self.preallocate: + self.v = self.v / self.alfa + else: + self.ncp.divide(self.v, self.alfa, out=self.v) else: self.v = x.copy() self.alfa = 0 @@ -736,6 +957,11 @@ def setup( # check if solution is already found self.arnorm: float = self.alfa * self.beta + # initialize other internal variables + if self.preallocate: + self.dk = self.ncp.empty_like(self.w) + self.w1 = self.ncp.empty_like(self.w) + # finalize setup self.arnorm0: float = self.arnorm self.rhobar: float = self.alfa @@ -778,19 +1004,31 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: # next beta, u, alfa, v. These satisfy the relations # beta*u = Op*v - alfa*u, # alfa*v = Op'*u - beta*v' - self.u = ( - self.Op.matvec(self.v) - to_numpy_conditional(self.u, self.alfa) * self.u - ) + if not self.preallocate: + self.u = self.Op.matvec(self.v) - self.alfa * self.u + else: + self.ncp.multiply(self.u, self.alfa, out=self.u) + self.ncp.subtract(self.Op.matvec(self.v), self.u, out=self.u) self.beta = self.ncp.linalg.norm(self.u) if self.beta > 0: - self.u = self.u / self.beta + if not self.preallocate: + self.u = self.u / self.beta + else: + self.ncp.divide(self.u, self.beta, out=self.u) self.anorm = np.linalg.norm( [self.anorm, to_numpy(self.alfa), to_numpy(self.beta), self.damp] ) - self.v = self.Op.rmatvec(self.u) - self.beta * self.v + if not self.preallocate: + self.v = self.Op.rmatvec(self.u) - self.beta * self.v + else: + self.ncp.multiply(self.v, self.beta, out=self.v) + self.ncp.subtract(self.Op.rmatvec(self.u), self.v, out=self.v) self.alfa = self.ncp.linalg.norm(self.v) if self.alfa > 0: - self.v = self.v / self.alfa + if not self.preallocate: + self.v = self.v / self.alfa + else: + self.ncp.divide(self.v, self.alfa, out=self.v) # use a plane rotation to eliminate the damping parameter. # This alters the diagonal (rhobar) of the lower-bidiagonal matrix. @@ -814,9 +1052,16 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: # update x and w. self.t1 = self.phi / self.rho self.t2 = -self.theta / self.rho - self.dk = self.w / self.rho - x = x + self.t1 * self.w - self.w = self.v + self.t2 * self.w + if not self.preallocate: + self.dk = self.w / self.rho + x = x + self.t1 * self.w + self.w = self.v + self.t2 * self.w + else: + self.ncp.divide(self.w, self.rho, out=self.dk) + self.ncp.multiply(self.w, self.t1, out=self.w1) + self.ncp.add(x, self.w1, out=x) + self.ncp.multiply(self.w, self.t2, out=self.w) + self.ncp.add(self.v, self.w, out=self.w) self.ddnorm = self.ddnorm + self.ncp.linalg.norm(self.dk) ** 2 if self.calc_var: self.var = self.var + to_numpy_conditional( @@ -965,6 +1210,7 @@ def solve( conlim: float = 100000000.0, niter: int = 10, calc_var: bool = True, + preallocate: bool = False, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), ) -> Tuple[ @@ -1008,6 +1254,12 @@ def solve( calc_var : :obj:`bool`, optional Estimate diagonals of :math:`(\mathbf{Op}^H\mathbf{Op} + \epsilon^2\mathbf{I})^{-1}`. + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -1079,6 +1331,7 @@ def solve( conlim=conlim, niter=niter, calc_var=calc_var, + preallocate=preallocate, show=show, ) x = self.run(x, niter=niter, show=show, itershow=itershow) diff --git a/pylops/optimization/cls_leastsquares.py b/pylops/optimization/cls_leastsquares.py index e824c86f2..8d66867bd 100644 --- a/pylops/optimization/cls_leastsquares.py +++ b/pylops/optimization/cls_leastsquares.py @@ -13,7 +13,7 @@ from scipy.sparse.linalg import lsqr from pylops.basicoperators import Diagonal, VStack -from pylops.optimization.basesolver import Solver +from pylops.optimization.basesolver import Solver, _units from pylops.optimization.basic import cg, cgls from pylops.utils.backend import get_array_module from pylops.utils.typing import NDArray @@ -88,6 +88,53 @@ def _print_finalize(self) -> None: print(f"\nTotal time (s) = {self.telapsed:.2f}") print("-" * 55 + "\n") + def memory_usage( + self, + nopRegs: Optional[Tuple[int]] = None, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + nopRegs : :obj:`tuple`, optional + Number of data elements of ``Regs`` operators + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Convert nopRegs if None + if nopRegs is None: + nopRegs = 0 + + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: y_normal, data_regs (temporary as these are + # later projected and summed to y_normal) + memuse = (self.Op.shape[1] + np.prod(nopRegs)) * nbytes + + # Run (additional variables to those in setup): Setup and Step + # of CG solver on normal equations + memuse += (self.Op.shape[1] + 3 * self.Op.shape[1]) * nbytes + memuse += (2 * self.Op.shape[1]) * nbytes + + if show: + print( + f"NormalEquationsInversion predicted memory usage: {memuse / _units[unit]:.2f} {unit}" + ) + + return memuse + def setup( self, y: NDArray, @@ -235,27 +282,23 @@ def run( ``<0``: illegal input or breakdown """ - if x is not None: - self.y_normal = self.y_normal - self.Op_normal.matvec(x) if engine == "scipy" and self.ncp == np: if "tol" in kwargs_solver: kwargs_solver["atol"] = kwargs_solver["tol"] kwargs_solver.pop("tol") - xinv, istop = sp_cg(self.Op_normal, self.y_normal, **kwargs_solver) + xinv, istop = sp_cg(self.Op_normal, self.y_normal, x0=x, **kwargs_solver) elif engine == "pylops" or self.ncp != np: if show: kwargs_solver["show"] = True xinv = cg( self.Op_normal, self.y_normal, - x0=self.ncp.zeros(self.Op_normal.shape[1], dtype=self.Op_normal.dtype), + x0=x, **kwargs_solver, )[0] istop = None else: raise NotImplementedError("Engine must be scipy or pylops") - if x is not None: - xinv = x + xinv return xinv, istop def solve( @@ -448,6 +491,60 @@ def _print_finalize(self) -> None: print(f"\nTotal time (s) = {self.telapsed:.2f}") print("-" * 65 + "\n") + def memory_usage( + self, + nopRegs: Optional[Tuple[int]] = None, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + .. note:: The memory usage is computed assuming that + ``engine="pylops"`` is used. When ``engine="scipy"`` is + used instead, :func:`scipy.sparse.linalg.lsqr` may consume + more memory. + + Parameters + ---------- + nopRegs : :obj:`tuple`, optional + Number of data elements of ``Regs`` operators + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Convert nopRegs if None + if nopRegs is None: + nopRegs = 0 + + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: datatot + ndatatot = self.Op.shape[0] + np.prod(nopRegs) + memuse = ndatatot * nbytes + + # Run (additional variables to those in setup): Setup and Step + # of PyLops CGLS solver on augumented equations. Note that when + # engine="scipy", the SciPy LSQR solver is used instead (which + # may consume more memory). + memuse += (2 * self.Op.shape[1] + 3 * ndatatot) * nbytes + memuse += (3 * self.Op.shape[1]) * nbytes + + if show: + print( + f"RegularizedInversion predicted memory usage: {memuse / _units[unit]:.2f} {unit}" + ) + + return memuse + def setup( self, y: NDArray, @@ -520,7 +617,7 @@ def setup( # augumented operator if self.epsRs is not None and self.dataregs is not None: for epsR, datareg in zip(self.epsRs, self.dataregs): - self.datatot = np.hstack((self.datatot, epsR * datareg)) + self.datatot = self.ncp.hstack((self.datatot, epsR * datareg)) # print setup if show: @@ -580,13 +677,11 @@ def run( Equal to ``r1norm`` if :math:`\epsilon=0` """ - if x is not None: - self.datatot = self.datatot - self.RegOp.matvec(x) if engine == "scipy" and self.ncp == np: if show: kwargs_solver["show"] = 1 xinv, istop, itn, r1norm, r2norm = lsqr( - self.RegOp, self.datatot, **kwargs_solver + self.RegOp, self.datatot, x0=x, **kwargs_solver )[0:5] elif engine == "pylops" or self.ncp != np: if show: @@ -594,13 +689,11 @@ def run( xinv, istop, itn, r1norm, r2norm = cgls( self.RegOp, self.datatot, - x0=self.ncp.zeros(self.RegOp.shape[1], dtype=self.RegOp.dtype), + x0=x, **kwargs_solver, )[0:5] else: raise NotImplementedError("Engine must be scipy or pylops") - if x is not None: - xinv = x + xinv return xinv, istop, itn, r1norm, r2norm def solve( @@ -717,6 +810,52 @@ def _print_finalize(self) -> None: print(f"\nTotal time (s) = {self.telapsed:.2f}") print("-" * 65 + "\n") + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + .. note:: The memory usage is computed assuming that + ``engine="pylops"`` is used. When ``engine="scipy"`` is + used instead, :func:`scipy.sparse.linalg.lsqr` may consume + more memory. + + Parameters + ---------- + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: y + memuse = self.Op.shape[0] * nbytes + + # Run (additional variables to those in setup): Setup and Step + # of PyLops CGLS solver on augumented equations. Note that when + # engine="scipy", the SciPy LSQR solver is used instead (which + # may consume more memory). + memuse += (2 * self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes + memuse += (3 * self.Op.shape[1]) * nbytes + + if show: + print( + f"PreconditionedInversion predicted memory usage: {memuse / _units[unit]:.2f} {unit}" + ) + + return memuse + def setup( self, y: NDArray, @@ -800,14 +939,13 @@ def run( Equal to ``r1norm`` if :math:`\epsilon=0` """ - if x is not None: - self.y = self.y - self.Op.matvec(x) if engine == "scipy" and self.ncp == np: if show: kwargs_solver["show"] = 1 pinv, istop, itn, r1norm, r2norm = lsqr( self.POp, self.y, + x0=x, **kwargs_solver, )[0:5] elif engine == "pylops" or self.ncp != np: @@ -816,7 +954,7 @@ def run( pinv, istop, itn, r1norm, r2norm = cgls( self.POp, self.y, - x0=self.ncp.zeros(self.POp.shape[1], dtype=self.POp.dtype), + x0=x, **kwargs_solver, )[0:5] # force it 1d as we decorate this method with disable_ndarray_multiplication @@ -824,8 +962,6 @@ def run( else: raise NotImplementedError("Engine must be scipy or pylops") xinv = self.P.matvec(pinv) - if x is not None: - xinv = x + xinv return xinv, istop, itn, r1norm, r2norm def solve( diff --git a/pylops/optimization/cls_sparsity.py b/pylops/optimization/cls_sparsity.py index ceb4c6700..b0bff8638 100644 --- a/pylops/optimization/cls_sparsity.py +++ b/pylops/optimization/cls_sparsity.py @@ -9,6 +9,7 @@ import logging import time +from math import sqrt from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import numpy as np @@ -16,12 +17,12 @@ from pylops import LinearOperator from pylops.basicoperators import Diagonal, Identity, VStack -from pylops.optimization.basesolver import Solver +from pylops.optimization.basesolver import Solver, _units from pylops.optimization.basic import cgls from pylops.optimization.eigs import power_iteration from pylops.optimization.leastsquares import regularized_inversion from pylops.utils import deps -from pylops.utils.backend import get_array_module, get_module_name +from pylops.utils.backend import get_array_module, get_module_name, inplace_set from pylops.utils.typing import InputDimsLike, NDArray, SamplingLike spgl1_message = deps.spgl1_import("the spgl1 solver") @@ -29,6 +30,8 @@ if spgl1_message is None: from spgl1 import spgl1 as ext_spgl1 +logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) + def _hardthreshold(x: NDArray, thresh: float) -> NDArray: r"""Hard thresholding. @@ -54,7 +57,7 @@ def _hardthreshold(x: NDArray, thresh: float) -> NDArray: """ x1 = x.copy() - x1[np.abs(x) <= np.sqrt(2 * thresh)] = 0 + x1[np.abs(x) <= sqrt(2 * thresh)] = 0 return x1 @@ -324,6 +327,51 @@ def _print_step(self, x: NDArray) -> None: str2 = f" {self.rnorm:10.3e}" print(str1 + str2) + def memory_usage( + self, + kind: str = "data", + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + kind : :obj:`str`, optional + Kind of solver (``model``, ``data`` or ``datamodel``) + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: y + augmented y if kind=datamodel + memuse = self.Op.shape[0] * nbytes + if kind == "datamodel": + memuse += self.Op.shape[1] * nbytes + + # Step (additional variables to those in setup): rw + if kind == "data": + memuse += self.Op.shape[0] * nbytes + elif kind == "model": + memuse += self.Op.shape[1] * nbytes + elif kind == "datamodel": + memuse += 2 * self.Op.shape[0] * nbytes + + if show: + print(f"IRLS predicted memory usage: {memuse / _units[unit]:.2f} {unit}") + + return memuse + def setup( self, y: NDArray, @@ -334,6 +382,7 @@ def setup( tolIRLS: float = 1e-10, warm: bool = False, kind: str = "data", + preallocate: bool = False, show: bool = False, ) -> None: r"""Setup solver @@ -360,6 +409,12 @@ def setup( This only applies to ``kind="data"`` and ``kind="datamodel"`` kind : :obj:`str`, optional Kind of solver (``model``, ``data`` or ``datamodel``) + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display setup log @@ -372,7 +427,12 @@ def setup( self.tolIRLS = tolIRLS self.warm = warm self.kind = kind + self.ncp = get_array_module(y) + self.isjax = get_module_name(self.ncp) == "jax" + self._setpreallocate(preallocate) + + # initiate outer iteration counter self.iiter = 0 # choose step to use @@ -386,33 +446,61 @@ def setup( # augment Op and y self.Op = VStack([self.Op, epsI * Identity(self.Op.shape[1])]) self.epsI = 0.0 # as epsI is added to the augmented system already - self.y = np.hstack([self.y, np.zeros(self.Op.shape[1])]) + self.y = self.ncp.hstack([self.y, self.ncp.zeros(self.Op.shape[1])]) else: raise NotImplementedError("kind must be model, data or datamodel") + if self.preallocate: + self.r = self.ncp.empty_like(self.y) + if "data" in self.kind: + self.rw = self.ncp.empty_like(self.y) + else: + self.rw = self.ncp.empty(self.Op.shape[1], dtype=self.Op.dtype) # print setup if show: self._print_setup() - def _step_data(self, x: NDArray, **kwargs_solver) -> NDArray: + def _step_data(self, x: NDArray, engine: str = "scipy", **kwargs_solver) -> NDArray: r"""Run one step of solver with L1 data term""" + # add preallocate to keywords of solver + if self.preallocate and (engine == "pylops" or self.ncp != np): + kwargs_solver["preallocate"] = True if self.iiter == 0: + # first iteration (standard least-squares) x = regularized_inversion( self.Op, self.y, None, x0=x if self.warm else None, damp=self.epsI, + engine=engine, **kwargs_solver, )[0] else: # other iterations (weighted least-squares) - if self.threshR: - self.rw = 1.0 / self.ncp.maximum(self.ncp.abs(self.r), self.epsR) + if self.preallocate and self.iiter == 1: + self.rw = self.ncp.zeros_like(self.y) + + if not self.preallocate: + if self.threshR: + self.rw = 1.0 / self.ncp.maximum(self.ncp.abs(self.r), self.epsR) + else: + self.rw = 1.0 / (self.ncp.abs(self.r) + self.epsR) + self.rw = self.rw / self.rw.max() else: - self.rw = 1.0 / (self.ncp.abs(self.r) + self.epsR) - self.rw = self.rw / self.rw.max() - R = Diagonal(np.sqrt(self.rw)) + if self.threshR: + self.ncp.divide( + 1.0, + self.ncp.maximum(self.ncp.abs(self.r), self.epsR), + out=self.rw, + ) + else: + self.ncp.divide( + 1.0, (self.ncp.abs(self.r) + self.epsR), out=self.rw + ) + self.ncp.divide(self.rw, self.rw.max(), out=self.rw) + + R = Diagonal(self.ncp.sqrt(self.rw)) x = regularized_inversion( self.Op, self.y, @@ -420,15 +508,21 @@ def _step_data(self, x: NDArray, **kwargs_solver) -> NDArray: Weight=R, x0=x if self.warm else None, damp=self.epsI, + engine=engine, **kwargs_solver, )[0] return x - def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray: + def _step_model( + self, x: NDArray, engine: str = "scipy", **kwargs_solver + ) -> NDArray: r"""Run one step of solver with L1 model term""" + # add preallocate to keywords of solver + if self.preallocate and (engine == "pylops" or self.ncp != np): + kwargs_solver["preallocate"] = True if self.iiter == 0: # first iteration (unweighted least-squares) - if self.ncp == np: + if engine == "scipy" and self.ncp == np: x = self.Op.rmatvec( lsqr( self.Op @ self.Op.H + (self.epsI**2) * self.Iop, @@ -436,7 +530,7 @@ def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray: **kwargs_solver, )[0] ) - else: + elif engine == "pylops" or self.ncp != np: x = self.Op.rmatvec( cgls( self.Op @ self.Op.H + (self.epsI**2) * self.Iop, @@ -447,10 +541,17 @@ def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray: ) else: # other iterations (weighted least-squares) - self.rw = np.abs(x) - self.rw = self.rw / self.rw.max() + if self.preallocate and self.iiter == 1: + self.rw = self.ncp.zeros_like(x) + if not self.preallocate: + self.rw = self.ncp.abs(x) + self.rw = self.rw / self.rw.max() + else: + self.ncp.abs(x, out=self.rw) + self.ncp.divide(self.rw, self.rw.max(), out=self.rw) + R = Diagonal(self.rw, dtype=self.rw.dtype) - if self.ncp == np: + if engine == "scipy" and self.ncp == np: x = R.matvec( self.Op.rmatvec( lsqr( @@ -460,7 +561,7 @@ def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray: )[0] ) ) - else: + elif engine == "pylops" or self.ncp != np: x = R.matvec( self.Op.rmatvec( cgls( @@ -473,21 +574,33 @@ def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray: ) return x - def step(self, x: NDArray, show: bool = False, **kwargs_solver) -> NDArray: + def step( + self, + x: NDArray, + engine: str = "scipy", + show: bool = False, + **kwargs_solver, + ) -> NDArray: r"""Run one step of solver Parameters ---------- x : :obj:`np.ndarray` Current model vector to be updated by a step of ISTA + engine : :obj:`str`, optional + .. versionadded:: 2.6.0 + + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display iteration log **kwargs_solver Arbitrary keyword arguments for :py:func:`scipy.sparse.linalg.cg` solver for data IRLS and :py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using - numpy data(or :py:func:`pylops.optimization.solver.cg` and - :py:func:`pylops.optimization.solver.cgls` when using cupy data) + numpy data and ``engine='scipy'`` (or + :py:func:`pylops.optimization.solver.cg` and + :py:func:`pylops.optimization.solver.cgls` when using cupy data or + ``engine='pylops'``) Returns ------- @@ -496,10 +609,13 @@ def step(self, x: NDArray, show: bool = False, **kwargs_solver) -> NDArray: """ # update model - x = self._step(x, **kwargs_solver) + x = self._step(x, engine=engine, **kwargs_solver) # compute residual - self.r: NDArray = self.y - self.Op.matvec(x) + if not self.preallocate: + self.r: NDArray = self.y - self.Op.matvec(x) + else: + self.ncp.subtract(self.y, self.Op.matvec(x), out=self.r) self.rnorm = self.ncp.linalg.norm(self.r) self.iiter += 1 @@ -509,8 +625,9 @@ def step(self, x: NDArray, show: bool = False, **kwargs_solver) -> NDArray: def run( self, - x: NDArray, + x: Optional[NDArray], nouter: int = 10, + engine: str = "scipy", show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), **kwargs_solver, @@ -520,9 +637,14 @@ def run( Parameters ---------- x : :obj:`np.ndarray` - Current model vector to be updated by multiple steps of IRLS + Current model vector to be updated by multiple steps of IRLS. Provide + ``None`` to initialize internally as zero vector nouter : :obj:`int`, optional Number of outer iterations. + engine : :obj:`str`, optional + .. versionadded:: 2.6.0 + + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -533,8 +655,10 @@ def run( Arbitrary keyword arguments for :py:func:`scipy.sparse.linalg.cg` solver for data IRLS and :py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using - numpy data(or :py:func:`pylops.optimization.solver.cg` and - :py:func:`pylops.optimization.solver.cgls` when using cupy data) + numpy data and ``engine='scipy'`` (or + :py:func:`pylops.optimization.solver.cg` and + :py:func:`pylops.optimization.solver.cgls` when using cupy data or + ``engine='pylops'``) Returns ------- @@ -546,6 +670,7 @@ def run( if x is not None: self.x0 = x.copy() self.y = self.y - self.Op.matvec(x) + # choose xold to ensure tolerance test is passed initially xold = x.copy() + np.inf while self.iiter < nouter and self.ncp.linalg.norm(x - xold) >= self.tolIRLS: @@ -560,7 +685,7 @@ def run( else False ) xold = x.copy() - x = self.step(x, showstep, **kwargs_solver) + x = self.step(x, engine, showstep, **kwargs_solver) self.callback(x) # adding initial guess @@ -594,6 +719,8 @@ def solve( tolIRLS: float = 1e-10, kind: str = "data", warm: bool = False, + engine: str = "scipy", + preallocate: bool = False, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), **kwargs_solver, @@ -624,6 +751,16 @@ def solve( This only applies to ``kind="data"`` and ``kind="datamodel"`` kind : :obj:`str`, optional Kind of solver (``data`` or ``model``) + engine : :obj:`str`, optional + .. versionadded:: 2.6.0 + + Solver to use (``scipy`` or ``pylops``) + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display setup log itershow : :obj:`tuple`, optional @@ -651,11 +788,19 @@ def solve( tolIRLS=tolIRLS, warm=warm, kind=kind, + preallocate=preallocate, show=show, ) if x0 is None: x0 = self.ncp.zeros(self.Op.shape[1], dtype=self.y.dtype) - x = self.run(x0, nouter=nouter, show=show, itershow=itershow, **kwargs_solver) + x = self.run( + x0, + nouter=nouter, + engine=engine, + show=show, + itershow=itershow, + **kwargs_solver, + ) self.finalize(show) return x, self.nouter @@ -756,6 +901,41 @@ def _print_step(self, x: NDArray) -> None: str2 = f" {self.cost[-1]:10.3e}" print(str1 + str2) + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: y, res + memuse = (2 * self.Op.shape[0]) * nbytes + + # Step (additional variables to those in setup): cres, cres_abs + memuse += (2 * self.Op.shape[0]) * nbytes + + if show: + print(f"OMP predicted memory usage: {memuse / _units[unit]:.2f} {unit}") + + return memuse + def setup( self, y: NDArray, @@ -765,6 +945,7 @@ def setup( normalizecols: bool = False, Opbasis: Optional["LinearOperator"] = None, optimal_coeff: bool = False, + preallocate: bool = False, show: bool = False, ) -> None: r"""Setup solver @@ -794,6 +975,12 @@ def setup( :math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the directly the value from the inner product :math:`\mathbf{Op}_j^H\,\mathbf{r}_k`. + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display setup log @@ -805,7 +992,10 @@ def setup( self.normalizecols = normalizecols self.Opbasis = Opbasis if Opbasis is not None else self.Op self.optimal_coeff = optimal_coeff + self.ncp = get_array_module(y) + self.isjax = get_module_name(self.ncp) == "jax" + self._setpreallocate(preallocate) # find normalization factor for each column if self.normalizecols: @@ -814,8 +1004,8 @@ def setup( for icol in range(ncols): unit = self.ncp.zeros(ncols, dtype=self.Opbasis.dtype) unit[icol] = 1 - self.norms[icol] = np.linalg.norm(self.Opbasis.matvec(unit)) - print(f"{self.norms = }") + self.norms[icol] = self.ncp.linalg.norm(self.Opbasis.matvec(unit)) + # create variables to track the residual norm and iterations self.res = self.y.copy() self.cost = [ @@ -830,7 +1020,9 @@ def step( self, x: NDArray, cols: InputDimsLike, + engine: str = "scipy", show: bool = False, + **kwargs_solver, ) -> NDArray: r"""Run one step of solver @@ -840,8 +1032,18 @@ def step( Current model vector to be updated by a step of OMP cols : :obj:`list` Current list of chosen elements of vector x to be updated by a step of OMP + engine : :obj:`str`, optional + .. versionadded:: 2.6.0 + + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display iteration log + **kwargs_solver + Arbitrary keyword arguments for + :py:func:`scipy.sparse.linalg.lsqr` solver when using + numpy data and ``engine='scipy'`` (or + :py:func:`pylops.optimization.solver.cgls` when using cupy + data or ``engine='pylops'``) Returns ------- @@ -851,14 +1053,18 @@ def step( Current list of chosen elements """ + # add preallocate to keywords of solver + if self.preallocate and (engine == "pylops" or self.ncp != np): + kwargs_solver["preallocate"] = True + # compute inner products cres = self.Op.rmatvec(self.res) if self.normalizecols: cres = cres / self.norms - cres_abs = np.abs(cres) + cres_abs = self.ncp.abs(cres) # choose column with max cres - cres_max = np.max(cres_abs) - imax = np.argwhere(cres_abs == cres_max).ravel() + cres_max = self.ncp.max(cres_abs) + imax = self.ncp.argwhere(cres_abs == cres_max).ravel() nimax = len(imax) if nimax > 0: imax = imax[np.random.permutation(nimax)[0]] @@ -882,7 +1088,14 @@ def step( ) if not self.optimal_coeff: # update with coefficient that maximizes the inner product - self.res -= Opcol.matvec(cres[imax] * self.ncp.ones(1)) + if not self.preallocate: + self.res -= Opcol.matvec(cres[imax] * self.ncp.ones(1)) + else: + self.ncp.subtract( + self.res, + Opcol.matvec(cres[imax] * self.ncp.ones(1)), + out=self.res, + ) if addnew: x.append(cres[imax]) else: @@ -891,7 +1104,12 @@ def step( # find optimal coefficient that minimizes the residual (r - cres * col) col = Opcol.matvec(self.ncp.ones(1, dtype=Opcol.dtype)) cresopt = (Opcol.rmatvec(self.res) / Opcol.rmatvec(col))[0] - self.res -= Opcol.matvec(cresopt * self.ncp.ones(1)) + if not self.preallocate: + self.res -= Opcol.matvec(cresopt * self.ncp.ones(1)) + else: + self.ncp.subtract( + self.res, Opcol.matvec(cresopt * self.ncp.ones(1)), out=self.res + ) if addnew: x.append(cresopt) else: @@ -899,16 +1117,21 @@ def step( else: # OMP update Opcol = self.Op.apply_columns(cols) - if self.ncp == np: - x = lsqr(Opcol, self.y, iter_lim=self.niter_inner)[0] - else: + if engine == "scipy" and self.ncp == np: + x = lsqr(Opcol, self.y, iter_lim=self.niter_inner, **kwargs_solver)[0] + elif engine == "pylops" or self.ncp != np: x = cgls( Opcol, self.y, self.ncp.zeros(int(Opcol.shape[1]), dtype=Opcol.dtype), niter=self.niter_inner, + **kwargs_solver, )[0] - self.res = self.y - Opcol.matvec(x) + if not self.preallocate: + self.res = self.y - Opcol.matvec(x) + else: + self.res = Opcol.matvec(x) + self.ncp.subtract(self.res, self.y, out=self.res) self.iiter += 1 self.cost.append(float(np.linalg.norm(self.res))) @@ -920,6 +1143,7 @@ def run( self, x: NDArray, cols: InputDimsLike, + engine: str = "scipy", show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), ) -> Tuple[NDArray, InputDimsLike]: @@ -931,6 +1155,10 @@ def run( Current model vector to be updated by multiple steps of IRLS cols : :obj:`list` Current list of chosen elements of vector x to be updated by a step of OMP + engine : :obj:`str`, optional + .. versionadded:: 2.6.0 + + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -957,7 +1185,7 @@ def run( ) else False ) - x, cols = self.step(x, cols, showstep) + x, cols = self.step(x, cols, engine, showstep) self.callback(x, cols) return x, cols @@ -990,7 +1218,8 @@ def finalize( self.nouter = self.iiter xfin = self.ncp.zeros(int(self.Op.shape[1]), dtype=self.Op.dtype) - xfin[cols] = self.ncp.array(x) + xfin = inplace_set(self.ncp.array(x), xfin, self.ncp.array(cols)) + if show: self._print_finalize(nbar=55) return xfin @@ -1004,6 +1233,8 @@ def solve( normalizecols: bool = False, Opbasis: Optional["LinearOperator"] = None, optimal_coeff: bool = False, + engine: str = "scipy", + preallocate: bool = False, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), ) -> Tuple[NDArray, int, NDArray]: @@ -1034,6 +1265,16 @@ def solve( :math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the directly the value from the inner product :math:`\mathbf{Op}_j^H\,\mathbf{r}_k`. + engine : :obj:`str`, optional + .. versionadded:: 2.6.0 + + Solver to use (``scipy`` or ``pylops``) + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -1059,11 +1300,12 @@ def solve( normalizecols=normalizecols, Opbasis=Opbasis, optimal_coeff=optimal_coeff, + preallocate=preallocate, show=show, ) x: List[NDArray] = [] cols: List[InputDimsLike] = [] - x, cols = self.run(x, cols, show=show, itershow=itershow) + x, cols = self.run(x, cols, engine=engine, show=show, itershow=itershow) x = self.finalize(x, cols, show) return x, self.nouter, self.cost @@ -1180,6 +1422,41 @@ def _print_step( ) print(msg) + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: x0 - y + memuse = (self.Op.shape[1] + self.Op.shape[0]) * nbytes + + # Step (additional variables to those in setup): xold, grad, x_unthesh - res + memuse += (3 * self.Op.shape[1] + self.Op.shape[0]) * nbytes + + if show: + print(f"ISTA predicted memory usage: {memuse / _units[unit]:.2f} {unit}") + + return memuse + def setup( self, y: NDArray, @@ -1194,6 +1471,7 @@ def setup( perc: Optional[float] = None, decay: Optional[NDArray] = None, monitorres: bool = False, + preallocate: bool = False, show: bool = False, ) -> NDArray: r"""Setup solver @@ -1234,6 +1512,12 @@ def setup( Decay factor to be applied to thresholding during iterations monitorres : :obj:`bool`, optional Monitor that residual is decreasing + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display setup log @@ -1255,6 +1539,8 @@ def setup( self.monitorres = monitorres self.ncp = get_array_module(y) + self.isjax = get_module_name(self.ncp) == "jax" + self._setpreallocate(preallocate) # choose matvec/rmatvec or matmat/rmatmat based on R if y.ndim > 1 and y.shape[1] > 1: @@ -1309,7 +1595,7 @@ def setup( # prepare decay (if not passed) if perc is None and decay is None: - self.decay = self.ncp.ones(niter) + self.decay = self.ncp.ones(niter, dtype=self.Op) # step size if alpha is not None: @@ -1357,6 +1643,17 @@ def setup( else: x = x0.copy() + # initialize other internal variabled + if self.preallocate: + self.res = self.ncp.empty_like(y) + self.grad = self.ncp.empty_like(x) + self.x_unthesh = self.ncp.empty_like(x) + self.xold = self.ncp.empty_like(x) + if self.SOp is not None: + self.SOpx_unthesh: NDArray = self.ncp.zeros( + self.SOp.shape[1], dtype=self.SOp.dtype + ) + # create variable to track residual if monitorres: self.normresold = np.inf @@ -1392,11 +1689,19 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]: """ # store old vector - xold = x.copy() + if self.preallocate: + self.xold[:] = x[:] + else: + xold = x.copy() + # compute residual - res: NDArray = self.y - self.Opmatvec(x) + if not self.preallocate: + res: NDArray = self.y - self.Opmatvec(x) + else: + self.ncp.subtract(self.y, self.Opmatvec(x), out=self.res) + if self.monitorres: - self.normres = np.linalg.norm(res) + self.normres = np.linalg.norm(self.res if self.preallocate else res) if self.normres > self.normresold: raise ValueError( f"ISTA stopped at iteration {self.iiter} due to " @@ -1407,25 +1712,67 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]: self.normresold = self.normres # compute gradient - grad: NDArray = self.alpha * (self.Oprmatvec(res)) + if not self.preallocate: + grad: NDArray = self.alpha * (self.Oprmatvec(res)) + else: + self.ncp.multiply( + self.Oprmatvec(self.res), + self.alpha, + out=self.grad, + ) # update inverted model - x_unthesh: NDArray = x + grad + if not self.preallocate: + x_unthesh: NDArray = x + grad + else: + self.ncp.add( + x, + self.grad, + out=self.x_unthesh, + ) + + # apply SOp.H to current x if self.SOp is not None: - x_unthesh = self.SOprmatvec(x_unthesh) - if self.perc is None and self.decay is not None: - x = self.threshf(x_unthesh, self.decay[self.iiter] * self.thresh) - elif self.perc is not None: - x = self.threshf(x_unthesh, 100 - self.perc) + if self.preallocate: + self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh) + else: + SOpx_unthesh = self.SOprmatvec(x_unthesh) + # threshold current solution or current solution projected onto SOp.H space + if self.SOp is None: + x_unthesh_or_SOpx_unthesh = ( + self.x_unthesh if self.preallocate else x_unthesh + ) + else: + x_unthesh_or_SOpx_unthesh = ( + self.SOpx_unthesh if self.preallocate else SOpx_unthesh + ) + if self.perc is None: + x = self.threshf( + x_unthesh_or_SOpx_unthesh, + self.decay[self.iiter] * self.thresh, + ) + else: + x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc) + # apply SOp to thresholded x if self.SOp is not None: x = self.SOpmatvec(x) - # model update - xupdate = np.linalg.norm(x - xold) + # check model update + if not self.preallocate: + xupdate = np.linalg.norm(x - xold) + else: + self.ncp.subtract( + x, + self.xold, + out=self.xold, + ) + xupdate = np.linalg.norm(self.xold) - costdata = 0.5 * np.linalg.norm(res) ** 2 + # cost functions + costdata = 0.5 * np.linalg.norm(self.res if self.preallocate else res) ** 2 costreg = self.eps * np.linalg.norm(x, ord=1) self.cost.append(float(costdata + costreg)) + self.iiter += 1 if show: self._print_step(x, costdata, costreg, xupdate) @@ -1512,6 +1859,7 @@ def solve( perc: Optional[float] = None, decay: Optional[NDArray] = None, monitorres: bool = False, + preallocate: bool = False, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), ) -> Tuple[NDArray, int, NDArray]: @@ -1552,6 +1900,12 @@ def solve( Decay factor to be applied to thresholding during iterations monitorres : :obj:`bool`, optional Monitor that residual is decreasing + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -1582,6 +1936,7 @@ def solve( perc=perc, decay=decay, monitorres=monitorres, + preallocate=preallocate, show=show, ) x = self.run(x, niter, show=show, itershow=itershow) @@ -1647,6 +2002,41 @@ class FISTA(ISTA): """ + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: x0 - y + memuse = (self.Op.shape[1] + self.Op.shape[0]) * nbytes + + # Step (additional variables to those in setup): xold, grad, x_unthesh, z - res + memuse += (4 * self.Op.shape[1] + self.Op.shape[0]) * nbytes + + if show: + print(f"FISTA predicted memory usage: {memuse / _units[unit]:.2f} {unit}") + + return memuse + def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray: r"""Run one step of solver @@ -1670,45 +2060,101 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray: """ # store old vector - xold = x.copy() + if self.preallocate: + self.xold[:] = x[:] + else: + xold = x.copy() + # compute residual - resz: NDArray = self.y - self.Opmatvec(z) + if not self.preallocate: + res: NDArray = self.y - self.Opmatvec(z) + else: + self.ncp.subtract(self.y, self.Opmatvec(z), out=self.res) + if self.monitorres: - self.normres = np.linalg.norm(resz) + self.normres = np.linalg.norm(self.res if self.preallocate else res) if self.normres > self.normresold: raise ValueError( - f"ISTA stopped at iteration {self.iiter} due to " + f"FISTA stopped at iteration {self.iiter} due to " "residual increasing, consider modifying " "eps and/or alpha..." ) else: self.normresold = self.normres - # compute gradient - grad: NDArray = self.alpha * (self.Oprmatvec(resz)) + # compute gradient and update inverted model + if not self.preallocate: + grad: NDArray = self.alpha * (self.Oprmatvec(res)) + x_unthesh: NDArray = z + grad + else: + self.ncp.multiply( + self.Oprmatvec(self.res), + self.alpha, + out=self.grad, + ) + self.ncp.add( + z, + self.grad, + out=self.x_unthesh, + ) - # update inverted model - x_unthesh: NDArray = z + grad + # apply SOp.H to current x if self.SOp is not None: - x_unthesh = self.SOprmatvec(x_unthesh) - if self.perc is None and self.decay is not None: - x = self.threshf(x_unthesh, self.decay[self.iiter] * self.thresh) - elif self.perc is not None: - x = self.threshf(x_unthesh, 100 - self.perc) + if self.preallocate: + self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh) + else: + SOpx_unthesh = self.SOprmatvec(x_unthesh) + + # threshold current solution or current solution projected onto SOp.H space + if self.SOp is None: + x_unthesh_or_SOpx_unthesh = ( + self.x_unthesh if self.preallocate else x_unthesh + ) + else: + x_unthesh_or_SOpx_unthesh = ( + self.SOpx_unthesh if self.preallocate else SOpx_unthesh + ) + if self.perc is None: + x = self.threshf( + x_unthesh_or_SOpx_unthesh, + self.decay[self.iiter] * self.thresh, + ) + else: + x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc) + + # apply SOp to thresholded x if self.SOp is not None: x = self.SOpmatvec(x) # update auxiliary coefficients told = self.t - self.t = (1.0 + np.sqrt(1.0 + 4.0 * self.t**2)) / 2.0 - z = x + ((told - 1.0) / self.t) * (x - xold) + self.t = (1.0 + sqrt(1.0 + 4.0 * self.t**2)) / 2.0 # model update - xupdate = np.linalg.norm(x - xold) + if not self.preallocate: + z = x + ((told - 1.0) / self.t) * (x - xold) + else: + self.ncp.subtract( + x, + self.xold, + out=self.xold, + ) + self.ncp.multiply(self.xold, ((told - 1.0) / self.t), out=z) + self.ncp.add(x, z, out=z) + # check model update + if not self.preallocate: + xupdate = np.linalg.norm(x - xold) + else: + # note that x - xold has been already computed as part of the + # intermediate calculation of x in model update step + xupdate = np.linalg.norm(self.xold) + + # cost functions costdata = 0.5 * np.linalg.norm(self.y - self.Op @ x) ** 2 costreg = self.eps * np.linalg.norm(x, ord=1) self.cost.append(float(costdata + costreg)) + self.iiter += 1 if show: self._print_step(x, costdata, costreg, xupdate) @@ -1829,6 +2275,13 @@ def _print_finalize(self) -> None: print(f"\nTotal time (s) = {self.telapsed:.2f}") print("-" * 80 + "\n") + def memory_usage( + self, + show: bool = False, + unit: str = "B", + ) -> float: + pass + def setup( self, y: NDArray, @@ -2141,6 +2594,60 @@ def _print_step(self, x: NDArray) -> None: str2 = f"{self.costdata:10.3e} {self.costtot:9.3e}" print(str1 + str2) + def memory_usage( + self, + nopRegsL1: Optional[Tuple[int]] = None, + nopRegsL2: Optional[Tuple[int]] = None, + show: bool = False, + unit: str = "B", + ) -> float: + """Compute memory usage of the solver + + Parameters + ---------- + nopRegsL1 : :obj:`tuple`, optional + Number of data elements of ``RegsL1`` operators + nopRegsL2 : :obj:`tuple`, optional + Number of data elements of ``RegsL2`` operators + show : :obj:`bool`, optional + Display memory usage + unit: :obj:`str`, optional + Unit used to display memory usage ( + ``B``, ``KB``, ``MB`` or ``GB``) + + Returns + ------- + memuse :obj:`float` + Memory usage in Bytes + + """ + # Convert nopRegsL1 and nopRegsL2 if None + if nopRegsL1 is None: + nopRegsL1 = 0 + if nopRegsL2 is None: + nopRegsL2 = 0 + + # Get number of bytes of dtype used in the solver + nbytes = np.dtype(self.Op.dtype).itemsize + + # Setup: x0 - y - b, d dataregsL1 - dataregsL2 + memuse = ( + self.Op.shape[1] + + self.Op.shape[0] + + 3 * np.prod(nopRegsL1) + + np.prod(nopRegsL2) + ) * nbytes + + # Step (additional variables to those in setup): dataregs + memuse += np.prod(nopRegsL1) * nbytes + + if show: + print( + f"Split-Bregman predicted memory usage: {memuse / _units[unit]:.2f} {unit}" + ) + + return memuse + def setup( self, y: NDArray, @@ -2156,6 +2663,7 @@ def setup( tol: float = 1e-10, tau: float = 1.0, restart: bool = False, + preallocate: bool = False, show: bool = False, ) -> NDArray: r"""Setup solver @@ -2201,6 +2709,12 @@ def setup( the initial guess (``True``) or with the last estimate (``False``). Note that when this is set to ``True``, the ``x0`` provided in the setup will be used in all iterations. + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display setup log @@ -2222,14 +2736,19 @@ def setup( self.tol = tol self.tau = tau self.restart = restart + self.ncp = get_array_module(y) + self.isjax = get_module_name(self.ncp) == "jax" + self._setpreallocate(preallocate) # L1 regularizations self.nregsL1 = len(RegsL1) self.b = [ self.ncp.zeros(RegL1.shape[0], dtype=self.Op.dtype) for RegL1 in RegsL1 ] - self.d = self.b.copy() + self.d = [ + self.ncp.zeros(RegL1.shape[0], dtype=self.Op.dtype) for RegL1 in RegsL1 + ] # L2 regularizations self.nregsL2 = 0 if RegsL2 is None else len(RegsL2) @@ -2247,13 +2766,11 @@ def setup( self.epsRs: List[float] = [] if epsRL2s is not None: self.epsRs += [ - np.sqrt(epsRL2s[ireg] / 2) / np.sqrt(mu / 2) - for ireg in range(self.nregsL2) + sqrt(epsRL2s[ireg] / 2) / sqrt(mu / 2) for ireg in range(self.nregsL2) ] if epsRL1s is not None: self.epsRs += [ - np.sqrt(epsRL1s[ireg] / 2) / np.sqrt(mu / 2) - for ireg in range(self.nregsL1) + sqrt(epsRL1s[ireg] / 2) / sqrt(mu / 2) for ireg in range(self.nregsL1) ] self.x0 = x0 @@ -2270,9 +2787,10 @@ def setup( def step( self, x: NDArray, + engine: str = "scipy", show: bool = False, show_inner: bool = False, - **kwargs_lsqr, + **kwargs_solver, ) -> NDArray: r"""Run one step of solver @@ -2280,14 +2798,18 @@ def step( ---------- x : :obj:`list` or :obj:`np.ndarray` Current model vector to be updated by a step of OMP - show_inner : :obj:`bool`, optional - Display inner iteration logs of lsqr + engine : :obj:`str`, optional + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display iteration log - **kwargs_lsqr - Arbitrary keyword arguments for - :py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first - subproblem in the first step of the Split Bregman algorithm. + show_inner : :obj:`bool`, optional + Display inner iteration logs of lsqr + **kwargs_solver + Arbitrary keyword arguments for chosen solver + used to solve the first subproblem in the first step of the + Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` and + :py:func:`pylops.optimization.solver.cgls` are used as default + for numpy and cupy `data`, respectively). Returns ------- @@ -2295,11 +2817,22 @@ def step( Updated model vector """ + # add preallocate to keywords of solver + if self.preallocate and (engine == "pylops" or self.ncp != np): + kwargs_solver["preallocate"] = True + for _ in range(self.niter_inner): # regularized problem - dataregs = self.dataregsL2 + [ - self.d[ireg] - self.b[ireg] for ireg in range(self.nregsL1) - ] + if not self.preallocate: + dataregs = self.dataregsL2 + [ + self.d[ireg] - self.b[ireg] for ireg in range(self.nregsL1) + ] + else: + for ireg in range(self.nregsL1): + self.ncp.subtract(self.d[ireg], self.b[ireg], out=self.d[ireg]) + dataregs = self.dataregsL2 + [ + self.d[ireg] for ireg in range(self.nregsL1) + ] x = regularized_inversion( self.Op, self.y, @@ -2308,21 +2841,25 @@ def step( epsRs=self.epsRs, x0=self.x0 if self.restart else x, show=show_inner, - **kwargs_lsqr, + engine=engine, + **kwargs_solver, )[0] - # Shrinkage - self.d = [ - _softthreshold( - self.RegsL1[ireg].matvec(x) + self.b[ireg], self.epsRL1s[ireg] - ) - for ireg in range(self.nregsL1) - ] + # shrinkage + if not self.preallocate: + for ireg in range(self.nregsL1): + self.d[ireg] = _softthreshold( + self.RegsL1[ireg].matvec(x) + self.b[ireg], self.epsRL1s[ireg] + ) + else: + for ireg in range(self.nregsL1): + self.ncp.add( + self.RegsL1[ireg].matvec(x), self.b[ireg], out=self.d[ireg] + ) + self.d[ireg] = _softthreshold(self.d[ireg], self.epsRL1s[ireg]) # Bregman update - self.b = [ - self.b[ireg] + self.tau * (self.RegsL1[ireg].matvec(x) - self.d[ireg]) - for ireg in range(self.nregsL1) - ] + for ireg in range(self.nregsL1): + self.b[ireg] += self.tau * (self.RegsL1[ireg].matvec(x) - self.d[ireg]) # compute residual norms self.costdata = ( @@ -2340,7 +2877,7 @@ def step( ) self.costregL1 = [ self.ncp.linalg.norm(RegL1.matvec(x), ord=1) - for epsRL1, RegL1 in zip(self.epsRL1s, self.RegsL1) + for _, RegL1 in zip(self.epsRL1s, self.RegsL1) ] self.costtot = ( self.costdata @@ -2358,6 +2895,7 @@ def step( def run( self, x: NDArray, + engine: str = "scipy", show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), show_inner: bool = False, @@ -2369,6 +2907,8 @@ def run( ---------- x : :obj:`np.ndarray` Current model vector to be updated by multiple steps of IRLS + engine : :obj:`str`, optional + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -2403,8 +2943,9 @@ def run( ) else False ) - x = self.step(x, showstep, show_inner, **kwargs_lsqr) + x = self.step(x, engine, showstep, show_inner, **kwargs_lsqr) self.callback(x) + return x def finalize(self, show: bool = False) -> NDArray: @@ -2443,6 +2984,8 @@ def solve( tol: float = 1e-10, tau: float = 1.0, restart: bool = False, + engine: str = "scipy", + preallocate: bool = False, show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), show_inner: bool = False, @@ -2491,6 +3034,14 @@ def solve( the initial guess (``True``) or with the last estimate (``False``). Note that when this is set to ``True``, the ``x0`` provided in the setup will be used in all iterations. + engine : :obj:`str`, optional + Solver to use (``scipy`` or ``pylops``) + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -2528,10 +3079,16 @@ def solve( tol=tol, tau=tau, restart=restart, + preallocate=preallocate, show=show, ) x = self.run( - x, show=show, itershow=itershow, show_inner=show_inner, **kwargs_lsqr + x, + engine=engine, + show=show, + itershow=itershow, + show_inner=show_inner, + **kwargs_lsqr, ) self.finalize(show) return x, self.iiter, self.cost diff --git a/pylops/optimization/leastsquares.py b/pylops/optimization/leastsquares.py index 3174ee6c8..0e61b8a26 100644 --- a/pylops/optimization/leastsquares.py +++ b/pylops/optimization/leastsquares.py @@ -239,9 +239,9 @@ def preconditioned_inversion( x0 : :obj:`numpy.ndarray` Initial guess of size :math:`[M \times 1]` engine : :obj:`str`, optional - Solver to use (``scipy`` or ``pylops``) + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional - Display normal equations solver log + Display normal equations solver log **kwargs_solver Arbitrary keyword arguments for chosen solver (:py:func:`scipy.sparse.linalg.lsqr` and diff --git a/pylops/optimization/sparsity.py b/pylops/optimization/sparsity.py index 6838e9461..c7c720efb 100644 --- a/pylops/optimization/sparsity.py +++ b/pylops/optimization/sparsity.py @@ -28,9 +28,11 @@ def irls( tolIRLS: float = 1e-10, warm: bool = False, kind: str = "data", + engine: str = "scipy", show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, + preallocate: bool = False, **kwargs_solver, ) -> Tuple[NDArray, int]: r"""Iteratively reweighted least squares. @@ -74,6 +76,8 @@ def irls( This only applies to ``kind="data"`` and ``kind="datamodel"`` kind : :obj:`str`, optional Kind of solver (``model``, ``data`` or ``datamodel``) + engine : :obj:`str`, optional + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display logs itershow : :obj:`tuple`, optional @@ -83,6 +87,12 @@ def irls( callback : :obj:`callable`, optional Function with signature (``callback(x)``) to call after each iteration where ``x`` is the current model vector + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. **kwargs_solver Arbitrary keyword arguments for :py:func:`scipy.sparse.linalg.cg` solver for data IRLS and @@ -113,8 +123,10 @@ def irls( epsR=epsR, epsI=epsI, tolIRLS=tolIRLS, - warm=warm, kind=kind, + warm=warm, + engine=engine, + preallocate=preallocate, show=show, itershow=itershow, **kwargs_solver, @@ -131,9 +143,11 @@ def omp( normalizecols: bool = False, Opbasis: Optional["LinearOperator"] = None, optimal_coeff: bool = False, + engine: str = "scipy", show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, + preallocate: bool = False, ) -> Tuple[NDArray, int, NDArray]: r"""Orthogonal Matching Pursuit (OMP). @@ -169,6 +183,8 @@ def omp( :math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the directly the value from the inner product :math:`\mathbf{Op}_j^H\,\mathbf{r}_k`. + engine : :obj:`str`, optional + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display iterations log itershow : :obj:`tuple`, optional @@ -179,7 +195,12 @@ def omp( Function with signature (``callback(x, cols)``) to call after each iteration where ``x`` contains the non-zero model coefficient and ``cols`` are the indices where the current model vector is non-zero + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. Returns ------- xinv : :obj:`numpy.ndarray` @@ -212,8 +233,10 @@ def omp( normalizecols=normalizecols, Opbasis=Opbasis, optimal_coeff=optimal_coeff, + engine=engine, show=show, itershow=itershow, + preallocate=preallocate, ) return x, niter_outer, cost @@ -235,6 +258,7 @@ def ista( show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, + preallocate: bool = False, ) -> Tuple[NDArray, int, NDArray]: r"""Iterative Shrinkage-Thresholding Algorithm (ISTA). @@ -289,6 +313,12 @@ def ista( callback : :obj:`callable`, optional Function with signature (``callback(x)``) to call after each iteration where ``x`` is the current model vector + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. Returns ------- @@ -340,6 +370,7 @@ def ista( monitorres=monitorres, show=show, itershow=itershow, + preallocate=preallocate, ) return x, iiter, cost @@ -361,6 +392,7 @@ def fista( show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), callback: Optional[Callable] = None, + preallocate: bool = False, ) -> Tuple[NDArray, int, NDArray]: r"""Fast Iterative Shrinkage-Thresholding Algorithm (FISTA). @@ -415,6 +447,12 @@ def fista( callback : :obj:`callable`, optional Function with signature (``callback(x)``) to call after each iteration where ``x`` is the current model vector + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. Returns ------- @@ -464,6 +502,7 @@ def fista( monitorres=monitorres, show=show, itershow=itershow, + preallocate=preallocate, ) return x, iiter, cost @@ -600,10 +639,12 @@ def splitbregman( tol: float = 1e-10, tau: float = 1.0, restart: bool = False, + engine: str = "scipy", show: bool = False, itershow: Tuple[int, int, int] = (10, 10, 10), show_inner: bool = False, callback: Optional[Callable] = None, + preallocate: bool = False, **kwargs_lsqr, ) -> Tuple[NDArray, int, NDArray]: r"""Split Bregman for mixed L2-L1 norms. @@ -653,6 +694,8 @@ def splitbregman( restart : :obj:`bool`, optional The unconstrained inverse problem in inner loop is initialized with the initial guess (``True``) or with the last estimate (``False``) + engine : :obj:`str`, optional + Solver to use (``scipy`` or ``pylops``) show : :obj:`bool`, optional Display iterations log itershow : :obj:`tuple`, optional @@ -664,6 +707,12 @@ def splitbregman( callback : :obj:`callable`, optional Function with signature (``callback(x)``) to call after each iteration where ``x`` is the current model vector + preallocate : :obj:`bool`, optional + .. versionadded:: 2.6.0 + + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. **kwargs_lsqr Arbitrary keyword arguments for :py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first @@ -700,6 +749,8 @@ def splitbregman( tol=tol, tau=tau, restart=restart, + engine=engine, + preallocate=preallocate, show=show, itershow=itershow, show_inner=show_inner, diff --git a/pylops/waveeqprocessing/seismicinterpolation.py b/pylops/waveeqprocessing/seismicinterpolation.py index e1159d4aa..3d18187e9 100644 --- a/pylops/waveeqprocessing/seismicinterpolation.py +++ b/pylops/waveeqprocessing/seismicinterpolation.py @@ -271,6 +271,8 @@ def SeismicInterpolation( Pop = FFT2D(dims=dims, nffts=nffts, sampling=sampling) Pop = Pop.H SIop = Rop * Pop + # Force data to be of same dtype of operator + data = data.astype(SIop.dtype) elif "chirpradon" in kind: prec = True dotcflag = 0 diff --git a/tutorials/solvers.py b/tutorials/solvers.py index c0bfe72c8..58037c41a 100755 --- a/tutorials/solvers.py +++ b/tutorials/solvers.py @@ -224,7 +224,7 @@ y, [D2op], epsRs=[np.sqrt(0.1)], - **dict(damp=np.sqrt(1e-4), iter_lim=50, show=0) + **dict(damp=np.sqrt(1e-4), iter_lim=50, show=0), )[0] ############################################################################### @@ -292,8 +292,8 @@ # \epsilon \|\mathbf{p}\|_1 # # where :math:`\mathbf{F}` is the FFT operator. We will thus use the -# :py:class:`pylops.optimization.sparsity.ista` and -# :py:class:`pylops.optimization.sparsity.fista` solvers to estimate our input +# :py:func:`pylops.optimization.sparsity.ista` and +# :py:func:`pylops.optimization.sparsity.fista` solvers to estimate our input # signal. pista, niteri, costi = pylops.optimization.sparsity.ista( @@ -347,7 +347,7 @@ # converges much faster than ISTA as expected and should be preferred when # using sparse solvers. # -# Finally we consider a slightly different cost function (note that in this +# We now consider a slightly different cost function (note that in this # case we try to solve a constrained problem): # # .. math:: @@ -356,7 +356,7 @@ # \mathbf{R} \mathbf{F} \mathbf{p}\| # # A very popular solver to solve such kind of cost function is called *spgl1* -# and can be accessed via :py:class:`pylops.optimization.sparsity.spgl1`. +# and can be accessed via :py:func:`pylops.optimization.sparsity.spgl1`. xspgl1, pspgl1, info = pylops.optimization.sparsity.spgl1( Rop, y, SOp=FFTop, tau=3, iter_lim=200 @@ -384,3 +384,35 @@ ax.legend() ax.grid(True) plt.tight_layout() + +############################################################################### +# Finally, we go back to the :py:func:`pylops.optimization.sparsity.ista` +# solver and show a new feature that was introduced in PyLops v2.6.0. +# When the solver is run with ``preallocate=True``, all internal vectors +# are pre-allocated in the ``setup`` method of the +# :py:class:`pylops.optimization.cls_sparsity.ISTA` class. This is likely to +# improve the performance of the solver (see :ref:`sphx_glr_tutorials_classsolvers.py` +# for more details), especially when it is applied to large problems. + +# Original ISTA +pista = pylops.optimization.sparsity.ista( + Rop * FFTop.H, + y, + niter=100, + alpha=1e-2, + eps=0.1, + tol=1e-7, +)[0] + +# ISTA with preallocation +pista_prealloc = pylops.optimization.sparsity.ista( + Rop * FFTop.H, + y, + niter=100, + alpha=1e-2, + eps=0.1, + tol=1e-7, + preallocate=True, +)[0] + +print(f"Norm of difference: {np.linalg.norm(pista - pista_prealloc)}")