|
9 | 9 | from mpl_toolkits.axes_grid1 import make_axes_locatable |
10 | 10 | from torch.nn.modules.utils import _pair |
11 | 11 |
|
12 | | -from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights |
| 12 | +from bindsnet.utils import ( |
| 13 | + reshape_conv2d_weights, |
| 14 | + reshape_locally_connected_weights, |
| 15 | + reshape_local_connection_2d_weights, |
| 16 | +) |
13 | 17 |
|
14 | 18 | plt.ion() |
15 | 19 |
|
@@ -378,6 +382,84 @@ def plot_locally_connected_weights( |
378 | 382 | return im |
379 | 383 |
|
380 | 384 |
|
| 385 | +def plot_local_connection_2d_weights( |
| 386 | + lc: object, |
| 387 | + input_channel: int = 0, |
| 388 | + output_channel: int = None, |
| 389 | + im: Optional[AxesImage] = None, |
| 390 | + lines: bool = True, |
| 391 | + figsize: Tuple[int, int] = (5, 5), |
| 392 | + cmap: str = "hot_r", |
| 393 | + color: str = "r", |
| 394 | +) -> AxesImage: |
| 395 | + # language=rst |
| 396 | + """ |
| 397 | + Plot a connection weight matrix of a :code:`Connection` with `locally connected |
| 398 | + structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_. |
| 399 | + :param lc: An object of the class LocalConnection2D |
| 400 | + :param input_channel: The input channel to plot its corresponding weights, default is the first channel |
| 401 | + :param output_channel: If not None, will only plot the weights corresponding to this output channel (filter) |
| 402 | + :param lines: Indicates whether or not draw horizontal and vertical lines separating input regions. |
| 403 | + :param figsize: Horizontal and vertical figure size in inches. |
| 404 | + :param cmap: Matplotlib colormap. |
| 405 | + :return: ``ims, axes``: Used for re-drawing the plots. |
| 406 | + """ |
| 407 | + |
| 408 | + n_sqrt = int(np.ceil(np.sqrt(lc.n_filters))) |
| 409 | + sel_slice = lc.w.view( |
| 410 | + lc.in_channels, |
| 411 | + lc.n_filters, |
| 412 | + lc.conv_size[0], |
| 413 | + lc.conv_size[1], |
| 414 | + lc.kernel_size[0], |
| 415 | + lc.kernel_size[1], |
| 416 | + ).cpu() |
| 417 | + input_size = _pair(int(np.sqrt(lc.source.n))) |
| 418 | + if output_channel is None: |
| 419 | + sel_slice = sel_slice[input_channel, ...] |
| 420 | + reshaped = reshape_local_connection_2d_weights( |
| 421 | + sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size |
| 422 | + ) |
| 423 | + else: |
| 424 | + sel_slice = sel_slice[input_channel, output_channel, ...] |
| 425 | + sel_slice = sel_slice.unsqueeze(0) |
| 426 | + reshaped = reshape_local_connection_2d_weights( |
| 427 | + sel_slice, 1, lc.kernel_size, lc.conv_size, input_size |
| 428 | + ) |
| 429 | + if im == None: |
| 430 | + fig, ax = plt.subplots(figsize=figsize) |
| 431 | + |
| 432 | + im = ax.imshow(reshaped.cpu(), cmap=cmap, vmin=lc.wmin, vmax=lc.wmax) |
| 433 | + div = make_axes_locatable(ax) |
| 434 | + cax = div.append_axes("right", size="5%", pad=0.05) |
| 435 | + |
| 436 | + if lines and output_channel is None: |
| 437 | + for i in range( |
| 438 | + n_sqrt * lc.kernel_size[0], |
| 439 | + n_sqrt * lc.conv_size[0] * lc.kernel_size[0], |
| 440 | + n_sqrt * lc.kernel_size[0], |
| 441 | + ): |
| 442 | + ax.axhline(i - 0.5, color=color, linestyle="--") |
| 443 | + |
| 444 | + for i in range( |
| 445 | + n_sqrt * lc.kernel_size[1], |
| 446 | + n_sqrt * lc.conv_size[1] * lc.kernel_size[1], |
| 447 | + n_sqrt * lc.kernel_size[1], |
| 448 | + ): |
| 449 | + ax.axvline(i - 0.5, color=color, linestyle="--") |
| 450 | + |
| 451 | + ax.set_xticks(()) |
| 452 | + ax.set_yticks(()) |
| 453 | + ax.set_aspect("auto") |
| 454 | + |
| 455 | + plt.colorbar(im, cax=cax) |
| 456 | + fig.tight_layout() |
| 457 | + |
| 458 | + else: |
| 459 | + im.set_data(reshaped.cpu()) |
| 460 | + return im |
| 461 | + |
| 462 | + |
381 | 463 | def plot_assignments( |
382 | 464 | assignments: torch.Tensor, |
383 | 465 | im: Optional[AxesImage] = None, |
|
0 commit comments