Skip to content

Commit e2becfd

Browse files
authored
Merge pull request #255 from mrava87/doc-pnptutorialextension
Doc: improved PnP tutorial
2 parents ecc9966 + a12c725 commit e2becfd

1 file changed

Lines changed: 106 additions & 16 deletions

File tree

tutorials/plugandplay.py

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
import pylops
2929
from pylops.config import set_ndarray_multiplication
30+
from pylops.utils.metrics import snr
3031

3132
import pyproximal
3233

@@ -72,7 +73,7 @@
7273

7374
###############################################################################
7475
# At this point we create a denoiser instance using the BM3D algorithm and use
75-
# as Plug-and-Play Prior to the PG and ADMM algorithms
76+
# as Plug-and-Play Prior to the ADMM, PG and HQS algorithms
7677

7778

7879
def callback(x, xtrue, errhist):
@@ -84,14 +85,31 @@ def callback(x, xtrue, errhist):
8485
tau = 1.0 / L
8586
sigma = 0.05
8687

87-
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
88-
8988
# BM3D denoiser
9089
denoiser = lambda x, tau: bm3d.bm3d(
9190
np.real(x), sigma_psd=sigma * tau, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING
9291
)
9392

93+
# ADMM-PnP
94+
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
95+
96+
errhistadmm = []
97+
xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay(
98+
l2,
99+
denoiser,
100+
x.shape,
101+
solver=pyproximal.optimization.primal.ADMM,
102+
tau=tau,
103+
x0=np.zeros(x.size),
104+
niter=40,
105+
show=True,
106+
callback=lambda xx: callback(xx, x.ravel(), errhistadmm),
107+
)[0]
108+
xpnpadmm = np.real(xpnpadmm.reshape(x.shape))
109+
94110
# PG-Pnp
111+
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
112+
95113
errhistpg = []
96114
xpnppg = pyproximal.optimization.pnp.PlugAndPlay(
97115
l2,
@@ -107,39 +125,111 @@ def callback(x, xtrue, errhist):
107125
)
108126
xpnppg = np.real(xpnppg.reshape(x.shape))
109127

110-
# ADMM-PnP
111-
errhistadmm = []
112-
xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay(
128+
# HQS-PnP
129+
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
130+
131+
tau_hqs = 1.0 / L * 0.99 ** (np.arange(40))
132+
errhisthqs = []
133+
xpnphqs = pyproximal.optimization.pnp.PlugAndPlay(
113134
l2,
114135
denoiser,
115136
x.shape,
137+
solver=pyproximal.optimization.primal.HQS,
138+
tau=tau_hqs,
139+
x0=np.zeros(x.size),
140+
niter=40,
141+
show=True,
142+
callback=lambda xx: callback(xx, x.ravel(), errhisthqs),
143+
)[0]
144+
xpnphqs = np.real(xpnphqs.reshape(x.shape))
145+
146+
fig, axs = plt.subplots(1, 4, sharey=True, figsize=(15, 5))
147+
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
148+
axs[0].set_title("Model")
149+
axs[0].axis("tight")
150+
axs[1].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
151+
axs[1].set_title(f"ADMM-PnP (SNR={snr(x, xpnpadmm):.2f} dB)")
152+
axs[1].axis("tight")
153+
axs[2].imshow(xpnppg, vmin=0, vmax=1, cmap="gray")
154+
axs[2].set_title(f"PG-PnP (SNR={snr(x, xpnppg):.2f} dB)")
155+
axs[2].axis("tight")
156+
axs[3].imshow(xpnphqs, vmin=0, vmax=1, cmap="gray")
157+
axs[3].set_title(f"HQS-PnP (SNR={snr(x, xpnphqs):.2f} dB)")
158+
axs[3].axis("tight")
159+
plt.tight_layout()
160+
161+
###############################################################################
162+
# Finally, the attentive reader may have noticed that in the HQS server a
163+
# continuation strategy was used for the `tau` parameter; whilst this is
164+
# strictly needed for HQS to converge, there is a consensus in the literature
165+
# that also other solvers should benefit from adopting the same strategy
166+
# when used with a PnP prior. This can be in fact interpreted as reducing
167+
# the strength of the denoiser as iterations progress and the estimate comes
168+
# closer to the true solution.
169+
#
170+
# While our :func:`pyproximal.optimization.primal.ADMM` solver does currently
171+
# not offer relaxation out-of-the-box, this can be achieved pretty easily
172+
# by creating an auxiliary `Denoiser` class with a `decay` parameter as
173+
# shown below.
174+
175+
176+
class Denoiser:
177+
def __init__(self, sigma, decay):
178+
self.sigma = sigma
179+
self.decay = decay
180+
self.iiter = 0
181+
182+
def denoise(self, x, tau):
183+
xden = bm3d.bm3d(
184+
np.real(x),
185+
sigma_psd=self.decay[self.iiter] * self.sigma * tau,
186+
stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING,
187+
)
188+
self.iiter += 1
189+
return xden
190+
191+
192+
# ADMM-PnP with relaxation
193+
denoiser = Denoiser(sigma, decay=0.99 ** (np.arange(40)))
194+
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
195+
196+
errhistadmm1 = []
197+
xpnpadmm1 = pyproximal.optimization.pnp.PlugAndPlay(
198+
l2,
199+
denoiser.denoise,
200+
x.shape,
116201
solver=pyproximal.optimization.primal.ADMM,
117202
tau=tau,
118203
x0=np.zeros(x.size),
119204
niter=40,
120205
show=True,
121-
callback=lambda xx: callback(xx, x.ravel(), errhistadmm),
206+
callback=lambda xx: callback(xx, x.ravel(), errhistadmm1),
122207
)[0]
123-
xpnpadmm = np.real(xpnpadmm.reshape(x.shape))
208+
xpnpadmm1 = np.real(xpnpadmm1.reshape(x.shape))
124209

125-
fig, axs = plt.subplots(1, 3, figsize=(14, 5))
210+
fig, axs = plt.subplots(1, 3, sharey=True, figsize=(15, 5))
126211
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
127212
axs[0].set_title("Model")
128213
axs[0].axis("tight")
129-
axs[1].imshow(xpnppg, vmin=0, vmax=1, cmap="gray")
130-
axs[1].set_title("PG-PnP Inversion")
214+
axs[1].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
215+
axs[1].set_title(f"ADMM-PnP (SNR={snr(x, xpnpadmm):.2f} dB)")
131216
axs[1].axis("tight")
132-
axs[2].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
133-
axs[2].set_title("ADMM-PnP Inversion")
217+
axs[2].imshow(xpnpadmm1, vmin=0, vmax=1, cmap="gray")
218+
axs[2].set_title(f"ADMM-PnP with rel. (SNR={snr(x, xpnpadmm1):.2f} dB)")
134219
axs[2].axis("tight")
135220
plt.tight_layout()
136221

137222
###############################################################################
138-
# Finally, let's compare the error convergence of the two variations of PnP
223+
# Let's finally compare the error convergence of the four variations of PnP
139224

140225
plt.figure(figsize=(12, 3))
141-
plt.plot(errhistpg, "k", lw=2, label="PG")
142-
plt.plot(errhistadmm, "r", lw=2, label="ADMM")
226+
plt.semilogy(errhistadmm, "k", lw=2, label="ADMM")
227+
plt.semilogy(errhistpg, "r", lw=2, label="PG")
228+
plt.semilogy(errhisthqs, "b", lw=2, label="HQS")
229+
plt.semilogy(errhistadmm1, "--b", lw=2, label="ADMM with rel.")
143230
plt.title("Error norm")
144231
plt.legend()
145232
plt.tight_layout()
233+
234+
###############################################################################
235+
# This final results clearly shows the importance of relaxation also for ADMM.

0 commit comments

Comments
 (0)