Skip to content

Commit 2708232

Browse files
author
Kevin Chang
committed
update
1 parent 3f5a586 commit 2708232

11 files changed

+500
-253
lines changed

bindsnet/network/topology_features.py

Lines changed: 213 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,41 @@
1414
import bindsnet.learning
1515

1616
class AbstractFeature(ABC):
17+
def update_weights(
18+
self,
19+
connection,
20+
feature_output: torch.Tensor,
21+
goodness: torch.Tensor,
22+
goodness_error: torch.Tensor,
23+
is_positive: bool,
24+
learning_rate: float = 0.03,
25+
alpha: float = 2.0,
26+
**kwargs
27+
):
28+
"""
29+
General Forward-Forward weight update using the loss:
30+
Loss = -alpha * delta / (1 + exp(alpha * delta))
31+
The update is proportional to the gradient of this loss w.r.t. delta.
32+
Args:
33+
connection: The connection whose weights to update
34+
feature_output: Output from feature computation
35+
goodness: Computed goodness values
36+
goodness_error: Goodness error (goodness - target_goodness)
37+
is_positive: Whether this is a positive example
38+
learning_rate: Learning rate for update
39+
alpha: Steepness parameter for loss (default 2.0)
40+
**kwargs: Additional arguments
41+
"""
42+
if not hasattr(connection, 'w'):
43+
return
44+
with torch.no_grad():
45+
delta = goodness_error
46+
exp_term = torch.exp(alpha * delta)
47+
denom = (1 + exp_term)
48+
numer = -alpha * delta
49+
grad = numer / denom
50+
weight_update = learning_rate * grad.unsqueeze(-1) * feature_output
51+
connection.w += weight_update.mean(0) # Average over batch
1752
# language=rst
1853
"""
1954
Features to operate on signals traversing a connection.
@@ -945,6 +980,7 @@ def __init__(
945980

946981
#FF Related
947982
class ArctangentSurrogateFeature(AbstractFeature):
983+
# Inherit update_weights from AbstractFeature for general FF update
948984
"""
949985
Arctangent surrogate gradient feature for spiking neural networks.
950986
@@ -1002,53 +1038,58 @@ def __init__(
10021038

10031039
def compute(self, conn_spikes) -> torch.Tensor:
10041040
"""
1005-
Compute forward pass with arctangent surrogate gradients.
1006-
1041+
Compute forward pass with arctangent surrogate gradients and optional batch normalization.
10071042
Args:
10081043
conn_spikes: Connection spikes tensor [batch_size, source_neurons * target_neurons]
1009-
10101044
Returns:
10111045
Target spikes with differentiable surrogate gradients [batch_size, source_neurons * target_neurons]
10121046
"""
10131047
# Ensure connection is available
10141048
if self.connection is None:
10151049
raise RuntimeError("ArctangentSurrogateFeature not properly initialized. Call prime_feature first.")
1016-
1050+
10171051
# Reshape conn_spikes to [batch_size, source_neurons, target_neurons]
10181052
batch_size = conn_spikes.size(0)
10191053
source_n = self.connection.source.n
10201054
target_n = self.connection.target.n
1021-
1055+
10221056
# Reshape connection spikes to matrix form
10231057
conn_spikes_matrix = conn_spikes.view(batch_size, source_n, target_n)
1024-
1025-
# Step 1: Compute synaptic input (sum over source neurons)
1058+
1059+
# Compute synaptic input (sum over source neurons)
10261060
synaptic_input = conn_spikes_matrix.sum(dim=1) # [batch_size, target_neurons]
1027-
1061+
1062+
# Set the feature value for batch normalization
1063+
self.value = synaptic_input
1064+
1065+
# Optionally apply batch normalization if present
1066+
if hasattr(self, 'batch_norm') and self.batch_norm is not None:
1067+
synaptic_input = self.batch_norm.batch_normalize()
1068+
10281069
# Step 2: Initialize membrane potential if needed
10291070
if not self.initialized:
10301071
self._initialize_state(synaptic_input)
1031-
1072+
10321073
# Step 3: Integrate membrane potential
10331074
self.v_membrane = self.v_membrane + synaptic_input * self.dt
1034-
1075+
10351076
# Step 4: Generate spikes with arctangent surrogate gradients
10361077
spikes = self.arctangent_surrogate_spike(
10371078
self.v_membrane,
10381079
self.spike_threshold,
10391080
self.alpha
10401081
)
1041-
1082+
10421083
# Step 5: Apply reset mechanism
10431084
self._apply_reset(spikes)
1044-
1085+
10451086
# Step 6: Broadcast spikes back to connection format
10461087
# Each target spike affects all connections to that target
10471088
spikes_broadcast = spikes.unsqueeze(1).expand(batch_size, source_n, target_n)
1048-
1089+
10491090
# Apply spikes to incoming connections
10501091
output_spikes = conn_spikes_matrix * spikes_broadcast
1051-
1092+
10521093
# Reshape back to original format
10531094
return output_spikes.view(batch_size, source_n * target_n)
10541095

@@ -1252,7 +1293,14 @@ def __init__(
12521293
self.network = network # <-- Store the network
12531294

12541295
def compute(self, sample: torch.Tensor) -> dict:
1255-
# Use self.network if provided, else fall back to parent_feature's connection
1296+
# For Forward-Forward learning, use the normalized feature values instead of raw spikes
1297+
# This allows gradients to flow through the batch normalization
1298+
1299+
# Get the pipeline that contains the features
1300+
# We'll compute goodness from the feature values that have been normalized
1301+
goodness_per_layer = {}
1302+
1303+
# Use parent feature's network if available
12561304
if self.network is not None:
12571305
network = self.network
12581306
else:
@@ -1262,27 +1310,86 @@ def compute(self, sample: torch.Tensor) -> dict:
12621310
raise RuntimeError("Connection must have a valid network attribute.")
12631311
network = self.parent.connection.network
12641312

1265-
network.reset_state_variables()
1266-
inputs = {self.input_layer: sample.unsqueeze(0) if sample.dim() == 1 else sample}
1267-
spike_record = {layer_name: [] for layer_name in network.layers}
1268-
1269-
for t in range(self.time):
1270-
network.run(inputs, time=1)
1313+
# For Forward-Forward, we compute goodness from the sum of squared normalized activities
1314+
# We assume the pipeline has already computed feature values during the forward pass
1315+
total_goodness = torch.tensor(0.0, requires_grad=True)
1316+
1317+
# If we have a parent feature with a value, use that for goodness computation
1318+
if hasattr(self.parent, 'value') and self.parent.value is not None:
1319+
# Compute goodness as sum of squared activities (common in Forward-Forward)
1320+
feature_goodness = (self.parent.value ** 2).sum()
1321+
goodness_per_layer[self.parent.name] = feature_goodness
1322+
total_goodness = total_goodness + feature_goodness
1323+
else:
1324+
# Fallback: run network and compute goodness from layer activities
1325+
network.reset_state_variables()
1326+
inputs = {self.input_layer: sample.unsqueeze(0) if sample.dim() == 1 else sample}
1327+
1328+
for t in range(self.time):
1329+
network.run(inputs, time=1)
1330+
12711331
for layer_name, layer in network.layers.items():
1272-
spike_record[layer_name].append(layer.s.clone().detach())
1332+
if layer_name != self.input_layer: # Skip input layer
1333+
# Convert spikes to float and compute goodness
1334+
layer_activity = layer.s.float()
1335+
if layer_activity.requires_grad:
1336+
goodness = (layer_activity ** 2).sum()
1337+
else:
1338+
goodness = (layer_activity ** 2).sum().requires_grad_(True)
1339+
goodness_per_layer[layer_name] = goodness
1340+
total_goodness = total_goodness + goodness
12731341

1274-
goodness_per_layer = {}
1275-
for layer_name, spikes_list in spike_record.items():
1276-
spikes = torch.stack(spikes_list, dim=0)
1277-
goodness = spikes.sum(dim=0).sum(dim=0)
1278-
goodness_per_layer[layer_name] = goodness
1279-
1280-
total_goodness = sum([v.sum() for v in goodness_per_layer.values()])
12811342
goodness_per_layer["total_goodness"] = total_goodness
1282-
12831343
return goodness_per_layer
12841344

1285-
1345+
class ForwardForwardUpdate(AbstractSubFeature):
1346+
"""
1347+
SubFeature for Forward-Forward weight update using the loss:
1348+
Loss = -alpha * delta / (1 + exp(alpha * delta))
1349+
The update is proportional to the gradient of this loss w.r.t. delta.
1350+
"""
1351+
1352+
def __init__(
1353+
self,
1354+
name: str,
1355+
parent_feature: AbstractFeature,
1356+
) -> None:
1357+
super().__init__(name, parent_feature)
1358+
# Optionally, you could set self.sub_feature = self.update_weights
1359+
1360+
def update_weights(
1361+
self,
1362+
connection,
1363+
feature_output: torch.Tensor,
1364+
goodness: torch.Tensor,
1365+
goodness_error: torch.Tensor,
1366+
is_positive: bool,
1367+
learning_rate: float = 0.03,
1368+
alpha: float = 2.0,
1369+
**kwargs
1370+
):
1371+
"""
1372+
Perform the Forward-Forward weight update.
1373+
Args:
1374+
connection: The connection whose weights to update
1375+
feature_output: Output from feature computation
1376+
goodness: Computed goodness values
1377+
goodness_error: Goodness error (goodness - target_goodness)
1378+
is_positive: Whether this is a positive example
1379+
learning_rate: Learning rate for update
1380+
alpha: Steepness parameter for loss (default 2.0)
1381+
**kwargs: Additional arguments
1382+
"""
1383+
if not hasattr(connection, 'w'):
1384+
return
1385+
with torch.no_grad():
1386+
delta = goodness_error
1387+
exp_term = torch.exp(alpha * delta)
1388+
denom = (1 + exp_term)
1389+
numer = -alpha * delta
1390+
grad = numer / denom
1391+
weight_update = learning_rate * grad.unsqueeze(-1) * feature_output
1392+
connection.w += weight_update.mean(0) # Average over batch
12861393
# Helper function for easy creation
12871394
def create_arctangent_surrogate_connection(
12881395
source,
@@ -1339,3 +1446,78 @@ def create_arctangent_surrogate_connection(
13391446
return connection
13401447

13411448

1449+
class BatchNormalization(AbstractSubFeature):
1450+
"""
1451+
SubFeature to perform batch normalization on the parent feature's value using PyTorch's nn.BatchNorm1d.
1452+
Normalizes across the batch (first) dimension and includes learnable gamma (weight) and beta (bias).
1453+
"""
1454+
def __init__(
1455+
self,
1456+
name: str,
1457+
parent_feature: AbstractFeature,
1458+
eps: float = 1e-5,
1459+
affine: bool = True,
1460+
momentum: float = 0.1,
1461+
) -> None:
1462+
super().__init__(name, parent_feature)
1463+
self.eps = eps
1464+
self.affine = affine
1465+
self.momentum = momentum
1466+
self.bn = None # Will be initialized after parent_feature is primed
1467+
1468+
# Try to infer feature size if possible
1469+
if hasattr(self.parent, 'value') and isinstance(self.parent.value, torch.Tensor):
1470+
num_features = self.parent.value.shape[-1]
1471+
self._init_bn(num_features)
1472+
else:
1473+
self._pending_init = True # Will initialize in prime_feature
1474+
1475+
self.sub_feature = self.batch_normalize
1476+
1477+
def _init_bn(self, num_features):
1478+
self.bn = torch.nn.BatchNorm1d(
1479+
num_features=num_features,
1480+
eps=self.eps,
1481+
affine=self.affine,
1482+
momentum=self.momentum,
1483+
)
1484+
self._pending_init = False
1485+
1486+
def prime_feature(self, connection, device, **kwargs):
1487+
# If not already initialized, do so now
1488+
if getattr(self, '_pending_init', False):
1489+
if hasattr(self.parent, 'value') and isinstance(self.parent.value, torch.Tensor):
1490+
num_features = self.parent.value.shape[-1]
1491+
self._init_bn(num_features)
1492+
if self.bn is not None:
1493+
self.bn.to(device)
1494+
super().prime_feature(connection, device, **kwargs)
1495+
1496+
def batch_normalize(self):
1497+
value = self.parent.value
1498+
# value should be [batch_size, num_features]
1499+
if self.bn is None:
1500+
raise RuntimeError("BatchNorm not initialized. Call prime_feature first.")
1501+
# If value is 1D, unsqueeze to 2D for BatchNorm1d
1502+
if value.dim() == 1:
1503+
value = value.unsqueeze(0)
1504+
1505+
# Handle single sample case for BatchNorm1d
1506+
if value.size(0) == 1 and self.bn.training:
1507+
# Switch to eval mode temporarily for single samples
1508+
was_training = self.bn.training
1509+
self.bn.eval()
1510+
result = self.bn(value)
1511+
if was_training:
1512+
self.bn.train()
1513+
return result
1514+
else:
1515+
return self.bn(value)
1516+
1517+
@property
1518+
def gamma(self):
1519+
return self.bn.weight if self.bn is not None and self.affine else None
1520+
1521+
@property
1522+
def beta(self):
1523+
return self.bn.bias if self.bn is not None and self.affine else None

0 commit comments

Comments
 (0)