Skip to content

Commit 52ee253

Browse files
Update snn_dt.py
1 parent f12c9b5 commit 52ee253

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

snn-dt/src/models/snn_dt.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def forward(self, y_heads: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
213213
# -------------------------
214214
# Three-factor plasticity (skeleton)
215215
# -------------------------
216-
class ThreeFactorPlasticity:
216+
class ThreeFactorPlasticity(nn.Module):
217217
"""
218218
Minimal three-factor plasticity class:
219219
- maintains eligibility trace E (same shape as a weight matrix)
@@ -224,10 +224,11 @@ class ThreeFactorPlasticity:
224224
"""
225225

226226
def __init__(self, weight_shape: Tuple[int, int], eta: float = 1e-3, lambda_decay: float = 0.99, device: Optional[torch.device] = None):
227-
self.device = device
227+
super().__init__()
228228
self.eta = float(eta)
229229
self.lambda_decay = float(lambda_decay)
230-
self.E = torch.zeros(weight_shape, device=device)
230+
# Register E as a buffer so it is moved to device along with the model
231+
self.register_buffer("E", torch.zeros(weight_shape, device=device))
231232

232233
def update_trace(self, pre: torch.Tensor, post: torch.Tensor):
233234
"""
@@ -261,6 +262,9 @@ def apply(self, weight_param: nn.Parameter, reward: float):
261262
# shapes mismatch -> no-op (user must ensure shapes align)
262263
pass
263264

265+
def reset(self):
266+
self.E.zero_()
267+
264268

265269
# -------------------------
266270
# Spiking Transformer Block
@@ -334,6 +338,8 @@ def reset_state(self):
334338
self.lif_q.reset_state()
335339
self.lif_k.reset_state()
336340
self.lif_v.reset_state()
341+
if self.plasticity_rule is not None and hasattr(self.plasticity_rule, 'reset'):
342+
self.plasticity_rule.reset()
337343

338344
def forward(self, x: torch.Tensor, phase_mod: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
339345
if self.training:

0 commit comments

Comments
 (0)