Skip to content

Commit b3aec14

Browse files
feat: Implement hybrid plasticity components and integrate with SNNDT
This commit introduces the core components for a hybrid learning setup combining backpropagation with a three-factor local learning rule: 1. `CustomLIFCell` (`novel_phases/phase-2/custom_lif.py`): A LIF neuron cell that maintains an eligibility trace based on pre and post-synaptic activity. 2. `apply_three_factor_update` (`novel_phases/phase-2/three_factor_updater.py`): A function to apply weight updates based on an eligibility trace, a third factor (e.g., return-to-go), and a local learning rate. Includes stability mechanisms like clipping and normalization. 3. `SNNDecisionTransformer` modifications (`src/models/snn_dt.py`): Integrated a hook-based mechanism to capture pre-synaptic input and post-synaptic output (logits) around the `predict_action` layer. This provides the necessary tensors for the training loop to compute and apply the three-factor update to the action head. 4. Conceptual training and test scripts (`novel_phases/phase-2/main_training.py`, `novel_phases/phase-2/test_snn_dt_plasticity.py`) were also developed, though execution of the `test_snn_dt_plasticity.py` faced persistent environmental issues. The changes to `SNNDecisionTransformer` are controlled by an `enable_action_head_plasticity` flag, allowing the new local update rule to be optionally applied in conjunction with standard backpropagation.
1 parent 88c3c30 commit b3aec14

5 files changed

Lines changed: 932 additions & 1 deletion

File tree

novel_phases/phase-2/custom_lif.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import torch
2+
import torch.nn as nn
3+
from norse.torch import LIFCell
4+
from norse.torch.functional.lif import LIFParameters
5+
from typing import Optional, Tuple, NamedTuple
6+
7+
8+
class CustomLIFCellState(NamedTuple):
9+
v: torch.Tensor
10+
i: torch.Tensor
11+
eligibility_trace: torch.Tensor
12+
13+
14+
class CustomLIFCell(LIFCell):
15+
def __init__(
16+
self,
17+
input_size: int,
18+
hidden_size: int,
19+
p: LIFParameters = LIFParameters(),
20+
dt: float = 0.001,
21+
trace_decay: float = 0.95,
22+
name: Optional[str] = None, # Added for compatibility with Norse's base cell
23+
**kwargs
24+
):
25+
super().__init__(p=p, dt=dt, name=name, **kwargs) # Pass name and kwargs to parent
26+
self.input_size = input_size
27+
self.hidden_size = hidden_size # Corresponds to output_size for a single layer
28+
self.trace_decay = trace_decay
29+
30+
# Initialize eligibility trace
31+
# This trace is typically associated with the weights of a layer.
32+
# If this cell IS the layer, trace shape is (input_size, hidden_size)
33+
# If this cell is PART of a more complex layer (e.g. recurrent), this might differ.
34+
# For now, assuming it's for a feed-forward connection where this cell's output
35+
# is the post-synaptic activity and its input is the pre-synaptic activity
36+
# for weights connecting input_size to hidden_size.
37+
self.register_buffer(
38+
"eligibility_trace",
39+
torch.zeros(input_size, hidden_size, device=kwargs.get('device'), dtype=kwargs.get('dtype'))
40+
)
41+
42+
def get_initial_state(self, batch_size: int, inputs: Optional[torch.Tensor] = None) -> CustomLIFCellState:
43+
# Overriding to include eligibility trace in the state if needed,
44+
# but the trace is more of a persistent parameter of the cell for learning,
45+
# rather than a state that changes with each input in a sequence in the same way v and i do.
46+
# For now, the eligibility trace is stored directly in the module.
47+
# If we need per-sequence traces, this would need to change.
48+
s_prev = super().get_initial_state(batch_size, inputs)
49+
# The eligibility trace is not part of the recurrent state passed from step to step.
50+
# It's a module buffer that accumulates over time.
51+
# So, we return the parent's state directly.
52+
return s_prev # v, i
53+
54+
def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
55+
"""
56+
Performs a forward pass through the CustomLIFCell.
57+
58+
Args:
59+
x (torch.Tensor): Input tensor (typically spikes) of shape (batch_size, input_size).
60+
state (Optional[Tuple[torch.Tensor, torch.Tensor]]): Previous state (v, i).
61+
If None, it's initialized.
62+
63+
Returns:
64+
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
65+
- Output spikes (s_out) of shape (batch_size, hidden_size).
66+
- Next state (v_next, i_next).
67+
"""
68+
if state is None:
69+
# If this cell is the first in a sequence or state is not passed,
70+
# it uses its own internal state, which is fine for non-recurrent use.
71+
# For recurrent use, state should be explicitly managed.
72+
# Norse LIFCell's default behavior handles this if state is None.
73+
# We get the initial state for v and i from the parent.
74+
# The eligibility trace is handled separately as it persists across calls differently.
75+
initial_parent_state = super().get_initial_state(batch_size=x.shape[0], inputs=x)
76+
if state is None:
77+
state = initial_parent_state
78+
79+
80+
# Perform the standard LIF cell computation
81+
s_out, next_state = super().forward(x, state) # next_state is (v_next, i_next)
82+
83+
# Update eligibility trace
84+
# x contains pre-synaptic spikes (batch_size, input_size)
85+
# s_out contains post-synaptic spikes (batch_size, hidden_size)
86+
# We need to compute the outer product for each item in the batch and sum or average.
87+
88+
# Assuming x is pre-synaptic spikes (0 or 1) and s_out is post-synaptic spikes (0 or 1)
89+
# For a batch, we sum the outer products: sum_batch(pre_i.T @ post_j)
90+
if x.requires_grad: # Ensure pre_spikes are detached if they come from a part of the graph we don't want to influence via this path
91+
pre_spikes = x.detach()
92+
else:
93+
pre_spikes = x
94+
95+
if s_out.requires_grad:
96+
post_spikes = s_out.detach()
97+
else:
98+
post_spikes = s_out
99+
100+
# Sum over the batch dimension
101+
# pre_spikes: (batch_size, input_size)
102+
# post_spikes: (batch_size, hidden_size)
103+
# update should be (input_size, hidden_size)
104+
# (input_size, batch_size) @ (batch_size, hidden_size)
105+
batch_trace_update = torch.matmul(pre_spikes.t(), post_spikes) / x.shape[0] # Averaging over batch
106+
107+
self.eligibility_trace.mul_(self.trace_decay).add_(batch_trace_update)
108+
109+
# The state returned should match what the parent LIFCell returns for recurrent connections.
110+
# The eligibility trace is updated in-place within the module.
111+
return s_out, next_state
112+
113+
def reset_trace(self):
114+
"""Resets the eligibility trace to zeros."""
115+
self.eligibility_trace.zero_()
116+
117+
# Example Usage (Illustrative)
118+
if __name__ == '__main__':
119+
batch_size = 10
120+
input_features = 20
121+
output_features = 5 # hidden_size for the cell
122+
123+
# Create a CustomLIFCell
124+
custom_lif_cell = CustomLIFCell(input_features, output_features)
125+
126+
# Dummy input spikes (binary) and initial state
127+
# Typically, input spikes would be generated by a previous layer or Poisson encoder
128+
input_spikes = (torch.rand(batch_size, input_features) > 0.8).float()
129+
130+
# Get initial state for v and i from the cell itself
131+
# This is how Norse typically handles it if you don't pass a state.
132+
# The state is managed internally by the cell if not provided.
133+
initial_state = custom_lif_cell.get_initial_state(batch_size=batch_size, inputs=input_spikes)
134+
135+
136+
# Simulate a few time steps
137+
print(f"Initial eligibility trace:\n{custom_lif_cell.eligibility_trace}")
138+
139+
# First step
140+
print("\n--- Step 1 ---")
141+
s_out, next_state = custom_lif_cell(input_spikes, initial_state)
142+
print(f"Output spikes (shape: {s_out.shape}):\n{s_out}")
143+
print(f"Updated eligibility trace (shape: {custom_lif_cell.eligibility_trace.shape}):\n{custom_lif_cell.eligibility_trace}")
144+
145+
# Second step (using the state from the previous step)
146+
print("\n--- Step 2 ---")
147+
input_spikes_2 = (torch.rand(batch_size, input_features) > 0.7).float()
148+
s_out_2, next_state_2 = custom_lif_cell(input_spikes_2, next_state)
149+
print(f"Output spikes 2 (shape: {s_out_2.shape}):\n{s_out_2}")
150+
print(f"Updated eligibility trace:\n{custom_lif_cell.eligibility_trace}")
151+
152+
# Reset trace
153+
custom_lif_cell.reset_trace()
154+
print(f"\nAfter reset, eligibility trace:\n{custom_lif_cell.eligibility_trace}")
155+
156+
# Test with a different device if available
157+
if torch.cuda.is_available():
158+
print("\n--- CUDA Test ---")
159+
device = torch.device("cuda")
160+
custom_lif_cell_cuda = CustomLIFCell(input_features, output_features, device=device, dtype=torch.float32)
161+
input_spikes_cuda = input_spikes.to(device)
162+
initial_state_cuda = custom_lif_cell_cuda.get_initial_state(batch_size=batch_size, inputs=input_spikes_cuda)
163+
164+
s_out_cuda, _ = custom_lif_cell_cuda(input_spikes_cuda, initial_state_cuda)
165+
print(f"CUDA Output spikes (shape: {s_out_cuda.shape}) on device: {s_out_cuda.device}")
166+
print(f"CUDA Eligibility trace (shape: {custom_lif_cell_cuda.eligibility_trace.shape}) on device: {custom_lif_cell_cuda.eligibility_trace.device}:\n{custom_lif_cell_cuda.eligibility_trace}")
167+
168+
print("\nNote: The eligibility trace accumulates. It's typically used in conjunction with a learning rule that applies it (and potentially resets it) after a learning episode/batch.")
169+
print("The CustomLIFCell itself doesn't return the eligibility trace in its forward pass's state tuple, as it's a module parameter.")
170+
print("If using this cell in a nn.Sequential or Norse's SequentialState, the state passed around will be (v,i).")
171+
print("The eligibility trace must be accessed directly from the module instance.")

0 commit comments

Comments
 (0)