@@ -219,9 +219,9 @@ def body(carry: tuple[Any, ...]) -> tuple[Any, ...]:
219219 converged = jli .norm (grad , ord = jnp .inf ) < max_g # noqa
220220 return i + 1 , popt , tree , cost , grad , converged
221221
222- def optimize (p : Array ) -> tuple [Array , Array , Array ]:
222+ def opti (p : Array ) -> tuple [Array , Array , Array ]:
223223 """
224- Conducts the optimization .
224+ Optimizes the parameters .
225225
226226 :param p: The initial parameter values.
227227 :returns: The optimized parameter values, the cost, and the
@@ -235,10 +235,10 @@ def optimize(p: Array) -> tuple[Array, Array, Array]:
235235
236236 def post (p : Array ) -> tuple [Array , Array ]:
237237 """
238- Conducts the post processing .
238+ Computes posterior parameter uncertainty .
239239
240240 :param p: The optimized parameter values.
241- :returns: The parameter uncertainty tensor and the parameter
241+ :returns: The posterior parameter uncertainty tensor and parameter
242242 standard uncertainties.
243243 """
244244 hess = jax .hessian (S )
@@ -252,7 +252,7 @@ def post(p: Array) -> tuple[Array, Array]:
252252 uy = jnp .broadcast_to (1.0 , y .shape )
253253 optimizer = optax .lbfgs ()
254254 cost_and_grad = optax .value_and_grad_from_state (S )
255- popt , cost , converged = optimize (p )
255+ popt , cost , converged = opti (p )
256256 pcov , punc = post (popt )
257257
258258 return popt , pcov , punc , cost , converged
0 commit comments