11import torch
22import torch .nn as nn
3- from norse .torch .module .lif import LIFCell
3+ from norse .torch .module .lif import LIF , LIFCell , LIFParameters
44
55from src .models .base import BasePolicy
66
@@ -16,8 +16,14 @@ def __init__(self, d_model, n_heads, lif_tau, surrogate_k):
1616 self .k_proj = nn .Linear (d_model , d_model )
1717 self .v_proj = nn .Linear (d_model , d_model )
1818
19- self .q_lif = LIFCell ()
20- self .k_lif = LIFCell ()
19+ p = LIFParameters (
20+ tau_mem_inv = torch .tensor (1.0 / lif_tau ),
21+ v_th = torch .tensor (0.8 ),
22+ method = "super" ,
23+ alpha = surrogate_k ,
24+ )
25+ self .q_lif = LIF (p = p )
26+ self .k_lif = LIF (p = p )
2127
2228 self .spike_count = 0
2329
@@ -28,15 +34,8 @@ def forward(self, x, state_q, state_k, attn_mask=None):
2834 k = self .k_proj (x )
2935 v = self .v_proj (x )
3036
31- spikes_q_seq = []
32- spikes_k_seq = []
33- for t in range (seq_len ):
34- spikes_q , state_q = self .q_lif (q [:, t ], state_q )
35- spikes_k , state_k = self .k_lif (k [:, t ], state_k )
36- spikes_q_seq .append (spikes_q )
37- spikes_k_seq .append (spikes_k )
38- spikes_q = torch .stack (spikes_q_seq , dim = 1 )
39- spikes_k = torch .stack (spikes_k_seq , dim = 1 )
37+ spikes_q , _ = self .q_lif (q )
38+ spikes_k , _ = self .k_lif (k )
4039
4140 q_reshaped = spikes_q .view (batch_size , seq_len , self .n_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
4241 k_reshaped = spikes_k .view (batch_size , seq_len , self .n_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
@@ -49,8 +48,10 @@ def forward(self, x, state_q, state_k, attn_mask=None):
4948
5049 attn_output = (attn_weights @ v_reshaped ).permute (0 , 2 , 1 , 3 ).reshape (batch_size , seq_len , self .d_model )
5150
52- self .spike_count = spikes_q .sum () + spikes_k .sum ()
53- return attn_output , state_q , state_k
51+ if not hasattr (self , "spike_count" ):
52+ self .spike_count = 0.0
53+ self .spike_count += (spikes_q .sum () + spikes_k .sum ()).item ()
54+ return attn_output
5455
5556
5657class DsFormer (BasePolicy , nn .Module ):
@@ -95,6 +96,18 @@ def forward(self, batch):
9596 action_input = actions
9697
9798 action_embeddings = self .embed_action (action_input )
99+
100+ # Pad action embeddings to match state/return sequence length
101+ if action_embeddings .shape [1 ] < seq_len :
102+ padding_size = seq_len - action_embeddings .shape [1 ]
103+ padding = torch .zeros (
104+ action_embeddings .shape [0 ],
105+ padding_size ,
106+ action_embeddings .shape [2 ],
107+ device = action_embeddings .device ,
108+ )
109+ action_embeddings = torch .cat ([action_embeddings , padding ], dim = 1 )
110+
98111 return_embeddings = self .embed_return (batch ["returns_to_go" ])
99112 time_embeddings = self .embed_timestep (batch ["timesteps" ])
100113
@@ -110,10 +123,8 @@ def forward(self, batch):
110123 x = self .embed_ln (stacked_inputs )
111124
112125 attn_mask = nn .Transformer .generate_square_subsequent_mask (x .shape [1 ], device = x .device )
113- q_states = [None ] * len (self .blocks )
114- k_states = [None ] * len (self .blocks )
115126 for i , block in enumerate (self .blocks ):
116- x , q_states [ i ], k_states [ i ] = block (x , q_states [ i ], k_states [ i ] , attn_mask = attn_mask )
127+ x = block (x , None , None , attn_mask = attn_mask )
117128
118129 action_preds = self .action_predictor (x [:, 1 ::3 ])
119130 return action_preds
@@ -144,7 +155,12 @@ def load(self, path):
144155 self .load_state_dict (torch .load (path ))
145156
146157 def count_spikes (self ):
147- return sum (b .spike_count for b in self .blocks )
158+ total_spikes = sum (block .spike_count for block in self .blocks )
159+ return total_spikes / len (self .blocks ) if len (self .blocks ) > 0 else 0.0
160+
161+ def reset_spike_counts (self ):
162+ for block in self .blocks :
163+ block .spike_count = 0.0
148164
149165 def num_params (self ):
150166 return sum (p .numel () for p in self .parameters () if p .requires_grad )
0 commit comments