Skip to content

Commit c18d672

Browse files
committed
Update: PyTorch LPU implementation
1 parent d400c44 commit c18d672

2 files changed

Lines changed: 190 additions & 25 deletions

File tree

uncertaintyx/f/torch.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,58 @@ def jac(f: Callable[[Tensor], Tensor], x: Tensor, rev: bool = True) -> Tensor:
3939
)
4040

4141

42+
@torch.compile
43+
def lpu(d: int, g: Tensor, u: Tensor, diag: bool = False) -> Tensor:
44+
r"""
45+
Implementation of the law of propagation of uncertainty in
46+
general tensor form.
47+
48+
Using Einstein's summation convention and the symmetry of the
49+
input uncertainty tensor :math:`U`:, the output uncertainty
50+
tensor reads:
51+
52+
.. math::
53+
V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
54+
55+
with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
56+
for some :math:`d \in \mathbb{N}`. The summation is taken over
57+
all :math:`k, l \in D`.
58+
59+
Under the same notation as :meth:`lpu_p`:
60+
61+
:param d: The number of inner tensor dimensions.
62+
:param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
63+
:param u: Tensor :math:`U \in \mathbb{R}^{M \times \cdots \times D}`.
64+
:param diag: To return only variance elements of :math:`V`.
65+
:returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
66+
"""
67+
return torch.vmap(make_lpu(d, diag), in_dims=(0, 0))(g, u)
68+
69+
70+
def make_lpu(
71+
d: int, diag: bool = False
72+
) -> Callable[[Tensor, Tensor], Tensor]:
73+
"""
74+
Returns the law of propagation of uncertainty.
75+
76+
:param d: The number of inner tensor dimensions.
77+
:param diag: To return only variance elements .
78+
:returns: The law of propagation of uncertainty.
79+
"""
80+
81+
def lpu(g: Tensor, u: Tensor) -> Tensor:
82+
"""The law of propagation of uncertainty."""
83+
dims = list(range(-d, 0))
84+
gu = torch.tensordot(g, u, (dims, dims)) if u.ndim != d else g * u
85+
return (
86+
torch.tensordot(gu, g, (dims, dims))
87+
if not diag
88+
else torch.sum(gu * g, dim=dims)
89+
)
90+
91+
return lpu
92+
93+
4294
def vec(f: Callable[[Tensor], Tensor], x: Tensor) -> Tensor:
4395
r"""
4496
Evaluates :math:`f(X)`.
@@ -98,18 +150,26 @@ def __init__(
98150
self._jit = jit
99151

100152
def eval(self, x: np.ndarray) -> np.ndarray:
101-
x_t = torch.from_numpy(x)
102-
y_t = vec(self._f, x_t) if self._jit else vec_no_jit(self._f, x_t)
153+
x_ = torch.from_numpy(x)
154+
y_t = vec(self._f, x_) if self._jit else vec_no_jit(self._f, x_)
103155
return y_t.detach().numpy()
104156

105157
def jac(self, x: np.ndarray) -> np.ndarray:
106-
x_t = torch.from_numpy(x).requires_grad_(True)
107-
g_t = (
108-
jac(self._f, x_t, self._rev)
158+
x_ = torch.from_numpy(x).requires_grad_(True)
159+
g_ = (
160+
jac(self._f, x_, self._rev)
109161
if self._jit
110-
else jac_no_jit(self._f, x_t, self._rev)
162+
else jac_no_jit(self._f, x_, self._rev)
111163
)
112-
return g_t.detach().numpy()
164+
return g_.detach().numpy()
165+
166+
def lpu(
167+
self, x: np.ndarray, u: np.ndarray, diag: bool = False
168+
) -> np.ndarray:
169+
x_ = torch.from_numpy(x).requires_grad_(True)
170+
u_ = torch.from_numpy(u)
171+
v_ = lpu(x_.ndim - 1, jac(self._f, x_, self._rev), u_, diag)
172+
return v_.detach().numpy()
113173

114174
@property
115175
def f(self) -> Callable[[Tensor], Tensor]:

uncertaintyx/m/torch.py

Lines changed: 123 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,93 @@ def jac(f: Callable[[Tensor, Tensor], Tensor], arg: int, rev: bool = True):
7878
)
7979

8080

81+
@torch.compile
82+
def lpu_p(d: int, g: Tensor, u: Tensor, diag: bool = False) -> Tensor:
83+
r"""
84+
Implementation of the law of propagation of uncertainty in
85+
general tensor form.
86+
87+
Using Einstein's summation convention and the symmetry of the
88+
parameter uncertainty tensor :math:`U`:, the output uncertainty
89+
tensor reads:
90+
91+
.. math::
92+
V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
93+
94+
with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
95+
for some :math:`d \in \mathbb{N}`. The summation is taken over
96+
all :math:`k, l \in D`.
97+
98+
Here, :math:`D` denotes the set of inner tensor indices
99+
(multi-indices of length :math:`d`), and the trailing tensor
100+
dimensions of :math:`G` and :math:`U` correspond to these
101+
indices.
102+
103+
In what follows, we write :math:`\mathbb{R}^{\cdots \times D}`
104+
for a tensor space whose trailing indices are labelled by the
105+
index set :math:`D`.
106+
107+
:param d: The number of inner tensor dimensions.
108+
:param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
109+
:param u: Tensor :math:`U \in \mathbb{R}^{\cdots \times D}`.
110+
:param diag: To return only variance elements of :math:`V`.
111+
:returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
112+
"""
113+
return torch.vmap(make_lpu(d, diag), in_dims=(0, None))(g, u)
114+
115+
116+
@torch.compile
117+
def lpu_x(d: int, g: Tensor, u: Tensor, diag: bool = False) -> Tensor:
118+
r"""
119+
Implementation of the law of propagation of uncertainty in
120+
general tensor form (for input uncertainty tensors).
121+
122+
Using Einstein's summation convention and the symmetry of the
123+
input uncertainty tensor :math:`U`:, the output uncertainty
124+
tensor reads:
125+
126+
.. math::
127+
V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
128+
129+
with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
130+
for some :math:`d \in \mathbb{N}`. The summation is taken over
131+
all :math:`k, l \in D`.
132+
133+
Under the same notation as :meth:`lpu_p`:
134+
135+
:param d: The number of inner tensor dimensions.
136+
:param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
137+
:param u: Tensor :math:`U \in \mathbb{R}^{M \times \cdots \times D}`.
138+
:param diag: To return only variance elements of :math:`V`.
139+
:returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
140+
"""
141+
return torch.vmap(make_lpu(d, diag), in_dims=(0, 0))(g, u)
142+
143+
144+
def make_lpu(
145+
d: int, diag: bool = False
146+
) -> Callable[[Tensor, Tensor], Tensor]:
147+
"""
148+
Returns the law of propagation of uncertainty.
149+
150+
:param d: The number of inner tensor dimensions.
151+
:param diag: To return only the diagonal elements .
152+
:returns: The law of propagation of uncertainty.
153+
"""
154+
155+
def lpu(g: Tensor, u: Tensor) -> Tensor:
156+
"""The law of propagation of uncertainty."""
157+
dims = list(range(-d, 0))
158+
gu = torch.tensordot(g, u, (dims, dims)) if u.ndim != d else g * u
159+
return (
160+
torch.tensordot(gu, g, (dims, dims))
161+
if not diag
162+
else torch.sum(gu * g, dim=dims)
163+
)
164+
165+
return lpu
166+
167+
81168
@torch.compile
82169
def vec_x(
83170
f: Callable[[Tensor, Tensor], Tensor], p: Tensor, x: Tensor
@@ -150,34 +237,52 @@ def __init__(
150237
self._jit = jit
151238

152239
def eval(self, p: np.ndarray, x: np.ndarray) -> np.ndarray:
153-
p_t = torch.from_numpy(p)
154-
x_t = torch.from_numpy(x)
155-
y_t = (
156-
vec_x(self._f, p_t, x_t)
240+
p_ = torch.from_numpy(p)
241+
x_ = torch.from_numpy(x)
242+
y_ = (
243+
vec_x(self._f, p_, x_)
157244
if self._jit
158-
else vec_x_no_jit(self._f, p_t, x_t)
245+
else vec_x_no_jit(self._f, p_, x_)
159246
)
160-
return y_t.detach().numpy()
247+
return y_.detach().numpy()
161248

162249
def jac_p(self, p: np.ndarray, x: np.ndarray) -> np.ndarray:
163-
p_t = torch.from_numpy(p).requires_grad_(True)
164-
x_t = torch.from_numpy(x)
165-
g_t = (
166-
jac_p(self._f, p_t, x_t, self._rev_p)
250+
p_ = torch.from_numpy(p).requires_grad_(True)
251+
x_ = torch.from_numpy(x)
252+
g_ = (
253+
jac_p(self._f, p_, x_, self._rev_p)
167254
if self._jit
168-
else jac_p_no_jit(self._f, p_t, x_t, self._rev_p)
255+
else jac_p_no_jit(self._f, p_, x_, self._rev_p)
169256
)
170-
return g_t.detach().numpy()
257+
return g_.detach().numpy()
171258

172259
def jac_x(self, p: np.ndarray, x: np.ndarray) -> np.ndarray:
173-
p_t = torch.from_numpy(p)
174-
x_t = torch.from_numpy(x).requires_grad_(True)
175-
g_t = (
176-
jac_x(self._f, p_t, x_t, self._rev_x)
260+
p_ = torch.from_numpy(p)
261+
x_ = torch.from_numpy(x).requires_grad_(True)
262+
g_ = (
263+
jac_x(self._f, p_, x_, self._rev_x)
177264
if self._jit
178-
else jac_x_no_jit(self._f, p_t, x_t, self._rev_x)
265+
else jac_x_no_jit(self._f, p_, x_, self._rev_x)
179266
)
180-
return g_t.detach().numpy()
267+
return g_.detach().numpy()
268+
269+
def lpu_p(
270+
self, p: np.ndarray, u: np.ndarray, x: np.ndarray, diag: bool = False
271+
) -> np.ndarray:
272+
p_ = torch.from_numpy(p).requires_grad_(True)
273+
u_ = torch.from_numpy(u)
274+
x_ = torch.from_numpy(x)
275+
v_ = lpu_p(p_.ndim, jac_p(self._f, p_, x_, self._rev_p), u_, diag)
276+
return v_.detach().numpy()
277+
278+
def lpu_x(
279+
self, p: np.ndarray, x: np.ndarray, u: np.ndarray, diag: bool = False
280+
) -> np.ndarray:
281+
p_ = torch.from_numpy(p)
282+
x_ = torch.from_numpy(x).requires_grad_(True)
283+
u_ = torch.from_numpy(u)
284+
v_ = lpu_x(x_.ndim - 1, jac_x(self._f, p_, x_, self._rev_x), u_, diag)
285+
return v_.detach().numpy()
181286

182287
@property
183288
def f(self) -> Callable[[Tensor, Tensor], Tensor]:

0 commit comments

Comments
 (0)