Skip to content

Commit f98cc31

Browse files
committed
feat: added preallocate to IRLS
1 parent dd404e6 commit f98cc31

2 files changed

Lines changed: 38 additions & 5 deletions

File tree

pylops/optimization/cls_sparsity.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,11 @@ def _step_model(
572572
return x
573573

574574
def step(
575-
self, x: NDArray, engine: str = "scipy", show: bool = False, **kwargs_solver
575+
self,
576+
x: NDArray,
577+
engine: str = "scipy",
578+
show: bool = False,
579+
**kwargs_solver,
576580
) -> NDArray:
577581
r"""Run one step of solver
578582
@@ -936,6 +940,7 @@ def setup(
936940
normalizecols: bool = False,
937941
Opbasis: Optional["LinearOperator"] = None,
938942
optimal_coeff: bool = False,
943+
preallocate: bool = False,
939944
show: bool = False,
940945
) -> None:
941946
r"""Setup solver
@@ -965,6 +970,10 @@ def setup(
965970
:math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
966971
directly the value from the inner product
967972
:math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
973+
preallocate : :obj:`bool`, optional
974+
.. versionadded:: 2.5.0
975+
976+
Pre-allocate all variables used by the solver
968977
show : :obj:`bool`, optional
969978
Display setup log
970979
@@ -979,6 +988,7 @@ def setup(
979988

980989
self.ncp = get_array_module(y)
981990
self.isjax = get_module_name(self.ncp) == "jax"
991+
self._setpreallocate(preallocate)
982992

983993
# find normalization factor for each column
984994
if self.normalizecols:
@@ -1005,6 +1015,7 @@ def step(
10051015
cols: InputDimsLike,
10061016
engine: str = "scipy",
10071017
show: bool = False,
1018+
**kwargs_solver,
10081019
) -> NDArray:
10091020
r"""Run one step of solver
10101021
@@ -1020,6 +1031,12 @@ def step(
10201031
Solver to use (``scipy`` or ``pylops``)
10211032
show : :obj:`bool`, optional
10221033
Display iteration log
1034+
**kwargs_solver
1035+
Arbitrary keyword arguments for
1036+
:py:func:`scipy.sparse.linalg.lsqr` solver when using
1037+
numpy data and ``engine='scipy'`` (or
1038+
:py:func:`pylops.optimization.solver.cgls` when using cupy
1039+
data or ``engine='pylops'``)
10231040
10241041
Returns
10251042
-------
@@ -1029,6 +1046,10 @@ def step(
10291046
Current list of chosen elements
10301047
10311048
"""
1049+
# add preallocate to keywords of solver
1050+
if self.preallocate and (engine == "pylops" or self.ncp != np):
1051+
kwargs_solver["preallocate"] = True
1052+
10321053
# compute inner products
10331054
cres = self.Op.rmatvec(self.res)
10341055
if self.normalizecols:
@@ -1060,7 +1081,7 @@ def step(
10601081
)
10611082
if not self.optimal_coeff:
10621083
# update with coefficient that maximizes the inner product
1063-
if self.isjax:
1084+
if not self.preallocate:
10641085
self.res -= Opcol.matvec(cres[imax] * self.ncp.ones(1))
10651086
else:
10661087
self.ncp.subtract(
@@ -1076,7 +1097,7 @@ def step(
10761097
# find optimal coefficient that minimizes the residual (r - cres * col)
10771098
col = Opcol.matvec(self.ncp.ones(1, dtype=Opcol.dtype))
10781099
cresopt = (Opcol.rmatvec(self.res) / Opcol.rmatvec(col))[0]
1079-
if self.isjax:
1100+
if not self.preallocate:
10801101
self.res -= Opcol.matvec(cresopt * self.ncp.ones(1))
10811102
else:
10821103
self.ncp.subtract(
@@ -1090,15 +1111,16 @@ def step(
10901111
# OMP update
10911112
Opcol = self.Op.apply_columns(cols)
10921113
if engine == "scipy" and self.ncp == np:
1093-
x = lsqr(Opcol, self.y, iter_lim=self.niter_inner)[0]
1114+
x = lsqr(Opcol, self.y, iter_lim=self.niter_inner, **kwargs_solver)[0]
10941115
elif engine == "pylops" or self.ncp != np:
10951116
x = cgls(
10961117
Opcol,
10971118
self.y,
10981119
self.ncp.zeros(int(Opcol.shape[1]), dtype=Opcol.dtype),
10991120
niter=self.niter_inner,
1121+
**kwargs_solver,
11001122
)[0]
1101-
if self.isjax:
1123+
if not self.preallocate:
11021124
self.res = self.y - Opcol.matvec(x)
11031125
else:
11041126
self.res = Opcol.matvec(x)
@@ -1205,6 +1227,7 @@ def solve(
12051227
Opbasis: Optional["LinearOperator"] = None,
12061228
optimal_coeff: bool = False,
12071229
engine: str = "scipy",
1230+
preallocate: bool = False,
12081231
show: bool = False,
12091232
itershow: Tuple[int, int, int] = (10, 10, 10),
12101233
) -> Tuple[NDArray, int, NDArray]:
@@ -1239,6 +1262,10 @@ def solve(
12391262
.. versionadded:: 2.5.0
12401263
12411264
Solver to use (``scipy`` or ``pylops``)
1265+
preallocate : :obj:`bool`, optional
1266+
.. versionadded:: 2.5.0
1267+
1268+
Pre-allocate all variables used by the solver
12421269
show : :obj:`bool`, optional
12431270
Display logs
12441271
itershow : :obj:`tuple`, optional
@@ -1264,6 +1291,7 @@ def solve(
12641291
normalizecols=normalizecols,
12651292
Opbasis=Opbasis,
12661293
optimal_coeff=optimal_coeff,
1294+
preallocate=preallocate,
12671295
show=show,
12681296
)
12691297
x: List[NDArray] = []

pylops/optimization/sparsity.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def omp(
145145
show: bool = False,
146146
itershow: Tuple[int, int, int] = (10, 10, 10),
147147
callback: Optional[Callable] = None,
148+
preallocate: bool = False,
148149
) -> Tuple[NDArray, int, NDArray]:
149150
r"""Orthogonal Matching Pursuit (OMP).
150151
@@ -192,7 +193,10 @@ def omp(
192193
Function with signature (``callback(x, cols)``) to call after each iteration
193194
where ``x`` contains the non-zero model coefficient and ``cols`` are the
194195
indices where the current model vector is non-zero
196+
preallocate : :obj:`bool`, optional
197+
.. versionadded:: 2.5.0
195198
199+
Pre-allocate all variables used by the solver
196200
Returns
197201
-------
198202
xinv : :obj:`numpy.ndarray`
@@ -228,6 +232,7 @@ def omp(
228232
engine=engine,
229233
show=show,
230234
itershow=itershow,
235+
preallocate=preallocate,
231236
)
232237
return x, niter_outer, cost
233238

0 commit comments

Comments
 (0)