Skip to content

Commit 86c7708

Browse files
committed
Update JAX import paths for compatibility with version 0.5.0 and update installation requirements in setup.py
1 parent d24ee76 commit 86c7708

6 files changed

Lines changed: 28 additions & 8 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu
2727

2828
BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms.
2929

30+
BrainPy requires ``jax<0.6.0``.
31+
3032
For detailed installation instructions, please refer to the documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html)
3133

3234

brainpy/_src/integrators/_jaxpr_to_source_code.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
import jax.numpy as jnp
2626
import numpy as np
2727
from jax._src.sharding_impls import UNSPECIFIED
28-
from jax.core import Literal, Var, Jaxpr
28+
if jax.__version__ >= '0.5.0':
29+
from jax.extend.core import Primitive, Literal, Var, Jaxpr
30+
else:
31+
from jax.core import Primitive, Literal, Var, Jaxpr
2932

3033
__all__ = [
3134
'fn_to_python_code',
@@ -187,7 +190,7 @@ def fn_to_python_code(fn, *args, **kwargs):
187190
return source
188191

189192

190-
def jaxpr_to_python_code(jaxpr: jax.core.Jaxpr,
193+
def jaxpr_to_python_code(jaxpr: Jaxpr,
191194
fn_name: str = "generated_function"):
192195
"""
193196
Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr.
@@ -367,7 +370,7 @@ def _maybe_wrap_fn_for_leaves(node, f, num_args):
367370

368371

369372
def jaxpr_to_py_ast(state: SourcerorState,
370-
jaxpr: jax.core.Jaxpr,
373+
jaxpr: Jaxpr,
371374
fn_name: str = "function"):
372375
# Generate argument declarations
373376
ast_args = [ast.arg(arg=state.str_name(var), annotation=None)
@@ -405,7 +408,7 @@ def jaxpr_to_py_ast(state: SourcerorState,
405408
return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[])
406409

407410

408-
def constant_fold_jaxpr(jaxpr: jax.core.Jaxpr):
411+
def constant_fold_jaxpr(jaxpr: Jaxpr):
409412
"""
410413
Given a jaxpr, return a new jaxpr with all constant folding done.
411414
"""

brainpy/_src/math/remove_vmap.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33

44
import jax.numpy as jnp
5-
from jax.core import Primitive, ShapedArray
5+
import jax
6+
if jax.__version__ >= '0.5.0':
7+
from jax.extend.core import Primitive
8+
else:
9+
from jax.core import Primitive
10+
from jax.core import ShapedArray
611
from jax.interpreters import batching, mlir, xla
712
from .ndarray import Array
813

brainpy/_src/math/sparse/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from functools import partial
55
from typing import Tuple
66

7+
8+
import jax
79
import numpy as np
810
from brainpy._src.math.interoperability import as_jax
911
from jax import core, numpy as jnp
@@ -12,6 +14,10 @@
1214
from jax.interpreters import mlir, ad
1315
from jax.tree_util import tree_flatten, tree_unflatten
1416
from jaxlib import gpu_sparse
17+
if jax.__version__ >= '0.5.0':
18+
from jax.extend.core import Primitive
19+
else:
20+
from jax.core import Primitive
1521

1622
__all__ = [
1723
'coo_to_csr',
@@ -171,7 +177,7 @@ def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape):
171177
return _csr_extract(indices, indptr, ct), indices, indptr
172178

173179

174-
csr_to_dense_p = core.Primitive('csr_to_dense')
180+
csr_to_dense_p = Primitive('csr_to_dense')
175181
csr_to_dense_p.def_impl(_csr_to_dense_impl)
176182
csr_to_dense_p.def_abstract_eval(_csr_to_dense_abstract_eval)
177183
ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None)

brainpy/_src/math/surrogate/_one_input_new.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import jax
66
import jax.numpy as jnp
77
import jax.scipy as sci
8-
from jax.core import Primitive
8+
9+
if jax.__version__ >= '0.5.0':
10+
from jax.extend.core import Primitive
11+
else:
12+
from jax.core import Primitive
913
from jax.interpreters import batching, ad, mlir
1014

1115
from brainpy._src.math.interoperability import as_jax

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
author_email='chao.brain@qq.com',
5858
packages=packages,
5959
python_requires='>=3.9',
60-
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'],
60+
install_requires=['numpy>=1.15', 'jax>=0.4.13,<0.6.0', 'tqdm'],
6161
url='https://github.com/brainpy/BrainPy',
6262
project_urls={
6363
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",

0 commit comments

Comments
 (0)