Skip to content

Commit 11e22ee

Browse files
committed
[BugFix] Fix TDLambdaEstimator for torch.compile compatibility
Make vectorized property return False during compile to avoid vectorized code paths that may cause issues. Also ensure lmbda device matches the input reward tensor device. ghstack-source-id: bcbffdb Pull-Request: #3302
1 parent d57fdec commit 11e22ee

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

torchrl/objectives/value/advantages.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,16 @@ def __init__(
10811081
self.vectorized = vectorized
10821082
self.time_dim = time_dim
10831083

1084+
@property
1085+
def vectorized(self):
1086+
if is_dynamo_compiling():
1087+
return False
1088+
return self._vectorized
1089+
1090+
@vectorized.setter
1091+
def vectorized(self, value):
1092+
self._vectorized = value
1093+
10841094
@_self_set_skip_existing
10851095
@_self_set_grad_enabled
10861096
@dispatch
@@ -1206,6 +1216,8 @@ def value_estimate(
12061216
if steps_to_next_obs is not None:
12071217
gamma = gamma ** steps_to_next_obs.view_as(reward)
12081218

1219+
if self.lmbda.device != device:
1220+
self.lmbda = self.lmbda.to(device)
12091221
lmbda = self.lmbda
12101222
if self.average_rewards:
12111223
reward = reward - reward.mean()

0 commit comments

Comments
 (0)