Skip to content

Commit 9b682ab

Browse files
rdyroOptaxDev
authored andcommitted
Improve optax.linear_algebra.nnls tests
PiperOrigin-RevId: 727736758
1 parent db42abd commit 9b682ab

2 files changed

Lines changed: 31 additions & 18 deletions

File tree

optax/_src/linear_algebra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def f(x_p_c, _):
333333

334334
return (xn, pn, cn), None
335335

336-
x = jnp.zeros(A.shape[1])
336+
x = jnp.zeros_like(b, shape=b.shape[:-1] + A.shape[-1:])
337337
p = x
338338
c = 0.
339339

optax/_src/linear_algebra_test.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -209,30 +209,43 @@ def _gen_symmetrix_matrix(dim, condition_number):
209209
# No guarantee of success after e >= 7
210210
pass
211211

212-
@parameterized.product(n=[24, 32, 64], d=[24, 32, 64], zero_A=[False, True])
213-
def test_nnls(self, n, d, zero_A, atol=1e-3, iters=10**3): # pylint: disable=invalid-name
212+
@parameterized.product(n=[24, 32], d=[24, 32], zero_lhs=[False, True],
213+
seed=[0], dtype=[jnp.float32, jnp.bfloat16])
214+
def test_nnls(self, n, d, zero_lhs, seed, dtype, atol=1e-5):
214215
"""Test non-negative least squares solver."""
215-
A = jnp.zeros((n, d)) if zero_A else np.random.normal(size=(n, d)) # pylint: disable=invalid-name
216-
b = np.random.normal(size=(n,))
216+
keys = jax.random.split(jax.random.key(seed), 2)
217+
if zero_lhs:
218+
if not (dtype == jnp.float32 and n == 32 and d == 32 and seed == 0):
219+
self.skipTest('Only 1 test case for zero_lhs=True')
220+
A = jnp.zeros((n, d), dtype=dtype) # pylint: disable=invalid-name
221+
else:
222+
A = jax.random.normal(keys[0], (n, d), dtype=dtype) # pylint: disable=invalid-name
223+
b = jax.random.normal(keys[1], (n,), dtype=dtype)
217224

218-
x = linear_algebra.nnls(A, b, iters=iters)
225+
x10 = linear_algebra.nnls(A, b, iters=10)
226+
x100 = linear_algebra.nnls(A, b, iters=100)
227+
x1000 = linear_algebra.nnls(A, b, iters=100000)
219228

220-
with self.subTest('x is non-negative'):
221-
assert jnp.allclose(x.clip(max=0), 0, atol=atol)
229+
with self.subTest('x has the correct dtype'):
230+
self.assertEqual(x10.dtype, dtype)
231+
self.assertEqual(x100.dtype, dtype)
232+
self.assertEqual(x1000.dtype, dtype)
222233

223-
try:
224-
xr, _ = scipy.optimize.nnls(A, b, maxiter=iters)
225-
except RuntimeError:
226-
return
234+
with self.subTest('x is non-negative'):
235+
assert jnp.allclose(x10.clip(max=0), 0, atol=atol)
236+
assert jnp.allclose(x100.clip(max=0), 0, atol=atol)
237+
assert jnp.allclose(x1000.clip(max=0), 0, atol=atol)
227238

228-
with self.subTest('xr is non-negative'):
229-
assert jnp.allclose(xr.clip(max=0), 0, atol=atol)
239+
# we skip comparison to scipy.optimize.nnls as convergence is flaky (by
240+
# design, this is an iterative algorithm)
230241

231-
d = jnp.square(A @ x - b).sum()
232-
dr = jnp.square(A @ xr - b).sum()
242+
l10 = jnp.square(A @ x10 - b).sum()
243+
l100 = jnp.square(A @ x100 - b).sum()
244+
l1000 = jnp.square(A @ x1000 - b).sum()
233245

234-
with self.subTest('x is optimal'):
235-
np.testing.assert_allclose(d, dr.clip(max=d), atol=atol)
246+
with self.subTest('x is converging'):
247+
jnp.allclose((l100 - l10).clip(max=0), 0, atol=atol)
248+
jnp.allclose((l1000 - l100).clip(max=0), 0, atol=atol)
236249

237250

238251
if __name__ == '__main__':

0 commit comments

Comments
 (0)