-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplay_env.py
More file actions
179 lines (146 loc) · 6.2 KB
/
play_env.py
File metadata and controls
179 lines (146 loc) · 6.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# TODO: I'm kinda using this project to pilot the whole config/network/example separation
# The motivation behind this is that the file sizes are getting large and its increasing cognitive load :(
# Import Python Standard Libraries
from threading import Thread, Lock
from argparse import ArgumentParser
from collections import namedtuple
from datetime import datetime
# Import Pytorch related packages for NNs
from numpy import array as np_array
from numpy import save as np_save
import torch
from torch.optim import Adam
# Import my custom RL library
import rltorch
from rltorch.memory import PrioritizedReplayMemory, ReplayMemory, iDQfDMemory
from rltorch.action_selector import EpsilonGreedySelector, ArgMaxSelector
import rltorch.env as E
import rltorch.network as rn
# Import OpenAI gym and related packages
from gym import make as makeEnv
from gym import Wrapper as GymWrapper
from gym.wrappers import Monitor as GymMonitor
import play
#
## Networks (Probably want to move this to config file)
#
from networks import Value
#
## Play Related Classes
#
class PlayClass(Thread):
def __init__(self, env, action_selector, agent, sneaky_env, sneaky_actor, record_lock, config):
super(PlayClass, self).__init__()
self.play = play.Play(env, action_selector, agent, sneaky_env, sneaky_actor, record_lock, config)
def run(self):
self.play.start()
class Record(GymWrapper):
def __init__(self, env, memory, lock, args):
GymWrapper.__init__(self, env)
self.memory = memory
self.lock = lock # Lock for memory access
self.skipframes = args['skip']
self.environment_name = args['environment_name']
self.logdir = args['logdir']
self.current_i = 0
def reset(self):
return self.env.reset()
def step(self, action):
state = self.env.env._get_obs()
next_state, reward, done, info = self.env.step(action)
self.current_i += 1
# Don't add to memory until a certain number of frames is reached
if self.current_i % self.skipframes == 0:
self.lock.acquire()
self.memory.append((state, action, reward, next_state, done))
self.lock.release()
self.current_i = 0
return next_state, reward, done, info
def log_transitions(self):
if len(self.memory) > 0:
basename = self.logdir + "/{}.{}".format(self.environment_name, datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
print("Base Filename: ", basename, flush = True)
state, action, reward, next_state, done = zip(*self.memory)
np_save(basename + "-state.npy", np_array(state), allow_pickle = False)
np_save(basename + "-action.npy", np_array(action), allow_pickle = False)
np_save(basename + "-reward.npy", np_array(reward), allow_pickle = False)
np_save(basename + "-nextstate.npy", np_array(next_state), allow_pickle = False)
np_save(basename + "-done.npy", np_array(done), allow_pickle = False)
self.memory.clear()
## Parsing arguments
parser = ArgumentParser(description="Play and log the environment")
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
parser.add_argument("--fps", type=int, help="Number of frames per second")
parser.add_argument("--model", type=str, help = "The path location of the PyTorch model")
args = vars(parser.parse_args())
## Main configuration for script
from config import config
# Environment name and log directory is vital so show help message and exit if not provided
if args['environment_name'] is None or args['logdir'] is None:
parser.print_help()
exit(1)
# Number of frames to skip when recording and fps can have sane defaults
if args['skip'] is None:
args['skip'] = 3
if 'fps' not in args:
args['fps'] = 30
def wrap_preprocessing(env, MaxAndSkipEnv = False):
env = E.NoopResetEnv(
E.EpisodicLifeEnv(env),
noop_max = 30
)
if MaxAndSkipEnv:
env = E.MaxAndSkipEnv(env, skip = 4)
return E.ClippedRewardsWrapper(
E.FrameStack(
E.TorchWrap(
E.ProcessFrame84(
E.FireResetEnv(env)
)
)
, 4)
)
## Set up environment to be recorded and preprocessed
record_memory = []
record_lock = Lock()
env = Record(makeEnv(args['environment_name']), record_memory, record_lock, args)
# Bind record_env to current env so that we can reference log_transitions easier later
record_env = env
# Use native gym monitor to get video recording
env = GymMonitor(env, args['logdir'], force=True)
# Preprocess enviornment
env = wrap_preprocessing(env)
# Set seeds
rltorch.set_seed(config['seed'])
env.seed(config['seed'])
device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
# Set up the networks
net = rn.Network(Value(state_size, action_size),
Adam, config, device = device)
target_net = rn.TargetNetwork(net, device = device)
# Relevant components from RLTorch
memory = iDQfDMemory(capacity= config['memory_size'], max_demo = config['memory_size'] // 10)
actor = ArgMaxSelector(net, action_size, device = device)
agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)
# Use a different environment for when the computer trains on the side so that the current game state isn't manipuated
# Also use MaxEnvSkip to speed up processing
sneaky_env = wrap_preprocessing(makeEnv(args['environment_name']), MaxAndSkipEnv = True)
sneaky_actor = EpsilonGreedySelector(net, action_size, device = device, epsilon = config['exploration_rate'])
# Pass all this information into the thread that will handle the game play and start
playThread = PlayClass(env, actor, agent, sneaky_env, sneaky_actor, record_lock, config)
playThread.start()
# While the play thread is running, we'll periodically log transitions we've encountered
while playThread.is_alive():
playThread.join(60)
record_lock.acquire()
print("Logging....", end = " ")
record_env.log_transitions()
record_lock.release()
# Save what's remaining after process died
record_lock.acquire()
record_env.log_transitions()
record_lock.release()