1111from tensordict .nn import TensorDictModule
1212from tensordict .utils import NestedKey
1313
14- from torchrl ._utils import timeit
14+ from torchrl ._utils import _maybe_record_function_decorator , _maybe_timeit
1515from torchrl .envs .model_based .dreamer import DreamerEnv
1616from torchrl .envs .utils import ExplorationType , set_exploration_type , step_mdp
1717from torchrl .objectives .common import LossModule
@@ -123,49 +123,57 @@ def __init__(
123123 def _forward_value_estimator_keys (self , ** kwargs ) -> None :
124124 pass
125125
126+ @_maybe_record_function_decorator ("world_model_loss/forward" )
126127 def forward (self , tensordict : TensorDict ) -> torch .Tensor :
127- tensordict = tensordict .clone ( recurse = False )
128+ tensordict = tensordict .copy ( )
128129 tensordict .rename_key_ (
129130 ("next" , self .tensor_keys .reward ),
130131 ("next" , self .tensor_keys .true_reward ),
131132 )
133+
132134 tensordict = self .world_model (tensordict )
133- # compute model loss
135+
136+ prior_mean = tensordict .get (("next" , self .tensor_keys .prior_mean ))
137+ prior_std = tensordict .get (("next" , self .tensor_keys .prior_std ))
138+ posterior_mean = tensordict .get (("next" , self .tensor_keys .posterior_mean ))
139+ posterior_std = tensordict .get (("next" , self .tensor_keys .posterior_std ))
140+
134141 kl_loss = self .kl_loss (
135- tensordict .get (("next" , self .tensor_keys .prior_mean )),
136- tensordict .get (("next" , self .tensor_keys .prior_std )),
137- tensordict .get (("next" , self .tensor_keys .posterior_mean )),
138- tensordict .get (("next" , self .tensor_keys .posterior_std )),
142+ prior_mean , prior_std , posterior_mean , posterior_std ,
139143 ).unsqueeze (- 1 )
144+
145+ # Ensure contiguous layout for torch.compile compatibility
146+ # The gradient from distance_loss flows back through decoder convolutions
147+ pixels = tensordict .get (("next" , self .tensor_keys .pixels )).contiguous ()
148+ reco_pixels = tensordict .get (("next" , self .tensor_keys .reco_pixels )).contiguous ()
140149 reco_loss = distance_loss (
141- tensordict . get (( "next" , self . tensor_keys . pixels )) ,
142- tensordict . get (( "next" , self . tensor_keys . reco_pixels )) ,
150+ pixels ,
151+ reco_pixels ,
143152 self .reco_loss ,
144153 )
145154 if not self .global_average :
146155 reco_loss = reco_loss .sum ((- 3 , - 2 , - 1 ))
147156 reco_loss = reco_loss .mean ().unsqueeze (- 1 )
148157
158+ true_reward = tensordict .get (("next" , self .tensor_keys .true_reward ))
159+ pred_reward = tensordict .get (("next" , self .tensor_keys .reward ))
149160 reward_loss = distance_loss (
150- tensordict . get (( "next" , self . tensor_keys . true_reward )) ,
151- tensordict . get (( "next" , self . tensor_keys . reward )) ,
161+ true_reward ,
162+ pred_reward ,
152163 self .reward_loss ,
153164 )
154165 if not self .global_average :
155166 reward_loss = reward_loss .squeeze (- 1 )
156167 reward_loss = reward_loss .mean ().unsqueeze (- 1 )
157- # import ipdb; ipdb.set_trace()
158168
159169 td_out = TensorDict (
160170 loss_model_kl = self .lambda_kl * kl_loss ,
161171 loss_model_reco = self .lambda_reco * reco_loss ,
162172 loss_model_reward = self .lambda_reward * reward_loss ,
163173 )
164174 self ._clear_weakrefs (tensordict , td_out )
165- return (
166- td_out ,
167- tensordict .detach (),
168- )
175+
176+ return (td_out , tensordict .data )
169177
170178 @staticmethod
171179 def normal_log_probability (x , mean , std ):
@@ -275,10 +283,11 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
275283 value = self ._tensor_keys .value ,
276284 )
277285
286+ @_maybe_record_function_decorator ("actor_loss/forward" )
278287 def forward (self , tensordict : TensorDict ) -> tuple [TensorDict , TensorDict ]:
279- tensordict = tensordict .select ("state" , self .tensor_keys .belief ).detach ()
288+ tensordict = tensordict .select ("state" , self .tensor_keys .belief ).data
280289
281- with timeit ("actor_loss/time-rollout" ), hold_out_net (
290+ with _maybe_timeit ("actor_loss/time-rollout" ), hold_out_net (
282291 self .model_based_env
283292 ), set_exploration_type (ExplorationType .RANDOM ):
284293 tensordict = self .model_based_env .reset (tensordict .copy ())
@@ -288,7 +297,6 @@ def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]:
288297 auto_reset = False ,
289298 tensordict = tensordict ,
290299 )
291-
292300 next_tensordict = step_mdp (fake_data , keep_other = True )
293301 with hold_out_net (self .value_model ):
294302 next_tensordict = self .value_model (next_tensordict )
@@ -308,7 +316,8 @@ def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]:
308316 actor_loss = - lambda_target .sum ((- 2 , - 1 )).mean ()
309317 loss_tensordict = TensorDict ({"loss_actor" : actor_loss }, [])
310318 self ._clear_weakrefs (tensordict , loss_tensordict )
311- return loss_tensordict , fake_data .detach ()
319+
320+ return loss_tensordict , fake_data .data
312321
313322 def lambda_target (self , reward : torch .Tensor , value : torch .Tensor ) -> torch .Tensor :
314323 done = torch .zeros (reward .shape , dtype = torch .bool , device = reward .device )
@@ -420,14 +429,15 @@ def __init__(
420429 def _forward_value_estimator_keys (self , ** kwargs ) -> None :
421430 pass
422431
432+ @_maybe_record_function_decorator ("value_loss/forward" )
423433 def forward (self , fake_data ) -> torch .Tensor :
424434 lambda_target = fake_data .get ("lambda_target" )
435+
425436 tensordict_select = fake_data .select (* self .value_model .in_keys , strict = False )
426437 self .value_model (tensordict_select )
438+
427439 if self .discount_loss :
428- discount = self .gamma * torch .ones_like (
429- lambda_target , device = lambda_target .device
430- )
440+ discount = self .gamma * torch .ones_like (lambda_target , device = lambda_target .device )
431441 discount [..., 0 , :] = 1
432442 discount = discount .cumprod (dim = - 2 )
433443 value_loss = (
@@ -452,6 +462,8 @@ def forward(self, fake_data) -> torch.Tensor:
452462 .sum ((- 1 , - 2 ))
453463 .mean ()
454464 )
465+
455466 loss_tensordict = TensorDict ({"loss_value" : value_loss })
456467 self ._clear_weakrefs (fake_data , loss_tensordict )
468+
457469 return loss_tensordict , fake_data
0 commit comments