Skip to content

Commit 855d386

Browse files
committed
Update: step size convergence criterion
1 parent 3233e2c commit 855d386

1 file changed

Lines changed: 16 additions & 7 deletions

File tree

uncertaintyx/fit/eiv/jax.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@
3737
from ...interface.core import M
3838
from ...interface.core import Result
3939

40-
DEFAULT_MAX_G: Any = 1.0e-08
41-
"""The maximum gradient permitted."""
40+
DEFAULT_MAX_D: Any = 1.0e-08
41+
"""The maximum L2 norm of the parameter step allowed for convergence."""
42+
43+
DEFAULT_MAX_G: Any = 1.0e-06
44+
"""The maximum infinity norm of the gradient allowed for convergence."""
4245

4346
DEFAULT_MAX_I: int = 100
4447
"""The maximum number of iterations permitted."""
@@ -55,6 +58,7 @@ def evm(
5558
up: Array | None = None,
5659
*,
5760
max_i: int = DEFAULT_MAX_I,
61+
max_d: Any = DEFAULT_MAX_D,
5862
max_g: Any = DEFAULT_MAX_G,
5963
covar: bool = False,
6064
) -> tuple[Array, Array, Array, Array, Array]:
@@ -87,8 +91,11 @@ def evm(
8791
:param ux: Uncertainty tensor :math:`U(X)`, full or diagonal.
8892
:param uy: Uncertainty tensor :math:`U(Y)`, full or diagonal.
8993
:param up: Uncertainty tensor :math:`U(p)`, full or diagonal.
90-
:param max_i: The maximum number of iterations permitted.
91-
:param max_g: The maximum gradient permitted.
94+
:param max_i: The maximum number of iterations allowed.
95+
:param max_d: The maximum L2 norm of the parameter step allowed
96+
for convergence
97+
:param max_g: The maximum infinity norm of the gradient allowed
98+
for convergence.
9299
:param covar: Use effective covariance, too.
93100
:returns: The fit result comprising: the optimized parameter
94101
values, the parameter uncertainty tensor, parameter standard
@@ -210,12 +217,14 @@ def body(carry: tuple[Any, ...]) -> tuple[Any, ...]:
210217
:returns: The updated loop state carrier.
211218
"""
212219
i, popt, tree, cost, grad, _ = carry
213-
u, tree = optimizer.update(
220+
d, tree = optimizer.update(
214221
grad, tree, popt, value=cost, grad=grad, value_fn=S
215222
)
216-
popt = optax.apply_updates(popt, u)
223+
popt = optax.apply_updates(popt, d)
217224
cost, grad = cost_and_grad(popt, state=tree)
218-
converged = jli.norm(grad, ord=jnp.inf) < max_g # noqa
225+
grad_norm = jli.norm(grad, ord=jnp.inf) # noqa
226+
step_norm = jli.norm(d)
227+
converged = jnp.logical_or(grad_norm < max_g, step_norm < max_d)
219228
return i + 1, popt, tree, cost, grad, converged
220229

221230
def opti(p: Array) -> tuple[Array, Array, Array]:

0 commit comments

Comments
 (0)