@@ -46,7 +46,7 @@ def step(self, inpts, dt):
4646 if self .traces :
4747 # Decay and set spike traces.
4848 self .x -= dt * self .trace_tc * self .x
49- self .x = self . x . masked_fill (self .s , 1 )
49+ self .x . masked_fill_ (self .s , 1 )
5050
5151 @abstractmethod
5252 def _reset (self ):
@@ -179,11 +179,11 @@ def step(self, inpts, dt):
179179 self .refrac_count [self .refrac_count != 0 ] -= dt
180180
181181 # Check for spiking neurons.
182- self .s = (self .v >= self .thresh ) * (self .refrac_count == 0 )
182+ self .s = (self .v >= self .thresh ) & (self .refrac_count == 0 )
183183
184184 # Refractoriness and voltage reset.
185- self .refrac_count = self . refrac_count . masked_fill (self .s , self .refrac )
186- self .v = self . v . masked_fill (self .s , self .reset )
185+ self .refrac_count . masked_fill_ (self .s , self .refrac )
186+ self .v . masked_fill_ (self .s , self .reset )
187187
188188 # Integrate input and decay voltages.
189189 self .v += inpts
@@ -222,14 +222,14 @@ def __init__(self, n=None, shape=None, traces=False, thresh=-52.0, rest=-65.0,
222222 '''
223223 super ().__init__ (n , shape , traces , trace_tc )
224224
225- self .rest = rest # Rest voltage.
226- self .reset = reset # Post-spike reset voltage.
227- self .thresh = thresh # Spike threshold voltage.
228- self .refrac = refrac # Post-spike refractory period.
229- self .decay = decay # Rate of decay of neuron voltage.
225+ self .rest = rest # Rest voltage.
226+ self .reset = reset # Post-spike reset voltage.
227+ self .thresh = thresh # Spike threshold voltage.
228+ self .refrac = refrac # Post-spike refractory period.
229+ self .decay = decay # Rate of decay of neuron voltage.
230230
231- self .v = torch .zeros (self .shape ) + self . rest # Neuron voltages.
232- self .refrac_count = torch .zeros (self .shape ) # Refractory period counters.
231+ self .v = self . rest * torch .ones (self .shape ) # Neuron voltages.
232+ self .refrac_count = torch .zeros (self .shape ) # Refractory period counters.
233233
234234 def step (self , inpts , dt ):
235235 '''
@@ -247,11 +247,11 @@ def step(self, inpts, dt):
247247 self .refrac_count [self .refrac_count != 0 ] -= dt
248248
249249 # Check for spiking neurons.
250- self .s = (self .v >= self .thresh ) * (self .refrac_count == 0 )
250+ self .s = (self .v >= self .thresh ) & (self .refrac_count == 0 )
251251
252252 # Refractoriness and voltage reset.
253- self .refrac_count = self . refrac_count . masked_fill (self .s , self .refrac )
254- self .v = self . v . masked_fill (self .s , self .reset )
253+ self .refrac_count . masked_fill_ (self .s , self .refrac )
254+ self .v . masked_fill_ (self .s , self .reset )
255255
256256 # Integrate inputs.
257257 self .v += inpts
@@ -267,6 +267,80 @@ def _reset(self):
267267 self .refrac_count = torch .zeros (self .shape ) # Refractory period counters.
268268
269269
270+ class CurrentLIFNodes (Nodes ):
271+ '''
272+ Layer of current-based leaky integrate-and-fire (LIF) neurons.
273+ '''
274+ def __init__ (self , n = None , shape = None , traces = False , thresh = - 52.0 , rest = - 65.0 ,
275+ reset = - 65.0 , refrac = 5 , decay = 1e-2 , i_decay = 2e-2 , trace_tc = 5e-2 ):
276+ '''
277+ Instantiates a layer of synaptic input current-based LIF neurons.
278+
279+ Inputs:
280+
281+ | :code:`n` (:code:`int`): The number of neurons in the layer.
282+ | :code:`shape` (:code:`iterable[int]`): The dimensionality of the layer.
283+ | :code:`traces` (:code:`bool`): Whether to record spike traces.
284+ | :code:`thresh` (:code:`float`): Spike threshold voltage.
285+ | :code:`rest` (:code:`float`): Resting membrane voltage.
286+ | :code:`reset` (:code:`float`): Post-spike reset voltage.
287+ | :code:`refrac` (:code:`int`): Refractory (non-firing) period of the neuron.
288+ | :code:`decay` (:code:`float`): Time constant of neuron voltage decay.
289+ | :code:`i_decay` (:code:`float`): Time constant of synaptic input current decay.
290+ | :code:`trace_tc` (:code:`float`): Time constant of spike trace decay.
291+ '''
292+ super ().__init__ (n , shape , traces , trace_tc )
293+
294+ self .rest = rest # Rest voltage.
295+ self .reset = reset # Post-spike reset voltage.
296+ self .thresh = thresh # Spike threshold voltage.
297+ self .refrac = refrac # Post-spike refractory period.
298+ self .decay = decay # Rate of decay of neuron voltage.
299+ self .i_decay = i_decay # Rate of decay of synaptic input current.
300+
301+ self .v = self .rest * torch .ones (self .shape ) # Neuron voltages.
302+ self .i = torch .zeros (self .shape ) # Synaptic input currents.
303+ self .refrac_count = torch .zeros (self .shape ) # Refractory period counters.
304+
305+ def step (self , inpts , dt ):
306+ '''
307+ Runs a single simulation step.
308+
309+ Inputs:
310+
311+ | :code:`inpts` (:code:`torch.Tensor`): Inputs to the layer.
312+ | :code:`dt` (:code:`float`): Simulation time step.
313+ '''
314+ # Decay voltages and current.
315+ self .v -= dt * self .decay * (self .v - self .rest )
316+ self .i -= dt * self .i_decay * self .i
317+
318+ # Decrement refrac counters.
319+ self .refrac_count [self .refrac_count != 0 ] -= dt
320+
321+ # Check for spiking neurons.
322+ self .s = (self .v >= self .thresh ) & (self .refrac_count == 0 )
323+
324+ # Refractoriness and voltage reset.
325+ self .refrac_count .masked_fill_ (self .s , self .refrac )
326+ self .v .masked_fill_ (self .s , self .reset )
327+
328+ # Integrate inputs.
329+ self .i += inpts
330+ self .v += self .i
331+
332+ super ().step (inpts , dt )
333+
334+ def _reset (self ):
335+ '''
336+ Resets relevant state variables.
337+ '''
338+ super ()._reset ()
339+ self .v = self .rest * torch .ones (self .shape ) # Neuron voltages.
340+ self .i = torch .zeros (self .shape ) # Synaptic input currents.
341+ self .refrac_count = torch .zeros (self .shape ) # Refractory period counters.
342+
343+
270344class AdaptiveLIFNodes (Nodes ):
271345 '''
272346 Layer of leaky integrate-and-fire (LIF) neurons with adaptive thresholds.
@@ -296,7 +370,7 @@ def __init__(self, n=None, shape=None, traces=False, rest=-65.0, reset=-65.0, th
296370 self .reset = reset # Post-spike reset voltage.
297371 self .thresh = thresh # Spike threshold voltage.
298372 self .refrac = refrac # Post-spike refractory period.
299- self .decay = decay # Rate of decay of neuron voltage.
373+ self .decay = decay # Rate of decay of neuron voltage.
300374 self .theta_plus = theta_plus # Constant threshold increase on spike.
301375 self .theta_decay = theta_decay # Rate of decay of adaptive thresholds.
302376
@@ -322,11 +396,11 @@ def step(self, inpts, dt):
322396 self .refrac_count [self .refrac_count != 0 ] -= dt
323397
324398 # Check for spiking neurons.
325- self .s = (self .v >= self .thresh + self .theta ) * (self .refrac_count == 0 )
399+ self .s = (self .v >= self .thresh + self .theta ) & (self .refrac_count == 0 )
326400
327401 # Refractoriness, voltage reset, and adaptive thresholds.
328- self .refrac_count = self . refrac_count . masked_fill (self .s , self .refrac )
329- self .v = self . v . masked_fill (self .s , self .reset )
402+ self .refrac_count . masked_fill_ (self .s , self .refrac )
403+ self .v . masked_fill_ (self .s , self .reset )
330404 self .theta += self .theta_plus * self .s .float ()
331405
332406 # Integrate inputs.
@@ -372,7 +446,7 @@ def __init__(self, n=None, shape=None, traces=False, rest=-65.0, reset=-65.0, th
372446 self .reset = reset # Post-spike reset voltage.
373447 self .thresh = thresh # Spike threshold voltage.
374448 self .refrac = refrac # Post-spike refractory period.
375- self .decay = decay # Rate of decay of neuron voltage.
449+ self .decay = decay # Rate of decay of neuron voltage.
376450 self .theta_plus = theta_plus # Constant threshold increase on spike.
377451 self .theta_decay = theta_decay # Rate of decay of adaptive thresholds.
378452
@@ -397,11 +471,11 @@ def step(self, inpts, dt):
397471 self .refrac_count [self .refrac_count != 0 ] -= dt
398472
399473 # Check for spiking neurons.
400- self .s = (self .v >= self .thresh + self .theta ) * (self .refrac_count == 0 )
474+ self .s = (self .v >= self .thresh + self .theta ) & (self .refrac_count == 0 )
401475
402476 # Refractoriness, voltage reset, and adaptive thresholds.
403- self .refrac_count = self . refrac_count . masked_fill (self .s , self .refrac )
404- self .v = self . v . masked_fill (self .s , self .reset )
477+ self .refrac_count . masked_fill_ (self .s , self .refrac )
478+ self .v . masked_fill_ (self .s , self .reset )
405479 self .theta += self .theta_plus * self .s .float ()
406480
407481 # Choose only a single neuron to spike.
@@ -449,11 +523,11 @@ def __init__(self, n=None, shape=None, traces=False, excitatory=True, rest=-65.0
449523 '''
450524 super ().__init__ (n , shape , traces , trace_tc )
451525
452- self .rest = rest # Rest voltage.
453- self .reset = reset # Post-spike reset voltage.
454- self .thresh = thresh # Spike threshold voltage.
455- self .refrac = refrac # Post-spike refractory period.
456- self .decay = decay # Rate of decay of neuron voltage.
526+ self .rest = rest # Rest voltage.
527+ self .reset = reset # Post-spike reset voltage.
528+ self .thresh = thresh # Spike threshold voltage.
529+ self .refrac = refrac # Post-spike refractory period.
530+ self .decay = decay # Rate of decay of neuron voltage.
457531
458532 if excitatory :
459533 self .r = torch .rand (n )
@@ -485,11 +559,11 @@ def step(self, inpts, dt):
485559 self .refrac_count [self .refrac_count != 0 ] -= dt
486560
487561 # Check for spiking neurons.
488- self .s = (self .v >= self .thresh ) * (self .refrac_count == 0 )
562+ self .s = (self .v >= self .thresh ) & (self .refrac_count == 0 )
489563
490564 # Refractoriness and voltage reset.
491- self .refrac_count = self . refrac_count . masked_fill (self .s , self .refrac )
492- self .v = self . v . masked_fill (self .s , self .reset )
565+ self .refrac_count . masked_fill_ (self .s , self .refrac )
566+ self .v . masked_fill_ (self .s , self .reset )
493567
494568 # Apply v and u updates.
495569 self .v += dt * (0.04 * (self .v ** 2 ) + 5 * self .v + 140 - self .u + inpts )
0 commit comments