@@ -78,6 +78,93 @@ def jac(f: Callable[[Tensor, Tensor], Tensor], arg: int, rev: bool = True):
7878 )
7979
8080
81+ @torch .compile
82+ def lpu_p (d : int , g : Tensor , u : Tensor , diag : bool = False ) -> Tensor :
83+ r"""
84+ Implementation of the law of propagation of uncertainty in
85+ general tensor form.
86+
87+ Using Einstein's summation convention and the symmetry of the
88+ parameter uncertainty tensor :math:`U`:, the output uncertainty
89+ tensor reads:
90+
91+ .. math::
92+ V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
93+
94+ with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
95+ for some :math:`d \in \mathbb{N}`. The summation is taken over
96+ all :math:`k, l \in D`.
97+
98+ Here, :math:`D` denotes the set of inner tensor indices
99+ (multi-indices of length :math:`d`), and the trailing tensor
100+ dimensions of :math:`G` and :math:`U` correspond to these
101+ indices.
102+
103+ In what follows, we write :math:`\mathbb{R}^{\cdots \times D}`
104+ for a tensor space whose trailing indices are labelled by the
105+ index set :math:`D`.
106+
107+ :param d: The number of inner tensor dimensions.
108+ :param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
109+ :param u: Tensor :math:`U \in \mathbb{R}^{\cdots \times D}`.
110+ :param diag: To return only variance elements of :math:`V`.
111+ :returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
112+ """
113+ return torch .vmap (make_lpu (d , diag ), in_dims = (0 , None ))(g , u )
114+
115+
116+ @torch .compile
117+ def lpu_x (d : int , g : Tensor , u : Tensor , diag : bool = False ) -> Tensor :
118+ r"""
119+ Implementation of the law of propagation of uncertainty in
120+ general tensor form (for input uncertainty tensors).
121+
122+ Using Einstein's summation convention and the symmetry of the
123+ input uncertainty tensor :math:`U`:, the output uncertainty
124+ tensor reads:
125+
126+ .. math::
127+ V_{\dots ij} = G_{\dots ik}U_{\dots lk}G_{\dots jl}
128+
129+ with multi-indices :math:`k, l \in D \subset \mathbb{N}^d`
130+ for some :math:`d \in \mathbb{N}`. The summation is taken over
131+ all :math:`k, l \in D`.
132+
133+ Under the same notation as :meth:`lpu_p`:
134+
135+ :param d: The number of inner tensor dimensions.
136+ :param g: Jacobian :math:`G \in \mathbb{R}^{M \times \cdots \times D}`.
137+ :param u: Tensor :math:`U \in \mathbb{R}^{M \times \cdots \times D}`.
138+ :param diag: To return only variance elements of :math:`V`.
139+ :returns: Tensor :math:`V \in \mathbb{R}^{M \times \cdots}`.
140+ """
141+ return torch .vmap (make_lpu (d , diag ), in_dims = (0 , 0 ))(g , u )
142+
143+
144+ def make_lpu (
145+ d : int , diag : bool = False
146+ ) -> Callable [[Tensor , Tensor ], Tensor ]:
147+ """
148+ Returns the law of propagation of uncertainty.
149+
150+ :param d: The number of inner tensor dimensions.
151+ :param diag: To return only the diagonal elements .
152+ :returns: The law of propagation of uncertainty.
153+ """
154+
155+ def lpu (g : Tensor , u : Tensor ) -> Tensor :
156+ """The law of propagation of uncertainty."""
157+ dims = list (range (- d , 0 ))
158+ gu = torch .tensordot (g , u , (dims , dims )) if u .ndim != d else g * u
159+ return (
160+ torch .tensordot (gu , g , (dims , dims ))
161+ if not diag
162+ else torch .sum (gu * g , dim = dims )
163+ )
164+
165+ return lpu
166+
167+
81168@torch .compile
82169def vec_x (
83170 f : Callable [[Tensor , Tensor ], Tensor ], p : Tensor , x : Tensor
@@ -150,34 +237,52 @@ def __init__(
150237 self ._jit = jit
151238
152239 def eval (self , p : np .ndarray , x : np .ndarray ) -> np .ndarray :
153- p_t = torch .from_numpy (p )
154- x_t = torch .from_numpy (x )
155- y_t = (
156- vec_x (self ._f , p_t , x_t )
240+ p_ = torch .from_numpy (p )
241+ x_ = torch .from_numpy (x )
242+ y_ = (
243+ vec_x (self ._f , p_ , x_ )
157244 if self ._jit
158- else vec_x_no_jit (self ._f , p_t , x_t )
245+ else vec_x_no_jit (self ._f , p_ , x_ )
159246 )
160- return y_t .detach ().numpy ()
247+ return y_ .detach ().numpy ()
161248
162249 def jac_p (self , p : np .ndarray , x : np .ndarray ) -> np .ndarray :
163- p_t = torch .from_numpy (p ).requires_grad_ (True )
164- x_t = torch .from_numpy (x )
165- g_t = (
166- jac_p (self ._f , p_t , x_t , self ._rev_p )
250+ p_ = torch .from_numpy (p ).requires_grad_ (True )
251+ x_ = torch .from_numpy (x )
252+ g_ = (
253+ jac_p (self ._f , p_ , x_ , self ._rev_p )
167254 if self ._jit
168- else jac_p_no_jit (self ._f , p_t , x_t , self ._rev_p )
255+ else jac_p_no_jit (self ._f , p_ , x_ , self ._rev_p )
169256 )
170- return g_t .detach ().numpy ()
257+ return g_ .detach ().numpy ()
171258
172259 def jac_x (self , p : np .ndarray , x : np .ndarray ) -> np .ndarray :
173- p_t = torch .from_numpy (p )
174- x_t = torch .from_numpy (x ).requires_grad_ (True )
175- g_t = (
176- jac_x (self ._f , p_t , x_t , self ._rev_x )
260+ p_ = torch .from_numpy (p )
261+ x_ = torch .from_numpy (x ).requires_grad_ (True )
262+ g_ = (
263+ jac_x (self ._f , p_ , x_ , self ._rev_x )
177264 if self ._jit
178- else jac_x_no_jit (self ._f , p_t , x_t , self ._rev_x )
265+ else jac_x_no_jit (self ._f , p_ , x_ , self ._rev_x )
179266 )
180- return g_t .detach ().numpy ()
267+ return g_ .detach ().numpy ()
268+
269+ def lpu_p (
270+ self , p : np .ndarray , u : np .ndarray , x : np .ndarray , diag : bool = False
271+ ) -> np .ndarray :
272+ p_ = torch .from_numpy (p ).requires_grad_ (True )
273+ u_ = torch .from_numpy (u )
274+ x_ = torch .from_numpy (x )
275+ v_ = lpu_p (p_ .ndim , jac_p (self ._f , p_ , x_ , self ._rev_p ), u_ , diag )
276+ return v_ .detach ().numpy ()
277+
278+ def lpu_x (
279+ self , p : np .ndarray , x : np .ndarray , u : np .ndarray , diag : bool = False
280+ ) -> np .ndarray :
281+ p_ = torch .from_numpy (p )
282+ x_ = torch .from_numpy (x ).requires_grad_ (True )
283+ u_ = torch .from_numpy (u )
284+ v_ = lpu_x (x_ .ndim - 1 , jac_x (self ._f , p_ , x_ , self ._rev_x ), u_ , diag )
285+ return v_ .detach ().numpy ()
181286
182287 @property
183288 def f (self ) -> Callable [[Tensor , Tensor ], Tensor ]:
0 commit comments