1- # Copyright (c) 2022-2023 .
1+ # Copyright (c) 2022-2024 .
22# ProrokLab (https://www.proroklab.org/)
33# All rights reserved.
44import math
55from typing import Dict
66
77import torch
88from torch import Tensor
9-
109from vmas import render_interactively
1110from vmas .simulator .core import Agent , Landmark , Sphere , World
1211from vmas .simulator .scenario import BaseScenario
13- from vmas .simulator .utils import Color
12+ from vmas .simulator .utils import Color , ScenarioUtils
1413
1514DEFAULT_ENERGY_COEFF = 0.02
1615
@@ -21,19 +20,24 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2120 self .energy_coeff = kwargs .get (
2221 "energy_coeff" , DEFAULT_ENERGY_COEFF
2322 ) # Weight of team energy penalty
23+ self .start_same_point = kwargs .get ("start_same_point" , False )
24+ self .agent_radius = 0.05
25+ self .goal_radius = 0.03
2426
2527 # Make world
2628 world = World (batch_dim , device )
2729 # Add agents
2830 for i in range (n_agents ):
2931 # Constraint: all agents have same action range and multiplier
30- agent = Agent (name = f"agent_{ i } " , collide = False )
32+ agent = Agent (
33+ name = f"agent_{ i } " , collide = False , shape = Sphere (radius = self .agent_radius )
34+ )
3135 world .add_agent (agent )
3236 # Add landmarks
3337 goal = Landmark (
3438 name = "goal" ,
3539 collide = False ,
36- shape = Sphere (radius = 0.03 ),
40+ shape = Sphere (radius = self . goal_radius ),
3741 color = Color .GREEN ,
3842 )
3943 world .add_landmark (goal )
@@ -45,36 +49,42 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
4549 return world
4650
4751 def reset_world_at (self , env_index : int = None ):
48- for agent in self .world .agents :
49- # Random pos between -1 and 1
50- agent .set_pos (
51- torch .zeros (
52- (1 , self .world .dim_p )
53- if env_index is not None
54- else (self .world .batch_dim , self .world .dim_p ),
52+ if self .start_same_point :
53+ for agent in self .world .agents :
54+ agent .set_pos (
55+ torch .zeros (
56+ (1 , 2 ) if env_index is not None else (self .world .batch_dim , 2 ),
57+ device = self .world .device ,
58+ dtype = torch .float ,
59+ ),
60+ batch_index = env_index ,
61+ )
62+ ScenarioUtils .spawn_entities_randomly (
63+ self .world .landmarks ,
64+ self .world ,
65+ env_index ,
66+ self .goal_radius + self .agent_radius + 0.01 ,
67+ x_bounds = (- 1 , 1 ),
68+ y_bounds = (- 1 , 1 ),
69+ occupied_positions = torch .zeros (
70+ 1 if env_index is not None else self .world .batch_dim ,
71+ 1 ,
72+ 2 ,
5573 device = self .world .device ,
56- dtype = torch .float32 ,
57- ).uniform_ (
58- - 1.0 ,
59- 1.0 ,
74+ dtype = torch .float ,
6075 ),
61- batch_index = env_index ,
6276 )
63- for landmark in self .world .landmarks :
64- # Random pos between -1 and 1
65- landmark .set_pos (
66- torch .zeros (
67- (1 , self .world .dim_p )
68- if env_index is not None
69- else (self .world .batch_dim , self .world .dim_p ),
70- device = self .world .device ,
71- dtype = torch .float32 ,
72- ).uniform_ (
73- - 1.0 ,
74- 1.0 ,
75- ),
76- batch_index = env_index ,
77+ else :
78+ ScenarioUtils .spawn_entities_randomly (
79+ self .world .policy_agents + self .world .landmarks ,
80+ self .world ,
81+ env_index ,
82+ self .goal_radius + self .agent_radius + 0.01 ,
83+ x_bounds = (- 1 , 1 ),
84+ y_bounds = (- 1 , 1 ),
7785 )
86+
87+ for landmark in self .world .landmarks :
7888 if env_index is None :
7989 landmark .eaten = torch .full (
8090 (self .world .batch_dim ,), False , device = self .world .device
0 commit comments