3737from ...interface .core import M
3838from ...interface .core import Result
3939
40- DEFAULT_MAX_G : Any = 1.0e-08
41- """The maximum gradient permitted."""
40+ DEFAULT_MAX_D : Any = 1.0e-08
41+ """The maximum L2 norm of the parameter step allowed for convergence."""
42+
43+ DEFAULT_MAX_G : Any = 1.0e-06
44+ """The maximum infinity norm of the gradient allowed for convergence."""
4245
4346DEFAULT_MAX_I : int = 100
4447"""The maximum number of iterations permitted."""
@@ -55,6 +58,7 @@ def evm(
5558 up : Array | None = None ,
5659 * ,
5760 max_i : int = DEFAULT_MAX_I ,
61+ max_d : Any = DEFAULT_MAX_D ,
5862 max_g : Any = DEFAULT_MAX_G ,
5963 covar : bool = False ,
6064) -> tuple [Array , Array , Array , Array , Array ]:
@@ -87,8 +91,11 @@ def evm(
8791 :param ux: Uncertainty tensor :math:`U(X)`, full or diagonal.
8892 :param uy: Uncertainty tensor :math:`U(Y)`, full or diagonal.
8993 :param up: Uncertainty tensor :math:`U(p)`, full or diagonal.
90- :param max_i: The maximum number of iterations permitted.
91- :param max_g: The maximum gradient permitted.
94+ :param max_i: The maximum number of iterations allowed.
95+ :param max_d: The maximum L2 norm of the parameter step allowed
96+ for convergence
97+ :param max_g: The maximum infinity norm of the gradient allowed
98+ for convergence.
9299 :param covar: Use effective covariance, too.
93100 :returns: The fit result comprising: the optimized parameter
94101 values, the parameter uncertainty tensor, parameter standard
@@ -210,12 +217,14 @@ def body(carry: tuple[Any, ...]) -> tuple[Any, ...]:
210217 :returns: The updated loop state carrier.
211218 """
212219 i , popt , tree , cost , grad , _ = carry
213- u , tree = optimizer .update (
220+ d , tree = optimizer .update (
214221 grad , tree , popt , value = cost , grad = grad , value_fn = S
215222 )
216- popt = optax .apply_updates (popt , u )
223+ popt = optax .apply_updates (popt , d )
217224 cost , grad = cost_and_grad (popt , state = tree )
218- converged = jli .norm (grad , ord = jnp .inf ) < max_g # noqa
225+ grad_norm = jli .norm (grad , ord = jnp .inf ) # noqa
226+ step_norm = jli .norm (d )
227+ converged = jnp .logical_or (grad_norm < max_g , step_norm < max_d )
219228 return i + 1 , popt , tree , cost , grad , converged
220229
221230 def opti (p : Array ) -> tuple [Array , Array , Array ]:
0 commit comments