|
34 | 34 |
|
35 | 35 |
|
36 | 36 | def _logical_or(a, b): |
37 | | - a = a.value if isinstance(a, bm.Array) else a |
38 | | - b = b.value if isinstance(b, bm.Array) else b |
| 37 | + a = a.value if isinstance(a, bm.BaseArray) else a |
| 38 | + b = b.value if isinstance(b, bm.BaseArray) else b |
39 | 39 | return jnp.logical_or(a, b) |
40 | 40 |
|
41 | 41 |
|
42 | 42 | def _logical_and(a, b): |
43 | | - a = a.value if isinstance(a, bm.Array) else a |
44 | | - b = b.value if isinstance(b, bm.Array) else b |
| 43 | + a = a.value if isinstance(a, bm.BaseArray) else a |
| 44 | + b = b.value if isinstance(b, bm.BaseArray) else b |
45 | 45 | return jnp.logical_and(a, b) |
46 | 46 |
|
47 | 47 |
|
48 | 48 | def _where(p, a, b): |
49 | | - p = p.value if isinstance(p, bm.Array) else p |
50 | | - a = a.value if isinstance(a, bm.Array) else a |
51 | | - b = b.value if isinstance(b, bm.Array) else b |
| 49 | + p = p.value if isinstance(p, bm.BaseArray) else p |
| 50 | + a = a.value if isinstance(a, bm.BaseArray) else a |
| 51 | + b = b.value if isinstance(b, bm.BaseArray) else b |
52 | 52 | return jnp.where(p, a, b) |
53 | 53 |
|
54 | 54 |
|
@@ -175,7 +175,7 @@ def get_brentq_candidates(f, xs, ys): |
175 | 175 |
|
176 | 176 | def brentq_candidates(vmap_f, *values, args=()): |
177 | 177 | # change the position of meshgrid values |
178 | | - values = tuple((v.value if isinstance(v, bm.Array) else v) for v in values) |
| 178 | + values = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in values) |
179 | 179 | xs = values[0] |
180 | 180 | mesh_values = jnp.meshgrid(*values) |
181 | 181 | if jnp.ndim(mesh_values[0]) > 1: |
@@ -348,7 +348,7 @@ def scipy_minimize_with_jax(fun, x0, |
348 | 348 | def fun_wrapper(x_flat, *args): |
349 | 349 | x = unravel(x_flat) |
350 | 350 | r = fun(x, *args) |
351 | | - r = r.value if isinstance(r, bm.Array) else r |
| 351 | + r = r.value if isinstance(r, bm.BaseArray) else r |
352 | 352 | return float(r) |
353 | 353 |
|
354 | 354 | # Wrap the gradient in a similar manner |
@@ -386,8 +386,8 @@ def roots_of_1d_by_x(f, candidates, args=()): |
386 | 386 | """Find the roots of the given function by numerical methods. |
387 | 387 | """ |
388 | 388 | f = f_without_jaxarray_return(f) |
389 | | - candidates = candidates.value if isinstance(candidates, bm.Array) else candidates |
390 | | - args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args) |
| 389 | + candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates |
| 390 | + args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args) |
391 | 391 | vals = f(candidates, *args) |
392 | 392 | signs = jnp.sign(vals) |
393 | 393 | zero_sign_idx = jnp.where(signs == 0)[0] |
|
0 commit comments