Skip to content

Commit f917281

Browse files
authored
Fix neuron, training, and ODE bugs and tighten dependencies (#818)
* chore: update version to 2.7.7, enhance README and requirements, and refactor event handling * fix: update required Python version to 3.11 * fix: update synaptic variable updates and error handling in multiple files
1 parent 64e31f2 commit f917281

File tree

14 files changed

+49
-39
lines changed

14 files changed

+49
-39
lines changed

brainpy/dnn/conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def update(self, x):
179179
if self.mask is not None:
180180
try:
181181
lax.broadcast_shapes(self.w.shape, self.mask.shape)
182-
except:
182+
except (ValueError, TypeError):
183183
raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}")
184184
w = w * self.mask
185185
y = lax.conv_general_dilated(lhs=bm.as_jax(x),
@@ -566,7 +566,7 @@ def update(self, x):
566566
if self.mask is not None:
567567
try:
568568
lax.broadcast_shapes(self.w.shape, self.mask.shape)
569-
except:
569+
except (ValueError, TypeError):
570570
raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}")
571571
w = w * self.mask
572572
y = lax.conv_transpose(lhs=bm.as_jax(x),

brainpy/dnn/interoperation_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
try:
2727
import flax # noqa
2828
from flax.linen.recurrent import RNNCellBase
29-
except:
29+
except (ImportError, ModuleNotFoundError):
3030
flax = None
3131
RNNCellBase = object
3232

brainpy/dyn/neurons/lif.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def update(self, x=None):
299299
elif self.spk_reset == 'hard':
300300
V += (self.V_reset - V) * spike
301301
else:
302-
raise ValueError
302+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
303303

304304
else:
305305
spike = V >= self.V_th
@@ -509,7 +509,7 @@ def update(self, x=None):
509509
elif self.spk_reset == 'hard':
510510
V += (self.V_reset - V) * spike_no_grad
511511
else:
512-
raise ValueError
512+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
513513
spike_ = spike_no_grad > 0.
514514
# will be used in other place, like Delta Synapse, so stop its gradient
515515
if self.ref_var:
@@ -785,7 +785,7 @@ def update(self, x=None):
785785
elif self.spk_reset == 'hard':
786786
V += (self.V_reset - V) * spike
787787
else:
788-
raise ValueError
788+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
789789

790790
else:
791791
spike = V >= self.V_th
@@ -1142,7 +1142,7 @@ def update(self, x=None):
11421142
elif self.spk_reset == 'hard':
11431143
V += (self.V_reset - V) * spike_no_grad
11441144
else:
1145-
raise ValueError
1145+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
11461146
spike_ = spike_no_grad > 0.
11471147
# will be used in other place, like Delta Synapse, so stop its gradient
11481148
if self.ref_var:
@@ -1497,7 +1497,7 @@ def update(self, x=None):
14971497
elif self.spk_reset == 'hard':
14981498
V += (self.V_reset - V) * spike
14991499
else:
1500-
raise ValueError
1500+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
15011501
w += self.b * spike
15021502

15031503
else:
@@ -1843,7 +1843,7 @@ def update(self, x=None):
18431843
elif self.spk_reset == 'hard':
18441844
V += (self.V_reset - V) * spike_no_grad
18451845
else:
1846-
raise ValueError
1846+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
18471847
w += self.b * spike_no_grad
18481848
spike_ = spike_no_grad > 0.
18491849
# will be used in other place, like Delta Synapse, so stop its gradient
@@ -2142,7 +2142,7 @@ def update(self, x=None):
21422142
elif self.spk_reset == 'hard':
21432143
V += (self.V_reset - V) * spike
21442144
else:
2145-
raise ValueError
2145+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
21462146

21472147
else:
21482148
spike = V >= self.V_th
@@ -2431,7 +2431,7 @@ def update(self, x=None):
24312431
elif self.spk_reset == 'hard':
24322432
V += (self.V_reset - V) * spike_no_grad
24332433
else:
2434-
raise ValueError
2434+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
24352435
spike_ = spike_no_grad > 0.
24362436
# will be used in other place, like Delta Synapse, so stop its gradient
24372437
if self.ref_var:
@@ -2734,7 +2734,7 @@ def update(self, x=None):
27342734
elif self.spk_reset == 'hard':
27352735
V += (self.V_reset - V) * spike
27362736
else:
2737-
raise ValueError
2737+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
27382738
w += self.b * spike
27392739

27402740
else:
@@ -3054,7 +3054,7 @@ def update(self, x=None):
30543054
elif self.spk_reset == 'hard':
30553055
V += (self.V_reset - V) * spike_no_grad
30563056
else:
3057-
raise ValueError
3057+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
30583058
w += self.b * spike_no_grad
30593059
spike_ = spike_no_grad > 0.
30603060
# will be used in other place, like Delta Synapse, so stop its gradient
@@ -3417,7 +3417,7 @@ def update(self, x=None):
34173417
elif self.spk_reset == 'hard':
34183418
V += (self.V_reset - V) * spike
34193419
else:
3420-
raise ValueError
3420+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
34213421
I1 += spike * (self.R1 * I1 + self.A1 - I1)
34223422
I2 += spike * (self.R2 * I2 + self.A2 - I2)
34233423
V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike
@@ -3810,7 +3810,7 @@ def update(self, x=None):
38103810
elif self.spk_reset == 'hard':
38113811
V += (self.V_reset - V) * spike_no_grad
38123812
else:
3813-
raise ValueError
3813+
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
38143814
I1 += spike * (self.R1 * I1 + self.A1 - I1)
38153815
I2 += spike * (self.R2 * I2 + self.A2 - I2)
38163816
V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad

brainpy/dyn/synapses/abstract_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def update(self, x):
285285

286286
# update synaptic variables
287287
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
288-
self.h += self.a * x
288+
self.h.value = self.h.value + self.a * x
289289
return self.g.value
290290

291291
def return_info(self):
@@ -552,7 +552,7 @@ def dg(self, g, t, h):
552552
def update(self, x):
553553
# update synaptic variables
554554
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
555-
self.h += x
555+
self.h.value = self.h.value + x
556556
return self.g.value
557557

558558
def return_info(self):
@@ -737,7 +737,7 @@ def update(self, pre_spike):
737737
t = share.load('t')
738738
dt = share.load('dt')
739739
self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, dt=dt)
740-
self.x += pre_spike
740+
self.x.value = self.x.value + pre_spike
741741
return self.g.value
742742

743743
def return_info(self):

brainpy/dynsys.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -966,11 +966,18 @@ def _slice_to_num(slice_: slice, length: int):
966966
step = slice_.step
967967
if step is None:
968968
step = 1
969+
if step == 0:
970+
raise ValueError("slice step cannot be zero")
969971
# number
970972
num = 0
971-
while start < stop:
972-
start += step
973-
num += 1
973+
if step > 0:
974+
while start < stop:
975+
start += step
976+
num += 1
977+
else:
978+
while start > stop:
979+
start += step
980+
num += 1
974981
return num
975982

976983

brainpy/helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs):
116116
missing_keys = []
117117
unexpected_keys = []
118118
for name, node in nodes.items():
119+
if name not in state_dict:
120+
missing_keys.append(name)
121+
continue
119122
r = node.load_state(state_dict[name], **kwargs)
120123
if r is not None:
121124
missing, unexpected = r

brainpy/integrators/ode/adaptive_rk.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767

6868
import jax.numpy as jnp
6969

70+
from brainpy import _errors as errors
7071
from brainpy.integrators import constants as C, utils
7172
from brainpy.integrators.ode import common
7273
from brainpy.integrators.ode.base import ODEIntegrator
@@ -456,7 +457,7 @@ class BogackiShampine(AdaptiveRKIntegrator):
456457
A = [(),
457458
(0.5,),
458459
(0., 0.75),
459-
('2/9', '1/3', '4/0'), ]
460+
('2/9', '1/3', '4/9'), ]
460461
B1 = ['2/9', '1/3', '4/9', 0]
461462
B2 = ['7/24', 0.25, '1/3', 0.125]
462463
C = [0, 0.5, 0.75, 1]

brainpy/integrators/ode/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def set_default_odeint(method):
134134
raise ValueError(f'Unsupported ODE_INT numerical method: {method}.')
135135

136136
global _DEFAULT_DDE_METHOD
137-
_DEFAULT_ODE_METHOD = method
137+
_DEFAULT_DDE_METHOD = method
138138

139139

140140
def get_default_odeint():

brainpy/math/ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def value(self):
272272

273273
@value.setter
274274
def value(self, value):
275-
self_value = self._check_tracer()
275+
self_value = self._value
276276

277277
if isinstance(value, Array):
278278
value = value.value

brainpy/math/object_transform/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -773,16 +773,16 @@ def update(self, *args, **kwargs) -> 'NodeDict':
773773
self[k] = v
774774
elif isinstance(arg, tuple):
775775
assert len(arg) == 2
776-
self[arg[0]] = args[1]
776+
self[arg[0]] = arg[1]
777777
for k, v in kwargs.items():
778778
self[k] = v
779779
return self
780780

781781
def __setitem__(self, key, value) -> 'NodeDict':
782782
if self.check_unique:
783783
exist = self.get(key, None)
784-
if id(exist) != id(value):
785-
raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {value}.')
784+
if exist is not None and id(exist) != id(value):
785+
raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {exist}.')
786786
super().__setitem__(key, value)
787787
return self
788788

0 commit comments

Comments
 (0)