@@ -92,7 +92,7 @@ def test_eval(self):
9292 c = self .c
9393 f = self .f
9494 y = f .eval (c )
95- precalculated = np .asarray (
95+ y_precalculated = np .asarray (
9696 [
9797 [
9898 [19.8694 , 19.7848 , 20.3956 ],
@@ -111,8 +111,8 @@ def test_eval(self):
111111 ],
112112 ]
113113 )
114- self .assertEqual (( 3 , 3 , 3 ) , y .shape )
115- self .assertTrue (np .allclose (y , precalculated ))
114+ self .assertEqual (y_precalculated . shape , y .shape )
115+ self .assertTrue (np .allclose (y , y_precalculated ))
116116
117117 g = f .jac (c )
118118 self .assertEqual (y .shape + c .shape , g .shape )
@@ -179,9 +179,9 @@ def test_bernstein_poly(self):
179179 )
180180 f = BernsteinPoly (c )
181181 y = f .eval (c , x )
182- precalculated = np .asarray ([19.8694 , 32.0761 , 19.6774 ])
183- self .assertEqual (( 3 ,) , y .shape )
184- self .assertTrue (jnp .allclose (y , precalculated ))
182+ y_precalculated = np .asarray ([19.8694 , 32.0761 , 19.6774 ])
183+ self .assertEqual (y_precalculated . shape , y .shape )
184+ self .assertTrue (np .allclose (y , y_precalculated ))
185185
186186 g = f .jac_p (c , x )
187187 self .assertEqual ((3 ,) + d , g .shape )
@@ -192,20 +192,20 @@ def test_bernstein_poly(self):
192192 self .assertTrue (np .all (g > 0.0 ))
193193
194194 def test_from_lookup_table (self ):
195- k = 5
196- x = np . asarray ([ 0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
197- y = np .square ( x ) + 2.0 * x + 3 .0
198-
199- f = BernsteinPoly . from_lookup_table (( k ,), ( x ,), y , non_negative = True )
200- c = f . prior ()
201- self . assertEqual (( k + 1 ,), c . shape )
202- self . assertAlmostEqual ( 3.0 , c [ 0 ] )
203- self . assertAlmostEqual ( 3.4 , c [ 1 ] )
204- self . assertAlmostEqual ( 3.9 , c [ 2 ])
205- self . assertAlmostEqual ( 4.5 , c [ 3 ] )
206- self . assertAlmostEqual ( 5.2 , c [ 4 ] )
207- self .assertAlmostEqual ( 6.0 , c [ 5 ] )
208- self .assertTrue (jnp .allclose (f . eval ( c , x ), y ))
195+ k = ( 3 , 4 , 2 )
196+ d = tuple ([ k_ + 1 for k_ in k ])
197+ c = np .arange ( np . prod ( np . asarray ( d ))). reshape ( d ) + 1 .0
198+ x = (
199+ np . asarray ([ 0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ]),
200+ np . asarray ([ 0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ]),
201+ np . asarray ([ 0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ]),
202+ )
203+ y = BernsteinGrid ( x ). eval ( c )
204+
205+ f = BernsteinPoly . from_lookup_table ( k , x , y )
206+ b = f . prior ( )
207+ self .assertEqual ( c . shape , b . shape )
208+ self .assertTrue (np .allclose (b , c ))
209209
210210
211211class BSolveTest (unittest .TestCase ):
@@ -217,49 +217,53 @@ class BSolveTest(unittest.TestCase):
217217 def test_b_solve_0_2 (self ):
218218 r"""Fit :math:`B_{0,2}(x)`."""
219219 k = 2
220- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
220+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
221221 y = jnp .square (1.0 - x )
222222
223223 c = b_solve ((k ,), (x ,), y , non_negative = True )
224224 self .assertEqual ((k + 1 ,), c .shape )
225+ self .assertFalse (np .any (c < 0.0 ))
225226 self .assertAlmostEqual (1.0 , c [0 ].item ())
226227 self .assertAlmostEqual (0.0 , c [1 ].item ())
227228 self .assertAlmostEqual (0.0 , c [2 ].item ())
228229
229230 def test_b_solve_1_2 (self ):
230231 r"""Fit :math:`B_{1,2}(x)`."""
231232 k = 2
232- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
233+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
233234 y = 2.0 * x * (1.0 - x )
234235
235236 c = b_solve ((k ,), (x ,), y , non_negative = True )
236237 self .assertEqual ((k + 1 ,), c .shape )
238+ self .assertFalse (np .any (c < 0.0 ))
237239 self .assertAlmostEqual (0.0 , c [0 ].item ())
238240 self .assertAlmostEqual (1.0 , c [1 ].item ())
239241 self .assertAlmostEqual (0.0 , c [2 ].item ())
240242
241243 def test_b_solve_2_2 (self ):
242244 r"""Fit :math:`B_{2,2}(x)`."""
243245 k = 2
244- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
246+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
245247 y = jnp .square (x )
246248
247249 c = b_solve ((k ,), (x ,), y , non_negative = True )
248250 self .assertEqual ((k + 1 ,), c .shape )
251+ self .assertFalse (np .any (c < 0.0 ))
249252 self .assertAlmostEqual (0.0 , c [0 ].item ())
250253 self .assertAlmostEqual (0.0 , c [1 ].item ())
251254 self .assertAlmostEqual (1.0 , c [2 ].item ())
252255
253256 def test_b_solve_0_0_2_2 (self ):
254257 r"""Fit :math:`B_{(0,0),(2,2)}(x_0, x_1)`."""
255258 k = 2
256- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
259+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
257260 y = jnp .square (1.0 - x [jnp .newaxis , :]) * jnp .square (
258261 1.0 - x [:, jnp .newaxis ]
259262 )
260263
261264 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
262265 self .assertEqual ((k + 1 , k + 1 ), c .shape )
266+ self .assertFalse (np .any (c < 0.0 ))
263267 self .assertAlmostEqual (1.0 , c [0 , 0 ].item ())
264268 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
265269 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -273,7 +277,7 @@ def test_b_solve_0_0_2_2(self):
273277 def test_b_solve_1_0_2_2 (self ):
274278 r"""Fit :math:`B_{(1,0),(2,2)}(x_0, x_1)`."""
275279 k = 2
276- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
280+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
277281 y = (
278282 2.0
279283 * x [:, jnp .newaxis ]
@@ -283,6 +287,7 @@ def test_b_solve_1_0_2_2(self):
283287
284288 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
285289 self .assertEqual ((k + 1 , k + 1 ), c .shape )
290+ self .assertFalse (np .any (c < 0.0 ))
286291 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
287292 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
288293 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -296,13 +301,14 @@ def test_b_solve_1_0_2_2(self):
296301 def test_b_solve_2_0_2_2 (self ):
297302 r"""Fit :math:`B_{(2,0),(2,2)}(x_0, x_1)`."""
298303 k = 2
299- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
304+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
300305 y = jnp .square (x [:, jnp .newaxis ]) * jnp .square (
301306 1.0 - x [jnp .newaxis , :]
302307 )
303308
304309 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
305310 self .assertEqual ((k + 1 , k + 1 ), c .shape )
311+ self .assertFalse (np .any (c < 0.0 ))
306312 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
307313 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
308314 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -316,7 +322,7 @@ def test_b_solve_2_0_2_2(self):
316322 def test_b_solve_0_1_2_2 (self ):
317323 r"""Fit :math:`B_{(0,1),(2,2)}(x_0, x_1)`."""
318324 k = 2
319- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
325+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
320326 y = (
321327 2.0
322328 * jnp .square (1.0 - x [:, jnp .newaxis ])
@@ -326,6 +332,7 @@ def test_b_solve_0_1_2_2(self):
326332
327333 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
328334 self .assertEqual ((k + 1 , k + 1 ), c .shape )
335+ self .assertFalse (np .any (c < 0.0 ))
329336 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
330337 self .assertAlmostEqual (1.0 , c [0 , 1 ].item ())
331338 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -339,7 +346,7 @@ def test_b_solve_0_1_2_2(self):
339346 def test_b_solve_1_1_2_2 (self ):
340347 r"""Fit :math:`B_{(1,1),(2,2)}(x_0, x_1)`."""
341348 k = 2
342- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
349+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
343350 y = (
344351 4.0
345352 * x [:, jnp .newaxis ]
@@ -350,6 +357,7 @@ def test_b_solve_1_1_2_2(self):
350357
351358 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
352359 self .assertEqual ((k + 1 , k + 1 ), c .shape )
360+ self .assertFalse (np .any (c < 0.0 ))
353361 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
354362 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
355363 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -363,7 +371,7 @@ def test_b_solve_1_1_2_2(self):
363371 def test_b_solve_2_1_2_2 (self ):
364372 r"""Fit :math:`B_{(2,1),(2,2)}(x_0, x_1)`."""
365373 k = 2
366- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
374+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
367375 y = (
368376 2.0
369377 * jnp .square (x [:, jnp .newaxis ])
@@ -372,6 +380,7 @@ def test_b_solve_2_1_2_2(self):
372380 )
373381
374382 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
383+ self .assertFalse (np .any (c < 0.0 ))
375384 self .assertEqual ((k + 1 , k + 1 ), c .shape )
376385 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
377386 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
@@ -386,12 +395,13 @@ def test_b_solve_2_1_2_2(self):
386395 def test_b_solve_0_2_2_2 (self ):
387396 r"""Fit :math:`B_{(0,2),(2,2)}(x_0, x_1)`."""
388397 k = 2
389- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
398+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
390399 y = jnp .square (1.0 - x [:, jnp .newaxis ]) * jnp .square (
391400 x [jnp .newaxis , :]
392401 )
393402
394403 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
404+ self .assertFalse (np .any (c < 0.0 ))
395405 self .assertEqual ((k + 1 , k + 1 ), c .shape )
396406 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
397407 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
@@ -406,7 +416,7 @@ def test_b_solve_0_2_2_2(self):
406416 def test_b_solve_1_2_2_2 (self ):
407417 r"""Fit :math:`B_{(1,2),(2,2)}(x_0, x_1)`."""
408418 k = 2
409- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
419+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
410420 y = (
411421 2.0
412422 * x [:, jnp .newaxis ]
@@ -416,6 +426,7 @@ def test_b_solve_1_2_2_2(self):
416426
417427 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
418428 self .assertEqual ((k + 1 , k + 1 ), c .shape )
429+ self .assertFalse (np .any (c < 0.0 ))
419430 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
420431 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
421432 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
@@ -429,11 +440,12 @@ def test_b_solve_1_2_2_2(self):
429440 def test_b_solve_2_2_2_2 (self ):
430441 r"""Fit :math:`B_{(2,2),(2,2)}(x_0, x_1)`."""
431442 k = 2
432- x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
443+ x = jnp .asarray ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
433444 y = jnp .square (x [:, jnp .newaxis ]) * jnp .square (x [jnp .newaxis , :])
434445
435446 c = b_solve ((k , k ), (x , x ), y , non_negative = True )
436447 self .assertEqual ((k + 1 , k + 1 ), c .shape )
448+ self .assertFalse (np .any (c < 0.0 ))
437449 self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
438450 self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
439451 self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
0 commit comments