Skip to content

Commit 46b33c1

Browse files
dt-baseline with acrobot fixing
1 parent f3d02a1 commit 46b33c1

6 files changed

Lines changed: 1164 additions & 4 deletions

File tree

baseline_comparisons_results/01_dt/dt_cartpole.log

Lines changed: 1121 additions & 0 deletions
Large diffs are not rendered by default.

data/Acrobot-v1/dataset.npz

12.7 MB
Binary file not shown.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2025-11-05 17:12:14,765 [INFO] Checking for dataset...
2+
2025-11-05 17:12:14,827 [INFO] Dataset not found at D:\Github\neuromorphic_decision_transformer\data\Acrobot-v1\dataset.npz. Generating new dataset...
3+
2025-11-05 17:15:33,625 [INFO] Dataset generated and saved to D:\Github\neuromorphic_decision_transformer\data\Acrobot-v1\dataset.npz
4+
2025-11-05 17:15:33,625 [INFO] Starting training...
5+
2025-11-05 17:15:34,115 [INFO] Dataset size: 1000 clips
6+
2025-11-05 17:15:34,120 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
7+
2025-11-05 17:15:45,399 [INFO] Starting training loop...
8+
2025-11-05 17:18:55,700 [INFO] Checking for dataset...
9+
2025-11-05 17:18:55,702 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\Acrobot-v1\dataset.npz.
10+
2025-11-05 17:18:56,038 [INFO] Starting training...
11+
2025-11-05 17:18:56,638 [INFO] Dataset size: 1000 clips
12+
2025-11-05 17:18:56,646 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
13+
2025-11-05 17:19:09,156 [INFO] Starting training loop...
14+
2025-11-05 17:22:25,707 [INFO] Checking for dataset...
15+
2025-11-05 17:22:26,140 [INFO] Dataset not found at D:\Github\neuromorphic_decision_transformer\data\Acrobot-v1\dataset.npz. Generating new dataset...
16+
2025-11-05 17:26:25,223 [INFO] Dataset generated and saved to D:\Github\neuromorphic_decision_transformer\data\Acrobot-v1\dataset.npz
17+
2025-11-05 17:26:25,231 [INFO] Starting training...
18+
2025-11-05 17:26:26,038 [INFO] Dataset size: 1000 clips
19+
2025-11-05 17:26:26,063 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
20+
2025-11-05 17:26:44,925 [INFO] Starting training loop...
3.54 MB
Binary file not shown.

results/all_runs/dt_CartPole-v1/training.log

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,9 @@
9191
2025-11-05 14:33:37,577 [INFO] Dataset size: 1000 clips
9292
2025-11-05 14:33:37,582 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
9393
2025-11-05 14:33:48,431 [INFO] Starting training loop...
94+
2025-11-05 15:15:08,400 [INFO] Epoch 10/1000 | Time: 241.47s | Loss: 0.6863 | Eval Return: 9.50
95+
2025-11-05 15:15:09,288 [INFO] New best eval return: 9.50. Saved best model.
96+
2025-11-05 15:49:05,018 [INFO] Epoch 20/1000 | Time: 200.64s | Loss: 0.6754 | Eval Return: 9.60
97+
2025-11-05 15:49:05,635 [INFO] New best eval return: 9.60. Saved best model.
98+
2025-11-05 16:21:15,718 [INFO] Epoch 30/1000 | Time: 191.88s | Loss: 0.6314 | Eval Return: 9.30
99+
2025-11-05 16:53:07,656 [INFO] Epoch 40/1000 | Time: 192.28s | Loss: 0.6089 | Eval Return: 9.30

snn-dt/scripts/generate_dataset.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,35 @@ def generate_random_trajectories(env_name, num_trajectories=1000, max_steps=1000
4747

4848
def process_trajectories(trajectories, max_timesteps=1000):
4949
# Find the maximum trajectory length
50-
max_len = min(max([t['length'] for t in trajectories]), max_timesteps)
50+
max_len = min(max([t['length'] for t in trajectories if t['length'] > 0], default=0), max_timesteps)
5151

52+
if not trajectories or max_len == 0:
53+
# Handle case with no trajectories or all empty trajectories
54+
return {
55+
'states': np.array([]), 'actions': np.array([]), 'returns_to_go': np.array([]),
56+
'timesteps': np.array([]), 'mask': np.array([]),
57+
'metadata': {'state_dim': 0, 'act_dim': 0, 'max_timesteps': max_timesteps}
58+
}
59+
5260
# Initialize arrays
5361
num_trajectories = len(trajectories)
5462
state_dim = trajectories[0]['states'][0].shape[0]
55-
act_dim = 1 # For discrete actions
5663

64+
# Determine action dimension from data
65+
all_actions = np.concatenate([t['actions'] for t in trajectories if t['length'] > 0])
66+
act_dim = int(all_actions.max()) + 1 if all_actions.size > 0 else 1
67+
5768
states = np.zeros((num_trajectories, max_len + 1, state_dim))
58-
actions = np.zeros((num_trajectories, max_len, act_dim))
69+
actions = np.zeros((num_trajectories, max_len, 1)) # Store actions as scalars
5970
returns_to_go = np.zeros((num_trajectories, max_len + 1, 1))
6071
timesteps = np.zeros((num_trajectories, max_len + 1))
6172
mask = np.zeros((num_trajectories, max_len + 1))
6273

6374
# Fill arrays
6475
for i, traj in enumerate(trajectories):
6576
length = min(traj['length'], max_len)
77+
if length == 0:
78+
continue
6679

6780
# States include the final state
6881
states[i, :length + 1] = traj['states'][:length + 1]
@@ -79,7 +92,7 @@ def process_trajectories(trajectories, max_timesteps=1000):
7992
# Create metadata
8093
metadata = {
8194
'state_dim': state_dim,
82-
'act_dim': 2, # CartPole has 2 actions
95+
'act_dim': act_dim,
8396
'max_timesteps': max_timesteps
8497
}
8598

0 commit comments

Comments
 (0)