Skip to content

Commit ce07cdd

Browse files
committed
Working DQN
1 parent 59c732b commit ce07cdd

4 files changed

Lines changed: 227 additions & 20 deletions

File tree

scripts/Chris/DQN/Environment.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
import pickle as pkl
66
import matplotlib.pyplot as plt
77

8-
class Maze_Environment(Maze):
8+
class Maze_Environment():
99
def __init__(self, width, height):
1010

1111
# Generate basic maze & solve
12-
super().__init__(width=width, height=height, generator=DepthFirstSearchGenerator())
13-
solver = MazeSolver()
14-
self.path = solver.solve(self)
15-
self.agent_cell = self.start_cell
12+
self.width = width
13+
self.height = height
14+
self.maze = Maze(width=width, height=height, generator=DepthFirstSearchGenerator())
15+
self.solver = MazeSolver()
16+
self.path = self.solver.solve(self.maze)
17+
self.maze.path = self.path # No idea why this is necessary
18+
self.agent_cell = self.maze.start_cell
19+
self.num_actions = 4
1620

1721
def plot(self):
1822
# Box around maze
@@ -40,18 +44,37 @@ def plot(self):
4044
plt.plot([row+0.5, row+0.5], [column-0.5, column+0.5], color='black')
4145

4246
def reset(self):
43-
pass
47+
# self.maze = Maze(width=self.width, height=self.height, generator=DepthFirstSearchGenerator())
48+
# self.solver = MazeSolver()
49+
# self.path = self.solver.solve(self.maze)
50+
# self.maze.path = self.path # No idea why this is necessary
51+
# self.agent_cell = self.maze.start_cell
52+
# return self.agent_cell, {}
53+
self.agent_cell = self.maze.start_cell
54+
return self.agent_cell, {}
4455

45-
# Takes action, returns next state, reward, done, info
56+
57+
# Takes action
58+
# Returns next state, reward, done, info
4659
def step(self, action):
60+
# Transform action into Direction
61+
if action == 0:
62+
action = Direction.N
63+
elif action == 1:
64+
action = Direction.E
65+
elif action == 2:
66+
action = Direction.S
67+
elif action == 3:
68+
action = Direction.W
69+
4770
# Check if action runs into wall
4871
if action not in self.agent_cell.open_walls:
49-
return self.agent_cell, -1, False, {}
72+
return self.agent_cell, -.5, False, {}
5073

5174
# Move agent
5275
else:
53-
self.agent_cell = self.agent_pos.neighbor(action)
54-
if self.agent_cell == self.end_cell:
76+
self.agent_cell = self.maze.neighbor(self.agent_cell, action)
77+
if self.agent_cell == self.maze.end_cell: # Check if agent has reached the end
5578
return self.agent_cell, 1, True, {}
5679
else:
5780
return self.agent_cell, 0, False, {}
@@ -61,11 +84,14 @@ def save(self, filename):
6184
pkl.dump(self, f)
6285

6386

87+
88+
6489
if __name__ == '__main__':
65-
maze = Maze_Environment(width=25, height=25)
66-
solver = MazeSolver()
67-
path = solver.solve(maze)
68-
maze.path = path
69-
print(maze)
70-
print(f'start: {maze.start_cell}')
71-
print(f'end: {maze.end_cell}')
90+
maze_env = Maze_Environment(width=25, height=25)
91+
print(maze_env.maze)
92+
print(f'start: {maze_env.maze.start_cell}')
93+
print(f'end: {maze_env.maze.end_cell}')
94+
maze_env.reset()
95+
print(maze_env.maze)
96+
print(f'start: {maze_env.maze.start_cell}')
97+
print(f'end: {maze_env.maze.end_cell}')

scripts/Chris/DQN/pipeline_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
## Constants ##
1212
WIDTH = 5
1313
HEIGHT = 5
14-
SAMPLES_PER_POS = 1000
14+
SAMPLES_PER_POS = 5000
1515
NOISE = 0.1 # Noise in sampling
1616
WINDOW_FREQ = 10
1717
WINDOW_SIZE = 10
@@ -74,5 +74,5 @@
7474
# # Preprocess Recalls ##
7575
# recalled_mem_preprocessing(WINDOW_FREQ, WINDOW_SIZE, PLOT)
7676

77-
# Train ANN ##
77+
## Train ANN ##
7878
classify_recalls(OUT_DIM, TRAIN_RATIO, BATCH_SIZE, TRAIN_EPOCHS)

scripts/Chris/DQN/recall_reservoir.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def recall_reservoir(exc_size, inh_size, sim_time, plot=False):
1313
memory_keys, labels = pkl.load(f)
1414

1515
## Recall memories ##
16-
# TODO: Plot output spikes according to inh exc populations
1716
recalled_memories = np.zeros((len(memory_keys), sim_time, exc_size + inh_size))
1817
recalled_memories_sorted = {}
1918
for i, (key, label) in enumerate(zip(memory_keys, labels)):

scripts/Chris/DQN/train_DQN.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import math
2+
import random
3+
import matplotlib
4+
import matplotlib.pyplot as plt
5+
from collections import namedtuple, deque
6+
from itertools import count
7+
8+
import numpy as np
9+
import torch
10+
import torch.nn as nn
11+
import torch.optim as optim
12+
import torch.nn.functional as F
13+
14+
from scripts.Chris.DQN.Environment import Maze_Environment
15+
16+
Transition = namedtuple('Transition',
17+
('state', 'action', 'next_state', 'reward'))
18+
19+
class ReplayMemory(object):
20+
def __init__(self, capacity):
21+
self.memory = deque([], maxlen=capacity)
22+
23+
def push(self, *args):
24+
"""Save a transition"""
25+
self.memory.append(Transition(*args))
26+
27+
def sample(self, batch_size):
28+
return random.sample(self.memory, batch_size)
29+
30+
def __len__(self):
31+
return len(self.memory)
32+
33+
class DQN(nn.Module):
34+
35+
def __init__(self, n_observations, n_actions):
36+
super(DQN, self).__init__()
37+
self.layer1 = nn.Linear(n_observations, 128)
38+
self.layer2 = nn.Linear(128, 128)
39+
self.layer3 = nn.Linear(128, n_actions)
40+
41+
# Called with either one element to determine next action, or a batch
42+
# during optimization. Returns tensor([[left0exp,right0exp]...]).
43+
def forward(self, x):
44+
x = F.relu(self.layer1(x))
45+
x = F.relu(self.layer2(x))
46+
return self.layer3(x)
47+
48+
49+
# Select action using epsilon-greedy policy
50+
def select_action(state, step, eps, policy_net, env):
51+
# eps_threshold = EPS_END + (EPS_START - EPS_END) * \
52+
# math.exp(-1. * step / EPS_DECAY)
53+
54+
# Select action from policy net
55+
if random.random() > eps:
56+
with torch.no_grad():
57+
# t.max(1) will return the largest column value of each row.
58+
# second column on max result is index of where max element was
59+
# found, so we pick action with the larger expected reward.
60+
return policy_net(state).max(1).indices.view(1, 1)
61+
62+
# Select random action (exploration)
63+
else:
64+
return torch.tensor(np.random.choice(env.num_actions)).view(1, 1)
65+
66+
67+
# Optimize DQN
68+
def optimize_model(memory, batch_size, policy_net, target_net, optimizer, gamma, device):
69+
if len(memory) < batch_size:
70+
return
71+
transitions = memory.sample(batch_size)
72+
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
73+
# detailed explanation). This converts batch-array of Transitions
74+
# to Transition of batch-arrays.
75+
batch = Transition(*zip(*transitions))
76+
77+
# Compute a mask of non-final states and concatenate the batch elements
78+
# (a final state would've been the one after which simulation ended)
79+
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
80+
batch.next_state)), device=device, dtype=torch.bool)
81+
non_final_next_states = torch.cat([s for s in batch.next_state
82+
if s is not None])
83+
state_batch = torch.cat(batch.state)
84+
action_batch = torch.cat(batch.action)
85+
reward_batch = torch.cat(batch.reward)
86+
87+
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
88+
# columns of actions taken. These are the actions which would've been taken
89+
# for each batch state according to policy_net
90+
state_action_values = policy_net(state_batch).gather(1, action_batch)
91+
92+
# Compute V(s_{t+1}) for all next states.
93+
# Expected values of actions for non_final_next_states are computed based
94+
# on the "older" target_net; selecting their best reward with max(1).values
95+
# This is merged based on the mask, such that we'll have either the expected
96+
# state value or 0 in case the state was final.
97+
next_state_values = torch.zeros(batch_size, device=device)
98+
with torch.no_grad():
99+
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
100+
# Compute the expected Q values
101+
expected_state_action_values = (next_state_values * gamma) + reward_batch
102+
103+
# Compute Huber loss
104+
criterion = nn.SmoothL1Loss()
105+
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
106+
107+
# Optimize the model
108+
optimizer.zero_grad()
109+
loss.backward()
110+
# In-place gradient clipping
111+
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
112+
optimizer.step()
113+
114+
115+
if __name__ == '__main__':
116+
device = 'cpu'
117+
n_actions = 4
118+
n_observations = 2
119+
LR = 0.01
120+
EPS_START = 0.9
121+
EPS_END = 0.05
122+
EPS_DECAY = 1000
123+
TAU = 0.005
124+
GAMMA = 0.99
125+
MAX_STEPS_PER_EP = 1000
126+
TOTAL_STEPS = 10000
127+
MAX_EPS = 300
128+
BATCH_SIZE = 128
129+
130+
policy_net_ = DQN(n_observations, n_actions).to(device)
131+
target_net_ = DQN(n_observations, n_actions).to(device)
132+
target_net_.load_state_dict(policy_net_.state_dict())
133+
optimizer_ = optim.AdamW(policy_net_.parameters(), lr=LR, amsgrad=True)
134+
memory_ = ReplayMemory(10000)
135+
env_ = Maze_Environment(width=5, height=5)
136+
137+
episode_durations = []
138+
episodes = 0
139+
total_steps = 0
140+
print(env_.maze)
141+
while total_steps < TOTAL_STEPS and episodes < MAX_EPS:
142+
# Initialize the environment and get its state
143+
state, info = env_.reset()
144+
state = torch.tensor(state.coordinates, dtype=torch.float32, device=device).unsqueeze(0)
145+
# print(f"Episode {i_episode}")
146+
for t in count():
147+
eps = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * total_steps / EPS_DECAY)
148+
action = select_action(state, t, eps, policy_net_, env_)
149+
observation, reward, terminated, _ = env_.step(action.item())
150+
reward = torch.tensor([reward], device=device)
151+
152+
if terminated:
153+
next_state = None
154+
else:
155+
next_state = torch.tensor(observation.coordinates, dtype=torch.float32, device=device).unsqueeze(0)
156+
157+
# Store the transition in memory
158+
memory_.push(state, action, next_state, reward)
159+
160+
# Move to the next state
161+
state = next_state
162+
163+
# Perform one step of the optimization (on the policy network)
164+
optimize_model(memory_, BATCH_SIZE, policy_net_, target_net_, optimizer_, gamma=GAMMA, device=device)
165+
166+
# Soft update of the target network's weights
167+
# θ′ ← τ θ + (1 −τ )θ′
168+
target_net_state_dict = target_net_.state_dict()
169+
policy_net_state_dict = policy_net_.state_dict()
170+
for key in policy_net_state_dict:
171+
target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
172+
target_net_.load_state_dict(target_net_state_dict)
173+
174+
total_steps += 1
175+
if terminated or t > MAX_STEPS_PER_EP:
176+
episode_durations.append(t + 1)
177+
break
178+
print(f"Episode {episodes} lasted {t+1} steps, eps = {round(eps, 2)} total steps = {total_steps}")
179+
episodes += 1
180+
181+
plt.plot(episode_durations)
182+
plt.show()

0 commit comments

Comments
 (0)