Skip to content

Commit 29c74ed

Browse files
authored
Refactor BernsteinPoly test for prior evaluation
1 parent 3269efd commit 29c74ed

1 file changed

Lines changed: 2 additions & 31 deletions

File tree

test/b/test_b_jax.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -202,37 +202,8 @@ def test_from_lookup_table(self):
202202
)
203203
y = BernsteinGrid(x).eval(c)
204204

205-
f = BernsteinPoly.from_lookup_table(k, x, y, non_negative=True)
206-
b = f.prior()
207-
self.assertEqual(c.shape, b.shape)
208-
self.assertFalse(np.any(b < 0.0))
209-
self.assertAlmostEqual(c[0, 0, 0], b[0, 0, 0])
210-
self.assertAlmostEqual(c[0, 0, 1], b[0, 0, 1])
211-
self.assertAlmostEqual(c[0, 0, 2], b[0, 0, 2])
212-
self.assertAlmostEqual(c[0, 1, 0], b[0, 1, 0])
213-
self.assertAlmostEqual(c[0, 1, 1], b[0, 1, 1])
214-
self.assertAlmostEqual(c[0, 1, 2], b[0, 1, 2])
215-
self.assertAlmostEqual(c[0, 2, 0], b[0, 2, 0])
216-
self.assertAlmostEqual(c[0, 2, 1], b[0, 2, 1])
217-
self.assertAlmostEqual(c[0, 2, 2], b[0, 2, 2])
218-
self.assertAlmostEqual(c[1, 0, 0], b[1, 0, 0])
219-
self.assertAlmostEqual(c[1, 0, 1], b[1, 0, 1])
220-
self.assertAlmostEqual(c[1, 0, 2], b[1, 0, 2])
221-
self.assertAlmostEqual(c[1, 1, 0], b[1, 1, 0])
222-
self.assertAlmostEqual(c[1, 1, 1], b[1, 1, 1])
223-
self.assertAlmostEqual(c[1, 1, 2], b[1, 1, 2])
224-
self.assertAlmostEqual(c[1, 2, 0], b[1, 2, 0])
225-
self.assertAlmostEqual(c[1, 2, 1], b[1, 2, 1])
226-
self.assertAlmostEqual(c[1, 2, 2], b[1, 2, 2])
227-
self.assertAlmostEqual(c[2, 0, 0], b[2, 0, 0])
228-
self.assertAlmostEqual(c[2, 0, 1], b[2, 0, 1])
229-
self.assertAlmostEqual(c[2, 0, 2], b[2, 0, 2])
230-
self.assertAlmostEqual(c[2, 1, 0], b[2, 1, 0])
231-
self.assertAlmostEqual(c[2, 1, 1], b[2, 1, 1])
232-
self.assertAlmostEqual(c[2, 1, 2], b[2, 1, 2])
233-
self.assertAlmostEqual(c[2, 2, 0], b[2, 2, 0])
234-
self.assertAlmostEqual(c[2, 2, 1], b[2, 2, 1])
235-
self.assertAlmostEqual(c[2, 2, 2], b[2, 2, 2])
205+
f = BernsteinPoly.from_lookup_table(k, x, y)
206+
self.assertTrue(np.allclose(BernsteinGrid(x).eval(f.prior())), y)
236207

237208

238209
class BSolveTest(unittest.TestCase):

0 commit comments

Comments
 (0)