Skip to content

Commit f56e808

Browse files
authored
Improve compatibility with jax (#351)
2 parents 3a41997 + fca0712 commit f56e808

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

pysages/typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
from importlib import import_module
99

1010
import jax
11-
import jaxlib.xla_extension as xe
1211

1312
from pysages._compat import _jax_version_tuple, _plum_version_tuple
1413

1514
# Compatibility for jax >=0.4.1
1615

1716
# https://github.com/google/jax/releases/tag/jax-v0.4.1
1817
if _jax_version_tuple < (0, 4, 1):
18+
xe = import_module("jaxlib.xla_extension")
1919
JaxArray = xe.DeviceArray
20+
del xe
2021
else:
2122
JaxArray = jax.Array
2223

@@ -47,7 +48,6 @@
4748

4849
# Remove namespace noise
4950
del jax
50-
del xe
5151
del import_module
5252
del _jax_version_tuple
5353
del _plum_version_tuple

0 commit comments

Comments
 (0)