1414import bindsnet .learning
1515
1616class AbstractFeature (ABC ):
17+ def update_weights (
18+ self ,
19+ connection ,
20+ feature_output : torch .Tensor ,
21+ goodness : torch .Tensor ,
22+ goodness_error : torch .Tensor ,
23+ is_positive : bool ,
24+ learning_rate : float = 0.03 ,
25+ alpha : float = 2.0 ,
26+ ** kwargs
27+ ):
28+ """
29+ General Forward-Forward weight update using the loss:
30+ Loss = -alpha * delta / (1 + exp(alpha * delta))
31+ The update is proportional to the gradient of this loss w.r.t. delta.
32+ Args:
33+ connection: The connection whose weights to update
34+ feature_output: Output from feature computation
35+ goodness: Computed goodness values
36+ goodness_error: Goodness error (goodness - target_goodness)
37+ is_positive: Whether this is a positive example
38+ learning_rate: Learning rate for update
39+ alpha: Steepness parameter for loss (default 2.0)
40+ **kwargs: Additional arguments
41+ """
42+ if not hasattr (connection , 'w' ):
43+ return
44+ with torch .no_grad ():
45+ delta = goodness_error
46+ exp_term = torch .exp (alpha * delta )
47+ denom = (1 + exp_term )
48+ numer = - alpha * delta
49+ grad = numer / denom
50+ weight_update = learning_rate * grad .unsqueeze (- 1 ) * feature_output
51+ connection .w += weight_update .mean (0 ) # Average over batch
1752 # language=rst
1853 """
1954 Features to operate on signals traversing a connection.
@@ -945,6 +980,7 @@ def __init__(
945980
946981#FF Related
947982class ArctangentSurrogateFeature (AbstractFeature ):
983+ # Inherit update_weights from AbstractFeature for general FF update
948984 """
949985 Arctangent surrogate gradient feature for spiking neural networks.
950986
@@ -1002,53 +1038,58 @@ def __init__(
10021038
10031039 def compute (self , conn_spikes ) -> torch .Tensor :
10041040 """
1005- Compute forward pass with arctangent surrogate gradients.
1006-
1041+ Compute forward pass with arctangent surrogate gradients and optional batch normalization.
10071042 Args:
10081043 conn_spikes: Connection spikes tensor [batch_size, source_neurons * target_neurons]
1009-
10101044 Returns:
10111045 Target spikes with differentiable surrogate gradients [batch_size, source_neurons * target_neurons]
10121046 """
10131047 # Ensure connection is available
10141048 if self .connection is None :
10151049 raise RuntimeError ("ArctangentSurrogateFeature not properly initialized. Call prime_feature first." )
1016-
1050+
10171051 # Reshape conn_spikes to [batch_size, source_neurons, target_neurons]
10181052 batch_size = conn_spikes .size (0 )
10191053 source_n = self .connection .source .n
10201054 target_n = self .connection .target .n
1021-
1055+
10221056 # Reshape connection spikes to matrix form
10231057 conn_spikes_matrix = conn_spikes .view (batch_size , source_n , target_n )
1024-
1025- # Step 1: Compute synaptic input (sum over source neurons)
1058+
1059+ # Compute synaptic input (sum over source neurons)
10261060 synaptic_input = conn_spikes_matrix .sum (dim = 1 ) # [batch_size, target_neurons]
1027-
1061+
1062+ # Set the feature value for batch normalization
1063+ self .value = synaptic_input
1064+
1065+ # Optionally apply batch normalization if present
1066+ if hasattr (self , 'batch_norm' ) and self .batch_norm is not None :
1067+ synaptic_input = self .batch_norm .batch_normalize ()
1068+
10281069 # Step 2: Initialize membrane potential if needed
10291070 if not self .initialized :
10301071 self ._initialize_state (synaptic_input )
1031-
1072+
10321073 # Step 3: Integrate membrane potential
10331074 self .v_membrane = self .v_membrane + synaptic_input * self .dt
1034-
1075+
10351076 # Step 4: Generate spikes with arctangent surrogate gradients
10361077 spikes = self .arctangent_surrogate_spike (
10371078 self .v_membrane ,
10381079 self .spike_threshold ,
10391080 self .alpha
10401081 )
1041-
1082+
10421083 # Step 5: Apply reset mechanism
10431084 self ._apply_reset (spikes )
1044-
1085+
10451086 # Step 6: Broadcast spikes back to connection format
10461087 # Each target spike affects all connections to that target
10471088 spikes_broadcast = spikes .unsqueeze (1 ).expand (batch_size , source_n , target_n )
1048-
1089+
10491090 # Apply spikes to incoming connections
10501091 output_spikes = conn_spikes_matrix * spikes_broadcast
1051-
1092+
10521093 # Reshape back to original format
10531094 return output_spikes .view (batch_size , source_n * target_n )
10541095
@@ -1252,7 +1293,14 @@ def __init__(
12521293 self .network = network # <-- Store the network
12531294
12541295 def compute (self , sample : torch .Tensor ) -> dict :
1255- # Use self.network if provided, else fall back to parent_feature's connection
1296+ # For Forward-Forward learning, use the normalized feature values instead of raw spikes
1297+ # This allows gradients to flow through the batch normalization
1298+
1299+ # Get the pipeline that contains the features
1300+ # We'll compute goodness from the feature values that have been normalized
1301+ goodness_per_layer = {}
1302+
1303+ # Use parent feature's network if available
12561304 if self .network is not None :
12571305 network = self .network
12581306 else :
@@ -1262,27 +1310,86 @@ def compute(self, sample: torch.Tensor) -> dict:
12621310 raise RuntimeError ("Connection must have a valid network attribute." )
12631311 network = self .parent .connection .network
12641312
1265- network .reset_state_variables ()
1266- inputs = {self .input_layer : sample .unsqueeze (0 ) if sample .dim () == 1 else sample }
1267- spike_record = {layer_name : [] for layer_name in network .layers }
1268-
1269- for t in range (self .time ):
1270- network .run (inputs , time = 1 )
1313+ # For Forward-Forward, we compute goodness from the sum of squared normalized activities
1314+ # We assume the pipeline has already computed feature values during the forward pass
1315+ total_goodness = torch .tensor (0.0 , requires_grad = True )
1316+
1317+ # If we have a parent feature with a value, use that for goodness computation
1318+ if hasattr (self .parent , 'value' ) and self .parent .value is not None :
1319+ # Compute goodness as sum of squared activities (common in Forward-Forward)
1320+ feature_goodness = (self .parent .value ** 2 ).sum ()
1321+ goodness_per_layer [self .parent .name ] = feature_goodness
1322+ total_goodness = total_goodness + feature_goodness
1323+ else :
1324+ # Fallback: run network and compute goodness from layer activities
1325+ network .reset_state_variables ()
1326+ inputs = {self .input_layer : sample .unsqueeze (0 ) if sample .dim () == 1 else sample }
1327+
1328+ for t in range (self .time ):
1329+ network .run (inputs , time = 1 )
1330+
12711331 for layer_name , layer in network .layers .items ():
1272- spike_record [layer_name ].append (layer .s .clone ().detach ())
1332+ if layer_name != self .input_layer : # Skip input layer
1333+ # Convert spikes to float and compute goodness
1334+ layer_activity = layer .s .float ()
1335+ if layer_activity .requires_grad :
1336+ goodness = (layer_activity ** 2 ).sum ()
1337+ else :
1338+ goodness = (layer_activity ** 2 ).sum ().requires_grad_ (True )
1339+ goodness_per_layer [layer_name ] = goodness
1340+ total_goodness = total_goodness + goodness
12731341
1274- goodness_per_layer = {}
1275- for layer_name , spikes_list in spike_record .items ():
1276- spikes = torch .stack (spikes_list , dim = 0 )
1277- goodness = spikes .sum (dim = 0 ).sum (dim = 0 )
1278- goodness_per_layer [layer_name ] = goodness
1279-
1280- total_goodness = sum ([v .sum () for v in goodness_per_layer .values ()])
12811342 goodness_per_layer ["total_goodness" ] = total_goodness
1282-
12831343 return goodness_per_layer
12841344
1285-
1345+ class ForwardForwardUpdate (AbstractSubFeature ):
1346+ """
1347+ SubFeature for Forward-Forward weight update using the loss:
1348+ Loss = -alpha * delta / (1 + exp(alpha * delta))
1349+ The update is proportional to the gradient of this loss w.r.t. delta.
1350+ """
1351+
1352+ def __init__ (
1353+ self ,
1354+ name : str ,
1355+ parent_feature : AbstractFeature ,
1356+ ) -> None :
1357+ super ().__init__ (name , parent_feature )
1358+ # Optionally, you could set self.sub_feature = self.update_weights
1359+
1360+ def update_weights (
1361+ self ,
1362+ connection ,
1363+ feature_output : torch .Tensor ,
1364+ goodness : torch .Tensor ,
1365+ goodness_error : torch .Tensor ,
1366+ is_positive : bool ,
1367+ learning_rate : float = 0.03 ,
1368+ alpha : float = 2.0 ,
1369+ ** kwargs
1370+ ):
1371+ """
1372+ Perform the Forward-Forward weight update.
1373+ Args:
1374+ connection: The connection whose weights to update
1375+ feature_output: Output from feature computation
1376+ goodness: Computed goodness values
1377+ goodness_error: Goodness error (goodness - target_goodness)
1378+ is_positive: Whether this is a positive example
1379+ learning_rate: Learning rate for update
1380+ alpha: Steepness parameter for loss (default 2.0)
1381+ **kwargs: Additional arguments
1382+ """
1383+ if not hasattr (connection , 'w' ):
1384+ return
1385+ with torch .no_grad ():
1386+ delta = goodness_error
1387+ exp_term = torch .exp (alpha * delta )
1388+ denom = (1 + exp_term )
1389+ numer = - alpha * delta
1390+ grad = numer / denom
1391+ weight_update = learning_rate * grad .unsqueeze (- 1 ) * feature_output
1392+ connection .w += weight_update .mean (0 ) # Average over batch
12861393# Helper function for easy creation
12871394def create_arctangent_surrogate_connection (
12881395 source ,
@@ -1339,3 +1446,78 @@ def create_arctangent_surrogate_connection(
13391446 return connection
13401447
13411448
1449+ class BatchNormalization (AbstractSubFeature ):
1450+ """
1451+ SubFeature to perform batch normalization on the parent feature's value using PyTorch's nn.BatchNorm1d.
1452+ Normalizes across the batch (first) dimension and includes learnable gamma (weight) and beta (bias).
1453+ """
1454+ def __init__ (
1455+ self ,
1456+ name : str ,
1457+ parent_feature : AbstractFeature ,
1458+ eps : float = 1e-5 ,
1459+ affine : bool = True ,
1460+ momentum : float = 0.1 ,
1461+ ) -> None :
1462+ super ().__init__ (name , parent_feature )
1463+ self .eps = eps
1464+ self .affine = affine
1465+ self .momentum = momentum
1466+ self .bn = None # Will be initialized after parent_feature is primed
1467+
1468+ # Try to infer feature size if possible
1469+ if hasattr (self .parent , 'value' ) and isinstance (self .parent .value , torch .Tensor ):
1470+ num_features = self .parent .value .shape [- 1 ]
1471+ self ._init_bn (num_features )
1472+ else :
1473+ self ._pending_init = True # Will initialize in prime_feature
1474+
1475+ self .sub_feature = self .batch_normalize
1476+
1477+ def _init_bn (self , num_features ):
1478+ self .bn = torch .nn .BatchNorm1d (
1479+ num_features = num_features ,
1480+ eps = self .eps ,
1481+ affine = self .affine ,
1482+ momentum = self .momentum ,
1483+ )
1484+ self ._pending_init = False
1485+
1486+ def prime_feature (self , connection , device , ** kwargs ):
1487+ # If not already initialized, do so now
1488+ if getattr (self , '_pending_init' , False ):
1489+ if hasattr (self .parent , 'value' ) and isinstance (self .parent .value , torch .Tensor ):
1490+ num_features = self .parent .value .shape [- 1 ]
1491+ self ._init_bn (num_features )
1492+ if self .bn is not None :
1493+ self .bn .to (device )
1494+ super ().prime_feature (connection , device , ** kwargs )
1495+
1496+ def batch_normalize (self ):
1497+ value = self .parent .value
1498+ # value should be [batch_size, num_features]
1499+ if self .bn is None :
1500+ raise RuntimeError ("BatchNorm not initialized. Call prime_feature first." )
1501+ # If value is 1D, unsqueeze to 2D for BatchNorm1d
1502+ if value .dim () == 1 :
1503+ value = value .unsqueeze (0 )
1504+
1505+ # Handle single sample case for BatchNorm1d
1506+ if value .size (0 ) == 1 and self .bn .training :
1507+ # Switch to eval mode temporarily for single samples
1508+ was_training = self .bn .training
1509+ self .bn .eval ()
1510+ result = self .bn (value )
1511+ if was_training :
1512+ self .bn .train ()
1513+ return result
1514+ else :
1515+ return self .bn (value )
1516+
1517+ @property
1518+ def gamma (self ):
1519+ return self .bn .weight if self .bn is not None and self .affine else None
1520+
1521+ @property
1522+ def beta (self ):
1523+ return self .bn .bias if self .bn is not None and self .affine else None
0 commit comments