Skip to content

Commit 6bca2f0

Browse files
authored
[Fix] road_traffic map_type="2" crashes on GPU reset (#175)
* add tests failing for road_traffic map type 2 * fix failing tests for road_traffic map type 2
1 parent 7f936ab commit 6bca2f0

2 files changed

Lines changed: 73 additions & 2 deletions

File tree

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) ProrokLab.
2+
#
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import pytest
7+
import torch
8+
9+
from vmas import make_env
10+
11+
12+
class TestRoadTraffic:
13+
def setup_env(self, n_envs, device="cpu", **kwargs) -> None:
14+
self.env = make_env(
15+
scenario="road_traffic",
16+
num_envs=n_envs,
17+
device=device,
18+
continuous_actions=True,
19+
**kwargs,
20+
)
21+
self.env.seed(0)
22+
23+
def _seed_buffer(self, device):
24+
"""Seed initial_state_buffer with a real state and force it to always be used."""
25+
scenario = self.env.scenario
26+
buf = scenario.initial_state_buffer
27+
buf.add(scenario.state_buffer.get_latest(n=1)[0])
28+
buf.probability_use_recording = torch.tensor(1.0, device=device)
29+
30+
@pytest.mark.parametrize("map_type", ["1", "2"])
31+
def test_map_type_runs(self, map_type, n_envs=4, n_steps=10):
32+
self.setup_env(n_envs=n_envs, map_type=map_type)
33+
self.env.reset()
34+
for _ in range(n_steps):
35+
actions = [
36+
torch.zeros(n_envs, agent.action.action_size)
37+
for agent in self.env.agents
38+
]
39+
obs, rews, dones, _ = self.env.step(actions)
40+
if dones.any():
41+
for env_index, done in enumerate(dones):
42+
if done:
43+
self.env.reset_at(env_index)
44+
45+
def test_map_type_2_reset_uses_buffer(self, n_envs=4):
46+
self.setup_env(n_envs=n_envs, map_type="2")
47+
self.env.reset()
48+
actions = [
49+
torch.zeros(n_envs, agent.action.action_size)
50+
for agent in self.env.agents
51+
]
52+
self.env.step(actions)
53+
self._seed_buffer(device="cpu")
54+
self.env.reset_at(0)
55+
56+
@pytest.mark.skipif(
57+
not torch.cuda.is_available(),
58+
reason="GPU required to reproduce road_traffic map_type=2 device bugs",
59+
)
60+
def test_gpu_map_type_2_rand_device(self, n_envs=4):
61+
self.setup_env(n_envs=n_envs, device="cuda", map_type="2")
62+
self.env.reset()
63+
actions = [
64+
torch.zeros(n_envs, agent.action.action_size, device="cuda")
65+
for agent in self.env.agents
66+
]
67+
self.env.step(actions)
68+
self._seed_buffer(device="cuda")
69+
self.env.reset_at(0)

vmas/scenarios/road_traffic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,8 @@ def reset_world_at(self, env_index: int = None, agent_index: int = None):
948948
if (
949949
(self.parameters.map_type == "2")
950950
and (
951-
torch.rand(1) < self.initial_state_buffer.probability_use_recording
951+
torch.rand(1, device=self.world.device)
952+
< self.initial_state_buffer.probability_use_recording
952953
)
953954
and (self.initial_state_buffer.valid_size >= 1)
954955
):
@@ -1113,6 +1114,7 @@ def reset_init_state(
11131114
agents[i_agent].set_pos(initial_state[i_agent, 0:2], batch_index=env_i)
11141115
agents[i_agent].set_rot(initial_state[i_agent, 2], batch_index=env_i)
11151116
agents[i_agent].set_vel(initial_state[i_agent, 3:5], batch_index=env_i)
1117+
return ref_path, path_id
11161118
else:
11171119
is_feasible_initial_position_found = False
11181120
# Ramdomly generate initial states for each agent
@@ -2300,7 +2302,7 @@ def done(self):
23002302
is_collision_with_lanelets = self.collisions.with_lanelets.any(dim=-1)
23012303

23022304
if self.parameters.map_type == "2": # Record into the initial state buffer
2303-
if torch.rand(1) > (
2305+
if torch.rand(1, device=self.world.device) > (
23042306
1 - self.initial_state_buffer.probability_record
23052307
): # Only a certain probability to record
23062308
for env_collide in torch.where(is_collision_with_agents)[0]:

0 commit comments

Comments
 (0)