Skip to content

Commit d9dbdfb

Browse files
authored
Merge pull request #4 from bcdev/heptaflar-test
Refactor Bernstein polynomial tests for clarity
2 parents f116809 + 27745b4 commit d9dbdfb

2 files changed

Lines changed: 64 additions & 69 deletions

File tree

test/b/test_b_jax.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_eval(self):
9292
c = self.c
9393
f = self.f
9494
y = f.eval(c)
95-
precalculated = np.asarray(
95+
y_precalculated = np.asarray(
9696
[
9797
[
9898
[19.8694, 19.7848, 20.3956],
@@ -111,8 +111,8 @@ def test_eval(self):
111111
],
112112
]
113113
)
114-
self.assertEqual((3, 3, 3), y.shape)
115-
self.assertTrue(np.allclose(y, precalculated))
114+
self.assertEqual(y_precalculated.shape, y.shape)
115+
self.assertTrue(np.allclose(y, y_precalculated))
116116

117117
g = f.jac(c)
118118
self.assertEqual(y.shape + c.shape, g.shape)
@@ -179,9 +179,9 @@ def test_bernstein_poly(self):
179179
)
180180
f = BernsteinPoly(c)
181181
y = f.eval(c, x)
182-
precalculated = np.asarray([19.8694, 32.0761, 19.6774])
183-
self.assertEqual((3,), y.shape)
184-
self.assertTrue(jnp.allclose(y, precalculated))
182+
y_precalculated = np.asarray([19.8694, 32.0761, 19.6774])
183+
self.assertEqual(y_precalculated.shape, y.shape)
184+
self.assertTrue(np.allclose(y, y_precalculated))
185185

186186
g = f.jac_p(c, x)
187187
self.assertEqual((3,) + d, g.shape)
@@ -192,20 +192,20 @@ def test_bernstein_poly(self):
192192
self.assertTrue(np.all(g > 0.0))
193193

194194
def test_from_lookup_table(self):
195-
k = 5
196-
x = np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
197-
y = np.square(x) + 2.0 * x + 3.0
198-
199-
f = BernsteinPoly.from_lookup_table((k,), (x,), y, non_negative=True)
200-
c = f.prior()
201-
self.assertEqual((k + 1,), c.shape)
202-
self.assertAlmostEqual(3.0, c[0])
203-
self.assertAlmostEqual(3.4, c[1])
204-
self.assertAlmostEqual(3.9, c[2])
205-
self.assertAlmostEqual(4.5, c[3])
206-
self.assertAlmostEqual(5.2, c[4])
207-
self.assertAlmostEqual(6.0, c[5])
208-
self.assertTrue(jnp.allclose(f.eval(c, x), y))
195+
k = (3, 4, 2)
196+
d = tuple([k_ + 1 for k_ in k])
197+
c = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0
198+
x = (
199+
np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]),
200+
np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]),
201+
np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]),
202+
)
203+
y = BernsteinGrid(x).eval(c)
204+
205+
f = BernsteinPoly.from_lookup_table(k, x, y)
206+
b = f.prior()
207+
self.assertEqual(c.shape, b.shape)
208+
self.assertTrue(np.allclose(b, c))
209209

210210

211211
class BSolveTest(unittest.TestCase):
@@ -217,49 +217,53 @@ class BSolveTest(unittest.TestCase):
217217
def test_b_solve_0_2(self):
218218
r"""Fit :math:`B_{0,2}(x)`."""
219219
k = 2
220-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
220+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
221221
y = jnp.square(1.0 - x)
222222

223223
c = b_solve((k,), (x,), y, non_negative=True)
224224
self.assertEqual((k + 1,), c.shape)
225+
self.assertFalse(np.any(c < 0.0))
225226
self.assertAlmostEqual(1.0, c[0].item())
226227
self.assertAlmostEqual(0.0, c[1].item())
227228
self.assertAlmostEqual(0.0, c[2].item())
228229

229230
def test_b_solve_1_2(self):
230231
r"""Fit :math:`B_{1,2}(x)`."""
231232
k = 2
232-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
233+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
233234
y = 2.0 * x * (1.0 - x)
234235

235236
c = b_solve((k,), (x,), y, non_negative=True)
236237
self.assertEqual((k + 1,), c.shape)
238+
self.assertFalse(np.any(c < 0.0))
237239
self.assertAlmostEqual(0.0, c[0].item())
238240
self.assertAlmostEqual(1.0, c[1].item())
239241
self.assertAlmostEqual(0.0, c[2].item())
240242

241243
def test_b_solve_2_2(self):
242244
r"""Fit :math:`B_{2,2}(x)`."""
243245
k = 2
244-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
246+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
245247
y = jnp.square(x)
246248

247249
c = b_solve((k,), (x,), y, non_negative=True)
248250
self.assertEqual((k + 1,), c.shape)
251+
self.assertFalse(np.any(c < 0.0))
249252
self.assertAlmostEqual(0.0, c[0].item())
250253
self.assertAlmostEqual(0.0, c[1].item())
251254
self.assertAlmostEqual(1.0, c[2].item())
252255

253256
def test_b_solve_0_0_2_2(self):
254257
r"""Fit :math:`B_{(0,0),(2,2)}(x_0, x_1)`."""
255258
k = 2
256-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
259+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
257260
y = jnp.square(1.0 - x[jnp.newaxis, :]) * jnp.square(
258261
1.0 - x[:, jnp.newaxis]
259262
)
260263

261264
c = b_solve((k, k), (x, x), y, non_negative=True)
262265
self.assertEqual((k + 1, k + 1), c.shape)
266+
self.assertFalse(np.any(c < 0.0))
263267
self.assertAlmostEqual(1.0, c[0, 0].item())
264268
self.assertAlmostEqual(0.0, c[0, 1].item())
265269
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -273,7 +277,7 @@ def test_b_solve_0_0_2_2(self):
273277
def test_b_solve_1_0_2_2(self):
274278
r"""Fit :math:`B_{(1,0),(2,2)}(x_0, x_1)`."""
275279
k = 2
276-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
280+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
277281
y = (
278282
2.0
279283
* x[:, jnp.newaxis]
@@ -283,6 +287,7 @@ def test_b_solve_1_0_2_2(self):
283287

284288
c = b_solve((k, k), (x, x), y, non_negative=True)
285289
self.assertEqual((k + 1, k + 1), c.shape)
290+
self.assertFalse(np.any(c < 0.0))
286291
self.assertAlmostEqual(0.0, c[0, 0].item())
287292
self.assertAlmostEqual(0.0, c[0, 1].item())
288293
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -296,13 +301,14 @@ def test_b_solve_1_0_2_2(self):
296301
def test_b_solve_2_0_2_2(self):
297302
r"""Fit :math:`B_{(2,0),(2,2)}(x_0, x_1)`."""
298303
k = 2
299-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
304+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
300305
y = jnp.square(x[:, jnp.newaxis]) * jnp.square(
301306
1.0 - x[jnp.newaxis, :]
302307
)
303308

304309
c = b_solve((k, k), (x, x), y, non_negative=True)
305310
self.assertEqual((k + 1, k + 1), c.shape)
311+
self.assertFalse(np.any(c < 0.0))
306312
self.assertAlmostEqual(0.0, c[0, 0].item())
307313
self.assertAlmostEqual(0.0, c[0, 1].item())
308314
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -316,7 +322,7 @@ def test_b_solve_2_0_2_2(self):
316322
def test_b_solve_0_1_2_2(self):
317323
r"""Fit :math:`B_{(0,1),(2,2)}(x_0, x_1)`."""
318324
k = 2
319-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
325+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
320326
y = (
321327
2.0
322328
* jnp.square(1.0 - x[:, jnp.newaxis])
@@ -326,6 +332,7 @@ def test_b_solve_0_1_2_2(self):
326332

327333
c = b_solve((k, k), (x, x), y, non_negative=True)
328334
self.assertEqual((k + 1, k + 1), c.shape)
335+
self.assertFalse(np.any(c < 0.0))
329336
self.assertAlmostEqual(0.0, c[0, 0].item())
330337
self.assertAlmostEqual(1.0, c[0, 1].item())
331338
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -339,7 +346,7 @@ def test_b_solve_0_1_2_2(self):
339346
def test_b_solve_1_1_2_2(self):
340347
r"""Fit :math:`B_{(1,1),(2,2)}(x_0, x_1)`."""
341348
k = 2
342-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
349+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
343350
y = (
344351
4.0
345352
* x[:, jnp.newaxis]
@@ -350,6 +357,7 @@ def test_b_solve_1_1_2_2(self):
350357

351358
c = b_solve((k, k), (x, x), y, non_negative=True)
352359
self.assertEqual((k + 1, k + 1), c.shape)
360+
self.assertFalse(np.any(c < 0.0))
353361
self.assertAlmostEqual(0.0, c[0, 0].item())
354362
self.assertAlmostEqual(0.0, c[0, 1].item())
355363
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -363,7 +371,7 @@ def test_b_solve_1_1_2_2(self):
363371
def test_b_solve_2_1_2_2(self):
364372
r"""Fit :math:`B_{(2,1),(2,2)}(x_0, x_1)`."""
365373
k = 2
366-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
374+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
367375
y = (
368376
2.0
369377
* jnp.square(x[:, jnp.newaxis])
@@ -372,6 +380,7 @@ def test_b_solve_2_1_2_2(self):
372380
)
373381

374382
c = b_solve((k, k), (x, x), y, non_negative=True)
383+
self.assertFalse(np.any(c < 0.0))
375384
self.assertEqual((k + 1, k + 1), c.shape)
376385
self.assertAlmostEqual(0.0, c[0, 0].item())
377386
self.assertAlmostEqual(0.0, c[0, 1].item())
@@ -386,12 +395,13 @@ def test_b_solve_2_1_2_2(self):
386395
def test_b_solve_0_2_2_2(self):
387396
r"""Fit :math:`B_{(0,2),(2,2)}(x_0, x_1)`."""
388397
k = 2
389-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
398+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
390399
y = jnp.square(1.0 - x[:, jnp.newaxis]) * jnp.square(
391400
x[jnp.newaxis, :]
392401
)
393402

394403
c = b_solve((k, k), (x, x), y, non_negative=True)
404+
self.assertFalse(np.any(c < 0.0))
395405
self.assertEqual((k + 1, k + 1), c.shape)
396406
self.assertAlmostEqual(0.0, c[0, 0].item())
397407
self.assertAlmostEqual(0.0, c[0, 1].item())
@@ -406,7 +416,7 @@ def test_b_solve_0_2_2_2(self):
406416
def test_b_solve_1_2_2_2(self):
407417
r"""Fit :math:`B_{(1,2),(2,2)}(x_0, x_1)`."""
408418
k = 2
409-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
419+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
410420
y = (
411421
2.0
412422
* x[:, jnp.newaxis]
@@ -416,6 +426,7 @@ def test_b_solve_1_2_2_2(self):
416426

417427
c = b_solve((k, k), (x, x), y, non_negative=True)
418428
self.assertEqual((k + 1, k + 1), c.shape)
429+
self.assertFalse(np.any(c < 0.0))
419430
self.assertAlmostEqual(0.0, c[0, 0].item())
420431
self.assertAlmostEqual(0.0, c[0, 1].item())
421432
self.assertAlmostEqual(0.0, c[0, 2].item())
@@ -429,11 +440,12 @@ def test_b_solve_1_2_2_2(self):
429440
def test_b_solve_2_2_2_2(self):
430441
r"""Fit :math:`B_{(2,2),(2,2)}(x_0, x_1)`."""
431442
k = 2
432-
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
443+
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
433444
y = jnp.square(x[:, jnp.newaxis]) * jnp.square(x[jnp.newaxis, :])
434445

435446
c = b_solve((k, k), (x, x), y, non_negative=True)
436447
self.assertEqual((k + 1, k + 1), c.shape)
448+
self.assertFalse(np.any(c < 0.0))
437449
self.assertAlmostEqual(0.0, c[0, 0].item())
438450
self.assertAlmostEqual(0.0, c[0, 1].item())
439451
self.assertAlmostEqual(0.0, c[0, 2].item())

uncertaintyx/b/jax.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from typing import Self
66

77
import jax
8-
import jax.lax.linalg as jla
98
import jax.numpy as jnp
9+
import jax.numpy.linalg as jli
1010
import numpy as np
1111
import optax
1212
import optimistix
@@ -186,44 +186,35 @@ def b_solve(
186186

187187
N = len(k) # noqa: N806
188188
bases = [b_basis(k[i], x[i]) for i in range(N)]
189-
facts = [jla.qr(B.T, full_matrices=False) for B in bases] # noqa: N806
190-
Q = [_[0] for _ in facts] # noqa: N806
191-
R = [_[1] for _ in facts] # noqa: N806
189+
grams = [jnp.dot(B, B.T) for B in bases] # noqa: N806
192190

193-
# compute the right hand side of the triangular equation
191+
# compute the right hand side of the normal equation
194192
rhs = y
195193
for i in range(N):
196-
rhs = jnp.tensordot(rhs, Q[i], axes=(0, 0))
197-
# solve the triangular equation
194+
B = bases[i] # noqa: N806
195+
rhs = jnp.tensordot(rhs, B, axes=(0, 1))
196+
# solve the normal equation
198197
c_unconstrained = rhs
199-
if N > 1:
200-
for i in range(N):
201-
solve = jax.vmap(
202-
lambda a, b: jla.triangular_solve(a, b, left_side=True),
203-
in_axes=(None, i),
204-
out_axes=i,
205-
)
206-
c_unconstrained = solve(R[i], c_unconstrained)
207-
else:
208-
c_unconstrained = jla.triangular_solve(
209-
R[0], c_unconstrained, left_side=True
198+
for i in range(N):
199+
G = grams[i] # noqa: N806
200+
c_unconstrained = jnp.tensordot(
201+
c_unconstrained, jli.pinv(G), axes=(0, 1)
210202
)
211203

212204
def hvp(c: Array):
213205
"""The Hessian-vector product."""
214206
res = c
215207
for i in range(N):
216-
res = jnp.tensordot(res, R[i], axes=(0, 1))
217-
for i in range(N):
218-
res = jnp.tensordot(res, R[i], axes=(0, 0))
208+
G = grams[i] # noqa: N806
209+
res = jnp.tensordot(res, G, axes=(0, 1))
219210
return res
220211

221-
def nnls(c: Array, rhs: Array):
212+
def nnls(c: Array):
222213
"""
223214
Non-negative least-squares solver.
224215
225-
Applies a positive transformation and an L-BFGS
226-
optimizer to ensure non-negativity.
216+
Applies a positive transformation and an L-BFGS optimizer
217+
to ensure non-negativity.
227218
"""
228219

229220
def forward(u: Array) -> Array:
@@ -250,10 +241,6 @@ def make_minimizer():
250241
optax.lbfgs(), atol=atol, rtol=rtol, norm=optimistix.max_norm
251242
)
252243

253-
# compute the right hand side of the normal equation
254-
for i in range(N):
255-
rhs = jnp.tensordot(rhs, R[i], axes=(0, 0))
256-
257244
u = inverse(jnp.abs(c) + jnp.finfo(c.dtype).eps)
258245
optimum = optimistix.minimise(
259246
misfit, make_minimizer(), u, max_steps=max_steps, throw=False
@@ -264,13 +251,7 @@ def make_minimizer():
264251
nnls_needed = jnp.logical_and(
265252
non_negative, jnp.any(c_unconstrained < 0.0)
266253
)
267-
return jax.lax.cond(
268-
nnls_needed,
269-
nnls,
270-
lambda c, _: c,
271-
c_unconstrained,
272-
rhs,
273-
)
254+
return jax.lax.cond(nnls_needed, nnls, lambda c: c, c_unconstrained)
274255

275256

276257
def _lower_bounds(
@@ -327,9 +308,10 @@ def __init__(
327308
:param a: The lower bounds of the grid coordinates.
328309
:param b: The upper bounds of the grid coordinates.
329310
"""
311+
N = len(x) # noqa: : N806
330312
a = _lower_bounds(a, x)
331313
b = _upper_bounds(b, x)
332-
x = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x)
314+
x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N))
333315

334316
def f(c: Array) -> Array:
335317
r"""
@@ -420,9 +402,10 @@ def from_lookup_table(
420402
:param rtol: The relative tolerance for terminating the solver.
421403
:param max_steps: The maximum number of steps the solver can take.
422404
"""
405+
N = len(k) # noqa: : N806
423406
a = _lower_bounds(a, x)
424407
b = _upper_bounds(b, x)
425-
x_ = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x)
408+
x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N))
426409
y_ = jnp.asarray(y)
427410
c_ = b_solve(
428411
k,

0 commit comments

Comments
 (0)