|
9 | 9 |
|
10 | 10 |
|
11 | 11 | def plot_weights_movie(ws, sample_every=1): |
12 | | - """ |
13 | | - Create and plot movie of weights (:code:`ws`). |
14 | | - |
15 | | - Inputs: |
16 | | - |
17 | | - | :code:`ws` (:code:`numpy.array`): Numpy array of shape :code:`[N_examples, source, target, time]` |
18 | | - | :code:`sample_every` (:code:`int`): Sub-sample using this parameter. For example if :code:`time` is |
19 | | - too large (500), set this parameter to 20 to sample weights |
20 | | - every 20 iterations. |
21 | | - """ |
22 | | - weights = [] |
23 | | - |
24 | | - # Obtain samples from the weights for every example |
25 | | - for i in range(ws.shape[0]): |
26 | | - sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)] |
27 | | - weights.append(sub_sampled_weight) |
28 | | - else: |
29 | | - weights = np.concatenate(weights, axis=0) |
30 | | - |
31 | | - # Initialize plot |
32 | | - fig = plt.figure() |
33 | | - im = plt.imshow(weights[0, :, :], cmap='hot_r', animated=True, vmin=0, vmax=1) |
34 | | - plt.axis('off'); plt.colorbar(im) |
35 | | - |
36 | | - # Update function for the animation |
37 | | - def update(j): |
38 | | - im.set_data(weights[j, :, :]) |
39 | | - return [im] |
40 | | - |
41 | | - # Initialize animatino |
42 | | - global ani; ani=0 |
43 | | - ani = animation.FuncAnimation(fig, update, frames=weights.shape[-1], interval=1000, blit=True) |
44 | | - plt.show() |
45 | | - |
| 12 | + """ |
| 13 | + Create and plot movie of weights (:code:`ws`). |
| 14 | + |
| 15 | + Inputs: |
| 16 | + |
| 17 | + | :code:`ws` (:code:`numpy.array`): Numpy array of shape :code:`[N_examples, source, target, time]` |
| 18 | + | :code:`sample_every` (:code:`int`): Sub-sample using this parameter. For example if :code:`time` is |
| 19 | + too large (500), set this parameter to 20 to sample weights |
| 20 | + every 20 iterations. |
| 21 | + """ |
| 22 | + weights = [] |
| 23 | + |
| 24 | + # Obtain samples from the weights for every example |
| 25 | + for i in range(ws.shape[0]): |
| 26 | + sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)] |
| 27 | + weights.append(sub_sampled_weight) |
| 28 | + else: |
| 29 | + weights = np.concatenate(weights, axis=0) |
| 30 | + |
| 31 | + # Initialize plot |
| 32 | + fig = plt.figure() |
| 33 | + im = plt.imshow(weights[0, :, :], cmap='hot_r', animated=True, vmin=0, vmax=1) |
| 34 | + plt.axis('off'); plt.colorbar(im) |
| 35 | + |
| 36 | + # Update function for the animation |
| 37 | + def update(j): |
| 38 | + im.set_data(weights[j, :, :]) |
| 39 | + return [im] |
| 40 | + |
| 41 | + # Initialize animatino |
| 42 | + global ani; ani=0 |
| 43 | + ani = animation.FuncAnimation(fig, update, frames=weights.shape[-1], interval=1000, blit=True) |
| 44 | + plt.show() |
| 45 | + |
46 | 46 | def plot_spike_trains_for_example(spikes, n_ex=None, top_k=None, indices=None): |
47 | | - ''' |
48 | | - Plot spike trains for top-k neurons or for specific indices. |
49 | | - |
50 | | - Inputs: |
51 | | - |
52 | | - | :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`): Spiking train data for a population of neurons for one example. |
53 | | - | :code:`n_ex` (:code:`int`): Allows user to pick which example to plot spikes for. Must be >= 0. |
54 | | - | :code:`top_k` (:code:`int`): Plot k neurons that spiked the most for n_ex example. |
55 | | - | :code:`indices` (:code:`list(int)`): Plot specific neurons' spiking activity instead of top_k. Meant to replace top_k. |
56 | | - ''' |
| 47 | + ''' |
| 48 | + Plot spike trains for top-k neurons or for specific indices. |
| 49 | + |
| 50 | + Inputs: |
| 51 | + |
| 52 | + | :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`): Spiking train data for a population of neurons for one example. |
| 53 | + | :code:`n_ex` (:code:`int`): Allows user to pick which example to plot spikes for. Must be >= 0. |
| 54 | + | :code:`top_k` (:code:`int`): Plot k neurons that spiked the most for n_ex example. |
| 55 | + | :code:`indices` (:code:`list(int)`): Plot specific neurons' spiking activity instead of top_k. Meant to replace top_k. |
| 56 | + ''' |
57 | 57 |
|
58 | | - assert (n_ex is not None and n_ex >= 0 and n_ex < spikes.shape[0]) |
59 | | - |
60 | | - plt.figure() |
61 | | - |
62 | | - if top_k is None and indices is None: # Plot all neurons' spiking activity |
63 | | - spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, :, :]] |
64 | | - plt.title('Spiking activity for all %d neurons'%spikes.shape[1]) |
65 | | - |
66 | | - elif top_k is None: # Plot based on indices parameter |
67 | | - assert (indices is not None) |
68 | | - spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, indices, :]] |
69 | | - |
70 | | - elif indices is None: # Plot based on top_k parameter |
71 | | - assert (top_k is not None) |
72 | | - # Obtain the top k neurons that fired the most |
73 | | - top_k_loc = np.argsort(np.sum(spikes[n_ex,:,:], axis=1), axis=0)[::-1] |
74 | | - spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :]] |
75 | | - plt.title('Spiking activity for top %d neurons'%top_k) |
76 | | - |
77 | | - plt.eventplot(spike_per_neuron, linelengths= [0.5]*len(spike_per_neuron)) |
78 | | - plt.xlabel('Simulation Time'); plt.ylabel('Neuron index') |
79 | | - plt.show() |
| 58 | + assert (n_ex is not None and n_ex >= 0 and n_ex < spikes.shape[0]) |
| 59 | + |
| 60 | + plt.figure() |
| 61 | + |
| 62 | + if top_k is None and indices is None: # Plot all neurons' spiking activity |
| 63 | + spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, :, :]] |
| 64 | + plt.title('Spiking activity for all %d neurons'%spikes.shape[1]) |
| 65 | + |
| 66 | + elif top_k is None: # Plot based on indices parameter |
| 67 | + assert (indices is not None) |
| 68 | + spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, indices, :]] |
| 69 | + |
| 70 | + elif indices is None: # Plot based on top_k parameter |
| 71 | + assert (top_k is not None) |
| 72 | + # Obtain the top k neurons that fired the most |
| 73 | + top_k_loc = np.argsort(np.sum(spikes[n_ex,:,:], axis=1), axis=0)[::-1] |
| 74 | + spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :]] |
| 75 | + plt.title('Spiking activity for top %d neurons'%top_k) |
| 76 | + |
| 77 | + plt.eventplot(spike_per_neuron, linelengths= [0.5]*len(spike_per_neuron)) |
| 78 | + plt.xlabel('Simulation Time'); plt.ylabel('Neuron index') |
| 79 | + plt.show() |
80 | 80 |
|
81 | 81 | def plot_voltage(voltage, n_ex=0, n_neuron=0, time=None, threshold=None): |
82 | | - ''' |
83 | | - Plot voltage for a single neuron on a specific example. |
84 | | - |
85 | | - Inputs: |
86 | | - |
87 | | - | :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`): Tensor or array of shape :code:`[n_examples, n_neurons, time]`. |
88 | | - | :code:`n_ex` (:code:`int`): Allows user to pick which example to plot voltage for. |
89 | | - | :code:`n_neuron` (:code:`int`): Neuron index for which to plot voltages for. |
90 | | - | :code:`time` (:code:`tuple(int)`): Plot spiking activity of neurons between the given range of time. |
91 | | - | :code:`threshold` (:code:`float`): Neuron spiking threshold. Will be shown on the plot. |
92 | | - ''' |
93 | | - |
94 | | - assert (n_ex >= 0 and n_neuron >= 0) |
95 | | - assert (n_ex < voltage.shape[0] and n_neuron < voltage.shape[1]) |
| 82 | + ''' |
| 83 | + Plot voltage for a single neuron on a specific example. |
| 84 | + |
| 85 | + Inputs: |
| 86 | + |
| 87 | + | :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`): Tensor or array of shape :code:`[n_examples, n_neurons, time]`. |
| 88 | + | :code:`n_ex` (:code:`int`): Allows user to pick which example to plot voltage for. |
| 89 | + | :code:`n_neuron` (:code:`int`): Neuron index for which to plot voltages for. |
| 90 | + | :code:`time` (:code:`tuple(int)`): Plot spiking activity of neurons between the given range of time. |
| 91 | + | :code:`threshold` (:code:`float`): Neuron spiking threshold. Will be shown on the plot. |
| 92 | + ''' |
| 93 | + |
| 94 | + assert (n_ex >= 0 and n_neuron >= 0) |
| 95 | + assert (n_ex < voltage.shape[0] and n_neuron < voltage.shape[1]) |
96 | 96 |
|
97 | | - if time is None: |
98 | | - time = (0, voltage.shape[-1]) |
99 | | - else: |
100 | | - assert (time[0] < time[1]) |
101 | | - assert (time[1] <= voltage.shape[-1]) |
102 | | - |
103 | | - timer = np.arange(time[0], time[1]) |
104 | | - time_ticks = np.arange(time[0], time[1]+1, 10) |
105 | | - |
106 | | - plt.figure() |
107 | | - plt.plot(voltage[n_ex, n_neuron, timer]) |
108 | | - plt.xlabel('Simulation Time'); plt.ylabel('Voltage'); plt.title('Membrane voltage of neuron %d for example %d'%(n_neuron, n_ex+1)) |
109 | | - locs, labels = plt.xticks() |
110 | | - locs = range(int(locs[1]), int(locs[-1]), 10) |
111 | | - plt.xticks(locs, time_ticks) |
112 | | - |
113 | | - # Draw threshold line only if given |
114 | | - if threshold is not None: |
115 | | - plt.axhline(threshold, linestyle='--', color='black', zorder=0) |
116 | | - |
117 | | - plt.show() |
| 97 | + if time is None: |
| 98 | + time = (0, voltage.shape[-1]) |
| 99 | + else: |
| 100 | + assert (time[0] < time[1]) |
| 101 | + assert (time[1] <= voltage.shape[-1]) |
| 102 | + |
| 103 | + timer = np.arange(time[0], time[1]) |
| 104 | + time_ticks = np.arange(time[0], time[1]+1, 10) |
| 105 | + |
| 106 | + plt.figure() |
| 107 | + plt.plot(voltage[n_ex, n_neuron, timer]) |
| 108 | + plt.xlabel('Simulation Time'); plt.ylabel('Voltage'); plt.title('Membrane voltage of neuron %d for example %d'%(n_neuron, n_ex+1)) |
| 109 | + locs, labels = plt.xticks() |
| 110 | + locs = range(int(locs[1]), int(locs[-1]), 10) |
| 111 | + plt.xticks(locs, time_ticks) |
| 112 | + |
| 113 | + # Draw threshold line only if given |
| 114 | + if threshold is not None: |
| 115 | + plt.axhline(threshold, linestyle='--', color='black', zorder=0) |
| 116 | + |
| 117 | + plt.show() |
0 commit comments