@@ -572,7 +572,11 @@ def _step_model(
572572 return x
573573
574574 def step (
575- self , x : NDArray , engine : str = "scipy" , show : bool = False , ** kwargs_solver
575+ self ,
576+ x : NDArray ,
577+ engine : str = "scipy" ,
578+ show : bool = False ,
579+ ** kwargs_solver ,
576580 ) -> NDArray :
577581 r"""Run one step of solver
578582
@@ -936,6 +940,7 @@ def setup(
936940 normalizecols : bool = False ,
937941 Opbasis : Optional ["LinearOperator" ] = None ,
938942 optimal_coeff : bool = False ,
943+ preallocate : bool = False ,
939944 show : bool = False ,
940945 ) -> None :
941946 r"""Setup solver
@@ -965,6 +970,10 @@ def setup(
965970 :math:`\mathbf{r} - c * \mathbf{Op}^j) norm (``True``) or use the
966971 directly the value from the inner product
967972 :math:`\mathbf{Op}_j^H\,\mathbf{r}_k`.
973+ preallocate : :obj:`bool`, optional
974+ .. versionadded:: 2.5.0
975+
976+ Pre-allocate all variables used by the solver
968977 show : :obj:`bool`, optional
969978 Display setup log
970979
@@ -979,6 +988,7 @@ def setup(
979988
980989 self .ncp = get_array_module (y )
981990 self .isjax = get_module_name (self .ncp ) == "jax"
991+ self ._setpreallocate (preallocate )
982992
983993 # find normalization factor for each column
984994 if self .normalizecols :
@@ -1005,6 +1015,7 @@ def step(
10051015 cols : InputDimsLike ,
10061016 engine : str = "scipy" ,
10071017 show : bool = False ,
1018+ ** kwargs_solver ,
10081019 ) -> NDArray :
10091020 r"""Run one step of solver
10101021
@@ -1020,6 +1031,12 @@ def step(
10201031 Solver to use (``scipy`` or ``pylops``)
10211032 show : :obj:`bool`, optional
10221033 Display iteration log
1034+ **kwargs_solver
1035+ Arbitrary keyword arguments for
1036+ :py:func:`scipy.sparse.linalg.lsqr` solver when using
1037+ numpy data and ``engine='scipy'`` (or
1038+ :py:func:`pylops.optimization.solver.cgls` when using cupy
1039+ data or ``engine='pylops'``)
10231040
10241041 Returns
10251042 -------
@@ -1029,6 +1046,10 @@ def step(
10291046 Current list of chosen elements
10301047
10311048 """
1049+ # add preallocate to keywords of solver
1050+ if self .preallocate and (engine == "pylops" or self .ncp != np ):
1051+ kwargs_solver ["preallocate" ] = True
1052+
10321053 # compute inner products
10331054 cres = self .Op .rmatvec (self .res )
10341055 if self .normalizecols :
@@ -1060,7 +1081,7 @@ def step(
10601081 )
10611082 if not self .optimal_coeff :
10621083 # update with coefficient that maximizes the inner product
1063- if self .isjax :
1084+ if not self .preallocate :
10641085 self .res -= Opcol .matvec (cres [imax ] * self .ncp .ones (1 ))
10651086 else :
10661087 self .ncp .subtract (
@@ -1076,7 +1097,7 @@ def step(
10761097 # find optimal coefficient that minimizes the residual (r - cres * col)
10771098 col = Opcol .matvec (self .ncp .ones (1 , dtype = Opcol .dtype ))
10781099 cresopt = (Opcol .rmatvec (self .res ) / Opcol .rmatvec (col ))[0 ]
1079- if self .isjax :
1100+ if not self .preallocate :
10801101 self .res -= Opcol .matvec (cresopt * self .ncp .ones (1 ))
10811102 else :
10821103 self .ncp .subtract (
@@ -1090,15 +1111,16 @@ def step(
10901111 # OMP update
10911112 Opcol = self .Op .apply_columns (cols )
10921113 if engine == "scipy" and self .ncp == np :
1093- x = lsqr (Opcol , self .y , iter_lim = self .niter_inner )[0 ]
1114+ x = lsqr (Opcol , self .y , iter_lim = self .niter_inner , ** kwargs_solver )[0 ]
10941115 elif engine == "pylops" or self .ncp != np :
10951116 x = cgls (
10961117 Opcol ,
10971118 self .y ,
10981119 self .ncp .zeros (int (Opcol .shape [1 ]), dtype = Opcol .dtype ),
10991120 niter = self .niter_inner ,
1121+ ** kwargs_solver ,
11001122 )[0 ]
1101- if self .isjax :
1123+ if not self .preallocate :
11021124 self .res = self .y - Opcol .matvec (x )
11031125 else :
11041126 self .res = Opcol .matvec (x )
@@ -1205,6 +1227,7 @@ def solve(
12051227 Opbasis : Optional ["LinearOperator" ] = None ,
12061228 optimal_coeff : bool = False ,
12071229 engine : str = "scipy" ,
1230+ preallocate : bool = False ,
12081231 show : bool = False ,
12091232 itershow : Tuple [int , int , int ] = (10 , 10 , 10 ),
12101233 ) -> Tuple [NDArray , int , NDArray ]:
@@ -1239,6 +1262,10 @@ def solve(
12391262 .. versionadded:: 2.5.0
12401263
12411264 Solver to use (``scipy`` or ``pylops``)
1265+ preallocate : :obj:`bool`, optional
1266+ .. versionadded:: 2.5.0
1267+
1268+ Pre-allocate all variables used by the solver
12421269 show : :obj:`bool`, optional
12431270 Display logs
12441271 itershow : :obj:`tuple`, optional
@@ -1264,6 +1291,7 @@ def solve(
12641291 normalizecols = normalizecols ,
12651292 Opbasis = Opbasis ,
12661293 optimal_coeff = optimal_coeff ,
1294+ preallocate = preallocate ,
12671295 show = show ,
12681296 )
12691297 x : List [NDArray ] = []
0 commit comments