Skip to content

Commit 701a18e

Browse files
committed
Environment animation
1 parent ce07cdd commit 701a18e

6 files changed

Lines changed: 130 additions & 145 deletions

File tree

scripts/Chris/DQN/ANN.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -43,70 +43,6 @@ def forward(self, x):
4343
x = x.to(torch.float32)
4444
return self.sequence(x)
4545

46-
47-
class DQN:
48-
def __init__(self, input_dim, output_dim, gamma=0.99, batch_size=128, device='cpu'):
49-
self.policy_net = ANN(input_dim, output_dim)
50-
self.target_net = ANN(input_dim, output_dim)
51-
self.optimizer = Adam(self.policy_net.parameters())
52-
self.memory = ReplayMemory(10000)
53-
self.gamma = gamma
54-
self.batch_size = batch_size
55-
self.device = device
56-
57-
def select_action(self, state, epsilon):
58-
# Random action
59-
if random.random() < epsilon:
60-
return torch.tensor([[random.randrange(2)]], dtype=torch.float32)
61-
62-
# ANN action
63-
else:
64-
with torch.no_grad():
65-
return self.policy_net(state).argmax()
66-
67-
def optimize_model(self):
68-
if len(self.memory) < self.batch_size:
69-
return
70-
transitions = self.memory.sample(self.batch_size)
71-
batch = Transition(*zip(*transitions))
72-
73-
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
74-
batch.next_state)), device=self.device, dtype=torch.bool)
75-
non_final_next_states = torch.cat([s for s in batch.next_state
76-
if s is not None]).reshape(-1, 2)
77-
state_batch = torch.cat(batch.state).reshape(-1, 2)
78-
action_batch = torch.tensor(batch.action).to(torch.int64)
79-
reward_batch = torch.tensor(batch.reward)
80-
81-
# Compute Q(s_t, a)
82-
state_action_values = self.policy_net(state_batch)[action_batch]
83-
84-
# Compute V(s_{t+1}) for all next states.
85-
next_state_values = torch.zeros(self.batch_size, device=self.device)
86-
with torch.no_grad():
87-
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1).values
88-
89-
# Compute the expected Q values
90-
expected_state_action_values = (next_state_values * self.gamma) + reward_batch
91-
92-
# Compute Loss
93-
criterion = torch.nn.SmoothL1Loss()
94-
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
95-
96-
# Optimize the model
97-
self.optimizer.zero_grad()
98-
loss.backward()
99-
# In-place gradient clipping
100-
torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
101-
self.optimizer.step()
102-
103-
def update_target(self, tau=0.005):
104-
target_net_state_dict = self.target_net.state_dict()
105-
policy_net_state_dict = self.policy_net.state_dict()
106-
for key in policy_net_state_dict:
107-
target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key] * (1 - tau)
108-
109-
11046
class Mem_Dataset(torch.utils.data.Dataset):
11147
def __init__(self, samples, labels):
11248
self.samples = samples

scripts/Chris/DQN/Environment.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
import random
2+
import numpy as np
13
from labyrinth.generate import DepthFirstSearchGenerator
24
from labyrinth.grid import Cell, Direction
35
from labyrinth.maze import Maze
46
from labyrinth.solve import MazeSolver
7+
from matplotlib.pyplot import plot as plt
8+
from matplotlib.animation import FuncAnimation
9+
510
import pickle as pkl
611
import matplotlib.pyplot as plt
12+
from torch import optim
713

814
class Maze_Environment():
915
def __init__(self, width, height):
@@ -17,6 +23,8 @@ def __init__(self, width, height):
1723
self.maze.path = self.path # No idea why this is necessary
1824
self.agent_cell = self.maze.start_cell
1925
self.num_actions = 4
26+
self.history = [(self.agent_cell.coordinates, 0, False, {})] # (state, reward, done, info)
27+
self.pos_history = []
2028

2129
def plot(self):
2230
# Box around maze
@@ -29,12 +37,12 @@ def plot(self):
2937
for row in range(self.height):
3038
for column in range(self.width):
3139
# Path
32-
cell = self[column, row] # Tranpose maze coordinates (just how the maze is stored)
33-
if cell == self.start_cell:
40+
cell = self.maze[column, row] # Tranpose maze coordinates (just how the maze is stored)
41+
if cell == self.maze.start_cell:
3442
plt.plot(row, column, 'go')
35-
elif cell == self.end_cell:
43+
elif cell == self.maze.end_cell:
3644
plt.plot(row, column,'bo')
37-
elif cell in self.path:
45+
elif cell in self.maze.path:
3846
plt.plot(row, column, 'ro')
3947

4048
# Walls
@@ -44,16 +52,11 @@ def plot(self):
4452
plt.plot([row+0.5, row+0.5], [column-0.5, column+0.5], color='black')
4553

4654
def reset(self):
47-
# self.maze = Maze(width=self.width, height=self.height, generator=DepthFirstSearchGenerator())
48-
# self.solver = MazeSolver()
49-
# self.path = self.solver.solve(self.maze)
50-
# self.maze.path = self.path # No idea why this is necessary
51-
# self.agent_cell = self.maze.start_cell
52-
# return self.agent_cell, {}
5355
self.agent_cell = self.maze.start_cell
56+
self.step_history = []
57+
self.pos_history = []
5458
return self.agent_cell, {}
5559

56-
5760
# Takes action
5861
# Returns next state, reward, done, info
5962
def step(self, action):
@@ -69,29 +72,68 @@ def step(self, action):
6972

7073
# Check if action runs into wall
7174
if action not in self.agent_cell.open_walls:
75+
self.history.append((self.agent_cell.coordinates, -0.5, False, {}))
7276
return self.agent_cell, -.5, False, {}
7377

7478
# Move agent
7579
else:
7680
self.agent_cell = self.maze.neighbor(self.agent_cell, action)
7781
if self.agent_cell == self.maze.end_cell: # Check if agent has reached the end
82+
self.history.append(self.agent_cell.coordinates, 1, True, {})
7883
return self.agent_cell, 1, True, {}
7984
else:
85+
self.history.append((self.agent_cell.coordinates, 0, False, {}))
8086
return self.agent_cell, 0, False, {}
8187

8288
def save(self, filename):
8389
with open(filename, 'wb') as f:
8490
pkl.dump(self, f)
8591

92+
def animate_history(self):
93+
def update(i):
94+
plt.clf()
95+
self.plot()
96+
plt.plot(self.history[i][0][1], self.history[i][0][0], 'yo')
97+
plt.title(f'Step {i}, Reward: {self.history[i][1]}')
98+
ani = FuncAnimation(plt.gcf(), update, frames=len(self.history), repeat=False)
99+
ani.save('maze.gif', writer='ffmpeg', fps=10)
100+
101+
class Grid_Cell_Maze_Environment(Maze_Environment):
102+
def __init__(self, width, height):
103+
super().__init__(width, height)
86104

105+
# Load spike train samples
106+
# {position: [spike_trains]}
107+
with open('Data/preprocessed_recalls_sorted.pkl', 'rb') as f:
108+
self.samples = pkl.load(f)
109+
110+
def reset(self):
111+
cell, info = super().reset()
112+
return self.state_to_grid_cell_spikes(cell), info
113+
114+
def step(self, action):
115+
obs, reward, done, info = super().step(action)
116+
obs = self.state_to_grid_cell_spikes(obs)
117+
return obs, reward, done, info
118+
119+
def state_to_grid_cell_spikes(self, cell):
120+
return random.choice(self.samples[cell.coordinates])
87121

88122

89123
if __name__ == '__main__':
90-
maze_env = Maze_Environment(width=25, height=25)
91-
print(maze_env.maze)
92-
print(f'start: {maze_env.maze.start_cell}')
93-
print(f'end: {maze_env.maze.end_cell}')
94-
maze_env.reset()
95-
print(maze_env.maze)
96-
print(f'start: {maze_env.maze.start_cell}')
97-
print(f'end: {maze_env.maze.end_cell}')
124+
from train_DQN import DQN, ReplayMemory
125+
from scripts.Chris.DQN.train_DQN import run_episode
126+
127+
device = 'cpu'
128+
n_actions = 4
129+
input_size = 300
130+
lr = 0.01
131+
policy_net = DQN(input_size, n_actions).to(device)
132+
target_net = DQN(input_size, n_actions).to(device)
133+
target_net.load_state_dict(policy_net.state_dict())
134+
optimizer = optim.AdamW(policy_net.parameters(), lr=lr, amsgrad=True)
135+
memory = ReplayMemory(10000)
136+
env = Grid_Cell_Maze_Environment(width=5, height=5)
137+
138+
run_episode(env, policy_net, 'cpu', 100, eps=0.9)
139+
env.animate_history()

scripts/Chris/DQN/pipeline_executor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pickle as pkl
3+
from train_DQN import train_DQN
34
from sample_generator import sample_generator
45
from spike_train_generator import spike_train_generator
56
from store_reservoir import store_reservoir
@@ -74,5 +75,19 @@
7475
# # Preprocess Recalls ##
7576
# recalled_mem_preprocessing(WINDOW_FREQ, WINDOW_SIZE, PLOT)
7677

78+
## Train DQN ##
79+
LR = 0.01
80+
EPS_START = 0.9
81+
EPS_END = 0.05
82+
EPS_DECAY = 1000
83+
TAU = 0.005
84+
GAMMA = 0.99
85+
MAX_STEPS_PER_EP = 10
86+
MAX_TOTAL_STEPS = 10
87+
MAX_EPS = 3000
88+
BATCH_SIZE = 128
89+
INPUT_SIZE = EXC_SIZE + INH_SIZE
90+
train_DQN(INPUT_SIZE, LR, BATCH_SIZE, EPS_START, EPS_END, EPS_DECAY, TAU, GAMMA, MAX_STEPS_PER_EP, MAX_TOTAL_STEPS, MAX_EPS)
91+
7792
## Train ANN ##
78-
classify_recalls(OUT_DIM, TRAIN_RATIO, BATCH_SIZE, TRAIN_EPOCHS)
93+
# classify_recalls(OUT_DIM, TRAIN_RATIO, BATCH_SIZE, TRAIN_EPOCHS)

scripts/Chris/DQN/recalled_mem_preprocessing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def recalled_mem_preprocessing(window_freq, window_size, plot):
4242
## Save transformed samples ##
4343
with open('Data/preprocessed_recalls.pkl', 'wb') as f:
4444
pkl.dump((new_samples, labels), f)
45+
with open('Data/preprocessed_recalls_sorted.pkl', 'wb') as f:
46+
pkl.dump(new_samples_sorted, f)
4547

4648
if plot:
4749
# positions = np.array([key for key in new_samples_sorted.keys()])

scripts/Chris/DQN/sample_generator.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def inter_positional_spread(env_to_gc):
2626
return spread
2727

2828
# Generate grid cell activity for all integer coordinate positions in environment
29-
def sample_generator(scales, offsets, vars, x_range, y_range, samples_per_pos, noise=0.1, padding=2, plot=False):
29+
def sample_generator(scales, offsets, vars, x_range, y_range,
30+
samples_per_pos, noise=0.1, padding=2, plot=False):
3031
print('Generating samples...')
3132
sorted_samples = {}
3233
samples = np.zeros((x_range[1] * y_range[1] * samples_per_pos, len(scales)))
@@ -62,24 +63,3 @@ def sample_generator(scales, offsets, vars, x_range, y_range, samples_per_pos, n
6263
plt.show()
6364

6465
return samples, labels, sorted_samples
65-
66-
if __name__ == '__main__':
67-
## Constants ##
68-
WIDTH = 5
69-
HEIGHT = 5
70-
SAMPLES_PER_POS = 1000
71-
WINDOW_FREQ = 10
72-
WINDOW_SIZE = 10
73-
# Grid Cells
74-
num_cells_ = 20
75-
x_range_ = (0, 5)
76-
y_range_ = (0, 5)
77-
x_offsets_ = np.random.uniform(-1, 1, num_cells_)
78-
y_offsets_ = np.random.uniform(-1, 1, num_cells_)
79-
offsets_ = list(zip(x_offsets_, y_offsets_))
80-
scales_ = [1 + 0.01 * i for i in range(num_cells_)]
81-
vars_ = [0.85]*num_cells_
82-
83-
# Test spread for set of parameters
84-
# Shape = (num_samples, num_cells)
85-
samples_, labels_, sorted_samples_ = sample_generator(scales_, offsets_, vars_, x_range_, y_range_, SAMPLES_PER_POS)

0 commit comments

Comments
 (0)