Skip to content

Commit 7895b44

Browse files
author
Dan Saunders
committed
Adding current-based LIF neurons and example Jupyter notebook.
1 parent 48f3c2b commit 7895b44

File tree

3 files changed

+342
-31
lines changed

3 files changed

+342
-31
lines changed

bindsnet/network/nodes.py

Lines changed: 104 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def step(self, inpts, dt):
4646
if self.traces:
4747
# Decay and set spike traces.
4848
self.x -= dt * self.trace_tc * self.x
49-
self.x = self.x.masked_fill(self.s, 1)
49+
self.x.masked_fill_(self.s, 1)
5050

5151
@abstractmethod
5252
def _reset(self):
@@ -179,11 +179,11 @@ def step(self, inpts, dt):
179179
self.refrac_count[self.refrac_count != 0] -= dt
180180

181181
# Check for spiking neurons.
182-
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
182+
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)
183183

184184
# Refractoriness and voltage reset.
185-
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
186-
self.v = self.v.masked_fill(self.s, self.reset)
185+
self.refrac_count.masked_fill_(self.s, self.refrac)
186+
self.v.masked_fill_(self.s, self.reset)
187187

188188
# Integrate input and decay voltages.
189189
self.v += inpts
@@ -222,14 +222,14 @@ def __init__(self, n=None, shape=None, traces=False, thresh=-52.0, rest=-65.0,
222222
'''
223223
super().__init__(n, shape, traces, trace_tc)
224224

225-
self.rest = rest # Rest voltage.
226-
self.reset = reset # Post-spike reset voltage.
227-
self.thresh = thresh # Spike threshold voltage.
228-
self.refrac = refrac # Post-spike refractory period.
229-
self.decay = decay # Rate of decay of neuron voltage.
225+
self.rest = rest # Rest voltage.
226+
self.reset = reset # Post-spike reset voltage.
227+
self.thresh = thresh # Spike threshold voltage.
228+
self.refrac = refrac # Post-spike refractory period.
229+
self.decay = decay # Rate of decay of neuron voltage.
230230

231-
self.v = torch.zeros(self.shape) + self.rest # Neuron voltages.
232-
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
231+
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
232+
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
233233

234234
def step(self, inpts, dt):
235235
'''
@@ -247,11 +247,11 @@ def step(self, inpts, dt):
247247
self.refrac_count[self.refrac_count != 0] -= dt
248248

249249
# Check for spiking neurons.
250-
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
250+
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)
251251

252252
# Refractoriness and voltage reset.
253-
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
254-
self.v = self.v.masked_fill(self.s, self.reset)
253+
self.refrac_count.masked_fill_(self.s, self.refrac)
254+
self.v.masked_fill_(self.s, self.reset)
255255

256256
# Integrate inputs.
257257
self.v += inpts
@@ -267,6 +267,80 @@ def _reset(self):
267267
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
268268

269269

270+
class CurrentLIFNodes(Nodes):
271+
'''
272+
Layer of current-based leaky integrate-and-fire (LIF) neurons.
273+
'''
274+
def __init__(self, n=None, shape=None, traces=False, thresh=-52.0, rest=-65.0,
275+
reset=-65.0, refrac=5, decay=1e-2, i_decay=2e-2, trace_tc=5e-2):
276+
'''
277+
Instantiates a layer of synaptic input current-based LIF neurons.
278+
279+
Inputs:
280+
281+
| :code:`n` (:code:`int`): The number of neurons in the layer.
282+
| :code:`shape` (:code:`iterable[int]`): The dimensionality of the layer.
283+
| :code:`traces` (:code:`bool`): Whether to record spike traces.
284+
| :code:`thresh` (:code:`float`): Spike threshold voltage.
285+
| :code:`rest` (:code:`float`): Resting membrane voltage.
286+
| :code:`reset` (:code:`float`): Post-spike reset voltage.
287+
| :code:`refrac` (:code:`int`): Refractory (non-firing) period of the neuron.
288+
| :code:`decay` (:code:`float`): Time constant of neuron voltage decay.
289+
| :code:`i_decay` (:code:`float`): Time constant of synaptic input current decay.
290+
| :code:`trace_tc` (:code:`float`): Time constant of spike trace decay.
291+
'''
292+
super().__init__(n, shape, traces, trace_tc)
293+
294+
self.rest = rest # Rest voltage.
295+
self.reset = reset # Post-spike reset voltage.
296+
self.thresh = thresh # Spike threshold voltage.
297+
self.refrac = refrac # Post-spike refractory period.
298+
self.decay = decay # Rate of decay of neuron voltage.
299+
self.i_decay = i_decay # Rate of decay of synaptic input current.
300+
301+
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
302+
self.i = torch.zeros(self.shape) # Synaptic input currents.
303+
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
304+
305+
def step(self, inpts, dt):
306+
'''
307+
Runs a single simulation step.
308+
309+
Inputs:
310+
311+
| :code:`inpts` (:code:`torch.Tensor`): Inputs to the layer.
312+
| :code:`dt` (:code:`float`): Simulation time step.
313+
'''
314+
# Decay voltages and current.
315+
self.v -= dt * self.decay * (self.v - self.rest)
316+
self.i -= dt * self.i_decay * self.i
317+
318+
# Decrement refrac counters.
319+
self.refrac_count[self.refrac_count != 0] -= dt
320+
321+
# Check for spiking neurons.
322+
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)
323+
324+
# Refractoriness and voltage reset.
325+
self.refrac_count.masked_fill_(self.s, self.refrac)
326+
self.v.masked_fill_(self.s, self.reset)
327+
328+
# Integrate inputs.
329+
self.i += inpts
330+
self.v += self.i
331+
332+
super().step(inpts, dt)
333+
334+
def _reset(self):
335+
'''
336+
Resets relevant state variables.
337+
'''
338+
super()._reset()
339+
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
340+
self.i = torch.zeros(self.shape) # Synaptic input currents.
341+
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
342+
343+
270344
class AdaptiveLIFNodes(Nodes):
271345
'''
272346
Layer of leaky integrate-and-fire (LIF) neurons with adaptive thresholds.
@@ -296,7 +370,7 @@ def __init__(self, n=None, shape=None, traces=False, rest=-65.0, reset=-65.0, th
296370
self.reset = reset # Post-spike reset voltage.
297371
self.thresh = thresh # Spike threshold voltage.
298372
self.refrac = refrac # Post-spike refractory period.
299-
self.decay = decay # Rate of decay of neuron voltage.
373+
self.decay = decay # Rate of decay of neuron voltage.
300374
self.theta_plus = theta_plus # Constant threshold increase on spike.
301375
self.theta_decay = theta_decay # Rate of decay of adaptive thresholds.
302376

@@ -322,11 +396,11 @@ def step(self, inpts, dt):
322396
self.refrac_count[self.refrac_count != 0] -= dt
323397

324398
# Check for spiking neurons.
325-
self.s = (self.v >= self.thresh + self.theta) * (self.refrac_count == 0)
399+
self.s = (self.v >= self.thresh + self.theta) & (self.refrac_count == 0)
326400

327401
# Refractoriness, voltage reset, and adaptive thresholds.
328-
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
329-
self.v = self.v.masked_fill(self.s, self.reset)
402+
self.refrac_count.masked_fill_(self.s, self.refrac)
403+
self.v.masked_fill_(self.s, self.reset)
330404
self.theta += self.theta_plus * self.s.float()
331405

332406
# Integrate inputs.
@@ -372,7 +446,7 @@ def __init__(self, n=None, shape=None, traces=False, rest=-65.0, reset=-65.0, th
372446
self.reset = reset # Post-spike reset voltage.
373447
self.thresh = thresh # Spike threshold voltage.
374448
self.refrac = refrac # Post-spike refractory period.
375-
self.decay = decay # Rate of decay of neuron voltage.
449+
self.decay = decay # Rate of decay of neuron voltage.
376450
self.theta_plus = theta_plus # Constant threshold increase on spike.
377451
self.theta_decay = theta_decay # Rate of decay of adaptive thresholds.
378452

@@ -397,11 +471,11 @@ def step(self, inpts, dt):
397471
self.refrac_count[self.refrac_count != 0] -= dt
398472

399473
# Check for spiking neurons.
400-
self.s = (self.v >= self.thresh + self.theta) * (self.refrac_count == 0)
474+
self.s = (self.v >= self.thresh + self.theta) & (self.refrac_count == 0)
401475

402476
# Refractoriness, voltage reset, and adaptive thresholds.
403-
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
404-
self.v = self.v.masked_fill(self.s, self.reset)
477+
self.refrac_count.masked_fill_(self.s, self.refrac)
478+
self.v.masked_fill_(self.s, self.reset)
405479
self.theta += self.theta_plus * self.s.float()
406480

407481
# Choose only a single neuron to spike.
@@ -449,11 +523,11 @@ def __init__(self, n=None, shape=None, traces=False, excitatory=True, rest=-65.0
449523
'''
450524
super().__init__(n, shape, traces, trace_tc)
451525

452-
self.rest = rest # Rest voltage.
453-
self.reset = reset # Post-spike reset voltage.
454-
self.thresh = thresh # Spike threshold voltage.
455-
self.refrac = refrac # Post-spike refractory period.
456-
self.decay = decay # Rate of decay of neuron voltage.
526+
self.rest = rest # Rest voltage.
527+
self.reset = reset # Post-spike reset voltage.
528+
self.thresh = thresh # Spike threshold voltage.
529+
self.refrac = refrac # Post-spike refractory period.
530+
self.decay = decay # Rate of decay of neuron voltage.
457531

458532
if excitatory:
459533
self.r = torch.rand(n)
@@ -485,11 +559,11 @@ def step(self, inpts, dt):
485559
self.refrac_count[self.refrac_count != 0] -= dt
486560

487561
# Check for spiking neurons.
488-
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
562+
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)
489563

490564
# Refractoriness and voltage reset.
491-
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
492-
self.v = self.v.masked_fill(self.s, self.reset)
565+
self.refrac_count.masked_fill_(self.s, self.refrac)
566+
self.v.masked_fill_(self.s, self.reset)
493567

494568
# Apply v and u updates.
495569
self.v += dt * (0.04 * (self.v ** 2) + 5 * self.v + 140 - self.u + inpts)

bindsnet/network/topology.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def __init__(self, source, target, nu=1e-2, nu_pre=1e-4, nu_post=1e-2, **kwargs)
120120
self.w = kwargs.get('w', None)
121121

122122
if self.w is None:
123-
self.w = self.wmin + torch.rand(*source.shape, *target.shape) * (self.wmin - self.wmin)
123+
if self.wmin == -np.inf or self.wmax == np.inf:
124+
self.w = torch.rand(*source.shape, *target.shape)
125+
else:
126+
self.w = self.wmin + torch.rand(*source.shape, *target.shape) * (self.wmin - self.wmin)
124127
else:
125128
if torch.max(self.w) > self.wmax or torch.min(self.w) < self.wmin:
126129
warnings.warn('Weight matrix will be clamped between [%f, %f]; values may be biased to interval values.' % (self.wmin, self.wmax))

examples/notebooks/LIFNodes vs. CurrentLIFNodes.ipynb

Lines changed: 234 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)