Skip to content

Commit e01375a

Browse files
authored
Refactor QR decomposition to use Gram matrices
1 parent 17ff4f1 commit e01375a

1 file changed

Lines changed: 16 additions & 35 deletions

File tree

uncertaintyx/b/jax.py

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from typing import Self
66

77
import jax
8-
import jax.lax.linalg as jla
98
import jax.numpy as jnp
9+
import jax.numpy.linalg as jli
1010
import numpy as np
1111
import optax
1212
import optimistix
@@ -186,44 +186,35 @@ def b_solve(
186186

187187
N = len(k) # noqa: N806
188188
bases = [b_basis(k[i], x[i]) for i in range(N)]
189-
facts = [jla.qr(B.T, full_matrices=False) for B in bases] # noqa: N806
190-
Q = [_[0] for _ in facts] # noqa: N806
191-
R = [_[1] for _ in facts] # noqa: N806
189+
grams = [jnp.dot(B, B.T) for B in bases] # noqa: N806
192190

193-
# compute the right hand side of the triangular equation
191+
# compute the right hand side of the normal equation
194192
rhs = y
195193
for i in range(N):
196-
rhs = jnp.tensordot(rhs, Q[i], axes=(0, 0))
197-
# solve the triangular equation
194+
B = bases[i] # noqa: N806
195+
rhs = jnp.tensordot(rhs, B, axes=(0, 1))
196+
# solve the normal equation
198197
c_unconstrained = rhs
199-
if N > 1:
200-
for i in range(N):
201-
solve = jax.vmap(
202-
lambda a, b: jla.triangular_solve(a, b, left_side=True),
203-
in_axes=(None, i),
204-
out_axes=i,
205-
)
206-
c_unconstrained = solve(R[i], c_unconstrained)
207-
else:
208-
c_unconstrained = jla.triangular_solve(
209-
R[0], c_unconstrained, left_side=True
198+
for i in range(N):
199+
G = grams[i] # noqa: N806
200+
c_unconstrained = jnp.tensordot(
201+
c_unconstrained, jli.pinv(G), axes=(0, 1)
210202
)
211203

212204
def hvp(c: Array):
213205
"""The Hessian-vector product."""
214206
res = c
215207
for i in range(N):
216-
res = jnp.tensordot(res, R[i], axes=(0, 1))
217-
for i in range(N):
218-
res = jnp.tensordot(res, R[i], axes=(0, 0))
208+
G = grams[i] # noqa: N806
209+
res = jnp.tensordot(res, G, axes=(0, 1))
219210
return res
220211

221-
def nnls(c: Array, rhs: Array):
212+
def nnls(c: Array):
222213
"""
223214
Non-negative least-squares solver.
224215
225-
Applies a positive transformation and an L-BFGS
226-
optimizer to ensure non-negativity.
216+
Applies a positive transformation and an L-BFGS optimizer
217+
to ensure non-negativity.
227218
"""
228219

229220
def forward(u: Array) -> Array:
@@ -250,10 +241,6 @@ def make_minimizer():
250241
optax.lbfgs(), atol=atol, rtol=rtol, norm=optimistix.max_norm
251242
)
252243

253-
# compute the right hand side of the normal equation
254-
for i in range(N):
255-
rhs = jnp.tensordot(rhs, R[i], axes=(0, 0))
256-
257244
u = inverse(jnp.abs(c) + jnp.finfo(c.dtype).eps)
258245
optimum = optimistix.minimise(
259246
misfit, make_minimizer(), u, max_steps=max_steps, throw=False
@@ -264,13 +251,7 @@ def make_minimizer():
264251
nnls_needed = jnp.logical_and(
265252
non_negative, jnp.any(c_unconstrained < 0.0)
266253
)
267-
return jax.lax.cond(
268-
nnls_needed,
269-
nnls,
270-
lambda c, _: c,
271-
c_unconstrained,
272-
rhs,
273-
)
254+
return jax.lax.cond(nnls_needed, nnls, lambda c: c, c_unconstrained)
274255

275256

276257
def _lower_bounds(

0 commit comments

Comments
 (0)