Skip to content

Commit 2c0760d

Browse files
committed
Add batch dimension in case if sparse == True
1 parent 12d8e03 commit 2c0760d

2 files changed

Lines changed: 55 additions & 27 deletions

File tree

bindsnet/models/models.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,7 @@ def __init__(
175175
)
176176

177177
# Connections
178-
if sparse:
179-
w = 0.3 * torch.rand(batch_size, self.n_inpt, self.n_neurons)
180-
else:
181-
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
178+
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
182179
input_exc_conn = MulticompartmentConnection(
183180
source=input_layer,
184181
target=exc_layer,
@@ -192,7 +189,8 @@ def __init__(
192189
reduction=reduction,
193190
nu=nu,
194191
learning_rule=MMCPostPre,
195-
sparse=sparse
192+
sparse=sparse,
193+
batch_size=batch_size
196194
)
197195
]
198196
)

bindsnet/network/topology_features.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
decay: float = 0.0,
3333
parent_feature=None,
3434
sparse: Optional[bool] = False,
35+
batch_size: int = 1,
3536
**kwargs,
3637
) -> None:
3738
# language=rst
@@ -49,6 +50,7 @@ def __init__(
4950
:param decay: Constant multiple to decay weights by on each iteration
5051
:param parent_feature: Parent feature to inherit :code:`value` from
5152
:param sparse: Should :code:`value` parameter be sparse tensor or not
53+
:param batch_size: Mini-batch size.
5254
"""
5355

5456
#### Initialize class variables ####
@@ -64,6 +66,7 @@ def __init__(
6466
self.decay = decay
6567
self.parent_feature = parent_feature
6668
self.sparse = sparse
69+
self.batch_size = batch_size
6770
self.kwargs = kwargs
6871

6972
## Backend ##
@@ -120,12 +123,19 @@ def __init__(
120123
)
121124

122125
self.assert_valid_range()
123-
if value is not None:
124-
self.assert_feature_in_range()
125-
if self.sparse:
126-
self.value = self.value.to_sparse()
127-
assert not getattr(self, 'enforce_polarity', False), \
128-
"enforce_polarity isn't supported for sparse tensors"
126+
if value is None:
127+
return
128+
129+
self.assert_feature_in_range()
130+
if not self.sparse:
131+
return
132+
133+
if len(self.value.shape) == 2:
134+
self.value = self.value.unsqueeze(0).repeat(self.batch_size, 1, 1)
135+
136+
self.value = self.value.to_sparse()
137+
assert not getattr(self, 'enforce_polarity', False), \
138+
"enforce_polarity isn't supported for sparse tensors"
129139

130140
@abstractmethod
131141
def reset_state_variables(self) -> None:
@@ -341,7 +351,8 @@ def __init__(
341351
reduction: Optional[callable] = None,
342352
decay: float = 0.0,
343353
parent_feature=None,
344-
sparse: Optional[bool] = False
354+
sparse: Optional[bool] = False,
355+
batch_size: int = 1
345356
) -> None:
346357
# language=rst
347358
"""
@@ -360,6 +371,7 @@ def __init__(
360371
:param decay: Constant multiple to decay weights by on each iteration
361372
:param parent_feature: Parent feature to inherit :code:`value` from
362373
:param sparse: Should :code:`value` parameter be sparse tensor or not
374+
:param batch_size: Mini-batch size.
363375
"""
364376

365377
### Assertions ###
@@ -373,7 +385,8 @@ def __init__(
373385
reduction=reduction,
374386
decay=decay,
375387
parent_feature=parent_feature,
376-
sparse=sparse
388+
sparse=sparse,
389+
batch_size=batch_size
377390
)
378391

379392
def sparse_bernoulli(self):
@@ -434,14 +447,16 @@ def __init__(
434447
self,
435448
name: str,
436449
value: Union[torch.Tensor, float, int] = None,
437-
sparse: Optional[bool] = False
450+
sparse: Optional[bool] = False,
451+
batch_size: int = 1
438452
) -> None:
439453
# language=rst
440454
"""
441455
Boolean mask which determines whether or not signals are allowed to traverse certain synapses.
442456
:param name: Name of the feature
443457
:param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable
444458
:param sparse: Should :code:`value` parameter be sparse tensor or not
459+
:param batch_size: Mini-batch size.
445460
"""
446461

447462
### Assertions ###
@@ -460,7 +475,8 @@ def __init__(
460475
super().__init__(
461476
name=name,
462477
value=value,
463-
sparse=sparse
478+
sparse=sparse,
479+
batch_size=batch_size
464480
)
465481

466482
def compute(self, conn_spikes) -> torch.Tensor:
@@ -544,7 +560,8 @@ def __init__(
544560
reduction: Optional[callable] = None,
545561
enforce_polarity: Optional[bool] = False,
546562
decay: float = 0.0,
547-
sparse: Optional[bool] = False
563+
sparse: Optional[bool] = False,
564+
batch_size: int = 1
548565
) -> None:
549566
# language=rst
550567
"""
@@ -564,6 +581,7 @@ def __init__(
564581
:param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
565582
:param decay: Constant multiple to decay weights by on each iteration
566583
:param sparse: Should :code:`value` parameter be sparse tensor or not
584+
:param batch_size: Mini-batch size.
567585
"""
568586

569587
self.norm_frequency = norm_frequency
@@ -577,7 +595,8 @@ def __init__(
577595
nu=nu,
578596
reduction=reduction,
579597
decay=decay,
580-
sparse=sparse
598+
sparse=sparse,
599+
batch_size=batch_size
581600
)
582601

583602
def reset_state_variables(self) -> None:
@@ -631,7 +650,8 @@ def __init__(
631650
value: Union[torch.Tensor, float, int] = None,
632651
range: Optional[Sequence[float]] = None,
633652
norm: Optional[Union[torch.Tensor, float, int]] = None,
634-
sparse: Optional[bool] = False
653+
sparse: Optional[bool] = False,
654+
batch_size: int = 1
635655
) -> None:
636656
# language=rst
637657
"""
@@ -642,14 +662,16 @@ def __init__(
642662
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
643663
and after the value has been updated by the learning rule (if there is one)
644664
:param sparse: Should :code:`value` parameter be sparse tensor or not
665+
:param batch_size: Mini-batch size.
645666
"""
646667

647668
super().__init__(
648669
name=name,
649670
value=value,
650671
range=[-torch.inf, +torch.inf] if range is None else range,
651672
norm=norm,
652-
sparse=sparse
673+
sparse=sparse,
674+
batch_size=batch_size
653675
)
654676

655677
def reset_state_variables(self) -> None:
@@ -674,17 +696,19 @@ def __init__(
674696
name: str,
675697
value: Union[torch.Tensor, float, int] = None,
676698
range: Optional[Sequence[float]] = None,
677-
sparse: Optional[bool] = False
699+
sparse: Optional[bool] = False,
700+
batch_size: int = 1
678701
) -> None:
679702
# language=rst
680703
"""
681704
Adds scalars to signals
682705
:param name: Name of the feature
683706
:param value: Values to scale signals by
684707
:param sparse: Should :code:`value` parameter be sparse tensor or not
708+
:param batch_size: Mini-batch size.
685709
"""
686710

687-
super().__init__(name=name, value=value, range=range, sparse=sparse)
711+
super().__init__(name=name, value=value, range=range, sparse=sparse, batch_size=batch_size)
688712

689713
def reset_state_variables(self) -> None:
690714
pass
@@ -713,7 +737,8 @@ def __init__(
713737
value: Union[torch.Tensor, float, int] = None,
714738
degrade_function: callable = None,
715739
parent_feature: Optional[AbstractFeature] = None,
716-
sparse: Optional[bool] = False
740+
sparse: Optional[bool] = False,
741+
batch_size: int = 1
717742
) -> None:
718743
# language=rst
719744
"""
@@ -725,10 +750,11 @@ def __init__(
725750
constant to be *subtracted* from the propagating spikes.
726751
:param parent_feature: Parent feature with desired :code:`value` to inherit
727752
:param sparse: Should :code:`value` parameter be sparse tensor or not
753+
:param batch_size: Mini-batch size.
728754
"""
729755

730756
# Note: parent_feature will override value. See abstract constructor
731-
super().__init__(name=name, value=value, parent_feature=parent_feature, sparse=sparse)
757+
super().__init__(name=name, value=value, parent_feature=parent_feature, sparse=sparse, batch_size=batch_size)
732758

733759
self.degrade_function = degrade_function
734760

@@ -747,7 +773,8 @@ def __init__(
747773
ann_values: Union[list, tuple] = None,
748774
const_update_rate: float = 0.1,
749775
const_decay: float = 0.001,
750-
sparse: Optional[bool] = False
776+
sparse: Optional[bool] = False,
777+
batch_size: int = 1
751778
) -> None:
752779
# language=rst
753780
"""
@@ -759,6 +786,7 @@ def __init__(
759786
:param const_update_rate: The mask upatate rate of the ANN decision.
760787
:param const_decay: The spontaneous activation of the synapses.
761788
:param sparse: Should :code:`value` parameter be sparse tensor or not
789+
:param batch_size: Mini-batch size.
762790
"""
763791

764792
# Define the ANN
@@ -794,7 +822,7 @@ def forward(self, x):
794822
self.const_update_rate = const_update_rate
795823
self.const_decay = const_decay
796824

797-
super().__init__(name=name, value=value, sparse=sparse)
825+
super().__init__(name=name, value=value, sparse=sparse, batch_size=batch_size)
798826

799827
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
800828

@@ -843,7 +871,8 @@ def __init__(
843871
ann_values: Union[list, tuple] = None,
844872
const_update_rate: float = 0.1,
845873
const_decay: float = 0.01,
846-
sparse: Optional[bool] = False
874+
sparse: Optional[bool] = False,
875+
batch_size: int = 1
847876
) -> None:
848877
# language=rst
849878
"""
@@ -855,6 +884,7 @@ def __init__(
855884
:param const_update_rate: The mask upatate rate of the ANN decision.
856885
:param const_decay: The spontaneous activation of the synapses.
857886
:param sparse: Should :code:`value` parameter be sparse tensor or not
887+
:param batch_size: Mini-batch size.
858888
"""
859889

860890
# Define the ANN
@@ -890,7 +920,7 @@ def forward(self, x):
890920
self.const_update_rate = const_update_rate
891921
self.const_decay = const_decay
892922

893-
super().__init__(name=name, value=value, sparse=sparse)
923+
super().__init__(name=name, value=value, sparse=sparse, batch_size=batch_size)
894924

895925
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
896926

0 commit comments

Comments
 (0)