Skip to content

Commit b6c103e

Browse files
authored
Merge pull request #758 from mrava87/fix-ista_decay
Bug: fix ista decay
2 parents 7961607 + 7a777d5 commit b6c103e

2 files changed

Lines changed: 46 additions & 5 deletions

File tree

docs/source/addingsolver.rst

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,44 @@ As for any Python class, our solver will need an ``__init__`` method. In this ca
5353
of the base class. Two input parameters are passed to the ``__init__`` method and saved as members of our class,
5454
namely the operator :math:`\mathbf{Op}` associated with the system of equations we wish to solve,
5555
:math:`\mathbf{y}=\mathbf{Op}\,\mathbf{x}`, and optionally a :class:`pylops.optimization.callback.Callbacks` object. Moreover,
56-
an additional parameters is created that contains the current time (this is used later to report the execution time
57-
of the solver). Here is the ``__init__`` method of the base class:
56+
two additional parameters are created that contains the counter of the iterations (which will be incremented every time the
57+
``step`` method is called) and the current time (this is used later to report the execution time of the solver). Here is the
58+
``__init__`` method of the base class:
5859

5960
.. code-block:: python
6061
6162
def __init__(self, Op, callbacks=None):
6263
self.Op = Op
6364
self.callbacks = callbacks
6465
self._registercallbacks()
66+
self.iiter = 0
6567
self.tstart = time.time()
6668
69+
Next, we will write the *memory_usage* method. This method allows users to get a prediction of the memory usage of
70+
the solver ahead of time (before running any of the methods of the solver as described below). It is very useful, especially
71+
for large problems, to get a feeling whether the current hardware resources (of the CPU or GPU if the user plans to run the solver on
72+
CuPy arrays and with a CuPy-enabled operator) will be sufficient to succesfully carry out the optimization process.
73+
74+
.. code-block:: python
75+
76+
def memory_usage(self, show: False, unit = "B"):
77+
nbytes = np.dtype(self.Op.dtype).itemsize
78+
79+
# Setup
80+
memuse = (self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes
81+
82+
# Step (additional variables to those in setup)
83+
memuse += (self.Op.shape[1] + self.Op.shape[0]) * nbytes
84+
85+
if show:
86+
print(f"CG predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
87+
88+
return memuse
89+
90+
Note that, although very useful, this method is not strictly needed to run the solver; so at the beginning you could just
91+
add a ``pass`` to the core of this method. However, since this method is marked as ``@abstractmethod`` in the base class,
92+
you can't simply skip it.
93+
6794
We can now move onto writing the *setup* of the solver in the method ``setup``. We will need to write
6895
a piece of code that prepares the solver prior to being able to apply a step. In general, this requires defining the
6996
data vector ``y`` and the initial guess of the solver ``x0`` (if not provided, this will be automatically set to be a zero

pylops/optimization/cls_sparsity.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,7 +1733,11 @@ def setup(
17331733

17341734
# prepare decay (if not passed)
17351735
if perc is None and decay is None:
1736-
self.decay = self.ncp.ones(niter, dtype=get_real_dtype(self.Op.dtype))
1736+
self.decay = (
1737+
1.0
1738+
if niter is None
1739+
else self.ncp.ones(niter, dtype=get_real_dtype(self.Op.dtype))
1740+
)
17371741

17381742
# step size
17391743
if alpha is not None:
@@ -1889,9 +1893,14 @@ def step(self, x: NDArray, show: bool = False) -> tuple[NDArray, float]:
18891893
self.SOpx_unthesh if self.preallocate else SOpx_unthesh
18901894
)
18911895
if self.perc is None:
1896+
decay = (
1897+
self.decay
1898+
if isinstance(self.decay, (int, float))
1899+
else self.decay[self.iiter]
1900+
) * self.thresh # single-valued decay when niter is not set in setup
18921901
x = self.threshf(
18931902
x_unthesh_or_SOpx_unthesh,
1894-
self.decay[self.iiter] * self.thresh,
1903+
decay,
18951904
)
18961905
else:
18971906
x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)
@@ -2317,9 +2326,14 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
23172326
self.SOpx_unthesh if self.preallocate else SOpx_unthesh
23182327
)
23192328
if self.perc is None:
2329+
decay = (
2330+
self.decay
2331+
if isinstance(self.decay, (int, float))
2332+
else self.decay[self.iiter]
2333+
) * self.thresh # single-valued decay when niter is not set in setup
23202334
x = self.threshf(
23212335
x_unthesh_or_SOpx_unthesh,
2322-
self.decay[self.iiter] * self.thresh,
2336+
decay,
23232337
)
23242338
else:
23252339
x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)

0 commit comments

Comments
 (0)