File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -813,7 +813,7 @@ def reset_state_variables(self) -> None:
813813 """
814814 super ().reset_state_variables ()
815815
816- self .firing_rates = torch .zeros (self .source .s .shape )
816+ self .firing_rates = torch .zeros (self .source .batch_size , * ( self . source . s .shape [ 1 :]) )
817817
818818
819819class MaxPool2dConnection (AbstractConnection ):
@@ -901,7 +901,7 @@ def reset_state_variables(self) -> None:
901901 """
902902 super ().reset_state_variables ()
903903
904- self .firing_rates = torch .zeros (self .source .s .shape )
904+ self .firing_rates = torch .zeros (self .source .batch_size , * ( self . source . s .shape [ 1 :]) )
905905
906906
907907class MaxPoo3dConnection (AbstractConnection ):
@@ -989,7 +989,7 @@ def reset_state_variables(self) -> None:
989989 """
990990 super ().reset_state_variables ()
991991
992- self .firing_rates = torch .zeros (self .source .s .shape )
992+ self .firing_rates = torch .zeros (self .source .batch_size , * ( self . source . s .shape [ 1 :]) )
993993
994994
995995class LocalConnection (AbstractConnection ):
You can’t perform that action at this time.
0 commit comments