Skip to content

Commit 014158a

Browse files
committed
feat: added HQS tests
1 parent cd4f78d commit 014158a

1 file changed

Lines changed: 52 additions & 0 deletions

File tree

pytests/test_solver.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from numpy.testing import assert_array_almost_equal
66
from pylops.basicoperators import Identity, MatrixMult
7+
from pylops.optimization.leastsquares import regularized_inversion
78
from pylops.optimization.sparsity import fista, ista
89

910
from pyproximal.optimization.primal import (
@@ -274,6 +275,57 @@ def test_PG_GPG(par):
274275
assert_array_almost_equal(xpg, xgpg, decimal=2)
275276

276277

278+
@pytest.mark.parametrize("par", [(par1), (par2)])
279+
def test_HQS_ADMM_L2(par):
280+
"""Check that HQS/ADMM can be used to solved a pure L2-based objective function
281+
(and compare with LSQR - note that despite the trajectory will be different,
282+
they should converge to the same solution)
283+
"""
284+
np.random.seed(0)
285+
n, m = par["n"], par["m"]
286+
287+
# Define sparse model
288+
x = np.random.normal(0.0, 1.0, m).astype(par["dtype"])
289+
290+
# Random mixing matrix
291+
R = np.random.normal(0.0, 1.0, (n, m)).astype(par["dtype"])
292+
Rop = MatrixMult(R, dtype=par["dtype"])
293+
294+
y = Rop @ x
295+
296+
# Step size
297+
L = (Rop.H * Rop).eigs(1).real
298+
tau = 0.99 / L
299+
eps = 1e-1
300+
301+
# L2
302+
Iop = Identity(m, dtype=par["dtype"])
303+
xl2 = regularized_inversion(
304+
Rop,
305+
y,
306+
Regs=[
307+
Iop,
308+
],
309+
epsRs=[
310+
np.sqrt(eps),
311+
],
312+
iter_lim=1000,
313+
)[0]
314+
315+
# HQS
316+
l2 = L2(Op=Rop, b=y, niter=10, warm=True)
317+
l2reg = L2(sigma=eps)
318+
xhqs = HQS(l2, l2reg, x0=np.zeros(m), tau=tau, niter=1000)[0]
319+
320+
# ADMM
321+
l2 = L2(Op=Rop, b=y, niter=10, warm=True)
322+
l2reg = L2(sigma=eps)
323+
xadmm = ADMM(l2, l2reg, x0=np.zeros(m), tau=tau, niter=1000)[0]
324+
325+
assert_array_almost_equal(xl2, xhqs, decimal=2)
326+
assert_array_almost_equal(xl2, xadmm, decimal=2)
327+
328+
277329
@pytest.mark.parametrize("par", [(par1), (par2)])
278330
def test_ADMM_DRS(par):
279331
"""Check equivalency of ADMM and DouglasRachfordSplitting

0 commit comments

Comments
 (0)