Skip to content

Commit 59fb47e

Browse files
Fixed: IQL Training for Baseline
1 parent 08f4680 commit 59fb47e

3 files changed

Lines changed: 82 additions & 52 deletions

File tree

results/all_runs/iql_CartPole-v1/training.log

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,9 @@
1818
2025-11-06 11:37:53,843 [INFO] Checking for dataset...
1919
2025-11-06 11:37:54,177 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
2020
2025-11-06 11:37:54,178 [INFO] Starting training...
21+
2025-11-06 12:36:23,691 [INFO] Checking for dataset...
22+
2025-11-06 12:36:23,705 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
23+
2025-11-06 12:36:23,706 [INFO] Starting training...
24+
2025-11-06 12:37:13,769 [INFO] Dataset size: 22826 clips
25+
2025-11-06 12:37:14,136 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
26+
2025-11-06 12:37:30,656 [INFO] Starting training loop...

snn-dt/scripts/train.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -57,62 +57,73 @@ class OfflineTransitionDataset(Dataset):
5757
def __init__(self, dataset_path):
5858
data = np.load(dataset_path)
5959

60-
# First, calculate the total number of transitions
61-
total_transitions = 0
62-
for i in range(data["states"].shape[0]):
63-
mask = data["mask"][i]
64-
clip_len = int(mask.sum())
65-
clip_len = min(clip_len, data["states"].shape[1]) # Cap clip_len
66-
total_transitions += clip_len
67-
68-
# Pre-allocate numpy arrays
69-
state_dim = data["states"].shape[2]
70-
action_dim = data["actions"].shape[2]
60+
states_list, actions_list, rewards_list, next_states_list, dones_list = [], [], [], [], []
7161

72-
self.states = np.zeros((total_transitions, state_dim), dtype=np.float32)
73-
self.actions = np.zeros((total_transitions, action_dim), dtype=np.float32)
74-
self.rewards = np.zeros((total_transitions, 1), dtype=np.float32)
75-
self.next_states = np.zeros((total_transitions, state_dim), dtype=np.float32)
76-
self.dones = np.zeros((total_transitions, 1), dtype=np.float32)
77-
78-
current_idx = 0
62+
# Check if actions exist and get action_dim, otherwise handle gracefully
63+
if "actions" not in data or data["actions"].shape[0] == 0:
64+
action_dim = 0 # Placeholder, this dataset will be empty
65+
else:
66+
action_dim = data["actions"].shape[2]
67+
7968
for i in range(data["states"].shape[0]):
8069
mask = data["mask"][i]
8170
clip_len = int(mask.sum())
82-
clip_len = min(clip_len, data["states"].shape[1]) # Cap clip_len
83-
71+
8472
if clip_len == 0:
8573
continue
8674

87-
# Get the trajectory data
75+
# Trajectory data
8876
traj_states = data["states"][i, :clip_len]
89-
traj_actions = data["actions"][i, :clip_len]
9077
traj_rtg = data["returns_to_go"][i, :clip_len]
9178

92-
# Populate states and actions
93-
self.states[current_idx : current_idx + clip_len] = traj_states
94-
self.actions[current_idx : current_idx + clip_len] = traj_actions
79+
# Add all states for this trajectory to the list
80+
states_list.append(traj_states)
9581

96-
# Populate rewards, next_states, and dones
82+
# Actions: N-1 real actions, plus one dummy action for the terminal state
83+
traj_actions_list = []
9784
if clip_len > 1:
98-
self.rewards[current_idx : current_idx + clip_len - 1] = (traj_rtg[:-1] - traj_rtg[1:]).reshape(-1, 1)
99-
self.next_states[current_idx : current_idx + clip_len - 1] = traj_states[1:]
100-
self.dones[current_idx : current_idx + clip_len - 1] = 0.0
101-
102-
# Final transition
103-
if clip_len > 0:
104-
self.rewards[current_idx + clip_len - 1] = traj_rtg[-1]
105-
self.next_states[current_idx + clip_len - 1] = np.zeros_like(traj_states[-1])
106-
self.dones[current_idx + clip_len - 1] = 1.0
85+
traj_actions_list.append(data["actions"][i, :clip_len - 1])
86+
87+
# Add a dummy action for the terminal state
88+
dummy_action = np.zeros((1, action_dim), dtype=np.float32)
89+
traj_actions_list.append(dummy_action)
90+
actions_list.append(np.concatenate(traj_actions_list, axis=0))
91+
92+
# Rewards and next_states
93+
rewards = np.zeros((clip_len, 1), dtype=np.float32)
94+
next_states = np.zeros_like(traj_states)
95+
dones = np.zeros((clip_len, 1), dtype=np.float32)
96+
97+
if clip_len > 1:
98+
# Rewards for non-terminal states
99+
rewards[:-1] = (traj_rtg[:-1] - traj_rtg[1:]).reshape(-1, 1)
100+
# Next_states for non-terminal states
101+
next_states[:-1] = traj_states[1:]
102+
103+
# Handle terminal transition
104+
rewards[-1] = traj_rtg[-1] # This is RTG, not a true reward, but matches original logic
105+
# next_states[-1] is already zeros
106+
dones[-1] = 1.0
107107

108-
current_idx += clip_len
109-
110-
# Convert to torch tensors
111-
self.states = torch.from_numpy(self.states).float()
112-
self.actions = torch.from_numpy(self.actions).float()
113-
self.rewards = torch.from_numpy(self.rewards).float()
114-
self.next_states = torch.from_numpy(self.next_states).float()
115-
self.dones = torch.from_numpy(self.dones).float()
108+
rewards_list.append(rewards)
109+
next_states_list.append(next_states)
110+
dones_list.append(dones)
111+
112+
# Handle case where no valid trajectories were found
113+
if not states_list:
114+
self.states = torch.empty(0, data["states"].shape[2], dtype=torch.float32)
115+
self.actions = torch.empty(0, action_dim, dtype=torch.float32)
116+
self.rewards = torch.empty(0, 1, dtype=torch.float32)
117+
self.next_states = torch.empty(0, data["states"].shape[2], dtype=torch.float32)
118+
self.dones = torch.empty(0, 1, dtype=torch.float32)
119+
return
120+
121+
# Concatenate and convert to torch tensors
122+
self.states = torch.from_numpy(np.concatenate(states_list, axis=0)).float()
123+
self.actions = torch.from_numpy(np.concatenate(actions_list, axis=0)).float()
124+
self.rewards = torch.from_numpy(np.concatenate(rewards_list, axis=0)).float()
125+
self.next_states = torch.from_numpy(np.concatenate(next_states_list, axis=0)).float()
126+
self.dones = torch.from_numpy(np.concatenate(dones_list, axis=0)).float()
116127

117128
def __len__(self):
118129
return len(self.states)

snn-dt/src/models/iql.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ def __init__(self, cfg):
103103
self.tau = cfg.iql.tau
104104
self.temperature = cfg.iql.temperature
105105
self.expectile = cfg.iql.expectile
106+
self.is_discrete = cfg.dataset.is_discrete
106107

107-
self.actor = Actor(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size).to(self.device)
108+
self.actor = Actor(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size, is_discrete=self.is_discrete).to(self.device)
108109
self.critic1 = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size).to(self.device)
109110
self.critic2 = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size).to(self.device)
110111
self.value_net = Value(cfg.dataset.state_dim, cfg.iql.hidden_size).to(self.device)
@@ -114,10 +115,10 @@ def __init__(self, cfg):
114115
self.critic1_target.load_state_dict(self.critic1.state_dict())
115116
self.critic2_target.load_state_dict(self.critic2.state_dict())
116117

117-
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=cfg.training.lr)
118-
self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=cfg.training.lr)
119-
self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=cfg.training.lr)
120-
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=cfg.training.lr)
118+
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=float(cfg.training.lr))
119+
self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=float(cfg.training.lr))
120+
self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=float(cfg.training.lr))
121+
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=float(cfg.training.lr))
121122

122123
def forward(self, batch):
123124
# IQL has a custom learn method, so forward is a no-op for now
@@ -134,8 +135,13 @@ def learn(self, batch):
134135

135136
# Value loss
136137
with torch.no_grad():
137-
q1 = self.critic1_target(states, actions)
138-
q2 = self.critic2_target(states, actions)
138+
if self.is_discrete:
139+
actions_one_hot = F.one_hot(actions.squeeze().long(), num_classes=self.actor.logits.out_features).float()
140+
q1 = self.critic1_target(states, actions_one_hot)
141+
q2 = self.critic2_target(states, actions_one_hot)
142+
else:
143+
q1 = self.critic1_target(states, actions)
144+
q2 = self.critic2_target(states, actions)
139145
min_q = torch.min(q1, q2)
140146
value = self.value_net(states)
141147
value_loss = loss_fn(min_q - value, self.expectile).mean()
@@ -159,8 +165,15 @@ def learn(self, batch):
159165
with torch.no_grad():
160166
next_v = self.value_net(next_states)
161167
q_target = rewards + self.gamma * (1 - dones) * next_v
162-
q1 = self.critic1(states, actions)
163-
q2 = self.critic2(states, actions)
168+
169+
if self.is_discrete:
170+
actions_one_hot = F.one_hot(actions.squeeze().long(), num_classes=self.actor.logits.out_features).float()
171+
q1 = self.critic1(states, actions_one_hot)
172+
q2 = self.critic2(states, actions_one_hot)
173+
else:
174+
q1 = self.critic1(states, actions)
175+
q2 = self.critic2(states, actions)
176+
164177
critic1_loss = F.mse_loss(q1, q_target)
165178
critic2_loss = F.mse_loss(q2, q_target)
166179

0 commit comments

Comments
 (0)