Skip to content

Commit 29f5847

Browse files
committed
feat: restore old behaviour for preallocate=False
When preallocate=False is selected, all solvers behave equivalently to the current solvers without performing any operation in-place.
1 parent feb05a2 commit 29f5847

4 files changed

Lines changed: 255 additions & 134 deletions

File tree

pylops/optimization/basesolver.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = ["Solver"]
22

33
import functools
4+
import logging
45
import time
56
from abc import ABCMeta, abstractmethod
67
from typing import TYPE_CHECKING, Any
@@ -125,36 +126,48 @@ def wrapper(*args, **kwargs):
125126
),
126127
)
127128

128-
# @abstractmethod
129-
# def memory_usage(
130-
# self,
131-
# show: bool = False,
132-
# unit: str = "B",
133-
# ) -> float:
134-
# """Compute memory usage of the solver
135-
136-
# This method computes an estimate of the memory required by the solver given
137-
# the shape of the operator. This is useful to assess upfront if the solver
138-
# will run out of memory.
139-
140-
# Note, that the memory usage of the operator itself is not taken into account
141-
# in this estimate.
142-
143-
# Parameters
144-
# ----------
145-
# show : :obj:`bool`, optional
146-
# Display memory usage
147-
# unit: :obj:`str`, optional
148-
# Unit used to display memory usage (
149-
# ``B``, ``KB``, ``MB`` or ``GB``)
150-
151-
# Returns
152-
# -------
153-
# memuse :obj:`float`
154-
# Memory usage in bytes
155-
156-
# """
157-
# pass
129+
def _setpreallocate(self, preallocate: bool) -> None:
130+
# Check if the solver can work in preallocate mode
131+
# (basically all the time except when JAX arrays are
132+
# used) and force it to be False otherwise.
133+
self.preallocate = preallocate if not self.isjax else False
134+
135+
if preallocate and self.isjax:
136+
logging.warning(
137+
"Preallocation is not supported for JAX arrays. "
138+
"Setting preallocate to False."
139+
)
140+
141+
@abstractmethod
142+
def memory_usage(
143+
self,
144+
show: bool = False,
145+
unit: str = "B",
146+
) -> float:
147+
"""Compute memory usage of the solver
148+
149+
This method computes an estimate of the memory required by the solver given
150+
the shape of the operator. This is useful to assess upfront if the solver
151+
will run out of memory.
152+
153+
Note, that the memory usage of the operator itself is not taken into account
154+
in this estimate.
155+
156+
Parameters
157+
----------
158+
show : :obj:`bool`, optional
159+
Display memory usage
160+
unit: :obj:`str`, optional
161+
Unit used to display memory usage (
162+
``B``, ``KB``, ``MB`` or ``GB``)
163+
164+
Returns
165+
-------
166+
memuse :obj:`float`
167+
Memory usage in bytes
168+
169+
"""
170+
pass
158171

159172
@abstractmethod
160173
def setup(

pylops/optimization/cls_basic.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def memory_usage(
8888
# Get number of bytes of dtype used in the solver
8989
nbytes = np.dtype(self.Op.dtype).itemsize
9090

91-
# Setup: x0, y, self.r, self.c
91+
# Setup: x0 - y, self.r, self.c
9292
memuse = (self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes
9393

94-
# Step (additional variables to those in setup): Opc, c1
95-
memuse += (2 * self.Op.shape[0]) * nbytes
94+
# Step (additional variables to those in setup): c1 - Opc
95+
memuse += (self.Op.shape[1] + self.Op.shape[0]) * nbytes
9696

9797
if show:
9898
print(f"CG predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
@@ -125,7 +125,10 @@ def setup(
125125
preallocate : :obj:`bool`, optional
126126
.. versionadded:: 2.5.0
127127
128-
Pre-allocate all variables used by the solver
128+
Pre-allocate all variables used by the solver. Note that if ``y``
129+
is a JAX array, this option is ignored and variables are not
130+
pre-allocated since JAX does not support in-place operations.
131+
129132
show : :obj:`bool`, optional
130133
Display setup log
131134
@@ -138,17 +141,18 @@ def setup(
138141
self.y = y
139142
self.niter = niter
140143
self.tol = tol
141-
self.preallocate = preallocate
144+
142145
self.ncp = get_array_module(y)
143146
self.isjax = get_module_name(self.ncp) == "jax"
147+
self._setpreallocate(preallocate)
144148

145149
# initialize solver
146150
if x0 is None:
147151
x = self.ncp.zeros(self.Op.shape[1], dtype=self.y.dtype)
148152
self.r = self.y.copy()
149153
else:
150154
x = x0
151-
if self.isjax:
155+
if not self.preallocate:
152156
self.r = self.y - self.Op.matvec(x)
153157
else:
154158
self.r = self.ncp.empty_like(self.y)
@@ -186,23 +190,20 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
186190
Updated model vector
187191
188192
"""
189-
if not self.preallocate:
190-
c1 = self.ncp.empty_like(self.c)
191-
192193
Opc = self.Op.matvec(self.c)
193194
cOpc = self.ncp.abs(self.c.dot(Opc.conj()))
194195
a = self.kold / cOpc
195-
if self.isjax:
196+
if not self.preallocate:
196197
x += a * self.c
197198
self.r -= a * Opc
198199
else:
199-
self.ncp.multiply(self.c, a, out=self.c1 if self.preallocate else c1)
200-
self.ncp.add(x, self.c1 if self.preallocate else c1, out=x)
200+
self.ncp.multiply(self.c, a, out=self.c1)
201+
self.ncp.add(x, self.c1, out=x)
201202
self.ncp.multiply(Opc, a, out=Opc)
202203
self.ncp.subtract(self.r, Opc, out=self.r)
203204
k = self.ncp.abs(self.r.dot(self.r.conj()))
204205
b = k / self.kold
205-
if self.isjax:
206+
if not self.preallocate:
206207
self.c = self.r + b * self.c
207208
else:
208209
self.ncp.multiply(self.c, b, out=self.c)
@@ -401,11 +402,11 @@ def memory_usage(
401402
# Get number of bytes of dtype used in the solver
402403
nbytes = np.dtype(self.Op.dtype).itemsize
403404

404-
# Setup: x0, y, self.s, self.c, self.q
405+
# Setup: x0, self.c - y, self.s, self.q
405406
memuse = (2 * self.Op.shape[1] + 3 * self.Op.shape[0]) * nbytes
406407

407408
# Step (additional variables to those in setup): r, x1, c1
408-
memuse += (self.Op.shape[1] + 2 * self.Op.shape[0]) * nbytes
409+
memuse += (3 * self.Op.shape[1]) * nbytes
409410

410411
if show:
411412
print(f"CGLS predicted memory usage: {memuse / _units[unit]:.2f} {unit}")
@@ -455,9 +456,10 @@ def setup(
455456
self.damp = damp**2
456457
self.tol = tol
457458
self.niter = niter
458-
self.preallocate = preallocate
459+
459460
self.ncp = get_array_module(y)
460461
self.isjax = get_module_name(self.ncp) == "jax"
462+
self._setpreallocate(preallocate)
461463

462464
# initialize solver
463465
if x0 is None:
@@ -466,7 +468,7 @@ def setup(
466468
self.c = self.Op.rmatvec(self.s)
467469
else:
468470
x = x0.copy()
469-
if self.isjax:
471+
if not self.preallocate:
470472
self.s = self.y - self.Op.matvec(x)
471473
self.c = self.Op.rmatvec(self.s) - damp * x
472474
else:
@@ -512,40 +514,35 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
512514
Display iteration log
513515
514516
"""
515-
if not self.preallocate:
516-
c1 = self.ncp.empty_like(self.c)
517-
x1 = self.ncp.empty_like(x)
518-
r = self.ncp.empty_like(x)
519-
520517
a = self.kold / (
521518
self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj())
522519
)
523-
if self.isjax:
520+
if not self.preallocate:
524521
x += a * self.c
525522
self.s = self.s - a * self.q
526523
r = self.Op.rmatvec(self.s) - self.damp * x
527524
else:
528-
self.ncp.multiply(self.c, a, out=self.c1 if self.preallocate else c1)
529-
self.ncp.add(x, self.c1 if self.preallocate else c1, out=x)
525+
self.ncp.multiply(self.c, a, out=self.c1)
526+
self.ncp.add(x, self.c1, out=x)
530527

531528
self.ncp.multiply(self.q, a, out=self.q)
532529
self.ncp.subtract(self.s, self.q, out=self.s)
533530

534-
self.ncp.multiply(x, self.damp, out=self.x1 if self.preallocate else x1)
531+
self.ncp.multiply(x, self.damp, out=self.x1)
535532
self.ncp.subtract(
536533
self.Op.rmatvec(self.s),
537-
self.x1 if self.preallocate else x1,
538-
out=self.r if self.preallocate else r,
534+
self.x1,
535+
out=self.r,
539536
)
540537
k = self.ncp.abs(
541538
self.r.dot(self.r.conj()) if self.preallocate else r.dot(r.conj())
542539
)
543540
b = k / self.kold
544-
if self.isjax:
541+
if not self.preallocate:
545542
self.c = r + b * self.c
546543
else:
547544
self.ncp.multiply(self.c, b, out=self.c)
548-
self.ncp.add(self.c, self.r if self.preallocate else r, out=self.c)
545+
self.ncp.add(self.c, self.r, out=self.c)
549546
self.q = self.Op.matvec(self.c)
550547
self.kold = k
551548
self.iiter += 1
@@ -818,7 +815,7 @@ def memory_usage(
818815
# Get number of bytes of dtype used in the solver
819816
nbytes = np.dtype(self.Op.dtype).itemsize
820817

821-
# Setup: x0, y, self.u, self.v, self.w, self.dk
818+
# Setup: x0, self.v, self.w, self.dk - y, self.u
822819
memuse = (4 * self.Op.shape[1] + 2 * self.Op.shape[0]) * nbytes
823820

824821
# Step (additional variables to those in setup): w1
@@ -890,9 +887,10 @@ def setup(
890887
self.conlim = conlim
891888
self.niter = niter
892889
self.calc_var = calc_var
893-
self.preallocate = preallocate
890+
894891
self.ncp = get_array_module(y)
895892
self.isjax = get_module_name(self.ncp) == "jax"
893+
self._setpreallocate(preallocate)
896894

897895
m, n = self.Op.shape
898896

@@ -924,22 +922,22 @@ def setup(
924922
self.u = y.copy()
925923
else:
926924
x = x0.copy()
927-
if self.isjax:
925+
if self.preallocate:
928926
self.u = self.y - self.Op.matvec(x0)
929927
else:
930928
self.u = self.ncp.empty_like(self.y)
931929
self.ncp.subtract(self.y, self.Op.matvec(x0), out=self.u)
932930
self.alfa = 0.0
933931
self.beta = self.ncp.linalg.norm(self.u)
934932
if self.beta > 0.0:
935-
if self.isjax:
933+
if self.preallocate:
936934
self.u = self.u / self.beta
937935
else:
938936
self.ncp.divide(self.u, self.beta, out=self.u)
939937
self.v = self.Op.rmatvec(self.u)
940938
self.alfa = self.ncp.linalg.norm(self.v)
941939
if self.alfa > 0:
942-
if self.isjax:
940+
if self.preallocate:
943941
self.v = self.v / self.alfa
944942
else:
945943
self.ncp.divide(self.v, self.alfa, out=self.v)
@@ -994,35 +992,32 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
994992
Estimated model of size :math:`[M \times 1]`
995993
996994
"""
997-
if not self.preallocate:
998-
w1 = self.ncp.empty_like(self.w)
999-
1000995
# perform the next step of the bidiagonalization to obtain the
1001996
# next beta, u, alfa, v. These satisfy the relations
1002997
# beta*u = Op*v - alfa*u,
1003998
# alfa*v = Op'*u - beta*v'
1004-
if self.isjax:
999+
if not self.preallocate:
10051000
self.u = self.Op.matvec(self.v) - self.alfa * self.u
10061001
else:
10071002
self.ncp.multiply(self.u, self.alfa, out=self.u)
10081003
self.ncp.subtract(self.Op.matvec(self.v), self.u, out=self.u)
10091004
self.beta = self.ncp.linalg.norm(self.u)
10101005
if self.beta > 0:
1011-
if self.isjax:
1006+
if not self.preallocate:
10121007
self.u = self.u / self.beta
10131008
else:
10141009
self.ncp.divide(self.u, self.beta, out=self.u)
10151010
self.anorm = np.linalg.norm(
10161011
[self.anorm, to_numpy(self.alfa), to_numpy(self.beta), self.damp]
10171012
)
1018-
if self.isjax:
1013+
if not self.preallocate:
10191014
self.v = self.Op.rmatvec(self.u) - self.beta * self.v
10201015
else:
10211016
self.ncp.multiply(self.v, self.beta, out=self.v)
10221017
self.ncp.subtract(self.Op.rmatvec(self.u), self.v, out=self.v)
10231018
self.alfa = self.ncp.linalg.norm(self.v)
10241019
if self.alfa > 0:
1025-
if self.isjax:
1020+
if not self.preallocate:
10261021
self.v = self.v / self.alfa
10271022
else:
10281023
self.ncp.divide(self.v, self.alfa, out=self.v)
@@ -1049,14 +1044,14 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
10491044
# update x and w.
10501045
self.t1 = self.phi / self.rho
10511046
self.t2 = -self.theta / self.rho
1052-
if self.isjax:
1047+
if not self.preallocate:
10531048
self.dk = self.w / self.rho
10541049
x = x + self.t1 * self.w
10551050
self.w = self.v + self.t2 * self.w
10561051
else:
10571052
self.ncp.divide(self.w, self.rho, out=self.dk)
1058-
self.ncp.multiply(self.w, self.t1, out=self.w1 if self.preallocate else w1)
1059-
self.ncp.add(x, self.w1 if self.preallocate else w1, out=x)
1053+
self.ncp.multiply(self.w, self.t1, out=self.w1)
1054+
self.ncp.add(x, self.w1, out=x)
10601055
self.ncp.multiply(self.w, self.t2, out=self.w)
10611056
self.ncp.add(self.v, self.w, out=self.w)
10621057
self.ddnorm = self.ddnorm + self.ncp.linalg.norm(self.dk) ** 2

0 commit comments

Comments
 (0)