Skip to content

Commit 7a2fdb9

Browse files
committed
fix
1 parent e13974b commit 7a2fdb9

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

bindsnet/network/topology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
441441

442442
# Run through pipeline
443443
for f in self.pipeline:
444-
if type(f).__name__ == 'Weight' and self.mask:
444+
if type(f).__name__ == 'Weight' and self.mask is not None:
445445
f.value.masked_fill_(self.mask, 0)
446446
conn_spikes = f.compute(conn_spikes)
447447

examples/dosidicus/network.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,13 @@
2626
source=source_layer,
2727
target=target_layer,
2828
pipeline=[weight, bias],
29+
mask=~torch.tril(torch.ones((5, 5)), diagonal=-1).bool(),
2930
device='cpu'
3031
)
3132
network.add_connection(connection, source="input", target="output")
3233
print(connection.pipeline[0].value)
3334
network.run(
3435
inputs={"input": torch.bernoulli(torch.rand(250, 5)).byte()},
35-
time=250,
36-
masks={
37-
('input', 'output'): ~torch.tril(torch.ones((5, 5)), diagonal=-1).bool()
38-
}
36+
time=250
3937
)
4038
print(connection.pipeline[0].value)

0 commit comments

Comments
 (0)