Skip to content

Commit fd22b51

Browse files
Update snn_dt.py
1 parent 2f12974 commit fd22b51

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

snn-dt/src/models/snn_dt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def forward(self, batch: dict) -> torch.Tensor:
538538

539539
# reset block states and diagnostics
540540
for b in self.blocks:
541-
if hasattr(b, "lif"):
541+
if hasattr(b, "spike_count"):
542542
# zero state handled internally if none provided; clear spike_count for each forward pass
543543
b.spike_count = 0.0
544544
b.last_alpha = None
@@ -557,9 +557,9 @@ def forward(self, batch: dict) -> torch.Tensor:
557557

558558
# diagnostics aggregation
559559
self.total_spike_count += block.spike_count
560-
# total opportunities for normalization: B * seq * d_model * 3 (Q,K,V) * 1 (per block)
560+
# total opportunities for normalization: B * seq * d_model * 3 (Q,K,V) * 1 (per block) * T (time)
561561
# We accumulate across blocks for global normalization
562-
self.total_spike_opportunities += float(B * x.shape[1] * (3 * self.d_model))
562+
self.total_spike_opportunities += float(B * x.shape[1] * (3 * self.d_model) * self.T)
563563
self.last_diagnostics[f"block_{i}_spike_rate_q"] = block.diagnostics.get("spike_rate_q", 0.0)
564564
self.last_diagnostics[f"block_{i}_spike_rate_v"] = block.diagnostics.get("spike_rate_v", 0.0)
565565
self.last_diagnostics[f"block_{i}_q_min"] = block.diagnostics.get("q_min", 0.0)

0 commit comments

Comments
 (0)