Skip to content

Commit 325cce0

Browse files
committed
Update: compute parameter Hessian and covariance
1 parent 3a1f3d5 commit 325cce0

1 file changed

Lines changed: 31 additions & 15 deletions

File tree

uncertaintyx/fit/eiv/jax.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,10 @@ def evm(
8989
:param up: Uncertainty tensor :math:`U(p)`, full or diagonal.
9090
:param max_i: The maximum number of iterations permitted.
9191
:param max_g: The maximum gradient permitted.
92-
:param covar: Use effective variance-covariance.
93-
:returns: The minimization loop state.
92+
:param covar: Use effective covariance, too.
93+
:returns: The fit result comprising: the optimized parameter
94+
values, the parameter uncertainty tensor, parameter standard
95+
uncertainties, the misfit, and the convergence status.
9496
"""
9597

9698
def g(q: Array, x: Array) -> Array:
@@ -217,13 +219,31 @@ def body(carry: tuple[Any, ...]) -> tuple[Any, ...]:
217219
converged = jli.norm(grad, ord=jnp.inf) < max_g # noqa
218220
return i + 1, popt, state, cost, grad, converged
219221

222+
def optimize(
223+
p: Array, state: Any, cost: Array, grad: Array
224+
) -> tuple[Any, ...]:
225+
carry = (0, p, state, cost, grad, False)
226+
_, popt, _, cost, _, converged = jax.lax.while_loop(cond, body, carry)
227+
return popt, cost, converged
228+
229+
def post(p: Array) -> tuple[Array, Array]:
230+
hess = jax.hessian(S)
231+
pcov = jli.pinv(hess(p).reshape(p.size, -1))
232+
punc = jnp.sqrt(jnp.diag(pcov)).reshape(p.shape)
233+
return pcov.reshape(p.shape + p.shape), punc.reshape(p.shape)
234+
235+
if ux is None:
236+
ux = jnp.broadcast_to(1.0, x.shape)
237+
if uy is None:
238+
uy = jnp.broadcast_to(1.0, y.shape)
220239
mizer = optax.lbfgs()
221240
state = mizer.init(p)
222241
cost_and_grad = optax.value_and_grad_from_state(S)
223242
cost, grad = cost_and_grad(p, state=state)
224-
carry = (0, p, state, cost, grad, False)
243+
popt, cost, converged = optimize(p, state, cost, grad)
244+
pcov, punc = post(p)
225245

226-
return jax.lax.while_loop(cond, body, carry) # noqa
246+
return popt, pcov, punc, cost, converged
227247

228248

229249
class EIV(Fitting):
@@ -262,30 +282,26 @@ def fit(
262282
:param max_iter: The maximum number of iterations conducted.
263283
:returns: The fit result.
264284
"""
265-
i, popt, state, cost, g, converged = evm(
285+
popt, pcov, punc, cost, converged = evm(
266286
f.f,
267287
jnp.asarray(f.estimate(x, y)),
268288
jnp.asarray(x),
269289
jnp.asarray(y),
270-
jnp.square(ux) if ux is not None else jnp.ones_like(x),
271-
jnp.square(uy) if uy is not None else jnp.ones_like(y),
290+
jnp.square(ux) if ux is not None else None,
291+
jnp.square(uy) if uy is not None else None,
272292
jnp.square(up) if up is not None else None,
273293
max_i=max_iter,
274294
**kwargs,
275295
)
276296
popt = np.asarray(popt)
277-
punc = np.zeros_like(popt)
278-
pcov = np.zeros_like(popt, shape=popt.shape + popt.shape)
279297
rvar = np.var(f.eval(popt, x) - y, axis=0, ddof=popt.size)
280-
cost = np.asarray(cost)
281-
conv = np.asarray(converged)
282298

283299
return Result(
284300
f,
285301
popt=popt,
286-
punc=punc,
287-
pcov=pcov,
302+
punc=np.asarray(punc),
303+
pcov=np.asarray(pcov),
288304
rvar=rvar,
289-
cost=cost,
290-
info=0 if conv else 1,
305+
cost=np.asarray(cost),
306+
info=0 if converged.item() else 1,
291307
)

0 commit comments

Comments
 (0)