Skip to content

Commit 583bff9

Browse files
committed
[refactor] Improve JAX configuration handling and clean up setup.py
1 parent 013007f commit 583bff9

5 files changed

Lines changed: 94 additions & 121 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu
1919
- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest
2020
- **Source**: https://github.com/brainpy/BrainPy
2121
- **Bug reports**: https://github.com/brainpy/BrainPy/issues
22+
- **Ecosystem**: https://brainmodeling.readthedocs.io/
2223
- **Source on OpenI**: https://git.openi.org.cn/OpenI/BrainPy
2324

2425

brainpy/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,11 @@
153153

154154
del deprecation_getattr2
155155

156-
# jax config
157-
import os
158-
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
159-
import jax
160-
jax.config.update('jax_cpu_enable_async_dispatch', False)
156+
try:
157+
# jax config
158+
import os
159+
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
160+
import jax
161+
jax.config.update('jax_cpu_enable_async_dispatch', False)
162+
except:
163+
pass

examples/dynamics_simulation/hh_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
import numpy as np
4-
53
import brainpy as bp
6-
from jax import pmap
74
import brainpy.math as bm
85

96
bm.set_host_device_count(20)

examples/dynamics_simulation/stdp.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,52 +13,52 @@
1313

1414

1515
class STDPNet(bp.DynSysGroup):
16-
def __init__(self, num_poisson, num_lif=1, g_max=0.01):
17-
super().__init__()
16+
def __init__(self, num_poisson, num_lif=1, g_max=0.01):
17+
super().__init__()
1818

19-
self.g_max = g_max
19+
self.g_max = g_max
2020

21-
# neuron groups
22-
self.noise = bp.dyn.PoissonGroup(num_poisson, freqs=15.)
23-
self.group = bp.dyn.Lif(num_lif, V_reset=-60., V_rest=-74, V_th=-54, tau=10.,
24-
V_initializer=bp.init.Normal(-60., 1.))
21+
# neuron groups
22+
self.noise = bp.dyn.PoissonGroup(num_poisson, freqs=15.)
23+
self.group = bp.dyn.Lif(num_lif, V_reset=-60., V_rest=-74, V_th=-54, tau=10.,
24+
V_initializer=bp.init.Normal(-60., 1.))
2525

26-
# synapses
27-
syn = bp.dyn.Expon.desc(num_lif, tau=5.)
28-
out = bp.dyn.COBA.desc(E=0.)
29-
comm = bp.dnn.AllToAll(num_poisson, num_lif, bp.init.Uniform(0., g_max))
30-
self.syn = bp.dyn.STDP_Song2000(self.noise, None, syn, comm, out, self.group,
31-
tau_s=20, tau_t=20, W_max=g_max, W_min=0.,
32-
A1=0.01 * g_max, A2=0.0105 * g_max)
26+
# synapses
27+
syn = bp.dyn.Expon.desc(num_lif, tau=5.)
28+
out = bp.dyn.COBA.desc(E=0.)
29+
comm = bp.dnn.AllToAll(num_poisson, num_lif, bp.init.Uniform(0., g_max))
30+
self.syn = bp.dyn.STDP_Song2000(self.noise, None, syn, comm, out, self.group,
31+
tau_s=20, tau_t=20, W_max=g_max, W_min=0.,
32+
A1=0.01 * g_max, A2=0.0105 * g_max)
3333

34-
def update(self, *args, **kwargs):
35-
self.noise()
36-
self.syn()
37-
self.group()
38-
return self.syn.comm.weight.flatten()[:10]
34+
def update(self, *args, **kwargs):
35+
self.noise()
36+
self.syn()
37+
self.group()
38+
return self.syn.comm.weight.flatten()[:10]
3939

4040

4141
def run_model():
42-
net = STDPNet(1000, 1)
43-
indices = np.arange(int(100.0e3 / bm.dt)) # 100 s
44-
ws = bm.for_loop(net.step_run, indices, progress_bar=True)
45-
weight = bm.as_numpy(net.syn.comm.weight.flatten())
42+
net = STDPNet(1000, 1)
43+
indices = np.arange(int(100.0e3 / bm.dt)) # 100 s
44+
ws = bm.for_loop(net.step_run, indices, progress_bar=True)
45+
weight = bm.as_numpy(net.syn.comm.weight.flatten())
4646

47-
fig, gs = bp.visualize.get_figure(3, 1, 3, 10)
48-
fig.add_subplot(gs[0, 0])
49-
plt.plot(weight / net.g_max, '.k')
50-
plt.xlabel('Weight / gmax')
47+
fig, gs = bp.visualize.get_figure(3, 1, 3, 10)
48+
fig.add_subplot(gs[0, 0])
49+
plt.plot(weight / net.g_max, '.k')
50+
plt.xlabel('Weight / gmax')
5151

52-
fig.add_subplot(gs[1, 0])
53-
plt.hist(weight / net.g_max, 20)
54-
plt.xlabel('Weight / gmax')
52+
fig.add_subplot(gs[1, 0])
53+
plt.hist(weight / net.g_max, 20)
54+
plt.xlabel('Weight / gmax')
5555

56-
fig.add_subplot(gs[2, 0])
57-
plt.plot(indices * bm.dt, bm.as_numpy(ws) / net.g_max)
58-
plt.xlabel('Time (s)')
59-
plt.ylabel('Weight / gmax')
60-
plt.show()
56+
fig.add_subplot(gs[2, 0])
57+
plt.plot(indices * bm.dt, bm.as_numpy(ws) / net.g_max)
58+
plt.xlabel('Time (s)')
59+
plt.ylabel('Weight / gmax')
60+
plt.show()
6161

6262

6363
if __name__ == '__main__':
64-
run_model()
64+
run_model()

setup.py

Lines changed: 49 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,96 +3,68 @@
33
import io
44
import os
55
import re
6-
import time
7-
import sys
8-
9-
from setuptools import find_packages
10-
from setuptools import setup
11-
12-
try:
13-
# require users to uninstall previous brainpy releases.
14-
import pkg_resources
15-
16-
installed_packages = pkg_resources.working_set
17-
for i in installed_packages:
18-
if i.key == 'brainpy-simulator':
19-
raise SystemError('Please uninstall the older version of brainpy '
20-
f'package "brainpy-simulator={i.version}" '
21-
f'(located in {i.location}) first. \n'
22-
'>>> pip uninstall brainpy-simulator')
23-
if i.key == 'brain-py':
24-
raise SystemError('Please uninstall the older version of brainpy '
25-
f'package "brain-py={i.version}" '
26-
f'(located in {i.location}) first. \n'
27-
'>>> pip uninstall brain-py')
28-
except ModuleNotFoundError:
29-
pass
306

7+
from setuptools import find_packages, setup
318

329
# version
3310
here = os.path.abspath(os.path.dirname(__file__))
3411
with open(os.path.join(here, 'brainpy', '__init__.py'), 'r') as f:
35-
init_py = f.read()
12+
init_py = f.read()
3613
version = re.search('__version__ = "(.*)"', init_py).groups()[0]
37-
if len(sys.argv) > 2 and sys.argv[2] == '--python-tag=py3':
38-
version = version
39-
else:
40-
version += '.post{}'.format(time.strftime("%Y%m%d", time.localtime()))
4114

4215
# obtain long description from README
4316
with io.open(os.path.join(here, 'README.md'), 'r', encoding='utf-8') as f:
44-
README = f.read()
17+
README = f.read()
4518

4619
# installation packages
4720
packages = find_packages(exclude=['lib*', 'docs', 'tests'])
4821

4922
# setup
5023
setup(
51-
name='brainpy',
52-
version=version,
53-
description='BrainPy: Brain Dynamics Programming in Python',
54-
long_description=README,
55-
long_description_content_type="text/markdown",
56-
author='BrainPy Team',
57-
author_email='chao.brain@qq.com',
58-
packages=packages,
59-
python_requires='>=3.10',
60-
install_requires=['numpy>=1.15', 'jax>=0.4.13,<0.6.0', 'tqdm'],
61-
url='https://github.com/brainpy/BrainPy',
62-
project_urls={
63-
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
64-
"Documentation": "https://brainpy.readthedocs.io/",
65-
"Source Code": "https://github.com/brainpy/BrainPy",
66-
},
67-
dependency_links=[
68-
'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html',
69-
],
70-
extras_require={
71-
'cpu': ['jaxlib>=0.4.13', 'numba', 'braintaichi'],
72-
'cuda12': ['jaxlib[cuda12_pip]', 'numba', 'braintaichi'],
73-
'tpu': ['jaxlib[tpu]', 'numba',],
74-
'cpu_mini': ['jaxlib>=0.4.13'],
75-
'cuda12_mini': ['jaxlib[cuda12_pip]'],
76-
},
77-
keywords=('computational neuroscience, '
78-
'brain-inspired computation, '
79-
'brain modeling, '
80-
'brain dynamics modeling, '
81-
'brain dynamics programming'),
82-
classifiers=[
83-
'Natural Language :: English',
84-
'Operating System :: OS Independent',
85-
'Programming Language :: Python',
86-
'Programming Language :: Python :: 3',
87-
'Programming Language :: Python :: 3.10',
88-
'Programming Language :: Python :: 3.11',
89-
'Programming Language :: Python :: 3.12',
90-
'Intended Audience :: Science/Research',
91-
'License :: OSI Approved :: Apache Software License',
92-
'Topic :: Scientific/Engineering :: Bio-Informatics',
93-
'Topic :: Scientific/Engineering :: Mathematics',
94-
'Topic :: Scientific/Engineering :: Artificial Intelligence',
95-
'Topic :: Software Development :: Libraries',
96-
],
97-
license='GPL-3.0 license',
24+
name='brainpy',
25+
version=version,
26+
description='BrainPy: Brain Dynamics Programming in Python',
27+
long_description=README,
28+
long_description_content_type="text/markdown",
29+
author='BrainPy Team',
30+
author_email='chao.brain@qq.com',
31+
packages=packages,
32+
python_requires='>=3.10',
33+
install_requires=['numpy>=1.15', 'jax', 'tqdm'],
34+
url='https://github.com/brainpy/BrainPy',
35+
project_urls={
36+
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
37+
"Documentation": "https://brainpy.readthedocs.io/",
38+
"Source Code": "https://github.com/brainpy/BrainPy",
39+
},
40+
dependency_links=[
41+
'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html',
42+
],
43+
extras_require={
44+
'cpu': ['jax[cpu]<0.5.0', 'numba', 'braintaichi'],
45+
'cuda12': ['jax[cuda12]<0.5.0', 'numba', 'braintaichi'],
46+
'tpu': ['jax[tpu]<0.5.0', 'numba', ],
47+
},
48+
keywords=('computational neuroscience, '
49+
'brain-inspired computation, '
50+
'brain modeling, '
51+
'brain dynamics modeling, '
52+
'brain dynamics programming'),
53+
classifiers=[
54+
'Natural Language :: English',
55+
'Operating System :: OS Independent',
56+
'Programming Language :: Python',
57+
'Programming Language :: Python :: 3',
58+
'Programming Language :: Python :: 3.10',
59+
'Programming Language :: Python :: 3.11',
60+
'Programming Language :: Python :: 3.12',
61+
'Programming Language :: Python :: 3.13',
62+
'Intended Audience :: Science/Research',
63+
'License :: OSI Approved :: Apache Software License',
64+
'Topic :: Scientific/Engineering :: Bio-Informatics',
65+
'Topic :: Scientific/Engineering :: Mathematics',
66+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
67+
'Topic :: Software Development :: Libraries',
68+
],
69+
license='GPL-3.0 license',
9870
)

0 commit comments

Comments
 (0)