Skip to content

Commit 286bcf5

Browse files
authored
Update jax.py
1 parent 2506b11 commit 286bcf5

1 file changed

Lines changed: 19 additions & 5 deletions

File tree

uncertaintyx/fit/eiv/jax.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def evm(
9292
:param covar: Use effective covariance, too.
9393
:returns: The fit result comprising: the optimized parameter
9494
values, the parameter uncertainty tensor, parameter standard
95-
uncertainties, the misfit, and the convergence status.
95+
uncertainties, the cost, and the convergence status.
9696
"""
9797

9898
def g(q: Array, x: Array) -> Array:
@@ -183,10 +183,10 @@ def prior(q: Array) -> Array:
183183

184184
def S(q: Array) -> Array: # noqa: N806
185185
"""
186-
The misfit function to minimize.
186+
The cost (or misfit) function to minimize.
187187
188188
:param q: The parameters.
189-
:returns: The misfit.
189+
:returns: The cost.
190190
"""
191191
loss_term = jnp.sum(
192192
jax.vmap(loss, in_axes=(None, 0, 0, 0, 0))(q, x, y, ux, uy)
@@ -219,14 +219,28 @@ def body(carry: tuple[Any, ...]) -> tuple[Any, ...]:
219219
converged = jli.norm(grad, ord=jnp.inf) < max_g # noqa
220220
return i + 1, popt, tree, cost, grad, converged
221221

222-
def optimize(p: Array) -> tuple[Array, ...]:
222+
def optimize(p: Array) -> tuple[Array, Array, Array]:
223+
"""
224+
Conducts the optimization.
225+
226+
:param p: The initial parameter values.
227+
:returns: The optimized parameter values, the cost, and the
228+
convergence status.
229+
"""
223230
tree = optimizer.init(p)
224231
cost, grad = cost_and_grad(p, state=tree)
225232
init = (0, p, tree, cost, grad, False)
226233
_, popt, _, cost, _, converged = jax.lax.while_loop(cond, body, init)
227234
return popt, cost, converged
228235

229236
def post(p: Array) -> tuple[Array, Array]:
237+
"""
238+
Conducts the post processing.
239+
240+
:param p: The optimized parameter values.
241+
:returns: The parameter uncertainty tensor and the parameter
242+
standard uncertainties.
243+
"""
230244
hess = jax.hessian(S)
231245
pcov = jli.pinv(hess(p).reshape(p.size, -1))
232246
punc = jnp.sqrt(jnp.diag(pcov)).reshape(p.shape)
@@ -239,7 +253,7 @@ def post(p: Array) -> tuple[Array, Array]:
239253
optimizer = optax.lbfgs()
240254
cost_and_grad = optax.value_and_grad_from_state(S)
241255
popt, cost, converged = optimize(p)
242-
pcov, punc = post(p)
256+
pcov, punc = post(popt)
243257

244258
return popt, pcov, punc, cost, converged
245259

0 commit comments

Comments
 (0)