Skip to content

Commit a39f1f8

Browse files
author
Dan Saunders
committed
Making Nodes code more self-consistent.
1 parent 4f28c39 commit a39f1f8

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

bindsnet/network/nodes.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,17 @@ def step(self, inpts, dt):
4646
if self.traces:
4747
# Decay and set spike traces.
4848
self.x -= dt * self.trace_tc * self.x
49-
self.x[self.s] = 1
49+
self.x = self.x.masked_fill(self.s, 1)
5050

5151
@abstractmethod
5252
def _reset(self):
5353
'''
5454
Abstract base class method for resetting state variables.
5555
'''
56-
self.s[self.s != 0] = 0 # Spike occurences.
56+
self.s = torch.zeros(self.shape).byte() # Spike occurences.
5757

5858
if self.traces:
59-
self.x = torch.zeros(self.shape) # Firing traces.
59+
self.x = torch.zeros(self.shape) # Firing traces.
6060

6161

6262
class Input(Nodes):
@@ -195,8 +195,8 @@ def _reset(self):
195195
Resets relevant state variables.
196196
'''
197197
super()._reset()
198-
self.v = torch.zeros(self.v.size()) + self.rest # Neuron voltages.
199-
self.refrac_count[self.refrac_count != 0] = 0 # Refractory period counters.
198+
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
199+
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
200200

201201

202202
class LIFNodes(Nodes):
@@ -263,8 +263,8 @@ def _reset(self):
263263
Resets relevant state variables.
264264
'''
265265
super()._reset()
266-
self.v = torch.zeros(self.v.size()) + self.rest # Neuron voltages.
267-
self.refrac_count = torch.zeros(self.v.size()) # Refractory period counters.
266+
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
267+
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
268268

269269

270270
class AdaptiveLIFNodes(Nodes):
@@ -339,8 +339,8 @@ def _reset(self):
339339
Resets relevant state variables.
340340
'''
341341
super()._reset()
342-
self.v = torch.zeros(self.v.size()) + self.rest # Neuron voltages.
343-
self.refrac_count[self.refrac_count != 0] = 0 # Refractory period counters.
342+
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
343+
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
344344

345345

346346
class DiehlAndCookNodes(Nodes):
@@ -421,8 +421,8 @@ def _reset(self):
421421
Resets relevant state variables.
422422
'''
423423
super()._reset()
424-
self.v = torch.zeros(self.v.size()) + self.rest # Neuron voltages.
425-
self.refrac_count[self.refrac_count != 0] = 0 # Refractory period counters.
424+
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
425+
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
426426

427427

428428
class IzhikevichNodes(Nodes):
@@ -502,6 +502,6 @@ def _reset(self):
502502
Resets relevant state variables.
503503
'''
504504
super()._reset()
505-
self.v = self.rest * torch.ones(self.n) # Neuron voltages.
506-
self.u = self.b * self.v # Neuron recovery.
507-
self.refrac_count = torch.zeros(self.n) # Refractory period counters.
505+
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
506+
self.u = self.b * self.v # Neuron recovery.
507+
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

0 commit comments

Comments
 (0)