99from mpl_toolkits .axes_grid1 import make_axes_locatable
1010from torch .nn .modules .utils import _pair
1111
12- from bindsnet .utils import reshape_conv2d_weights , reshape_locally_connected_weights , reshape_local_connection_2d_weights
12+ from bindsnet .utils import (
13+ reshape_conv2d_weights ,
14+ reshape_locally_connected_weights ,
15+ reshape_local_connection_2d_weights ,
16+ )
1317
1418plt .ion ()
1519
@@ -377,15 +381,17 @@ def plot_locally_connected_weights(
377381
378382 return im
379383
380- def plot_local_connection_2d_weights (lc : object ,
384+
385+ def plot_local_connection_2d_weights (
386+ lc : object ,
381387 input_channel : int = 0 ,
382388 output_channel : int = None ,
383389 im : Optional [AxesImage ] = None ,
384390 lines : bool = True ,
385391 figsize : Tuple [int , int ] = (5 , 5 ),
386392 cmap : str = "hot_r" ,
387- color : str = 'r' ,
388- ) -> AxesImage :
393+ color : str = "r" ,
394+ ) -> AxesImage :
389395 # language=rst
390396 """
391397 Plot a connection weight matrix of a :code:`Connection` with `locally connected
@@ -400,23 +406,34 @@ def plot_local_connection_2d_weights(lc : object,
400406 """
401407
402408 n_sqrt = int (np .ceil (np .sqrt (lc .n_filters )))
403- sel_slice = lc .w .view (lc .in_channels , lc .n_filters , lc .conv_size [0 ], lc .conv_size [1 ], lc .kernel_size [0 ], lc .kernel_size [1 ]).cpu ()
409+ sel_slice = lc .w .view (
410+ lc .in_channels ,
411+ lc .n_filters ,
412+ lc .conv_size [0 ],
413+ lc .conv_size [1 ],
414+ lc .kernel_size [0 ],
415+ lc .kernel_size [1 ],
416+ ).cpu ()
404417 input_size = _pair (int (np .sqrt (lc .source .n )))
405418 if output_channel is None :
406419 sel_slice = sel_slice [input_channel , ...]
407- reshaped = reshape_local_connection_2d_weights (sel_slice , lc .n_filters , lc .kernel_size , lc .conv_size , input_size )
420+ reshaped = reshape_local_connection_2d_weights (
421+ sel_slice , lc .n_filters , lc .kernel_size , lc .conv_size , input_size
422+ )
408423 else :
409424 sel_slice = sel_slice [input_channel , output_channel , ...]
410425 sel_slice = sel_slice .unsqueeze (0 )
411- reshaped = reshape_local_connection_2d_weights (sel_slice , 1 , lc .kernel_size , lc .conv_size , input_size )
426+ reshaped = reshape_local_connection_2d_weights (
427+ sel_slice , 1 , lc .kernel_size , lc .conv_size , input_size
428+ )
412429 if im == None :
413430 fig , ax = plt .subplots (figsize = figsize )
414431
415432 im = ax .imshow (reshaped .cpu (), cmap = cmap , vmin = lc .wmin , vmax = lc .wmax )
416433 div = make_axes_locatable (ax )
417434 cax = div .append_axes ("right" , size = "5%" , pad = 0.05 )
418435
419- if lines and output_channel is None :
436+ if lines and output_channel is None :
420437 for i in range (
421438 n_sqrt * lc .kernel_size [0 ],
422439 n_sqrt * lc .conv_size [0 ] * lc .kernel_size [0 ],
@@ -430,7 +447,7 @@ def plot_local_connection_2d_weights(lc : object,
430447 n_sqrt * lc .kernel_size [1 ],
431448 ):
432449 ax .axvline (i - 0.5 , color = color , linestyle = "--" )
433-
450+
434451 ax .set_xticks (())
435452 ax .set_yticks (())
436453 ax .set_aspect ("auto" )
@@ -441,7 +458,7 @@ def plot_local_connection_2d_weights(lc : object,
441458 else :
442459 im .set_data (reshaped .cpu ())
443460 return im
444-
461+
445462
446463def plot_assignments (
447464 assignments : torch .Tensor ,
@@ -825,5 +842,3 @@ def plot_voltages(
825842 plt .tight_layout ()
826843
827844 return ims , axes
828-
829-
0 commit comments