Skip to content

Commit 81a1365

Browse files
committed
feat: added RED operator
1 parent e2becfd commit 81a1365

4 files changed

Lines changed: 193 additions & 1 deletion

File tree

docs/source/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ Non-Convex
100100
Log1
101101
QuadraticEnvelopeCard
102102
QuadraticEnvelopeCardIndicator
103+
RED
103104
RelaxedMumfordShah
104105
SCAD
105106

pyproximal/optimization/pnp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class _Denoise(ProxOperator):
1515
denoiser : :obj:`func`
1616
Denoiser (must be a function with two inputs, the first is the signal
1717
to be denoised, the second is the `tau` constant of the y-update in
18-
the PnP optimization)
18+
the PnP optimization, which should be interpreted as the strenght of
19+
the denoiser)
1920
dims : :obj:`tuple`
2021
Dimensions used to reshape the vector ``x`` in the ``prox`` method
2122
prior to calling the ``denoiser``

pyproximal/proximal/RED.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
from collections.abc import Callable
2+
from typing import Any
3+
4+
from pylops.utils.backend import get_array_module
5+
from pylops.utils.typing import NDArray, ShapeLike
6+
from typing_extensions import Self
7+
8+
from pyproximal.proximal.L1 import _current_sigma
9+
from pyproximal.ProxOperator import ProxOperator, _check_tau
10+
from pyproximal.utils.typing import FloatCallableLike
11+
12+
13+
class _Denoise:
14+
r"""Denoiser of choice
15+
16+
Parameters
17+
----------
18+
denoiser : :obj:`func`
19+
Denoiser (must be a function with two inputs, the first is the signal
20+
to be denoised, the second is the strenght of the denoiser `sigma`)
21+
dims : :obj:`tuple`
22+
Dimensions used to reshape the vector ``x`` in the ``prox`` method
23+
prior to calling the ``denoiser``
24+
25+
"""
26+
27+
def __init__(
28+
self,
29+
denoiser: Callable[[NDArray, float], NDArray],
30+
dims: ShapeLike,
31+
) -> None:
32+
self.denoiser = denoiser
33+
self.dims = dims
34+
35+
def __call__(self, x: NDArray, tau: float) -> NDArray:
36+
x = x.reshape(self.dims)
37+
xden = self.denoiser(x, tau)
38+
return xden.ravel()
39+
40+
41+
class RED(ProxOperator):
42+
r"""Regularization by Denoising (RED)
43+
44+
Regularization by Denoising loss:
45+
:math:`RED(\mathbf{x}) = \sigma\mathbf{x}^T (\mathbf{x} -
46+
f_{\sigma_d}(\mathbf{x}))`
47+
48+
Parameters
49+
----------
50+
denoiser : :obj:`func`
51+
Denoiser (must be a function with one input corresponding to
52+
the signal to be denoised)
53+
dims : :obj:`tuple`
54+
Dimensions used to reshape the vector ``x`` in the ``prox`` method
55+
prior to calling the ``denoiser``
56+
sigma : :obj:`float`, optional
57+
Multiplicative coefficient of RED term
58+
sigmad : :obj:`float` or :obj:`numpy.ndarray` or :obj:`func`, optional
59+
Strenght of the denoiser. This can be a constant number or a function
60+
that is called passing a counter which keeps track of how many
61+
times the ``grad`` or ``prox`` methods has been invoked before and
62+
returns a scalar (or a list of) ``sigma`` to be used
63+
x0 : :obj:`numpy.ndarray`, optional
64+
Initial vector of iterative scheme used to compute the proximal
65+
niter : :obj:`int`, optional
66+
Number of iterations of iterative scheme used to compute the proximal
67+
warm : :obj:`bool`, optional
68+
Warm start (``True``) or not (``False``). Uses estimate from previous
69+
call of ``prox`` method.
70+
call : :obj:`bool`, optional
71+
Evalutate call method (``True``) or not (``False``)
72+
73+
Notes
74+
-----
75+
The gradient of the RED loss is defined as:
76+
77+
.. math::
78+
79+
\nabla_\mathbf{x} RED(\mathbf{x}) =
80+
\sigma (\mathbf{x} - f_{\sigma_d}(\mathbf{x}))
81+
82+
whilst the proximal operator is obtained by solving the
83+
minimization problem
84+
85+
.. math::
86+
87+
prox_{\tau RED} (\mathbf{x}) = \argmin_{\mathbf{y}} RED(\mathbf{y}) +
88+
\frac{1}{2 \tau}||\mathbf{y} - \mathbf{x}||^2_2
89+
90+
via the following fixed-point iteration:
91+
92+
.. math::
93+
94+
\mathbf{y}^k = \frac{1}{\beta + \sigma} (\sigma f_{\sigma_d}(\mathbf{y}^{k-1})
95+
+ \beta \mathbf{x})
96+
97+
where :math:`\beta=1/\tau`.
98+
99+
References
100+
----------
101+
.. [1] Romano, Y., Elad, M., and Milanfar, P.
102+
"The Little Engine that Could Regularization by
103+
Denoising (RED)", SIAM Journal on Imaging Science.
104+
2017.
105+
106+
"""
107+
108+
def __init__(
109+
self,
110+
denoiser: Callable[[NDArray, float], NDArray],
111+
dims: ShapeLike,
112+
sigma: float = 1.0,
113+
sigmad: FloatCallableLike = 1.0,
114+
x0: NDArray | None = None,
115+
niter: int = 10,
116+
warm: bool = True,
117+
call: bool = True,
118+
) -> None:
119+
super().__init__(None, False)
120+
121+
self.denoiser = _Denoise(denoiser, dims=dims)
122+
self.sigma = sigma
123+
self.sigmad = sigmad
124+
self.x0 = x0
125+
self.niter = niter
126+
self.warm = warm
127+
self.call = call
128+
self.count = 0
129+
130+
def __call__(self, x: NDArray) -> bool | float:
131+
"""Evaluate RED loss
132+
133+
Parameters
134+
----------
135+
x : :obj:`numpy.ndarray`
136+
Vector
137+
138+
Returns
139+
-------
140+
:obj:`float`
141+
- return ``0.0`` immediately if ``call=False``
142+
- return dot-product of the input and residual
143+
if ``call=True``
144+
"""
145+
if not self.call:
146+
return 0.0
147+
else:
148+
ncp = get_array_module(x)
149+
sigmad = _current_sigma(self.sigmad, self.count)
150+
res = self.sigma * (x - self.denoiser(x, sigmad))
151+
return float(ncp.dot(x, res))
152+
153+
def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]:
154+
"""Increment counter"""
155+
156+
def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any:
157+
self.count += 1
158+
return func(self, *args, **kwargs)
159+
160+
return wrapped
161+
162+
@_increment_count
163+
@_check_tau
164+
def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray:
165+
ncp = get_array_module(x)
166+
beta = 1.0 / tau
167+
sigmad = _current_sigma(self.sigmad, self.count)
168+
169+
# Define starting guess
170+
if self.x0 is None:
171+
sol = ncp.zeros_like(x)
172+
else:
173+
sol = self.x0
174+
175+
# Fixed point iterations
176+
for _ in range(self.niter):
177+
den = self.denoiser(sol, sigmad)
178+
sol = (self.sigma * den + beta * x) / (self.sigma + beta)
179+
if self.warm:
180+
self.x0 = sol
181+
return sol
182+
183+
@_increment_count
184+
def grad(self, x: NDArray) -> NDArray:
185+
sigmad = _current_sigma(self.sigmad, self.count)
186+
res = x - self.denoiser(x, sigmad)
187+
return self.sigma * res

pyproximal/proximal/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
HalfSpace Half space indicator
4444
GenericIntersectionProx Indicator of projection onto a union of given sets
4545
Sum Proximal operator of the sum of proximable functions
46+
RED Regularization by Denoising
4647
"""
4748

4849
from .Box import *
@@ -73,6 +74,7 @@
7374
from .HalfSpace import *
7475
from .GenericIntersection import *
7576
from .Sum import *
77+
from .RED import *
7678

7779

7880
__all__ = [
@@ -115,4 +117,5 @@
115117
"HalfSpace",
116118
"GenericIntersectionProx",
117119
"Sum",
120+
"RED",
118121
]

0 commit comments

Comments
 (0)