Skip to content

Commit d8c3887

Browse files
committed
[Refactor] Use compile-aware helpers in Dreamer objectives
- Replace timeit with _maybe_timeit - Add _maybe_record_function_decorator to loss forward methods - Use .data and .copy() instead of .detach() and .clone(recurse=False) - Ensure contiguous layout for tensor operations ghstack-source-id: 84bc237 Pull-Request: #3304
1 parent 3bdc7b1 commit d8c3887

1 file changed

Lines changed: 35 additions & 23 deletions

File tree

torchrl/objectives/dreamer.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tensordict.nn import TensorDictModule
1212
from tensordict.utils import NestedKey
1313

14-
from torchrl._utils import timeit
14+
from torchrl._utils import _maybe_record_function_decorator, _maybe_timeit
1515
from torchrl.envs.model_based.dreamer import DreamerEnv
1616
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
1717
from torchrl.objectives.common import LossModule
@@ -123,49 +123,57 @@ def __init__(
123123
def _forward_value_estimator_keys(self, **kwargs) -> None:
124124
pass
125125

126+
@_maybe_record_function_decorator("world_model_loss/forward")
126127
def forward(self, tensordict: TensorDict) -> torch.Tensor:
127-
tensordict = tensordict.clone(recurse=False)
128+
tensordict = tensordict.copy()
128129
tensordict.rename_key_(
129130
("next", self.tensor_keys.reward),
130131
("next", self.tensor_keys.true_reward),
131132
)
133+
132134
tensordict = self.world_model(tensordict)
133-
# compute model loss
135+
136+
prior_mean = tensordict.get(("next", self.tensor_keys.prior_mean))
137+
prior_std = tensordict.get(("next", self.tensor_keys.prior_std))
138+
posterior_mean = tensordict.get(("next", self.tensor_keys.posterior_mean))
139+
posterior_std = tensordict.get(("next", self.tensor_keys.posterior_std))
140+
134141
kl_loss = self.kl_loss(
135-
tensordict.get(("next", self.tensor_keys.prior_mean)),
136-
tensordict.get(("next", self.tensor_keys.prior_std)),
137-
tensordict.get(("next", self.tensor_keys.posterior_mean)),
138-
tensordict.get(("next", self.tensor_keys.posterior_std)),
142+
prior_mean, prior_std, posterior_mean, posterior_std,
139143
).unsqueeze(-1)
144+
145+
# Ensure contiguous layout for torch.compile compatibility
146+
# The gradient from distance_loss flows back through decoder convolutions
147+
pixels = tensordict.get(("next", self.tensor_keys.pixels)).contiguous()
148+
reco_pixels = tensordict.get(("next", self.tensor_keys.reco_pixels)).contiguous()
140149
reco_loss = distance_loss(
141-
tensordict.get(("next", self.tensor_keys.pixels)),
142-
tensordict.get(("next", self.tensor_keys.reco_pixels)),
150+
pixels,
151+
reco_pixels,
143152
self.reco_loss,
144153
)
145154
if not self.global_average:
146155
reco_loss = reco_loss.sum((-3, -2, -1))
147156
reco_loss = reco_loss.mean().unsqueeze(-1)
148157

158+
true_reward = tensordict.get(("next", self.tensor_keys.true_reward))
159+
pred_reward = tensordict.get(("next", self.tensor_keys.reward))
149160
reward_loss = distance_loss(
150-
tensordict.get(("next", self.tensor_keys.true_reward)),
151-
tensordict.get(("next", self.tensor_keys.reward)),
161+
true_reward,
162+
pred_reward,
152163
self.reward_loss,
153164
)
154165
if not self.global_average:
155166
reward_loss = reward_loss.squeeze(-1)
156167
reward_loss = reward_loss.mean().unsqueeze(-1)
157-
# import ipdb; ipdb.set_trace()
158168

159169
td_out = TensorDict(
160170
loss_model_kl=self.lambda_kl * kl_loss,
161171
loss_model_reco=self.lambda_reco * reco_loss,
162172
loss_model_reward=self.lambda_reward * reward_loss,
163173
)
164174
self._clear_weakrefs(tensordict, td_out)
165-
return (
166-
td_out,
167-
tensordict.detach(),
168-
)
175+
176+
return (td_out, tensordict.data)
169177

170178
@staticmethod
171179
def normal_log_probability(x, mean, std):
@@ -275,10 +283,11 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
275283
value=self._tensor_keys.value,
276284
)
277285

286+
@_maybe_record_function_decorator("actor_loss/forward")
278287
def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]:
279-
tensordict = tensordict.select("state", self.tensor_keys.belief).detach()
288+
tensordict = tensordict.select("state", self.tensor_keys.belief).data
280289

281-
with timeit("actor_loss/time-rollout"), hold_out_net(
290+
with _maybe_timeit("actor_loss/time-rollout"), hold_out_net(
282291
self.model_based_env
283292
), set_exploration_type(ExplorationType.RANDOM):
284293
tensordict = self.model_based_env.reset(tensordict.copy())
@@ -288,7 +297,6 @@ def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]:
288297
auto_reset=False,
289298
tensordict=tensordict,
290299
)
291-
292300
next_tensordict = step_mdp(fake_data, keep_other=True)
293301
with hold_out_net(self.value_model):
294302
next_tensordict = self.value_model(next_tensordict)
@@ -308,7 +316,8 @@ def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]:
308316
actor_loss = -lambda_target.sum((-2, -1)).mean()
309317
loss_tensordict = TensorDict({"loss_actor": actor_loss}, [])
310318
self._clear_weakrefs(tensordict, loss_tensordict)
311-
return loss_tensordict, fake_data.detach()
319+
320+
return loss_tensordict, fake_data.data
312321

313322
def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
314323
done = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device)
@@ -420,14 +429,15 @@ def __init__(
420429
def _forward_value_estimator_keys(self, **kwargs) -> None:
421430
pass
422431

432+
@_maybe_record_function_decorator("value_loss/forward")
423433
def forward(self, fake_data) -> torch.Tensor:
424434
lambda_target = fake_data.get("lambda_target")
435+
425436
tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False)
426437
self.value_model(tensordict_select)
438+
427439
if self.discount_loss:
428-
discount = self.gamma * torch.ones_like(
429-
lambda_target, device=lambda_target.device
430-
)
440+
discount = self.gamma * torch.ones_like(lambda_target, device=lambda_target.device)
431441
discount[..., 0, :] = 1
432442
discount = discount.cumprod(dim=-2)
433443
value_loss = (
@@ -452,6 +462,8 @@ def forward(self, fake_data) -> torch.Tensor:
452462
.sum((-1, -2))
453463
.mean()
454464
)
465+
455466
loss_tensordict = TensorDict({"loss_value": value_loss})
456467
self._clear_weakrefs(fake_data, loss_tensordict)
468+
457469
return loss_tensordict, fake_data

0 commit comments

Comments
 (0)