Skip to content

Commit e5a6d76

Browse files
Fixed: DSFormer Baseline
1 parent 834b62b commit e5a6d76

8 files changed

Lines changed: 250 additions & 49 deletions

File tree

debug_snn_dt.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
import sys
3+
import os
4+
5+
# Add snn-dt to the python path to allow imports from src
6+
sys.path.insert(0, os.path.abspath("snn-dt"))
7+
8+
from src.models.snn_dt import SnnDt
9+
10+
# A simplified MockConfig, similar to the one in the tests
11+
class MockConfig:
12+
def __init__(self):
13+
self.model = self.Model()
14+
self.dataset = self.Dataset()
15+
self.snn = self.Snn()
16+
self.env = "dummy_env"
17+
18+
class Model:
19+
name = "snn_dt"
20+
d_model = 128
21+
n_heads = 4
22+
n_layers = 2
23+
24+
class Dataset:
25+
state_dim = 4
26+
act_dim = 1
27+
max_timesteps = 100
28+
is_discrete = False
29+
30+
class Snn:
31+
lif_tau = 20.0
32+
surrogate_k = 25.0
33+
use_plasticity = False
34+
35+
class Training:
36+
device = "cpu"
37+
38+
def debug_snn_dt():
39+
"""
40+
Instantiates the SnnDt model, runs a forward pass, and checks the spike count.
41+
"""
42+
print("--- Initializing SNN-DT Debug Script ---")
43+
44+
# 1. Setup model and config
45+
cfg = MockConfig()
46+
model = SnnDt(cfg)
47+
model.eval() # Use eval mode to disable training-specific logic like plasticity
48+
49+
print("SnnDt model instantiated successfully.")
50+
51+
# 2. Create a batch of dummy data with high magnitude to encourage spiking
52+
batch = {
53+
"states": torch.randn(16, 20, 4) * 100,
54+
"actions": torch.randn(16, 20, 1) * 100,
55+
"returns_to_go": torch.randn(16, 20, 1) * 100,
56+
"timesteps": torch.randint(0, 100, (16, 20)),
57+
"mask": torch.ones(16, 20),
58+
}
59+
60+
print("Batch created. Running a single forward pass...")
61+
62+
# 3. Run the forward pass and check spikes
63+
with torch.no_grad():
64+
model(batch)
65+
66+
spike_count_1 = model.count_spikes()
67+
print(f"Spike count after first pass: {spike_count_1}")
68+
69+
# 4. Run a second pass to check accumulation
70+
print("Running a second forward pass to check accumulation...")
71+
with torch.no_grad():
72+
model(batch)
73+
74+
spike_count_2 = model.count_spikes()
75+
print(f"Spike count after second pass: {spike_count_2}")
76+
77+
# 5. Check the reset mechanism
78+
print("Resetting spike counts...")
79+
model.reset_spike_counts()
80+
print(f"Spike count after reset: {model.count_spikes()}")
81+
82+
print("--- Debug Script Finished ---")
83+
84+
if __name__ == "__main__":
85+
debug_snn_dt()
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2025-11-09 10:11:25,647 [INFO] Checking for dataset...
2+
2025-11-09 10:11:25,655 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
3+
2025-11-09 10:11:25,656 [INFO] Starting training...
4+
2025-11-09 10:11:25,818 [INFO] Dataset size: 1000 clips
5+
2025-11-09 10:11:25,874 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
6+
2025-11-09 10:11:39,234 [INFO] Starting training loop...
7+
2025-11-09 10:26:07,900 [INFO] Checking for dataset...
8+
2025-11-09 10:26:07,902 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
9+
2025-11-09 10:26:07,903 [INFO] Starting training...
10+
2025-11-09 10:26:08,336 [INFO] Dataset size: 1000 clips
11+
2025-11-09 10:26:08,369 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
12+
2025-11-09 10:26:22,564 [INFO] Starting training loop...
13+
2025-11-09 14:47:13,401 [INFO] Checking for dataset...
14+
2025-11-09 14:47:13,403 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
15+
2025-11-09 14:47:13,404 [INFO] Starting training...
16+
2025-11-09 14:47:13,575 [INFO] Dataset size: 1000 clips
17+
2025-11-09 14:47:13,635 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
18+
2025-11-09 14:47:29,057 [INFO] Starting training loop...
19+
2025-11-09 14:50:54,988 [INFO] Checking for dataset...
20+
2025-11-09 14:50:55,000 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
21+
2025-11-09 14:50:55,001 [INFO] Starting training...
22+
2025-11-09 14:50:55,110 [INFO] Dataset size: 1000 clips
23+
2025-11-09 14:50:55,139 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
24+
2025-11-09 14:51:05,312 [INFO] Starting training loop...
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2025-11-09 10:06:36,149 [INFO] Checking for dataset...
2+
2025-11-09 10:06:36,159 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
3+
2025-11-09 10:06:36,164 [INFO] Starting training...
4+
2025-11-09 10:06:36,398 [INFO] Dataset size: 1000 clips
5+
2025-11-09 10:06:36,472 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
6+
2025-11-09 10:06:56,621 [INFO] Starting training loop...

run_tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import sys
2+
import os
3+
import pytest
4+
5+
if __name__ == "__main__":
6+
print("Current working directory:", os.getcwd())
7+
sys.path.insert(0, os.path.abspath("snn-dt"))
8+
print("sys.path:", sys.path)
9+
sys.exit(pytest.main(["-x", "snn-dt/tests/test_models.py"]))

snn-dt/scripts/train.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import warnings
2424
warnings.filterwarnings('ignore')
2525

26-
# from src.models.cql import CQL
27-
# from src.models.dt import DecisionTransformer
28-
# from src.models.dsformer import DsFormer
29-
# from src.models.iql import IQL
26+
from src.models.cql import CQL
27+
from src.models.dt import DecisionTransformer
28+
from src.models.dsformer import DsFormer
29+
from src.models.iql import IQL
3030
from src.models.snn_dt import SnnDt
3131
from src.utils.config import AttrDict
3232
from src.utils.models import get_model
@@ -195,6 +195,9 @@ def train(cfg, logger):
195195
for epoch in range(cfg.training.epochs):
196196
start_time = time.time()
197197
epoch_losses = []
198+
199+
if hasattr(model, "reset_spike_counts"):
200+
model.reset_spike_counts()
198201

199202
train_iter = iter(train_loader)
200203
pbar = tqdm(range(cfg.training.batches_per_epoch), desc=f"Epoch {epoch+1}/{cfg.training.epochs}")
@@ -250,10 +253,12 @@ def train(cfg, logger):
250253
log_str = f"Epoch {epoch+1}/{cfg.training.epochs} | Time: {epoch_time:.2f}s | Loss: {avg_loss:.4f}"
251254

252255
# Spike counting for SNN models
253-
if isinstance(model, SnnDt):
256+
if hasattr(model, "count_spikes"):
254257
spikes = model.count_spikes()
255-
log_str += f" | Spikes: {spikes}"
258+
log_str += f" | Spikes: {spikes:.2f}"
256259
eval_results["spikes"] = spikes
260+
else:
261+
eval_results["spikes"] = 0.0
257262

258263
metrics.append({"epoch": epoch + 1, "loss": avg_loss, **eval_results, "time_s": epoch_time})
259264
log_str += f" | Eval Return: {eval_results['return_mean']:.2f}"
@@ -376,6 +381,11 @@ def main():
376381

377382
# Convert to AttrDict for easy access
378383
cfg = AttrDict(cfg)
384+
385+
# Adaptive training controls for SNNs
386+
if "snn" in cfg.model.name or "dsformer" in cfg.model.name:
387+
cfg.training.batches_per_epoch = min(cfg.training.batches_per_epoch, cfg_raw.get("snn_batches_per_epoch", 100))
388+
cfg.training.eval_every = max(cfg.training.eval_every, cfg_raw.get("snn_eval_every", 50))
379389

380390
# Construct dataset path from env name, relative to project root
381391
cfg.dataset.path = str(project_root / f"data/{args.env}/dataset.npz")

snn-dt/src/models/dsformer.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
from norse.torch.module.lif import LIFCell
3+
from norse.torch.module.lif import LIF, LIFCell, LIFParameters
44

55
from src.models.base import BasePolicy
66

@@ -16,8 +16,14 @@ def __init__(self, d_model, n_heads, lif_tau, surrogate_k):
1616
self.k_proj = nn.Linear(d_model, d_model)
1717
self.v_proj = nn.Linear(d_model, d_model)
1818

19-
self.q_lif = LIFCell()
20-
self.k_lif = LIFCell()
19+
p = LIFParameters(
20+
tau_mem_inv=torch.tensor(1.0 / lif_tau),
21+
v_th=torch.tensor(0.8),
22+
method="super",
23+
alpha=surrogate_k,
24+
)
25+
self.q_lif = LIF(p=p)
26+
self.k_lif = LIF(p=p)
2127

2228
self.spike_count = 0
2329

@@ -28,15 +34,8 @@ def forward(self, x, state_q, state_k, attn_mask=None):
2834
k = self.k_proj(x)
2935
v = self.v_proj(x)
3036

31-
spikes_q_seq = []
32-
spikes_k_seq = []
33-
for t in range(seq_len):
34-
spikes_q, state_q = self.q_lif(q[:, t], state_q)
35-
spikes_k, state_k = self.k_lif(k[:, t], state_k)
36-
spikes_q_seq.append(spikes_q)
37-
spikes_k_seq.append(spikes_k)
38-
spikes_q = torch.stack(spikes_q_seq, dim=1)
39-
spikes_k = torch.stack(spikes_k_seq, dim=1)
37+
spikes_q, _ = self.q_lif(q)
38+
spikes_k, _ = self.k_lif(k)
4039

4140
q_reshaped = spikes_q.view(batch_size, seq_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
4241
k_reshaped = spikes_k.view(batch_size, seq_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
@@ -49,8 +48,10 @@ def forward(self, x, state_q, state_k, attn_mask=None):
4948

5049
attn_output = (attn_weights @ v_reshaped).permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.d_model)
5150

52-
self.spike_count = spikes_q.sum() + spikes_k.sum()
53-
return attn_output, state_q, state_k
51+
if not hasattr(self, "spike_count"):
52+
self.spike_count = 0.0
53+
self.spike_count += (spikes_q.sum() + spikes_k.sum()).item()
54+
return attn_output
5455

5556

5657
class DsFormer(BasePolicy, nn.Module):
@@ -95,6 +96,18 @@ def forward(self, batch):
9596
action_input = actions
9697

9798
action_embeddings = self.embed_action(action_input)
99+
100+
# Pad action embeddings to match state/return sequence length
101+
if action_embeddings.shape[1] < seq_len:
102+
padding_size = seq_len - action_embeddings.shape[1]
103+
padding = torch.zeros(
104+
action_embeddings.shape[0],
105+
padding_size,
106+
action_embeddings.shape[2],
107+
device=action_embeddings.device,
108+
)
109+
action_embeddings = torch.cat([action_embeddings, padding], dim=1)
110+
98111
return_embeddings = self.embed_return(batch["returns_to_go"])
99112
time_embeddings = self.embed_timestep(batch["timesteps"])
100113

@@ -110,10 +123,8 @@ def forward(self, batch):
110123
x = self.embed_ln(stacked_inputs)
111124

112125
attn_mask = nn.Transformer.generate_square_subsequent_mask(x.shape[1], device=x.device)
113-
q_states = [None] * len(self.blocks)
114-
k_states = [None] * len(self.blocks)
115126
for i, block in enumerate(self.blocks):
116-
x, q_states[i], k_states[i] = block(x, q_states[i], k_states[i], attn_mask=attn_mask)
127+
x = block(x, None, None, attn_mask=attn_mask)
117128

118129
action_preds = self.action_predictor(x[:, 1::3])
119130
return action_preds
@@ -144,7 +155,12 @@ def load(self, path):
144155
self.load_state_dict(torch.load(path))
145156

146157
def count_spikes(self):
147-
return sum(b.spike_count for b in self.blocks)
158+
total_spikes = sum(block.spike_count for block in self.blocks)
159+
return total_spikes / len(self.blocks) if len(self.blocks) > 0 else 0.0
160+
161+
def reset_spike_counts(self):
162+
for block in self.blocks:
163+
block.spike_count = 0.0
148164

149165
def num_params(self):
150166
return sum(p.numel() for p in self.parameters() if p.requires_grad)

snn-dt/src/models/snn_dt.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from norse.torch.module.leaky_integrator import LICell
4-
from norse.torch.module.lif import LIFCell, LIFParameters
4+
from norse.torch.module.lif import LIF, LIFCell, LIFParameters
55

66
from src.models.base import BasePolicy
77

@@ -27,8 +27,8 @@ def __init__(self, cfg, d_model, n_heads, lif_tau, surrogate_k):
2727
alpha=surrogate_k,
2828
)
2929

30-
self.q_lif = LIFCell(p=p)
31-
self.k_lif = LIFCell(p=p)
30+
self.q_lif = LIF(p=p)
31+
self.k_lif = LIF(p=p)
3232
self.v_li = LICell()
3333

3434
self.use_plasticity = False # Will be set by SnnDt
@@ -52,16 +52,8 @@ def forward(self, x, state_q, state_k, attn_mask=None):
5252
k = self.k_proj(x)
5353
v = self.v_proj(x)
5454

55-
# Spiking Q and K
56-
spikes_q_seq = []
57-
spikes_k_seq = []
58-
for t in range(seq_len):
59-
spikes_q, state_q = self.q_lif(q[:, t], state_q)
60-
spikes_k, state_k = self.k_lif(k[:, t], state_k)
61-
spikes_q_seq.append(spikes_q)
62-
spikes_k_seq.append(spikes_k)
63-
spikes_q = torch.stack(spikes_q_seq, dim=1)
64-
spikes_k = torch.stack(spikes_k_seq, dim=1)
55+
spikes_q, _ = self.q_lif(q)
56+
spikes_k, _ = self.k_lif(k)
6557

6658
# Attention
6759
q_reshaped = spikes_q.view(batch_size, seq_len, self.n_heads, self.head_dim)
@@ -79,13 +71,15 @@ def forward(self, x, state_q, state_k, attn_mask=None):
7971
routing_gate = self.routing_mlp(attn_output)
8072
out = attn_output * routing_gate
8173

82-
self.spike_count = spikes_q.sum() + spikes_k.sum()
74+
if not hasattr(self, "spike_count"):
75+
self.spike_count = 0.0
76+
self.spike_count += (spikes_q.sum() + spikes_k.sum()).item()
8377

8478
# Three-factor plasticity
8579
if self.training and self.use_plasticity:
8680
self.update_eligibility_trace(spikes_q, v)
8781

88-
return out, state_q, state_k
82+
return out
8983

9084
def update_eligibility_trace(self, presynaptic_spikes, postsynaptic_potential):
9185
# Simplified eligibility trace update
@@ -193,10 +187,8 @@ def forward(self, batch):
193187

194188
# Spiking transformer blocks
195189
attn_mask = nn.Transformer.generate_square_subsequent_mask(x.shape[1], device=x.device)
196-
q_states = [None] * len(self.blocks)
197-
k_states = [None] * len(self.blocks)
198190
for i, block in enumerate(self.blocks):
199-
x, q_states[i], k_states[i] = block(x, q_states[i], k_states[i], attn_mask=attn_mask)
191+
x = block(x, None, None, attn_mask=attn_mask)
200192

201193
action_preds = self.action_predictor(x[:, 1::3])
202194
return action_preds
@@ -227,7 +219,12 @@ def load(self, path):
227219
self.load_state_dict(torch.load(path))
228220

229221
def count_spikes(self):
230-
return sum(b.spike_count for b in self.blocks)
222+
total_spikes = sum(block.spike_count for block in self.blocks)
223+
return total_spikes / len(self.blocks) if len(self.blocks) > 0 else 0.0
224+
225+
def reset_spike_counts(self):
226+
for block in self.blocks:
227+
block.spike_count = 0.0
231228

232229
def num_params(self):
233230
return sum(p.numel() for p in self.parameters() if p.requires_grad)

0 commit comments

Comments
 (0)