@@ -92,7 +92,7 @@ def evm(
9292 :param covar: Use effective covariance, too.
9393 :returns: The fit result comprising: the optimized parameter
9494 values, the parameter uncertainty tensor, parameter standard
95- uncertainties, the misfit , and the convergence status.
95+ uncertainties, the cost , and the convergence status.
9696 """
9797
9898 def g (q : Array , x : Array ) -> Array :
@@ -183,10 +183,10 @@ def prior(q: Array) -> Array:
183183
184184 def S (q : Array ) -> Array : # noqa: N806
185185 """
186- The misfit function to minimize.
186+ The cost (or misfit) function to minimize.
187187
188188 :param q: The parameters.
189- :returns: The misfit .
189+ :returns: The cost .
190190 """
191191 loss_term = jnp .sum (
192192 jax .vmap (loss , in_axes = (None , 0 , 0 , 0 , 0 ))(q , x , y , ux , uy )
@@ -219,14 +219,28 @@ def body(carry: tuple[Any, ...]) -> tuple[Any, ...]:
219219 converged = jli .norm (grad , ord = jnp .inf ) < max_g # noqa
220220 return i + 1 , popt , tree , cost , grad , converged
221221
222- def optimize (p : Array ) -> tuple [Array , ...]:
222+ def optimize (p : Array ) -> tuple [Array , Array , Array ]:
223+ """
224+ Conducts the optimization.
225+
226+ :param p: The initial parameter values.
227+ :returns: The optimized parameter values, the cost, and the
228+ convergence status.
229+ """
223230 tree = optimizer .init (p )
224231 cost , grad = cost_and_grad (p , state = tree )
225232 init = (0 , p , tree , cost , grad , False )
226233 _ , popt , _ , cost , _ , converged = jax .lax .while_loop (cond , body , init )
227234 return popt , cost , converged
228235
229236 def post (p : Array ) -> tuple [Array , Array ]:
237+ """
238+ Conducts the post processing.
239+
240+ :param p: The optimized parameter values.
241+ :returns: The parameter uncertainty tensor and the parameter
242+ standard uncertainties.
243+ """
230244 hess = jax .hessian (S )
231245 pcov = jli .pinv (hess (p ).reshape (p .size , - 1 ))
232246 punc = jnp .sqrt (jnp .diag (pcov )).reshape (p .shape )
@@ -239,7 +253,7 @@ def post(p: Array) -> tuple[Array, Array]:
239253 optimizer = optax .lbfgs ()
240254 cost_and_grad = optax .value_and_grad_from_state (S )
241255 popt , cost , converged = optimize (p )
242- pcov , punc = post (p )
256+ pcov , punc = post (popt )
243257
244258 return popt , pcov , punc , cost , converged
245259
0 commit comments