|
4 | 4 | import pytest |
5 | 5 | from numpy.testing import assert_array_almost_equal |
6 | 6 | from pylops.basicoperators import Identity, MatrixMult |
| 7 | +from pylops.optimization.leastsquares import regularized_inversion |
7 | 8 | from pylops.optimization.sparsity import fista, ista |
8 | 9 |
|
9 | 10 | from pyproximal.optimization.primal import ( |
@@ -274,6 +275,57 @@ def test_PG_GPG(par): |
274 | 275 | assert_array_almost_equal(xpg, xgpg, decimal=2) |
275 | 276 |
|
276 | 277 |
|
| 278 | +@pytest.mark.parametrize("par", [(par1), (par2)]) |
| 279 | +def test_HQS_ADMM_L2(par): |
| 280 | + """Check that HQS/ADMM can be used to solved a pure L2-based objective function |
| 281 | + (and compare with LSQR - note that despite the trajectory will be different, |
| 282 | + they should converge to the same solution) |
| 283 | + """ |
| 284 | + np.random.seed(0) |
| 285 | + n, m = par["n"], par["m"] |
| 286 | + |
| 287 | + # Define sparse model |
| 288 | + x = np.random.normal(0.0, 1.0, m).astype(par["dtype"]) |
| 289 | + |
| 290 | + # Random mixing matrix |
| 291 | + R = np.random.normal(0.0, 1.0, (n, m)).astype(par["dtype"]) |
| 292 | + Rop = MatrixMult(R, dtype=par["dtype"]) |
| 293 | + |
| 294 | + y = Rop @ x |
| 295 | + |
| 296 | + # Step size |
| 297 | + L = (Rop.H * Rop).eigs(1).real |
| 298 | + tau = 0.99 / L |
| 299 | + eps = 1e-1 |
| 300 | + |
| 301 | + # L2 |
| 302 | + Iop = Identity(m, dtype=par["dtype"]) |
| 303 | + xl2 = regularized_inversion( |
| 304 | + Rop, |
| 305 | + y, |
| 306 | + Regs=[ |
| 307 | + Iop, |
| 308 | + ], |
| 309 | + epsRs=[ |
| 310 | + np.sqrt(eps), |
| 311 | + ], |
| 312 | + iter_lim=1000, |
| 313 | + )[0] |
| 314 | + |
| 315 | + # HQS |
| 316 | + l2 = L2(Op=Rop, b=y, niter=10, warm=True) |
| 317 | + l2reg = L2(sigma=eps) |
| 318 | + xhqs = HQS(l2, l2reg, x0=np.zeros(m), tau=tau, niter=1000)[0] |
| 319 | + |
| 320 | + # ADMM |
| 321 | + l2 = L2(Op=Rop, b=y, niter=10, warm=True) |
| 322 | + l2reg = L2(sigma=eps) |
| 323 | + xadmm = ADMM(l2, l2reg, x0=np.zeros(m), tau=tau, niter=1000)[0] |
| 324 | + |
| 325 | + assert_array_almost_equal(xl2, xhqs, decimal=2) |
| 326 | + assert_array_almost_equal(xl2, xadmm, decimal=2) |
| 327 | + |
| 328 | + |
277 | 329 | @pytest.mark.parametrize("par", [(par1), (par2)]) |
278 | 330 | def test_ADMM_DRS(par): |
279 | 331 | """Check equivalency of ADMM and DouglasRachfordSplitting |
|
0 commit comments