Skip to content

Commit ac5aac9

Browse files
committed
Update: grid and Bernstein functions
1 parent 8ddeb42 commit ac5aac9

6 files changed

Lines changed: 352 additions & 41 deletions

File tree

test/b/test_b_jax.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
import jax
77
import jax.numpy as jnp
8+
import numpy as np
89

10+
from uncertaintyx.b.jax import BernsteinGrid
11+
from uncertaintyx.b.jax import BernsteinPoly
912
from uncertaintyx.b.jax import b_basis
1013
from uncertaintyx.b.jax import b_poly
11-
from uncertaintyx.b.jax import b_poly_grid
12-
from uncertaintyx.b.jax import b_poly_point
13-
from uncertaintyx.b.jax import b_poly_points
1414

1515

1616
class BBasisTest(unittest.TestCase):
@@ -156,23 +156,24 @@ def b_poly_grad(b, x):
156156
self.assertTrue(jnp.allclose(g, 1.0))
157157

158158

159-
class BPolyGridTest(unittest.TestCase):
159+
class BernsteinGridTest(unittest.TestCase):
160160
"""
161161
Tests the evaluation of multivariate Bernstein polynomials
162162
against values precalculated with Mathematica.
163163
"""
164164

165-
def test_b_poly_grid(self):
165+
def test_bernstein_grid(self):
166166
k = (4, 3, 2)
167167
d = tuple([k_ + 1 for k_ in k])
168-
b = jnp.arange(jnp.prod(jnp.asarray(d))).reshape(d) + 1.0
168+
b = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0
169169
x = (
170-
jnp.asarray([0.2718, 0.5772, 0.3141]),
171-
jnp.asarray([0.5772, 0.3141, 0.2718]),
172-
jnp.asarray([0.3141, 0.2718, 0.5772]),
170+
np.asarray([0.2718, 0.5772, 0.3141]),
171+
np.asarray([0.5772, 0.3141, 0.2718]),
172+
np.asarray([0.3141, 0.2718, 0.5772]),
173173
)
174-
y = b_poly_grid(b, x)
175-
precalculated = jnp.asarray(
174+
f = BernsteinGrid(k, x)
175+
y = f.eval(b)
176+
precalculated = np.asarray(
176177
[
177178
[
178179
[19.8694, 19.7848, 20.3956],
@@ -192,41 +193,57 @@ def test_b_poly_grid(self):
192193
]
193194
)
194195
self.assertEqual((3, 3, 3), y.shape)
195-
self.assertTrue(jnp.allclose(y, precalculated))
196+
self.assertTrue(np.allclose(y, precalculated))
197+
198+
g = f.jac(b)
199+
self.assertEqual(y.shape + b.shape, g.shape)
200+
self.assertTrue(np.all(g > 0.0))
201+
202+
u = to_var(0.1 * b)
203+
u = f.lpu(b, u, diag=True)
204+
self.assertEqual(y.shape, u.shape)
205+
self.assertTrue(np.all(u > 0.0))
206+
207+
u = to_var(0.1 * b)
208+
u = f.lpu(b, u)
209+
self.assertEqual(y.shape + y.shape, u.shape)
210+
self.assertTrue(np.all(u > 0.0))
196211

197212

198-
class BPolyPointsTest(unittest.TestCase):
213+
class BernsteinPolyTest(unittest.TestCase):
199214
"""
200215
Tests the evaluation of multivariate Bernstein polynomials
201216
against values precalculated with Mathematica.
202217
"""
203218

204-
def test_b_poly_point(self):
219+
def test_bernstein_poly(self):
205220
k = (4, 3, 2)
206221
d = tuple([k_ + 1 for k_ in k])
207-
b = jnp.arange(jnp.prod(jnp.asarray(d))).reshape(d) + 1.0
208-
x = jnp.asarray([0.2718, 0.5772, 0.3141])
209-
y = b_poly_point(b, x)
210-
precalculated = 19.8694
211-
self.assertEqual((), y.shape)
212-
self.assertTrue(jnp.allclose(y, precalculated))
213-
214-
def test_b_poly_points(self):
215-
k = (4, 3, 2)
216-
d = tuple([k_ + 1 for k_ in k])
217-
b = jnp.arange(jnp.prod(jnp.asarray(d))).reshape(d) + 1.0
218-
x = jnp.asarray(
222+
b = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0
223+
x = np.asarray(
219224
[
220225
[0.2718, 0.5772, 0.3141],
221226
[0.5772, 0.3141, 0.2718],
222227
[0.3141, 0.2718, 0.5772],
223228
]
224229
)
225-
y = b_poly_points(b, x)
226-
precalculated = jnp.asarray([19.8694, 32.0761, 19.6774])
230+
f = BernsteinPoly(b)
231+
y = f.eval(b, x)
232+
precalculated = np.asarray([19.8694, 32.0761, 19.6774])
227233
self.assertEqual((3,), y.shape)
228234
self.assertTrue(jnp.allclose(y, precalculated))
229235

236+
g = f.jac_p(b, x)
237+
self.assertEqual((3,) + d, g.shape)
238+
self.assertTrue(np.all(g > 0.0))
239+
240+
241+
def to_var(u: np.ndarray) -> np.ndarray:
242+
"""
243+
Converts standard uncertainty to a diagonal uncertainty tensor.
244+
"""
245+
return np.square(u)
246+
230247

231248
if __name__ == "__main__":
232249
unittest.main()

uncertaintyx/b/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Brockmann Consult GmbH, 2026.
22
# License: MIT
33
"""
4-
A package to provide univariate and multi-variate Bernstein polynomials.
4+
A package to provide univariate and N-variate Bernstein polynomials.
55
"""

uncertaintyx/b/jax.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
# License: MIT
33
import jax
44
import jax.numpy as jnp
5+
import numpy as np
56
from jax import Array
67
from jax.scipy.special import gammaln
78

9+
from uncertaintyx.g.jax import ToG
10+
from uncertaintyx.m.jax import ToM
11+
812

913
def _binom(i: Array, k: int) -> Array:
1014
"""
@@ -154,19 +158,78 @@ def b_poly_point(b: Array, x: Array) -> Array:
154158
return b
155159

156160

157-
@jax.jit
158-
def b_poly_points(b: Array, x: Array) -> Array:
161+
class BernsteinGrid(ToG):
159162
r"""
160-
Evaluates a multivariate Bernstein polynomial on a batch of
161-
irregularly distributed N-dimensional points.
163+
Evaluates an N-variate Bernstein polynomial on a regular
164+
grid of points.
162165
163-
Under the same notation as :meth:`b_poly_point` let
164-
:math:`X \in \mathbb{R}^{M \times N}` be a batch of
165-
points over the outer batch dimension :math:`M `.
166-
Then:
166+
Encapsulates the multivariate Bernstein polynomial
167167
168-
:param b: The Bernstein coefficients :math:`b \in \mathbb{R}^{k + 1}`.
169-
:param x: The points :math:`X \in \mathbb{R}^{M \times N}`.
170-
:returns: The polynomial values :math:`B(x) \in \mathbb{R}^{M}`.
168+
.. math::
169+
B: \mathbb{R}^{k + 1} \to \mathbb{R}^{m},
170+
b \mapsto B_{k}(b, x)
171+
172+
where :math:`N \in \mathbb{N}` is the arity of the Bernstein
173+
polynomial, :math:`k = (k_{1}, \dots, k_{N}) \in \mathbb{N}^{N}`
174+
are its degrees, :math:`\mathbb{R}^{k + 1}` is the tensor space
175+
with dimensions :math:`(k_{1} + 1, \dots, k_{N} + 1)`,
176+
:math:`m = (m_{1}, \dots , m_{N}) \in \mathbb{N}^{N}` are the
177+
dimensions of the grid coordinates :math:`x = (x_{1}, \dots, x_{N})`,
178+
and :math:`\mathbb{R}^{m}` is the tensor space with dimensions
179+
:math:`m`.
180+
"""
181+
182+
def __init__(self, k: tuple[int, ...], x: tuple[np.ndarray, ...]):
183+
"""
184+
Creates a new instance of this class.
185+
186+
:param k: The degrees :math:`k`.
187+
:param x: The grid coordinates :math:`x`.
188+
"""
189+
self._d = tuple([k_ + 1 for k_ in k])
190+
self._x = tuple([jnp.asarray(x_) for x_ in x])
191+
192+
def f(b: Array) -> Array:
193+
return b_poly_grid(b, self._x)
194+
195+
super().__init__(f, rev=False)
196+
197+
def prior(self, preset: str | None = None) -> np.ndarray:
198+
return np.ones(self._d)
199+
200+
201+
class BernsteinPoly(ToM):
202+
r"""
203+
Evaluates an N-variate Bernstein polynomial on a batch of
204+
irregularly distributed points.
205+
206+
Encapsulates the multivariate Bernstein polynomial
207+
208+
.. math::
209+
B: \mathbb{R}^{k + 1} \times \mathbb{R}^{N} \to
210+
\mathbb{R},
211+
(b, x) \mapsto B(b, x)
212+
213+
where :math:`N \in \mathbb{N}` is the arity of the Bernstein
214+
polynomial, :math:`k = (k_{1}, \dots, k_{N}) \in \mathbb{N}^{N}`
215+
denotes its degrees, and :math:`\mathbb{R}^{k + 1}` denotes the
216+
tensor space with dimensions :math:`(k_{1} + 1, \dots, k_{N} + 1)`.
171217
"""
172-
return jax.vmap(b_poly_point, in_axes=(None, 0))(b, x)
218+
219+
def __init__(self, b: np.ndarray):
220+
"""
221+
Creates a new instance of this class.
222+
223+
:param b: The Bernstein coefficients :math:`b \in \mathbb{R}^{k + 1}`.
224+
"""
225+
self._b = b
226+
227+
super().__init__(b_poly_point)
228+
229+
def prior(
230+
self,
231+
x: np.ndarray | None = None,
232+
y: np.ndarray | None = None,
233+
preset: str | None = None,
234+
) -> np.ndarray:
235+
return self._b

uncertaintyx/g/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Brockmann Consult GmbH, 2026.
2+
# License: MIT

uncertaintyx/g/jax.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright (c) Brockmann Consult GmbH, 2026.
2+
# License: MIT
3+
"""
4+
Interface adapters for pure JAX functions.
5+
6+
Adapters employ JAX algorithmic differentiation to compute
7+
derivatives of generic model functions.
8+
"""
9+
10+
from abc import ABC
11+
from typing import Callable
12+
13+
import jax
14+
import jax.numpy as jnp
15+
import numpy as np
16+
from jax import Array
17+
18+
from ..tyx import G
19+
20+
21+
@jax.jit(static_argnums=(0, 2))
22+
def jac(f: Callable[[Array], Array], p: Array, rev: bool = True) -> Array:
23+
"""Returns the Jacobian (does not belong to public API)."""
24+
return jax.jacrev(f)(p) if rev else jax.jacfwd(f)(p)
25+
26+
27+
@jax.jit(static_argnums=(0, 3))
28+
def lpu(d: int, g: Array, u: Array, diag: bool = True) -> Array:
29+
r"""
30+
Implementation of the law of propagation of uncertainty in
31+
general tensor form.
32+
33+
Using Einstein's summation convention and the symmetry of the
34+
parameter uncertainty tensor :math:`U`:, the output uncertainty
35+
tensor reads:
36+
37+
.. math::
38+
V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
39+
40+
with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
41+
for some :math:`d \in \mathbb{N}`. The summation is taken over
42+
all :math:`k, l \in D`.
43+
44+
Here, :math:`D` denotes the set of inner tensor indices
45+
(multi-indices of length :math:`d`), and the trailing tensor
46+
dimensions of :math:`G` and :math:`U` correspond to these
47+
indices.
48+
49+
In what follows, we write :math:`\mathbb{R}^{\cdots \times D}`
50+
for a tensor space whose trailing indices are labelled by the
51+
index set :math:`D`.
52+
53+
:param d: The number of inner tensor dimensions.
54+
:param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
55+
:param u: Tensor :math:`U \in \mathbb{R}^{\cdots \times D}`.
56+
:param diag: To return only variance elements of :math:`V`.
57+
:returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
58+
"""
59+
return make_lpu(d, diag)(g, u)
60+
61+
62+
def make_lpu(d: int, diag: bool = False) -> Callable[[Array, Array], Array]:
63+
"""
64+
Returns the law of propagation of uncertainty.
65+
66+
:param d: The number of inner tensor dimensions.
67+
:param diag: To return only the diagonal elements .
68+
:returns: The law of propagation of uncertainty.
69+
"""
70+
71+
def lpu(g: Array, u: Array) -> Array:
72+
"""The law of propagation of uncertainty."""
73+
dims = tuple(range(-d, 0))
74+
gu = jnp.tensordot(g, u, (dims, dims)) if u.ndim != d else g * u
75+
return (
76+
jnp.tensordot(gu, g, (dims, dims))
77+
if not diag
78+
else jnp.sum(gu * g, dims)
79+
)
80+
81+
return lpu
82+
83+
84+
def jac_no_jit( # pragma: no cover
85+
f: Callable[[Array], Array], p: Array, rev: bool = False
86+
) -> Array:
87+
"""Noncompiled version of :meth:`jac` for debugging."""
88+
return jax.jacrev(f)(p) if rev else jax.jacfwd(f)(p)
89+
90+
91+
class ToG(G, ABC):
92+
r"""
93+
Adapts a pure function
94+
95+
.. math::
96+
f: \mathbb{R}^{k} \to \mathbb{R}^{m}, \quad
97+
p \mapsto f(p)
98+
99+
where :math:`k, m` are shapes (natural numbers or tuples
100+
of natural numbers) to the function interface ``G``.
101+
"""
102+
103+
def __init__(
104+
self,
105+
f: Callable[[Array], Array],
106+
rev: bool = True,
107+
jit: bool = True,
108+
):
109+
"""
110+
Creates a new instance of this class.
111+
112+
:param f: The function :math:`f`.
113+
:param rev: Use reverse mode for the Jacobian.
114+
:param jit: Switches JIT compilation on and off (for debugging).
115+
"""
116+
self._f = jax.jit(f) if jit else f
117+
self._jit = jit
118+
self._rev = rev
119+
120+
def eval(self, p: np.ndarray) -> np.ndarray:
121+
p_ = jnp.asarray(p)
122+
y_ = self._f(p_)
123+
return np.asarray(y_)
124+
125+
def jac(self, p: np.ndarray) -> np.ndarray:
126+
p_ = jnp.asarray(p)
127+
g_ = (
128+
jac(self._f, p_, self._rev)
129+
if self._jit
130+
else jac_no_jit(self._f, p_, self._rev)
131+
)
132+
return np.asarray(g_)
133+
134+
def lpu(
135+
self, p: np.ndarray, u: np.ndarray, diag: bool = False
136+
) -> np.ndarray:
137+
p_ = jnp.asarray(p)
138+
u_ = jnp.asarray(u)
139+
v_ = lpu(p_.ndim, jac(self._f, p_, self._rev), u_, diag)
140+
return np.asarray(v_)
141+
142+
@property
143+
def f(self) -> Callable[[Array, Array], Array]:
144+
return self._f

0 commit comments

Comments
 (0)