@@ -213,7 +213,7 @@ def forward(self, y_heads: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
213213# -------------------------
214214# Three-factor plasticity (skeleton)
215215# -------------------------
216- class ThreeFactorPlasticity :
216+ class ThreeFactorPlasticity ( nn . Module ) :
217217 """
218218 Minimal three-factor plasticity class:
219219 - maintains eligibility trace E (same shape as a weight matrix)
@@ -224,10 +224,11 @@ class ThreeFactorPlasticity:
224224 """
225225
226226 def __init__ (self , weight_shape : Tuple [int , int ], eta : float = 1e-3 , lambda_decay : float = 0.99 , device : Optional [torch .device ] = None ):
227- self . device = device
227+ super (). __init__ ()
228228 self .eta = float (eta )
229229 self .lambda_decay = float (lambda_decay )
230- self .E = torch .zeros (weight_shape , device = device )
230+ # Register E as a buffer so it is moved to device along with the model
231+ self .register_buffer ("E" , torch .zeros (weight_shape , device = device ))
231232
232233 def update_trace (self , pre : torch .Tensor , post : torch .Tensor ):
233234 """
@@ -261,6 +262,9 @@ def apply(self, weight_param: nn.Parameter, reward: float):
261262 # shapes mismatch -> no-op (user must ensure shapes align)
262263 pass
263264
265+ def reset (self ):
266+ self .E .zero_ ()
267+
264268
265269# -------------------------
266270# Spiking Transformer Block
@@ -334,6 +338,8 @@ def reset_state(self):
334338 self .lif_q .reset_state ()
335339 self .lif_k .reset_state ()
336340 self .lif_v .reset_state ()
341+ if self .plasticity_rule is not None and hasattr (self .plasticity_rule , 'reset' ):
342+ self .plasticity_rule .reset ()
337343
338344 def forward (self , x : torch .Tensor , phase_mod : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
339345 if self .training :
0 commit comments