-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathexample_agent.py
More file actions
177 lines (150 loc) · 4.9 KB
/
example_agent.py
File metadata and controls
177 lines (150 loc) · 4.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
#
# Repository:
# https://github.com/biological-alignment-benchmarks/biological-alignment-gridworlds-benchmarks
import csv
import logging
from typing import List, Optional, Tuple
from collections import defaultdict
from gymnasium.spaces import Discrete
from omegaconf import DictConfig
import numpy as np
import numpy.typing as npt
from aintelope.environments.savanna_safetygrid import ACTION_RELATIVE_COORDINATE_MAP
from aintelope.agents.abstract_agent import Agent
from aintelope.aintelope_typing import ObservationFloat, PettingZooEnv
from aintelope.training.dqn_training import Trainer
from typing import Union
import gymnasium as gym
from pettingzoo import AECEnv, ParallelEnv
PettingZooEnv = Union[AECEnv, ParallelEnv]
Environment = Union[gym.Env, PettingZooEnv]
logger = logging.getLogger("aintelope.agents.example_agent")
class ExampleAgent(Agent):
"""Example agent class"""
def __init__(
self,
agent_id: str,
trainer: Trainer,
env: Environment = None,
cfg: DictConfig = None,
**kwargs,
) -> None:
self.id = agent_id
self.trainer = trainer
self.env = env
self.cfg = cfg
self.done = False
self.last_action = None
def reset(self, state, info, env_class) -> None:
"""Resets self and updates the state."""
self.done = False
self.last_action = None
self.state = state
self.info = info
self.env_class = env_class
def get_action(
self,
observation: Tuple[
npt.NDArray[ObservationFloat], npt.NDArray[ObservationFloat]
] = None,
info: dict = {},
step: int = 0,
env_layout_seed: int = 0,
episode: int = 0,
pipeline_cycle: int = 0,
test_mode: bool = False,
*args,
**kwargs,
) -> Optional[int]:
"""Given an observation, ask your model what to do. State is needed to be
given here as other agents have changed the state!
Returns:
action (Optional[int]): index of action
"""
if self.done:
return None
action_space = self.env.action_spaces[self.id]
if isinstance(action_space, Discrete):
min_action = action_space.start
max_action = action_space.start + action_space.n - 1
else:
min_action = action_space.min_action
max_action = action_space.max_action
action = action_space.sample()
# print(f"Action: {action}")
self.last_action = action
return action
def update(
self,
env: PettingZooEnv = None,
observation: Tuple[
npt.NDArray[ObservationFloat], npt.NDArray[ObservationFloat]
] = None,
info: dict = {},
score: float = 0.0,
done: bool = False,
test_mode: bool = False,
) -> list:
"""
Takes observations and updates trainer on perceived experiences.
Args:
env: Environment
observation: Tuple[ObservationArray, ObservationArray]
score: Only baseline uses score as a reward
done: boolean whether run is done
Returns:
agent_id (str): same as elsewhere ("agent_0" among them)
state (Tuple[npt.NDArray[ObservationFloat], npt.NDArray[ObservationFloat]]): input for the net
action (int): index of action
reward (float): reward signal
done (bool): if agent is done
next_state (npt.NDArray[ObservationFloat]): input for the net
"""
assert self.last_action is not None
next_state = observation
next_info = info
# TODO: implement any learning mechanisms here
reward = score # This will be sent to the log file
event = [self.id, self.state, self.last_action, reward, done, next_state]
self.state = next_state
self.info = info
return event
def init_model(
self,
observation_shape,
action_space,
unit_test_mode: bool,
checkpoint: Optional[str] = None,
*args,
**kwargs,
):
self.trainer.add_agent(
self.id,
observation_shape,
action_space,
unit_test_mode,
checkpoint,
*args,
**kwargs,
)
def save_model(
self,
i_episode,
path,
experiment_name,
use_separate_models_for_each_experiment,
*args,
**kwargs,
):
self.trainer.save_model(
self.id,
i_episode,
path,
experiment_name,
use_separate_models_for_each_experiment,
*args,
**kwargs,
)