Skip to content

Commit d400c44

Browse files
committed
Update: JAX LPU implementation
1 parent b90a471 commit d400c44

2 files changed

Lines changed: 187 additions & 26 deletions

File tree

uncertaintyx/f/jax.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,56 @@ def jac(f: Callable[[Array], Array], x: Array, rev: bool = True) -> Array:
3939
return jax.vmap(jax.jacrev(f) if rev else jax.jacfwd(f))(x)
4040

4141

42+
@jax.jit(static_argnums=(0, 3))
43+
def lpu(d: int, g: Array, u: Array, diag: bool = False) -> Array:
44+
r"""
45+
Implementation of the law of propagation of uncertainty in
46+
general tensor form.
47+
48+
Using Einstein's summation convention and the symmetry of the
49+
input uncertainty tensor :math:`U`:, the output uncertainty
50+
tensor reads:
51+
52+
.. math::
53+
V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
54+
55+
with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
56+
for some :math:`d \in \mathbb{N}`. The summation is taken over
57+
all :math:`k, l \in D`.
58+
59+
Under the same notation as :meth:`lpu_p`:
60+
61+
:param d: The number of inner tensor dimensions.
62+
:param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
63+
:param u: Tensor :math:`U \in \mathbb{R}^{M \times \cdots \times D}`.
64+
:param diag: To return only variance elements of :math:`V`.
65+
:returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
66+
"""
67+
return jax.vmap(make_lpu(d, diag), in_axes=(0, 0))(g, u)
68+
69+
70+
def make_lpu(d: int, diag: bool = False) -> Callable[[Array, Array], Array]:
71+
"""
72+
Returns the law of propagation of uncertainty.
73+
74+
:param d: The number of inner tensor dimensions.
75+
:param diag: To return only variance elements .
76+
:returns: The law of propagation of uncertainty.
77+
"""
78+
79+
def lpu(g: Array, u: Array) -> Array:
80+
"""The law of propagation of uncertainty."""
81+
dims = tuple(range(-d, 0))
82+
gu = jnp.tensordot(g, u, (dims, dims)) if u.ndim != d else g * u
83+
return (
84+
jnp.tensordot(gu, g, (dims, dims))
85+
if not diag
86+
else jnp.sum(gu * g, dims)
87+
)
88+
89+
return lpu
90+
91+
4292
@jax.jit(static_argnums=(0,))
4393
def vec(f: Callable[[Array], Array], x: Array) -> Array:
4494
r"""
@@ -94,18 +144,26 @@ def __init__(
94144
self._jit = jit
95145

96146
def eval(self, x: np.ndarray) -> np.ndarray:
97-
x_j = jnp.asarray(x)
98-
y_j = vec(self._f, x_j) if self._jit else vec_no_jit(self._f, x_j)
99-
return np.asarray(y_j)
147+
x_ = jnp.asarray(x)
148+
y_ = vec(self._f, x_) if self._jit else vec_no_jit(self._f, x_)
149+
return np.asarray(y_)
100150

101151
def jac(self, x: np.ndarray) -> np.ndarray:
102-
x_j = jnp.asarray(x)
103-
g_j = (
104-
jac(self._f, x_j, self._rev)
152+
x_ = jnp.asarray(x)
153+
g_ = (
154+
jac(self._f, x_, self._rev)
105155
if self._jit
106-
else jac_no_jit(self._f, x_j, self._rev)
156+
else jac_no_jit(self._f, x_, self._rev)
107157
)
108-
return np.asarray(g_j)
158+
return np.asarray(g_)
159+
160+
def lpu(
161+
self, x: np.ndarray, u: np.ndarray, diag: bool = False
162+
) -> np.ndarray:
163+
x_ = jnp.asarray(x)
164+
u_ = jnp.asarray(u)
165+
v_ = lpu(x_.ndim - 1, jac(self._f, x_, self._rev), u_, diag)
166+
return np.asarray(v_)
109167

110168
@property
111169
def f(self) -> Callable[[Array], Array]:

uncertaintyx/m/jax.py

Lines changed: 121 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,91 @@ def jac(f: Callable[[Array, Array], Array], arg: int, rev: bool = True):
6969
return jax.jacrev(f, argnums=arg) if rev else jax.jacfwd(f, argnums=arg)
7070

7171

72+
@jax.jit(static_argnums=(0, 3))
73+
def lpu_p(d: int, g: Array, u: Array, diag: bool = False) -> Array:
74+
r"""
75+
Implementation of the law of propagation of uncertainty in
76+
general tensor form.
77+
78+
Using Einstein's summation convention and the symmetry of the
79+
parameter uncertainty tensor :math:`U`:, the output uncertainty
80+
tensor reads:
81+
82+
.. math::
83+
V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
84+
85+
with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
86+
for some :math:`d \in \mathbb{N}`. The summation is taken over
87+
all :math:`k, l \in D`.
88+
89+
Here, :math:`D` denotes the set of inner tensor indices
90+
(multi-indices of length :math:`d`), and the trailing tensor
91+
dimensions of :math:`G` and :math:`U` correspond to these
92+
indices.
93+
94+
In what follows, we write :math:`\mathbb{R}^{\cdots \times D}`
95+
for a tensor space whose trailing indices are labelled by the
96+
index set :math:`D`.
97+
98+
:param d: The number of inner tensor dimensions.
99+
:param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
100+
:param u: Tensor :math:`U \in \mathbb{R}^{\cdots \times D}`.
101+
:param diag: To return only variance elements of :math:`V`.
102+
:returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
103+
"""
104+
return jax.vmap(make_lpu(d, diag), in_axes=(0, None))(g, u)
105+
106+
107+
@jax.jit(static_argnums=(0, 3))
108+
def lpu_x(d: int, g: Array, u: Array, diag: bool = False) -> Array:
109+
r"""
110+
Implementation of the law of propagation of uncertainty in
111+
general tensor form (for input uncertainty tensors).
112+
113+
Using Einstein's summation convention and the symmetry of the
114+
input uncertainty tensor :math:`U`:, the output uncertainty
115+
tensor reads:
116+
117+
.. math::
118+
V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
119+
120+
with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
121+
for some :math:`d \in \mathbb{N}`. The summation is taken over
122+
all :math:`k, l \in D`.
123+
124+
Under the same notation as :meth:`lpu_p`:
125+
126+
:param d: The number of inner tensor dimensions.
127+
:param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
128+
:param u: Tensor :math:`U \in \mathbb{R}^{M \times \cdots \times D}`.
129+
:param diag: To return only variance elements of :math:`V`.
130+
:returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
131+
"""
132+
return jax.vmap(make_lpu(d, diag), in_axes=(0, 0))(g, u)
133+
134+
135+
def make_lpu(d: int, diag: bool = False) -> Callable[[Array, Array], Array]:
136+
"""
137+
Returns the law of propagation of uncertainty.
138+
139+
:param d: The number of inner tensor dimensions.
140+
:param diag: To return only the diagonal elements .
141+
:returns: The law of propagation of uncertainty.
142+
"""
143+
144+
def lpu(g: Array, u: Array) -> Array:
145+
"""The law of propagation of uncertainty."""
146+
dims = tuple(range(-d, 0))
147+
gu = jnp.tensordot(g, u, (dims, dims)) if u.ndim != d else g * u
148+
return (
149+
jnp.tensordot(gu, g, (dims, dims))
150+
if not diag
151+
else jnp.sum(gu * g, dims)
152+
)
153+
154+
return lpu
155+
156+
72157
@jax.jit(static_argnums=(0,))
73158
def vec_x(f: Callable[[Array, Array], Array], p: Array, x: Array) -> Array:
74159
r"""
@@ -139,34 +224,52 @@ def __init__(
139224
self._rev_x = rev_x
140225

141226
def eval(self, p: np.ndarray, x: np.ndarray) -> np.ndarray:
142-
p_j = jnp.asarray(p)
143-
x_j = jnp.asarray(x)
144-
y_j = (
145-
vec_x(self._f, p_j, x_j)
227+
p_ = jnp.asarray(p)
228+
x_ = jnp.asarray(x)
229+
y_ = (
230+
vec_x(self._f, p_, x_)
146231
if self._jit
147-
else vec_x_no_jit(self._f, p_j, x_j)
232+
else vec_x_no_jit(self._f, p_, x_)
148233
)
149-
return np.asarray(y_j)
234+
return np.asarray(y_)
150235

151236
def jac_p(self, p: np.ndarray, x: np.ndarray) -> np.ndarray:
152-
p_j = jnp.asarray(p)
153-
x_j = jnp.asarray(x)
154-
g_j = (
155-
jac_p(self._f, p_j, x_j, self._rev_p)
237+
p_ = jnp.asarray(p)
238+
x_ = jnp.asarray(x)
239+
g_ = (
240+
jac_p(self._f, p_, x_, self._rev_p)
156241
if self._jit
157-
else jac_p_no_jit(self._f, p_j, x_j, self._rev_p)
242+
else jac_p_no_jit(self._f, p_, x_, self._rev_p)
158243
)
159-
return np.asarray(g_j)
244+
return np.asarray(g_)
160245

161246
def jac_x(self, p: np.ndarray, x: np.ndarray) -> np.ndarray:
162-
p_j = jnp.asarray(p)
163-
x_j = jnp.asarray(x)
164-
g_j = (
165-
jac_x(self._f, p_j, x_j, self._rev_x)
247+
p_ = jnp.asarray(p)
248+
x_ = jnp.asarray(x)
249+
g_ = (
250+
jac_x(self._f, p_, x_, self._rev_x)
166251
if self._jit
167-
else jac_x_no_jit(self._f, p_j, x_j, self._rev_x)
252+
else jac_x_no_jit(self._f, p_, x_, self._rev_x)
168253
)
169-
return np.asarray(g_j)
254+
return np.asarray(g_)
255+
256+
def lpu_p(
257+
self, p: np.ndarray, u: np.ndarray, x: np.ndarray, diag: bool = False
258+
) -> np.ndarray:
259+
p_ = jnp.asarray(p)
260+
u_ = jnp.asarray(u)
261+
x_ = jnp.asarray(x)
262+
v_ = lpu_p(p_.ndim, jac_p(self._f, p_, x_, self._rev_p), u_, diag)
263+
return np.asarray(v_)
264+
265+
def lpu_x(
266+
self, p: np.ndarray, x: np.ndarray, u: np.ndarray, diag: bool = False
267+
) -> np.ndarray:
268+
p_ = jnp.asarray(p)
269+
x_ = jnp.asarray(x)
270+
u_ = jnp.asarray(u)
271+
v_ = lpu_x(x_.ndim - 1, jac_x(self._f, p_, x_, self._rev_x), u_, diag)
272+
return np.asarray(v_)
170273

171274
@property
172275
def f(self) -> Callable[[Array, Array], Array]:

0 commit comments

Comments
 (0)