Skip to content

Commit ef49663

Browse files
Hide demo rows and fix world-model agents
1 parent 9dac5cd commit ef49663

File tree

8 files changed

+76
-8
lines changed

8 files changed

+76
-8
lines changed

agents/worldmodel_agents/imagination_mpc.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
from __future__ import annotations
22

3-
import copy
4-
53
import numpy as np
4+
import torch
65
from worldmodel_models.registry import create_world_model
76
from worldmodel_planners.mpc_cem import MPCCEMPlanner
87

98
from worldmodel_agents.base import AgentConfig, BaseAgent
109

1110

11+
def _clone_state(value):
12+
if isinstance(value, torch.Tensor):
13+
return value.detach().clone()
14+
if isinstance(value, dict):
15+
return {key: _clone_state(item) for key, item in value.items()}
16+
if isinstance(value, list):
17+
return [_clone_state(item) for item in value]
18+
if isinstance(value, tuple):
19+
return tuple(_clone_state(item) for item in value)
20+
return value
21+
22+
1223
class ImaginationMPCAgent(BaseAgent):
1324
def __init__(self, config: AgentConfig | None = None):
1425
super().__init__(config=config)
@@ -40,7 +51,7 @@ def rollout_fn(state, action_seq):
4051
result = self.planner.plan(
4152
root_state=self.latent,
4253
rollout_fn=rollout_fn,
43-
clone_state_fn=copy.deepcopy,
54+
clone_state_fn=_clone_state,
4455
)
4556
self.last_imagined_transitions = result.imagined_transitions
4657
self.last_planner_trace = result.trace

agents/worldmodel_agents/search_mcts.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
from __future__ import annotations
22

3-
import copy
4-
53
import numpy as np
4+
import torch
65
from worldmodel_models.registry import create_world_model
76
from worldmodel_planners.mcts import MCTSPlanner
87

98
from worldmodel_agents.base import AgentConfig, BaseAgent
109

1110

11+
def _clone_state(value):
12+
if isinstance(value, torch.Tensor):
13+
return value.detach().clone()
14+
if isinstance(value, dict):
15+
return {key: _clone_state(item) for key, item in value.items()}
16+
if isinstance(value, list):
17+
return [_clone_state(item) for item in value]
18+
if isinstance(value, tuple):
19+
return tuple(_clone_state(item) for item in value)
20+
return value
21+
22+
1223
class SearchMCTSAgent(BaseAgent):
1324
"""Minimal MuZero-style skeleton: learned model + MCTS planning."""
1425

@@ -42,7 +53,7 @@ def transition_fn(state, action):
4253
result = self.planner.plan(
4354
root_state=self.latent,
4455
transition_fn=transition_fn,
45-
clone_state_fn=copy.deepcopy,
56+
clone_state_fn=_clone_state,
4657
)
4758

4859
self.last_imagined_transitions = result.imagined_transitions

server/worldmodel_server/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,16 @@ def leaderboard(
263263
track: str = Query(default="test"),
264264
env: str | None = Query(default=None),
265265
agent: str | None = Query(default=None),
266+
include_demo: bool = Query(default=False),
266267
session: Session = Depends(get_session),
267268
):
268269
q = select(RunEntry).where(RunEntry.status == "uploaded", RunEntry.track == track)
269270
if env:
270271
q = q.where(RunEntry.env == env)
271272
if agent:
272273
q = q.where(RunEntry.agent == agent)
274+
if not include_demo:
275+
q = q.where(RunEntry.created_by != "demo-seed")
273276

274277
rows = session.scalars(q.order_by(desc(RunEntry.created_at))).all()
275278
out: list[LeaderboardRow] = []

tests/test_agent_registry.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from worldmodel_agents.registry import create_agent
4+
from worldmodel_gym.envs.registry import make_env
5+
6+
7+
def test_search_mcts_can_act_on_memory_maze_observation():
8+
env = make_env("memory_maze", obs_mode="both", max_steps=8)
9+
obs, info = env.reset(seed=123)
10+
agent = create_agent("search_mcts")
11+
agent.reset(seed=123)
12+
13+
action = agent.act(obs, info)
14+
15+
assert isinstance(action, int)
16+
17+
18+
def test_imagination_mpc_can_act_on_switch_quest_observation():
19+
env = make_env("switch_quest", obs_mode="both", max_steps=8)
20+
obs, info = env.reset(seed=123)
21+
agent = create_agent("imagination_mpc")
22+
agent.reset(seed=123)
23+
24+
action = agent.act(obs, info)
25+
26+
assert isinstance(action, int)

tests/test_server_app.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,26 @@ def test_seed_demo_data_populates_leaderboard(tmp_path, monkeypatch):
4444
app = modules["worldmodel_server.main"].app
4545

4646
with TestClient(app) as client:
47-
response = client.get("/api/leaderboard?track=test")
47+
response = client.get("/api/leaderboard?track=test&include_demo=true")
4848

4949
assert response.status_code == 200
5050
rows = response.json()
5151
assert len(rows) >= 2
5252
assert any(row["agent"] == "demo-mpc" for row in rows)
5353

5454

55+
def test_public_leaderboard_hides_seeded_demo_rows_by_default(tmp_path, monkeypatch):
56+
modules = load_test_modules(monkeypatch, tmp_path, seed_demo=True)
57+
app = modules["worldmodel_server.main"].app
58+
59+
with TestClient(app) as client:
60+
response = client.get("/api/leaderboard?track=test")
61+
62+
assert response.status_code == 200
63+
rows = response.json()
64+
assert rows == []
65+
66+
5567
def test_api_key_can_create_and_upload_run(tmp_path, monkeypatch):
5668
modules = load_test_modules(monkeypatch, tmp_path)
5769
app = modules["worldmodel_server.main"].app

worldmodels/worldmodel_models/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ def __init__(self, config: ModelConfig):
3333
super().__init__()
3434
self.config = config
3535
self.device = torch.device(config.device)
36+
self.optimizer: torch.optim.Optimizer | None = None
37+
38+
def initialize_optimizer(self) -> None:
3639
self.to(self.device)
37-
self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
40+
self.optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr)
3841

3942
def _obs_tensor(self, obs) -> torch.Tensor:
4043
arr = to_numpy_obs(obs)

worldmodels/worldmodel_models/deterministic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, config: ModelConfig | None = None):
2424
self.obs_head = torch.nn.Linear(c.latent_dim, c.obs_dim)
2525
self.reward_head = torch.nn.Linear(c.latent_dim, 1)
2626
self.done_head = torch.nn.Linear(c.latent_dim, 1)
27+
self.initialize_optimizer()
2728

2829
def init_state(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
2930
latent = torch.zeros(

worldmodels/worldmodel_models/stochastic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, config: ModelConfig | None = None):
2222
self.reward_head = torch.nn.Linear(c.latent_dim, 1)
2323
self.done_head = torch.nn.Linear(c.latent_dim, 1)
2424
self.obs_head = torch.nn.Linear(c.latent_dim, c.obs_dim)
25+
self.initialize_optimizer()
2526

2627
def init_state(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
2728
h = torch.zeros((batch_size, self.config.latent_dim), device=self.device)

0 commit comments

Comments
 (0)