@@ -133,40 +133,24 @@ of length $d$), and the trailing tensor dimensions of the Jacobian
133133tensor $G$ and the input uncertainty tensor $U$ correspond to
134134these indices. The code below provides an implementation.
135135
136- def propagate(d: int, g: np.ndarray, u: np.ndarray) -> np.ndarray:
137- r"""
138- Default implementation of the law of propagation of uncertainty
139- in general tensor form.
140-
141- Using Einstein's summation convention and the symmetry of the
142- input uncertainty tensor :math:`U`, the output uncertainty
143- tensor reads:
144-
145- .. math::
146- V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
147-
148- with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
149- for some :math:`d \in \mathbb{N}`. The summation is taken over
150- all :math:`k, l \in D`.
151-
152- Here, :math:`D` denotes the set of inner tensor indices
153- (multi-indices of length :math:`d`), and the trailing tensor
154- dimensions of :math:`G` and :math:`U` correspond to these
155- indices.
156-
157- In what follows, we write :math:`\mathbb{R}^{\cdots \times D}`
158- for a tensor space whose trailing indices are labelled by the
159- index set :math:`D`.
160-
161- :param d: The number of inner tensor dimensions.
162- :param g: Jacobian tensor :math:`G \in \mathbb{R}^{\cdots \times D}`.
163- :param u: Uncertainty tensor :math:`U \in \mathbb{R}^{\cdots \times D}`.
164- :returns: Uncertainty tensor :math:`V \in \mathbb{R}^{\cdots}`.
165- """
136+ import jax.numpy as jnp
137+ from jax import Array
138+
139+ def make_lpu(d: int) -> Callable[[Array, Array], Array]:
140+ """
141+ Returns the law of propagation of uncertainty.
142+
143+ :param d: The number of inner tensor dimensions.
144+ :returns: The law of propagation of uncertainty.
145+ """
146+
147+ def lpu(g: Array, u: Array) -> Array:
148+ """The law of propagation of uncertainty."""
166149 dims = tuple(range(-d, 0))
167- return np.tensordot(
168- np.tensordot(g, u, axes=(dims, dims)), g, axes=(dims, dims)
169- )
150+ return jnp.tensordot(jnp.tensordot(g, u, (dims, dims)), g, (dims, dims))
151+
152+ return lpu
153+
170154
171155[ ![ CodeQL Advanced] ( https://github.com/bcdev/uncertaintyx/actions/workflows/codeql.yml/badge.svg )] ( https://github.com/bcdev/uncertaintyx/actions/workflows/codeql.yml )
172156[ ![ Python package] ( https://github.com/bcdev/uncertaintyx/actions/workflows/python-package.yml/badge.svg )] ( https://github.com/bcdev/uncertaintyx/actions/workflows/python-package.yml )
0 commit comments