Skip to content

Commit 12d060c

Browse files
Merge pull request #536 from hafezgh/localconnections
Localconnections
2 parents b850dd1 + 9049bc6 commit 12d060c

File tree

7 files changed

+2568
-26
lines changed

7 files changed

+2568
-26
lines changed

bindsnet/analysis/plotting.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from mpl_toolkits.axes_grid1 import make_axes_locatable
1010
from torch.nn.modules.utils import _pair
1111

12-
from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights
12+
from bindsnet.utils import (
13+
reshape_conv2d_weights,
14+
reshape_locally_connected_weights,
15+
reshape_local_connection_2d_weights,
16+
)
1317

1418
plt.ion()
1519

@@ -378,6 +382,84 @@ def plot_locally_connected_weights(
378382
return im
379383

380384

385+
def plot_local_connection_2d_weights(
386+
lc: object,
387+
input_channel: int = 0,
388+
output_channel: int = None,
389+
im: Optional[AxesImage] = None,
390+
lines: bool = True,
391+
figsize: Tuple[int, int] = (5, 5),
392+
cmap: str = "hot_r",
393+
color: str = "r",
394+
) -> AxesImage:
395+
# language=rst
396+
"""
397+
Plot a connection weight matrix of a :code:`Connection` with `locally connected
398+
structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_.
399+
:param lc: An object of the class LocalConnection2D
400+
:param input_channel: The input channel to plot its corresponding weights, default is the first channel
401+
:param output_channel: If not None, will only plot the weights corresponding to this output channel (filter)
402+
:param lines: Indicates whether or not draw horizontal and vertical lines separating input regions.
403+
:param figsize: Horizontal and vertical figure size in inches.
404+
:param cmap: Matplotlib colormap.
405+
:return: ``ims, axes``: Used for re-drawing the plots.
406+
"""
407+
408+
n_sqrt = int(np.ceil(np.sqrt(lc.n_filters)))
409+
sel_slice = lc.w.view(
410+
lc.in_channels,
411+
lc.n_filters,
412+
lc.conv_size[0],
413+
lc.conv_size[1],
414+
lc.kernel_size[0],
415+
lc.kernel_size[1],
416+
).cpu()
417+
input_size = _pair(int(np.sqrt(lc.source.n)))
418+
if output_channel is None:
419+
sel_slice = sel_slice[input_channel, ...]
420+
reshaped = reshape_local_connection_2d_weights(
421+
sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size
422+
)
423+
else:
424+
sel_slice = sel_slice[input_channel, output_channel, ...]
425+
sel_slice = sel_slice.unsqueeze(0)
426+
reshaped = reshape_local_connection_2d_weights(
427+
sel_slice, 1, lc.kernel_size, lc.conv_size, input_size
428+
)
429+
if im == None:
430+
fig, ax = plt.subplots(figsize=figsize)
431+
432+
im = ax.imshow(reshaped.cpu(), cmap=cmap, vmin=lc.wmin, vmax=lc.wmax)
433+
div = make_axes_locatable(ax)
434+
cax = div.append_axes("right", size="5%", pad=0.05)
435+
436+
if lines and output_channel is None:
437+
for i in range(
438+
n_sqrt * lc.kernel_size[0],
439+
n_sqrt * lc.conv_size[0] * lc.kernel_size[0],
440+
n_sqrt * lc.kernel_size[0],
441+
):
442+
ax.axhline(i - 0.5, color=color, linestyle="--")
443+
444+
for i in range(
445+
n_sqrt * lc.kernel_size[1],
446+
n_sqrt * lc.conv_size[1] * lc.kernel_size[1],
447+
n_sqrt * lc.kernel_size[1],
448+
):
449+
ax.axvline(i - 0.5, color=color, linestyle="--")
450+
451+
ax.set_xticks(())
452+
ax.set_yticks(())
453+
ax.set_aspect("auto")
454+
455+
plt.colorbar(im, cax=cax)
456+
fig.tight_layout()
457+
458+
else:
459+
im.set_data(reshaped.cpu())
460+
return im
461+
462+
381463
def plot_assignments(
382464
assignments: torch.Tensor,
383465
im: Optional[AxesImage] = None,

0 commit comments

Comments
 (0)