Skip to content

Commit 530d88a

Browse files
author
Dan Saunders
committed
Change tab-spacing to four spaces spacing.
1 parent a1526ee commit 530d88a

33 files changed

Lines changed: 4665 additions & 4664 deletions

bindsnet/analysis/plotting.py

Lines changed: 437 additions & 437 deletions
Large diffs are not rendered by default.

bindsnet/analysis/visualization.py

Lines changed: 101 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -9,109 +9,109 @@
99

1010

1111
def plot_weights_movie(ws, sample_every=1):
12-
"""
13-
Create and plot movie of weights (:code:`ws`).
14-
15-
Inputs:
16-
17-
| :code:`ws` (:code:`numpy.array`): Numpy array of shape :code:`[N_examples, source, target, time]`
18-
| :code:`sample_every` (:code:`int`): Sub-sample using this parameter. For example if :code:`time` is
19-
too large (500), set this parameter to 20 to sample weights
20-
every 20 iterations.
21-
"""
22-
weights = []
23-
24-
# Obtain samples from the weights for every example
25-
for i in range(ws.shape[0]):
26-
sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)]
27-
weights.append(sub_sampled_weight)
28-
else:
29-
weights = np.concatenate(weights, axis=0)
30-
31-
# Initialize plot
32-
fig = plt.figure()
33-
im = plt.imshow(weights[0, :, :], cmap='hot_r', animated=True, vmin=0, vmax=1)
34-
plt.axis('off'); plt.colorbar(im)
35-
36-
# Update function for the animation
37-
def update(j):
38-
im.set_data(weights[j, :, :])
39-
return [im]
40-
41-
# Initialize animatino
42-
global ani; ani=0
43-
ani = animation.FuncAnimation(fig, update, frames=weights.shape[-1], interval=1000, blit=True)
44-
plt.show()
45-
12+
"""
13+
Create and plot movie of weights (:code:`ws`).
14+
15+
Inputs:
16+
17+
| :code:`ws` (:code:`numpy.array`): Numpy array of shape :code:`[N_examples, source, target, time]`
18+
| :code:`sample_every` (:code:`int`): Sub-sample using this parameter. For example if :code:`time` is
19+
too large (500), set this parameter to 20 to sample weights
20+
every 20 iterations.
21+
"""
22+
weights = []
23+
24+
# Obtain samples from the weights for every example
25+
for i in range(ws.shape[0]):
26+
sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)]
27+
weights.append(sub_sampled_weight)
28+
else:
29+
weights = np.concatenate(weights, axis=0)
30+
31+
# Initialize plot
32+
fig = plt.figure()
33+
im = plt.imshow(weights[0, :, :], cmap='hot_r', animated=True, vmin=0, vmax=1)
34+
plt.axis('off'); plt.colorbar(im)
35+
36+
# Update function for the animation
37+
def update(j):
38+
im.set_data(weights[j, :, :])
39+
return [im]
40+
41+
# Initialize animatino
42+
global ani; ani=0
43+
ani = animation.FuncAnimation(fig, update, frames=weights.shape[-1], interval=1000, blit=True)
44+
plt.show()
45+
4646
def plot_spike_trains_for_example(spikes, n_ex=None, top_k=None, indices=None):
47-
'''
48-
Plot spike trains for top-k neurons or for specific indices.
49-
50-
Inputs:
51-
52-
| :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`): Spiking train data for a population of neurons for one example.
53-
| :code:`n_ex` (:code:`int`): Allows user to pick which example to plot spikes for. Must be >= 0.
54-
| :code:`top_k` (:code:`int`): Plot k neurons that spiked the most for n_ex example.
55-
| :code:`indices` (:code:`list(int)`): Plot specific neurons' spiking activity instead of top_k. Meant to replace top_k.
56-
'''
47+
'''
48+
Plot spike trains for top-k neurons or for specific indices.
49+
50+
Inputs:
51+
52+
| :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`): Spiking train data for a population of neurons for one example.
53+
| :code:`n_ex` (:code:`int`): Allows user to pick which example to plot spikes for. Must be >= 0.
54+
| :code:`top_k` (:code:`int`): Plot k neurons that spiked the most for n_ex example.
55+
| :code:`indices` (:code:`list(int)`): Plot specific neurons' spiking activity instead of top_k. Meant to replace top_k.
56+
'''
5757

58-
assert (n_ex is not None and n_ex >= 0 and n_ex < spikes.shape[0])
59-
60-
plt.figure()
61-
62-
if top_k is None and indices is None: # Plot all neurons' spiking activity
63-
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, :, :]]
64-
plt.title('Spiking activity for all %d neurons'%spikes.shape[1])
65-
66-
elif top_k is None: # Plot based on indices parameter
67-
assert (indices is not None)
68-
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, indices, :]]
69-
70-
elif indices is None: # Plot based on top_k parameter
71-
assert (top_k is not None)
72-
# Obtain the top k neurons that fired the most
73-
top_k_loc = np.argsort(np.sum(spikes[n_ex,:,:], axis=1), axis=0)[::-1]
74-
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :]]
75-
plt.title('Spiking activity for top %d neurons'%top_k)
76-
77-
plt.eventplot(spike_per_neuron, linelengths= [0.5]*len(spike_per_neuron))
78-
plt.xlabel('Simulation Time'); plt.ylabel('Neuron index')
79-
plt.show()
58+
assert (n_ex is not None and n_ex >= 0 and n_ex < spikes.shape[0])
59+
60+
plt.figure()
61+
62+
if top_k is None and indices is None: # Plot all neurons' spiking activity
63+
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, :, :]]
64+
plt.title('Spiking activity for all %d neurons'%spikes.shape[1])
65+
66+
elif top_k is None: # Plot based on indices parameter
67+
assert (indices is not None)
68+
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, indices, :]]
69+
70+
elif indices is None: # Plot based on top_k parameter
71+
assert (top_k is not None)
72+
# Obtain the top k neurons that fired the most
73+
top_k_loc = np.argsort(np.sum(spikes[n_ex,:,:], axis=1), axis=0)[::-1]
74+
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :]]
75+
plt.title('Spiking activity for top %d neurons'%top_k)
76+
77+
plt.eventplot(spike_per_neuron, linelengths= [0.5]*len(spike_per_neuron))
78+
plt.xlabel('Simulation Time'); plt.ylabel('Neuron index')
79+
plt.show()
8080

8181
def plot_voltage(voltage, n_ex=0, n_neuron=0, time=None, threshold=None):
82-
'''
83-
Plot voltage for a single neuron on a specific example.
84-
85-
Inputs:
86-
87-
| :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`): Tensor or array of shape :code:`[n_examples, n_neurons, time]`.
88-
| :code:`n_ex` (:code:`int`): Allows user to pick which example to plot voltage for.
89-
| :code:`n_neuron` (:code:`int`): Neuron index for which to plot voltages for.
90-
| :code:`time` (:code:`tuple(int)`): Plot spiking activity of neurons between the given range of time.
91-
| :code:`threshold` (:code:`float`): Neuron spiking threshold. Will be shown on the plot.
92-
'''
93-
94-
assert (n_ex >= 0 and n_neuron >= 0)
95-
assert (n_ex < voltage.shape[0] and n_neuron < voltage.shape[1])
82+
'''
83+
Plot voltage for a single neuron on a specific example.
84+
85+
Inputs:
86+
87+
| :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`): Tensor or array of shape :code:`[n_examples, n_neurons, time]`.
88+
| :code:`n_ex` (:code:`int`): Allows user to pick which example to plot voltage for.
89+
| :code:`n_neuron` (:code:`int`): Neuron index for which to plot voltages for.
90+
| :code:`time` (:code:`tuple(int)`): Plot spiking activity of neurons between the given range of time.
91+
| :code:`threshold` (:code:`float`): Neuron spiking threshold. Will be shown on the plot.
92+
'''
93+
94+
assert (n_ex >= 0 and n_neuron >= 0)
95+
assert (n_ex < voltage.shape[0] and n_neuron < voltage.shape[1])
9696

97-
if time is None:
98-
time = (0, voltage.shape[-1])
99-
else:
100-
assert (time[0] < time[1])
101-
assert (time[1] <= voltage.shape[-1])
102-
103-
timer = np.arange(time[0], time[1])
104-
time_ticks = np.arange(time[0], time[1]+1, 10)
105-
106-
plt.figure()
107-
plt.plot(voltage[n_ex, n_neuron, timer])
108-
plt.xlabel('Simulation Time'); plt.ylabel('Voltage'); plt.title('Membrane voltage of neuron %d for example %d'%(n_neuron, n_ex+1))
109-
locs, labels = plt.xticks()
110-
locs = range(int(locs[1]), int(locs[-1]), 10)
111-
plt.xticks(locs, time_ticks)
112-
113-
# Draw threshold line only if given
114-
if threshold is not None:
115-
plt.axhline(threshold, linestyle='--', color='black', zorder=0)
116-
117-
plt.show()
97+
if time is None:
98+
time = (0, voltage.shape[-1])
99+
else:
100+
assert (time[0] < time[1])
101+
assert (time[1] <= voltage.shape[-1])
102+
103+
timer = np.arange(time[0], time[1])
104+
time_ticks = np.arange(time[0], time[1]+1, 10)
105+
106+
plt.figure()
107+
plt.plot(voltage[n_ex, n_neuron, timer])
108+
plt.xlabel('Simulation Time'); plt.ylabel('Voltage'); plt.title('Membrane voltage of neuron %d for example %d'%(n_neuron, n_ex+1))
109+
locs, labels = plt.xticks()
110+
locs = range(int(locs[1]), int(locs[-1]), 10)
111+
plt.xticks(locs, time_ticks)
112+
113+
# Draw threshold line only if given
114+
if threshold is not None:
115+
plt.axhline(threshold, linestyle='--', color='black', zorder=0)
116+
117+
plt.show()

0 commit comments

Comments
 (0)