Overview
This RFC proposes adding torchrl.modules.WorldModel, a composable, TensorDict-native abstraction for learned environment dynamics, to serve as a general foundation for model-based RL workflows in TorchRL.
Motivation
TorchRL has strong abstractions for environments (EnvBase), policies (TensorDictModule), data collection (Collector), replay buffers, and loss modules. Model-based RL, however, still requires users to write substantial glue code.
What exists today
TorchRL ships Dreamer-specific components:
RSSMPrior, RSSMPosterior, RSSMRollout — RSSM-specific dynamics (also V3 discrete variants)
ObsEncoder, ObsDecoder — Conv stacks for pixel observations
DreamerActor — distribution head for the imagination policy
WorldModelWrapper — thin TensorDictSequential(transition_model, reward_model) with no rollout interface
ModelBasedEnvBase / DreamerEnv — environment wrapper that calls world_model(td) in _step()
DreamerModelLoss, DreamerActorLoss, DreamerValueLoss — losses coupled tightly to the RSSM + DreamerEnv stack
The gap
There is no general abstraction for learned dynamics. Users implementing any world-model-based algorithm (MBPO, TD-MPC, DreamerV3, PlaNet, Iris, CWVAE, or a custom learned env) must:
- Manually wrap each component with
TensorDictModule and wire in_keys/out_keys by hand
- Write a custom multi-step rollout loop (no standard
rollout(policy, horizon) interface)
- Accept that the losses (
DreamerActorLoss) are tightly coupled to DreamerEnv, not to a generic dynamics interface
- Write ~150 lines of boilerplate (
make_dreamer() in sota-implementations/dreamer/dreamer_utils.py) before the training loop starts
The result: world-model workflows are effectively siloed from the rest of TorchRL. Imagined rollouts cannot be dropped into standard collectors, sequence replay buffers, or existing loss modules without significant adaptation.
Proposed API
WorldModel
from torchrl.modules import WorldModel
world_model = WorldModel(
encoder=encoder, # TensorDictModule: obs → latent
dynamics=dynamics, # TensorDictModule: (latent, action) → next_latent
reward_head=reward_head, # TensorDictModule: next_latent → reward
done_head=done_head, # TensorDictModule: next_latent → done (optional)
decoder=decoder, # TensorDictModule: latent → obs_recon (optional)
observation_key="observation",
action_key="action",
latent_key=("latent", "state"),
next_latent_key=("next", "latent", "state"),
)
Key methods:
# Encode a real observation into the latent space
latent_td = world_model.encode(tensordict)
# Take one imagined step given current latent and action
next_td = world_model.step(tensordict)
# Decode a latent back to observation space (requires decoder)
recon_td = world_model.decode(tensordict)
# Run an imagined rollout for `horizon` steps with a given policy
# Returns TensorDict of shape [batch, horizon] — same layout as a real rollout
imagined_td = world_model.rollout(
start_td=start_td,
policy=actor,
horizon=15,
)
WorldModelLoss
from torchrl.objectives import WorldModelLoss
loss_module = WorldModelLoss(
world_model,
losses=["latent", "reward", "done", "reconstruction"],
kl_weight=1.0,
reconstruction_weight=1.0,
)
real_batch = replay_buffer.sample(batch_size=256)
loss_td = loss_module(real_batch)
# loss_td contains: loss_latent, loss_reward, loss_done, loss_reconstruction
End-to-end training sketch
# World model update (on real data)
real_batch = replay_buffer.sample(batch_size=256)
loss_td = model_loss(real_batch)
(loss_td["loss_latent"] + loss_td["loss_reward"]).backward()
model_opt.step()
# Actor/critic update (on imagined rollouts)
start_td = real_batch[:, 0]
imagined_td = world_model.rollout(start_td, policy=actor, horizon=15)
actor_loss_td = actor_loss(imagined_td)
actor_loss_td["loss_objective"].backward()
actor_opt.step()
Compatibility with Existing TorchRL Infrastructure
| Component |
Compatibility |
TensorDictReplayBuffer |
Imagined rollouts stored directly (same TensorDict layout) |
SliceSampler |
Imagined trajectories are sequences — sample subsequences for recurrent training |
GAE / MultiAgentGAE |
Value targets computed from imagined rollouts |
ClipPPOLoss, SACLoss |
Consume imagined TensorDicts without modification |
EnvBase specs |
WorldModel optionally validates against observation_spec, action_spec |
Collector / SyncDataCollector |
Imagined rollout interface mirrors real rollout interface |
DreamerEnv (existing) |
Remains supported; WorldModel.rollout is an alternative that avoids the EnvBase overhead |
Migration Path for Dreamer Users
Existing DreamerModelLoss, DreamerActorLoss, DreamerValueLoss are unchanged. Users may optionally migrate:
# Before (current Dreamer setup — ~150 LOC in make_dreamer())
world_model = WorldModelWrapper(
TensorDictSequential(rssm_prior, rssm_posterior, ...),
reward_model,
)
dreamer_env = DreamerEnv(world_model, ...)
imagined_td = dreamer_env.rollout(15, actor)
# After (with WorldModel abstraction)
world_model = WorldModel(
encoder=obs_encoder,
dynamics=TensorDictSequential(rssm_prior, rssm_posterior),
reward_head=reward_mlp,
done_head=None,
)
imagined_td = world_model.rollout(start_td, actor, horizon=15)
Scope of Implementation
New files:
torchrl/modules/world_model.py — WorldModel
torchrl/objectives/world_model_loss.py — WorldModelLoss
test/test_world_model.py — unit tests
tutorials/sphinx-tutorials/world_model.py — tutorial
Modified files:
torchrl/modules/__init__.py — export WorldModel
torchrl/objectives/__init__.py — export WorldModelLoss
docs/source/reference/modules_models.rst — add entry
docs/source/reference/objectives_other.rst — add entry
Overview
This RFC proposes adding
torchrl.modules.WorldModel, a composable, TensorDict-native abstraction for learned environment dynamics, to serve as a general foundation for model-based RL workflows in TorchRL.Motivation
TorchRL has strong abstractions for environments (
EnvBase), policies (TensorDictModule), data collection (Collector), replay buffers, and loss modules. Model-based RL, however, still requires users to write substantial glue code.What exists today
TorchRL ships Dreamer-specific components:
RSSMPrior,RSSMPosterior,RSSMRollout— RSSM-specific dynamics (also V3 discrete variants)ObsEncoder,ObsDecoder— Conv stacks for pixel observationsDreamerActor— distribution head for the imagination policyWorldModelWrapper— thinTensorDictSequential(transition_model, reward_model)with no rollout interfaceModelBasedEnvBase/DreamerEnv— environment wrapper that callsworld_model(td)in_step()DreamerModelLoss,DreamerActorLoss,DreamerValueLoss— losses coupled tightly to the RSSM + DreamerEnv stackThe gap
There is no general abstraction for learned dynamics. Users implementing any world-model-based algorithm (MBPO, TD-MPC, DreamerV3, PlaNet, Iris, CWVAE, or a custom learned env) must:
TensorDictModuleand wirein_keys/out_keysby handrollout(policy, horizon)interface)DreamerActorLoss) are tightly coupled toDreamerEnv, not to a generic dynamics interfacemake_dreamer()insota-implementations/dreamer/dreamer_utils.py) before the training loop startsThe result: world-model workflows are effectively siloed from the rest of TorchRL. Imagined rollouts cannot be dropped into standard collectors, sequence replay buffers, or existing loss modules without significant adaptation.
Proposed API
WorldModelKey methods:
WorldModelLossEnd-to-end training sketch
Compatibility with Existing TorchRL Infrastructure
TensorDictReplayBufferSliceSamplerGAE/MultiAgentGAEClipPPOLoss,SACLossEnvBasespecsWorldModeloptionally validates againstobservation_spec,action_specCollector/SyncDataCollectorDreamerEnv(existing)WorldModel.rolloutis an alternative that avoids theEnvBaseoverheadMigration Path for Dreamer Users
Existing
DreamerModelLoss,DreamerActorLoss,DreamerValueLossare unchanged. Users may optionally migrate:Scope of Implementation
New files:
torchrl/modules/world_model.py—WorldModeltorchrl/objectives/world_model_loss.py—WorldModelLosstest/test_world_model.py— unit teststutorials/sphinx-tutorials/world_model.py— tutorialModified files:
torchrl/modules/__init__.py— exportWorldModeltorchrl/objectives/__init__.py— exportWorldModelLossdocs/source/reference/modules_models.rst— add entrydocs/source/reference/objectives_other.rst— add entry