|
| 1 | +from collections.abc import Callable |
| 2 | +from typing import Any |
| 3 | + |
| 4 | +from pylops.utils.backend import get_array_module |
| 5 | +from pylops.utils.typing import NDArray, ShapeLike |
| 6 | +from typing_extensions import Self |
| 7 | + |
| 8 | +from pyproximal.proximal.L1 import _current_sigma |
| 9 | +from pyproximal.ProxOperator import ProxOperator, _check_tau |
| 10 | +from pyproximal.utils.typing import FloatCallableLike |
| 11 | + |
| 12 | + |
| 13 | +class _Denoise: |
| 14 | + r"""Denoiser of choice |
| 15 | +
|
| 16 | + Parameters |
| 17 | + ---------- |
| 18 | + denoiser : :obj:`func` |
| 19 | + Denoiser (must be a function with two inputs, the first is the signal |
| 20 | + to be denoised, the second is the strenght of the denoiser `sigma`) |
| 21 | + dims : :obj:`tuple` |
| 22 | + Dimensions used to reshape the vector ``x`` in the ``prox`` method |
| 23 | + prior to calling the ``denoiser`` |
| 24 | +
|
| 25 | + """ |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + denoiser: Callable[[NDArray, float], NDArray], |
| 30 | + dims: ShapeLike, |
| 31 | + ) -> None: |
| 32 | + self.denoiser = denoiser |
| 33 | + self.dims = dims |
| 34 | + |
| 35 | + def __call__(self, x: NDArray, tau: float) -> NDArray: |
| 36 | + x = x.reshape(self.dims) |
| 37 | + xden = self.denoiser(x, tau) |
| 38 | + return xden.ravel() |
| 39 | + |
| 40 | + |
| 41 | +class RED(ProxOperator): |
| 42 | + r"""Regularization by Denoising (RED) |
| 43 | +
|
| 44 | + Regularization by Denoising loss: |
| 45 | + :math:`RED(\mathbf{x}) = \sigma\mathbf{x}^T (\mathbf{x} - |
| 46 | + f_{\sigma_d}(\mathbf{x}))` |
| 47 | +
|
| 48 | + Parameters |
| 49 | + ---------- |
| 50 | + denoiser : :obj:`func` |
| 51 | + Denoiser (must be a function with one input corresponding to |
| 52 | + the signal to be denoised) |
| 53 | + dims : :obj:`tuple` |
| 54 | + Dimensions used to reshape the vector ``x`` in the ``prox`` method |
| 55 | + prior to calling the ``denoiser`` |
| 56 | + sigma : :obj:`float`, optional |
| 57 | + Multiplicative coefficient of RED term |
| 58 | + sigmad : :obj:`float` or :obj:`numpy.ndarray` or :obj:`func`, optional |
| 59 | + Strenght of the denoiser. This can be a constant number or a function |
| 60 | + that is called passing a counter which keeps track of how many |
| 61 | + times the ``grad`` or ``prox`` methods has been invoked before and |
| 62 | + returns a scalar (or a list of) ``sigma`` to be used |
| 63 | + x0 : :obj:`numpy.ndarray`, optional |
| 64 | + Initial vector of iterative scheme used to compute the proximal |
| 65 | + niter : :obj:`int`, optional |
| 66 | + Number of iterations of iterative scheme used to compute the proximal |
| 67 | + warm : :obj:`bool`, optional |
| 68 | + Warm start (``True``) or not (``False``). Uses estimate from previous |
| 69 | + call of ``prox`` method. |
| 70 | + call : :obj:`bool`, optional |
| 71 | + Evalutate call method (``True``) or not (``False``) |
| 72 | +
|
| 73 | + Notes |
| 74 | + ----- |
| 75 | + The gradient of the RED loss is defined as: |
| 76 | +
|
| 77 | + .. math:: |
| 78 | +
|
| 79 | + \nabla_\mathbf{x} RED(\mathbf{x}) = |
| 80 | + \sigma (\mathbf{x} - f_{\sigma_d}(\mathbf{x})) |
| 81 | +
|
| 82 | + whilst the proximal operator is obtained by solving the |
| 83 | + minimization problem |
| 84 | +
|
| 85 | + .. math:: |
| 86 | +
|
| 87 | + prox_{\tau RED} (\mathbf{x}) = \argmin_{\mathbf{y}} RED(\mathbf{y}) + |
| 88 | + \frac{1}{2 \tau}||\mathbf{y} - \mathbf{x}||^2_2 |
| 89 | +
|
| 90 | + via the following fixed-point iteration: |
| 91 | +
|
| 92 | + .. math:: |
| 93 | +
|
| 94 | + \mathbf{y}^k = \frac{1}{\beta + \sigma} (\sigma f_{\sigma_d}(\mathbf{y}^{k-1}) |
| 95 | + + \beta \mathbf{x}) |
| 96 | +
|
| 97 | + where :math:`\beta=1/\tau`. |
| 98 | +
|
| 99 | + References |
| 100 | + ---------- |
| 101 | + .. [1] Romano, Y., Elad, M., and Milanfar, P. |
| 102 | + "The Little Engine that Could Regularization by |
| 103 | + Denoising (RED)", SIAM Journal on Imaging Science. |
| 104 | + 2017. |
| 105 | +
|
| 106 | + """ |
| 107 | + |
| 108 | + def __init__( |
| 109 | + self, |
| 110 | + denoiser: Callable[[NDArray, float], NDArray], |
| 111 | + dims: ShapeLike, |
| 112 | + sigma: float = 1.0, |
| 113 | + sigmad: FloatCallableLike = 1.0, |
| 114 | + x0: NDArray | None = None, |
| 115 | + niter: int = 10, |
| 116 | + warm: bool = True, |
| 117 | + call: bool = True, |
| 118 | + ) -> None: |
| 119 | + super().__init__(None, False) |
| 120 | + |
| 121 | + self.denoiser = _Denoise(denoiser, dims=dims) |
| 122 | + self.sigma = sigma |
| 123 | + self.sigmad = sigmad |
| 124 | + self.x0 = x0 |
| 125 | + self.niter = niter |
| 126 | + self.warm = warm |
| 127 | + self.call = call |
| 128 | + self.count = 0 |
| 129 | + |
| 130 | + def __call__(self, x: NDArray) -> bool | float: |
| 131 | + """Evaluate RED loss |
| 132 | +
|
| 133 | + Parameters |
| 134 | + ---------- |
| 135 | + x : :obj:`numpy.ndarray` |
| 136 | + Vector |
| 137 | +
|
| 138 | + Returns |
| 139 | + ------- |
| 140 | + :obj:`float` |
| 141 | + - return ``0.0`` immediately if ``call=False`` |
| 142 | + - return dot-product of the input and residual |
| 143 | + if ``call=True`` |
| 144 | + """ |
| 145 | + if not self.call: |
| 146 | + return 0.0 |
| 147 | + else: |
| 148 | + ncp = get_array_module(x) |
| 149 | + sigmad = _current_sigma(self.sigmad, self.count) |
| 150 | + res = self.sigma * (x - self.denoiser(x, sigmad)) |
| 151 | + return float(ncp.dot(x, res)) |
| 152 | + |
| 153 | + def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: |
| 154 | + """Increment counter""" |
| 155 | + |
| 156 | + def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: |
| 157 | + self.count += 1 |
| 158 | + return func(self, *args, **kwargs) |
| 159 | + |
| 160 | + return wrapped |
| 161 | + |
| 162 | + @_increment_count |
| 163 | + @_check_tau |
| 164 | + def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: |
| 165 | + ncp = get_array_module(x) |
| 166 | + beta = 1.0 / tau |
| 167 | + sigmad = _current_sigma(self.sigmad, self.count) |
| 168 | + |
| 169 | + # Define starting guess |
| 170 | + if self.x0 is None: |
| 171 | + sol = ncp.zeros_like(x) |
| 172 | + else: |
| 173 | + sol = self.x0 |
| 174 | + |
| 175 | + # Fixed point iterations |
| 176 | + for _ in range(self.niter): |
| 177 | + den = self.denoiser(sol, sigmad) |
| 178 | + sol = (self.sigma * den + beta * x) / (self.sigma + beta) |
| 179 | + if self.warm: |
| 180 | + self.x0 = sol |
| 181 | + return sol |
| 182 | + |
| 183 | + @_increment_count |
| 184 | + def grad(self, x: NDArray) -> NDArray: |
| 185 | + sigmad = _current_sigma(self.sigmad, self.count) |
| 186 | + res = x - self.denoiser(x, sigmad) |
| 187 | + return self.sigma * res |
0 commit comments