diff --git a/docs/source/sg_execution_times.rst b/docs/source/sg_execution_times.rst
deleted file mode 100644
index 9641132d2..000000000
--- a/docs/source/sg_execution_times.rst
+++ /dev/null
@@ -1,271 +0,0 @@
-
-:orphan:
-
-.. _sphx_glr_sg_execution_times:
-
-
-Computation times
-=================
-**00:00.703** total execution time for 79 files **from all galleries**:
-
-.. container::
-
- .. raw:: html
-
-
-
-
-
-
-
- .. list-table::
- :header-rows: 1
- :class: table table-striped sg-datatable
-
- * - Example
- - Time
- - Mem (MB)
- * - :ref:`sphx_glr_tutorials_jaxop.py` (``../../tutorials/jaxop.py``)
- - 00:00.703
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_avo.py` (``../../examples/plot_avo.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_bayeslinearregr.py` (``../../examples/plot_bayeslinearregr.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_bilinear.py` (``../../examples/plot_bilinear.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_blending.py` (``../../examples/plot_blending.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_causalintegration.py` (``../../examples/plot_causalintegration.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_cgls.py` (``../../examples/plot_cgls.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_chirpradon.py` (``../../examples/plot_chirpradon.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_conj.py` (``../../examples/plot_conj.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_convolve.py` (``../../examples/plot_convolve.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_dct.py` (``../../examples/plot_dct.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_derivative.py` (``../../examples/plot_derivative.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_describe.py` (``../../examples/plot_describe.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_diagonal.py` (``../../examples/plot_diagonal.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_dtcwt.py` (``../../examples/plot_dtcwt.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_fft.py` (``../../examples/plot_fft.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_flip.py` (``../../examples/plot_flip.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_fourierradon.py` (``../../examples/plot_fourierradon.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_identity.py` (``../../examples/plot_identity.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_imag.py` (``../../examples/plot_imag.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_ista.py` (``../../examples/plot_ista.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_l1l1.py` (``../../examples/plot_l1l1.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_linearregr.py` (``../../examples/plot_linearregr.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_matrixmult.py` (``../../examples/plot_matrixmult.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_mdc.py` (``../../examples/plot_mdc.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_multiproc.py` (``../../examples/plot_multiproc.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_nmo.py` (``../../examples/plot_nmo.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_nonstatconvolve.py` (``../../examples/plot_nonstatconvolve.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_nonstatfilter.py` (``../../examples/plot_nonstatfilter.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_pad.py` (``../../examples/plot_pad.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_patching.py` (``../../examples/plot_patching.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_phaseshift.py` (``../../examples/plot_phaseshift.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_prestack.py` (``../../examples/plot_prestack.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_radon.py` (``../../examples/plot_radon.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_real.py` (``../../examples/plot_real.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_regr.py` (``../../examples/plot_regr.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_restriction.py` (``../../examples/plot_restriction.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_roll.py` (``../../examples/plot_roll.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_seislet.py` (``../../examples/plot_seislet.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_seismicevents.py` (``../../examples/plot_seismicevents.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_shift.py` (``../../examples/plot_shift.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_sliding.py` (``../../examples/plot_sliding.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_slopeest.py` (``../../examples/plot_slopeest.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_smoothing1d.py` (``../../examples/plot_smoothing1d.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_smoothing2d.py` (``../../examples/plot_smoothing2d.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_spread.py` (``../../examples/plot_spread.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_stacking.py` (``../../examples/plot_stacking.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_sum.py` (``../../examples/plot_sum.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_symmetrize.py` (``../../examples/plot_symmetrize.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_tapers.py` (``../../examples/plot_tapers.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_tndarray.py` (``../../examples/plot_tndarray.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_transpose.py` (``../../examples/plot_transpose.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_tvreg.py` (``../../examples/plot_tvreg.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_twoway.py` (``../../examples/plot_twoway.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_wavelet.py` (``../../examples/plot_wavelet.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_wavest.py` (``../../examples/plot_wavest.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_wavs.py` (``../../examples/plot_wavs.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_gallery_plot_zero.py` (``../../examples/plot_zero.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_bayesian.py` (``../../tutorials/bayesian.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_classsolvers.py` (``../../tutorials/classsolvers.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_ctscan.py` (``../../tutorials/ctscan.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_deblending.py` (``../../tutorials/deblending.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_deblurring.py` (``../../tutorials/deblurring.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_deghosting.py` (``../../tutorials/deghosting.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_dottest.py` (``../../tutorials/dottest.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_ilsm.py` (``../../tutorials/ilsm.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_interpolation.py` (``../../tutorials/interpolation.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_linearoperator.py` (``../../tutorials/linearoperator.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_lsm.py` (``../../tutorials/lsm.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_marchenko.py` (``../../tutorials/marchenko.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_mdd.py` (``../../tutorials/mdd.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_poststack.py` (``../../tutorials/poststack.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_prestack.py` (``../../tutorials/prestack.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_radonfiltering.py` (``../../tutorials/radonfiltering.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_realcomplex.py` (``../../tutorials/realcomplex.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_seismicinterpolation.py` (``../../tutorials/seismicinterpolation.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_solvers.py` (``../../tutorials/solvers.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_torchop.py` (``../../tutorials/torchop.py``)
- - 00:00.000
- - 0.0
- * - :ref:`sphx_glr_tutorials_wavefielddecomposition.py` (``../../tutorials/wavefielddecomposition.py``)
- - 00:00.000
- - 0.0
diff --git a/pylops/optimization/basesolver.py b/pylops/optimization/basesolver.py
index ded819424..490abcf3b 100644
--- a/pylops/optimization/basesolver.py
+++ b/pylops/optimization/basesolver.py
@@ -1,6 +1,7 @@
__all__ = ["Solver"]
import functools
+import logging
import time
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Any
@@ -11,6 +12,8 @@
if TYPE_CHECKING:
from pylops.linearoperator import LinearOperator
+_units = {"B": 1, "KB": 1024, "MB": 1024**2, "GB": 1024**3}
+
class Solver(metaclass=ABCMeta):
r"""Solver
@@ -19,6 +22,8 @@ class Solver(metaclass=ABCMeta):
This class comprises of the following mandatory methods:
- ``__init__``: initialization method to which the operator `Op` must be passed
+ - ``memory_usage``: a method to compute upfront the memory used by each
+ step of the solver
- ``setup``: a method that is invoked to setup the solver, basically it will create
anything required prior to applying a step of the solver
- ``step``: a method applying a single step of the solver
@@ -121,11 +126,55 @@ def wrapper(*args, **kwargs):
),
)
+ def _setpreallocate(self, preallocate: bool) -> None:
+ # Check if the solver can work in preallocate mode
+ # (basically all the time except when JAX arrays are
+ # used) and force it to be False otherwise.
+ self.preallocate = preallocate if not self.isjax else False
+
+ if preallocate and self.isjax:
+ logging.warning(
+ "Preallocation is not supported for JAX arrays. "
+ "Setting preallocate to False."
+ )
+
+ @abstractmethod
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ This method computes an estimate of the memory required by the solver given
+ the shape of the operator. This is useful to assess upfront if the solver
+ will run out of memory.
+
+ Note, that the memory usage of the operator itself is not taken into account
+ in this estimate.
+
+ Parameters
+ ----------
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in bytes
+
+ """
+ pass
+
@abstractmethod
def setup(
self,
y: NDArray,
*args,
+ preallocate: bool = False,
show: bool = False,
**kwargs,
) -> None:
@@ -138,6 +187,10 @@ def setup(
----------
y : :obj:`np.ndarray`
Data of size :math:`[N \times 1]`
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver.
show : :obj:`bool`, optional
Display setup log
diff --git a/pylops/optimization/basic.py b/pylops/optimization/basic.py
index 34b03bb71..f8b060d2d 100644
--- a/pylops/optimization/basic.py
+++ b/pylops/optimization/basic.py
@@ -24,6 +24,7 @@ def cg(
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
+ preallocate: bool = False,
) -> Tuple[NDArray, int, NDArray]:
r"""Conjugate gradient
@@ -51,6 +52,10 @@ def cg(
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver
Returns
-------
@@ -70,7 +75,13 @@ def cg(
if callback is not None:
cgsolve.callback = callback
x, iiter, cost = cgsolve.solve(
- y=y, x0=x0, tol=tol, niter=niter, show=show, itershow=itershow
+ y=y,
+ x0=x0,
+ tol=tol,
+ niter=niter,
+ show=show,
+ itershow=itershow,
+ preallocate=preallocate,
)
return x, iiter, cost
@@ -86,6 +97,7 @@ def cgls(
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
+ preallocate: bool = False,
) -> Tuple[NDArray, int, int, float, float, NDArray]:
r"""Conjugate gradient least squares
@@ -115,6 +127,10 @@ def cgls(
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.5.0
+
+ Pre-allocate all variables used by the solver
Returns
-------
@@ -149,7 +165,14 @@ def cgls(
if callback is not None:
cgsolve.callback = callback
x, istop, iiter, r1norm, r2norm, cost = cgsolve.solve(
- y=y, x0=x0, tol=tol, niter=niter, damp=damp, show=show, itershow=itershow
+ y=y,
+ x0=x0,
+ tol=tol,
+ niter=niter,
+ damp=damp,
+ show=show,
+ itershow=itershow,
+ preallocate=preallocate,
)
return x, istop, iiter, r1norm, r2norm, cost
@@ -168,6 +191,7 @@ def lsqr(
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
+ preallocate: bool = False,
) -> Tuple[NDArray, int, int, float, float, float, float, float, float, float, NDArray]:
r"""LSQR
@@ -213,6 +237,10 @@ def lsqr(
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver
Returns
-------
@@ -299,5 +327,6 @@ def lsqr(
calc_var=calc_var,
show=show,
itershow=itershow,
+ preallocate=preallocate,
)
return x, istop, iiter, r1norm, r2norm, anorm, acond, arnorm, xnorm, var, cost
diff --git a/pylops/optimization/cls_basic.py b/pylops/optimization/cls_basic.py
index 1fb1ee7da..0713773ea 100644
--- a/pylops/optimization/cls_basic.py
+++ b/pylops/optimization/cls_basic.py
@@ -9,10 +9,10 @@
import numpy as np
-from pylops.optimization.basesolver import Solver
+from pylops.optimization.basesolver import Solver, _units
from pylops.utils.backend import (
get_array_module,
- to_cupy_conditional,
+ get_module_name,
to_numpy,
to_numpy_conditional,
)
@@ -64,12 +64,48 @@ def _print_step(self, x: NDArray) -> None:
msg = f"{self.iiter:6g} " + strx + f"{self.cost[self.iiter]:11.4e}"
print(msg)
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: x0 - y, self.r, self.c
+ memuse = (self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes
+
+ # Step (additional variables to those in setup): c1 - Opc
+ memuse += (self.Op.shape[1] + self.Op.shape[0]) * nbytes
+
+ if show:
+ print(f"CG predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
+
+ return memuse
+
def setup(
self,
y: NDArray,
x0: Optional[NDArray] = None,
niter: Optional[int] = None,
tol: float = 1e-4,
+ preallocate: bool = False,
show: bool = False,
) -> NDArray:
r"""Setup solver
@@ -86,6 +122,13 @@ def setup(
manually step over the solver)
tol : :obj:`float`, optional
Tolerance on residual norm
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
+
show : :obj:`bool`, optional
Display setup log
@@ -98,7 +141,10 @@ def setup(
self.y = y
self.niter = niter
self.tol = tol
+
self.ncp = get_array_module(y)
+ self.isjax = get_module_name(self.ncp) == "jax"
+ self._setpreallocate(preallocate)
# initialize solver
if x0 is None:
@@ -106,10 +152,18 @@ def setup(
self.r = self.y.copy()
else:
x = x0.copy()
- self.r = self.y - self.Op.matvec(x)
+ if not self.preallocate:
+ self.r = self.y - self.Op.matvec(x)
+ else:
+ self.r = self.ncp.empty_like(self.y)
+ self.ncp.subtract(self.y, self.Op.matvec(x), out=self.r)
self.c = self.r.copy()
self.kold = self.ncp.abs(self.r.dot(self.r.conj()))
+ # initialize other internal variabled
+ if self.preallocate:
+ self.c1 = self.ncp.empty_like(x)
+
# create variables to track the residual norm and iterations
self.cost: List = []
self.cost.append(float(np.sqrt(self.kold)))
@@ -136,14 +190,24 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
Updated model vector
"""
- Opc = self.Op.matvec(to_cupy_conditional(x, self.c))
+ Opc = self.Op.matvec(self.c)
cOpc = self.ncp.abs(self.c.dot(Opc.conj()))
a = self.kold / cOpc
- x += to_cupy_conditional(x, a) * to_cupy_conditional(x, self.c)
- self.r -= a * Opc
+ if not self.preallocate:
+ x += a * self.c
+ self.r -= a * Opc
+ else:
+ self.ncp.multiply(self.c, a, out=self.c1)
+ self.ncp.add(x, self.c1, out=x)
+ self.ncp.multiply(Opc, a, out=Opc)
+ self.ncp.subtract(self.r, Opc, out=self.r)
k = self.ncp.abs(self.r.dot(self.r.conj()))
b = k / self.kold
- self.c = self.r + b * self.c
+ if not self.preallocate:
+ self.c = self.r + b * self.c
+ else:
+ self.ncp.multiply(self.c, b, out=self.c)
+ self.ncp.add(self.c, self.r, out=self.c)
self.kold = k
self.iiter += 1
self.cost.append(float(np.sqrt(self.kold)))
@@ -219,6 +283,7 @@ def solve(
x0: Optional[NDArray] = None,
niter: int = 10,
tol: float = 1e-4,
+ preallocate: bool = False,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
) -> Tuple[NDArray, int, NDArray]:
@@ -235,6 +300,12 @@ def solve(
Number of iterations
tol : :obj:`float`, optional
Tolerance on residual norm
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -252,7 +323,9 @@ def solve(
History of the L2 norm of the residual
"""
- x = self.setup(y=y, x0=x0, niter=niter, tol=tol, show=show)
+ x = self.setup(
+ y=y, x0=x0, niter=niter, tol=tol, preallocate=preallocate, show=show
+ )
x = self.run(x, niter, show=show, itershow=itershow)
self.finalize(show)
return x, self.iiter, self.cost
@@ -307,6 +380,41 @@ def _print_step(self, x: NDArray) -> None:
)
print(msg)
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: x0, self.c - y, self.s, self.q
+ memuse = (2 * self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes
+
+ # Step (additional variables to those in setup): r, x1, c1
+ memuse += (3 * self.Op.shape[1]) * nbytes
+
+ if show:
+ print(f"CGLS predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -314,6 +422,7 @@ def setup(
niter: Optional[int] = None,
damp: float = 0.0,
tol: float = 1e-4,
+ preallocate: bool = False,
show: bool = False,
) -> NDArray:
r"""Setup solver
@@ -323,7 +432,7 @@ def setup(
y : :obj:`np.ndarray`
Data of size :math:`[N \times 1]`
x0 : :obj:`np.ndarray`, optional
- Initial guess of size :math:`[M \times 1]`. If ``None``, initialize
+ Initial guess of size :math:`[M \times 1]`. If ``None``, initialize
internally as zero vector
niter : :obj:`int`, optional
Number of iterations (default to ``None`` in case a user wants to
@@ -332,6 +441,12 @@ def setup(
Damping coefficient
tol : :obj:`float`, optional
Tolerance on residual norm
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display setup log
@@ -345,20 +460,36 @@ def setup(
self.damp = damp**2
self.tol = tol
self.niter = niter
+
self.ncp = get_array_module(y)
+ self.isjax = get_module_name(self.ncp) == "jax"
+ self._setpreallocate(preallocate)
# initialize solver
if x0 is None:
x = self.ncp.zeros(self.Op.shape[1], dtype=y.dtype)
self.s = self.y.copy()
- r = self.Op.rmatvec(self.s)
+ self.c = self.Op.rmatvec(self.s)
else:
x = x0.copy()
- self.s = self.y - self.Op.matvec(x)
- r = self.Op.rmatvec(self.s) - damp * x
- self.c = r.copy()
+ if not self.preallocate:
+ self.s = self.y - self.Op.matvec(x)
+ self.c = self.Op.rmatvec(self.s) - damp * x
+ else:
+ self.s = self.ncp.empty_like(self.y)
+ self.ncp.subtract(self.y, self.Op.matvec(x), out=self.s)
+ x1 = self.ncp.empty_like(x)
+ self.c = self.ncp.empty_like(x)
+ self.ncp.multiply(x, damp, out=x1)
+ self.ncp.subtract(self.Op.rmatvec(self.s), x1, out=self.c)
self.q = self.Op.matvec(self.c)
- self.kold = self.ncp.abs(r.dot(r.conj()))
+ self.kold = self.ncp.abs(self.c.dot(self.c.conj()))
+
+ # initialize other internal variables
+ if self.preallocate:
+ self.c1 = self.ncp.empty_like(self.c)
+ self.x1 = self.ncp.empty_like(x)
+ self.r = self.ncp.empty_like(x)
# create variables to track the residual norm and iterations
self.cost = []
@@ -390,12 +521,32 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
a = self.kold / (
self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj())
)
- x = x + a * self.c
- self.s = self.s - to_numpy_conditional(self.q, a) * self.q
- r = self.Op.rmatvec(self.s) - self.damp * x
- k = self.ncp.abs(r.dot(r.conj()))
+ if not self.preallocate:
+ x = x + a * self.c
+ self.s = self.s - a * self.q
+ r = self.Op.rmatvec(self.s) - self.damp * x
+ else:
+ self.ncp.multiply(self.c, a, out=self.c1)
+ self.ncp.add(x, self.c1, out=x)
+
+ self.ncp.multiply(self.q, a, out=self.q)
+ self.ncp.subtract(self.s, self.q, out=self.s)
+
+ self.ncp.multiply(x, self.damp, out=self.x1)
+ self.ncp.subtract(
+ self.Op.rmatvec(self.s),
+ self.x1,
+ out=self.r,
+ )
+ k = self.ncp.abs(
+ self.r.dot(self.r.conj()) if self.preallocate else r.dot(r.conj())
+ )
b = k / self.kold
- self.c = r + b * self.c
+ if not self.preallocate:
+ self.c = r + b * self.c
+ else:
+ self.ncp.multiply(self.c, b, out=self.c)
+ self.ncp.add(self.c, self.r, out=self.c)
self.q = self.Op.matvec(self.c)
self.kold = k
self.iiter += 1
@@ -485,6 +636,7 @@ def solve(
niter: int = 10,
damp: float = 0.0,
tol: float = 1e-4,
+ preallocate: bool = False,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
) -> Tuple[NDArray, int, int, float, float, NDArray]:
@@ -504,6 +656,12 @@ def solve(
Damping coefficient
tol : :obj:`float`, optional
Tolerance on residual norm
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -536,7 +694,15 @@ def solve(
History of r1norm through iterations
"""
- x = self.setup(y=y, x0=x0, niter=niter, damp=damp, tol=tol, show=show)
+ x = self.setup(
+ y=y,
+ x0=x0,
+ niter=niter,
+ damp=damp,
+ tol=tol,
+ preallocate=preallocate,
+ show=show,
+ )
x = self.run(x, niter, show=show, itershow=itershow)
self.finalize(show)
return x, self.istop, self.iiter, self.r1norm, self.r2norm, self.cost
@@ -631,6 +797,41 @@ def _print_finalize(self) -> None:
print(str5)
print("-" * 90 + "\n")
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: x0, self.v, self.w, self.dk - y, self.u
+ memuse = (4 * self.Op.shape[1] + 2 * self.Op.shape[0]) * nbytes
+
+ # Step (additional variables to those in setup): w1
+ memuse += self.Op.shape[1] * nbytes
+
+ if show:
+ print(f"LSQR predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -641,6 +842,7 @@ def setup(
conlim: float = 100000000.0,
niter: int = 10,
calc_var: bool = True,
+ preallocate: bool = False,
show: bool = False,
) -> NDArray:
r"""Setup solver
@@ -670,7 +872,13 @@ def setup(
Number of iterations
calc_var : :obj:`bool`, optional
Estimate diagonals of :math:`(\mathbf{Op}^H\mathbf{Op} +
- \epsilon^2\mathbf{I})^{-1}`.
+ \epsilon^2\mathbf{I})^{-1}`
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display setup log
@@ -687,7 +895,10 @@ def setup(
self.conlim = conlim
self.niter = niter
self.calc_var = calc_var
+
self.ncp = get_array_module(y)
+ self.isjax = get_module_name(self.ncp) == "jax"
+ self._setpreallocate(preallocate)
m, n = self.Op.shape
@@ -719,15 +930,25 @@ def setup(
self.u = y.copy()
else:
x = x0.copy()
- self.u = self.y - self.Op.matvec(x0)
+ if not self.preallocate:
+ self.u = self.y - self.Op.matvec(x0)
+ else:
+ self.u = self.ncp.empty_like(self.y)
+ self.ncp.subtract(self.y, self.Op.matvec(x0), out=self.u)
self.alfa = 0.0
self.beta = self.ncp.linalg.norm(self.u)
if self.beta > 0.0:
- self.u = self.u / self.beta
+ if not self.preallocate:
+ self.u = self.u / self.beta
+ else:
+ self.ncp.divide(self.u, self.beta, out=self.u)
self.v = self.Op.rmatvec(self.u)
self.alfa = self.ncp.linalg.norm(self.v)
if self.alfa > 0:
- self.v = self.v / self.alfa
+ if not self.preallocate:
+ self.v = self.v / self.alfa
+ else:
+ self.ncp.divide(self.v, self.alfa, out=self.v)
else:
self.v = x.copy()
self.alfa = 0
@@ -736,6 +957,11 @@ def setup(
# check if solution is already found
self.arnorm: float = self.alfa * self.beta
+ # initialize other internal variables
+ if self.preallocate:
+ self.dk = self.ncp.empty_like(self.w)
+ self.w1 = self.ncp.empty_like(self.w)
+
# finalize setup
self.arnorm0: float = self.arnorm
self.rhobar: float = self.alfa
@@ -778,19 +1004,31 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
# next beta, u, alfa, v. These satisfy the relations
# beta*u = Op*v - alfa*u,
# alfa*v = Op'*u - beta*v'
- self.u = (
- self.Op.matvec(self.v) - to_numpy_conditional(self.u, self.alfa) * self.u
- )
+ if not self.preallocate:
+ self.u = self.Op.matvec(self.v) - self.alfa * self.u
+ else:
+ self.ncp.multiply(self.u, self.alfa, out=self.u)
+ self.ncp.subtract(self.Op.matvec(self.v), self.u, out=self.u)
self.beta = self.ncp.linalg.norm(self.u)
if self.beta > 0:
- self.u = self.u / self.beta
+ if not self.preallocate:
+ self.u = self.u / self.beta
+ else:
+ self.ncp.divide(self.u, self.beta, out=self.u)
self.anorm = np.linalg.norm(
[self.anorm, to_numpy(self.alfa), to_numpy(self.beta), self.damp]
)
- self.v = self.Op.rmatvec(self.u) - self.beta * self.v
+ if not self.preallocate:
+ self.v = self.Op.rmatvec(self.u) - self.beta * self.v
+ else:
+ self.ncp.multiply(self.v, self.beta, out=self.v)
+ self.ncp.subtract(self.Op.rmatvec(self.u), self.v, out=self.v)
self.alfa = self.ncp.linalg.norm(self.v)
if self.alfa > 0:
- self.v = self.v / self.alfa
+ if not self.preallocate:
+ self.v = self.v / self.alfa
+ else:
+ self.ncp.divide(self.v, self.alfa, out=self.v)
# use a plane rotation to eliminate the damping parameter.
# This alters the diagonal (rhobar) of the lower-bidiagonal matrix.
@@ -814,9 +1052,16 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
# update x and w.
self.t1 = self.phi / self.rho
self.t2 = -self.theta / self.rho
- self.dk = self.w / self.rho
- x = x + self.t1 * self.w
- self.w = self.v + self.t2 * self.w
+ if not self.preallocate:
+ self.dk = self.w / self.rho
+ x = x + self.t1 * self.w
+ self.w = self.v + self.t2 * self.w
+ else:
+ self.ncp.divide(self.w, self.rho, out=self.dk)
+ self.ncp.multiply(self.w, self.t1, out=self.w1)
+ self.ncp.add(x, self.w1, out=x)
+ self.ncp.multiply(self.w, self.t2, out=self.w)
+ self.ncp.add(self.v, self.w, out=self.w)
self.ddnorm = self.ddnorm + self.ncp.linalg.norm(self.dk) ** 2
if self.calc_var:
self.var = self.var + to_numpy_conditional(
@@ -965,6 +1210,7 @@ def solve(
conlim: float = 100000000.0,
niter: int = 10,
calc_var: bool = True,
+ preallocate: bool = False,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
) -> Tuple[
@@ -1008,6 +1254,12 @@ def solve(
calc_var : :obj:`bool`, optional
Estimate diagonals of :math:`(\mathbf{Op}^H\mathbf{Op} +
\epsilon^2\mathbf{I})^{-1}`.
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -1079,6 +1331,7 @@ def solve(
conlim=conlim,
niter=niter,
calc_var=calc_var,
+ preallocate=preallocate,
show=show,
)
x = self.run(x, niter=niter, show=show, itershow=itershow)
diff --git a/pylops/optimization/cls_leastsquares.py b/pylops/optimization/cls_leastsquares.py
index e824c86f2..8d66867bd 100644
--- a/pylops/optimization/cls_leastsquares.py
+++ b/pylops/optimization/cls_leastsquares.py
@@ -13,7 +13,7 @@
from scipy.sparse.linalg import lsqr
from pylops.basicoperators import Diagonal, VStack
-from pylops.optimization.basesolver import Solver
+from pylops.optimization.basesolver import Solver, _units
from pylops.optimization.basic import cg, cgls
from pylops.utils.backend import get_array_module
from pylops.utils.typing import NDArray
@@ -88,6 +88,53 @@ def _print_finalize(self) -> None:
print(f"\nTotal time (s) = {self.telapsed:.2f}")
print("-" * 55 + "\n")
+ def memory_usage(
+ self,
+ nopRegs: Optional[Tuple[int]] = None,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ nopRegs : :obj:`tuple`, optional
+ Number of data elements of ``Regs`` operators
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Convert nopRegs if None
+ if nopRegs is None:
+ nopRegs = 0
+
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: y_normal, data_regs (temporary as these are
+ # later projected and summed to y_normal)
+ memuse = (self.Op.shape[1] + np.prod(nopRegs)) * nbytes
+
+ # Run (additional variables to those in setup): Setup and Step
+ # of CG solver on normal equations
+ memuse += (self.Op.shape[1] + 3 * self.Op.shape[1]) * nbytes
+ memuse += (2 * self.Op.shape[1]) * nbytes
+
+ if show:
+ print(
+ f"NormalEquationsInversion predicted memory usage: {memuse / _units[unit]:.2f} {unit}"
+ )
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -235,27 +282,23 @@ def run(
``<0``: illegal input or breakdown
"""
- if x is not None:
- self.y_normal = self.y_normal - self.Op_normal.matvec(x)
if engine == "scipy" and self.ncp == np:
if "tol" in kwargs_solver:
kwargs_solver["atol"] = kwargs_solver["tol"]
kwargs_solver.pop("tol")
- xinv, istop = sp_cg(self.Op_normal, self.y_normal, **kwargs_solver)
+ xinv, istop = sp_cg(self.Op_normal, self.y_normal, x0=x, **kwargs_solver)
elif engine == "pylops" or self.ncp != np:
if show:
kwargs_solver["show"] = True
xinv = cg(
self.Op_normal,
self.y_normal,
- x0=self.ncp.zeros(self.Op_normal.shape[1], dtype=self.Op_normal.dtype),
+ x0=x,
**kwargs_solver,
)[0]
istop = None
else:
raise NotImplementedError("Engine must be scipy or pylops")
- if x is not None:
- xinv = x + xinv
return xinv, istop
def solve(
@@ -448,6 +491,60 @@ def _print_finalize(self) -> None:
print(f"\nTotal time (s) = {self.telapsed:.2f}")
print("-" * 65 + "\n")
+ def memory_usage(
+ self,
+ nopRegs: Optional[Tuple[int]] = None,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ .. note:: The memory usage is computed assuming that
+ ``engine="pylops"`` is used. When ``engine="scipy"`` is
+ used instead, :func:`scipy.sparse.linalg.lsqr` may consume
+ more memory.
+
+ Parameters
+ ----------
+ nopRegs : :obj:`tuple`, optional
+ Number of data elements of ``Regs`` operators
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Convert nopRegs if None
+ if nopRegs is None:
+ nopRegs = 0
+
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: datatot
+ ndatatot = self.Op.shape[0] + np.prod(nopRegs)
+ memuse = ndatatot * nbytes
+
+ # Run (additional variables to those in setup): Setup and Step
+ # of PyLops CGLS solver on augumented equations. Note that when
+ # engine="scipy", the SciPy LSQR solver is used instead (which
+ # may consume more memory).
+ memuse += (2 * self.Op.shape[1] + 3 * ndatatot) * nbytes
+ memuse += (3 * self.Op.shape[1]) * nbytes
+
+ if show:
+ print(
+ f"RegularizedInversion predicted memory usage: {memuse / _units[unit]:.2f} {unit}"
+ )
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -520,7 +617,7 @@ def setup(
# augumented operator
if self.epsRs is not None and self.dataregs is not None:
for epsR, datareg in zip(self.epsRs, self.dataregs):
- self.datatot = np.hstack((self.datatot, epsR * datareg))
+ self.datatot = self.ncp.hstack((self.datatot, epsR * datareg))
# print setup
if show:
@@ -580,13 +677,11 @@ def run(
Equal to ``r1norm`` if :math:`\epsilon=0`
"""
- if x is not None:
- self.datatot = self.datatot - self.RegOp.matvec(x)
if engine == "scipy" and self.ncp == np:
if show:
kwargs_solver["show"] = 1
xinv, istop, itn, r1norm, r2norm = lsqr(
- self.RegOp, self.datatot, **kwargs_solver
+ self.RegOp, self.datatot, x0=x, **kwargs_solver
)[0:5]
elif engine == "pylops" or self.ncp != np:
if show:
@@ -594,13 +689,11 @@ def run(
xinv, istop, itn, r1norm, r2norm = cgls(
self.RegOp,
self.datatot,
- x0=self.ncp.zeros(self.RegOp.shape[1], dtype=self.RegOp.dtype),
+ x0=x,
**kwargs_solver,
)[0:5]
else:
raise NotImplementedError("Engine must be scipy or pylops")
- if x is not None:
- xinv = x + xinv
return xinv, istop, itn, r1norm, r2norm
def solve(
@@ -717,6 +810,52 @@ def _print_finalize(self) -> None:
print(f"\nTotal time (s) = {self.telapsed:.2f}")
print("-" * 65 + "\n")
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ .. note:: The memory usage is computed assuming that
+ ``engine="pylops"`` is used. When ``engine="scipy"`` is
+ used instead, :func:`scipy.sparse.linalg.lsqr` may consume
+ more memory.
+
+ Parameters
+ ----------
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: y
+ memuse = self.Op.shape[0] * nbytes
+
+ # Run (additional variables to those in setup): Setup and Step
+ # of PyLops CGLS solver on augumented equations. Note that when
+ # engine="scipy", the SciPy LSQR solver is used instead (which
+ # may consume more memory).
+ memuse += (2 * self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes
+ memuse += (3 * self.Op.shape[1]) * nbytes
+
+ if show:
+ print(
+ f"PreconditionedInversion predicted memory usage: {memuse / _units[unit]:.2f} {unit}"
+ )
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -800,14 +939,13 @@ def run(
Equal to ``r1norm`` if :math:`\epsilon=0`
"""
- if x is not None:
- self.y = self.y - self.Op.matvec(x)
if engine == "scipy" and self.ncp == np:
if show:
kwargs_solver["show"] = 1
pinv, istop, itn, r1norm, r2norm = lsqr(
self.POp,
self.y,
+ x0=x,
**kwargs_solver,
)[0:5]
elif engine == "pylops" or self.ncp != np:
@@ -816,7 +954,7 @@ def run(
pinv, istop, itn, r1norm, r2norm = cgls(
self.POp,
self.y,
- x0=self.ncp.zeros(self.POp.shape[1], dtype=self.POp.dtype),
+ x0=x,
**kwargs_solver,
)[0:5]
# force it 1d as we decorate this method with disable_ndarray_multiplication
@@ -824,8 +962,6 @@ def run(
else:
raise NotImplementedError("Engine must be scipy or pylops")
xinv = self.P.matvec(pinv)
- if x is not None:
- xinv = x + xinv
return xinv, istop, itn, r1norm, r2norm
def solve(
diff --git a/pylops/optimization/cls_sparsity.py b/pylops/optimization/cls_sparsity.py
index ceb4c6700..b0bff8638 100644
--- a/pylops/optimization/cls_sparsity.py
+++ b/pylops/optimization/cls_sparsity.py
@@ -9,6 +9,7 @@
import logging
import time
+from math import sqrt
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import numpy as np
@@ -16,12 +17,12 @@
from pylops import LinearOperator
from pylops.basicoperators import Diagonal, Identity, VStack
-from pylops.optimization.basesolver import Solver
+from pylops.optimization.basesolver import Solver, _units
from pylops.optimization.basic import cgls
from pylops.optimization.eigs import power_iteration
from pylops.optimization.leastsquares import regularized_inversion
from pylops.utils import deps
-from pylops.utils.backend import get_array_module, get_module_name
+from pylops.utils.backend import get_array_module, get_module_name, inplace_set
from pylops.utils.typing import InputDimsLike, NDArray, SamplingLike
spgl1_message = deps.spgl1_import("the spgl1 solver")
@@ -29,6 +30,8 @@
if spgl1_message is None:
from spgl1 import spgl1 as ext_spgl1
+logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
+
def _hardthreshold(x: NDArray, thresh: float) -> NDArray:
r"""Hard thresholding.
@@ -54,7 +57,7 @@ def _hardthreshold(x: NDArray, thresh: float) -> NDArray:
"""
x1 = x.copy()
- x1[np.abs(x) <= np.sqrt(2 * thresh)] = 0
+ x1[np.abs(x) <= sqrt(2 * thresh)] = 0
return x1
@@ -324,6 +327,51 @@ def _print_step(self, x: NDArray) -> None:
str2 = f" {self.rnorm:10.3e}"
print(str1 + str2)
+ def memory_usage(
+ self,
+ kind: str = "data",
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ kind : :obj:`str`, optional
+ Kind of solver (``model``, ``data`` or ``datamodel``)
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: y + augmented y if kind=datamodel
+ memuse = self.Op.shape[0] * nbytes
+ if kind == "datamodel":
+ memuse += self.Op.shape[1] * nbytes
+
+ # Step (additional variables to those in setup): rw
+ if kind == "data":
+ memuse += self.Op.shape[0] * nbytes
+ elif kind == "model":
+ memuse += self.Op.shape[1] * nbytes
+ elif kind == "datamodel":
+ memuse += 2 * self.Op.shape[0] * nbytes
+
+ if show:
+ print(f"IRLS predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -334,6 +382,7 @@ def setup(
tolIRLS: float = 1e-10,
warm: bool = False,
kind: str = "data",
+ preallocate: bool = False,
show: bool = False,
) -> None:
r"""Setup solver
@@ -360,6 +409,12 @@ def setup(
This only applies to ``kind="data"`` and ``kind="datamodel"``
kind : :obj:`str`, optional
Kind of solver (``model``, ``data`` or ``datamodel``)
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display setup log
@@ -372,7 +427,12 @@ def setup(
self.tolIRLS = tolIRLS
self.warm = warm
self.kind = kind
+
self.ncp = get_array_module(y)
+ self.isjax = get_module_name(self.ncp) == "jax"
+ self._setpreallocate(preallocate)
+
+ # initiate outer iteration counter
self.iiter = 0
# choose step to use
@@ -386,33 +446,61 @@ def setup(
# augment Op and y
self.Op = VStack([self.Op, epsI * Identity(self.Op.shape[1])])
self.epsI = 0.0 # as epsI is added to the augmented system already
- self.y = np.hstack([self.y, np.zeros(self.Op.shape[1])])
+ self.y = self.ncp.hstack([self.y, self.ncp.zeros(self.Op.shape[1])])
else:
raise NotImplementedError("kind must be model, data or datamodel")
+ if self.preallocate:
+ self.r = self.ncp.empty_like(self.y)
+ if "data" in self.kind:
+ self.rw = self.ncp.empty_like(self.y)
+ else:
+ self.rw = self.ncp.empty(self.Op.shape[1], dtype=self.Op.dtype)
# print setup
if show:
self._print_setup()
- def _step_data(self, x: NDArray, **kwargs_solver) -> NDArray:
+ def _step_data(self, x: NDArray, engine: str = "scipy", **kwargs_solver) -> NDArray:
r"""Run one step of solver with L1 data term"""
+ # add preallocate to keywords of solver
+ if self.preallocate and (engine == "pylops" or self.ncp != np):
+ kwargs_solver["preallocate"] = True
if self.iiter == 0:
+ # first iteration (standard least-squares)
x = regularized_inversion(
self.Op,
self.y,
None,
x0=x if self.warm else None,
damp=self.epsI,
+ engine=engine,
**kwargs_solver,
)[0]
else:
# other iterations (weighted least-squares)
- if self.threshR:
- self.rw = 1.0 / self.ncp.maximum(self.ncp.abs(self.r), self.epsR)
+ if self.preallocate and self.iiter == 1:
+ self.rw = self.ncp.zeros_like(self.y)
+
+ if not self.preallocate:
+ if self.threshR:
+ self.rw = 1.0 / self.ncp.maximum(self.ncp.abs(self.r), self.epsR)
+ else:
+ self.rw = 1.0 / (self.ncp.abs(self.r) + self.epsR)
+ self.rw = self.rw / self.rw.max()
else:
- self.rw = 1.0 / (self.ncp.abs(self.r) + self.epsR)
- self.rw = self.rw / self.rw.max()
- R = Diagonal(np.sqrt(self.rw))
+ if self.threshR:
+ self.ncp.divide(
+ 1.0,
+ self.ncp.maximum(self.ncp.abs(self.r), self.epsR),
+ out=self.rw,
+ )
+ else:
+ self.ncp.divide(
+ 1.0, (self.ncp.abs(self.r) + self.epsR), out=self.rw
+ )
+ self.ncp.divide(self.rw, self.rw.max(), out=self.rw)
+
+ R = Diagonal(self.ncp.sqrt(self.rw))
x = regularized_inversion(
self.Op,
self.y,
@@ -420,15 +508,21 @@ def _step_data(self, x: NDArray, **kwargs_solver) -> NDArray:
Weight=R,
x0=x if self.warm else None,
damp=self.epsI,
+ engine=engine,
**kwargs_solver,
)[0]
return x
- def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray:
+ def _step_model(
+ self, x: NDArray, engine: str = "scipy", **kwargs_solver
+ ) -> NDArray:
r"""Run one step of solver with L1 model term"""
+ # add preallocate to keywords of solver
+ if self.preallocate and (engine == "pylops" or self.ncp != np):
+ kwargs_solver["preallocate"] = True
if self.iiter == 0:
# first iteration (unweighted least-squares)
- if self.ncp == np:
+ if engine == "scipy" and self.ncp == np:
x = self.Op.rmatvec(
lsqr(
self.Op @ self.Op.H + (self.epsI**2) * self.Iop,
@@ -436,7 +530,7 @@ def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray:
**kwargs_solver,
)[0]
)
- else:
+ elif engine == "pylops" or self.ncp != np:
x = self.Op.rmatvec(
cgls(
self.Op @ self.Op.H + (self.epsI**2) * self.Iop,
@@ -447,10 +541,17 @@ def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray:
)
else:
# other iterations (weighted least-squares)
- self.rw = np.abs(x)
- self.rw = self.rw / self.rw.max()
+ if self.preallocate and self.iiter == 1:
+ self.rw = self.ncp.zeros_like(x)
+ if not self.preallocate:
+ self.rw = self.ncp.abs(x)
+ self.rw = self.rw / self.rw.max()
+ else:
+ self.ncp.abs(x, out=self.rw)
+ self.ncp.divide(self.rw, self.rw.max(), out=self.rw)
+
R = Diagonal(self.rw, dtype=self.rw.dtype)
- if self.ncp == np:
+ if engine == "scipy" and self.ncp == np:
x = R.matvec(
self.Op.rmatvec(
lsqr(
@@ -460,7 +561,7 @@ def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray:
)[0]
)
)
- else:
+ elif engine == "pylops" or self.ncp != np:
x = R.matvec(
self.Op.rmatvec(
cgls(
@@ -473,21 +574,33 @@ def _step_model(self, x: NDArray, **kwargs_solver) -> NDArray:
)
return x
- def step(self, x: NDArray, show: bool = False, **kwargs_solver) -> NDArray:
+ def step(
+ self,
+ x: NDArray,
+ engine: str = "scipy",
+ show: bool = False,
+ **kwargs_solver,
+ ) -> NDArray:
r"""Run one step of solver
Parameters
----------
x : :obj:`np.ndarray`
Current model vector to be updated by a step of ISTA
+ engine : :obj:`str`, optional
+ .. versionadded:: 2.6.0
+
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display iteration log
**kwargs_solver
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.cg` solver for data IRLS and
:py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using
- numpy data(or :py:func:`pylops.optimization.solver.cg` and
- :py:func:`pylops.optimization.solver.cgls` when using cupy data)
+ numpy data and ``engine='scipy'`` (or
+ :py:func:`pylops.optimization.solver.cg` and
+ :py:func:`pylops.optimization.solver.cgls` when using cupy data or
+ ``engine='pylops'``)
Returns
-------
@@ -496,10 +609,13 @@ def step(self, x: NDArray, show: bool = False, **kwargs_solver) -> NDArray:
"""
# update model
- x = self._step(x, **kwargs_solver)
+ x = self._step(x, engine=engine, **kwargs_solver)
# compute residual
- self.r: NDArray = self.y - self.Op.matvec(x)
+ if not self.preallocate:
+ self.r: NDArray = self.y - self.Op.matvec(x)
+ else:
+ self.ncp.subtract(self.y, self.Op.matvec(x), out=self.r)
self.rnorm = self.ncp.linalg.norm(self.r)
self.iiter += 1
@@ -509,8 +625,9 @@ def step(self, x: NDArray, show: bool = False, **kwargs_solver) -> NDArray:
def run(
self,
- x: NDArray,
+ x: Optional[NDArray],
nouter: int = 10,
+ engine: str = "scipy",
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
**kwargs_solver,
@@ -520,9 +637,14 @@ def run(
Parameters
----------
x : :obj:`np.ndarray`
- Current model vector to be updated by multiple steps of IRLS
+ Current model vector to be updated by multiple steps of IRLS. Provide
+ ``None`` to initialize internally as zero vector
nouter : :obj:`int`, optional
Number of outer iterations.
+ engine : :obj:`str`, optional
+ .. versionadded:: 2.6.0
+
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -533,8 +655,10 @@ def run(
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.cg` solver for data IRLS and
:py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using
- numpy data(or :py:func:`pylops.optimization.solver.cg` and
- :py:func:`pylops.optimization.solver.cgls` when using cupy data)
+ numpy data and ``engine='scipy'`` (or
+ :py:func:`pylops.optimization.solver.cg` and
+ :py:func:`pylops.optimization.solver.cgls` when using cupy data or
+ ``engine='pylops'``)
Returns
-------
@@ -546,6 +670,7 @@ def run(
if x is not None:
self.x0 = x.copy()
self.y = self.y - self.Op.matvec(x)
+
# choose xold to ensure tolerance test is passed initially
xold = x.copy() + np.inf
while self.iiter < nouter and self.ncp.linalg.norm(x - xold) >= self.tolIRLS:
@@ -560,7 +685,7 @@ def run(
else False
)
xold = x.copy()
- x = self.step(x, showstep, **kwargs_solver)
+ x = self.step(x, engine, showstep, **kwargs_solver)
self.callback(x)
# adding initial guess
@@ -594,6 +719,8 @@ def solve(
tolIRLS: float = 1e-10,
kind: str = "data",
warm: bool = False,
+ engine: str = "scipy",
+ preallocate: bool = False,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
**kwargs_solver,
@@ -624,6 +751,16 @@ def solve(
This only applies to ``kind="data"`` and ``kind="datamodel"``
kind : :obj:`str`, optional
Kind of solver (``data`` or ``model``)
+ engine : :obj:`str`, optional
+ .. versionadded:: 2.6.0
+
+ Solver to use (``scipy`` or ``pylops``)
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display setup log
itershow : :obj:`tuple`, optional
@@ -651,11 +788,19 @@ def solve(
tolIRLS=tolIRLS,
warm=warm,
kind=kind,
+ preallocate=preallocate,
show=show,
)
if x0 is None:
x0 = self.ncp.zeros(self.Op.shape[1], dtype=self.y.dtype)
- x = self.run(x0, nouter=nouter, show=show, itershow=itershow, **kwargs_solver)
+ x = self.run(
+ x0,
+ nouter=nouter,
+ engine=engine,
+ show=show,
+ itershow=itershow,
+ **kwargs_solver,
+ )
self.finalize(show)
return x, self.nouter
@@ -756,6 +901,41 @@ def _print_step(self, x: NDArray) -> None:
str2 = f" {self.cost[-1]:10.3e}"
print(str1 + str2)
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: y, res
+ memuse = (2 * self.Op.shape[0]) * nbytes
+
+ # Step (additional variables to those in setup): cres, cres_abs
+ memuse += (2 * self.Op.shape[0]) * nbytes
+
+ if show:
+ print(f"OMP predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -765,6 +945,7 @@ def setup(
normalizecols: bool = False,
Opbasis: Optional["LinearOperator"] = None,
optimal_coeff: bool = False,
+ preallocate: bool = False,
show: bool = False,
) -> None:
r"""Setup solver
@@ -794,6 +975,12 @@ def setup(
:math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
directly the value from the inner product
:math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display setup log
@@ -805,7 +992,10 @@ def setup(
self.normalizecols = normalizecols
self.Opbasis = Opbasis if Opbasis is not None else self.Op
self.optimal_coeff = optimal_coeff
+
self.ncp = get_array_module(y)
+ self.isjax = get_module_name(self.ncp) == "jax"
+ self._setpreallocate(preallocate)
# find normalization factor for each column
if self.normalizecols:
@@ -814,8 +1004,8 @@ def setup(
for icol in range(ncols):
unit = self.ncp.zeros(ncols, dtype=self.Opbasis.dtype)
unit[icol] = 1
- self.norms[icol] = np.linalg.norm(self.Opbasis.matvec(unit))
- print(f"{self.norms = }")
+ self.norms[icol] = self.ncp.linalg.norm(self.Opbasis.matvec(unit))
+
# create variables to track the residual norm and iterations
self.res = self.y.copy()
self.cost = [
@@ -830,7 +1020,9 @@ def step(
self,
x: NDArray,
cols: InputDimsLike,
+ engine: str = "scipy",
show: bool = False,
+ **kwargs_solver,
) -> NDArray:
r"""Run one step of solver
@@ -840,8 +1032,18 @@ def step(
Current model vector to be updated by a step of OMP
cols : :obj:`list`
Current list of chosen elements of vector x to be updated by a step of OMP
+ engine : :obj:`str`, optional
+ .. versionadded:: 2.6.0
+
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display iteration log
+ **kwargs_solver
+ Arbitrary keyword arguments for
+ :py:func:`scipy.sparse.linalg.lsqr` solver when using
+ numpy data and ``engine='scipy'`` (or
+ :py:func:`pylops.optimization.solver.cgls` when using cupy
+ data or ``engine='pylops'``)
Returns
-------
@@ -851,14 +1053,18 @@ def step(
Current list of chosen elements
"""
+ # add preallocate to keywords of solver
+ if self.preallocate and (engine == "pylops" or self.ncp != np):
+ kwargs_solver["preallocate"] = True
+
# compute inner products
cres = self.Op.rmatvec(self.res)
if self.normalizecols:
cres = cres / self.norms
- cres_abs = np.abs(cres)
+ cres_abs = self.ncp.abs(cres)
# choose column with max cres
- cres_max = np.max(cres_abs)
- imax = np.argwhere(cres_abs == cres_max).ravel()
+ cres_max = self.ncp.max(cres_abs)
+ imax = self.ncp.argwhere(cres_abs == cres_max).ravel()
nimax = len(imax)
if nimax > 0:
imax = imax[np.random.permutation(nimax)[0]]
@@ -882,7 +1088,14 @@ def step(
)
if not self.optimal_coeff:
# update with coefficient that maximizes the inner product
- self.res -= Opcol.matvec(cres[imax] * self.ncp.ones(1))
+ if not self.preallocate:
+ self.res -= Opcol.matvec(cres[imax] * self.ncp.ones(1))
+ else:
+ self.ncp.subtract(
+ self.res,
+ Opcol.matvec(cres[imax] * self.ncp.ones(1)),
+ out=self.res,
+ )
if addnew:
x.append(cres[imax])
else:
@@ -891,7 +1104,12 @@ def step(
# find optimal coefficient that minimizes the residual (r - cres * col)
col = Opcol.matvec(self.ncp.ones(1, dtype=Opcol.dtype))
cresopt = (Opcol.rmatvec(self.res) / Opcol.rmatvec(col))[0]
- self.res -= Opcol.matvec(cresopt * self.ncp.ones(1))
+ if not self.preallocate:
+ self.res -= Opcol.matvec(cresopt * self.ncp.ones(1))
+ else:
+ self.ncp.subtract(
+ self.res, Opcol.matvec(cresopt * self.ncp.ones(1)), out=self.res
+ )
if addnew:
x.append(cresopt)
else:
@@ -899,16 +1117,21 @@ def step(
else:
# OMP update
Opcol = self.Op.apply_columns(cols)
- if self.ncp == np:
- x = lsqr(Opcol, self.y, iter_lim=self.niter_inner)[0]
- else:
+ if engine == "scipy" and self.ncp == np:
+ x = lsqr(Opcol, self.y, iter_lim=self.niter_inner, **kwargs_solver)[0]
+ elif engine == "pylops" or self.ncp != np:
x = cgls(
Opcol,
self.y,
self.ncp.zeros(int(Opcol.shape[1]), dtype=Opcol.dtype),
niter=self.niter_inner,
+ **kwargs_solver,
)[0]
- self.res = self.y - Opcol.matvec(x)
+ if not self.preallocate:
+ self.res = self.y - Opcol.matvec(x)
+ else:
+ self.res = Opcol.matvec(x)
+ self.ncp.subtract(self.res, self.y, out=self.res)
self.iiter += 1
self.cost.append(float(np.linalg.norm(self.res)))
@@ -920,6 +1143,7 @@ def run(
self,
x: NDArray,
cols: InputDimsLike,
+ engine: str = "scipy",
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
) -> Tuple[NDArray, InputDimsLike]:
@@ -931,6 +1155,10 @@ def run(
Current model vector to be updated by multiple steps of IRLS
cols : :obj:`list`
Current list of chosen elements of vector x to be updated by a step of OMP
+ engine : :obj:`str`, optional
+ .. versionadded:: 2.6.0
+
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -957,7 +1185,7 @@ def run(
)
else False
)
- x, cols = self.step(x, cols, showstep)
+ x, cols = self.step(x, cols, engine, showstep)
self.callback(x, cols)
return x, cols
@@ -990,7 +1218,8 @@ def finalize(
self.nouter = self.iiter
xfin = self.ncp.zeros(int(self.Op.shape[1]), dtype=self.Op.dtype)
- xfin[cols] = self.ncp.array(x)
+ xfin = inplace_set(self.ncp.array(x), xfin, self.ncp.array(cols))
+
if show:
self._print_finalize(nbar=55)
return xfin
@@ -1004,6 +1233,8 @@ def solve(
normalizecols: bool = False,
Opbasis: Optional["LinearOperator"] = None,
optimal_coeff: bool = False,
+ engine: str = "scipy",
+ preallocate: bool = False,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
) -> Tuple[NDArray, int, NDArray]:
@@ -1034,6 +1265,16 @@ def solve(
:math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
directly the value from the inner product
:math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
+ engine : :obj:`str`, optional
+ .. versionadded:: 2.6.0
+
+ Solver to use (``scipy`` or ``pylops``)
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -1059,11 +1300,12 @@ def solve(
normalizecols=normalizecols,
Opbasis=Opbasis,
optimal_coeff=optimal_coeff,
+ preallocate=preallocate,
show=show,
)
x: List[NDArray] = []
cols: List[InputDimsLike] = []
- x, cols = self.run(x, cols, show=show, itershow=itershow)
+ x, cols = self.run(x, cols, engine=engine, show=show, itershow=itershow)
x = self.finalize(x, cols, show)
return x, self.nouter, self.cost
@@ -1180,6 +1422,41 @@ def _print_step(
)
print(msg)
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: x0 - y
+ memuse = (self.Op.shape[1] + self.Op.shape[0]) * nbytes
+
+ # Step (additional variables to those in setup): xold, grad, x_unthesh - res
+ memuse += (3 * self.Op.shape[1] + self.Op.shape[0]) * nbytes
+
+ if show:
+ print(f"ISTA predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -1194,6 +1471,7 @@ def setup(
perc: Optional[float] = None,
decay: Optional[NDArray] = None,
monitorres: bool = False,
+ preallocate: bool = False,
show: bool = False,
) -> NDArray:
r"""Setup solver
@@ -1234,6 +1512,12 @@ def setup(
Decay factor to be applied to thresholding during iterations
monitorres : :obj:`bool`, optional
Monitor that residual is decreasing
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display setup log
@@ -1255,6 +1539,8 @@ def setup(
self.monitorres = monitorres
self.ncp = get_array_module(y)
+ self.isjax = get_module_name(self.ncp) == "jax"
+ self._setpreallocate(preallocate)
# choose matvec/rmatvec or matmat/rmatmat based on R
if y.ndim > 1 and y.shape[1] > 1:
@@ -1309,7 +1595,7 @@ def setup(
# prepare decay (if not passed)
if perc is None and decay is None:
- self.decay = self.ncp.ones(niter)
+ self.decay = self.ncp.ones(niter, dtype=self.Op)
# step size
if alpha is not None:
@@ -1357,6 +1643,17 @@ def setup(
else:
x = x0.copy()
+ # initialize other internal variabled
+ if self.preallocate:
+ self.res = self.ncp.empty_like(y)
+ self.grad = self.ncp.empty_like(x)
+ self.x_unthesh = self.ncp.empty_like(x)
+ self.xold = self.ncp.empty_like(x)
+ if self.SOp is not None:
+ self.SOpx_unthesh: NDArray = self.ncp.zeros(
+ self.SOp.shape[1], dtype=self.SOp.dtype
+ )
+
# create variable to track residual
if monitorres:
self.normresold = np.inf
@@ -1392,11 +1689,19 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
"""
# store old vector
- xold = x.copy()
+ if self.preallocate:
+ self.xold[:] = x[:]
+ else:
+ xold = x.copy()
+
# compute residual
- res: NDArray = self.y - self.Opmatvec(x)
+ if not self.preallocate:
+ res: NDArray = self.y - self.Opmatvec(x)
+ else:
+ self.ncp.subtract(self.y, self.Opmatvec(x), out=self.res)
+
if self.monitorres:
- self.normres = np.linalg.norm(res)
+ self.normres = np.linalg.norm(self.res if self.preallocate else res)
if self.normres > self.normresold:
raise ValueError(
f"ISTA stopped at iteration {self.iiter} due to "
@@ -1407,25 +1712,67 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
self.normresold = self.normres
# compute gradient
- grad: NDArray = self.alpha * (self.Oprmatvec(res))
+ if not self.preallocate:
+ grad: NDArray = self.alpha * (self.Oprmatvec(res))
+ else:
+ self.ncp.multiply(
+ self.Oprmatvec(self.res),
+ self.alpha,
+ out=self.grad,
+ )
# update inverted model
- x_unthesh: NDArray = x + grad
+ if not self.preallocate:
+ x_unthesh: NDArray = x + grad
+ else:
+ self.ncp.add(
+ x,
+ self.grad,
+ out=self.x_unthesh,
+ )
+
+ # apply SOp.H to current x
if self.SOp is not None:
- x_unthesh = self.SOprmatvec(x_unthesh)
- if self.perc is None and self.decay is not None:
- x = self.threshf(x_unthesh, self.decay[self.iiter] * self.thresh)
- elif self.perc is not None:
- x = self.threshf(x_unthesh, 100 - self.perc)
+ if self.preallocate:
+ self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh)
+ else:
+ SOpx_unthesh = self.SOprmatvec(x_unthesh)
+ # threshold current solution or current solution projected onto SOp.H space
+ if self.SOp is None:
+ x_unthesh_or_SOpx_unthesh = (
+ self.x_unthesh if self.preallocate else x_unthesh
+ )
+ else:
+ x_unthesh_or_SOpx_unthesh = (
+ self.SOpx_unthesh if self.preallocate else SOpx_unthesh
+ )
+ if self.perc is None:
+ x = self.threshf(
+ x_unthesh_or_SOpx_unthesh,
+ self.decay[self.iiter] * self.thresh,
+ )
+ else:
+ x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)
+ # apply SOp to thresholded x
if self.SOp is not None:
x = self.SOpmatvec(x)
- # model update
- xupdate = np.linalg.norm(x - xold)
+ # check model update
+ if not self.preallocate:
+ xupdate = np.linalg.norm(x - xold)
+ else:
+ self.ncp.subtract(
+ x,
+ self.xold,
+ out=self.xold,
+ )
+ xupdate = np.linalg.norm(self.xold)
- costdata = 0.5 * np.linalg.norm(res) ** 2
+ # cost functions
+ costdata = 0.5 * np.linalg.norm(self.res if self.preallocate else res) ** 2
costreg = self.eps * np.linalg.norm(x, ord=1)
self.cost.append(float(costdata + costreg))
+
self.iiter += 1
if show:
self._print_step(x, costdata, costreg, xupdate)
@@ -1512,6 +1859,7 @@ def solve(
perc: Optional[float] = None,
decay: Optional[NDArray] = None,
monitorres: bool = False,
+ preallocate: bool = False,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
) -> Tuple[NDArray, int, NDArray]:
@@ -1552,6 +1900,12 @@ def solve(
Decay factor to be applied to thresholding during iterations
monitorres : :obj:`bool`, optional
Monitor that residual is decreasing
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -1582,6 +1936,7 @@ def solve(
perc=perc,
decay=decay,
monitorres=monitorres,
+ preallocate=preallocate,
show=show,
)
x = self.run(x, niter, show=show, itershow=itershow)
@@ -1647,6 +2002,41 @@ class FISTA(ISTA):
"""
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: x0 - y
+ memuse = (self.Op.shape[1] + self.Op.shape[0]) * nbytes
+
+ # Step (additional variables to those in setup): xold, grad, x_unthesh, z - res
+ memuse += (4 * self.Op.shape[1] + self.Op.shape[0]) * nbytes
+
+ if show:
+ print(f"FISTA predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
+
+ return memuse
+
def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
r"""Run one step of solver
@@ -1670,45 +2060,101 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
"""
# store old vector
- xold = x.copy()
+ if self.preallocate:
+ self.xold[:] = x[:]
+ else:
+ xold = x.copy()
+
# compute residual
- resz: NDArray = self.y - self.Opmatvec(z)
+ if not self.preallocate:
+ res: NDArray = self.y - self.Opmatvec(z)
+ else:
+ self.ncp.subtract(self.y, self.Opmatvec(z), out=self.res)
+
if self.monitorres:
- self.normres = np.linalg.norm(resz)
+ self.normres = np.linalg.norm(self.res if self.preallocate else res)
if self.normres > self.normresold:
raise ValueError(
- f"ISTA stopped at iteration {self.iiter} due to "
+ f"FISTA stopped at iteration {self.iiter} due to "
"residual increasing, consider modifying "
"eps and/or alpha..."
)
else:
self.normresold = self.normres
- # compute gradient
- grad: NDArray = self.alpha * (self.Oprmatvec(resz))
+ # compute gradient and update inverted model
+ if not self.preallocate:
+ grad: NDArray = self.alpha * (self.Oprmatvec(res))
+ x_unthesh: NDArray = z + grad
+ else:
+ self.ncp.multiply(
+ self.Oprmatvec(self.res),
+ self.alpha,
+ out=self.grad,
+ )
+ self.ncp.add(
+ z,
+ self.grad,
+ out=self.x_unthesh,
+ )
- # update inverted model
- x_unthesh: NDArray = z + grad
+ # apply SOp.H to current x
if self.SOp is not None:
- x_unthesh = self.SOprmatvec(x_unthesh)
- if self.perc is None and self.decay is not None:
- x = self.threshf(x_unthesh, self.decay[self.iiter] * self.thresh)
- elif self.perc is not None:
- x = self.threshf(x_unthesh, 100 - self.perc)
+ if self.preallocate:
+ self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh)
+ else:
+ SOpx_unthesh = self.SOprmatvec(x_unthesh)
+
+ # threshold current solution or current solution projected onto SOp.H space
+ if self.SOp is None:
+ x_unthesh_or_SOpx_unthesh = (
+ self.x_unthesh if self.preallocate else x_unthesh
+ )
+ else:
+ x_unthesh_or_SOpx_unthesh = (
+ self.SOpx_unthesh if self.preallocate else SOpx_unthesh
+ )
+ if self.perc is None:
+ x = self.threshf(
+ x_unthesh_or_SOpx_unthesh,
+ self.decay[self.iiter] * self.thresh,
+ )
+ else:
+ x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)
+
+ # apply SOp to thresholded x
if self.SOp is not None:
x = self.SOpmatvec(x)
# update auxiliary coefficients
told = self.t
- self.t = (1.0 + np.sqrt(1.0 + 4.0 * self.t**2)) / 2.0
- z = x + ((told - 1.0) / self.t) * (x - xold)
+ self.t = (1.0 + sqrt(1.0 + 4.0 * self.t**2)) / 2.0
# model update
- xupdate = np.linalg.norm(x - xold)
+ if not self.preallocate:
+ z = x + ((told - 1.0) / self.t) * (x - xold)
+ else:
+ self.ncp.subtract(
+ x,
+ self.xold,
+ out=self.xold,
+ )
+ self.ncp.multiply(self.xold, ((told - 1.0) / self.t), out=z)
+ self.ncp.add(x, z, out=z)
+ # check model update
+ if not self.preallocate:
+ xupdate = np.linalg.norm(x - xold)
+ else:
+ # note that x - xold has been already computed as part of the
+ # intermediate calculation of x in model update step
+ xupdate = np.linalg.norm(self.xold)
+
+ # cost functions
costdata = 0.5 * np.linalg.norm(self.y - self.Op @ x) ** 2
costreg = self.eps * np.linalg.norm(x, ord=1)
self.cost.append(float(costdata + costreg))
+
self.iiter += 1
if show:
self._print_step(x, costdata, costreg, xupdate)
@@ -1829,6 +2275,13 @@ def _print_finalize(self) -> None:
print(f"\nTotal time (s) = {self.telapsed:.2f}")
print("-" * 80 + "\n")
+ def memory_usage(
+ self,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ pass
+
def setup(
self,
y: NDArray,
@@ -2141,6 +2594,60 @@ def _print_step(self, x: NDArray) -> None:
str2 = f"{self.costdata:10.3e} {self.costtot:9.3e}"
print(str1 + str2)
+ def memory_usage(
+ self,
+ nopRegsL1: Optional[Tuple[int]] = None,
+ nopRegsL2: Optional[Tuple[int]] = None,
+ show: bool = False,
+ unit: str = "B",
+ ) -> float:
+ """Compute memory usage of the solver
+
+ Parameters
+ ----------
+ nopRegsL1 : :obj:`tuple`, optional
+ Number of data elements of ``RegsL1`` operators
+ nopRegsL2 : :obj:`tuple`, optional
+ Number of data elements of ``RegsL2`` operators
+ show : :obj:`bool`, optional
+ Display memory usage
+ unit: :obj:`str`, optional
+ Unit used to display memory usage (
+ ``B``, ``KB``, ``MB`` or ``GB``)
+
+ Returns
+ -------
+ memuse :obj:`float`
+ Memory usage in Bytes
+
+ """
+ # Convert nopRegsL1 and nopRegsL2 if None
+ if nopRegsL1 is None:
+ nopRegsL1 = 0
+ if nopRegsL2 is None:
+ nopRegsL2 = 0
+
+ # Get number of bytes of dtype used in the solver
+ nbytes = np.dtype(self.Op.dtype).itemsize
+
+ # Setup: x0 - y - b, d dataregsL1 - dataregsL2
+ memuse = (
+ self.Op.shape[1]
+ + self.Op.shape[0]
+ + 3 * np.prod(nopRegsL1)
+ + np.prod(nopRegsL2)
+ ) * nbytes
+
+ # Step (additional variables to those in setup): dataregs
+ memuse += np.prod(nopRegsL1) * nbytes
+
+ if show:
+ print(
+ f"Split-Bregman predicted memory usage: {memuse / _units[unit]:.2f} {unit}"
+ )
+
+ return memuse
+
def setup(
self,
y: NDArray,
@@ -2156,6 +2663,7 @@ def setup(
tol: float = 1e-10,
tau: float = 1.0,
restart: bool = False,
+ preallocate: bool = False,
show: bool = False,
) -> NDArray:
r"""Setup solver
@@ -2201,6 +2709,12 @@ def setup(
the initial guess (``True``) or with the last estimate (``False``).
Note that when this is set to ``True``, the ``x0`` provided in the setup will
be used in all iterations.
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display setup log
@@ -2222,14 +2736,19 @@ def setup(
self.tol = tol
self.tau = tau
self.restart = restart
+
self.ncp = get_array_module(y)
+ self.isjax = get_module_name(self.ncp) == "jax"
+ self._setpreallocate(preallocate)
# L1 regularizations
self.nregsL1 = len(RegsL1)
self.b = [
self.ncp.zeros(RegL1.shape[0], dtype=self.Op.dtype) for RegL1 in RegsL1
]
- self.d = self.b.copy()
+ self.d = [
+ self.ncp.zeros(RegL1.shape[0], dtype=self.Op.dtype) for RegL1 in RegsL1
+ ]
# L2 regularizations
self.nregsL2 = 0 if RegsL2 is None else len(RegsL2)
@@ -2247,13 +2766,11 @@ def setup(
self.epsRs: List[float] = []
if epsRL2s is not None:
self.epsRs += [
- np.sqrt(epsRL2s[ireg] / 2) / np.sqrt(mu / 2)
- for ireg in range(self.nregsL2)
+ sqrt(epsRL2s[ireg] / 2) / sqrt(mu / 2) for ireg in range(self.nregsL2)
]
if epsRL1s is not None:
self.epsRs += [
- np.sqrt(epsRL1s[ireg] / 2) / np.sqrt(mu / 2)
- for ireg in range(self.nregsL1)
+ sqrt(epsRL1s[ireg] / 2) / sqrt(mu / 2) for ireg in range(self.nregsL1)
]
self.x0 = x0
@@ -2270,9 +2787,10 @@ def setup(
def step(
self,
x: NDArray,
+ engine: str = "scipy",
show: bool = False,
show_inner: bool = False,
- **kwargs_lsqr,
+ **kwargs_solver,
) -> NDArray:
r"""Run one step of solver
@@ -2280,14 +2798,18 @@ def step(
----------
x : :obj:`list` or :obj:`np.ndarray`
Current model vector to be updated by a step of OMP
- show_inner : :obj:`bool`, optional
- Display inner iteration logs of lsqr
+ engine : :obj:`str`, optional
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display iteration log
- **kwargs_lsqr
- Arbitrary keyword arguments for
- :py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first
- subproblem in the first step of the Split Bregman algorithm.
+ show_inner : :obj:`bool`, optional
+ Display inner iteration logs of lsqr
+ **kwargs_solver
+ Arbitrary keyword arguments for chosen solver
+ used to solve the first subproblem in the first step of the
+ Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` and
+ :py:func:`pylops.optimization.solver.cgls` are used as default
+ for numpy and cupy `data`, respectively).
Returns
-------
@@ -2295,11 +2817,22 @@ def step(
Updated model vector
"""
+ # add preallocate to keywords of solver
+ if self.preallocate and (engine == "pylops" or self.ncp != np):
+ kwargs_solver["preallocate"] = True
+
for _ in range(self.niter_inner):
# regularized problem
- dataregs = self.dataregsL2 + [
- self.d[ireg] - self.b[ireg] for ireg in range(self.nregsL1)
- ]
+ if not self.preallocate:
+ dataregs = self.dataregsL2 + [
+ self.d[ireg] - self.b[ireg] for ireg in range(self.nregsL1)
+ ]
+ else:
+ for ireg in range(self.nregsL1):
+ self.ncp.subtract(self.d[ireg], self.b[ireg], out=self.d[ireg])
+ dataregs = self.dataregsL2 + [
+ self.d[ireg] for ireg in range(self.nregsL1)
+ ]
x = regularized_inversion(
self.Op,
self.y,
@@ -2308,21 +2841,25 @@ def step(
epsRs=self.epsRs,
x0=self.x0 if self.restart else x,
show=show_inner,
- **kwargs_lsqr,
+ engine=engine,
+ **kwargs_solver,
)[0]
- # Shrinkage
- self.d = [
- _softthreshold(
- self.RegsL1[ireg].matvec(x) + self.b[ireg], self.epsRL1s[ireg]
- )
- for ireg in range(self.nregsL1)
- ]
+ # shrinkage
+ if not self.preallocate:
+ for ireg in range(self.nregsL1):
+ self.d[ireg] = _softthreshold(
+ self.RegsL1[ireg].matvec(x) + self.b[ireg], self.epsRL1s[ireg]
+ )
+ else:
+ for ireg in range(self.nregsL1):
+ self.ncp.add(
+ self.RegsL1[ireg].matvec(x), self.b[ireg], out=self.d[ireg]
+ )
+ self.d[ireg] = _softthreshold(self.d[ireg], self.epsRL1s[ireg])
# Bregman update
- self.b = [
- self.b[ireg] + self.tau * (self.RegsL1[ireg].matvec(x) - self.d[ireg])
- for ireg in range(self.nregsL1)
- ]
+ for ireg in range(self.nregsL1):
+ self.b[ireg] += self.tau * (self.RegsL1[ireg].matvec(x) - self.d[ireg])
# compute residual norms
self.costdata = (
@@ -2340,7 +2877,7 @@ def step(
)
self.costregL1 = [
self.ncp.linalg.norm(RegL1.matvec(x), ord=1)
- for epsRL1, RegL1 in zip(self.epsRL1s, self.RegsL1)
+ for _, RegL1 in zip(self.epsRL1s, self.RegsL1)
]
self.costtot = (
self.costdata
@@ -2358,6 +2895,7 @@ def step(
def run(
self,
x: NDArray,
+ engine: str = "scipy",
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
show_inner: bool = False,
@@ -2369,6 +2907,8 @@ def run(
----------
x : :obj:`np.ndarray`
Current model vector to be updated by multiple steps of IRLS
+ engine : :obj:`str`, optional
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -2403,8 +2943,9 @@ def run(
)
else False
)
- x = self.step(x, showstep, show_inner, **kwargs_lsqr)
+ x = self.step(x, engine, showstep, show_inner, **kwargs_lsqr)
self.callback(x)
+
return x
def finalize(self, show: bool = False) -> NDArray:
@@ -2443,6 +2984,8 @@ def solve(
tol: float = 1e-10,
tau: float = 1.0,
restart: bool = False,
+ engine: str = "scipy",
+ preallocate: bool = False,
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
show_inner: bool = False,
@@ -2491,6 +3034,14 @@ def solve(
the initial guess (``True``) or with the last estimate (``False``).
Note that when this is set to ``True``, the ``x0`` provided in the setup will
be used in all iterations.
+ engine : :obj:`str`, optional
+ Solver to use (``scipy`` or ``pylops``)
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -2528,10 +3079,16 @@ def solve(
tol=tol,
tau=tau,
restart=restart,
+ preallocate=preallocate,
show=show,
)
x = self.run(
- x, show=show, itershow=itershow, show_inner=show_inner, **kwargs_lsqr
+ x,
+ engine=engine,
+ show=show,
+ itershow=itershow,
+ show_inner=show_inner,
+ **kwargs_lsqr,
)
self.finalize(show)
return x, self.iiter, self.cost
diff --git a/pylops/optimization/leastsquares.py b/pylops/optimization/leastsquares.py
index 3174ee6c8..0e61b8a26 100644
--- a/pylops/optimization/leastsquares.py
+++ b/pylops/optimization/leastsquares.py
@@ -239,9 +239,9 @@ def preconditioned_inversion(
x0 : :obj:`numpy.ndarray`
Initial guess of size :math:`[M \times 1]`
engine : :obj:`str`, optional
- Solver to use (``scipy`` or ``pylops``)
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
- Display normal equations solver log
+ Display normal equations solver log
**kwargs_solver
Arbitrary keyword arguments for chosen solver
(:py:func:`scipy.sparse.linalg.lsqr` and
diff --git a/pylops/optimization/sparsity.py b/pylops/optimization/sparsity.py
index 6838e9461..c7c720efb 100644
--- a/pylops/optimization/sparsity.py
+++ b/pylops/optimization/sparsity.py
@@ -28,9 +28,11 @@ def irls(
tolIRLS: float = 1e-10,
warm: bool = False,
kind: str = "data",
+ engine: str = "scipy",
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
+ preallocate: bool = False,
**kwargs_solver,
) -> Tuple[NDArray, int]:
r"""Iteratively reweighted least squares.
@@ -74,6 +76,8 @@ def irls(
This only applies to ``kind="data"`` and ``kind="datamodel"``
kind : :obj:`str`, optional
Kind of solver (``model``, ``data`` or ``datamodel``)
+ engine : :obj:`str`, optional
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display logs
itershow : :obj:`tuple`, optional
@@ -83,6 +87,12 @@ def irls(
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
**kwargs_solver
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.cg` solver for data IRLS and
@@ -113,8 +123,10 @@ def irls(
epsR=epsR,
epsI=epsI,
tolIRLS=tolIRLS,
- warm=warm,
kind=kind,
+ warm=warm,
+ engine=engine,
+ preallocate=preallocate,
show=show,
itershow=itershow,
**kwargs_solver,
@@ -131,9 +143,11 @@ def omp(
normalizecols: bool = False,
Opbasis: Optional["LinearOperator"] = None,
optimal_coeff: bool = False,
+ engine: str = "scipy",
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
+ preallocate: bool = False,
) -> Tuple[NDArray, int, NDArray]:
r"""Orthogonal Matching Pursuit (OMP).
@@ -169,6 +183,8 @@ def omp(
:math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
directly the value from the inner product
:math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
+ engine : :obj:`str`, optional
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display iterations log
itershow : :obj:`tuple`, optional
@@ -179,7 +195,12 @@ def omp(
Function with signature (``callback(x, cols)``) to call after each iteration
where ``x`` contains the non-zero model coefficient and ``cols`` are the
indices where the current model vector is non-zero
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
Returns
-------
xinv : :obj:`numpy.ndarray`
@@ -212,8 +233,10 @@ def omp(
normalizecols=normalizecols,
Opbasis=Opbasis,
optimal_coeff=optimal_coeff,
+ engine=engine,
show=show,
itershow=itershow,
+ preallocate=preallocate,
)
return x, niter_outer, cost
@@ -235,6 +258,7 @@ def ista(
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
+ preallocate: bool = False,
) -> Tuple[NDArray, int, NDArray]:
r"""Iterative Shrinkage-Thresholding Algorithm (ISTA).
@@ -289,6 +313,12 @@ def ista(
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
Returns
-------
@@ -340,6 +370,7 @@ def ista(
monitorres=monitorres,
show=show,
itershow=itershow,
+ preallocate=preallocate,
)
return x, iiter, cost
@@ -361,6 +392,7 @@ def fista(
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
callback: Optional[Callable] = None,
+ preallocate: bool = False,
) -> Tuple[NDArray, int, NDArray]:
r"""Fast Iterative Shrinkage-Thresholding Algorithm (FISTA).
@@ -415,6 +447,12 @@ def fista(
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
Returns
-------
@@ -464,6 +502,7 @@ def fista(
monitorres=monitorres,
show=show,
itershow=itershow,
+ preallocate=preallocate,
)
return x, iiter, cost
@@ -600,10 +639,12 @@ def splitbregman(
tol: float = 1e-10,
tau: float = 1.0,
restart: bool = False,
+ engine: str = "scipy",
show: bool = False,
itershow: Tuple[int, int, int] = (10, 10, 10),
show_inner: bool = False,
callback: Optional[Callable] = None,
+ preallocate: bool = False,
**kwargs_lsqr,
) -> Tuple[NDArray, int, NDArray]:
r"""Split Bregman for mixed L2-L1 norms.
@@ -653,6 +694,8 @@ def splitbregman(
restart : :obj:`bool`, optional
The unconstrained inverse problem in inner loop is initialized with
the initial guess (``True``) or with the last estimate (``False``)
+ engine : :obj:`str`, optional
+ Solver to use (``scipy`` or ``pylops``)
show : :obj:`bool`, optional
Display iterations log
itershow : :obj:`tuple`, optional
@@ -664,6 +707,12 @@ def splitbregman(
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
+ preallocate : :obj:`bool`, optional
+ .. versionadded:: 2.6.0
+
+ Pre-allocate all variables used by the solver. Note that if ``y``
+ is a JAX array, this option is ignored and variables are not
+ pre-allocated since JAX does not support in-place operations.
**kwargs_lsqr
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first
@@ -700,6 +749,8 @@ def splitbregman(
tol=tol,
tau=tau,
restart=restart,
+ engine=engine,
+ preallocate=preallocate,
show=show,
itershow=itershow,
show_inner=show_inner,
diff --git a/pylops/waveeqprocessing/seismicinterpolation.py b/pylops/waveeqprocessing/seismicinterpolation.py
index e1159d4aa..3d18187e9 100644
--- a/pylops/waveeqprocessing/seismicinterpolation.py
+++ b/pylops/waveeqprocessing/seismicinterpolation.py
@@ -271,6 +271,8 @@ def SeismicInterpolation(
Pop = FFT2D(dims=dims, nffts=nffts, sampling=sampling)
Pop = Pop.H
SIop = Rop * Pop
+ # Force data to be of same dtype of operator
+ data = data.astype(SIop.dtype)
elif "chirpradon" in kind:
prec = True
dotcflag = 0
diff --git a/tutorials/solvers.py b/tutorials/solvers.py
index c0bfe72c8..58037c41a 100755
--- a/tutorials/solvers.py
+++ b/tutorials/solvers.py
@@ -224,7 +224,7 @@
y,
[D2op],
epsRs=[np.sqrt(0.1)],
- **dict(damp=np.sqrt(1e-4), iter_lim=50, show=0)
+ **dict(damp=np.sqrt(1e-4), iter_lim=50, show=0),
)[0]
###############################################################################
@@ -292,8 +292,8 @@
# \epsilon \|\mathbf{p}\|_1
#
# where :math:`\mathbf{F}` is the FFT operator. We will thus use the
-# :py:class:`pylops.optimization.sparsity.ista` and
-# :py:class:`pylops.optimization.sparsity.fista` solvers to estimate our input
+# :py:func:`pylops.optimization.sparsity.ista` and
+# :py:func:`pylops.optimization.sparsity.fista` solvers to estimate our input
# signal.
pista, niteri, costi = pylops.optimization.sparsity.ista(
@@ -347,7 +347,7 @@
# converges much faster than ISTA as expected and should be preferred when
# using sparse solvers.
#
-# Finally we consider a slightly different cost function (note that in this
+# We now consider a slightly different cost function (note that in this
# case we try to solve a constrained problem):
#
# .. math::
@@ -356,7 +356,7 @@
# \mathbf{R} \mathbf{F} \mathbf{p}\|
#
# A very popular solver to solve such kind of cost function is called *spgl1*
-# and can be accessed via :py:class:`pylops.optimization.sparsity.spgl1`.
+# and can be accessed via :py:func:`pylops.optimization.sparsity.spgl1`.
xspgl1, pspgl1, info = pylops.optimization.sparsity.spgl1(
Rop, y, SOp=FFTop, tau=3, iter_lim=200
@@ -384,3 +384,35 @@
ax.legend()
ax.grid(True)
plt.tight_layout()
+
+###############################################################################
+# Finally, we go back to the :py:func:`pylops.optimization.sparsity.ista`
+# solver and show a new feature that was introduced in PyLops v2.6.0.
+# When the solver is run with ``preallocate=True``, all internal vectors
+# are pre-allocated in the ``setup`` method of the
+# :py:class:`pylops.optimization.cls_sparsity.ISTA` class. This is likely to
+# improve the performance of the solver (see :ref:`sphx_glr_tutorials_classsolvers.py`
+# for more details), especially when it is applied to large problems.
+
+# Original ISTA
+pista = pylops.optimization.sparsity.ista(
+ Rop * FFTop.H,
+ y,
+ niter=100,
+ alpha=1e-2,
+ eps=0.1,
+ tol=1e-7,
+)[0]
+
+# ISTA with preallocation
+pista_prealloc = pylops.optimization.sparsity.ista(
+ Rop * FFTop.H,
+ y,
+ niter=100,
+ alpha=1e-2,
+ eps=0.1,
+ tol=1e-7,
+ preallocate=True,
+)[0]
+
+print(f"Norm of difference: {np.linalg.norm(pista - pista_prealloc)}")