1010import torch .nn as nn
1111import bindsnet .learning
1212
13-
1413class AbstractFeature (ABC ):
1514 # language=rst
1615 """
@@ -938,3 +937,301 @@ def __init__(
938937 super ().__init__ (name , parent_feature )
939938
940939 self .sub_feature = self .parent .update
940+
941+
942+
943+
944+ class ForwardForwardWeight (AbstractFeature ):
945+ """
946+ Forward-Forward learning weight feature for MulticompartmentConnection.
947+
948+ Implements the Forward-Forward algorithm with surrogate gradients, enabling
949+ layer-wise learning without backpropagation through time. This feature adds:
950+ - Arctangent surrogate gradient computation
951+ - Membrane potential tracking for goodness scores
952+ - Forward-Forward loss computation capabilities
953+
954+ Compatible with the MCC architecture and composable with other features.
955+ """
956+
957+ def __init__ (
958+ self ,
959+ spike_threshold : float = 1.0 ,
960+ alpha : float = 2.0 ,
961+ alpha_loss : float = 0.6 ,
962+ dt : float = 1.0 ,
963+ ** kwargs
964+ ):
965+ """
966+ Initialize Forward-Forward weight feature.
967+
968+ Args:
969+ spike_threshold: Threshold for spike generation and surrogate gradient
970+ alpha: Arctangent surrogate gradient steepness parameter
971+ alpha_loss: Forward-Forward loss threshold parameter
972+ dt: Time step size for membrane potential integration
973+ **kwargs: Additional arguments passed to parent WeightFeature
974+ """
975+ super ().__init__ (** kwargs )
976+
977+ self .spike_threshold = spike_threshold
978+ self .alpha = alpha
979+ self .alpha_loss = alpha_loss
980+ self .dt = dt
981+
982+ # Membrane potential state for goodness computation
983+ self .v_membrane = None
984+
985+ def reset_state (self ):
986+ """Reset membrane potential state."""
987+ self .v_membrane = None
988+
989+ def forward (
990+ self ,
991+ s : torch .Tensor ,
992+ connection : 'MulticompartmentConnection' ,
993+ ** kwargs
994+ ) -> torch .Tensor :
995+ """
996+ Forward pass through weight feature with surrogate gradients.
997+
998+ This method integrates with the MCC forward pass pipeline.
999+
1000+ Args:
1001+ s: Input spikes [batch_size, source_neurons]
1002+ connection: Parent MulticompartmentConnection instance
1003+ **kwargs: Additional arguments from MCC forward pass
1004+
1005+ Returns:
1006+ Weighted synaptic input with surrogate gradient computation
1007+ """
1008+ # Get connection weights (handled by parent MCC)
1009+ w = connection .w
1010+
1011+ # Compute synaptic input: I = s * W
1012+ synaptic_input = torch .mm (s .float (), w )
1013+
1014+ # Track this for goodness score computation if needed
1015+ if hasattr (self , '_track_activity' ) and self ._track_activity :
1016+ self ._last_synaptic_input = synaptic_input .detach ()
1017+
1018+ return synaptic_input
1019+
1020+ def compute_spikes_with_surrogate (
1021+ self ,
1022+ synaptic_input : torch .Tensor ,
1023+ target_layer : 'AbstractPopulation'
1024+ ) -> torch .Tensor :
1025+ """
1026+ Generate spikes with surrogate gradients from synaptic input.
1027+
1028+ This method should be called after the weight forward pass to convert
1029+ synaptic input to spikes using the Forward-Forward surrogate gradient.
1030+
1031+ Args:
1032+ synaptic_input: Weighted input [batch_size, target_neurons]
1033+ target_layer: Target neuron population
1034+
1035+ Returns:
1036+ Spikes with surrogate gradients [batch_size, target_neurons]
1037+ """
1038+ # Initialize or update membrane potential
1039+ if self .v_membrane is None :
1040+ self .v_membrane = torch .zeros_like (synaptic_input )
1041+
1042+ # Integrate synaptic input (simple Euler integration)
1043+ self .v_membrane = self .v_membrane + synaptic_input * self .dt
1044+
1045+ # Generate spikes using arctangent surrogate gradient
1046+ spikes = ArctangentSurrogate .apply (
1047+ self .v_membrane ,
1048+ self .spike_threshold ,
1049+ self .alpha
1050+ )
1051+
1052+ # Optional: Reset membrane potential where spikes occurred
1053+ # self.v_membrane = self.v_membrane * (1 - spikes)
1054+
1055+ return spikes
1056+
1057+ def compute_goodness_score (self , spike_activity : torch .Tensor ) -> torch .Tensor :
1058+ """
1059+ Compute Forward-Forward goodness score from spike activity.
1060+
1061+ Args:
1062+ spike_activity: Spike traces [batch_size, time_steps, neurons] or
1063+ spike counts [batch_size, neurons]
1064+
1065+ Returns:
1066+ Goodness scores [batch_size]
1067+ """
1068+ if spike_activity .dim () == 3 :
1069+ # Sum over time dimension if time traces provided
1070+ spike_counts = torch .sum (spike_activity , dim = 1 ) # [batch_size, neurons]
1071+ else :
1072+ spike_counts = spike_activity # Already summed
1073+
1074+ # Forward-Forward goodness: mean squared spike activity
1075+ goodness = torch .mean (spike_counts ** 2 , dim = 1 ) # [batch_size]
1076+
1077+ return goodness
1078+
1079+ def compute_ff_loss (
1080+ self ,
1081+ goodness_pos : torch .Tensor ,
1082+ goodness_neg : torch .Tensor
1083+ ) -> torch .Tensor :
1084+ """
1085+ Compute Forward-Forward contrastive loss.
1086+
1087+ Loss = log(1 + exp(-g_pos + α)) + log(1 + exp(g_neg - α))
1088+
1089+ Args:
1090+ goodness_pos: Goodness scores for positive samples [batch_size]
1091+ goodness_neg: Goodness scores for negative samples [batch_size]
1092+
1093+ Returns:
1094+ Forward-Forward loss [batch_size]
1095+ """
1096+ # Positive loss: encourage high goodness for true labels
1097+ loss_pos = torch .log (1 + torch .exp (- goodness_pos + self .alpha_loss ))
1098+
1099+ # Negative loss: encourage low goodness for false labels
1100+ loss_neg = torch .log (1 + torch .exp (goodness_neg - self .alpha_loss ))
1101+
1102+ return loss_pos + loss_neg
1103+
1104+ def get_feature_info (self ) -> dict :
1105+ """Get information about this Forward-Forward feature."""
1106+ return {
1107+ 'feature_type' : 'ForwardForwardWeight' ,
1108+ 'spike_threshold' : self .spike_threshold ,
1109+ 'alpha_surrogate' : self .alpha ,
1110+ 'alpha_loss' : self .alpha_loss ,
1111+ 'dt' : self .dt ,
1112+ 'surrogate_function' : 'arctangent' ,
1113+ 'compatible_with' : ['MCC' , 'other_weight_features' ]
1114+ }
1115+
1116+
1117+ class ArctangentSurrogate (torch .autograd .Function ):
1118+ """
1119+ Arctangent surrogate gradient function for Forward-Forward training.
1120+
1121+ Forward pass: spikes = (membrane_potential >= threshold)
1122+ Backward pass: gradient = 1 / (α * |membrane_potential - threshold| + 1)
1123+
1124+ This enables gradient-based learning in spiking neural networks by
1125+ providing a smooth approximation of the non-differentiable spike function.
1126+ """
1127+
1128+ @staticmethod
1129+ def forward (
1130+ ctx ,
1131+ membrane_potential : torch .Tensor ,
1132+ threshold : float ,
1133+ alpha : float
1134+ ) -> torch .Tensor :
1135+ """
1136+ Forward pass: generate binary spikes.
1137+
1138+ Args:
1139+ membrane_potential: Neuron membrane potentials
1140+ threshold: Spike threshold
1141+ alpha: Surrogate gradient steepness parameter
1142+
1143+ Returns:
1144+ Binary spike tensor (0 or 1)
1145+ """
1146+ # Save tensors and parameters for backward pass
1147+ ctx .save_for_backward (membrane_potential )
1148+ ctx .threshold = threshold
1149+ ctx .alpha = alpha
1150+
1151+ # Generate spikes (heaviside step function)
1152+ spikes = (membrane_potential >= threshold ).float ()
1153+
1154+ return spikes
1155+
1156+ @staticmethod
1157+ def backward (
1158+ ctx ,
1159+ grad_output : torch .Tensor
1160+ ) -> Tuple [torch .Tensor , None , None ]:
1161+ """
1162+ Backward pass: compute surrogate gradients.
1163+
1164+ Uses arctangent-based surrogate: 1 / (α * |v - threshold| + 1)
1165+
1166+ Args:
1167+ grad_output: Gradient from subsequent layers
1168+
1169+ Returns:
1170+ Tuple of (grad_membrane_potential, None, None)
1171+ """
1172+ membrane_potential , = ctx .saved_tensors
1173+ threshold = ctx .threshold
1174+ alpha = ctx .alpha
1175+
1176+ # Compute arctangent surrogate gradient
1177+ # grad = 1 / (α * |v - v_th| + 1)
1178+ surrogate_grad = 1.0 / (alpha * torch .abs (membrane_potential - threshold ) + 1.0 )
1179+
1180+ # Apply chain rule with incoming gradients
1181+ grad_membrane_potential = grad_output * surrogate_grad
1182+
1183+ # Return gradients (only for first argument)
1184+ return grad_membrane_potential , None , None
1185+
1186+
1187+ # Add this helper function to create FF-enabled MCC connections
1188+ def create_ff_connection (
1189+ source : 'AbstractPopulation' ,
1190+ target : 'AbstractPopulation' ,
1191+ w : Optional [torch .Tensor ] = None ,
1192+ spike_threshold : float = 1.0 ,
1193+ alpha : float = 2.0 ,
1194+ alpha_loss : float = 0.6 ,
1195+ dt : float = 1.0 ,
1196+ ** mcc_kwargs
1197+ ) -> 'MulticompartmentConnection' :
1198+ """
1199+ Helper function to create MulticompartmentConnection with ForwardForwardWeight feature.
1200+
1201+ Args:
1202+ source: Source neuron population
1203+ target: Target neuron population
1204+ w: Connection weights (if None, will be initialized)
1205+ spike_threshold: FF spike threshold
1206+ alpha: FF surrogate gradient parameter
1207+ alpha_loss: FF loss threshold parameter
1208+ dt: Time step size
1209+ **mcc_kwargs: Additional arguments for MulticompartmentConnection
1210+
1211+ Returns:
1212+ MCC with ForwardForwardWeight feature attached
1213+ """
1214+ from bindsnet .network .topology import MulticompartmentConnection
1215+
1216+ # Create ForwardForwardWeight feature
1217+ ff_feature = ForwardForwardWeight (
1218+ spike_threshold = spike_threshold ,
1219+ alpha = alpha ,
1220+ alpha_loss = alpha_loss ,
1221+ dt = dt
1222+ )
1223+
1224+ # Initialize weights if not provided
1225+ if w is None :
1226+ w = 0.1 * torch .randn (source .n , target .n )
1227+
1228+ # Create MCC with FF feature
1229+ connection = MulticompartmentConnection (
1230+ source = source ,
1231+ target = target ,
1232+ w = w ,
1233+ features = [ff_feature ],
1234+ ** mcc_kwargs
1235+ )
1236+
1237+ return connection
0 commit comments