@@ -1720,19 +1720,22 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17201720 if self .SOp is not None :
17211721 x = self .SOpmatvec (x )
17221722
1723- # model update
1723+ # check model update
17241724 if not self .preallocate :
17251725 xupdate = np .linalg .norm (x - xold )
17261726 else :
17271727 self .ncp .subtract (
17281728 x ,
1729- self .xold if self . preallocate else xold ,
1730- out = self .xold if self . preallocate else xold ,
1729+ self .xold ,
1730+ out = self .xold ,
17311731 )
1732- xupdate = np .linalg .norm (self .xold if self .preallocate else xold )
1732+ xupdate = np .linalg .norm (self .xold )
1733+
1734+ # cost functions
17331735 costdata = 0.5 * np .linalg .norm (self .res if self .preallocate else res ) ** 2
17341736 costreg = self .eps * np .linalg .norm (x , ord = 1 )
17351737 self .cost .append (float (costdata + costreg ))
1738+
17361739 self .iiter += 1
17371740 if show :
17381741 self ._print_step (x , costdata , costreg , xupdate )
@@ -2028,9 +2031,7 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20282031 if not self .preallocate :
20292032 res : NDArray = self .y - self .Opmatvec (z )
20302033 else :
2031- self .ncp .subtract (
2032- self .y , self .Opmatvec (z ), out = self .res if self .preallocate else res
2033- )
2034+ self .ncp .subtract (self .y , self .Opmatvec (z ), out = self .res )
20342035
20352036 if self .monitorres :
20362037 self .normres = np .linalg .norm (self .res if self .preallocate else res )
@@ -2049,14 +2050,14 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20492050 x_unthesh : NDArray = z + grad
20502051 else :
20512052 self .ncp .multiply (
2052- self .Oprmatvec (self .res if self . preallocate else res ),
2053+ self .Oprmatvec (self .res ),
20532054 self .alpha ,
2054- out = self .grad if self . preallocate else grad ,
2055+ out = self .grad ,
20552056 )
20562057 self .ncp .add (
20572058 z ,
2058- self .grad if self . preallocate else grad ,
2059- out = self .x_unthesh if self . preallocate else x_unthesh ,
2059+ self .grad ,
2060+ out = self .x_unthesh ,
20602061 )
20612062
20622063 # apply SOp.H to current x
@@ -2065,6 +2066,7 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20652066 self .SOpx_unthesh [:] = self .SOprmatvec (self .x_unthesh )
20662067 else :
20672068 SOpx_unthesh = self .SOprmatvec (x_unthesh )
2069+
20682070 # threshold current solution or current solution projected onto SOp.H space
20692071 if self .SOp is None :
20702072 x_unthesh_or_SOpx_unthesh = (
@@ -2081,6 +2083,7 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20812083 )
20822084 else :
20832085 x = self .threshf (x_unthesh_or_SOpx_unthesh , 100 - self .perc )
2086+
20842087 # apply SOp to thresholded x
20852088 if self .SOp is not None :
20862089 x = self .SOpmatvec (x )
@@ -2095,20 +2098,25 @@ def step(self, x: NDArray, z: NDArray, show: bool = False) -> NDArray:
20952098 else :
20962099 self .ncp .subtract (
20972100 x ,
2098- self .xold if self .preallocate else xold ,
2099- out = self .xold if self .preallocate else xold ,
2100- )
2101- self .ncp .multiply (
2102- self .xold if self .preallocate else xold , ((told - 1.0 ) / self .t ), out = z
2101+ self .xold ,
2102+ out = self .xold ,
21032103 )
2104+ self .ncp .multiply (self .xold , ((told - 1.0 ) / self .t ), out = z )
21042105 self .ncp .add (x , z , out = z )
21052106
2106- # xupdate = np.linalg.norm(x - xold)
2107- xupdate = np .linalg .norm (self .xold if self .preallocate else xold )
2107+ # check model update
2108+ if not self .preallocate :
2109+ xupdate = np .linalg .norm (x - xold )
2110+ else :
2111+ # note that x - xold has been already computed as part of the
2112+ # intermediate calculation of x in model update step
2113+ xupdate = np .linalg .norm (self .xold )
21082114
2115+ # cost functions
21092116 costdata = 0.5 * np .linalg .norm (self .y - self .Op @ x ) ** 2
21102117 costreg = self .eps * np .linalg .norm (x , ord = 1 )
21112118 self .cost .append (float (costdata + costreg ))
2119+
21122120 self .iiter += 1
21132121 if show :
21142122 self ._print_step (x , costdata , costreg , xupdate )
@@ -2229,6 +2237,13 @@ def _print_finalize(self) -> None:
22292237 print (f"\n Total time (s) = { self .telapsed :.2f} " )
22302238 print ("-" * 80 + "\n " )
22312239
2240+ def memory_usage (
2241+ self ,
2242+ show : bool = False ,
2243+ unit : str = "B" ,
2244+ ) -> float :
2245+ pass
2246+
22322247 def setup (
22332248 self ,
22342249 y : NDArray ,
0 commit comments