@@ -46,17 +46,17 @@ def step(self, inpts, dt):
4646 if self .traces :
4747 # Decay and set spike traces.
4848 self .x -= dt * self .trace_tc * self .x
49- self .x [ self .s ] = 1
49+ self .x = self .x . masked_fill ( self . s , 1 )
5050
5151 @abstractmethod
5252 def _reset (self ):
5353 '''
5454 Abstract base class method for resetting state variables.
5555 '''
56- self .s [ self . s != 0 ] = 0 # Spike occurences.
56+ self .s = torch . zeros ( self . shape ). byte () # Spike occurences.
5757
5858 if self .traces :
59- self .x = torch .zeros (self .shape ) # Firing traces.
59+ self .x = torch .zeros (self .shape ) # Firing traces.
6060
6161
6262class Input (Nodes ):
@@ -195,8 +195,8 @@ def _reset(self):
195195 Resets relevant state variables.
196196 '''
197197 super ()._reset ()
198- self .v = torch . zeros ( self .v . size ()) + self .rest # Neuron voltages.
199- self .refrac_count [ self . refrac_count != 0 ] = 0 # Refractory period counters.
198+ self .v = self .rest * torch . ones ( self .shape ) # Neuron voltages.
199+ self .refrac_count = torch . zeros ( self . shape ) # Refractory period counters.
200200
201201
202202class LIFNodes (Nodes ):
@@ -263,8 +263,8 @@ def _reset(self):
263263 Resets relevant state variables.
264264 '''
265265 super ()._reset ()
266- self .v = torch .zeros (self .v . size ()) + self . rest # Neuron voltages.
267- self .refrac_count = torch .zeros (self .v . size () ) # Refractory period counters.
266+ self .v = self . rest * torch .ones (self .shape ) # Neuron voltages.
267+ self .refrac_count = torch .zeros (self .shape ) # Refractory period counters.
268268
269269
270270class AdaptiveLIFNodes (Nodes ):
@@ -339,8 +339,8 @@ def _reset(self):
339339 Resets relevant state variables.
340340 '''
341341 super ()._reset ()
342- self .v = torch . zeros ( self .v . size ()) + self .rest # Neuron voltages.
343- self .refrac_count [ self . refrac_count != 0 ] = 0 # Refractory period counters.
342+ self .v = self .rest * torch . ones ( self .shape ) # Neuron voltages.
343+ self .refrac_count = torch . zeros ( self . shape ) # Refractory period counters.
344344
345345
346346class DiehlAndCookNodes (Nodes ):
@@ -421,8 +421,8 @@ def _reset(self):
421421 Resets relevant state variables.
422422 '''
423423 super ()._reset ()
424- self .v = torch . zeros ( self .v . size ()) + self .rest # Neuron voltages.
425- self .refrac_count [ self . refrac_count != 0 ] = 0 # Refractory period counters.
424+ self .v = self .rest * torch . ones ( self .shape ) # Neuron voltages.
425+ self .refrac_count = torch . zeros ( self . shape ) # Refractory period counters.
426426
427427
428428class IzhikevichNodes (Nodes ):
@@ -502,6 +502,6 @@ def _reset(self):
502502 Resets relevant state variables.
503503 '''
504504 super ()._reset ()
505- self .v = self .rest * torch .ones (self .n ) # Neuron voltages.
506- self .u = self .b * self .v # Neuron recovery.
507- self .refrac_count = torch .zeros (self .n ) # Refractory period counters.
505+ self .v = self .rest * torch .ones (self .shape ) # Neuron voltages.
506+ self .u = self .b * self .v # Neuron recovery.
507+ self .refrac_count = torch .zeros (self .shape ) # Refractory period counters.
0 commit comments