diff --git a/README.md b/README.md index 627488f0f..5233f50d2 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. +BrainPy requires ``jax<0.6.0``. + For detailed installation instructions, please refer to the documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html) diff --git a/brainpy/_src/integrators/_jaxpr_to_source_code.py b/brainpy/_src/integrators/_jaxpr_to_source_code.py index 3fa1d9006..6c8c0c9e8 100644 --- a/brainpy/_src/integrators/_jaxpr_to_source_code.py +++ b/brainpy/_src/integrators/_jaxpr_to_source_code.py @@ -25,7 +25,10 @@ import jax.numpy as jnp import numpy as np from jax._src.sharding_impls import UNSPECIFIED -from jax.core import Literal, Var, Jaxpr +if jax.__version__ >= '0.5.0': + from jax.extend.core import Primitive, Literal, Var, Jaxpr +else: + from jax.core import Primitive, Literal, Var, Jaxpr __all__ = [ 'fn_to_python_code', @@ -187,7 +190,7 @@ def fn_to_python_code(fn, *args, **kwargs): return source -def jaxpr_to_python_code(jaxpr: jax.core.Jaxpr, +def jaxpr_to_python_code(jaxpr: Jaxpr, fn_name: str = "generated_function"): """ Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr. @@ -367,7 +370,7 @@ def _maybe_wrap_fn_for_leaves(node, f, num_args): def jaxpr_to_py_ast(state: SourcerorState, - jaxpr: jax.core.Jaxpr, + jaxpr: Jaxpr, fn_name: str = "function"): # Generate argument declarations ast_args = [ast.arg(arg=state.str_name(var), annotation=None) @@ -405,7 +408,7 @@ def jaxpr_to_py_ast(state: SourcerorState, return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[]) -def constant_fold_jaxpr(jaxpr: jax.core.Jaxpr): +def constant_fold_jaxpr(jaxpr: Jaxpr): """ Given a jaxpr, return a new jaxpr with all constant folding done. """ diff --git a/brainpy/_src/math/remove_vmap.py b/brainpy/_src/math/remove_vmap.py index ee81c0c17..25991e4aa 100644 --- a/brainpy/_src/math/remove_vmap.py +++ b/brainpy/_src/math/remove_vmap.py @@ -2,7 +2,12 @@ import jax.numpy as jnp -from jax.core import Primitive, ShapedArray +import jax +if jax.__version__ >= '0.5.0': + from jax.extend.core import Primitive +else: + from jax.core import Primitive +from jax.core import ShapedArray from jax.interpreters import batching, mlir, xla from .ndarray import Array diff --git a/brainpy/_src/math/sparse/utils.py b/brainpy/_src/math/sparse/utils.py index 38cfdb7b9..0c7f58348 100644 --- a/brainpy/_src/math/sparse/utils.py +++ b/brainpy/_src/math/sparse/utils.py @@ -4,6 +4,8 @@ from functools import partial from typing import Tuple + +import jax import numpy as np from brainpy._src.math.interoperability import as_jax from jax import core, numpy as jnp @@ -12,6 +14,10 @@ from jax.interpreters import mlir, ad from jax.tree_util import tree_flatten, tree_unflatten from jaxlib import gpu_sparse +if jax.__version__ >= '0.5.0': + from jax.extend.core import Primitive +else: + from jax.core import Primitive __all__ = [ 'coo_to_csr', @@ -171,7 +177,7 @@ def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape): return _csr_extract(indices, indptr, ct), indices, indptr -csr_to_dense_p = core.Primitive('csr_to_dense') +csr_to_dense_p = Primitive('csr_to_dense') csr_to_dense_p.def_impl(_csr_to_dense_impl) csr_to_dense_p.def_abstract_eval(_csr_to_dense_abstract_eval) ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None) diff --git a/brainpy/_src/math/surrogate/_one_input_new.py b/brainpy/_src/math/surrogate/_one_input_new.py index ed9957261..142c2b695 100644 --- a/brainpy/_src/math/surrogate/_one_input_new.py +++ b/brainpy/_src/math/surrogate/_one_input_new.py @@ -5,7 +5,11 @@ import jax import jax.numpy as jnp import jax.scipy as sci -from jax.core import Primitive + +if jax.__version__ >= '0.5.0': + from jax.extend.core import Primitive +else: + from jax.core import Primitive from jax.interpreters import batching, ad, mlir from brainpy._src.math.interoperability import as_jax diff --git a/requirements-dev.txt b/requirements-dev.txt index 3931bd501..247b796ad 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ numpy -jax -jaxlib +jax<0.6.0 +jaxlib<0.6.0 absl-py<=2.1.0 brainstate<=0.1.0.post20241210 braintaichi<=0.0.4 diff --git a/requirements-doc.txt b/requirements-doc.txt index 5c6d440ee..e66303265 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -1,6 +1,6 @@ tqdm -jax -jaxlib +jax<0.6.0 +jaxlib<0.6.0 matplotlib numpy scipy diff --git a/requirements.txt b/requirements.txt index ab5665e73..5216d647c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ numpy -jax +jax<0.6.0 tqdm diff --git a/setup.py b/setup.py index e76727d70..73d4dc743 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.9', - install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'], + install_requires=['numpy>=1.15', 'jax>=0.4.13,<0.6.0', 'tqdm'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues",