@@ -184,6 +184,26 @@ def b_solve(
184184 :returns: The Bernstein coefficients.
185185 """
186186
187+ N = len (k ) # noqa: N806
188+ bases = [b_basis (k [i ], x [i ]) for i in range (N )]
189+ facts = [jla .qr (B .T , full_matrices = False ) for B in bases ] # noqa: N806
190+ Q = [_ [0 ] for _ in facts ] # noqa: N806
191+ R = [_ [1 ] for _ in facts ] # noqa: N806
192+
193+ # compute the right hand side of the triangular equation
194+ rhs = y
195+ for i in range (N ):
196+ rhs = jnp .tensordot (rhs , Q [i ], axes = (0 , 0 ))
197+ # solve the triangular equation
198+ c_unconstrained = rhs
199+ for i in range (N ):
200+ c_unconstrained = jla .triangular_solve (
201+ R [i ], c_unconstrained , left_side = True
202+ )
203+ c_unconstrained = jnp .moveaxis ( # like the tensor dot product
204+ c_unconstrained , 0 , - 1
205+ )
206+
187207 def hvp (c : Array ):
188208 """The Hessian-vector product."""
189209 res = c
@@ -230,33 +250,16 @@ def make_minimizer():
230250 for i in range (N ):
231251 rhs = jnp .tensordot (rhs , R [i ], axes = (0 , 0 ))
232252
233- u = inverse (jnp .abs (c ))
253+ u = inverse (jnp .abs (c ) + jnp . finfo ( c . dtype ). eps )
234254 optimum = optimistix .minimise (
235255 misfit , make_minimizer (), u , max_steps = max_steps , throw = False
236256 )
237257 return forward (optimum .value )
238258
239- N = len (k ) # noqa: N806
240- bases = [b_basis (k [i ], x [i ]) for i in range (N )]
241- facts = [jla .qr (B .T , full_matrices = False ) for B in bases ] # noqa: N806
242- Q = [_ [0 ] for _ in facts ] # noqa: N806
243- R = [_ [1 ] for _ in facts ] # noqa: N806
244-
245- # compute the right hand side of the triangular equation
246- rhs = y
247- for i in range (N ):
248- rhs = jnp .tensordot (rhs , Q [i ], axes = (0 , 0 ))
249- # solve the triangular equation
250- c_unconstrained = rhs
251- for i in range (N ):
252- c_unconstrained = jla .triangular_solve (
253- R [i ], c_unconstrained , left_side = True
254- )
255- c_unconstrained = jnp .moveaxis ( # like the tensor dot product
256- c_unconstrained , 0 , - 1
257- )
258259 # solve iteratively with non-negativity constraint, if needed
259- nnls_needed = non_negative and jnp .any (c_unconstrained < 0.0 )
260+ nnls_needed = jnp .logical_and (
261+ non_negative , jnp .any (c_unconstrained < 0.0 )
262+ )
260263 return jax .lax .cond (
261264 nnls_needed ,
262265 nnls ,
0 commit comments