Skip to content

Commit 8b0fdfd

Browse files
committed
doc: added tests and documentation about rtol in solvers
1 parent dbfe2e5 commit 8b0fdfd

3 files changed

Lines changed: 348 additions & 9 deletions

File tree

docs/source/addingsolver.rst

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ Implementing new solvers
44
========================
55
Users are welcome to create new solvers and add them to the PyLops library.
66

7-
In this tutorial, we will go through the key steps in the definition of a solver, using the
8-
:py:class:`pylops.optimization.basic.CG` as an example.
7+
In this tutorial, we will go through the key steps in the definition of a solver, using a
8+
sligthly simplified version of :py:class:`pylops.optimization.basic.CG` as an example.
99

1010
.. note::
1111
In case the solver that you are planning to create falls within the category of proximal solvers,
@@ -134,7 +134,26 @@ can add additional input parameters. For CG, the step is:
134134
135135
136136
Similarly, we also implement a ``run`` method that is in charge of running a number of iterations by repeatedly
137-
calling the ``step`` method. It is also usually convenient to implement a finalize method; this method can do any required post-processing that should
137+
calling the ``step`` method.
138+
139+
.. code-block:: python
140+
141+
def run(self, x, niter, show, itershow):
142+
while self.iiter < niter and self.kold > self.tol:
143+
x = self.step(x, showstep)
144+
self.callback(x)
145+
# check if any callback has raised a stop flag
146+
stop = _callback_stop(self.callbacks)
147+
if stop:
148+
break
149+
return x
150+
151+
It is worth noting that any number of callbacks can be attached to the solver; some of these
152+
callbacks can implement a stopping criterion and set the ``stop`` member to True when a given
153+
condition is met. The ``_callback_stop`` method is in change of checking if any of the callbacks
154+
has set ``stop`` to True and in the case break the iterations.
155+
156+
Finally, it is also usually convenient to implement a ``finalize`` method; this method can do any required post-processing that should
138157
not be applied at the end of each step, rather at the end of the entire optimization process. For CG, this is as simple
139158
as converting the ``cost`` variable from a list to a ``numpy`` array. For more details, see our implementations for CG.
140159

@@ -169,8 +188,10 @@ input and returns some of the most valuable properties of the class-based solver
169188

170189
.. code-block:: python
171190
172-
def cg(Op, y, x0, niter=10, tol=1e-4, show=False, itershow=(10, 10, 10), callback=None):
173-
cgsolve = CG(Op)
191+
def cg(Op, y, x0, niter=10, tol=1e-4, rtol=0.0,
192+
show=False, itershow=(10, 10, 10), callback=None):
193+
rcallback = ResidualNormCallback(rtol)
194+
cgsolve = CG(Op, callbacks=[rcallback, ])
174195
if callback is not None:
175196
cgsolve.callback = callback
176197
x, iiter, cost = cgsolve.solve(

pytests/test_solver.py

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,38 @@ def test_cg_forceflat(par):
164164
assert_array_almost_equal(x.ravel(), xinv, decimal=4)
165165

166166

167+
@pytest.mark.parametrize(
168+
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
169+
)
170+
def test_cg_stopping(par):
171+
"""CG testing stopping criterion rtol"""
172+
np.random.seed(10)
173+
174+
A = np.random.normal(0, 10, (par["ny"], par["nx"])) + par[
175+
"imag"
176+
] * np.random.normal(0, 10, (par["ny"], par["nx"]))
177+
A = np.conj(A).T @ A # to ensure definite positive matrix
178+
Aop = MatrixMult(A, dtype=par["dtype"])
179+
180+
x = np.ones(par["nx"]) + par["imag"] * np.ones(par["nx"])
181+
if par["x0"]:
182+
x0 = np.random.normal(0, 10, par["nx"]) + par["imag"] * np.random.normal(
183+
0, 10, par["nx"]
184+
)
185+
else:
186+
x0 = None
187+
188+
y = Aop * x
189+
190+
for preallocate in [False, True]:
191+
rtol = 1e-2
192+
_, _, cost = cg(
193+
Aop, y, x0=x0, niter=par["nx"], tol=0, rtol=rtol, preallocate=preallocate
194+
)
195+
assert cost[-2] / cost[0] >= rtol
196+
assert cost[-1] / cost[0] < rtol
197+
198+
167199
@pytest.mark.parametrize(
168200
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
169201
)
@@ -193,14 +225,105 @@ def test_cgls(par):
193225
assert_array_almost_equal(x, xinv, decimal=4)
194226

195227

228+
@pytest.mark.parametrize(
229+
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
230+
)
231+
def test_cgls_ndarray(par):
232+
"""CGLS with linear operator (and ndarray as input and output)"""
233+
np.random.seed(10)
234+
235+
dims = dimsd = (par["nx"], par["ny"])
236+
x = np.ones(dims) + par["imag"] * np.ones(dims)
237+
238+
A = np.random.normal(0, 10, (x.size, x.size)) + par["imag"] * np.random.normal(
239+
0, 10, (x.size, x.size)
240+
)
241+
Aop = MatrixMult(A, dtype=par["dtype"])
242+
Aop.dims = dims
243+
Aop.dimsd = dimsd
244+
245+
if par["x0"]:
246+
x0 = np.random.normal(0, 10, dims) + par["imag"] * np.random.normal(0, 10, dims)
247+
else:
248+
x0 = None
249+
250+
y = Aop * x
251+
252+
for preallocate in [False, True]:
253+
xinv = cgls(Aop, y, x0=x0, niter=2 * x.size, tol=0, preallocate=preallocate)[0]
254+
assert xinv.shape == x.shape
255+
assert_array_almost_equal(x, xinv, decimal=4)
256+
257+
258+
@pytest.mark.parametrize(
259+
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
260+
)
261+
def test_cgls_forceflat(par):
262+
"""CGLS with linear operator (and forced 1darray as input and output)"""
263+
np.random.seed(10)
264+
265+
dims = dimsd = (par["nx"], par["ny"])
266+
x = np.ones(dims) + par["imag"] * np.ones(dims)
267+
268+
A = np.random.normal(0, 10, (x.size, x.size)) + par["imag"] * np.random.normal(
269+
0, 10, (x.size, x.size)
270+
)
271+
Aop = MatrixMult(A, dtype=par["dtype"], forceflat=True)
272+
Aop.dims = dims
273+
Aop.dimsd = dimsd
274+
275+
if par["x0"]:
276+
x0 = np.random.normal(0, 10, dims) + par["imag"] * np.random.normal(0, 10, dims)
277+
else:
278+
x0 = None
279+
280+
y = Aop * x
281+
282+
for preallocate in [False, True]:
283+
xinv = cgls(Aop, y, x0=x0, niter=2 * x.size, tol=0, preallocate=preallocate)[0]
284+
assert xinv.shape == x.ravel().shape
285+
assert_array_almost_equal(x.ravel(), xinv, decimal=4)
286+
287+
288+
@pytest.mark.parametrize(
289+
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
290+
)
291+
def test_cgls_stopping(par):
292+
"""CGLS testing stopping criterion rtol"""
293+
np.random.seed(10)
294+
295+
A = np.random.normal(0, 10, (par["ny"], par["nx"])) + par[
296+
"imag"
297+
] * np.random.normal(0, 10, (par["ny"], par["nx"]))
298+
Aop = MatrixMult(A, dtype=par["dtype"])
299+
300+
x = np.ones(par["nx"]) + par["imag"] * np.ones(par["nx"])
301+
if par["x0"]:
302+
x0 = np.random.normal(0, 10, par["nx"]) + par["imag"] * np.random.normal(
303+
0, 10, par["nx"]
304+
)
305+
else:
306+
x0 = None
307+
308+
y = Aop * x
309+
310+
for preallocate in [False, True]:
311+
rtol = 1e-2
312+
cost = cgls(
313+
Aop, y, x0=x0, niter=par["nx"], tol=0, rtol=rtol, preallocate=preallocate
314+
)[-1]
315+
assert cost[-2] / cost[0] >= rtol
316+
assert cost[-1] / cost[0] < rtol
317+
318+
196319
@pytest.mark.skipif(
197320
int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled"
198321
)
199322
@pytest.mark.parametrize(
200323
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
201324
)
202-
def test_lsqr(par):
203-
"""Compare local Pylops and scipy LSQR"""
325+
def test_lsqr_pylops_scipy(par):
326+
"""Compare Pylops and scipy LSQR"""
204327
np.random.seed(10)
205328

206329
A = np.random.normal(0, 10, (par["ny"], par["nx"])) + par[
@@ -230,3 +353,92 @@ def test_lsqr(par):
230353

231354
assert_array_almost_equal(xinv, x, decimal=4)
232355
assert_array_almost_equal(xinv_sp, x, decimal=4)
356+
357+
358+
@pytest.mark.parametrize(
359+
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
360+
)
361+
def test_lsqr(par):
362+
"""LSQR with linear operator"""
363+
np.random.seed(10)
364+
365+
A = np.random.normal(0, 10, (par["ny"], par["nx"])) + par[
366+
"imag"
367+
] * np.random.normal(0, 10, (par["ny"], par["nx"]))
368+
Aop = MatrixMult(A, dtype=par["dtype"])
369+
370+
x = np.ones(par["nx"]) + par["imag"] * np.ones(par["nx"])
371+
if par["x0"]:
372+
x0 = np.random.normal(0, 10, par["nx"]) + par["imag"] * np.random.normal(
373+
0, 10, par["nx"]
374+
)
375+
else:
376+
x0 = None
377+
378+
y = Aop * x
379+
380+
for preallocate in [False, True]:
381+
xinv = lsqr(Aop, y, x0=x0, niter=par["nx"], atol=1e-5, preallocate=preallocate)[
382+
0
383+
]
384+
assert_array_almost_equal(x, xinv, decimal=4)
385+
386+
387+
@pytest.mark.parametrize(
388+
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
389+
)
390+
def test_lsqr_ndarray(par):
391+
"""LSQR with linear operator (and ndarray as input and output)"""
392+
np.random.seed(10)
393+
394+
dims = dimsd = (par["nx"], par["ny"])
395+
x = np.ones(dims) + par["imag"] * np.ones(dims)
396+
397+
A = np.random.normal(0, 10, (x.size, x.size)) + par["imag"] * np.random.normal(
398+
0, 10, (x.size, x.size)
399+
)
400+
Aop = MatrixMult(A, dtype=par["dtype"])
401+
Aop.dims = dims
402+
Aop.dimsd = dimsd
403+
404+
if par["x0"]:
405+
x0 = np.random.normal(0, 10, dims) + par["imag"] * np.random.normal(0, 10, dims)
406+
else:
407+
x0 = None
408+
409+
y = Aop * x
410+
411+
for preallocate in [False, True]:
412+
xinv = lsqr(Aop, y, x0=x0, niter=2 * x.size, atol=0, preallocate=preallocate)[0]
413+
assert xinv.shape == x.shape
414+
assert_array_almost_equal(x, xinv, decimal=4)
415+
416+
417+
@pytest.mark.parametrize(
418+
"par", [(par1), (par2), (par3), (par4), (par1j), (par2j), (par3j), (par3j)]
419+
)
420+
def test_lsqr_forceflat(par):
421+
"""LSQR with linear operator (and forced 1darray as input and output)"""
422+
np.random.seed(10)
423+
424+
dims = dimsd = (par["nx"], par["ny"])
425+
x = np.ones(dims) + par["imag"] * np.ones(dims)
426+
427+
A = np.random.normal(0, 10, (x.size, x.size)) + par["imag"] * np.random.normal(
428+
0, 10, (x.size, x.size)
429+
)
430+
Aop = MatrixMult(A, dtype=par["dtype"], forceflat=True)
431+
Aop.dims = dims
432+
Aop.dimsd = dimsd
433+
434+
if par["x0"]:
435+
x0 = np.random.normal(0, 10, dims) + par["imag"] * np.random.normal(0, 10, dims)
436+
else:
437+
x0 = None
438+
439+
y = Aop * x
440+
441+
for preallocate in [False, True]:
442+
xinv = lsqr(Aop, y, x0=x0, niter=2 * x.size, atol=0, preallocate=preallocate)[0]
443+
assert xinv.shape == x.ravel().shape
444+
assert_array_almost_equal(x.ravel(), xinv, decimal=4)

0 commit comments

Comments
 (0)