@@ -209,30 +209,43 @@ def _gen_symmetrix_matrix(dim, condition_number):
209209 # No guarantee of success after e >= 7
210210 pass
211211
212- @parameterized .product (n = [24 , 32 , 64 ], d = [24 , 32 , 64 ], zero_A = [False , True ])
213- def test_nnls (self , n , d , zero_A , atol = 1e-3 , iters = 10 ** 3 ): # pylint: disable=invalid-name
212+ @parameterized .product (n = [24 , 32 ], d = [24 , 32 ], zero_lhs = [False , True ],
213+ seed = [0 ], dtype = [jnp .float32 , jnp .bfloat16 ])
214+ def test_nnls (self , n , d , zero_lhs , seed , dtype , atol = 1e-5 ):
214215 """Test non-negative least squares solver."""
215- A = jnp .zeros ((n , d )) if zero_A else np .random .normal (size = (n , d )) # pylint: disable=invalid-name
216- b = np .random .normal (size = (n ,))
216+ keys = jax .random .split (jax .random .key (seed ), 2 )
217+ if zero_lhs :
218+ if not (dtype == jnp .float32 and n == 32 and d == 32 and seed == 0 ):
219+ self .skipTest ('Only 1 test case for zero_lhs=True' )
220+ A = jnp .zeros ((n , d ), dtype = dtype ) # pylint: disable=invalid-name
221+ else :
222+ A = jax .random .normal (keys [0 ], (n , d ), dtype = dtype ) # pylint: disable=invalid-name
223+ b = jax .random .normal (keys [1 ], (n ,), dtype = dtype )
217224
218- x = linear_algebra .nnls (A , b , iters = iters )
225+ x10 = linear_algebra .nnls (A , b , iters = 10 )
226+ x100 = linear_algebra .nnls (A , b , iters = 100 )
227+ x1000 = linear_algebra .nnls (A , b , iters = 100000 )
219228
220- with self .subTest ('x is non-negative' ):
221- assert jnp .allclose (x .clip (max = 0 ), 0 , atol = atol )
229+ with self .subTest ('x has the correct dtype' ):
230+ self .assertEqual (x10 .dtype , dtype )
231+ self .assertEqual (x100 .dtype , dtype )
232+ self .assertEqual (x1000 .dtype , dtype )
222233
223- try :
224- xr , _ = scipy . optimize . nnls ( A , b , maxiter = iters )
225- except RuntimeError :
226- return
234+ with self . subTest ( 'x is non-negative' ) :
235+ assert jnp . allclose ( x10 . clip ( max = 0 ), 0 , atol = atol )
236+ assert jnp . allclose ( x100 . clip ( max = 0 ), 0 , atol = atol )
237+ assert jnp . allclose ( x1000 . clip ( max = 0 ), 0 , atol = atol )
227238
228- with self . subTest ( 'xr is non-negative' ):
229- assert jnp . allclose ( xr . clip ( max = 0 ), 0 , atol = atol )
239+ # we skip comparison to scipy.optimize.nnls as convergence is flaky (by
240+ # design, this is an iterative algorithm )
230241
231- d = jnp .square (A @ x - b ).sum ()
232- dr = jnp .square (A @ xr - b ).sum ()
242+ l10 = jnp .square (A @ x10 - b ).sum ()
243+ l100 = jnp .square (A @ x100 - b ).sum ()
244+ l1000 = jnp .square (A @ x1000 - b ).sum ()
233245
234- with self .subTest ('x is optimal' ):
235- np .testing .assert_allclose (d , dr .clip (max = d ), atol = atol )
246+ with self .subTest ('x is converging' ):
247+ jnp .allclose ((l100 - l10 ).clip (max = 0 ), 0 , atol = atol )
248+ jnp .allclose ((l1000 - l100 ).clip (max = 0 ), 0 , atol = atol )
236249
237250
238251if __name__ == '__main__' :
0 commit comments