@@ -1618,6 +1618,10 @@ def setup(
16181618 self .grad = self .ncp .empty_like (x )
16191619 self .x_unthesh = self .ncp .empty_like (x )
16201620 self .xold = self .ncp .empty_like (x )
1621+ if self .SOp is not None :
1622+ self .SOpx_unthesh : NDArray = self .ncp .zeros (
1623+ self .SOp .shape [1 ], dtype = self .SOp .dtype
1624+ )
16211625
16221626 # create variable to track residual
16231627 if monitorres :
@@ -1657,6 +1661,10 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
16571661 res : NDArray = self .ncp .zeros_like (self .y )
16581662 grad : NDArray = self .ncp .zeros_like (x )
16591663 x_unthesh : NDArray = self .ncp .zeros_like (x )
1664+ if self .SOp is not None :
1665+ SOpx_unthesh : NDArray = self .ncp .zeros (
1666+ self .SOp .shape [1 ], dtype = self .SOp .dtype
1667+ )
16601668
16611669 # store old vector
16621670 if self .preallocate :
@@ -1703,20 +1711,29 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17031711 out = self .x_unthesh if self .preallocate else x_unthesh ,
17041712 )
17051713
1714+ # apply SOp.H to current x
17061715 if self .SOp is not None :
17071716 if self .preallocate :
1708- self .x_unthesh [:] = self .SOprmatvec (self .x_unthesh )
1717+ self .SOpx_unthesh [:] = self .SOprmatvec (self .x_unthesh )
17091718 else :
1710- x_unthesh [:] = self .SOprmatvec (x_unthesh )
1711- if self .perc is None and self .decay is not None :
1712- x = self .threshf (
1713- self .x_unthesh if self .preallocate else x_unthesh ,
1714- self .decay [self .iiter ] * self .thresh ,
1719+ SOpx_unthesh [:] = self .SOprmatvec (x_unthesh )
1720+ # threshold current solution or current solution projected onto SOp.H space
1721+ if self .SOp is None :
1722+ x_unthesh_or_SOpx_unthesh = (
1723+ self .x_unthesh if self .preallocate else x_unthesh
1724+ )
1725+ else :
1726+ x_unthesh_or_SOpx_unthesh = (
1727+ self .SOpx_unthesh if self .preallocate else SOpx_unthesh
17151728 )
1716- elif self .perc is not None :
1729+ if self .perc is None :
17171730 x = self .threshf (
1718- self .x_unthesh if self .preallocate else x_unthesh , 100 - self .perc
1731+ x_unthesh_or_SOpx_unthesh ,
1732+ self .decay [self .iiter ] * self .thresh ,
17191733 )
1734+ else :
1735+ x = self .threshf (x_unthesh_or_SOpx_unthesh , 100 - self .perc )
1736+ # apply SOp to thresholded x
17201737 if self .SOp is not None :
17211738 x = self .SOpmatvec (x )
17221739
@@ -2022,6 +2039,10 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20222039 res : NDArray = self .ncp .zeros_like (self .y )
20232040 grad : NDArray = self .ncp .zeros_like (x )
20242041 x_unthesh : NDArray = self .ncp .zeros_like (x )
2042+ if self .SOp is not None :
2043+ SOpx_unthesh : NDArray = self .ncp .zeros (
2044+ self .SOp .shape [1 ], dtype = self .SOp .dtype
2045+ )
20252046
20262047 # store old vector
20272048 if self .preallocate :
@@ -2064,21 +2085,29 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20642085 out = self .x_unthesh if self .preallocate else x_unthesh ,
20652086 )
20662087
2067- # apply regularization operator
2088+ # apply SOp.H to current x
20682089 if self .SOp is not None :
20692090 if self .preallocate :
2070- self .x_unthesh [:] = self .SOprmatvec (self .x_unthesh )
2091+ self .SOpx_unthesh [:] = self .SOprmatvec (self .x_unthesh )
20712092 else :
2072- x_unthesh [:] = self .SOprmatvec (x_unthesh )
2073- if self . perc is None and self . decay is not None :
2074- x = self .threshf (
2075- self . x_unthesh if self . preallocate else x_unthesh ,
2076- self .decay [ self . iiter ] * self .thresh ,
2093+ SOpx_unthesh [:] = self .SOprmatvec (x_unthesh )
2094+ # threshold current solution or current solution projected onto SOp.H space
2095+ if self .SOp is None :
2096+ x_unthesh_or_SOpx_unthesh = (
2097+ self .x_unthesh if self .preallocate else x_unthesh
20772098 )
2078- elif self .perc is not None :
2099+ else :
2100+ x_unthesh_or_SOpx_unthesh = (
2101+ self .SOpx_unthesh if self .preallocate else SOpx_unthesh
2102+ )
2103+ if self .perc is None :
20792104 x = self .threshf (
2080- self .x_unthesh if self .preallocate else x_unthesh , 100 - self .perc
2105+ x_unthesh_or_SOpx_unthesh ,
2106+ self .decay [self .iiter ] * self .thresh ,
20812107 )
2108+ else :
2109+ x = self .threshf (x_unthesh_or_SOpx_unthesh , 100 - self .perc )
2110+ # apply SOp to thresholded x
20822111 if self .SOp is not None :
20832112 x = self .SOpmatvec (x )
20842113
0 commit comments