Skip to content

Commit e48cbdd

Browse files
author
Dan Saunders
committed
Major refactoring to more closely conform to Python style best practices.
1 parent 7895b44 commit e48cbdd

File tree

13 files changed

+482
-269
lines changed

13 files changed

+482
-269
lines changed

bindsnet/analysis/plotting.py

Lines changed: 105 additions & 47 deletions
Large diffs are not rendered by default.

bindsnet/analysis/visualization.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ def plot_weights_movie(ws, sample_every=1):
1414
1515
Inputs:
1616
17-
| :code:`ws` (:code:`numpy.array`): Numpy array of shape :code:`[N_examples, source, target, time]`
18-
| :code:`sample_every` (:code:`int`): Sub-sample using this parameter. For example if :code:`time` is
19-
too large (500), set this parameter to 20 to sample weights
20-
every 20 iterations.
17+
| :code:`ws` (:code:`numpy.array`): Numpy array
18+
of shape :code:`[n_examples, source, target, time]`
19+
| :code:`sample_every` (:code:`int`): Sub-sample using this parameter.
2120
"""
2221
weights = []
2322

@@ -49,10 +48,13 @@ def plot_spike_trains_for_example(spikes, n_ex=None, top_k=None, indices=None):
4948
5049
Inputs:
5150
52-
| :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`): Spiking train data for a population of neurons for one example.
53-
| :code:`n_ex` (:code:`int`): Allows user to pick which example to plot spikes for. Must be >= 0.
51+
| :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`):
52+
Spiking train data for a population of neurons for one example.
53+
| :code:`n_ex` (:code:`int`): Allows user to pick
54+
which example to plot spikes for. Must be >= 0.
5455
| :code:`top_k` (:code:`int`): Plot k neurons that spiked the most for n_ex example.
55-
| :code:`indices` (:code:`list(int)`): Plot specific neurons' spiking activity instead of top_k. Meant to replace top_k.
56+
| :code:`indices` (:code:`list(int)`): Plot specific neurons'
57+
spiking activity instead of top_k. Meant to replace top_k.
5658
'''
5759

5860
assert (n_ex is not None and n_ex >= 0 and n_ex < spikes.shape[0])
@@ -84,11 +86,16 @@ def plot_voltage(voltage, n_ex=0, n_neuron=0, time=None, threshold=None):
8486
8587
Inputs:
8688
87-
| :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`): Tensor or array of shape :code:`[n_examples, n_neurons, time]`.
88-
| :code:`n_ex` (:code:`int`): Allows user to pick which example to plot voltage for.
89-
| :code:`n_neuron` (:code:`int`): Neuron index for which to plot voltages for.
90-
| :code:`time` (:code:`tuple(int)`): Plot spiking activity of neurons between the given range of time.
91-
| :code:`threshold` (:code:`float`): Neuron spiking threshold. Will be shown on the plot.
89+
| :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`):
90+
Tensor or array of shape :code:`[n_examples, n_neurons, time]`.
91+
| :code:`n_ex` (:code:`int`): Allows user
92+
to pick which example to plot voltage for.
93+
| :code:`n_neuron` (:code:`int`): Neuron
94+
index for which to plot voltages for.
95+
| :code:`time` (:code:`tuple(int)`): Plot spiking
96+
activity of neurons between the given range of time.
97+
| :code:`threshold` (:code:`float`): Neuron
98+
spiking threshold. Will be shown on the plot.
9299
'''
93100

94101
assert (n_ex >= 0 and n_neuron >= 0)
@@ -105,7 +112,9 @@ def plot_voltage(voltage, n_ex=0, n_neuron=0, time=None, threshold=None):
105112

106113
plt.figure()
107114
plt.plot(voltage[n_ex, n_neuron, timer])
108-
plt.xlabel('Simulation Time'); plt.ylabel('Voltage'); plt.title('Membrane voltage of neuron %d for example %d'%(n_neuron, n_ex+1))
115+
plt.xlabel('Simulation Time')
116+
plt.ylabel('Voltage')
117+
plt.title('Membrane voltage of neuron %d for example %d' % (n_neuron, n_ex + 1))
109118
locs, labels = plt.xticks()
110119
locs = range(int(locs[1]), int(locs[-1]), 10)
111120
plt.xticks(locs, time_ticks)

bindsnet/datasets/__init__.py

Lines changed: 86 additions & 53 deletions
Large diffs are not rendered by default.

bindsnet/datasets/preprocess.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ def gray_scale(im):
1212
1313
| :code:`im` (:code:`numpy.array`): Grayscaled image
1414
'''
15-
im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
16-
return im
15+
return cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
1716

1817

1918
def crop(im, x1, x2, y1, y2):
2019
return im[x1:x2, y1:y2, :]
21-
20+
2221

2322
def binary_image(im):
2423
'''
@@ -32,7 +31,7 @@ def binary_image(im):
3231
3332
| :code:`im` (:code:`numpy.array`): Black and white image.
3433
'''
35-
ret, im = cv2.threshold(im, 0, 1, cv2.THRESH_BINARY)
34+
_, im = cv2.threshold(im, 0, 1, cv2.THRESH_BINARY)
3635
return im
3736

3837

@@ -50,6 +49,4 @@ def subsample(im, x, y):
5049
5150
| :code:`im` (:code:`numpy.array`): Rescaled image.
5251
'''
53-
im = cv2.resize(im, (x, y))
54-
return im
55-
52+
return cv2.resize(im, (x, y))

bindsnet/encoding/__init__.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
def bernoulli(datum, time=None, **kwargs):
66
'''
7-
Generates Bernoulli-distributed spike trains based on input intensity. Inputs must be non-negative. Spikes correspond to successful Bernoulli trials, with success probability equal to (normalized in [0, 1]) input value.
7+
Generates Bernoulli-distributed spike trains based on input intensity.
8+
Inputs must be non-negative. Spikes correspond to successful Bernoulli
9+
trials, with success probability equal to (normalized in [0, 1]) input value.
810
911
Inputs:
1012
@@ -13,11 +15,13 @@ def bernoulli(datum, time=None, **kwargs):
1315
1416
Keyword arguments:
1517
16-
| :code:`max_prob` (:code:`float`): Maximum probability of spike per Bernoulli trial.
18+
| :code:`max_prob` (:code:`float`): Maximum
19+
probability of spike per Bernoulli trial.
1720
1821
Returns:
1922
20-
| (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]` of Bernoulli-distributed spikes.
23+
| (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]`
24+
of Bernoulli-distributed spikes.
2125
'''
2226
# Setting kwargs.
2327
max_prob = kwargs.get('max_prob', 1.0)
@@ -48,7 +52,8 @@ def bernoulli_loader(data, time=None, **kwargs):
4852
4953
Inputs:
5054
51-
| :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): Tensor of shape :code:`[n_samples, n_1, ..., n_k]`.
55+
| :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s):
56+
Tensor of shape :code:`[n_samples, n_1, ..., n_k]`.
5257
| :code:`time` (:code:`int`): Length of Bernoulli spike train per input variable.
5358
5459
Keyword arguments:
@@ -57,7 +62,8 @@ def bernoulli_loader(data, time=None, **kwargs):
5762
5863
Yields:
5964
60-
| (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` of Bernoulli-distributed spikes.
65+
| (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]`
66+
of Bernoulli-distributed spikes.
6167
'''
6268

6369
# Setting kwargs.
@@ -68,7 +74,8 @@ def bernoulli_loader(data, time=None, **kwargs):
6874

6975
def poisson(datum, time, **kwargs):
7076
'''
71-
Generates Poisson-distributed spike trains based on input intensity. Inputs must be non-negative.
77+
Generates Poisson-distributed spike trains based
78+
on input intensity. Inputs must be non-negative.
7279
7380
Inputs:
7481
@@ -77,7 +84,8 @@ def poisson(datum, time, **kwargs):
7784
7885
Returns:
7986
80-
| (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]` of Poisson-distributed spikes.
87+
| (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]`
88+
of Poisson-distributed spikes.
8189
'''
8290
datum = np.copy(datum)
8391
shape, size = datum.shape, datum.size
@@ -107,29 +115,36 @@ def poisson_loader(data, time, **kwargs):
107115
108116
Inputs:
109117
110-
| :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): Tensor of shape :code:`[n_samples, n_1, ..., n_k]`
118+
| :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s):
119+
Tensor of shape :code:`[n_samples, n_1, ..., n_k]`
111120
| :code:`time` (:code:`int`): Length of Poisson spike train per input variable.
112121
113122
Yields:
114123
115-
| (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` of Poisson-distributed spikes.
124+
| (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]`
125+
of Poisson-distributed spikes.
116126
'''
117127
for i in range(len(data)):
118128
yield poisson(data[i], time) # Encode datum as Poisson spike trains.
119129

120130

121131
def rank_order(datum, time, **kwargs):
122132
'''
123-
Encodes data via a rank order coding-like representation. One spike per neuron, temporally ordered by decreasing intensity. Inputs must be non-negative.
133+
Encodes data via a rank order coding-like representation. One
134+
spike per neuron, temporally ordered by decreasing intensity.
135+
Inputs must be non-negative.
124136
125137
Inputs:
126138
127-
| :code:`data` (:code:`torch.Tensor`): Tensor of shape :code:`[n_samples, n_1, ..., n_k]`
128-
| :code:`time` (:code:`int`): Length of Poisson spike train per input variable.
139+
| :code:`data` (:code:`torch.Tensor`): Tensor
140+
of shape :code:`[n_samples, n_1, ..., n_k]`
141+
| :code:`time` (:code:`int`): Length of rank
142+
order-encoded spike train per input variable.
129143
130144
Returns:
131145
132-
| (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]` of Poisson-distributed spikes.
146+
| (:code:`torch.Tensor`): Tensor of shape
147+
:code:`[time, n_1, ..., n_k]` of rank order-encoded spikes.
133148
'''
134149
datum = np.copy(datum)
135150
shape, size = datum.shape, datum.size
@@ -155,16 +170,20 @@ def rank_order(datum, time, **kwargs):
155170

156171
def rank_order_loader(data, time, **kwargs):
157172
'''
158-
Lazily invokes :code:`bindsnet.encoding.rank_order` to iteratively encode a sequence of data.
173+
Lazily invokes :code:`bindsnet.encoding.rank_order`
174+
to iteratively encode a sequence of data.
159175
160176
Inputs:
161177
162-
| :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): Tensor of shape :code:`[n_samples, n_1, ..., n_k]`
163-
| :code:`time` (:code:`int`): Length of Poisson spike train per input variable.
178+
| :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s):
179+
Tensor of shape :code:`[n_samples, n_1, ..., n_k]`
180+
| :code:`time` (:code:`int`): Length of rank
181+
order-encoded spike train per input variable.
164182
165183
Yields:
166184
167-
| (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` of rank order-encoded spikes.
185+
| (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]`
186+
of rank order-encoded spikes.
168187
'''
169188
for i in range(len(data)):
170189
yield rank_order(data[i], time) # Encode datum as rank order-encoded spike trains.

bindsnet/environment/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __init__(self, dataset, train=True, time=350, **kwargs):
3333
self.intensity = kwargs.get('intensity', 1)
3434
self.max_prob = kwargs.get('max_prob', 1)
3535

36-
assert self.max_prob > 0 and self.max_prob <= 1, 'Maximum spiking probability must be in (0, 1].'
36+
assert self.max_prob > 0 and self.max_prob <= 1, \
37+
'Maximum spiking probability must be in (0, 1].'
3738

3839
if train:
3940
self.data, self.labels = self.dataset.get_train()
@@ -50,14 +51,14 @@ def step(self, a=None):
5051
5152
Inputs:
5253
53-
| :code:`a` (:code:`None`): There is no interaction of the network with the MNIST dataset.
54+
| :code:`a` (:code:`None`): There is no interaction of the network the dataset.
5455
5556
Returns:
5657
57-
| :code:`obs` (:code:`torch.Tensor`): Observation from the environment (spike train-encoded MNIST digit).
58+
| :code:`obs` (:code:`torch.Tensor`): Observation from the environment.
5859
| :code:`reward` (:code:`float`): Fixed to :code:`0`.
5960
| :code:`done` (:code:`bool`): Fixed to :code:`False`.
60-
| :code:`info` (:code:`dict`): Contains label of MNIST digit.
61+
| :code:`info` (:code:`dict`): Contains label of data item.
6162
'''
6263
try:
6364
# Attempt to fetch the next observation.
@@ -140,7 +141,8 @@ def __init__(self, name, **kwargs):
140141
# Keyword arguments.
141142
self.max_prob = kwargs.get('max_prob', 1)
142143

143-
assert self.max_prob > 0 and self.max_prob <= 1, 'Maximum spiking probability must be in (0, 1].'
144+
assert self.max_prob > 0 and self.max_prob <= 1, \
145+
'Maximum spiking probability must be in (0, 1].'
144146

145147
def step(self, a):
146148
'''

bindsnet/evaluation/__init__.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,21 @@ def assign_labels(spikes, labels, n_labels, rates=None, alpha=1.0):
77
88
Inputs:
99
10-
| :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape :code:`(n_samples, time, n_neurons)` of a single layer's spiking activity.
11-
| :code:`labels` (:code:`torch.Tensor`): Vector of shape :code:`(n_samples,)` with data labels corresponding to spiking activity.
10+
| :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape
11+
:code:`(n_samples, time, n_neurons)` of a single layer's spiking activity.
12+
| :code:`labels` (:code:`torch.Tensor`): Vector of shape :code:`(n_samples,)`
13+
with data labels corresponding to spiking activity.
1214
| :code:`n_labels` (:code:`int`): The number of target labels in the data.
13-
| :code:`rates` (:code:`torch.Tensor`): If passed, these represent spike rates from a previous :code:`assign_labels()` call.
15+
| :code:`rates` (:code:`torch.Tensor`): If passed, these represent spike
16+
rates from a previous :code:`assign_labels()` call.
1417
| :code:`alpha` (:code:`float`): Rate of decay of label assignments.
1518
1619
Returns:
1720
18-
| (:code:`torch.Tensor`): Vector of shape :code:`(n_neurons,)` of neuron label assignments.
19-
| (:code:`torch.Tensor`): Vector of shape :code:`(n_neurons, n_labels)` of proportions of firing activity per neuron, per data label.
21+
| (:code:`torch.Tensor`): Vector of shape
22+
:code:`(n_neurons,)` of neuron label assignments.
23+
| (:code:`torch.Tensor`): Vector of shape :code:`(n_neurons, n_labels)`
24+
of proportions of firing activity per neuron, per data label.
2025
'''
2126
n_neurons = spikes.size(2)
2227

@@ -53,13 +58,16 @@ def all_activity(spikes, assignments, n_labels):
5358
5459
Inputs:
5560
56-
| :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape :code:`(n_samples, time, n_neurons)` of a layer's spiking activity.
57-
| :code:`assignments` (:code:`torch.Tensor`): A vector of shape :code:`(n_neurons,)` of neuron label assignments.
61+
| :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape
62+
:code:`(n_samples, time, n_neurons)` of a layer's spiking activity.
63+
| :code:`assignments` (:code:`torch.Tensor`): A vector of shape
64+
:code:`(n_neurons,)` of neuron label assignments.
5865
| :code:`n_labels` (:code:`int`): The number of target labels in the data.
5966
6067
Returns:
6168
62-
| (:code:`torch.Tensor`): Predictions tensor of shape :code:`(n_samples,)` resulting from the "all activity" classification scheme.
69+
| (:code:`torch.Tensor`): Predictions tensor of shape :code:`(n_samples,)`
70+
resulting from the "all activity" classification scheme.
6371
'''
6472
n_samples = spikes.size(0)
6573

@@ -88,18 +96,23 @@ def all_activity(spikes, assignments, n_labels):
8896

8997
def proportion_weighting(spikes, assignments, proportions, n_labels):
9098
'''
91-
Classify data with the label with highest average spiking activity over all neurons, weighted by class-wise proportion..
99+
Classify data with the label with highest average spiking
100+
activity over all neurons, weighted by class-wise proportion.
92101
93102
Inputs:
94103
95-
| :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape :code:`(n_samples, time, n_neurons)` of a single layer's spiking activity.
96-
| :code:`assignments` (:code:`torch.Tensor`): A vector of shape :code:`(n_neurons,)` of neuron label assignments.
97-
| :code:`proportions` (torch.Tensor): A matrix of shape :code:`(n_neurons, n_labels)` giving the per-class proportions of neuron spiking activity.
104+
| :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape
105+
:code:`(n_samples, time, n_neurons)` of a single layer's spiking activity.
106+
| :code:`assignments` (:code:`torch.Tensor`): A vector of shape
107+
:code:`(n_neurons,)` of neuron label assignments.
108+
| :code:`proportions` (torch.Tensor): A matrix of shape :code:`(n_neurons, n_labels)`
109+
giving the per-class proportions of neuron spiking activity.
98110
| :code:`n_labels` (:code:`int`): The number of target labels in the data.
99111
100112
Returns:
101113
102-
| (:code:`torch.Tensor`): Predictions tensor of shape :code:`(n_samples,)` resulting from the "proportion weighting" classification scheme.
114+
| (:code:`torch.Tensor`): Predictions tensor of shapez:code:`(n_samples,)`
115+
resulting from the "proportion weighting" classification scheme.
103116
'''
104117
n_samples = spikes.size(0)
105118

0 commit comments

Comments
 (0)