@@ -118,15 +118,14 @@ def test_eval(self):
118118 self .assertEqual (y .shape + c .shape , g .shape )
119119 self .assertTrue (np .all (g > 0.0 ))
120120
121- u = to_var (0.1 * c )
122- u = f .lpu (c , u , diag = True )
123- self .assertEqual (y .shape , u .shape )
124- self .assertTrue (np .all (u > 0.0 ))
121+ u = np . square (0.1 * c )
122+ v = f .lpu (c , u , diag = True )
123+ self .assertEqual (y .shape , v .shape )
124+ self .assertTrue (np .all (v > 0.0 ))
125125
126- u = to_var (0.1 * c )
127- u = f .lpu (c , u )
128- self .assertEqual (y .shape + y .shape , u .shape )
129- self .assertTrue (np .all (u > 0.0 ))
126+ v = f .lpu (c , u )
127+ self .assertEqual (y .shape + y .shape , v .shape )
128+ self .assertTrue (np .all (v > 0.0 ))
130129
131130 def test_jac (self ):
132131 c = self .c
@@ -137,30 +136,28 @@ def test_jac(self):
137136 self .assertEqual (y .shape + c .shape , g .shape )
138137 self .assertTrue (np .all (g > 0.0 ))
139138
140- u = to_var (0.1 * c )
141- u = f .lpu (c , u , diag = True )
142- self .assertEqual (y .shape , u .shape )
143- self .assertTrue (np .all (u > 0.0 ))
139+ u = np . square (0.1 * c )
140+ v = f .lpu (c , u , diag = True )
141+ self .assertEqual (y .shape , v .shape )
142+ self .assertTrue (np .all (v > 0.0 ))
144143
145- u = to_var (0.1 * c )
146- u = f .lpu (c , u )
147- self .assertEqual (y .shape + y .shape , u .shape )
148- self .assertTrue (np .all (u > 0.0 ))
144+ v = f .lpu (c , u )
145+ self .assertEqual (y .shape + y .shape , v .shape )
146+ self .assertTrue (np .all (v > 0.0 ))
149147
150148 def test_lpu (self ):
151149 c = self .c
152150 f = self .f
153151 y = f .eval (c )
154152
155- u = to_var (0.1 * c )
156- u = f .lpu (c , u , diag = True )
157- self .assertEqual (y .shape , u .shape )
158- self .assertTrue (np .all (u > 0.0 ))
153+ u = np . square (0.1 * c )
154+ v = f .lpu (c , u , diag = True )
155+ self .assertEqual (y .shape , v .shape )
156+ self .assertTrue (np .all (v > 0.0 ))
159157
160- u = to_var (0.1 * c )
161- u = f .lpu (c , u )
162- self .assertEqual (y .shape + y .shape , u .shape )
163- self .assertTrue (np .all (u > 0.0 ))
158+ v = f .lpu (c , u )
159+ self .assertEqual (y .shape + y .shape , v .shape )
160+ self .assertTrue (np .all (v > 0.0 ))
164161
165162
166163class BernsteinPolyTest (unittest .TestCase ):
@@ -196,10 +193,8 @@ def test_bernstein_poly(self):
196193
197194 def test_from_lookup_table (self ):
198195 k = 5
199- x = np .array ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
200- y = np .array ( # y = x ** 2 + 2 x + 3
201- [3.00 , 3.44 , 3.96 , 4.56 , 5.24 , 6.00 ]
202- )
196+ x = np .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
197+ y = np .square (x ) + 2.0 * x + 3.0
203198
204199 f = BernsteinPoly .from_lookup_table ((k ,), (x ,), y , non_negative = True )
205200 c = f .prior ()
@@ -214,55 +209,240 @@ def test_from_lookup_table(self):
214209
215210
216211class BSolveTest (unittest .TestCase ):
217- """Tests the solving function."""
212+ """
213+ Tests the solving function by fitting coefficients to
214+ Bernstein basis polynomials.
215+ """
218216
219- def test_b_solve_degree_2 (self ):
217+ def test_b_solve_0_2 (self ):
218+ r"""Fit :math:`B_{0,2}(x)`."""
220219 k = 2
221- x = jnp .array ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
222- y = jnp .array ( # y = x ** 2 + 2 x + 3
223- [3.00 , 3.44 , 3.96 , 4.56 , 5.24 , 6.00 ]
224- )
220+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
221+ y = jnp .square (1.0 - x )
225222
226223 c = b_solve ((k ,), (x ,), y , non_negative = True )
227224 self .assertEqual ((k + 1 ,), c .shape )
228- self .assertAlmostEqual (3 .0 , c [0 ].item ())
229- self .assertAlmostEqual (4 .0 , c [1 ].item ())
230- self .assertAlmostEqual (6 .0 , c [2 ].item ())
225+ self .assertAlmostEqual (1 .0 , c [0 ].item ())
226+ self .assertAlmostEqual (0 .0 , c [1 ].item ())
227+ self .assertAlmostEqual (0 .0 , c [2 ].item ())
231228
232- def test_b_solve_degree_5 (self ):
233- k = 5
234- x = jnp .array ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
235- y = (
236- jnp .array ( # y = x ** 2 + 2 x - 1 / 100
237- [0.00 , 0.44 , 0.96 , 1.56 , 2.24 , 3.00 ]
238- )
239- - 0.01
240- )
229+ def test_b_solve_1_2 (self ):
230+ r"""Fit :math:`B_{1,2}(x)`."""
231+ k = 2
232+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
233+ y = 2.0 * x * (1.0 - x )
241234
242- c = b_solve ((k ,), (x ,), y )
235+ c = b_solve ((k ,), (x ,), y , non_negative = True )
243236 self .assertEqual ((k + 1 ,), c .shape )
244- self .assertAlmostEqual (0.00 , c [0 ].item () + 0.01 )
245- self .assertAlmostEqual (0.39 , c [1 ].item ())
246- self .assertAlmostEqual (0.89 , c [2 ].item ())
247- self .assertAlmostEqual (1.49 , c [3 ].item ())
248- self .assertAlmostEqual (2.19 , c [4 ].item ())
249- self .assertAlmostEqual (2.99 , c [5 ].item ())
237+ self .assertAlmostEqual (0.0 , c [0 ].item ())
238+ self .assertAlmostEqual (1.0 , c [1 ].item ())
239+ self .assertAlmostEqual (0.0 , c [2 ].item ())
240+
241+ def test_b_solve_2_2 (self ):
242+ r"""Fit :math:`B_{2,2}(x)`."""
243+ k = 2
244+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
245+ y = jnp .square (x )
250246
251247 c = b_solve ((k ,), (x ,), y , non_negative = True )
252248 self .assertEqual ((k + 1 ,), c .shape )
253- self .assertAlmostEqual (0.00 , c [0 ].item ())
254- self .assertAlmostEqual (0.38 , c [1 ].item (), places = 2 )
255- self .assertAlmostEqual (0.90 , c [2 ].item (), places = 2 )
256- self .assertAlmostEqual (1.48 , c [3 ].item (), places = 2 )
257- self .assertAlmostEqual (2.19 , c [4 ].item (), places = 2 )
258- self .assertAlmostEqual (2.99 , c [5 ].item (), places = 2 )
249+ self .assertAlmostEqual (0.0 , c [0 ].item ())
250+ self .assertAlmostEqual (0.0 , c [1 ].item ())
251+ self .assertAlmostEqual (1.0 , c [2 ].item ())
259252
253+ def test_b_solve_0_0_2_2 (self ):
254+ r"""Fit :math:`B_{(0,0),(2,2)}(x_0, x_1)`."""
255+ k = 2
256+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
257+ y = jnp .square (1.0 - x [jnp .newaxis , :]) * jnp .square (
258+ 1.0 - x [:, jnp .newaxis ]
259+ )
260260
261- def to_var (u : np .ndarray ) -> np .ndarray :
262- """
263- Converts standard uncertainty to a diagonal uncertainty tensor.
264- """
265- return np .square (u )
261+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
262+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
263+ self .assertAlmostEqual (1.0 , c [0 , 0 ].item ())
264+ self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
265+ self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
266+ self .assertAlmostEqual (0.0 , c [1 , 0 ].item ())
267+ self .assertAlmostEqual (0.0 , c [1 , 1 ].item ())
268+ self .assertAlmostEqual (0.0 , c [1 , 2 ].item ())
269+ self .assertAlmostEqual (0.0 , c [2 , 0 ].item ())
270+ self .assertAlmostEqual (0.0 , c [2 , 1 ].item ())
271+ self .assertAlmostEqual (0.0 , c [2 , 2 ].item ())
272+
273+ def test_b_solve_1_0_2_2 (self ):
274+ r"""Fit :math:`B_{(1,0),(2,2)}(x_0, x_1)`."""
275+ k = 2
276+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
277+ y = (
278+ 2.0
279+ * x [:, jnp .newaxis ]
280+ * (1.0 - x [:, jnp .newaxis ])
281+ * jnp .square (1.0 - x [jnp .newaxis , :])
282+ )
283+
284+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
285+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
286+ self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
287+ self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
288+ self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
289+ self .assertAlmostEqual (1.0 , c [1 , 0 ].item ())
290+ self .assertAlmostEqual (0.0 , c [1 , 1 ].item ())
291+ self .assertAlmostEqual (0.0 , c [1 , 2 ].item ())
292+ self .assertAlmostEqual (0.0 , c [2 , 0 ].item ())
293+ self .assertAlmostEqual (0.0 , c [2 , 1 ].item ())
294+ self .assertAlmostEqual (0.0 , c [2 , 2 ].item ())
295+
296+ def test_b_solve_2_0_2_2 (self ):
297+ r"""Fit :math:`B_{(2,0),(2,2)}(x_0, x_1)`."""
298+ k = 2
299+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
300+ y = jnp .square (x [:, jnp .newaxis ]) * jnp .square (
301+ 1.0 - x [jnp .newaxis , :]
302+ )
303+
304+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
305+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
306+ self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
307+ self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
308+ self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
309+ self .assertAlmostEqual (0.0 , c [1 , 0 ].item ())
310+ self .assertAlmostEqual (0.0 , c [1 , 1 ].item ())
311+ self .assertAlmostEqual (0.0 , c [1 , 2 ].item ())
312+ self .assertAlmostEqual (1.0 , c [2 , 0 ].item ())
313+ self .assertAlmostEqual (0.0 , c [2 , 1 ].item ())
314+ self .assertAlmostEqual (0.0 , c [2 , 2 ].item ())
315+
316+ def test_b_solve_0_1_2_2 (self ):
317+ r"""Fit :math:`B_{(0,1),(2,2)}(x_0, x_1)`."""
318+ k = 2
319+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
320+ y = (
321+ 2.0
322+ * jnp .square (1.0 - x [:, jnp .newaxis ])
323+ * x [jnp .newaxis , :]
324+ * (1.0 - x [jnp .newaxis , :])
325+ )
326+
327+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
328+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
329+ self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
330+ self .assertAlmostEqual (1.0 , c [0 , 1 ].item ())
331+ self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
332+ self .assertAlmostEqual (0.0 , c [1 , 0 ].item ())
333+ self .assertAlmostEqual (0.0 , c [1 , 1 ].item ())
334+ self .assertAlmostEqual (0.0 , c [1 , 2 ].item ())
335+ self .assertAlmostEqual (0.0 , c [2 , 0 ].item ())
336+ self .assertAlmostEqual (0.0 , c [2 , 1 ].item ())
337+ self .assertAlmostEqual (0.0 , c [2 , 2 ].item ())
338+
339+ def test_b_solve_1_1_2_2 (self ):
340+ r"""Fit :math:`B_{(1,1),(2,2)}(x_0, x_1)`."""
341+ k = 2
342+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
343+ y = (
344+ 4.0
345+ * x [:, jnp .newaxis ]
346+ * (1.0 - x [:, jnp .newaxis ])
347+ * x [jnp .newaxis , :]
348+ * (1.0 - x [jnp .newaxis , :])
349+ )
350+
351+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
352+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
353+ self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
354+ self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
355+ self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
356+ self .assertAlmostEqual (0.0 , c [1 , 0 ].item ())
357+ self .assertAlmostEqual (1.0 , c [1 , 1 ].item ())
358+ self .assertAlmostEqual (0.0 , c [1 , 2 ].item ())
359+ self .assertAlmostEqual (0.0 , c [2 , 0 ].item ())
360+ self .assertAlmostEqual (0.0 , c [2 , 1 ].item ())
361+ self .assertAlmostEqual (0.0 , c [2 , 2 ].item ())
362+
363+ def test_b_solve_2_1_2_2 (self ):
364+ r"""Fit :math:`B_{(2,1),(2,2)}(x_0, x_1)`."""
365+ k = 2
366+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
367+ y = (
368+ 2.0
369+ * jnp .square (x [:, jnp .newaxis ])
370+ * x [jnp .newaxis , :]
371+ * (1.0 - x [jnp .newaxis , :])
372+ )
373+
374+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
375+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
376+ self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
377+ self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
378+ self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
379+ self .assertAlmostEqual (0.0 , c [1 , 0 ].item ())
380+ self .assertAlmostEqual (0.0 , c [1 , 1 ].item ())
381+ self .assertAlmostEqual (0.0 , c [1 , 2 ].item ())
382+ self .assertAlmostEqual (0.0 , c [2 , 0 ].item ())
383+ self .assertAlmostEqual (1.0 , c [2 , 1 ].item ())
384+ self .assertAlmostEqual (0.0 , c [2 , 2 ].item ())
385+
386+ def test_b_solve_0_2_2_2 (self ):
387+ r"""Fit :math:`B_{(0,2),(2,2)}(x_0, x_1)`."""
388+ k = 2
389+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
390+ y = jnp .square (1.0 - x [:, jnp .newaxis ]) * jnp .square (
391+ x [jnp .newaxis , :]
392+ )
393+
394+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
395+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
396+ self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
397+ self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
398+ self .assertAlmostEqual (1.0 , c [0 , 2 ].item ())
399+ self .assertAlmostEqual (0.0 , c [1 , 0 ].item ())
400+ self .assertAlmostEqual (0.0 , c [1 , 1 ].item ())
401+ self .assertAlmostEqual (0.0 , c [1 , 2 ].item ())
402+ self .assertAlmostEqual (0.0 , c [2 , 0 ].item ())
403+ self .assertAlmostEqual (0.0 , c [2 , 1 ].item ())
404+ self .assertAlmostEqual (0.0 , c [2 , 2 ].item ())
405+
406+ def test_b_solve_1_2_2_2 (self ):
407+ r"""Fit :math:`B_{(1,2),(2,2)}(x_0, x_1)`."""
408+ k = 2
409+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
410+ y = (
411+ 2.0
412+ * x [:, jnp .newaxis ]
413+ * (1.0 - x [:, jnp .newaxis ])
414+ * jnp .square (x [jnp .newaxis , :])
415+ )
416+
417+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
418+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
419+ self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
420+ self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
421+ self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
422+ self .assertAlmostEqual (0.0 , c [1 , 0 ].item ())
423+ self .assertAlmostEqual (0.0 , c [1 , 1 ].item ())
424+ self .assertAlmostEqual (1.0 , c [1 , 2 ].item ())
425+ self .assertAlmostEqual (0.0 , c [2 , 0 ].item ())
426+ self .assertAlmostEqual (0.0 , c [2 , 1 ].item ())
427+ self .assertAlmostEqual (0.0 , c [2 , 2 ].item ())
428+
429+ def test_b_solve_2_2_2_2 (self ):
430+ r"""Fit :math:`B_{(2,2),(2,2)}(x_0, x_1)`."""
431+ k = 2
432+ x = jnp .asarray ([0.00 , 0.20 , 0.40 , 0.60 , 0.80 , 1.00 ])
433+ y = jnp .square (x [:, jnp .newaxis ]) * jnp .square (x [jnp .newaxis , :])
434+
435+ c = b_solve ((k , k ), (x , x ), y , non_negative = True )
436+ self .assertEqual ((k + 1 , k + 1 ), c .shape )
437+ self .assertAlmostEqual (0.0 , c [0 , 0 ].item ())
438+ self .assertAlmostEqual (0.0 , c [0 , 1 ].item ())
439+ self .assertAlmostEqual (0.0 , c [0 , 2 ].item ())
440+ self .assertAlmostEqual (0.0 , c [1 , 0 ].item ())
441+ self .assertAlmostEqual (0.0 , c [1 , 1 ].item ())
442+ self .assertAlmostEqual (0.0 , c [1 , 2 ].item ())
443+ self .assertAlmostEqual (0.0 , c [2 , 0 ].item ())
444+ self .assertAlmostEqual (0.0 , c [2 , 1 ].item ())
445+ self .assertAlmostEqual (1.0 , c [2 , 2 ].item ())
266446
267447
268448if __name__ == "__main__" :
0 commit comments