Skip to content

Commit ad3ad63

Browse files
authored
Use jax.numpy.sum instead of numpy.sum
Replaced numpy sum with jax.numpy sum for consistency.
1 parent cf5c0ea commit ad3ad63

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

test/b/test_b_jax.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import jax
77
import jax.numpy as jnp
8-
import numpy as np
98

109
from uncertaintyx.b.jax import b_basis
1110
from uncertaintyx.b.jax import b_poly
@@ -39,7 +38,7 @@ def test_b_basis_of_degree_1(self):
3938
self.assertTrue(jnp.allclose(y[1, 2], 0.5))
4039
self.assertTrue(jnp.allclose(y[0, -1], 0.0))
4140
self.assertTrue(jnp.allclose(y[1, -1], 1.0))
42-
self.assertTrue(jnp.allclose(np.sum(y, axis=0), 1.0))
41+
self.assertTrue(jnp.allclose(jnp.sum(y, axis=0), 1.0))
4342

4443
def test_b_basis_of_degree_2(self):
4544
m = 5
@@ -55,7 +54,7 @@ def test_b_basis_of_degree_2(self):
5554
self.assertTrue(jnp.allclose(y[2, 2], 0.25))
5655
self.assertTrue(jnp.allclose(y[:-1, -1], 0.0))
5756
self.assertTrue(jnp.allclose(y[-1:, -1], 1.0))
58-
self.assertTrue(jnp.allclose(np.sum(y, axis=0), 1.0))
57+
self.assertTrue(jnp.allclose(jnp.sum(y, axis=0), 1.0))
5958

6059
def test_b_basis_of_degree_5(self):
6160
m = 5
@@ -68,7 +67,7 @@ def test_b_basis_of_degree_5(self):
6867
self.assertTrue(jnp.allclose(y[1:, 0], 0.0))
6968
self.assertTrue(jnp.allclose(y[:-1, -1], 0.0))
7069
self.assertTrue(jnp.allclose(y[-1:, -1], 1.0))
71-
self.assertTrue(jnp.allclose(np.sum(y, axis=0), 1.0))
70+
self.assertTrue(jnp.allclose(jnp.sum(y, axis=0), 1.0))
7271

7372

7473
class BPolyTest(unittest.TestCase):

0 commit comments

Comments
 (0)