Skip to content

Commit 138b93b

Browse files
committed
Fix: prevent possible singularity
1 parent 96cf21a commit 138b93b

1 file changed

Lines changed: 24 additions & 21 deletions

File tree

uncertaintyx/b/jax.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,26 @@ def b_solve(
184184
:returns: The Bernstein coefficients.
185185
"""
186186

187+
N = len(k) # noqa: N806
188+
bases = [b_basis(k[i], x[i]) for i in range(N)]
189+
facts = [jla.qr(B.T, full_matrices=False) for B in bases] # noqa: N806
190+
Q = [_[0] for _ in facts] # noqa: N806
191+
R = [_[1] for _ in facts] # noqa: N806
192+
193+
# compute the right hand side of the triangular equation
194+
rhs = y
195+
for i in range(N):
196+
rhs = jnp.tensordot(rhs, Q[i], axes=(0, 0))
197+
# solve the triangular equation
198+
c_unconstrained = rhs
199+
for i in range(N):
200+
c_unconstrained = jla.triangular_solve(
201+
R[i], c_unconstrained, left_side=True
202+
)
203+
c_unconstrained = jnp.moveaxis( # like the tensor dot product
204+
c_unconstrained, 0, -1
205+
)
206+
187207
def hvp(c: Array):
188208
"""The Hessian-vector product."""
189209
res = c
@@ -230,33 +250,16 @@ def make_minimizer():
230250
for i in range(N):
231251
rhs = jnp.tensordot(rhs, R[i], axes=(0, 0))
232252

233-
u = inverse(jnp.abs(c))
253+
u = inverse(jnp.abs(c) + jnp.finfo(c.dtype).eps)
234254
optimum = optimistix.minimise(
235255
misfit, make_minimizer(), u, max_steps=max_steps, throw=False
236256
)
237257
return forward(optimum.value)
238258

239-
N = len(k) # noqa: N806
240-
bases = [b_basis(k[i], x[i]) for i in range(N)]
241-
facts = [jla.qr(B.T, full_matrices=False) for B in bases] # noqa: N806
242-
Q = [_[0] for _ in facts] # noqa: N806
243-
R = [_[1] for _ in facts] # noqa: N806
244-
245-
# compute the right hand side of the triangular equation
246-
rhs = y
247-
for i in range(N):
248-
rhs = jnp.tensordot(rhs, Q[i], axes=(0, 0))
249-
# solve the triangular equation
250-
c_unconstrained = rhs
251-
for i in range(N):
252-
c_unconstrained = jla.triangular_solve(
253-
R[i], c_unconstrained, left_side=True
254-
)
255-
c_unconstrained = jnp.moveaxis( # like the tensor dot product
256-
c_unconstrained, 0, -1
257-
)
258259
# solve iteratively with non-negativity constraint, if needed
259-
nnls_needed = non_negative and jnp.any(c_unconstrained < 0.0)
260+
nnls_needed = jnp.logical_and(
261+
non_negative, jnp.any(c_unconstrained < 0.0)
262+
)
260263
return jax.lax.cond(
261264
nnls_needed,
262265
nnls,

0 commit comments

Comments
 (0)