Skip to content

Commit 64e31f2

Browse files
authored
Refactor STDP and event ops for new brainevent binary backend (#817)
chore: update version to 2.7.7, enhance README and requirements, and refactor event handling
1 parent a0b47d1 commit 64e31f2

File tree

12 files changed

+35
-36
lines changed

12 files changed

+35
-36
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
__version__ = "2.7.6"
17+
__version__ = "2.7.7"
1818
__version_info__ = tuple(map(int, __version__.split(".")))
1919

2020
from brainpy import _errors as errors
@@ -133,20 +133,14 @@
133133
synouts, # synaptic output
134134
synplast, # synaptic plasticity
135135
)
136-
from brainpy.math.object_transform.base import (
137-
Base as Base,
138-
)
136+
from brainpy.math.object_transform.base import Base as Base
139137

140138
from brainpy.math.object_transform.collectors import (
141139
ArrayCollector as ArrayCollector,
142140
Collector as Collector,
143141
)
144142

145-
from brainpy.deprecations import deprecation_getattr
146-
147143
optimizers = optim
148144

149-
150145
# New package
151146
from brainpy import state
152-

brainpy/dnn/linear.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
import jax
2020
import jax.numpy as jnp
2121
import numpy as np
22-
from brainevent import csr_on_pre, csr2csc_on_post
23-
from brainevent import dense_on_pre, dense_on_post
22+
from brainevent import (
23+
update_csr_on_binary_pre,
24+
update_csr_on_binary_post,
25+
update_dense_on_binary_pre,
26+
update_dense_on_binary_post,
27+
)
2428

2529
from brainpy import connect, initialize as init
2630
from brainpy import math as bm
@@ -226,11 +230,11 @@ def stdp_update(
226230
if on_pre is not None:
227231
spike = on_pre['spike']
228232
trace = on_pre['trace']
229-
self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max)
233+
self.W.value = update_dense_on_binary_pre(self.W.value, spike, trace, w_min, w_max)
230234
if on_post is not None:
231235
spike = on_post['spike']
232236
trace = on_post['trace']
233-
self.W.value = dense_on_post(self.W.value, trace, spike, w_min, w_max)
237+
self.W.value = update_dense_on_binary_post(self.W.value, trace, spike, w_min, w_max)
234238

235239

236240
Linear = Dense
@@ -321,11 +325,11 @@ def stdp_update(
321325
if on_pre is not None:
322326
spike = on_pre['spike']
323327
trace = on_pre['trace']
324-
self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
328+
self.weight.value = update_dense_on_binary_pre(self.weight.value, spike, trace, w_min, w_max)
325329
if on_post is not None:
326330
spike = on_post['spike']
327331
trace = on_post['trace']
328-
self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max)
332+
self.weight.value = update_dense_on_binary_post(self.weight.value, trace, spike, w_min, w_max)
329333

330334

331335
class OneToOne(Layer, SupportSTDP):
@@ -449,11 +453,11 @@ def stdp_update(
449453
if on_pre is not None:
450454
spike = on_pre['spike']
451455
trace = on_pre['trace']
452-
self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
456+
self.weight.value = update_dense_on_binary_pre(self.weight.value, spike, trace, w_min, w_max)
453457
if on_post is not None:
454458
spike = on_post['spike']
455459
trace = on_post['trace']
456-
self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max)
460+
self.weight.value = update_dense_on_binary_post(self.weight.value, trace, spike, w_min, w_max)
457461

458462

459463
class _CSRLayer(Layer, SupportSTDP):
@@ -500,7 +504,7 @@ def stdp_update(
500504
if on_pre is not None: # update on presynaptic spike
501505
spike = on_pre['spike']
502506
trace = on_pre['trace']
503-
self.weight.value = csr_on_pre(
507+
self.weight.value = update_csr_on_binary_pre(
504508
self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max,
505509
shape=(spike.shape[0], trace.shape[0]),
506510
)
@@ -512,7 +516,7 @@ def stdp_update(
512516
)
513517
spike = on_post['spike']
514518
trace = on_post['trace']
515-
self.weight.value = csr2csc_on_post(
519+
self.weight.value = update_csr_on_binary_post(
516520
self.weight.value, self._pre_ids, self._post_indptr,
517521
self.w_indices, trace, spike, w_min, w_max,
518522
shape=(trace.shape[0], spike.shape[0]),

brainpy/initialize/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,9 @@
1312
# See the License for the specific language governing permissions and
1413
# limitations under the License.
1514
# ==============================================================================
15+
16+
# -*- coding: utf-8 -*-
17+
1618
import abc
1719

1820
__all__ = [

brainpy/math/compat_numpy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan',
3737
'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
3838
'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round',
39-
'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod',
39+
'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'prod',
4040
'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum',
4141
'cumprod', 'cumsum', 'ediff1d', 'cross', 'isfinite', 'isinf',
4242
'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve',
@@ -397,7 +397,6 @@ def msort(a):
397397
floor = _compatible_with_brainpy_array(jnp.floor)
398398
ceil = _compatible_with_brainpy_array(jnp.ceil)
399399
trunc = _compatible_with_brainpy_array(jnp.trunc)
400-
fix = _compatible_with_brainpy_array(jnp.fix)
401400
prod = _compatible_with_brainpy_array(jnp.prod)
402401

403402
sum = _compatible_with_brainpy_array(jnp.sum)

brainpy/math/event/csr_matmat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def csrmm(
5959
if isinstance(matrix, Array):
6060
matrix = matrix.value
6161

62-
matrix = brainevent.EventArray(matrix)
62+
matrix = brainevent.BinaryArray(matrix)
6363
csr = brainevent.CSR((data, indices, indptr), shape=shape)
6464
if transpose:
6565
return matrix @ csr

brainpy/math/event/csr_matvec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def csrmv(
8484
if isinstance(events, Array):
8585
events = events.value
8686

87-
events = brainevent.EventArray(events)
87+
events = brainevent.BinaryArray(events)
8888
csr = brainevent.CSR((data, indices, indptr), shape=shape)
8989
if transpose:
9090
return events @ csr

brainpy/math/jitconn/event_matvec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def event_mv_prob_homo(
4949
if isinstance(weight, Array):
5050
weight = weight.value
5151

52-
events = brainevent.EventArray(events)
53-
csr = brainevent.JITCHomoR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel)
52+
events = brainevent.BinaryArray(events)
53+
csr = brainevent.JITCScalarR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel)
5454
if transpose:
5555
return events @ csr
5656
else:
@@ -75,7 +75,7 @@ def event_mv_prob_uniform(
7575
seed = np.random.randint(0, 1000000000)
7676
if isinstance(events, Array):
7777
events = events.value
78-
events = brainevent.EventArray(events)
78+
events = brainevent.BinaryArray(events)
7979
if isinstance(w_low, Array):
8080
w_low = w_low.value
8181
if isinstance(w_high, Array):
@@ -106,7 +106,7 @@ def event_mv_prob_normal(
106106
seed = np.random.randint(0, 1000000000)
107107
if isinstance(events, Array):
108108
events = events.value
109-
events = brainevent.EventArray(events)
109+
events = brainevent.BinaryArray(events)
110110
if isinstance(w_mu, Array):
111111
w_mu = w_mu.value
112112
if isinstance(w_sigma, Array):

brainpy/math/jitconn/matvec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def mv_prob_homo(
9696
if isinstance(weight, Array):
9797
weight = weight.value
9898

99-
csr = brainevent.JITCHomoR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel)
99+
csr = brainevent.JITCScalarR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel)
100100
if transpose:
101101
return vector @ csr
102102
else:
@@ -290,7 +290,7 @@ def get_homo_weight_matrix(
290290
"""
291291
if seed is None:
292292
seed = np.random.randint(0, 1000000000)
293-
csr = brainevent.JITCHomoR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel)
293+
csr = brainevent.JITCScalarR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel)
294294
if transpose:
295295
csr = csr.T
296296
return csr.todense()

brainpy/state/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# `brainpy.state` - State-based Brain Dynamics Programming
1+
# `brainpy.state`
22

33
## Overview
44

@@ -37,7 +37,7 @@ pip install brainpy -U
3737
For development or to install the state module separately:
3838

3939
```bash
40-
pip install brainpy_state -U
40+
pip install brainpy.state -U
4141
```
4242

4343
## Usage

brainpy/state/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
from brainpy_state import *
1717
from brainpy_state import __all__
1818

19-
20-
19+
if __name__ == '__main__':
20+
print(LIF)
21+
print(__all__)

0 commit comments

Comments
 (0)