Skip to content

Commit bcf350f

Browse files
committed
Fix and test array_api_obj, is_writable_array, is_lazy_array
1 parent 4f1e7f0 commit bcf350f

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

array_api_compat/common/_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def _is_array_api_cls(cls: type) -> bool:
300300
or _issubclass_fast(cls, "sparse", "SparseArray")
301301
# TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
302302
or _issubclass_fast(cls, "jax", "Array")
303+
or _issubclass_fast(cls, "jax.core", "Tracer")
303304
)
304305

305306

@@ -938,6 +939,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
938939
if (
939940
_issubclass_fast(cls, "numpy", "generic")
940941
or _issubclass_fast(cls, "jax", "Array")
942+
or _issubclass_fast(cls, "jax.core", "Tracer")
941943
or _issubclass_fast(cls, "sparse", "SparseArray")
942944
):
943945
return False
@@ -977,6 +979,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
977979
return False
978980
if (
979981
_issubclass_fast(cls, "jax", "Array")
982+
or _issubclass_fast(cls, "jax.core", "Tracer")
980983
or _issubclass_fast(cls, "dask.array", "Array")
981984
or _issubclass_fast(cls, "ndonnx", "Array")
982985
):

tests/test_common.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,6 @@ def test_is_xp_array(library, func):
5656
assert is_array_api_obj(x)
5757

5858

59-
def test_is_jax_array_jitted():
60-
jax = pytest.importorskip("jax")
61-
import jax.numpy as jnp
62-
63-
x = jnp.asarray([1, 2, 3])
64-
assert is_jax_array(x)
65-
assert jax.jit(lambda y: is_jax_array(y))(x)
66-
67-
6859
@pytest.mark.parametrize('library', is_namespace_functions.keys())
6960
@pytest.mark.parametrize('func', is_namespace_functions.values())
7061
def test_is_xp_namespace(library, func):

tests/test_jax.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from numpy.testing import assert_equal
22
import pytest
33

4-
from array_api_compat import device, to_device
4+
from array_api_compat import (
5+
device,
6+
to_device,
7+
is_jax_array,
8+
is_lazy_array,
9+
is_array_api_obj,
10+
is_writeable_array,
11+
)
512

613
try:
714
import jax
@@ -13,7 +20,7 @@
1320

1421

1522
@pytest.mark.parametrize(
16-
"func",
23+
"func",
1724
[
1825
lambda x: jnp.zeros(1, device=device(x)),
1926
lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))),
@@ -26,7 +33,7 @@
2633
),
2734
),
2835
lambda x: to_device(jnp.zeros(1), device(x)),
29-
]
36+
],
3037
)
3138
def test_device_jit(func):
3239
# Test work around to https://github.com/jax-ml/jax/issues/26000
@@ -36,3 +43,14 @@ def test_device_jit(func):
3643
x = jnp.ones(1)
3744
assert_equal(func(x), jnp.asarray([0]))
3845
assert_equal(jax.jit(func)(x), jnp.asarray([0]))
46+
47+
48+
def test_inside_jit():
49+
jax = pytest.importorskip("jax")
50+
import jax.numpy as jnp
51+
52+
x = jnp.asarray([1, 2, 3])
53+
assert jax.jit(is_jax_array)(x)
54+
assert jax.jit(is_array_api_obj)(x)
55+
assert not jax.jit(is_writeable_array)(x)
56+
assert jax.jit(is_lazy_array)(x)

0 commit comments

Comments
 (0)