Skip to content

Commit 8f4b3b3

Browse files
authored
Update jax.py
1 parent 286bcf5 commit 8f4b3b3

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

uncertaintyx/fit/eiv/jax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,9 @@ def body(carry: tuple[Any, ...]) -> tuple[Any, ...]:
219219
converged = jli.norm(grad, ord=jnp.inf) < max_g # noqa
220220
return i + 1, popt, tree, cost, grad, converged
221221

222-
def optimize(p: Array) -> tuple[Array, Array, Array]:
222+
def opti(p: Array) -> tuple[Array, Array, Array]:
223223
"""
224-
Conducts the optimization.
224+
Optimizes the parameters.
225225
226226
:param p: The initial parameter values.
227227
:returns: The optimized parameter values, the cost, and the
@@ -235,10 +235,10 @@ def optimize(p: Array) -> tuple[Array, Array, Array]:
235235

236236
def post(p: Array) -> tuple[Array, Array]:
237237
"""
238-
Conducts the post processing.
238+
Computes posterior parameter uncertainty.
239239
240240
:param p: The optimized parameter values.
241-
:returns: The parameter uncertainty tensor and the parameter
241+
:returns: The posterior parameter uncertainty tensor and parameter
242242
standard uncertainties.
243243
"""
244244
hess = jax.hessian(S)
@@ -252,7 +252,7 @@ def post(p: Array) -> tuple[Array, Array]:
252252
uy = jnp.broadcast_to(1.0, y.shape)
253253
optimizer = optax.lbfgs()
254254
cost_and_grad = optax.value_and_grad_from_state(S)
255-
popt, cost, converged = optimize(p)
255+
popt, cost, converged = opti(p)
256256
pcov, punc = post(popt)
257257

258258
return popt, pcov, punc, cost, converged

0 commit comments

Comments
 (0)