Skip to content

Commit acdd3a2

Browse files
Fixed: CQL Training
1 parent a4c9607 commit acdd3a2

7 files changed

Lines changed: 196 additions & 85 deletions

File tree

configs/cql_acrobot.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,12 @@ learning_rate: 1e-4
88
epochs: 1000
99
eval_every: 10
1010
checkpoint_every: 50
11+
num_workers: 1
1112

12-
num_workers: 1
13+
cql:
14+
tau: 0.005
15+
hidden_size: 256
16+
with_lagrange: true
17+
temperature: 1.0
18+
cql_weight: 5.0
19+
target_action_gap: 10.0

configs/cql_cartpole.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,12 @@ learning_rate: 1e-4
88
epochs: 1000
99
eval_every: 10
1010
checkpoint_every: 50
11+
num_workers: 1
1112

12-
num_workers: 1
13+
cql:
14+
tau: 0.005
15+
hidden_size: 256
16+
with_lagrange: true
17+
temperature: 1.0
18+
cql_weight: 5.0
19+
target_action_gap: 10.0

configs/cql_mountaincar.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,12 @@ learning_rate: 1e-4
88
epochs: 1000
99
eval_every: 10
1010
checkpoint_every: 50
11+
num_workers: 1
1112

12-
num_workers: 1
13+
cql:
14+
tau: 0.005
15+
hidden_size: 256
16+
with_lagrange: true
17+
temperature: 1.0
18+
cql_weight: 5.0
19+
target_action_gap: 10.0

configs/cql_pendulum.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,12 @@ learning_rate: 1e-4
88
epochs: 1000
99
eval_every: 10
1010
checkpoint_every: 50
11+
num_workers: 1
1112

12-
num_workers: 1
13+
cql:
14+
tau: 0.005
15+
hidden_size: 256
16+
with_lagrange: true
17+
temperature: 1.0
18+
cql_weight: 5.0
19+
target_action_gap: 10.0
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2025-11-07 08:48:41,523 [INFO] Checking for dataset...
2+
2025-11-07 08:48:41,524 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
3+
2025-11-07 08:48:41,525 [INFO] Starting training...
4+
2025-11-07 08:49:30,159 [INFO] Dataset size: 22826 clips
5+
2025-11-07 08:49:30,192 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
6+
2025-11-07 08:55:45,661 [INFO] Checking for dataset...
7+
2025-11-07 08:55:45,672 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
8+
2025-11-07 08:55:45,674 [INFO] Starting training...
9+
2025-11-07 08:57:17,019 [INFO] Dataset size: 22826 clips
10+
2025-11-07 08:57:17,068 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
11+
2025-11-07 08:58:27,942 [INFO] Checking for dataset...
12+
2025-11-07 08:58:27,943 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
13+
2025-11-07 08:58:27,944 [INFO] Starting training...
14+
2025-11-07 08:59:19,244 [INFO] Dataset size: 22826 clips
15+
2025-11-07 08:59:19,278 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
16+
2025-11-07 09:10:42,753 [INFO] Checking for dataset...
17+
2025-11-07 09:10:42,763 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
18+
2025-11-07 09:10:42,765 [INFO] Starting training...
19+
2025-11-07 09:11:54,307 [INFO] Dataset size: 22826 clips
20+
2025-11-07 09:11:54,364 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
21+
2025-11-07 09:12:13,379 [INFO] Starting training loop...
22+
2025-11-07 09:23:31,716 [INFO] Checking for dataset...
23+
2025-11-07 09:23:31,718 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
24+
2025-11-07 09:23:31,722 [INFO] Starting training...
25+
2025-11-07 09:24:22,292 [INFO] Dataset size: 22826 clips
26+
2025-11-07 09:24:22,320 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
27+
2025-11-07 09:28:19,691 [INFO] Checking for dataset...
28+
2025-11-07 09:28:19,725 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
29+
2025-11-07 09:28:19,726 [INFO] Starting training...
30+
2025-11-07 09:29:27,110 [INFO] Dataset size: 22826 clips
31+
2025-11-07 09:29:27,149 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
32+
2025-11-07 09:29:42,850 [INFO] Starting training loop...
33+
2025-11-07 09:53:58,753 [INFO] Checking for dataset...
34+
2025-11-07 09:53:58,895 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
35+
2025-11-07 09:53:58,896 [INFO] Starting training...
36+
2025-11-07 09:54:49,612 [INFO] Dataset size: 22826 clips
37+
2025-11-07 09:54:49,645 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
38+
2025-11-07 09:55:00,870 [INFO] Starting training loop...

snn-dt/scripts/train.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,31 +60,19 @@ class OfflineTransitionDataset(Dataset):
6060
def __init__(self, dataset_path):
6161
data = np.load(dataset_path, mmap_mode='r')
6262

63-
# Pre-calculate total number of transitions to pre-allocate memmapped arrays
6463
total_transitions = 0
6564
for i in range(data['mask'].shape[0]):
6665
total_transitions += int(data['mask'][i].sum())
6766

6867
state_dim = data['states'].shape[2]
6968
action_dim = data['actions'].shape[2] if 'actions' in data.keys() and data['actions'].shape[0] > 0 else 0
7069

71-
# Create a temporary directory for memory-mapped files
72-
self.temp_dir = tempfile.mkdtemp()
73-
atexit.register(self._cleanup)
74-
75-
self.states_mmap_path = os.path.join(self.temp_dir, 'states.mmap')
76-
self.actions_mmap_path = os.path.join(self.temp_dir, 'actions.mmap')
77-
self.rewards_mmap_path = os.path.join(self.temp_dir, 'rewards.mmap')
78-
self.next_states_mmap_path = os.path.join(self.temp_dir, 'next_states.mmap')
79-
self.dones_mmap_path = os.path.join(self.temp_dir, 'dones.mmap')
80-
81-
# Pre-allocate memory-mapped arrays
82-
self.states = np.memmap(self.states_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, state_dim))
83-
# Use int64 for discrete actions
84-
self.actions = np.memmap(self.actions_mmap_path, dtype=np.int64, mode='w+', shape=(total_transitions, action_dim))
85-
self.rewards = np.memmap(self.rewards_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, 1))
86-
self.next_states = np.memmap(self.next_states_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, state_dim))
87-
self.dones = np.memmap(self.dones_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, 1))
70+
# Pre-allocate arrays in memory
71+
self.states = np.empty((total_transitions, state_dim), dtype=np.float32)
72+
self.actions = np.empty((total_transitions, action_dim), dtype=np.int64)
73+
self.rewards = np.empty((total_transitions, 1), dtype=np.float32)
74+
self.next_states = np.empty((total_transitions, state_dim), dtype=np.float32)
75+
self.dones = np.empty((total_transitions, 1), dtype=np.float32)
8876

8977
current_idx = 0
9078
for i in range(data['states'].shape[0]):
@@ -94,16 +82,13 @@ def __init__(self, dataset_path):
9482
if clip_len == 0:
9583
continue
9684

97-
# Trajectory data
9885
traj_states = data['states'][i, :clip_len]
9986
traj_rtg = data['returns_to_go'][i, :clip_len]
10087

101-
# Actions
10288
traj_actions = np.zeros((clip_len, action_dim), dtype=np.int64)
10389
if clip_len > 1:
10490
traj_actions[:clip_len-1] = data['actions'][i, :clip_len-1].astype(np.int64)
10591

106-
# Rewards and next_states
10792
rewards = np.zeros((clip_len, 1), dtype=np.float32)
10893
next_states = np.zeros_like(traj_states)
10994
dones = np.zeros((clip_len, 1), dtype=np.float32)
@@ -115,7 +100,6 @@ def __init__(self, dataset_path):
115100
rewards[-1] = traj_rtg[-1]
116101
dones[-1] = 1.0
117102

118-
# Write to memmapped arrays
119103
self.states[current_idx:current_idx+clip_len] = traj_states
120104
self.actions[current_idx:current_idx+clip_len] = traj_actions
121105
self.rewards[current_idx:current_idx+clip_len] = rewards
@@ -124,16 +108,12 @@ def __init__(self, dataset_path):
124108

125109
current_idx += clip_len
126110

127-
# Convert numpy memmaps to torch tensors
128111
self.states = torch.from_numpy(self.states).float()
129112
self.actions = torch.from_numpy(self.actions).float()
130113
self.rewards = torch.from_numpy(self.rewards).float()
131114
self.next_states = torch.from_numpy(self.next_states).float()
132115
self.dones = torch.from_numpy(self.dones).float()
133116

134-
def _cleanup(self):
135-
shutil.rmtree(self.temp_dir)
136-
137117
def __len__(self):
138118
return len(self.states)
139119

@@ -383,6 +363,14 @@ def main():
383363
"temperature": cfg_raw.get("temperature", 3.0),
384364
"expectile": cfg_raw.get("expectile", 0.7),
385365
"hidden_size": cfg_raw.get("hidden_size", 256)
366+
},
367+
"cql": {
368+
"tau": cfg_raw.get("tau", 0.005),
369+
"temperature": cfg_raw.get("temperature", 1.0),
370+
"hidden_size": cfg_raw.get("hidden_size", 256),
371+
"with_lagrange": cfg_raw.get("with_lagrange", False),
372+
"cql_weight": cfg_raw.get("cql_weight", 1.0),
373+
"target_action_gap": cfg_raw.get("target_action_gap", 10.0)
386374
}
387375
}
388376

0 commit comments

Comments
 (0)