Skip to content

Commit ab187f7

Browse files
committed
doc: added red tutorial
1 parent 81a1365 commit ab187f7

2 files changed

Lines changed: 223 additions & 2 deletions

File tree

pyproximal/proximal/RED.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class RED(ProxOperator):
5151
Denoiser (must be a function with one input corresponding to
5252
the signal to be denoised)
5353
dims : :obj:`tuple`
54-
Dimensions used to reshape the vector ``x`` in the ``prox`` method
55-
prior to calling the ``denoiser``
54+
Dimensions used to reshape the vector ``x`` in the ``denoiser``
55+
method prior to applying the denoiser
5656
sigma : :obj:`float`, optional
5757
Multiplicative coefficient of RED term
5858
sigmad : :obj:`float` or :obj:`numpy.ndarray` or :obj:`func`, optional

tutorials/red.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
r"""
2+
Regularization by Denoising (RED)
3+
=================================
4+
This is a follow up tutorial to the :ref:`sphx_glr_tutorials_plugandplay.py` tutorial,
5+
showcasing an competitive technical of the famous Plug-and-Play method called
6+
Regularization by Denoising (RED).
7+
8+
The Plug-and-Play algorithm leverges a user-defined denoiser in place of the proximal
9+
operator of the regularization term in the solution of an inverse problem, ultimately
10+
acting as an implicit prior; RED, instead, defines an the following
11+
explicit regularization term
12+
13+
.. math::
14+
RED(\mathbf{x}) = \sigma\mathbf{x}^T (\mathbf{x} - f_{\sigma_d}(\mathbf{x}))
15+
16+
where the dot-product of the sought after model and residual from the action of
17+
the denoiser is minimized.
18+
19+
Let's consider again a simplified MRI experiment, where the
20+
data is created by appling a 2D Fourier Transform to the input model and
21+
by randomly sampling 60% of its values, and the
22+
`BM3D <https://pypi.org/project/bm3d>`_ method as the denoiser of choice.
23+
24+
Two different solvers will be compared, namely:
25+
26+
- Gradient descent, which simply uses the gradient of the data misfit term and that
27+
of the (now well defined and differentiable) regularization term;
28+
- ADMM, where the proximal of RED is solved using a fixed-point iteration.
29+
- Fixed-point method.
30+
31+
"""
32+
33+
import bm3d
34+
import matplotlib.pyplot as plt
35+
import numpy as np
36+
import pylops
37+
from pylops.config import set_ndarray_multiplication
38+
from pylops.utils.metrics import snr
39+
from scipy.sparse.linalg import lsqr
40+
41+
import pyproximal
42+
43+
plt.close("all")
44+
np.random.seed(0)
45+
set_ndarray_multiplication(False)
46+
47+
48+
###############################################################################
49+
# Let's first write a simple gradient descent solver and a fixed-point solver
50+
def GradientDescent(f, g, x0, xtrue, alpha=1.0, niter=100):
51+
x = x0.copy()
52+
errhist = []
53+
for _ in range(niter):
54+
grad = f.grad(x).real + g.grad(x)
55+
x -= alpha * grad
56+
errhist.append(np.linalg.norm(x - xtrue))
57+
return x, errhist
58+
59+
60+
def FixedPoint(Op, y, denoiser, x0, xtrue, sigma, sigmad, niter=100, niter_inner=10):
61+
x = x0.copy()
62+
yy = Op.H @ y
63+
sigmad = sigmad * np.ones(niter) if isinstance(sigmad, float) else sigmad
64+
errhist = []
65+
for i in range(niter):
66+
xden = denoiser(x, sigmad(i))
67+
Op1 = Op1 = sigma * pylops.Identity(Op.shape[1], dtype=Op.dtype) + Op.H * Op
68+
y1 = yy + sigma * xden
69+
x = x = lsqr(Op1, y1, iter_lim=niter_inner, x0=x)[0]
70+
errhist.append(np.linalg.norm(x - xtrue))
71+
return x, errhist
72+
73+
74+
###############################################################################
75+
# We start by loading the famous Shepp logan phantom and creating the
76+
# modelling operator
77+
x = np.load("../testdata/shepp_logan_phantom.npy")
78+
x = x / x.max()
79+
ny, nx = x.shape
80+
81+
perc_subsampling = 0.6
82+
nxsub = int(np.round(ny * nx * perc_subsampling))
83+
iava = np.sort(np.random.permutation(np.arange(ny * nx))[:nxsub])
84+
Rop = pylops.Restriction(ny * nx, iava, dtype=np.complex128)
85+
Fop = pylops.signalprocessing.FFT2D(dims=(ny, nx))
86+
87+
###############################################################################
88+
# We now create and display the data alongside the model
89+
y = Rop * Fop * x.ravel()
90+
yfft = Fop * x.ravel()
91+
yfft = np.fft.fftshift(yfft.reshape(ny, nx))
92+
93+
ymask = Rop.mask(Fop * x.ravel())
94+
ymask = ymask.reshape(ny, nx)
95+
ymask.data[:] = np.fft.fftshift(ymask.data)
96+
ymask.mask[:] = np.fft.fftshift(ymask.mask)
97+
98+
fig, axs = plt.subplots(1, 3, figsize=(14, 5))
99+
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
100+
axs[0].set_title("Model")
101+
axs[0].axis("tight")
102+
axs[1].imshow(np.abs(yfft), vmin=0, vmax=1, cmap="rainbow")
103+
axs[1].set_title("Full data")
104+
axs[1].axis("tight")
105+
axs[2].imshow(np.abs(ymask), vmin=0, vmax=1, cmap="rainbow")
106+
axs[2].set_title("Sampled data")
107+
axs[2].axis("tight")
108+
plt.tight_layout()
109+
110+
###############################################################################
111+
# At this point we create a denoiser instance using the BM3D algorithm and use
112+
# the gradient descent solver that we wrote at the start
113+
114+
115+
def sigmad(iiter):
116+
return 0.1 * 0.99**iiter
117+
118+
119+
# BM3D denoiser
120+
denoiser = lambda x, sigma: bm3d.bm3d(
121+
np.real(x), sigma_psd=sigma, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING
122+
)
123+
124+
l2 = pyproximal.proximal.L2(Op=Rop * Fop, b=y.ravel())
125+
red = pyproximal.proximal.RED(denoiser, x.shape, sigma=0.4, sigmad=sigmad, call=False)
126+
127+
xredgd, errhistgd = GradientDescent(
128+
l2,
129+
red,
130+
x0=np.zeros(x.size),
131+
xtrue=x.ravel(),
132+
alpha=0.5,
133+
niter=50,
134+
)
135+
xredgd = np.real(xredgd.reshape(x.shape))
136+
137+
###############################################################################
138+
# And now we use the ADMM solver
139+
140+
141+
def callback(x, xtrue, errhist):
142+
errhist.append(np.linalg.norm(x - xtrue))
143+
144+
145+
Op = Rop * Fop
146+
L = np.real((Op.H * Op).eigs(neigs=1, which="LM")[0])
147+
tau = 1.0 / L
148+
149+
# BM3D denoiser
150+
denoiser = lambda x, sigma: bm3d.bm3d(
151+
np.real(x), sigma_psd=sigma, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING
152+
)
153+
154+
# ADMM-RED
155+
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=10, warm=True)
156+
red = pyproximal.proximal.RED(
157+
denoiser, x.shape, sigma=0.4, sigmad=sigmad, niter=5, warm=True, call=False
158+
)
159+
160+
errhistadmm = []
161+
xredadmm = pyproximal.optimization.pnp.ADMM(
162+
l2,
163+
red,
164+
tau=1.0,
165+
x0=np.zeros(x.size),
166+
niter=50,
167+
show=True,
168+
callback=lambda xx: callback(xx, x.ravel(), errhistadmm),
169+
)[0]
170+
xredadmm = np.real(xredadmm.reshape(x.shape))
171+
172+
###############################################################################
173+
# And finally we use the Fixed-Point solver
174+
175+
# BM3D
176+
xshape = x.shape
177+
den = lambda x, sigma: bm3d.bm3d(
178+
x.real.reshape(xshape), sigma_psd=sigma, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING
179+
).ravel()
180+
181+
# FP-RED
182+
xredfp, errhistfp = FixedPoint(
183+
Rop * Fop,
184+
y.ravel(),
185+
den,
186+
x0=np.zeros(x.size),
187+
xtrue=x.ravel(),
188+
sigma=0.4,
189+
sigmad=sigmad,
190+
niter=50,
191+
niter_inner=10,
192+
)
193+
xredfp = np.real(xredfp.reshape(x.shape))
194+
195+
###############################################################################
196+
# Let's finally compare the results and the error convergence of the three
197+
# variations of RED
198+
199+
fig, axs = plt.subplots(1, 4, sharey=True, figsize=(15, 5))
200+
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
201+
axs[0].set_title("Model")
202+
axs[0].axis("tight")
203+
axs[1].imshow(xredgd, vmin=0, vmax=1, cmap="gray")
204+
axs[1].set_title(f"GD-RED (SNR={snr(x, xredgd):.2f} dB)")
205+
axs[1].axis("tight")
206+
axs[2].imshow(xredadmm, vmin=0, vmax=1, cmap="gray")
207+
axs[2].set_title(f"ADMM-RED (SNR={snr(x, xredadmm):.2f} dB)")
208+
axs[2].axis("tight")
209+
axs[3].imshow(xredfp, vmin=0, vmax=1, cmap="gray")
210+
axs[3].set_title(f"FP-RED (SNR={snr(x, xredfp):.2f} dB)")
211+
axs[3].axis("tight")
212+
plt.tight_layout()
213+
214+
plt.figure(figsize=(12, 3))
215+
plt.semilogy(errhistgd, "k", lw=2, label="GD")
216+
plt.semilogy(errhistadmm, "r", lw=2, label="ADMM")
217+
plt.semilogy(errhistfp, "b", lw=2, label="FP")
218+
219+
plt.title("Error norm")
220+
plt.legend()
221+
plt.tight_layout()

0 commit comments

Comments
 (0)