55
66import jax
77import jax .numpy as jnp
8- import numpy as np
98
109from uncertaintyx .b .jax import b_basis
1110from uncertaintyx .b .jax import b_poly
@@ -39,7 +38,7 @@ def test_b_basis_of_degree_1(self):
3938 self .assertTrue (jnp .allclose (y [1 , 2 ], 0.5 ))
4039 self .assertTrue (jnp .allclose (y [0 , - 1 ], 0.0 ))
4140 self .assertTrue (jnp .allclose (y [1 , - 1 ], 1.0 ))
42- self .assertTrue (jnp .allclose (np .sum (y , axis = 0 ), 1.0 ))
41+ self .assertTrue (jnp .allclose (jnp .sum (y , axis = 0 ), 1.0 ))
4342
4443 def test_b_basis_of_degree_2 (self ):
4544 m = 5
@@ -55,7 +54,7 @@ def test_b_basis_of_degree_2(self):
5554 self .assertTrue (jnp .allclose (y [2 , 2 ], 0.25 ))
5655 self .assertTrue (jnp .allclose (y [:- 1 , - 1 ], 0.0 ))
5756 self .assertTrue (jnp .allclose (y [- 1 :, - 1 ], 1.0 ))
58- self .assertTrue (jnp .allclose (np .sum (y , axis = 0 ), 1.0 ))
57+ self .assertTrue (jnp .allclose (jnp .sum (y , axis = 0 ), 1.0 ))
5958
6059 def test_b_basis_of_degree_5 (self ):
6160 m = 5
@@ -68,7 +67,7 @@ def test_b_basis_of_degree_5(self):
6867 self .assertTrue (jnp .allclose (y [1 :, 0 ], 0.0 ))
6968 self .assertTrue (jnp .allclose (y [:- 1 , - 1 ], 0.0 ))
7069 self .assertTrue (jnp .allclose (y [- 1 :, - 1 ], 1.0 ))
71- self .assertTrue (jnp .allclose (np .sum (y , axis = 0 ), 1.0 ))
70+ self .assertTrue (jnp .allclose (jnp .sum (y , axis = 0 ), 1.0 ))
7271
7372
7473class BPolyTest (unittest .TestCase ):
0 commit comments