Skip to content

Commit 9049bc6

Browse files
committed
New local connection classes (1D, 2D, and 3D)
Apply fix for #537
1 parent 6a3b80e commit 9049bc6

12 files changed

Lines changed: 915 additions & 592 deletions

File tree

bindsnet/analysis/plotting.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from mpl_toolkits.axes_grid1 import make_axes_locatable
1010
from torch.nn.modules.utils import _pair
1111

12-
from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights, reshape_local_connection_2d_weights
12+
from bindsnet.utils import (
13+
reshape_conv2d_weights,
14+
reshape_locally_connected_weights,
15+
reshape_local_connection_2d_weights,
16+
)
1317

1418
plt.ion()
1519

@@ -377,15 +381,17 @@ def plot_locally_connected_weights(
377381

378382
return im
379383

380-
def plot_local_connection_2d_weights(lc : object,
384+
385+
def plot_local_connection_2d_weights(
386+
lc: object,
381387
input_channel: int = 0,
382388
output_channel: int = None,
383389
im: Optional[AxesImage] = None,
384390
lines: bool = True,
385391
figsize: Tuple[int, int] = (5, 5),
386392
cmap: str = "hot_r",
387-
color: str='r',
388-
) -> AxesImage:
393+
color: str = "r",
394+
) -> AxesImage:
389395
# language=rst
390396
"""
391397
Plot a connection weight matrix of a :code:`Connection` with `locally connected
@@ -400,23 +406,34 @@ def plot_local_connection_2d_weights(lc : object,
400406
"""
401407

402408
n_sqrt = int(np.ceil(np.sqrt(lc.n_filters)))
403-
sel_slice = lc.w.view(lc.in_channels, lc.n_filters, lc.conv_size[0], lc.conv_size[1], lc.kernel_size[0], lc.kernel_size[1]).cpu()
409+
sel_slice = lc.w.view(
410+
lc.in_channels,
411+
lc.n_filters,
412+
lc.conv_size[0],
413+
lc.conv_size[1],
414+
lc.kernel_size[0],
415+
lc.kernel_size[1],
416+
).cpu()
404417
input_size = _pair(int(np.sqrt(lc.source.n)))
405418
if output_channel is None:
406419
sel_slice = sel_slice[input_channel, ...]
407-
reshaped = reshape_local_connection_2d_weights(sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size)
420+
reshaped = reshape_local_connection_2d_weights(
421+
sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size
422+
)
408423
else:
409424
sel_slice = sel_slice[input_channel, output_channel, ...]
410425
sel_slice = sel_slice.unsqueeze(0)
411-
reshaped = reshape_local_connection_2d_weights(sel_slice, 1, lc.kernel_size, lc.conv_size, input_size)
426+
reshaped = reshape_local_connection_2d_weights(
427+
sel_slice, 1, lc.kernel_size, lc.conv_size, input_size
428+
)
412429
if im == None:
413430
fig, ax = plt.subplots(figsize=figsize)
414431

415432
im = ax.imshow(reshaped.cpu(), cmap=cmap, vmin=lc.wmin, vmax=lc.wmax)
416433
div = make_axes_locatable(ax)
417434
cax = div.append_axes("right", size="5%", pad=0.05)
418435

419-
if lines and output_channel is None:
436+
if lines and output_channel is None:
420437
for i in range(
421438
n_sqrt * lc.kernel_size[0],
422439
n_sqrt * lc.conv_size[0] * lc.kernel_size[0],
@@ -430,7 +447,7 @@ def plot_local_connection_2d_weights(lc : object,
430447
n_sqrt * lc.kernel_size[1],
431448
):
432449
ax.axvline(i - 0.5, color=color, linestyle="--")
433-
450+
434451
ax.set_xticks(())
435452
ax.set_yticks(())
436453
ax.set_aspect("auto")
@@ -441,7 +458,7 @@ def plot_local_connection_2d_weights(lc : object,
441458
else:
442459
im.set_data(reshaped.cpu())
443460
return im
444-
461+
445462

446463
def plot_assignments(
447464
assignments: torch.Tensor,
@@ -825,5 +842,3 @@ def plot_voltages(
825842
plt.tight_layout()
826843

827844
return ims, axes
828-
829-

bindsnet/datasets/spoken_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def process_data(
250250
# Fast Fourier Transform and Power Spectrum
251251
NFFT = 512
252252
mag_frames = np.absolute(np.fft.rfft(frames, NFFT)) # Magnitude of the FFT
253-
pow_frames = (1.0 / NFFT) * (mag_frames ** 2) # Power Spectrum
253+
pow_frames = (1.0 / NFFT) * (mag_frames**2) # Power Spectrum
254254

255255
# Log filter banks
256256
nfilt = 40

0 commit comments

Comments
 (0)