Skip to content

Commit e993ac4

Browse files
committed
Fix: triangular solve issue for N=1
1 parent 0224e19 commit e993ac4

2 files changed

Lines changed: 254 additions & 69 deletions

File tree

test/b/test_b_jax.py

Lines changed: 244 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,14 @@ def test_eval(self):
118118
self.assertEqual(y.shape + c.shape, g.shape)
119119
self.assertTrue(np.all(g > 0.0))
120120

121-
u = to_var(0.1 * c)
122-
u = f.lpu(c, u, diag=True)
123-
self.assertEqual(y.shape, u.shape)
124-
self.assertTrue(np.all(u > 0.0))
121+
u = np.square(0.1 * c)
122+
v = f.lpu(c, u, diag=True)
123+
self.assertEqual(y.shape, v.shape)
124+
self.assertTrue(np.all(v > 0.0))
125125

126-
u = to_var(0.1 * c)
127-
u = f.lpu(c, u)
128-
self.assertEqual(y.shape + y.shape, u.shape)
129-
self.assertTrue(np.all(u > 0.0))
126+
v = f.lpu(c, u)
127+
self.assertEqual(y.shape + y.shape, v.shape)
128+
self.assertTrue(np.all(v > 0.0))
130129

131130
def test_jac(self):
132131
c = self.c
@@ -137,30 +136,28 @@ def test_jac(self):
137136
self.assertEqual(y.shape + c.shape, g.shape)
138137
self.assertTrue(np.all(g > 0.0))
139138

140-
u = to_var(0.1 * c)
141-
u = f.lpu(c, u, diag=True)
142-
self.assertEqual(y.shape, u.shape)
143-
self.assertTrue(np.all(u > 0.0))
139+
u = np.square(0.1 * c)
140+
v = f.lpu(c, u, diag=True)
141+
self.assertEqual(y.shape, v.shape)
142+
self.assertTrue(np.all(v > 0.0))
144143

145-
u = to_var(0.1 * c)
146-
u = f.lpu(c, u)
147-
self.assertEqual(y.shape + y.shape, u.shape)
148-
self.assertTrue(np.all(u > 0.0))
144+
v = f.lpu(c, u)
145+
self.assertEqual(y.shape + y.shape, v.shape)
146+
self.assertTrue(np.all(v > 0.0))
149147

150148
def test_lpu(self):
151149
c = self.c
152150
f = self.f
153151
y = f.eval(c)
154152

155-
u = to_var(0.1 * c)
156-
u = f.lpu(c, u, diag=True)
157-
self.assertEqual(y.shape, u.shape)
158-
self.assertTrue(np.all(u > 0.0))
153+
u = np.square(0.1 * c)
154+
v = f.lpu(c, u, diag=True)
155+
self.assertEqual(y.shape, v.shape)
156+
self.assertTrue(np.all(v > 0.0))
159157

160-
u = to_var(0.1 * c)
161-
u = f.lpu(c, u)
162-
self.assertEqual(y.shape + y.shape, u.shape)
163-
self.assertTrue(np.all(u > 0.0))
158+
v = f.lpu(c, u)
159+
self.assertEqual(y.shape + y.shape, v.shape)
160+
self.assertTrue(np.all(v > 0.0))
164161

165162

166163
class BernsteinPolyTest(unittest.TestCase):
@@ -196,10 +193,8 @@ def test_bernstein_poly(self):
196193

197194
def test_from_lookup_table(self):
198195
k = 5
199-
x = np.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
200-
y = np.array( # y = x ** 2 + 2 x + 3
201-
[3.00, 3.44, 3.96, 4.56, 5.24, 6.00]
202-
)
196+
x = np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
197+
y = np.square(x) + 2.0 * x + 3.0
203198

204199
f = BernsteinPoly.from_lookup_table((k,), (x,), y, non_negative=True)
205200
c = f.prior()
@@ -214,55 +209,240 @@ def test_from_lookup_table(self):
214209

215210

216211
class BSolveTest(unittest.TestCase):
217-
"""Tests the solving function."""
212+
"""
213+
Tests the solving function by fitting coefficients to
214+
Bernstein basis polynomials.
215+
"""
218216

219-
def test_b_solve_degree_2(self):
217+
def test_b_solve_0_2(self):
218+
r"""Fit :math:`B_{0,2}(x)`."""
220219
k = 2
221-
x = jnp.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
222-
y = jnp.array( # y = x ** 2 + 2 x + 3
223-
[3.00, 3.44, 3.96, 4.56, 5.24, 6.00]
224-
)
220+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
221+
y = jnp.square(1.0 - x)
225222

226223
c = b_solve((k,), (x,), y, non_negative=True)
227224
self.assertEqual((k + 1,), c.shape)
228-
self.assertAlmostEqual(3.0, c[0].item())
229-
self.assertAlmostEqual(4.0, c[1].item())
230-
self.assertAlmostEqual(6.0, c[2].item())
225+
self.assertAlmostEqual(1.0, c[0].item())
226+
self.assertAlmostEqual(0.0, c[1].item())
227+
self.assertAlmostEqual(0.0, c[2].item())
231228

232-
def test_b_solve_degree_5(self):
233-
k = 5
234-
x = jnp.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
235-
y = (
236-
jnp.array( # y = x ** 2 + 2 x - 1 / 100
237-
[0.00, 0.44, 0.96, 1.56, 2.24, 3.00]
238-
)
239-
- 0.01
240-
)
229+
def test_b_solve_1_2(self):
230+
r"""Fit :math:`B_{1,2}(x)`."""
231+
k = 2
232+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
233+
y = 2.0 * x * (1.0 - x)
241234

242-
c = b_solve((k,), (x,), y)
235+
c = b_solve((k,), (x,), y, non_negative=True)
243236
self.assertEqual((k + 1,), c.shape)
244-
self.assertAlmostEqual(0.00, c[0].item() + 0.01)
245-
self.assertAlmostEqual(0.39, c[1].item())
246-
self.assertAlmostEqual(0.89, c[2].item())
247-
self.assertAlmostEqual(1.49, c[3].item())
248-
self.assertAlmostEqual(2.19, c[4].item())
249-
self.assertAlmostEqual(2.99, c[5].item())
237+
self.assertAlmostEqual(0.0, c[0].item())
238+
self.assertAlmostEqual(1.0, c[1].item())
239+
self.assertAlmostEqual(0.0, c[2].item())
240+
241+
def test_b_solve_2_2(self):
242+
r"""Fit :math:`B_{2,2}(x)`."""
243+
k = 2
244+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
245+
y = jnp.square(x)
250246

251247
c = b_solve((k,), (x,), y, non_negative=True)
252248
self.assertEqual((k + 1,), c.shape)
253-
self.assertAlmostEqual(0.00, c[0].item())
254-
self.assertAlmostEqual(0.38, c[1].item(), places=2)
255-
self.assertAlmostEqual(0.90, c[2].item(), places=2)
256-
self.assertAlmostEqual(1.48, c[3].item(), places=2)
257-
self.assertAlmostEqual(2.19, c[4].item(), places=2)
258-
self.assertAlmostEqual(2.99, c[5].item(), places=2)
249+
self.assertAlmostEqual(0.0, c[0].item())
250+
self.assertAlmostEqual(0.0, c[1].item())
251+
self.assertAlmostEqual(1.0, c[2].item())
259252

253+
def test_b_solve_0_0_2_2(self):
254+
r"""Fit :math:`B_{(0,0),(2,2)}(x_0, x_1)`."""
255+
k = 2
256+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
257+
y = jnp.square(1.0 - x[jnp.newaxis, :]) * jnp.square(
258+
1.0 - x[:, jnp.newaxis]
259+
)
260260

261-
def to_var(u: np.ndarray) -> np.ndarray:
262-
"""
263-
Converts standard uncertainty to a diagonal uncertainty tensor.
264-
"""
265-
return np.square(u)
261+
c = b_solve((k, k), (x, x), y, non_negative=True)
262+
self.assertEqual((k + 1, k + 1), c.shape)
263+
self.assertAlmostEqual(1.0, c[0, 0].item())
264+
self.assertAlmostEqual(0.0, c[0, 1].item())
265+
self.assertAlmostEqual(0.0, c[0, 2].item())
266+
self.assertAlmostEqual(0.0, c[1, 0].item())
267+
self.assertAlmostEqual(0.0, c[1, 1].item())
268+
self.assertAlmostEqual(0.0, c[1, 2].item())
269+
self.assertAlmostEqual(0.0, c[2, 0].item())
270+
self.assertAlmostEqual(0.0, c[2, 1].item())
271+
self.assertAlmostEqual(0.0, c[2, 2].item())
272+
273+
def test_b_solve_1_0_2_2(self):
274+
r"""Fit :math:`B_{(1,0),(2,2)}(x_0, x_1)`."""
275+
k = 2
276+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
277+
y = (
278+
2.0
279+
* x[:, jnp.newaxis]
280+
* (1.0 - x[:, jnp.newaxis])
281+
* jnp.square(1.0 - x[jnp.newaxis, :])
282+
)
283+
284+
c = b_solve((k, k), (x, x), y, non_negative=True)
285+
self.assertEqual((k + 1, k + 1), c.shape)
286+
self.assertAlmostEqual(0.0, c[0, 0].item())
287+
self.assertAlmostEqual(0.0, c[0, 1].item())
288+
self.assertAlmostEqual(0.0, c[0, 2].item())
289+
self.assertAlmostEqual(1.0, c[1, 0].item())
290+
self.assertAlmostEqual(0.0, c[1, 1].item())
291+
self.assertAlmostEqual(0.0, c[1, 2].item())
292+
self.assertAlmostEqual(0.0, c[2, 0].item())
293+
self.assertAlmostEqual(0.0, c[2, 1].item())
294+
self.assertAlmostEqual(0.0, c[2, 2].item())
295+
296+
def test_b_solve_2_0_2_2(self):
297+
r"""Fit :math:`B_{(2,0),(2,2)}(x_0, x_1)`."""
298+
k = 2
299+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
300+
y = jnp.square(x[:, jnp.newaxis]) * jnp.square(
301+
1.0 - x[jnp.newaxis, :]
302+
)
303+
304+
c = b_solve((k, k), (x, x), y, non_negative=True)
305+
self.assertEqual((k + 1, k + 1), c.shape)
306+
self.assertAlmostEqual(0.0, c[0, 0].item())
307+
self.assertAlmostEqual(0.0, c[0, 1].item())
308+
self.assertAlmostEqual(0.0, c[0, 2].item())
309+
self.assertAlmostEqual(0.0, c[1, 0].item())
310+
self.assertAlmostEqual(0.0, c[1, 1].item())
311+
self.assertAlmostEqual(0.0, c[1, 2].item())
312+
self.assertAlmostEqual(1.0, c[2, 0].item())
313+
self.assertAlmostEqual(0.0, c[2, 1].item())
314+
self.assertAlmostEqual(0.0, c[2, 2].item())
315+
316+
def test_b_solve_0_1_2_2(self):
317+
r"""Fit :math:`B_{(0,1),(2,2)}(x_0, x_1)`."""
318+
k = 2
319+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
320+
y = (
321+
2.0
322+
* jnp.square(1.0 - x[:, jnp.newaxis])
323+
* x[jnp.newaxis, :]
324+
* (1.0 - x[jnp.newaxis, :])
325+
)
326+
327+
c = b_solve((k, k), (x, x), y, non_negative=True)
328+
self.assertEqual((k + 1, k + 1), c.shape)
329+
self.assertAlmostEqual(0.0, c[0, 0].item())
330+
self.assertAlmostEqual(1.0, c[0, 1].item())
331+
self.assertAlmostEqual(0.0, c[0, 2].item())
332+
self.assertAlmostEqual(0.0, c[1, 0].item())
333+
self.assertAlmostEqual(0.0, c[1, 1].item())
334+
self.assertAlmostEqual(0.0, c[1, 2].item())
335+
self.assertAlmostEqual(0.0, c[2, 0].item())
336+
self.assertAlmostEqual(0.0, c[2, 1].item())
337+
self.assertAlmostEqual(0.0, c[2, 2].item())
338+
339+
def test_b_solve_1_1_2_2(self):
340+
r"""Fit :math:`B_{(1,1),(2,2)}(x_0, x_1)`."""
341+
k = 2
342+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
343+
y = (
344+
4.0
345+
* x[:, jnp.newaxis]
346+
* (1.0 - x[:, jnp.newaxis])
347+
* x[jnp.newaxis, :]
348+
* (1.0 - x[jnp.newaxis, :])
349+
)
350+
351+
c = b_solve((k, k), (x, x), y, non_negative=True)
352+
self.assertEqual((k + 1, k + 1), c.shape)
353+
self.assertAlmostEqual(0.0, c[0, 0].item())
354+
self.assertAlmostEqual(0.0, c[0, 1].item())
355+
self.assertAlmostEqual(0.0, c[0, 2].item())
356+
self.assertAlmostEqual(0.0, c[1, 0].item())
357+
self.assertAlmostEqual(1.0, c[1, 1].item())
358+
self.assertAlmostEqual(0.0, c[1, 2].item())
359+
self.assertAlmostEqual(0.0, c[2, 0].item())
360+
self.assertAlmostEqual(0.0, c[2, 1].item())
361+
self.assertAlmostEqual(0.0, c[2, 2].item())
362+
363+
def test_b_solve_2_1_2_2(self):
364+
r"""Fit :math:`B_{(2,1),(2,2)}(x_0, x_1)`."""
365+
k = 2
366+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
367+
y = (
368+
2.0
369+
* jnp.square(x[:, jnp.newaxis])
370+
* x[jnp.newaxis, :]
371+
* (1.0 - x[jnp.newaxis, :])
372+
)
373+
374+
c = b_solve((k, k), (x, x), y, non_negative=True)
375+
self.assertEqual((k + 1, k + 1), c.shape)
376+
self.assertAlmostEqual(0.0, c[0, 0].item())
377+
self.assertAlmostEqual(0.0, c[0, 1].item())
378+
self.assertAlmostEqual(0.0, c[0, 2].item())
379+
self.assertAlmostEqual(0.0, c[1, 0].item())
380+
self.assertAlmostEqual(0.0, c[1, 1].item())
381+
self.assertAlmostEqual(0.0, c[1, 2].item())
382+
self.assertAlmostEqual(0.0, c[2, 0].item())
383+
self.assertAlmostEqual(1.0, c[2, 1].item())
384+
self.assertAlmostEqual(0.0, c[2, 2].item())
385+
386+
def test_b_solve_0_2_2_2(self):
387+
r"""Fit :math:`B_{(0,2),(2,2)}(x_0, x_1)`."""
388+
k = 2
389+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
390+
y = jnp.square(1.0 - x[:, jnp.newaxis]) * jnp.square(
391+
x[jnp.newaxis, :]
392+
)
393+
394+
c = b_solve((k, k), (x, x), y, non_negative=True)
395+
self.assertEqual((k + 1, k + 1), c.shape)
396+
self.assertAlmostEqual(0.0, c[0, 0].item())
397+
self.assertAlmostEqual(0.0, c[0, 1].item())
398+
self.assertAlmostEqual(1.0, c[0, 2].item())
399+
self.assertAlmostEqual(0.0, c[1, 0].item())
400+
self.assertAlmostEqual(0.0, c[1, 1].item())
401+
self.assertAlmostEqual(0.0, c[1, 2].item())
402+
self.assertAlmostEqual(0.0, c[2, 0].item())
403+
self.assertAlmostEqual(0.0, c[2, 1].item())
404+
self.assertAlmostEqual(0.0, c[2, 2].item())
405+
406+
def test_b_solve_1_2_2_2(self):
407+
r"""Fit :math:`B_{(1,2),(2,2)}(x_0, x_1)`."""
408+
k = 2
409+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
410+
y = (
411+
2.0
412+
* x[:, jnp.newaxis]
413+
* (1.0 - x[:, jnp.newaxis])
414+
* jnp.square(x[jnp.newaxis, :])
415+
)
416+
417+
c = b_solve((k, k), (x, x), y, non_negative=True)
418+
self.assertEqual((k + 1, k + 1), c.shape)
419+
self.assertAlmostEqual(0.0, c[0, 0].item())
420+
self.assertAlmostEqual(0.0, c[0, 1].item())
421+
self.assertAlmostEqual(0.0, c[0, 2].item())
422+
self.assertAlmostEqual(0.0, c[1, 0].item())
423+
self.assertAlmostEqual(0.0, c[1, 1].item())
424+
self.assertAlmostEqual(1.0, c[1, 2].item())
425+
self.assertAlmostEqual(0.0, c[2, 0].item())
426+
self.assertAlmostEqual(0.0, c[2, 1].item())
427+
self.assertAlmostEqual(0.0, c[2, 2].item())
428+
429+
def test_b_solve_2_2_2_2(self):
430+
r"""Fit :math:`B_{(2,2),(2,2)}(x_0, x_1)`."""
431+
k = 2
432+
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
433+
y = jnp.square(x[:, jnp.newaxis]) * jnp.square(x[jnp.newaxis, :])
434+
435+
c = b_solve((k, k), (x, x), y, non_negative=True)
436+
self.assertEqual((k + 1, k + 1), c.shape)
437+
self.assertAlmostEqual(0.0, c[0, 0].item())
438+
self.assertAlmostEqual(0.0, c[0, 1].item())
439+
self.assertAlmostEqual(0.0, c[0, 2].item())
440+
self.assertAlmostEqual(0.0, c[1, 0].item())
441+
self.assertAlmostEqual(0.0, c[1, 1].item())
442+
self.assertAlmostEqual(0.0, c[1, 2].item())
443+
self.assertAlmostEqual(0.0, c[2, 0].item())
444+
self.assertAlmostEqual(0.0, c[2, 1].item())
445+
self.assertAlmostEqual(1.0, c[2, 2].item())
266446

267447

268448
if __name__ == "__main__":

uncertaintyx/b/jax.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,17 @@ def b_solve(
196196
rhs = jnp.tensordot(rhs, Q[i], axes=(0, 0))
197197
# solve the triangular equation
198198
c_unconstrained = rhs
199-
for i in range(N):
199+
if N > 1:
200+
for i in range(N):
201+
solve = jax.vmap(
202+
lambda a, b: jla.triangular_solve(a, b, left_side=True),
203+
in_axes=(None, i),
204+
out_axes=i,
205+
)
206+
c_unconstrained = solve(R[i], c_unconstrained)
207+
else:
200208
c_unconstrained = jla.triangular_solve(
201-
R[i], c_unconstrained, left_side=True
202-
)
203-
c_unconstrained = jnp.moveaxis( # like the tensor dot product
204-
c_unconstrained, 0, -1
209+
R[0], c_unconstrained, left_side=True
205210
)
206211

207212
def hvp(c: Array):

0 commit comments

Comments
 (0)