|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import copy |
4 | | - |
5 | 3 | import numpy as np |
| 4 | +import torch |
6 | 5 | from worldmodel_models.registry import create_world_model |
7 | 6 | from worldmodel_planners.mpc_cem import MPCCEMPlanner |
8 | 7 |
|
9 | 8 | from worldmodel_agents.base import AgentConfig, BaseAgent |
10 | 9 |
|
11 | 10 |
|
| 11 | +def _clone_state(value): |
| 12 | + if isinstance(value, torch.Tensor): |
| 13 | + return value.detach().clone() |
| 14 | + if isinstance(value, dict): |
| 15 | + return {key: _clone_state(item) for key, item in value.items()} |
| 16 | + if isinstance(value, list): |
| 17 | + return [_clone_state(item) for item in value] |
| 18 | + if isinstance(value, tuple): |
| 19 | + return tuple(_clone_state(item) for item in value) |
| 20 | + return value |
| 21 | + |
| 22 | + |
12 | 23 | class ImaginationMPCAgent(BaseAgent): |
13 | 24 | def __init__(self, config: AgentConfig | None = None): |
14 | 25 | super().__init__(config=config) |
@@ -40,7 +51,7 @@ def rollout_fn(state, action_seq): |
40 | 51 | result = self.planner.plan( |
41 | 52 | root_state=self.latent, |
42 | 53 | rollout_fn=rollout_fn, |
43 | | - clone_state_fn=copy.deepcopy, |
| 54 | + clone_state_fn=_clone_state, |
44 | 55 | ) |
45 | 56 | self.last_imagined_transitions = result.imagined_transitions |
46 | 57 | self.last_planner_trace = result.trace |
|
0 commit comments