Skip to content

Commit 02e9820

Browse files
Fixed: IQL Training on Acrobot ENV
1 parent 7623086 commit 02e9820

3 files changed

Lines changed: 34 additions & 27 deletions

File tree

configs/iql_acrobot.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ learning_rate: 1e-4
88
epochs: 1000
99
eval_every: 10
1010
checkpoint_every: 50
11+
batches_per_epoch: 100
1112

1213
num_workers: 1

snn-dt/scripts/train.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def __init__(self, dataset_path):
8080

8181
# Pre-allocate memory-mapped arrays
8282
self.states = np.memmap(self.states_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, state_dim))
83-
self.actions = np.memmap(self.actions_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, action_dim))
83+
# Use int64 for discrete actions
84+
self.actions = np.memmap(self.actions_mmap_path, dtype=np.int64, mode='w+', shape=(total_transitions, action_dim))
8485
self.rewards = np.memmap(self.rewards_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, 1))
8586
self.next_states = np.memmap(self.next_states_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, state_dim))
8687
self.dones = np.memmap(self.dones_mmap_path, dtype=np.float32, mode='w+', shape=(total_transitions, 1))
@@ -98,9 +99,9 @@ def __init__(self, dataset_path):
9899
traj_rtg = data['returns_to_go'][i, :clip_len]
99100

100101
# Actions
101-
traj_actions = np.zeros((clip_len, action_dim), dtype=np.float32)
102+
traj_actions = np.zeros((clip_len, action_dim), dtype=np.int64)
102103
if clip_len > 1:
103-
traj_actions[:clip_len-1] = data['actions'][i, :clip_len-1]
104+
traj_actions[:clip_len-1] = data['actions'][i, :clip_len-1].astype(np.int64)
104105

105106
# Rewards and next_states
106107
rewards = np.zeros((clip_len, 1), dtype=np.float32)
@@ -216,12 +217,15 @@ def train(cfg, logger):
216217
epoch_losses = []
217218

218219
batch_iter = tqdm(
219-
enumerate(train_loader),
220-
total=len(train_loader),
220+
enumerate(train_loader),
221+
total=cfg.training.batches_per_epoch,
221222
desc=f"Epoch {epoch+1}/{cfg.training.epochs}"
222223
)
223224

224225
for i, batch in batch_iter:
226+
if i >= cfg.training.batches_per_epoch:
227+
break
228+
225229
model.train()
226230

227231
for k, v in batch.items():

snn-dt/src/models/iql.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,23 @@ def get_det_action(self, state):
6363

6464

6565
class Critic(nn.Module):
66-
def __init__(self, state_size, action_size, hidden_size=256):
66+
def __init__(self, state_size, action_size, hidden_size=256, is_discrete=False):
6767
super(Critic, self).__init__()
68-
self.fc1 = nn.Linear(state_size + action_size, hidden_size)
68+
self.is_discrete = is_discrete
69+
if self.is_discrete:
70+
self.action_embedding = nn.Embedding(action_size, hidden_size)
71+
self.fc1 = nn.Linear(state_size + hidden_size, hidden_size)
72+
else:
73+
self.fc1 = nn.Linear(state_size + action_size, hidden_size)
6974
self.fc2 = nn.Linear(hidden_size, hidden_size)
7075
self.fc3 = nn.Linear(hidden_size, 1)
7176

7277
def forward(self, state, action):
73-
x = torch.cat((state, action), dim=-1)
78+
if self.is_discrete:
79+
action_emb = self.action_embedding(action.long().squeeze(-1))
80+
x = torch.cat((state, action_emb), dim=-1)
81+
else:
82+
x = torch.cat((state, action), dim=-1)
7483
x = F.relu(self.fc1(x))
7584
x = F.relu(self.fc2(x))
7685
return self.fc3(x)
@@ -106,12 +115,12 @@ def __init__(self, cfg):
106115
self.is_discrete = cfg.dataset.is_discrete
107116

108117
self.actor = Actor(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size, is_discrete=self.is_discrete).to(self.device)
109-
self.critic1 = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size).to(self.device)
110-
self.critic2 = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size).to(self.device)
118+
self.critic1 = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size, is_discrete=self.is_discrete).to(self.device)
119+
self.critic2 = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size, is_discrete=self.is_discrete).to(self.device)
111120
self.value_net = Value(cfg.dataset.state_dim, cfg.iql.hidden_size).to(self.device)
112121

113-
self.critic1_target = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size).to(self.device)
114-
self.critic2_target = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size).to(self.device)
122+
self.critic1_target = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size, is_discrete=self.is_discrete).to(self.device)
123+
self.critic2_target = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size, is_discrete=self.is_discrete).to(self.device)
115124
self.critic1_target.load_state_dict(self.critic1.state_dict())
116125
self.critic2_target.load_state_dict(self.critic2.state_dict())
117126

@@ -135,13 +144,8 @@ def learn(self, batch):
135144

136145
# Value loss
137146
with torch.no_grad():
138-
if self.is_discrete:
139-
actions_one_hot = F.one_hot(actions.squeeze().long(), num_classes=self.actor.logits.out_features).float()
140-
q1 = self.critic1_target(states, actions_one_hot)
141-
q2 = self.critic2_target(states, actions_one_hot)
142-
else:
143-
q1 = self.critic1_target(states, actions)
144-
q2 = self.critic2_target(states, actions)
147+
q1 = self.critic1_target(states, actions)
148+
q2 = self.critic2_target(states, actions)
145149
min_q = torch.min(q1, q2)
146150
value = self.value_net(states)
147151
value_loss = loss_fn(min_q - value, self.expectile).mean()
@@ -155,7 +159,10 @@ def learn(self, batch):
155159
exp_a = torch.exp((min_q - v) * self.temperature)
156160
exp_a = torch.min(exp_a, torch.tensor(100.0, device=self.device))
157161
_, dist = self.actor.evaluate(states)
158-
log_probs = dist.log_prob(actions)
162+
if self.is_discrete:
163+
log_probs = dist.log_prob(actions.squeeze(-1).long())
164+
else:
165+
log_probs = dist.log_prob(actions)
159166
actor_loss = -(exp_a * log_probs).mean()
160167
self.actor_optimizer.zero_grad()
161168
actor_loss.backward()
@@ -166,13 +173,8 @@ def learn(self, batch):
166173
next_v = self.value_net(next_states)
167174
q_target = rewards + self.gamma * (1 - dones) * next_v
168175

169-
if self.is_discrete:
170-
actions_one_hot = F.one_hot(actions.squeeze().long(), num_classes=self.actor.logits.out_features).float()
171-
q1 = self.critic1(states, actions_one_hot)
172-
q2 = self.critic2(states, actions_one_hot)
173-
else:
174-
q1 = self.critic1(states, actions)
175-
q2 = self.critic2(states, actions)
176+
q1 = self.critic1(states, actions)
177+
q2 = self.critic2(states, actions)
176178

177179
critic1_loss = F.mse_loss(q1, q_target)
178180
critic2_loss = F.mse_loss(q2, q_target)

0 commit comments

Comments
 (0)