Skip to content

Commit 8d89292

Browse files
committed
Update
[ghstack-poisoned]
2 parents 09a5ddf + cdffc9b commit 8d89292

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

  • torchrl/modules/tensordict_module

torchrl/modules/tensordict_module/rnn.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2461,10 +2461,17 @@ def __init__(self):
24612461
)
24622462

24632463
def get_mode(self) -> bool | None:
2464+
# Dynamo can't trace ContextVar.get; fall back to the parent's plain
2465+
# attribute under torch.compile. set_mode keeps both in sync so this
2466+
# stays correct (compile traces a single thread).
2467+
if is_compiling():
2468+
return self._mode
24642469
return self._context_mode.get()
24652470

24662471
def set_mode(self, mode: bool | None) -> None:
2467-
self._context_mode.set(mode)
2472+
self._mode = mode
2473+
if not is_compiling():
2474+
self._context_mode.set(mode)
24682475

24692476

24702477
recurrent_mode_state_manager = _RecurrentModeContextManager()

0 commit comments

Comments
 (0)