|
1 | 1 | __all__ = [ |
2 | 2 | "Callbacks", |
| 3 | + "CostNanInfCallback", |
| 4 | + "CostToDataCallback", |
| 5 | + "CostToInitialCallback", |
3 | 6 | "MetricsCallback", |
4 | | - "ResidualNormCallback", |
5 | 7 | ] |
6 | 8 |
|
7 | 9 | from typing import TYPE_CHECKING, Dict, List, Optional, Sequence |
8 | 10 |
|
| 11 | +import numpy as np |
| 12 | + |
9 | 13 | from pylops.utils.metrics import mae, mse, psnr, snr |
10 | 14 | from pylops.utils.typing import NDArray |
11 | 15 |
|
@@ -141,6 +145,83 @@ def on_run_end(self, solver: "Solver", x: NDArray) -> None: |
141 | 145 | pass |
142 | 146 |
|
143 | 147 |
|
| 148 | +class CostToDataCallback(Callbacks): |
| 149 | + """Cost to data callback |
| 150 | +
|
| 151 | + This callback can be used to stop the solver when the ``cost`` parameter |
| 152 | + of the solver is below a certain threshold defined as a percentage of the |
| 153 | + Euclidean norm of the data. |
| 154 | +
|
| 155 | + Note that the meaning of ``cost`` can change from solver to solver - e.g., |
| 156 | + it can represent the misfit of the data term or the total cost function. |
| 157 | +
|
| 158 | + Parameters |
| 159 | + ---------- |
| 160 | + rtol : :obj:`float` |
| 161 | + Percentage of the initial cost below which the solver |
| 162 | + will stop iterating. For example, if ``rtol`` is 0.1, the solver |
| 163 | + will stop when the cost is below 10% of the Euclidean norm of |
| 164 | + the data. |
| 165 | +
|
| 166 | + """ |
| 167 | + |
| 168 | + def __init__(self, rtol: float) -> None: |
| 169 | + self.rtol = rtol |
| 170 | + self.stop = False |
| 171 | + |
| 172 | + def on_setup_end(self, solver: "Solver", x: NDArray) -> None: |
| 173 | + self.ynorm = solver.ncp.linalg.norm(solver.y) |
| 174 | + |
| 175 | + def on_step_end(self, solver: "Solver", x: NDArray) -> None: |
| 176 | + if solver.cost[-1] < self.rtol * self.ynorm: |
| 177 | + self.stop = True |
| 178 | + |
| 179 | + |
| 180 | +class CostToInitialCallback(Callbacks): |
| 181 | + """Cost to initial callback |
| 182 | +
|
| 183 | + This callback can be used to stop the solver when the ``cost`` |
| 184 | + parameter of the solver is below a certain threshold defined as a |
| 185 | + percentage of the initial residual norm. |
| 186 | +
|
| 187 | + Note that the meaning of ``cost`` can change from solver to solver - e.g., |
| 188 | + it can represent the misfit of the data term or the total cost function. |
| 189 | +
|
| 190 | + Parameters |
| 191 | + ---------- |
| 192 | + rtol : :obj:`float` |
| 193 | + Percentage of the initial cost below which the solver |
| 194 | + will stop iterating. For example, if ``rtol`` is 0.1, the solver |
| 195 | + will stop when the cost is below 10% of the initial |
| 196 | + cost. |
| 197 | +
|
| 198 | + """ |
| 199 | + |
| 200 | + def __init__(self, rtol: float) -> None: |
| 201 | + self.rtol = rtol |
| 202 | + self.stop = False |
| 203 | + |
| 204 | + def on_step_end(self, solver: "Solver", x: NDArray) -> None: |
| 205 | + if solver.cost[-1] < self.rtol * solver.cost[0]: |
| 206 | + self.stop = True |
| 207 | + |
| 208 | + |
| 209 | +class CostNanInfCallback(Callbacks): |
| 210 | + """Cost Nan/Inf callback |
| 211 | +
|
| 212 | + This callback can be used to stop the solver when the ``cost`` |
| 213 | + becomes either ``np.nan`` or ``np.inf`` |
| 214 | +
|
| 215 | + """ |
| 216 | + |
| 217 | + def __init__(self) -> None: |
| 218 | + self.stop = False |
| 219 | + |
| 220 | + def on_step_end(self, solver: "Solver", x: NDArray) -> None: |
| 221 | + if np.isnan(solver.cost[-1]) or np.isinf(solver.cost[-1]): |
| 222 | + self.stop = True |
| 223 | + |
| 224 | + |
144 | 225 | class MetricsCallback(Callbacks): |
145 | 226 | r"""Metrics callback |
146 | 227 |
|
@@ -191,32 +272,6 @@ def on_step_end(self, solver: "Solver", x: NDArray) -> None: |
191 | 272 | self.metrics["psnr"].append(psnr(self.xtrue, x)) |
192 | 273 |
|
193 | 274 |
|
194 | | -class ResidualNormCallback(Callbacks): |
195 | | - """Residual norm callback |
196 | | -
|
197 | | - This callback can be used to stop the solver when the residual norm |
198 | | - is below a certain threshold defined as a percentage of the |
199 | | - initial residual norm. |
200 | | -
|
201 | | - Parameters |
202 | | - ---------- |
203 | | - rtol : :obj:`float` |
204 | | - Percentage of the initial residual norm below which the solver |
205 | | - will stop iterating. For example, if `rtol` is 0.1, the solver |
206 | | - will stop when the residual norm is below 10% of the initial |
207 | | - residual norm. |
208 | | -
|
209 | | - """ |
210 | | - |
211 | | - def __init__(self, rtol: float) -> None: |
212 | | - self.rtol = rtol |
213 | | - self.stop = False |
214 | | - |
215 | | - def on_step_end(self, solver: "Solver", x: NDArray) -> None: |
216 | | - if solver.cost[-1] < self.rtol * solver.cost[0]: |
217 | | - self.stop = True |
218 | | - |
219 | | - |
220 | 275 | def _callback_stop(callbacks: Sequence[Callbacks]) -> bool: |
221 | 276 | """Check if any callback has raised a stop flag |
222 | 277 |
|
|
0 commit comments