Skip to content

Commit 928d482

Browse files
authored
Refactor normalization of grid coordinates in jax.py
1 parent 54297d8 commit 928d482

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

uncertaintyx/b/jax.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,11 @@ def __init__(
308308
:param a: The lower bounds of the grid coordinates.
309309
:param b: The upper bounds of the grid coordinates.
310310
"""
311-
N = len(x) # noqa: : N806
312311
a = _lower_bounds(a, x)
313312
b = _upper_bounds(b, x)
314-
x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N))
313+
x_ = tuple(
314+
jnp.asarray((_ - a_) / (b_ - a_)) for _, a_, b_ in zip(x, a, b)
315+
)
315316

316317
def f(c: Array) -> Array:
317318
r"""
@@ -320,7 +321,7 @@ def f(c: Array) -> Array:
320321
321322
:param c: The coefficients :math:`c \in \mathbb{R}^{k + 1}`.
322323
"""
323-
return b_poly_grid(c, x)
324+
return b_poly_grid(c, x_)
324325

325326
super().__init__(f, rev=False)
326327

@@ -402,10 +403,11 @@ def from_lookup_table(
402403
:param rtol: The relative tolerance for terminating the solver.
403404
:param max_steps: The maximum number of steps the solver can take.
404405
"""
405-
N = len(k) # noqa: : N806
406406
a = _lower_bounds(a, x)
407407
b = _upper_bounds(b, x)
408-
x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N))
408+
x_ = tuple(
409+
jnp.asarray((_ - a_) / (b_ - a_)) for _, a_, b_ in zip(x, a, b)
410+
)
409411
y_ = jnp.asarray(y)
410412
c_ = b_solve(
411413
k,

0 commit comments

Comments
 (0)