Skip to content

Commit 5fe569d

Browse files
committed
Update: code for LPU
1 parent c304064 commit 5fe569d

1 file changed

Lines changed: 17 additions & 33 deletions

File tree

README.md

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -133,40 +133,24 @@ of length $d$), and the trailing tensor dimensions of the Jacobian
133133
tensor $G$ and the input uncertainty tensor $U$ correspond to
134134
these indices. The code below provides an implementation.
135135

136-
def propagate(d: int, g: np.ndarray, u: np.ndarray) -> np.ndarray:
137-
r"""
138-
Default implementation of the law of propagation of uncertainty
139-
in general tensor form.
140-
141-
Using Einstein's summation convention and the symmetry of the
142-
input uncertainty tensor :math:`U`, the output uncertainty
143-
tensor reads:
144-
145-
.. math::
146-
V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
147-
148-
with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
149-
for some :math:`d \in \mathbb{N}`. The summation is taken over
150-
all :math:`k, l \in D`.
151-
152-
Here, :math:`D` denotes the set of inner tensor indices
153-
(multi-indices of length :math:`d`), and the trailing tensor
154-
dimensions of :math:`G` and :math:`U` correspond to these
155-
indices.
156-
157-
In what follows, we write :math:`\mathbb{R}^{\cdots \times D}`
158-
for a tensor space whose trailing indices are labelled by the
159-
index set :math:`D`.
160-
161-
:param d: The number of inner tensor dimensions.
162-
:param g: Jacobian tensor :math:`G \in \mathbb{R}^{\cdots \times D}`.
163-
:param u: Uncertainty tensor :math:`U \in \mathbb{R}^{\cdots \times D}`.
164-
:returns: Uncertainty tensor :math:`V \in \mathbb{R}^{\cdots}`.
165-
"""
136+
import jax.numpy as jnp
137+
from jax import Array
138+
139+
def make_lpu(d: int) -> Callable[[Array, Array], Array]:
140+
"""
141+
Returns the law of propagation of uncertainty.
142+
143+
:param d: The number of inner tensor dimensions.
144+
:returns: The law of propagation of uncertainty.
145+
"""
146+
147+
def lpu(g: Array, u: Array) -> Array:
148+
"""The law of propagation of uncertainty."""
166149
dims = tuple(range(-d, 0))
167-
return np.tensordot(
168-
np.tensordot(g, u, axes=(dims, dims)), g, axes=(dims, dims)
169-
)
150+
return jnp.tensordot(jnp.tensordot(g, u, (dims, dims)), g, (dims, dims))
151+
152+
return lpu
153+
170154

171155
[![CodeQL Advanced](https://github.com/bcdev/uncertaintyx/actions/workflows/codeql.yml/badge.svg)](https://github.com/bcdev/uncertaintyx/actions/workflows/codeql.yml)
172156
[![Python package](https://github.com/bcdev/uncertaintyx/actions/workflows/python-package.yml/badge.svg)](https://github.com/bcdev/uncertaintyx/actions/workflows/python-package.yml)

0 commit comments

Comments
 (0)