@@ -69,6 +69,91 @@ def jac(f: Callable[[Array, Array], Array], arg: int, rev: bool = True):
6969 return jax .jacrev (f , argnums = arg ) if rev else jax .jacfwd (f , argnums = arg )
7070
7171
72+ @jax .jit (static_argnums = (0 , 3 ))
73+ def lpu_p (d : int , g : Array , u : Array , diag : bool = False ) -> Array :
74+ r"""
75+ Implementation of the law of propagation of uncertainty in
76+ general tensor form.
77+
78+ Using Einstein's summation convention and the symmetry of the
79+ parameter uncertainty tensor :math:`U`:, the output uncertainty
80+ tensor reads:
81+
82+ .. math::
83+ V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
84+
85+ with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
86+ for some :math:`d \in \mathbb{N}`. The summation is taken over
87+ all :math:`k, l \in D`.
88+
89+ Here, :math:`D` denotes the set of inner tensor indices
90+ (multi-indices of length :math:`d`), and the trailing tensor
91+ dimensions of :math:`G` and :math:`U` correspond to these
92+ indices.
93+
94+ In what follows, we write :math:`\mathbb{R}^{\cdots \times D}`
95+ for a tensor space whose trailing indices are labelled by the
96+ index set :math:`D`.
97+
98+ :param d: The number of inner tensor dimensions.
99+ :param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
100+ :param u: Tensor :math:`U \in \mathbb{R}^{\cdots \times D}`.
101+ :param diag: To return only variance elements of :math:`V`.
102+ :returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
103+ """
104+ return jax .vmap (make_lpu (d , diag ), in_axes = (0 , None ))(g , u )
105+
106+
107+ @jax .jit (static_argnums = (0 , 3 ))
108+ def lpu_x (d : int , g : Array , u : Array , diag : bool = False ) -> Array :
109+ r"""
110+ Implementation of the law of propagation of uncertainty in
111+ general tensor form (for input uncertainty tensors).
112+
113+ Using Einstein's summation convention and the symmetry of the
114+ input uncertainty tensor :math:`U`:, the output uncertainty
115+ tensor reads:
116+
117+ .. math::
118+ V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
119+
120+ with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
121+ for some :math:`d \in \mathbb{N}`. The summation is taken over
122+ all :math:`k, l \in D`.
123+
124+ Under the same notation as :meth:`lpu_p`:
125+
126+ :param d: The number of inner tensor dimensions.
127+ :param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
128+ :param u: Tensor :math:`U \in \mathbb{R}^{M \times \cdots \times D}`.
129+ :param diag: To return only variance elements of :math:`V`.
130+ :returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
131+ """
132+ return jax .vmap (make_lpu (d , diag ), in_axes = (0 , 0 ))(g , u )
133+
134+
135+ def make_lpu (d : int , diag : bool = False ) -> Callable [[Array , Array ], Array ]:
136+ """
137+ Returns the law of propagation of uncertainty.
138+
139+ :param d: The number of inner tensor dimensions.
140+ :param diag: To return only the diagonal elements .
141+ :returns: The law of propagation of uncertainty.
142+ """
143+
144+ def lpu (g : Array , u : Array ) -> Array :
145+ """The law of propagation of uncertainty."""
146+ dims = tuple (range (- d , 0 ))
147+ gu = jnp .tensordot (g , u , (dims , dims )) if u .ndim != d else g * u
148+ return (
149+ jnp .tensordot (gu , g , (dims , dims ))
150+ if not diag
151+ else jnp .sum (gu * g , dims )
152+ )
153+
154+ return lpu
155+
156+
72157@jax .jit (static_argnums = (0 ,))
73158def vec_x (f : Callable [[Array , Array ], Array ], p : Array , x : Array ) -> Array :
74159 r"""
@@ -139,34 +224,52 @@ def __init__(
139224 self ._rev_x = rev_x
140225
141226 def eval (self , p : np .ndarray , x : np .ndarray ) -> np .ndarray :
142- p_j = jnp .asarray (p )
143- x_j = jnp .asarray (x )
144- y_j = (
145- vec_x (self ._f , p_j , x_j )
227+ p_ = jnp .asarray (p )
228+ x_ = jnp .asarray (x )
229+ y_ = (
230+ vec_x (self ._f , p_ , x_ )
146231 if self ._jit
147- else vec_x_no_jit (self ._f , p_j , x_j )
232+ else vec_x_no_jit (self ._f , p_ , x_ )
148233 )
149- return np .asarray (y_j )
234+ return np .asarray (y_ )
150235
151236 def jac_p (self , p : np .ndarray , x : np .ndarray ) -> np .ndarray :
152- p_j = jnp .asarray (p )
153- x_j = jnp .asarray (x )
154- g_j = (
155- jac_p (self ._f , p_j , x_j , self ._rev_p )
237+ p_ = jnp .asarray (p )
238+ x_ = jnp .asarray (x )
239+ g_ = (
240+ jac_p (self ._f , p_ , x_ , self ._rev_p )
156241 if self ._jit
157- else jac_p_no_jit (self ._f , p_j , x_j , self ._rev_p )
242+ else jac_p_no_jit (self ._f , p_ , x_ , self ._rev_p )
158243 )
159- return np .asarray (g_j )
244+ return np .asarray (g_ )
160245
161246 def jac_x (self , p : np .ndarray , x : np .ndarray ) -> np .ndarray :
162- p_j = jnp .asarray (p )
163- x_j = jnp .asarray (x )
164- g_j = (
165- jac_x (self ._f , p_j , x_j , self ._rev_x )
247+ p_ = jnp .asarray (p )
248+ x_ = jnp .asarray (x )
249+ g_ = (
250+ jac_x (self ._f , p_ , x_ , self ._rev_x )
166251 if self ._jit
167- else jac_x_no_jit (self ._f , p_j , x_j , self ._rev_x )
252+ else jac_x_no_jit (self ._f , p_ , x_ , self ._rev_x )
168253 )
169- return np .asarray (g_j )
254+ return np .asarray (g_ )
255+
256+ def lpu_p (
257+ self , p : np .ndarray , u : np .ndarray , x : np .ndarray , diag : bool = False
258+ ) -> np .ndarray :
259+ p_ = jnp .asarray (p )
260+ u_ = jnp .asarray (u )
261+ x_ = jnp .asarray (x )
262+ v_ = lpu_p (p_ .ndim , jac_p (self ._f , p_ , x_ , self ._rev_p ), u_ , diag )
263+ return np .asarray (v_ )
264+
265+ def lpu_x (
266+ self , p : np .ndarray , x : np .ndarray , u : np .ndarray , diag : bool = False
267+ ) -> np .ndarray :
268+ p_ = jnp .asarray (p )
269+ x_ = jnp .asarray (x )
270+ u_ = jnp .asarray (u )
271+ v_ = lpu_x (x_ .ndim - 1 , jac_x (self ._f , p_ , x_ , self ._rev_x ), u_ , diag )
272+ return np .asarray (v_ )
170273
171274 @property
172275 def f (self ) -> Callable [[Array , Array ], Array ]:
0 commit comments