File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -538,7 +538,7 @@ def forward(self, batch: dict) -> torch.Tensor:
538538
539539 # reset block states and diagnostics
540540 for b in self .blocks :
541- if hasattr (b , "lif " ):
541+ if hasattr (b , "spike_count " ):
542542 # zero state handled internally if none provided; clear spike_count for each forward pass
543543 b .spike_count = 0.0
544544 b .last_alpha = None
@@ -557,9 +557,9 @@ def forward(self, batch: dict) -> torch.Tensor:
557557
558558 # diagnostics aggregation
559559 self .total_spike_count += block .spike_count
560- # total opportunities for normalization: B * seq * d_model * 3 (Q,K,V) * 1 (per block)
560+ # total opportunities for normalization: B * seq * d_model * 3 (Q,K,V) * 1 (per block) * T (time)
561561 # We accumulate across blocks for global normalization
562- self .total_spike_opportunities += float (B * x .shape [1 ] * (3 * self .d_model ))
562+ self .total_spike_opportunities += float (B * x .shape [1 ] * (3 * self .d_model ) * self . T )
563563 self .last_diagnostics [f"block_{ i } _spike_rate_q" ] = block .diagnostics .get ("spike_rate_q" , 0.0 )
564564 self .last_diagnostics [f"block_{ i } _spike_rate_v" ] = block .diagnostics .get ("spike_rate_v" , 0.0 )
565565 self .last_diagnostics [f"block_{ i } _q_min" ] = block .diagnostics .get ("q_min" , 0.0 )
You can’t perform that action at this time.
0 commit comments