Skip to content

Commit b4bd8c9

Browse files
authored
Fix shape assertion and update comparison method
1 parent 5c35daf commit b4bd8c9

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

test/b/test_b_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def test_bernstein_poly(self):
180180
f = BernsteinPoly(c)
181181
y = f.eval(c, x)
182182
y_precalculated = np.asarray([19.8694, 32.0761, 19.6774])
183-
self.assertEqual(y_precalculated, y.shape)
184-
self.assertTrue(jnp.allclose(y, y_precalculated))
183+
self.assertEqual(y_precalculated.shape, y.shape)
184+
self.assertTrue(np.allclose(y, y_precalculated))
185185

186186
g = f.jac_p(c, x)
187187
self.assertEqual((3,) + d, g.shape)

0 commit comments

Comments
 (0)