Skip to content

Commit 2995eff

Browse files
committed
fix: update backend import for compatibility with jax>=0.8.0
1 parent 3e7eb8f commit 2995eff

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

brainpy/math/environment.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import brainstate.environ
2626
import jax
2727
from jax import config, numpy as jnp, devices
28-
from jax.lib import xla_bridge
2928

3029
from . import modes
3130
from . import scales
@@ -733,8 +732,13 @@ def clear_buffer_memory(
733732
Clear name cache. Default is True.
734733
735734
"""
735+
if jax.__version_info__ < (0, 8, 0):
736+
from jax.lib.xla_bridge import get_backend
737+
else:
738+
from jax.extend.backend import get_backend
739+
736740
if array:
737-
for buf in xla_bridge.get_backend(platform).live_buffers():
741+
for buf in get_backend(platform).live_buffers():
738742
buf.delete()
739743
if compilation:
740744
jax.clear_caches()

0 commit comments

Comments
 (0)