-
Notifications
You must be signed in to change notification settings - Fork 109
Expand file tree
/
Copy pathmake_env.py
More file actions
101 lines (88 loc) · 4.67 KB
/
make_env.py
File metadata and controls
101 lines (88 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) ProrokLab.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Union
from vmas import scenarios
from vmas.simulator.environment import Environment, Wrapper
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.utils import DEVICE_TYPING
def make_env(
scenario: Union[str, BaseScenario],
num_envs: int,
device: DEVICE_TYPING = "cpu",
continuous_actions: bool = True,
wrapper: Optional[Union[Wrapper, str]] = None,
max_steps: Optional[int] = None,
seed: Optional[int] = None,
dict_spaces: bool = False,
multidiscrete_actions: bool = False,
clamp_actions: bool = False,
grad_enabled: bool = False,
terminated_truncated: bool = False,
wrapper_kwargs: Optional[dict] = None,
**kwargs,
):
"""Create a vmas environment.
Args:
scenario (Union[str, BaseScenario]): Scenario to load.
Can be the name of a file in `vmas.scenarios` folder or a :class:`~vmas.simulator.scenario.BaseScenario` class,
num_envs (int): Number of vectorized simulation environments. VMAS performs vectorized simulations using PyTorch.
This argument indicates the number of vectorized environments that should be simulated in a batch. It will also
determine the batch size of the environment.
device (Union[str, int, torch.device], optional): Device for simulation. All the tensors created by VMAS
will be placed on this device. Default is ``"cpu"``,
continuous_actions (bool, optional): Whether to use continuous actions. If ``False``, actions
will be discrete. The number of actions and their size will depend on the chosen scenario. Default is ``True``,
wrapper (Union[Wrapper, str], optional): Wrapper class to use. For example, it can be
``"rllib"``, ``"gym"``, ``"gymnasium"``, ``"gymnasium_vec"``. Default is ``None``.
max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each VMAS scenario can
be terminating or not. If ``max_steps`` is specified,
the scenario is also terminated whenever this horizon is reached,
seed (int, optional): Seed for the environment. Defaults to ``None``,
dict_spaces (bool, optional): Weather to use dictionaries spaces with format ``{"agent_name": tensor, ...}``
for obs, rewards, and info instead of tuples. Defaults to ``False``: obs, rewards, info are tuples with length number of agents,
multidiscrete_actions (bool, optional): Whether to use multidiscrete action spaces when ``continuous_actions=False``.
Default is ``False``: the action space will be ``Discrete``, and it will be the cartesian product of the
discrete action spaces available to an agent,
clamp_actions (bool, optional): Weather to clamp input actions to their range instead of throwing
an error when ``continuous_actions==True`` and actions are out of bounds,
grad_enabled (bool, optional): If ``True`` the simulator will not call ``detach()`` on input actions and gradients can
be taken from the simulator output. Default is ``False``.
terminated_truncated (bool, optional): Weather to use terminated and truncated flags in the output of the step method (or single done).
Default is ``False``.
wrapper_kwargs (dict, optional): Keyword arguments to pass to the wrapper class. Default is ``{}``.
**kwargs (dict, optional): Keyword arguments to pass to the :class:`~vmas.simulator.scenario.BaseScenario` class.
Examples:
>>> from vmas import make_env
>>> env = make_env(
... "waterfall",
... num_envs=3,
... num_agents=2,
... )
>>> print(env.reset())
"""
# load scenario from script
if isinstance(scenario, str):
if not scenario.endswith(".py"):
scenario += ".py"
scenario = scenarios.load(scenario).Scenario()
env = Environment(
scenario,
num_envs=num_envs,
device=device,
continuous_actions=continuous_actions,
max_steps=max_steps,
seed=seed,
dict_spaces=dict_spaces,
multidiscrete_actions=multidiscrete_actions,
clamp_actions=clamp_actions,
grad_enabled=grad_enabled,
terminated_truncated=terminated_truncated,
**kwargs,
)
if wrapper is not None and isinstance(wrapper, str):
wrapper = Wrapper[wrapper.upper()]
if wrapper_kwargs is None:
wrapper_kwargs = {}
return wrapper.get_env(env, **wrapper_kwargs) if wrapper is not None else env