@@ -194,37 +194,19 @@ def test_bernstein_poly(self):
194194 def test_from_lookup_table (self ):
195195 k = (2 , 2 , 2 )
196196 d = tuple ([k_ + 1 for k_ in k ])
197+ c = np .arange (np .prod (np .asarray (d ))).reshape (d ) + 1.0
197198 x = (
198199 np .asarray ([0.2718 , 0.5772 , 0.3141 ]),
199200 np .asarray ([0.5772 , 0.3141 , 0.2718 ]),
200201 np .asarray ([0.3141 , 0.2718 , 0.5772 ]),
201202 )
202- y = np .asarray (
203- [
204- [
205- [19.8694 , 19.7848 , 20.3956 ],
206- [17.5015 , 17.4169 , 18.0277 ],
207- [17.1208 , 17.0362 , 17.6470 ],
208- ],
209- [
210- [34.5286 , 34.4440 , 35.0548 ],
211- [32.1607 , 32.0761 , 32.6869 ],
212- [31.7800 , 31.6954 , 32.3062 ],
213- ],
214- [
215- [21.8998 , 21.8152 , 22.4260 ],
216- [19.5319 , 19.4473 , 20.0581 ],
217- [19.1512 , 19.0666 , 19.6774 ],
218- ],
219- ]
220- )
203+ y = BernsteinGrid (x ).eval (c )
221204
222205 f = BernsteinPoly .from_lookup_table (k , x , y , non_negative = True )
223- c = f .prior ()
224- c_expected = np .arange (np .prod (np .asarray (d ))).reshape (d ) + 1.0
225- self .assertEqual (c_expected .shape , c .shape )
226- self .assertTrue (np .all (c > 0.0 ))
227- self .assertTrue (np .allclose (c , c_expected ))
206+ b = f .prior ()
207+ self .assertEqual (c .shape , b .shape )
208+ self .assertFalse (np .any (b < 0.0 ))
209+ self .assertTrue (np .allclose (b , c ))
228210
229211
230212class BSolveTest (unittest .TestCase ):
@@ -241,6 +223,7 @@ def test_b_solve_0_2(self):
241223
242224 c = b_solve ((k ,), (x ,), y , non_negative = True )
243225 self .assertEqual ((k + 1 ,), c .shape )
226+ self .assertFalse (np .any (c < 0.0 ))
244227 self .assertAlmostEqual (1.0 , c [0 ].item ())
245228 self .assertAlmostEqual (0.0 , c [1 ].item ())
246229 self .assertAlmostEqual (0.0 , c [2 ].item ())
@@ -253,6 +236,7 @@ def test_b_solve_1_2(self):
253236
254237 c = b_solve ((k ,), (x ,), y , non_negative = True )
255238 self .assertEqual ((k + 1 ,), c .shape )
239+ self .assertFalse (np .any (c < 0.0 ))
256240 self .assertAlmostEqual (0.0 , c [0 ].item ())
257241 self .assertAlmostEqual (1.0 , c [1 ].item ())
258242 self .assertAlmostEqual (0.0 , c [2 ].item ())
@@ -265,6 +249,7 @@ def test_b_solve_2_2(self):
265249
266250 c = b_solve ((k ,), (x ,), y , non_negative = True )
267251 self .assertEqual ((k + 1 ,), c .shape )
252+ self .assertFalse (np .any (c < 0.0 ))
268253 self .assertAlmostEqual (0.0 , c [0 ].item ())
269254 self .assertAlmostEqual (0.0 , c [1 ].item ())
270255 self .assertAlmostEqual (1.0 , c [2 ].item ())
@@ -279,6 +264,7 @@ def test_b_solve_0_0_2_2(self):
279264
280265 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
281266 self .assertEqual ((k + 1 , k + 1 ), c .shape )
267+ self .assertFalse (np .any (c < 0.0 ))
282268 self .assertAlmostEqual (1.0 , c [0 , 0 ].item ())
283269 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
284270 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -302,6 +288,7 @@ def test_b_solve_1_0_2_2(self):
302288
303289 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
304290 self .assertEqual ((k + 1 , k + 1 ), c .shape )
291+ self .assertFalse (np .any (c < 0.0 ))
305292 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
306293 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
307294 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -322,6 +309,7 @@ def test_b_solve_2_0_2_2(self):
322309
323310 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
324311 self .assertEqual ((k + 1 , k + 1 ), c .shape )
312+ self .assertFalse (np .any (c < 0.0 ))
325313 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
326314 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
327315 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -345,6 +333,7 @@ def test_b_solve_0_1_2_2(self):
345333
346334 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
347335 self .assertEqual ((k + 1 , k + 1 ), c .shape )
336+ self .assertFalse (np .any (c < 0.0 ))
348337 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
349338 self .assertAlmostEqual (1.0 , c [0 , 1 ].item ())
350339 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -369,6 +358,7 @@ def test_b_solve_1_1_2_2(self):
369358
370359 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
371360 self .assertEqual ((k + 1 , k + 1 ), c .shape )
361+ self .assertFalse (np .any (c < 0.0 ))
372362 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
373363 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
374364 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -391,6 +381,7 @@ def test_b_solve_2_1_2_2(self):
391381 )
392382
393383 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
384+ self .assertFalse (np .any (c < 0.0 ))
394385 self .assertEqual ((k + 1 , k + 1 ), c .shape )
395386 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
396387 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
@@ -411,6 +402,7 @@ def test_b_solve_0_2_2_2(self):
411402 )
412403
413404 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
405+ self .assertFalse (np .any (c < 0.0 ))
414406 self .assertEqual ((k + 1 , k + 1 ), c .shape )
415407 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
416408 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
@@ -435,6 +427,7 @@ def test_b_solve_1_2_2_2(self):
435427
436428 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
437429 self .assertEqual ((k + 1 , k + 1 ), c .shape )
430+ self .assertFalse (np .any (c < 0.0 ))
438431 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
439432 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
440433 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -453,6 +446,7 @@ def test_b_solve_2_2_2_2(self):
453446
454447 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
455448 self .assertEqual ((k + 1 , k + 1 ), c .shape )
449+ self .assertFalse (np .any (c < 0.0 ))
456450 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
457451 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
458452 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
0 commit comments