@@ -89,8 +89,10 @@ def evm(
8989 :param up: Uncertainty tensor :math:`U(p)`, full or diagonal.
9090 :param max_i: The maximum number of iterations permitted.
9191 :param max_g: The maximum gradient permitted.
92- :param covar: Use effective variance-covariance.
93- :returns: The minimization loop state.
92+ :param covar: Use effective covariance, too.
93+ :returns: The fit result comprising: the optimized parameter
94+ values, the parameter uncertainty tensor, parameter standard
95+ uncertainties, the misfit, and the convergence status.
9496 """
9597
9698 def g (q : Array , x : Array ) -> Array :
@@ -217,13 +219,31 @@ def body(carry: tuple[Any, ...]) -> tuple[Any, ...]:
217219 converged = jli .norm (grad , ord = jnp .inf ) < max_g # noqa
218220 return i + 1 , popt , state , cost , grad , converged
219221
222+ def optimize (
223+ p : Array , state : Any , cost : Array , grad : Array
224+ ) -> tuple [Any , ...]:
225+ carry = (0 , p , state , cost , grad , False )
226+ _ , popt , _ , cost , _ , converged = jax .lax .while_loop (cond , body , carry )
227+ return popt , cost , converged
228+
229+ def post (p : Array ) -> tuple [Array , Array ]:
230+ hess = jax .hessian (S )
231+ pcov = jli .pinv (hess (p ).reshape (p .size , - 1 ))
232+ punc = jnp .sqrt (jnp .diag (pcov )).reshape (p .shape )
233+ return pcov .reshape (p .shape + p .shape ), punc .reshape (p .shape )
234+
235+ if ux is None :
236+ ux = jnp .broadcast_to (1.0 , x .shape )
237+ if uy is None :
238+ uy = jnp .broadcast_to (1.0 , y .shape )
220239 mizer = optax .lbfgs ()
221240 state = mizer .init (p )
222241 cost_and_grad = optax .value_and_grad_from_state (S )
223242 cost , grad = cost_and_grad (p , state = state )
224- carry = (0 , p , state , cost , grad , False )
243+ popt , cost , converged = optimize (p , state , cost , grad )
244+ pcov , punc = post (p )
225245
226- return jax . lax . while_loop ( cond , body , carry ) # noqa
246+ return popt , pcov , punc , cost , converged
227247
228248
229249class EIV (Fitting ):
@@ -262,30 +282,26 @@ def fit(
262282 :param max_iter: The maximum number of iterations conducted.
263283 :returns: The fit result.
264284 """
265- i , popt , state , cost , g , converged = evm (
285+ popt , pcov , punc , cost , converged = evm (
266286 f .f ,
267287 jnp .asarray (f .estimate (x , y )),
268288 jnp .asarray (x ),
269289 jnp .asarray (y ),
270- jnp .square (ux ) if ux is not None else jnp . ones_like ( x ) ,
271- jnp .square (uy ) if uy is not None else jnp . ones_like ( y ) ,
290+ jnp .square (ux ) if ux is not None else None ,
291+ jnp .square (uy ) if uy is not None else None ,
272292 jnp .square (up ) if up is not None else None ,
273293 max_i = max_iter ,
274294 ** kwargs ,
275295 )
276296 popt = np .asarray (popt )
277- punc = np .zeros_like (popt )
278- pcov = np .zeros_like (popt , shape = popt .shape + popt .shape )
279297 rvar = np .var (f .eval (popt , x ) - y , axis = 0 , ddof = popt .size )
280- cost = np .asarray (cost )
281- conv = np .asarray (converged )
282298
283299 return Result (
284300 f ,
285301 popt = popt ,
286- punc = punc ,
287- pcov = pcov ,
302+ punc = np . asarray ( punc ) ,
303+ pcov = np . asarray ( pcov ) ,
288304 rvar = rvar ,
289- cost = cost ,
290- info = 0 if conv else 1 ,
305+ cost = np . asarray ( cost ) ,
306+ info = 0 if converged . item () else 1 ,
291307 )
0 commit comments