55from typing import Self
66
77import jax
8- import jax .lax .linalg as jla
98import jax .numpy as jnp
9+ import jax .numpy .linalg as jli
1010import numpy as np
1111import optax
1212import optimistix
@@ -186,44 +186,35 @@ def b_solve(
186186
187187 N = len (k ) # noqa: N806
188188 bases = [b_basis (k [i ], x [i ]) for i in range (N )]
189- facts = [jla .qr (B .T , full_matrices = False ) for B in bases ] # noqa: N806
190- Q = [_ [0 ] for _ in facts ] # noqa: N806
191- R = [_ [1 ] for _ in facts ] # noqa: N806
189+ grams = [jnp .dot (B , B .T ) for B in bases ] # noqa: N806
192190
193- # compute the right hand side of the triangular equation
191+ # compute the right hand side of the normal equation
194192 rhs = y
195193 for i in range (N ):
196- rhs = jnp .tensordot (rhs , Q [i ], axes = (0 , 0 ))
197- # solve the triangular equation
194+ B = bases [i ] # noqa: N806
195+ rhs = jnp .tensordot (rhs , B , axes = (0 , 1 ))
196+ # solve the normal equation
198197 c_unconstrained = rhs
199- if N > 1 :
200- for i in range (N ):
201- solve = jax .vmap (
202- lambda a , b : jla .triangular_solve (a , b , left_side = True ),
203- in_axes = (None , i ),
204- out_axes = i ,
205- )
206- c_unconstrained = solve (R [i ], c_unconstrained )
207- else :
208- c_unconstrained = jla .triangular_solve (
209- R [0 ], c_unconstrained , left_side = True
198+ for i in range (N ):
199+ G = grams [i ] # noqa: N806
200+ c_unconstrained = jnp .tensordot (
201+ c_unconstrained , jli .pinv (G ), axes = (0 , 1 )
210202 )
211203
212204 def hvp (c : Array ):
213205 """The Hessian-vector product."""
214206 res = c
215207 for i in range (N ):
216- res = jnp .tensordot (res , R [i ], axes = (0 , 1 ))
217- for i in range (N ):
218- res = jnp .tensordot (res , R [i ], axes = (0 , 0 ))
208+ G = grams [i ] # noqa: N806
209+ res = jnp .tensordot (res , G , axes = (0 , 1 ))
219210 return res
220211
221- def nnls (c : Array , rhs : Array ):
212+ def nnls (c : Array ):
222213 """
223214 Non-negative least-squares solver.
224215
225- Applies a positive transformation and an L-BFGS
226- optimizer to ensure non-negativity.
216+ Applies a positive transformation and an L-BFGS optimizer
217+ to ensure non-negativity.
227218 """
228219
229220 def forward (u : Array ) -> Array :
@@ -250,10 +241,6 @@ def make_minimizer():
250241 optax .lbfgs (), atol = atol , rtol = rtol , norm = optimistix .max_norm
251242 )
252243
253- # compute the right hand side of the normal equation
254- for i in range (N ):
255- rhs = jnp .tensordot (rhs , R [i ], axes = (0 , 0 ))
256-
257244 u = inverse (jnp .abs (c ) + jnp .finfo (c .dtype ).eps )
258245 optimum = optimistix .minimise (
259246 misfit , make_minimizer (), u , max_steps = max_steps , throw = False
@@ -264,13 +251,7 @@ def make_minimizer():
264251 nnls_needed = jnp .logical_and (
265252 non_negative , jnp .any (c_unconstrained < 0.0 )
266253 )
267- return jax .lax .cond (
268- nnls_needed ,
269- nnls ,
270- lambda c , _ : c ,
271- c_unconstrained ,
272- rhs ,
273- )
254+ return jax .lax .cond (nnls_needed , nnls , lambda c : c , c_unconstrained )
274255
275256
276257def _lower_bounds (
0 commit comments