55
66import jax
77import jax .numpy as jnp
8+ import numpy as np
89
10+ from uncertaintyx .b .jax import BernsteinGrid
11+ from uncertaintyx .b .jax import BernsteinPoly
912from uncertaintyx .b .jax import b_basis
1013from uncertaintyx .b .jax import b_poly
11- from uncertaintyx .b .jax import b_poly_grid
12- from uncertaintyx .b .jax import b_poly_point
13- from uncertaintyx .b .jax import b_poly_points
1414
1515
1616class BBasisTest (unittest .TestCase ):
@@ -156,23 +156,24 @@ def b_poly_grad(b, x):
156156 self .assertTrue (jnp .allclose (g , 1.0 ))
157157
158158
159- class BPolyGridTest (unittest .TestCase ):
159+ class BernsteinGridTest (unittest .TestCase ):
160160 """
161161 Tests the evaluation of multivariate Bernstein polynomials
162162 against values precalculated with Mathematica.
163163 """
164164
165- def test_b_poly_grid (self ):
165+ def test_bernstein_grid (self ):
166166 k = (4 , 3 , 2 )
167167 d = tuple ([k_ + 1 for k_ in k ])
168- b = jnp .arange (jnp .prod (jnp .asarray (d ))).reshape (d ) + 1.0
168+ b = np .arange (np .prod (np .asarray (d ))).reshape (d ) + 1.0
169169 x = (
170- jnp .asarray ([0.2718 , 0.5772 , 0.3141 ]),
171- jnp .asarray ([0.5772 , 0.3141 , 0.2718 ]),
172- jnp .asarray ([0.3141 , 0.2718 , 0.5772 ]),
170+ np .asarray ([0.2718 , 0.5772 , 0.3141 ]),
171+ np .asarray ([0.5772 , 0.3141 , 0.2718 ]),
172+ np .asarray ([0.3141 , 0.2718 , 0.5772 ]),
173173 )
174- y = b_poly_grid (b , x )
175- precalculated = jnp .asarray (
174+ f = BernsteinGrid (k , x )
175+ y = f .eval (b )
176+ precalculated = np .asarray (
176177 [
177178 [
178179 [19.8694 , 19.7848 , 20.3956 ],
@@ -192,41 +193,57 @@ def test_b_poly_grid(self):
192193 ]
193194 )
194195 self .assertEqual ((3 , 3 , 3 ), y .shape )
195- self .assertTrue (jnp .allclose (y , precalculated ))
196+ self .assertTrue (np .allclose (y , precalculated ))
197+
198+ g = f .jac (b )
199+ self .assertEqual (y .shape + b .shape , g .shape )
200+ self .assertTrue (np .all (g > 0.0 ))
201+
202+ u = to_var (0.1 * b )
203+ u = f .lpu (b , u , diag = True )
204+ self .assertEqual (y .shape , u .shape )
205+ self .assertTrue (np .all (u > 0.0 ))
206+
207+ u = to_var (0.1 * b )
208+ u = f .lpu (b , u )
209+ self .assertEqual (y .shape + y .shape , u .shape )
210+ self .assertTrue (np .all (u > 0.0 ))
196211
197212
198- class BPolyPointsTest (unittest .TestCase ):
213+ class BernsteinPolyTest (unittest .TestCase ):
199214 """
200215 Tests the evaluation of multivariate Bernstein polynomials
201216 against values precalculated with Mathematica.
202217 """
203218
204- def test_b_poly_point (self ):
219+ def test_bernstein_poly (self ):
205220 k = (4 , 3 , 2 )
206221 d = tuple ([k_ + 1 for k_ in k ])
207- b = jnp .arange (jnp .prod (jnp .asarray (d ))).reshape (d ) + 1.0
208- x = jnp .asarray ([0.2718 , 0.5772 , 0.3141 ])
209- y = b_poly_point (b , x )
210- precalculated = 19.8694
211- self .assertEqual ((), y .shape )
212- self .assertTrue (jnp .allclose (y , precalculated ))
213-
214- def test_b_poly_points (self ):
215- k = (4 , 3 , 2 )
216- d = tuple ([k_ + 1 for k_ in k ])
217- b = jnp .arange (jnp .prod (jnp .asarray (d ))).reshape (d ) + 1.0
218- x = jnp .asarray (
222+ b = np .arange (np .prod (np .asarray (d ))).reshape (d ) + 1.0
223+ x = np .asarray (
219224 [
220225 [0.2718 , 0.5772 , 0.3141 ],
221226 [0.5772 , 0.3141 , 0.2718 ],
222227 [0.3141 , 0.2718 , 0.5772 ],
223228 ]
224229 )
225- y = b_poly_points (b , x )
226- precalculated = jnp .asarray ([19.8694 , 32.0761 , 19.6774 ])
230+ f = BernsteinPoly (b )
231+ y = f .eval (b , x )
232+ precalculated = np .asarray ([19.8694 , 32.0761 , 19.6774 ])
227233 self .assertEqual ((3 ,), y .shape )
228234 self .assertTrue (jnp .allclose (y , precalculated ))
229235
236+ g = f .jac_p (b , x )
237+ self .assertEqual ((3 ,) + d , g .shape )
238+ self .assertTrue (np .all (g > 0.0 ))
239+
240+
241+ def to_var (u : np .ndarray ) -> np .ndarray :
242+ """
243+ Converts standard uncertainty to a diagonal uncertainty tensor.
244+ """
245+ return np .square (u )
246+
230247
231248if __name__ == "__main__" :
232249 unittest .main ()
0 commit comments