@@ -32,6 +32,7 @@ def __init__(
3232 decay : float = 0.0 ,
3333 parent_feature = None ,
3434 sparse : Optional [bool ] = False ,
35+ batch_size : int = 1 ,
3536 ** kwargs ,
3637 ) -> None :
3738 # language=rst
@@ -49,6 +50,7 @@ def __init__(
4950 :param decay: Constant multiple to decay weights by on each iteration
5051 :param parent_feature: Parent feature to inherit :code:`value` from
5152 :param sparse: Should :code:`value` parameter be sparse tensor or not
53+ :param batch_size: Mini-batch size.
5254 """
5355
5456 #### Initialize class variables ####
@@ -64,6 +66,7 @@ def __init__(
6466 self .decay = decay
6567 self .parent_feature = parent_feature
6668 self .sparse = sparse
69+ self .batch_size = batch_size
6770 self .kwargs = kwargs
6871
6972 ## Backend ##
@@ -120,12 +123,19 @@ def __init__(
120123 )
121124
122125 self .assert_valid_range ()
123- if value is not None :
124- self .assert_feature_in_range ()
125- if self .sparse :
126- self .value = self .value .to_sparse ()
127- assert not getattr (self , 'enforce_polarity' , False ), \
128- "enforce_polarity isn't supported for sparse tensors"
126+ if value is None :
127+ return
128+
129+ self .assert_feature_in_range ()
130+ if not self .sparse :
131+ return
132+
133+ if len (self .value .shape ) == 2 :
134+ self .value = self .value .unsqueeze (0 ).repeat (self .batch_size , 1 , 1 )
135+
136+ self .value = self .value .to_sparse ()
137+ assert not getattr (self , 'enforce_polarity' , False ), \
138+ "enforce_polarity isn't supported for sparse tensors"
129139
130140 @abstractmethod
131141 def reset_state_variables (self ) -> None :
@@ -341,7 +351,8 @@ def __init__(
341351 reduction : Optional [callable ] = None ,
342352 decay : float = 0.0 ,
343353 parent_feature = None ,
344- sparse : Optional [bool ] = False
354+ sparse : Optional [bool ] = False ,
355+ batch_size : int = 1
345356 ) -> None :
346357 # language=rst
347358 """
@@ -360,6 +371,7 @@ def __init__(
360371 :param decay: Constant multiple to decay weights by on each iteration
361372 :param parent_feature: Parent feature to inherit :code:`value` from
362373 :param sparse: Should :code:`value` parameter be sparse tensor or not
374+ :param batch_size: Mini-batch size.
363375 """
364376
365377 ### Assertions ###
@@ -373,7 +385,8 @@ def __init__(
373385 reduction = reduction ,
374386 decay = decay ,
375387 parent_feature = parent_feature ,
376- sparse = sparse
388+ sparse = sparse ,
389+ batch_size = batch_size
377390 )
378391
379392 def sparse_bernoulli (self ):
@@ -434,14 +447,16 @@ def __init__(
434447 self ,
435448 name : str ,
436449 value : Union [torch .Tensor , float , int ] = None ,
437- sparse : Optional [bool ] = False
450+ sparse : Optional [bool ] = False ,
451+ batch_size : int = 1
438452 ) -> None :
439453 # language=rst
440454 """
441455 Boolean mask which determines whether or not signals are allowed to traverse certain synapses.
442456 :param name: Name of the feature
443457 :param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable
444458 :param sparse: Should :code:`value` parameter be sparse tensor or not
459+ :param batch_size: Mini-batch size.
445460 """
446461
447462 ### Assertions ###
@@ -460,7 +475,8 @@ def __init__(
460475 super ().__init__ (
461476 name = name ,
462477 value = value ,
463- sparse = sparse
478+ sparse = sparse ,
479+ batch_size = batch_size
464480 )
465481
466482 def compute (self , conn_spikes ) -> torch .Tensor :
@@ -544,7 +560,8 @@ def __init__(
544560 reduction : Optional [callable ] = None ,
545561 enforce_polarity : Optional [bool ] = False ,
546562 decay : float = 0.0 ,
547- sparse : Optional [bool ] = False
563+ sparse : Optional [bool ] = False ,
564+ batch_size : int = 1
548565 ) -> None :
549566 # language=rst
550567 """
@@ -564,6 +581,7 @@ def __init__(
564581 :param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
565582 :param decay: Constant multiple to decay weights by on each iteration
566583 :param sparse: Should :code:`value` parameter be sparse tensor or not
584+ :param batch_size: Mini-batch size.
567585 """
568586
569587 self .norm_frequency = norm_frequency
@@ -577,7 +595,8 @@ def __init__(
577595 nu = nu ,
578596 reduction = reduction ,
579597 decay = decay ,
580- sparse = sparse
598+ sparse = sparse ,
599+ batch_size = batch_size
581600 )
582601
583602 def reset_state_variables (self ) -> None :
@@ -631,7 +650,8 @@ def __init__(
631650 value : Union [torch .Tensor , float , int ] = None ,
632651 range : Optional [Sequence [float ]] = None ,
633652 norm : Optional [Union [torch .Tensor , float , int ]] = None ,
634- sparse : Optional [bool ] = False
653+ sparse : Optional [bool ] = False ,
654+ batch_size : int = 1
635655 ) -> None :
636656 # language=rst
637657 """
@@ -642,14 +662,16 @@ def __init__(
642662 :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
643663 and after the value has been updated by the learning rule (if there is one)
644664 :param sparse: Should :code:`value` parameter be sparse tensor or not
665+ :param batch_size: Mini-batch size.
645666 """
646667
647668 super ().__init__ (
648669 name = name ,
649670 value = value ,
650671 range = [- torch .inf , + torch .inf ] if range is None else range ,
651672 norm = norm ,
652- sparse = sparse
673+ sparse = sparse ,
674+ batch_size = batch_size
653675 )
654676
655677 def reset_state_variables (self ) -> None :
@@ -674,17 +696,19 @@ def __init__(
674696 name : str ,
675697 value : Union [torch .Tensor , float , int ] = None ,
676698 range : Optional [Sequence [float ]] = None ,
677- sparse : Optional [bool ] = False
699+ sparse : Optional [bool ] = False ,
700+ batch_size : int = 1
678701 ) -> None :
679702 # language=rst
680703 """
681704 Adds scalars to signals
682705 :param name: Name of the feature
683706 :param value: Values to scale signals by
684707 :param sparse: Should :code:`value` parameter be sparse tensor or not
708+ :param batch_size: Mini-batch size.
685709 """
686710
687- super ().__init__ (name = name , value = value , range = range , sparse = sparse )
711+ super ().__init__ (name = name , value = value , range = range , sparse = sparse , batch_size = batch_size )
688712
689713 def reset_state_variables (self ) -> None :
690714 pass
@@ -713,7 +737,8 @@ def __init__(
713737 value : Union [torch .Tensor , float , int ] = None ,
714738 degrade_function : callable = None ,
715739 parent_feature : Optional [AbstractFeature ] = None ,
716- sparse : Optional [bool ] = False
740+ sparse : Optional [bool ] = False ,
741+ batch_size : int = 1
717742 ) -> None :
718743 # language=rst
719744 """
@@ -725,10 +750,11 @@ def __init__(
725750 constant to be *subtracted* from the propagating spikes.
726751 :param parent_feature: Parent feature with desired :code:`value` to inherit
727752 :param sparse: Should :code:`value` parameter be sparse tensor or not
753+ :param batch_size: Mini-batch size.
728754 """
729755
730756 # Note: parent_feature will override value. See abstract constructor
731- super ().__init__ (name = name , value = value , parent_feature = parent_feature , sparse = sparse )
757+ super ().__init__ (name = name , value = value , parent_feature = parent_feature , sparse = sparse , batch_size = batch_size )
732758
733759 self .degrade_function = degrade_function
734760
@@ -747,7 +773,8 @@ def __init__(
747773 ann_values : Union [list , tuple ] = None ,
748774 const_update_rate : float = 0.1 ,
749775 const_decay : float = 0.001 ,
750- sparse : Optional [bool ] = False
776+ sparse : Optional [bool ] = False ,
777+ batch_size : int = 1
751778 ) -> None :
752779 # language=rst
753780 """
@@ -759,6 +786,7 @@ def __init__(
759786 :param const_update_rate: The mask upatate rate of the ANN decision.
760787 :param const_decay: The spontaneous activation of the synapses.
761788 :param sparse: Should :code:`value` parameter be sparse tensor or not
789+ :param batch_size: Mini-batch size.
762790 """
763791
764792 # Define the ANN
@@ -794,7 +822,7 @@ def forward(self, x):
794822 self .const_update_rate = const_update_rate
795823 self .const_decay = const_decay
796824
797- super ().__init__ (name = name , value = value , sparse = sparse )
825+ super ().__init__ (name = name , value = value , sparse = sparse , batch_size = batch_size )
798826
799827 def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
800828
@@ -843,7 +871,8 @@ def __init__(
843871 ann_values : Union [list , tuple ] = None ,
844872 const_update_rate : float = 0.1 ,
845873 const_decay : float = 0.01 ,
846- sparse : Optional [bool ] = False
874+ sparse : Optional [bool ] = False ,
875+ batch_size : int = 1
847876 ) -> None :
848877 # language=rst
849878 """
@@ -855,6 +884,7 @@ def __init__(
855884 :param const_update_rate: The mask upatate rate of the ANN decision.
856885 :param const_decay: The spontaneous activation of the synapses.
857886 :param sparse: Should :code:`value` parameter be sparse tensor or not
887+ :param batch_size: Mini-batch size.
858888 """
859889
860890 # Define the ANN
@@ -890,7 +920,7 @@ def forward(self, x):
890920 self .const_update_rate = const_update_rate
891921 self .const_decay = const_decay
892922
893- super ().__init__ (name = name , value = value , sparse = sparse )
923+ super ().__init__ (name = name , value = value , sparse = sparse , batch_size = batch_size )
894924
895925 def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
896926
0 commit comments