Skip to content

Commit a1526ee

Browse files
author
Dan Saunders
committed
Using faster indexing scheme (torch.Tensor.masked_fill); updating README.md.
1 parent 6c1433c commit a1526ee

2 files changed

Lines changed: 20 additions & 14 deletions

File tree

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# BindsNET
22

3-
A Python package used for simulating spiking neural networks (SNNs) using [PyTorch](http://pytorch.org/) GPU functionality.
3+
A Python package used for simulating spiking neural networks (SNNs) on CPUs or GPUs using [PyTorch](http://pytorch.org/) `Tensor` functionality.
44

5-
BindsNET is a spiking neural network simulation software, a critical component enabling the modeling of neural systems and the development of biologically inspired algorithms in the machine learning domain.
5+
BindsNET is a spiking neural network simulation library geared towards the development of biologically inspired algorithms for machine learning.
66

77
This package is used as part of ongoing research on applying SNNs to machine learning (ML) and reinforcement learning (RL) problems in the [Biologically Inspired Neural & Dynamical Systems (BINDS) lab](http://binds.cs.umass.edu/).
88

@@ -12,6 +12,12 @@ This package is used as part of ongoing research on applying SNNs to machine lea
1212
## Requirements
1313

1414
- Python 3.6
15+
- `torch`
16+
- `numpy`
17+
- `matplotlib`
18+
- `scikit_image`
19+
- `opencv-python`
20+
- `gym` (optional)
1521

1622
## Setting things up
1723

bindsnet/network/nodes.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def step(self, inpts, dt):
182182
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
183183

184184
# Refractoriness and voltage reset.
185-
self.refrac_count[self.s] = self.refrac
186-
self.v[self.s] = self.reset
185+
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
186+
self.v = self.v.masked_fill(self.s, self.reset)
187187

188188
# Integrate input and decay voltages.
189189
self.v += inpts
@@ -250,8 +250,8 @@ def step(self, inpts, dt):
250250
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
251251

252252
# Refractoriness and voltage reset.
253-
self.refrac_count[self.s] = self.refrac
254-
self.v[self.s] = self.reset
253+
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
254+
self.v = self.v.masked_fill(self.s, self.reset)
255255

256256
# Integrate inputs.
257257
self.v += inpts
@@ -263,8 +263,8 @@ def _reset(self):
263263
Resets relevant state variables.
264264
'''
265265
super()._reset()
266-
self.v = torch.zeros(self.v.size()) + self.rest # Neuron voltages.
267-
self.refrac_count[self.refrac_count != 0] = 0 # Refractory period counters.
266+
self.v = torch.zeros(self.v.size()) + self.rest # Neuron voltages.
267+
self.refrac_count = torch.zeros(self.v.size()) # Refractory period counters.
268268

269269

270270
class AdaptiveLIFNodes(Nodes):
@@ -325,8 +325,8 @@ def step(self, inpts, dt):
325325
self.s = (self.v >= self.thresh + self.theta) * (self.refrac_count == 0)
326326

327327
# Refractoriness, voltage reset, and adaptive thresholds.
328-
self.refrac_count[self.s] = self.refrac
329-
self.v[self.s] = self.reset
328+
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
329+
self.v = self.v.masked_fill(self.s, self.reset)
330330
self.theta += self.theta_plus * self.s.float()
331331

332332
# Integrate inputs.
@@ -400,8 +400,8 @@ def step(self, inpts, dt):
400400
self.s = (self.v >= self.thresh + self.theta) * (self.refrac_count == 0)
401401

402402
# Refractoriness, voltage reset, and adaptive thresholds.
403-
self.refrac_count[self.s] = self.refrac
404-
self.v[self.s] = self.reset
403+
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
404+
self.v = self.v.masked_fill(self.s, self.reset)
405405
self.theta += self.theta_plus * self.s.float()
406406

407407
# Choose only a single neuron to spike.
@@ -488,8 +488,8 @@ def step(self, inpts, dt):
488488
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
489489

490490
# Refractoriness and voltage reset.
491-
self.refrac_count[self.s] = self.refrac
492-
self.v[self.s] = self.reset
491+
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
492+
self.v = self.v.masked_fill(self.s, self.reset)
493493

494494
# Apply v and u updates.
495495
self.v += dt * (0.04 * (self.v ** 2) + 5 * self.v + 140 - self.u + inpts)

0 commit comments

Comments
 (0)