Skip to content

Commit 6a3b80e

Browse files
committed
Fix the maxPool batch size issue
1 parent a9226d9 commit 6a3b80e

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

bindsnet/network/topology.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ def reset_state_variables(self) -> None:
813813
"""
814814
super().reset_state_variables()
815815

816-
self.firing_rates = torch.zeros(self.source.s.shape)
816+
self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:]))
817817

818818

819819
class MaxPool2dConnection(AbstractConnection):
@@ -901,7 +901,7 @@ def reset_state_variables(self) -> None:
901901
"""
902902
super().reset_state_variables()
903903

904-
self.firing_rates = torch.zeros(self.source.s.shape)
904+
self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:]))
905905

906906

907907
class MaxPoo3dConnection(AbstractConnection):
@@ -989,7 +989,7 @@ def reset_state_variables(self) -> None:
989989
"""
990990
super().reset_state_variables()
991991

992-
self.firing_rates = torch.zeros(self.source.s.shape)
992+
self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:]))
993993

994994

995995
class LocalConnection(AbstractConnection):

0 commit comments

Comments
 (0)