2727import numpy as np
2828import pylops
2929from pylops .config import set_ndarray_multiplication
30+ from pylops .utils .metrics import snr
3031
3132import pyproximal
3233
7273
7374###############################################################################
7475# At this point we create a denoiser instance using the BM3D algorithm and use
75- # as Plug-and-Play Prior to the PG and ADMM algorithms
76+ # as Plug-and-Play Prior to the ADMM, PG and HQS algorithms
7677
7778
7879def callback (x , xtrue , errhist ):
@@ -84,14 +85,31 @@ def callback(x, xtrue, errhist):
8485tau = 1.0 / L
8586sigma = 0.05
8687
87- l2 = pyproximal .proximal .L2 (Op = Op , b = y .ravel (), niter = 50 , warm = True )
88-
8988# BM3D denoiser
9089denoiser = lambda x , tau : bm3d .bm3d (
9190 np .real (x ), sigma_psd = sigma * tau , stage_arg = bm3d .BM3DStages .HARD_THRESHOLDING
9291)
9392
93+ # ADMM-PnP
94+ l2 = pyproximal .proximal .L2 (Op = Op , b = y .ravel (), niter = 50 , warm = True )
95+
96+ errhistadmm = []
97+ xpnpadmm = pyproximal .optimization .pnp .PlugAndPlay (
98+ l2 ,
99+ denoiser ,
100+ x .shape ,
101+ solver = pyproximal .optimization .primal .ADMM ,
102+ tau = tau ,
103+ x0 = np .zeros (x .size ),
104+ niter = 40 ,
105+ show = True ,
106+ callback = lambda xx : callback (xx , x .ravel (), errhistadmm ),
107+ )[0 ]
108+ xpnpadmm = np .real (xpnpadmm .reshape (x .shape ))
109+
94110# PG-Pnp
111+ l2 = pyproximal .proximal .L2 (Op = Op , b = y .ravel (), niter = 50 , warm = True )
112+
95113errhistpg = []
96114xpnppg = pyproximal .optimization .pnp .PlugAndPlay (
97115 l2 ,
@@ -107,39 +125,111 @@ def callback(x, xtrue, errhist):
107125)
108126xpnppg = np .real (xpnppg .reshape (x .shape ))
109127
110- # ADMM-PnP
111- errhistadmm = []
112- xpnpadmm = pyproximal .optimization .pnp .PlugAndPlay (
128+ # HQS-PnP
129+ l2 = pyproximal .proximal .L2 (Op = Op , b = y .ravel (), niter = 50 , warm = True )
130+
131+ tau_hqs = 1.0 / L * 0.99 ** (np .arange (40 ))
132+ errhisthqs = []
133+ xpnphqs = pyproximal .optimization .pnp .PlugAndPlay (
113134 l2 ,
114135 denoiser ,
115136 x .shape ,
137+ solver = pyproximal .optimization .primal .HQS ,
138+ tau = tau_hqs ,
139+ x0 = np .zeros (x .size ),
140+ niter = 40 ,
141+ show = True ,
142+ callback = lambda xx : callback (xx , x .ravel (), errhisthqs ),
143+ )[0 ]
144+ xpnphqs = np .real (xpnphqs .reshape (x .shape ))
145+
146+ fig , axs = plt .subplots (1 , 4 , sharey = True , figsize = (15 , 5 ))
147+ axs [0 ].imshow (x , vmin = 0 , vmax = 1 , cmap = "gray" )
148+ axs [0 ].set_title ("Model" )
149+ axs [0 ].axis ("tight" )
150+ axs [1 ].imshow (xpnpadmm , vmin = 0 , vmax = 1 , cmap = "gray" )
151+ axs [1 ].set_title (f"ADMM-PnP (SNR={ snr (x , xpnpadmm ):.2f} dB)" )
152+ axs [1 ].axis ("tight" )
153+ axs [2 ].imshow (xpnppg , vmin = 0 , vmax = 1 , cmap = "gray" )
154+ axs [2 ].set_title (f"PG-PnP (SNR={ snr (x , xpnppg ):.2f} dB)" )
155+ axs [2 ].axis ("tight" )
156+ axs [3 ].imshow (xpnphqs , vmin = 0 , vmax = 1 , cmap = "gray" )
157+ axs [3 ].set_title (f"HQS-PnP (SNR={ snr (x , xpnphqs ):.2f} dB)" )
158+ axs [3 ].axis ("tight" )
159+ plt .tight_layout ()
160+
161+ ###############################################################################
162+ # Finally, the attentive reader may have noticed that in the HQS server a
163+ # continuation strategy was used for the `tau` parameter; whilst this is
164+ # strictly needed for HQS to converge, there is a consensus in the literature
165+ # that also other solvers should benefit from adopting the same strategy
166+ # when used with a PnP prior. This can be in fact interpreted as reducing
167+ # the strength of the denoiser as iterations progress and the estimate comes
168+ # closer to the true solution.
169+ #
170+ # While our :func:`pyproximal.optimization.primal.ADMM` solver does currently
171+ # not offer relaxation out-of-the-box, this can be achieved pretty easily
172+ # by creating an auxiliary `Denoiser` class with a `decay` parameter as
173+ # shown below.
174+
175+
176+ class Denoiser :
177+ def __init__ (self , sigma , decay ):
178+ self .sigma = sigma
179+ self .decay = decay
180+ self .iiter = 0
181+
182+ def denoise (self , x , tau ):
183+ xden = bm3d .bm3d (
184+ np .real (x ),
185+ sigma_psd = self .decay [self .iiter ] * self .sigma * tau ,
186+ stage_arg = bm3d .BM3DStages .HARD_THRESHOLDING ,
187+ )
188+ self .iiter += 1
189+ return xden
190+
191+
192+ # ADMM-PnP with relaxation
193+ denoiser = Denoiser (sigma , decay = 0.99 ** (np .arange (40 )))
194+ l2 = pyproximal .proximal .L2 (Op = Op , b = y .ravel (), niter = 50 , warm = True )
195+
196+ errhistadmm1 = []
197+ xpnpadmm1 = pyproximal .optimization .pnp .PlugAndPlay (
198+ l2 ,
199+ denoiser .denoise ,
200+ x .shape ,
116201 solver = pyproximal .optimization .primal .ADMM ,
117202 tau = tau ,
118203 x0 = np .zeros (x .size ),
119204 niter = 40 ,
120205 show = True ,
121- callback = lambda xx : callback (xx , x .ravel (), errhistadmm ),
206+ callback = lambda xx : callback (xx , x .ravel (), errhistadmm1 ),
122207)[0 ]
123- xpnpadmm = np .real (xpnpadmm .reshape (x .shape ))
208+ xpnpadmm1 = np .real (xpnpadmm1 .reshape (x .shape ))
124209
125- fig , axs = plt .subplots (1 , 3 , figsize = (14 , 5 ))
210+ fig , axs = plt .subplots (1 , 3 , sharey = True , figsize = (15 , 5 ))
126211axs [0 ].imshow (x , vmin = 0 , vmax = 1 , cmap = "gray" )
127212axs [0 ].set_title ("Model" )
128213axs [0 ].axis ("tight" )
129- axs [1 ].imshow (xpnppg , vmin = 0 , vmax = 1 , cmap = "gray" )
130- axs [1 ].set_title ("PG -PnP Inversion " )
214+ axs [1 ].imshow (xpnpadmm , vmin = 0 , vmax = 1 , cmap = "gray" )
215+ axs [1 ].set_title (f"ADMM -PnP (SNR= { snr ( x , xpnpadmm ):.2f } dB) " )
131216axs [1 ].axis ("tight" )
132- axs [2 ].imshow (xpnpadmm , vmin = 0 , vmax = 1 , cmap = "gray" )
133- axs [2 ].set_title ("ADMM-PnP Inversion " )
217+ axs [2 ].imshow (xpnpadmm1 , vmin = 0 , vmax = 1 , cmap = "gray" )
218+ axs [2 ].set_title (f "ADMM-PnP with rel. (SNR= { snr ( x , xpnpadmm1 ):.2f } dB) " )
134219axs [2 ].axis ("tight" )
135220plt .tight_layout ()
136221
137222###############################################################################
138- # Finally, let 's compare the error convergence of the two variations of PnP
223+ # Let 's finally compare the error convergence of the four variations of PnP
139224
140225plt .figure (figsize = (12 , 3 ))
141- plt .plot (errhistpg , "k" , lw = 2 , label = "PG" )
142- plt .plot (errhistadmm , "r" , lw = 2 , label = "ADMM" )
226+ plt .semilogy (errhistadmm , "k" , lw = 2 , label = "ADMM" )
227+ plt .semilogy (errhistpg , "r" , lw = 2 , label = "PG" )
228+ plt .semilogy (errhisthqs , "b" , lw = 2 , label = "HQS" )
229+ plt .semilogy (errhistadmm1 , "--b" , lw = 2 , label = "ADMM with rel." )
143230plt .title ("Error norm" )
144231plt .legend ()
145232plt .tight_layout ()
233+
234+ ###############################################################################
235+ # This final results clearly shows the importance of relaxation also for ADMM.
0 commit comments