Skip to content

Commit a35540b

Browse files
committed
bug: fix handling of xupdate in ista/fista for preallocate case
1 parent 5ec9c78 commit a35540b

1 file changed

Lines changed: 33 additions & 18 deletions

File tree

pylops/optimization/cls_sparsity.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,19 +1720,22 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17201720
if self.SOp is not None:
17211721
x = self.SOpmatvec(x)
17221722

1723-
# model update
1723+
# check model update
17241724
if not self.preallocate:
17251725
xupdate = np.linalg.norm(x - xold)
17261726
else:
17271727
self.ncp.subtract(
17281728
x,
1729-
self.xold if self.preallocate else xold,
1730-
out=self.xold if self.preallocate else xold,
1729+
self.xold,
1730+
out=self.xold,
17311731
)
1732-
xupdate = np.linalg.norm(self.xold if self.preallocate else xold)
1732+
xupdate = np.linalg.norm(self.xold)
1733+
1734+
# cost functions
17331735
costdata = 0.5 * np.linalg.norm(self.res if self.preallocate else res) ** 2
17341736
costreg = self.eps * np.linalg.norm(x, ord=1)
17351737
self.cost.append(float(costdata + costreg))
1738+
17361739
self.iiter += 1
17371740
if show:
17381741
self._print_step(x, costdata, costreg, xupdate)
@@ -2028,9 +2031,7 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20282031
if not self.preallocate:
20292032
res: NDArray = self.y - self.Opmatvec(z)
20302033
else:
2031-
self.ncp.subtract(
2032-
self.y, self.Opmatvec(z), out=self.res if self.preallocate else res
2033-
)
2034+
self.ncp.subtract(self.y, self.Opmatvec(z), out=self.res)
20342035

20352036
if self.monitorres:
20362037
self.normres = np.linalg.norm(self.res if self.preallocate else res)
@@ -2049,14 +2050,14 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20492050
x_unthesh: NDArray = z + grad
20502051
else:
20512052
self.ncp.multiply(
2052-
self.Oprmatvec(self.res if self.preallocate else res),
2053+
self.Oprmatvec(self.res),
20532054
self.alpha,
2054-
out=self.grad if self.preallocate else grad,
2055+
out=self.grad,
20552056
)
20562057
self.ncp.add(
20572058
z,
2058-
self.grad if self.preallocate else grad,
2059-
out=self.x_unthesh if self.preallocate else x_unthesh,
2059+
self.grad,
2060+
out=self.x_unthesh,
20602061
)
20612062

20622063
# apply SOp.H to current x
@@ -2065,6 +2066,7 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20652066
self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh)
20662067
else:
20672068
SOpx_unthesh = self.SOprmatvec(x_unthesh)
2069+
20682070
# threshold current solution or current solution projected onto SOp.H space
20692071
if self.SOp is None:
20702072
x_unthesh_or_SOpx_unthesh = (
@@ -2081,6 +2083,7 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20812083
)
20822084
else:
20832085
x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)
2086+
20842087
# apply SOp to thresholded x
20852088
if self.SOp is not None:
20862089
x = self.SOpmatvec(x)
@@ -2095,20 +2098,25 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20952098
else:
20962099
self.ncp.subtract(
20972100
x,
2098-
self.xold if self.preallocate else xold,
2099-
out=self.xold if self.preallocate else xold,
2100-
)
2101-
self.ncp.multiply(
2102-
self.xold if self.preallocate else xold, ((told - 1.0) / self.t), out=z
2101+
self.xold,
2102+
out=self.xold,
21032103
)
2104+
self.ncp.multiply(self.xold, ((told - 1.0) / self.t), out=z)
21042105
self.ncp.add(x, z, out=z)
21052106

2106-
# xupdate = np.linalg.norm(x - xold)
2107-
xupdate = np.linalg.norm(self.xold if self.preallocate else xold)
2107+
# check model update
2108+
if not self.preallocate:
2109+
xupdate = np.linalg.norm(x - xold)
2110+
else:
2111+
# note that x - xold has been already computed as part of the
2112+
# intermediate calculation of x in model update step
2113+
xupdate = np.linalg.norm(self.xold)
21082114

2115+
# cost functions
21092116
costdata = 0.5 * np.linalg.norm(self.y - self.Op @ x) ** 2
21102117
costreg = self.eps * np.linalg.norm(x, ord=1)
21112118
self.cost.append(float(costdata + costreg))
2119+
21122120
self.iiter += 1
21132121
if show:
21142122
self._print_step(x, costdata, costreg, xupdate)
@@ -2229,6 +2237,13 @@ def _print_finalize(self) -> None:
22292237
print(f"\nTotal time (s) = {self.telapsed:.2f}")
22302238
print("-" * 80 + "\n")
22312239

2240+
def memory_usage(
2241+
self,
2242+
show: bool = False,
2243+
unit: str = "B",
2244+
) -> float:
2245+
pass
2246+
22322247
def setup(
22332248
self,
22342249
y: NDArray,

0 commit comments

Comments
 (0)