Skip to content

Commit feb05a2

Browse files
committed
fix: correct internal behaviour of ISTA and FISTA with SOp
1 parent 3b86136 commit feb05a2

1 file changed

Lines changed: 46 additions & 17 deletions

File tree

pylops/optimization/cls_sparsity.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,10 @@ def setup(
16181618
self.grad = self.ncp.empty_like(x)
16191619
self.x_unthesh = self.ncp.empty_like(x)
16201620
self.xold = self.ncp.empty_like(x)
1621+
if self.SOp is not None:
1622+
self.SOpx_unthesh: NDArray = self.ncp.zeros(
1623+
self.SOp.shape[1], dtype=self.SOp.dtype
1624+
)
16211625

16221626
# create variable to track residual
16231627
if monitorres:
@@ -1657,6 +1661,10 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
16571661
res: NDArray = self.ncp.zeros_like(self.y)
16581662
grad: NDArray = self.ncp.zeros_like(x)
16591663
x_unthesh: NDArray = self.ncp.zeros_like(x)
1664+
if self.SOp is not None:
1665+
SOpx_unthesh: NDArray = self.ncp.zeros(
1666+
self.SOp.shape[1], dtype=self.SOp.dtype
1667+
)
16601668

16611669
# store old vector
16621670
if self.preallocate:
@@ -1703,20 +1711,29 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17031711
out=self.x_unthesh if self.preallocate else x_unthesh,
17041712
)
17051713

1714+
# apply SOp.H to current x
17061715
if self.SOp is not None:
17071716
if self.preallocate:
1708-
self.x_unthesh[:] = self.SOprmatvec(self.x_unthesh)
1717+
self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh)
17091718
else:
1710-
x_unthesh[:] = self.SOprmatvec(x_unthesh)
1711-
if self.perc is None and self.decay is not None:
1712-
x = self.threshf(
1713-
self.x_unthesh if self.preallocate else x_unthesh,
1714-
self.decay[self.iiter] * self.thresh,
1719+
SOpx_unthesh[:] = self.SOprmatvec(x_unthesh)
1720+
# threshold current solution or current solution projected onto SOp.H space
1721+
if self.SOp is None:
1722+
x_unthesh_or_SOpx_unthesh = (
1723+
self.x_unthesh if self.preallocate else x_unthesh
1724+
)
1725+
else:
1726+
x_unthesh_or_SOpx_unthesh = (
1727+
self.SOpx_unthesh if self.preallocate else SOpx_unthesh
17151728
)
1716-
elif self.perc is not None:
1729+
if self.perc is None:
17171730
x = self.threshf(
1718-
self.x_unthesh if self.preallocate else x_unthesh, 100 - self.perc
1731+
x_unthesh_or_SOpx_unthesh,
1732+
self.decay[self.iiter] * self.thresh,
17191733
)
1734+
else:
1735+
x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)
1736+
# apply SOp to thresholded x
17201737
if self.SOp is not None:
17211738
x = self.SOpmatvec(x)
17221739

@@ -2022,6 +2039,10 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20222039
res: NDArray = self.ncp.zeros_like(self.y)
20232040
grad: NDArray = self.ncp.zeros_like(x)
20242041
x_unthesh: NDArray = self.ncp.zeros_like(x)
2042+
if self.SOp is not None:
2043+
SOpx_unthesh: NDArray = self.ncp.zeros(
2044+
self.SOp.shape[1], dtype=self.SOp.dtype
2045+
)
20252046

20262047
# store old vector
20272048
if self.preallocate:
@@ -2064,21 +2085,29 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20642085
out=self.x_unthesh if self.preallocate else x_unthesh,
20652086
)
20662087

2067-
# apply regularization operator
2088+
# apply SOp.H to current x
20682089
if self.SOp is not None:
20692090
if self.preallocate:
2070-
self.x_unthesh[:] = self.SOprmatvec(self.x_unthesh)
2091+
self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh)
20712092
else:
2072-
x_unthesh[:] = self.SOprmatvec(x_unthesh)
2073-
if self.perc is None and self.decay is not None:
2074-
x = self.threshf(
2075-
self.x_unthesh if self.preallocate else x_unthesh,
2076-
self.decay[self.iiter] * self.thresh,
2093+
SOpx_unthesh[:] = self.SOprmatvec(x_unthesh)
2094+
# threshold current solution or current solution projected onto SOp.H space
2095+
if self.SOp is None:
2096+
x_unthesh_or_SOpx_unthesh = (
2097+
self.x_unthesh if self.preallocate else x_unthesh
20772098
)
2078-
elif self.perc is not None:
2099+
else:
2100+
x_unthesh_or_SOpx_unthesh = (
2101+
self.SOpx_unthesh if self.preallocate else SOpx_unthesh
2102+
)
2103+
if self.perc is None:
20792104
x = self.threshf(
2080-
self.x_unthesh if self.preallocate else x_unthesh, 100 - self.perc
2105+
x_unthesh_or_SOpx_unthesh,
2106+
self.decay[self.iiter] * self.thresh,
20812107
)
2108+
else:
2109+
x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)
2110+
# apply SOp to thresholded x
20822111
if self.SOp is not None:
20832112
x = self.SOpmatvec(x)
20842113

0 commit comments

Comments
 (0)