Skip to content

Commit bbfb0b3

Browse files
committed
Zerro out
1 parent 3bc3b00 commit bbfb0b3

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

bindsnet/network/topology.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,15 @@ def update(self, **kwargs) -> None:
209209
210210
:param bool learning: Whether to allow connection updates.
211211
"""
212-
pass
212+
mask = kwargs.get("mask", None)
213+
if mask is None:
214+
return
215+
216+
for f in self.pipeline:
217+
if type(f).__name__ != 'Weight':
218+
continue
219+
220+
f.value.masked_fill_(mask, 0)
213221

214222
@abstractmethod
215223
def reset_state_variables(self) -> None:
@@ -484,6 +492,8 @@ def update(self, **kwargs) -> None:
484492
"""
485493
learning = kwargs.get("learning", False)
486494
if learning and not self.manual_update:
495+
super().update(**kwargs)
496+
487497
# Pipeline learning
488498
for f in self.pipeline:
489499
f.update(**kwargs)

examples/dosidicus/network.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from bindsnet.network import Network
44
from bindsnet.network.nodes import LIFNodes
55
from bindsnet.network.topology import MulticompartmentConnection
6-
from bindsnet.network.topology_features import Weight, Bias, Mask
6+
from bindsnet.network.topology_features import Weight, Bias
77
from bindsnet.learning.MCC_learning import PostPre
88

99
network = Network(dt=1.0)
@@ -22,17 +22,22 @@
2222
)
2323
bias = Bias(name='bias_feature', value=torch.rand(5, 5))
2424

25-
mask = torch.tril(torch.ones((5, 5)), diagonal=-1).bool()
26-
2725
connection = MulticompartmentConnection(
2826
source=source_layer,
2927
target=target_layer,
30-
pipeline=[weight, Mask(name='mask', value=mask), bias],
28+
pipeline=[weight, bias],
3129
device='cpu'
3230
)
3331
network.add_connection(connection, source="input", target="output")
3432
print(connection.pipeline[0].value)
3533
network.run(
36-
inputs={"input": torch.bernoulli(torch.rand(250, 5)).byte()}, time=250
34+
inputs={"input": torch.bernoulli(torch.rand(250, 5)).byte()},
35+
time=250,
36+
masks={
37+
('input', 'output'): ~torch.tril(torch.ones((5, 5)), diagonal=-1).bool()
38+
}
3739
)
40+
41+
print(network.layers['input'].v)
42+
print(network.layers['output'].v)
3843
print(connection.pipeline[0].value)

0 commit comments

Comments
 (0)