We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3e7eb8f commit 2995effCopy full SHA for 2995eff
1 file changed
brainpy/math/environment.py
@@ -25,7 +25,6 @@
25
import brainstate.environ
26
import jax
27
from jax import config, numpy as jnp, devices
28
-from jax.lib import xla_bridge
29
30
from . import modes
31
from . import scales
@@ -733,8 +732,13 @@ def clear_buffer_memory(
733
732
Clear name cache. Default is True.
734
735
"""
+ if jax.__version_info__ < (0, 8, 0):
736
+ from jax.lib.xla_bridge import get_backend
737
+ else:
738
+ from jax.extend.backend import get_backend
739
+
740
if array:
- for buf in xla_bridge.get_backend(platform).live_buffers():
741
+ for buf in get_backend(platform).live_buffers():
742
buf.delete()
743
if compilation:
744
jax.clear_caches()
0 commit comments