4444"""The maximum number of iterations permitted."""
4545
4646
47- @jax .jit (static_argnums = (0 ,), static_argnames = ("diagonalize " ,))
48- def evm_fit (
47+ @jax .jit (static_argnums = (0 ,), static_argnames = ("covar " ,))
48+ def evm (
4949 f : Callable [[Array , Array ], Array ],
5050 p : Array ,
5151 x : Array ,
@@ -56,7 +56,7 @@ def evm_fit(
5656 * ,
5757 max_i : int = DEFAULT_MAX_I ,
5858 max_g : Any = DEFAULT_MAX_G ,
59- diagonalize : bool = True ,
59+ covar : bool = False ,
6060) -> tuple [Any , ...]:
6161 r"""
6262 Implementation of the effective variance method (EVM) with
@@ -89,7 +89,7 @@ def evm_fit(
8989 :param up: Uncertainty tensor :math:`U(p)`, full or diagonal.
9090 :param max_i: The maximum number of iterations permitted.
9191 :param max_g: The maximum gradient permitted.
92- :param diagonalize : Use only diagonal of uncertainty propagation output .
92+ :param covar : Use effective variance-covariance .
9393 :returns: The minimization loop state.
9494 """
9595
@@ -147,7 +147,7 @@ def loss(q: Array, x: Array, y: Array, ux: Array, uy: Array) -> Array:
147147 """
148148 d = f (q , x ) - y
149149 G = g (q , x ) # noqa: N806
150- if diagonalize :
150+ if not covar :
151151 U = upd (x .ndim , G , ux ) + ( # noqa: N806
152152 uy
153153 if uy .ndim == y .ndim
0 commit comments