Skip to content

Commit 09fbec3

Browse files
authored
Refactor test_from_lookup_table and b_solve tests
Refactor tests to use BernsteinGrid evaluation and update assertions for non-negativity checks.
1 parent f34b217 commit 09fbec3

1 file changed

Lines changed: 18 additions & 24 deletions

File tree

test/b/test_b_jax.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -194,37 +194,19 @@ def test_bernstein_poly(self):
194194
def test_from_lookup_table(self):
195195
k = (2, 2, 2)
196196
d = tuple([k_ + 1 for k_ in k])
197+
c = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0
197198
x = (
198199
np.asarray([0.2718, 0.5772, 0.3141]),
199200
np.asarray([0.5772, 0.3141, 0.2718]),
200201
np.asarray([0.3141, 0.2718, 0.5772]),
201202
)
202-
y = np.asarray(
203-
[
204-
[
205-
[19.8694, 19.7848, 20.3956],
206-
[17.5015, 17.4169, 18.0277],
207-
[17.1208, 17.0362, 17.6470],
208-
],
209-
[
210-
[34.5286, 34.4440, 35.0548],
211-
[32.1607, 32.0761, 32.6869],
212-
[31.7800, 31.6954, 32.3062],
213-
],
214-
[
215-
[21.8998, 21.8152, 22.4260],
216-
[19.5319, 19.4473, 20.0581],
217-
[19.1512, 19.0666, 19.6774],
218-
],
219-
]
220-
)
203+
y = BernsteinGrid(x).eval(c)
221204

222205
f = BernsteinPoly.from_lookup_table(k, x, y, non_negative=True)
223-
c = f.prior()
224-
c_expected = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0
225-
self.assertEqual(c_expected.shape, c.shape)
226-
self.assertTrue(np.all(c > 0.0))
227-
self.assertTrue(np.allclose(c, c_expected))
206+
b = f.prior()
207+
self.assertEqual(c.shape, b.shape)
208+
self.assertFalse(np.any(b < 0.0))
209+
self.assertTrue(np.allclose(b, c))
228210

229211

230212
class BSolveTest(unittest.TestCase):
@@ -241,6 +223,7 @@ def test_b_solve_0_2(self):
241223

242224
c = b_solve((k,), (x,), y, non_negative=True)
243225
self.assertEqual((k + 1,), c.shape)
226+
self.assertFalse(np.any(c < 0.0))
244227
self.assertAlmostEqual(1.0, c[0].item())
245228
self.assertAlmostEqual(0.0, c[1].item())
246229
self.assertAlmostEqual(0.0, c[2].item())
@@ -253,6 +236,7 @@ def test_b_solve_1_2(self):
253236

254237
c = b_solve((k,), (x,), y, non_negative=True)
255238
self.assertEqual((k + 1,), c.shape)
239+
self.assertFalse(np.any(c < 0.0))
256240
self.assertAlmostEqual(0.0, c[0].item())
257241
self.assertAlmostEqual(1.0, c[1].item())
258242
self.assertAlmostEqual(0.0, c[2].item())
@@ -265,6 +249,7 @@ def test_b_solve_2_2(self):
265249

266250
c = b_solve((k,), (x,), y, non_negative=True)
267251
self.assertEqual((k + 1,), c.shape)
252+
self.assertFalse(np.any(c < 0.0))
268253
self.assertAlmostEqual(0.0, c[0].item())
269254
self.assertAlmostEqual(0.0, c[1].item())
270255
self.assertAlmostEqual(1.0, c[2].item())
@@ -279,6 +264,7 @@ def test_b_solve_0_0_2_2(self):
279264

280265
c = b_solve((k, k), (x, x), y, non_negative=True)
281266
self.assertEqual((k + 1, k + 1), c.shape)
267+
self.assertFalse(np.any(c < 0.0))
282268
self.assertAlmostEqual(1.0, c[0, 0].item())
283269
self.assertAlmostEqual(0.0, c[0, 1].item())
284270
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -302,6 +288,7 @@ def test_b_solve_1_0_2_2(self):
302288

303289
c = b_solve((k, k), (x, x), y, non_negative=True)
304290
self.assertEqual((k + 1, k + 1), c.shape)
291+
self.assertFalse(np.any(c < 0.0))
305292
self.assertAlmostEqual(0.0, c[0, 0].item())
306293
self.assertAlmostEqual(0.0, c[0, 1].item())
307294
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -322,6 +309,7 @@ def test_b_solve_2_0_2_2(self):
322309

323310
c = b_solve((k, k), (x, x), y, non_negative=True)
324311
self.assertEqual((k + 1, k + 1), c.shape)
312+
self.assertFalse(np.any(c < 0.0))
325313
self.assertAlmostEqual(0.0, c[0, 0].item())
326314
self.assertAlmostEqual(0.0, c[0, 1].item())
327315
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -345,6 +333,7 @@ def test_b_solve_0_1_2_2(self):
345333

346334
c = b_solve((k, k), (x, x), y, non_negative=True)
347335
self.assertEqual((k + 1, k + 1), c.shape)
336+
self.assertFalse(np.any(c < 0.0))
348337
self.assertAlmostEqual(0.0, c[0, 0].item())
349338
self.assertAlmostEqual(1.0, c[0, 1].item())
350339
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -369,6 +358,7 @@ def test_b_solve_1_1_2_2(self):
369358

370359
c = b_solve((k, k), (x, x), y, non_negative=True)
371360
self.assertEqual((k + 1, k + 1), c.shape)
361+
self.assertFalse(np.any(c < 0.0))
372362
self.assertAlmostEqual(0.0, c[0, 0].item())
373363
self.assertAlmostEqual(0.0, c[0, 1].item())
374364
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -391,6 +381,7 @@ def test_b_solve_2_1_2_2(self):
391381
)
392382

393383
c = b_solve((k, k), (x, x), y, non_negative=True)
384+
self.assertFalse(np.any(c < 0.0))
394385
self.assertEqual((k + 1, k + 1), c.shape)
395386
self.assertAlmostEqual(0.0, c[0, 0].item())
396387
self.assertAlmostEqual(0.0, c[0, 1].item())
@@ -411,6 +402,7 @@ def test_b_solve_0_2_2_2(self):
411402
)
412403

413404
c = b_solve((k, k), (x, x), y, non_negative=True)
405+
self.assertFalse(np.any(c < 0.0))
414406
self.assertEqual((k + 1, k + 1), c.shape)
415407
self.assertAlmostEqual(0.0, c[0, 0].item())
416408
self.assertAlmostEqual(0.0, c[0, 1].item())
@@ -435,6 +427,7 @@ def test_b_solve_1_2_2_2(self):
435427

436428
c = b_solve((k, k), (x, x), y, non_negative=True)
437429
self.assertEqual((k + 1, k + 1), c.shape)
430+
self.assertFalse(np.any(c < 0.0))
438431
self.assertAlmostEqual(0.0, c[0, 0].item())
439432
self.assertAlmostEqual(0.0, c[0, 1].item())
440433
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -453,6 +446,7 @@ def test_b_solve_2_2_2_2(self):
453446

454447
c = b_solve((k, k), (x, x), y, non_negative=True)
455448
self.assertEqual((k + 1, k + 1), c.shape)
449+
self.assertFalse(np.any(c < 0.0))
456450
self.assertAlmostEqual(0.0, c[0, 0].item())
457451
self.assertAlmostEqual(0.0, c[0, 1].item())
458452
self.assertAlmostEqual(0.0, c[0, 2].item())

0 commit comments

Comments
 (0)