Skip to content

Commit 6bf27db

Browse files
[Feature] Update scenarios (#78)
* Amend * Amend * Amend
1 parent b41c141 commit 6bf27db

File tree

7 files changed

+270
-128
lines changed

7 files changed

+270
-128
lines changed

vmas/scenarios/dispersion.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023.
1+
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44

@@ -15,11 +15,14 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
1515
n_agents = kwargs.get("n_agents", 4)
1616
self.share_reward = kwargs.get("share_reward", False)
1717
self.penalise_by_time = kwargs.get("penalise_by_time", False)
18-
19-
n_food = n_agents
18+
self.food_radius = kwargs.get("food_radius", 0.05)
19+
self.pos_range = kwargs.get("pos_range", 1.0)
20+
n_food = kwargs.get("n_food", n_agents)
2021

2122
# Make world
22-
world = World(batch_dim, device)
23+
world = World(
24+
batch_dim, device, x_semidim=self.pos_range, y_semidim=self.pos_range
25+
)
2326
# Add agents
2427
for i in range(n_agents):
2528
# Constraint: all agents have same action range and multiplier
@@ -32,9 +35,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
3235
# Add landmarks
3336
for i in range(n_food):
3437
food = Landmark(
35-
name=f"food {i}",
38+
name=f"food_{i}",
3639
collide=False,
37-
shape=Sphere(radius=0.02),
40+
shape=Sphere(radius=self.food_radius),
3841
color=Color.GREEN,
3942
)
4043
world.add_landmark(food)
@@ -58,8 +61,8 @@ def reset_world_at(self, env_index: int = None):
5861
device=self.world.device,
5962
dtype=torch.float32,
6063
).uniform_(
61-
-1.0,
62-
1.0,
64+
-self.pos_range,
65+
self.pos_range,
6366
),
6467
batch_index=env_index,
6568
)

vmas/scenarios/dropout.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
# Copyright (c) 2022-2023.
1+
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
55
from typing import Dict
66

77
import torch
88
from torch import Tensor
9-
109
from vmas import render_interactively
1110
from vmas.simulator.core import Agent, Landmark, Sphere, World
1211
from vmas.simulator.scenario import BaseScenario
13-
from vmas.simulator.utils import Color
12+
from vmas.simulator.utils import Color, ScenarioUtils
1413

1514
DEFAULT_ENERGY_COEFF = 0.02
1615

@@ -21,19 +20,24 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2120
self.energy_coeff = kwargs.get(
2221
"energy_coeff", DEFAULT_ENERGY_COEFF
2322
) # Weight of team energy penalty
23+
self.start_same_point = kwargs.get("start_same_point", False)
24+
self.agent_radius = 0.05
25+
self.goal_radius = 0.03
2426

2527
# Make world
2628
world = World(batch_dim, device)
2729
# Add agents
2830
for i in range(n_agents):
2931
# Constraint: all agents have same action range and multiplier
30-
agent = Agent(name=f"agent_{i}", collide=False)
32+
agent = Agent(
33+
name=f"agent_{i}", collide=False, shape=Sphere(radius=self.agent_radius)
34+
)
3135
world.add_agent(agent)
3236
# Add landmarks
3337
goal = Landmark(
3438
name="goal",
3539
collide=False,
36-
shape=Sphere(radius=0.03),
40+
shape=Sphere(radius=self.goal_radius),
3741
color=Color.GREEN,
3842
)
3943
world.add_landmark(goal)
@@ -45,36 +49,42 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
4549
return world
4650

4751
def reset_world_at(self, env_index: int = None):
48-
for agent in self.world.agents:
49-
# Random pos between -1 and 1
50-
agent.set_pos(
51-
torch.zeros(
52-
(1, self.world.dim_p)
53-
if env_index is not None
54-
else (self.world.batch_dim, self.world.dim_p),
52+
if self.start_same_point:
53+
for agent in self.world.agents:
54+
agent.set_pos(
55+
torch.zeros(
56+
(1, 2) if env_index is not None else (self.world.batch_dim, 2),
57+
device=self.world.device,
58+
dtype=torch.float,
59+
),
60+
batch_index=env_index,
61+
)
62+
ScenarioUtils.spawn_entities_randomly(
63+
self.world.landmarks,
64+
self.world,
65+
env_index,
66+
self.goal_radius + self.agent_radius + 0.01,
67+
x_bounds=(-1, 1),
68+
y_bounds=(-1, 1),
69+
occupied_positions=torch.zeros(
70+
1 if env_index is not None else self.world.batch_dim,
71+
1,
72+
2,
5573
device=self.world.device,
56-
dtype=torch.float32,
57-
).uniform_(
58-
-1.0,
59-
1.0,
74+
dtype=torch.float,
6075
),
61-
batch_index=env_index,
6276
)
63-
for landmark in self.world.landmarks:
64-
# Random pos between -1 and 1
65-
landmark.set_pos(
66-
torch.zeros(
67-
(1, self.world.dim_p)
68-
if env_index is not None
69-
else (self.world.batch_dim, self.world.dim_p),
70-
device=self.world.device,
71-
dtype=torch.float32,
72-
).uniform_(
73-
-1.0,
74-
1.0,
75-
),
76-
batch_index=env_index,
77+
else:
78+
ScenarioUtils.spawn_entities_randomly(
79+
self.world.policy_agents + self.world.landmarks,
80+
self.world,
81+
env_index,
82+
self.goal_radius + self.agent_radius + 0.01,
83+
x_bounds=(-1, 1),
84+
y_bounds=(-1, 1),
7785
)
86+
87+
for landmark in self.world.landmarks:
7888
if env_index is None:
7989
landmark.eaten = torch.full(
8090
(self.world.batch_dim,), False, device=self.world.device

vmas/scenarios/give_way.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
# Copyright (c) 2022-2023.
1+
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
55

66
import torch
7-
87
from vmas import render_interactively
98
from vmas.simulator.core import Agent, World, Landmark, Sphere, Line, Box
109
from vmas.simulator.scenario import BaseScenario
@@ -21,6 +20,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2120
self.linear_friction = kwargs.get("linear_friction", 0.1)
2221
self.mirror_passage = kwargs.get("mirror_passage", False)
2322
self.done_on_completion = kwargs.get("done_on_completion", False)
23+
self.observe_rel_pos = kwargs.get("observe_rel_pos", False)
2424

2525
# Reward params
2626
self.pos_shaping_factor = kwargs.get("pos_shaping_factor", 1.0)
@@ -63,7 +63,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
6363

6464
# Add agents
6565
blue_agent = Agent(
66-
name="blue_agent_0",
66+
name="agent_0",
6767
rotatable=False,
6868
linear_friction=self.linear_friction,
6969
shape=Sphere(radius=self.agent_radius)
@@ -79,7 +79,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
7979
blue_agent, world, controller_params, "standard"
8080
)
8181
blue_goal = Landmark(
82-
name="blue goal",
82+
name="goal_0",
8383
collide=False,
8484
shape=Sphere(radius=self.agent_radius / 2),
8585
color=Color.BLUE,
@@ -89,7 +89,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
8989
world.add_landmark(blue_goal)
9090

9191
green_agent = Agent(
92-
name="green_agent_0",
92+
name="agent_1",
9393
color=Color.GREEN,
9494
linear_friction=self.linear_friction,
9595
shape=Sphere(radius=self.agent_radius)
@@ -106,7 +106,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
106106
green_agent, world, controller_params, "standard"
107107
)
108108
green_goal = Landmark(
109-
name="green goal",
109+
name="goal_1",
110110
collide=False,
111111
shape=Sphere(radius=self.agent_radius / 2),
112112
color=Color.GREEN,
@@ -302,11 +302,17 @@ def reward(self, agent: Agent):
302302
)
303303

304304
def observation(self, agent: Agent):
305+
rel = []
306+
for a in self.world.agents:
307+
if a != agent:
308+
rel.append(agent.state.pos - a.state.pos)
309+
305310
observations = [
306311
agent.state.pos,
307312
agent.state.vel,
308-
agent.state.pos,
309313
]
314+
if self.observe_rel_pos:
315+
observations += rel
310316

311317
if self.obs_noise > 0:
312318
for i, obs in enumerate(observations):

0 commit comments

Comments
 (0)