Skip to content

Commit 6840ae0

Browse files
Fix world-model training gradients
1 parent ef49663 commit 6840ae0

2 files changed

Lines changed: 25 additions & 23 deletions

File tree

worldmodels/worldmodel_models/deterministic.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,18 @@ def update(self, batch: list[dict]) -> dict[str, float]:
8686

8787
for item in batch:
8888
state = self.observe(self.init_state(batch_size=1), item["obs"])
89-
next_state, _pred_obs, pred_reward, pred_done, _aux = self.predict(
90-
state, int(item["action"])
91-
)
92-
del next_state
89+
action_tensor = self._action_tensor(int(item["action"]))
90+
transition_input = torch.cat([state["latent"], action_tensor], dim=-1)
91+
next_latent = self.transition(transition_input)
92+
93+
pred_reward = self.reward_head(next_latent).squeeze(-1)
94+
pred_done_prob = torch.sigmoid(self.done_head(next_latent)).squeeze(-1)
9395

9496
target_reward = torch.tensor([item["reward"]], device=self.device)
9597
target_done = torch.tensor([float(item["done"])], device=self.device)
9698

97-
reward_loss = (
98-
(torch.tensor([pred_reward], device=self.device) - target_reward).pow(2).mean()
99-
)
100-
done_loss = (
101-
(torch.tensor([float(pred_done)], device=self.device) - target_done).pow(2).mean()
102-
)
99+
reward_loss = (pred_reward - target_reward).pow(2).mean()
100+
done_loss = (pred_done_prob - target_done).pow(2).mean()
103101
loss = loss + reward_loss + done_loss
104102

105103
loss = loss / len(batch)

worldmodels/worldmodel_models/stochastic.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,24 +95,28 @@ def update(self, batch: list[dict]) -> dict[str, float]:
9595

9696
for item in batch:
9797
posterior_state = self.observe(self.init_state(batch_size=1), item["obs"])
98-
pred_state, _pred_obs, pred_reward, pred_done, _aux = self.predict(
99-
posterior_state, int(item["action"])
100-
)
101-
del pred_state
98+
action_tensor = self._action_tensor(int(item["action"]))
99+
prior_stats = self.prior(torch.cat([posterior_state["h"], action_tensor], dim=-1))
100+
mean, logvar = torch.chunk(prior_stats, 2, dim=-1)
101+
std = torch.exp(0.5 * logvar).clamp(min=1e-4)
102+
eps = torch.randn_like(std)
103+
z = mean + eps * std
104+
h = self.gru(z, posterior_state["h"])
105+
106+
pred_reward = self.reward_head(h).squeeze(-1)
107+
pred_done_prob = torch.sigmoid(self.done_head(h)).squeeze(-1)
102108

103109
target_reward = torch.tensor([item["reward"]], device=self.device)
104110
target_done = torch.tensor([float(item["done"])], device=self.device)
105111

106-
reward_loss = (
107-
(torch.tensor([pred_reward], device=self.device) - target_reward).pow(2).mean()
108-
)
109-
done_loss = (
110-
(torch.tensor([float(pred_done)], device=self.device) - target_done).pow(2).mean()
111-
)
112+
reward_loss = (pred_reward - target_reward).pow(2).mean()
113+
done_loss = (pred_done_prob - target_done).pow(2).mean()
112114

113-
mean = posterior_state["mean"]
114-
logvar = posterior_state["logvar"]
115-
kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
115+
posterior_mean = posterior_state["mean"]
116+
posterior_logvar = posterior_state["logvar"]
117+
kl = -0.5 * torch.mean(
118+
1 + posterior_logvar - posterior_mean.pow(2) - posterior_logvar.exp()
119+
)
116120

117121
total = total + reward_loss + done_loss + 0.1 * kl
118122
kl_total = kl_total + kl

0 commit comments

Comments
 (0)