Skip to content

Commit 8e1fa07

Browse files
committed
Grid Cell model files
1 parent c14ea09 commit 8e1fa07

15 files changed

Lines changed: 1476 additions & 0 deletions

scripts/Chris/DQN/ANN.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import pickle as pkl
2+
import random
3+
from collections import namedtuple, deque
4+
5+
from matplotlib import pyplot as plt
6+
from sklearn.metrics import confusion_matrix
7+
from torch.nn import Module, Linear, ReLU, Sequential
8+
from torch.optim import Adam
9+
import torch
10+
11+
# https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
12+
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
13+
14+
# https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
15+
class ReplayMemory(object):
16+
17+
def __init__(self, capacity):
18+
self.memory = deque([], maxlen=capacity)
19+
20+
def push(self, *args):
21+
"""Save a transition"""
22+
self.memory.append(Transition(*args))
23+
24+
def sample(self, batch_size):
25+
return random.sample(self.memory, batch_size)
26+
27+
def __len__(self):
28+
return len(self.memory)
29+
30+
31+
class ANN(Module):
32+
def __init__(self, input_dim, output_dim):
33+
super(ANN, self).__init__()
34+
self.sequence = Sequential(
35+
Linear(input_dim, 1000),
36+
ReLU(),
37+
Linear(1000, 100),
38+
ReLU(),
39+
Linear(100, output_dim)
40+
)
41+
42+
def forward(self, x):
43+
x = x.to(torch.float32)
44+
return self.sequence(x)
45+
46+
47+
class DQN:
48+
def __init__(self, input_dim, output_dim, gamma=0.99, batch_size=128, device='cpu'):
49+
self.policy_net = ANN(input_dim, output_dim)
50+
self.target_net = ANN(input_dim, output_dim)
51+
self.optimizer = Adam(self.policy_net.parameters())
52+
self.memory = ReplayMemory(10000)
53+
self.gamma = gamma
54+
self.batch_size = batch_size
55+
self.device = device
56+
57+
def select_action(self, state, epsilon):
58+
# Random action
59+
if random.random() < epsilon:
60+
return torch.tensor([[random.randrange(2)]], dtype=torch.float32)
61+
62+
# ANN action
63+
else:
64+
with torch.no_grad():
65+
return self.policy_net(state).argmax()
66+
67+
def optimize_model(self):
68+
if len(self.memory) < self.batch_size:
69+
return
70+
transitions = self.memory.sample(self.batch_size)
71+
batch = Transition(*zip(*transitions))
72+
73+
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
74+
batch.next_state)), device=self.device, dtype=torch.bool)
75+
non_final_next_states = torch.cat([s for s in batch.next_state
76+
if s is not None]).reshape(-1, 2)
77+
state_batch = torch.cat(batch.state).reshape(-1, 2)
78+
action_batch = torch.tensor(batch.action).to(torch.int64)
79+
reward_batch = torch.tensor(batch.reward)
80+
81+
# Compute Q(s_t, a)
82+
state_action_values = self.policy_net(state_batch)[action_batch]
83+
84+
# Compute V(s_{t+1}) for all next states.
85+
next_state_values = torch.zeros(self.batch_size, device=self.device)
86+
with torch.no_grad():
87+
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1).values
88+
89+
# Compute the expected Q values
90+
expected_state_action_values = (next_state_values * self.gamma) + reward_batch
91+
92+
# Compute Loss
93+
criterion = torch.nn.SmoothL1Loss()
94+
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
95+
96+
# Optimize the model
97+
self.optimizer.zero_grad()
98+
loss.backward()
99+
# In-place gradient clipping
100+
torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
101+
self.optimizer.step()
102+
103+
def update_target(self, tau=0.005):
104+
target_net_state_dict = self.target_net.state_dict()
105+
policy_net_state_dict = self.policy_net.state_dict()
106+
for key in policy_net_state_dict:
107+
target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key] * (1 - tau)
108+
109+
110+
class Mem_Dataset(torch.utils.data.Dataset):
111+
def __init__(self, samples, labels):
112+
self.samples = samples
113+
self.labels = labels
114+
115+
def __len__(self):
116+
return len(self.samples)
117+
118+
def __getitem__(self, idx):
119+
# Compress spike train into windows for dimension reduction
120+
return self.samples[idx].sum(0).squeeze(), self.labels[idx]
121+
122+
123+
if __name__ == '__main__':
124+
### ANN for input spike trains ###
125+
# Load recalled memory samples ##
126+
with open('Data/grid_cell_spk_trains.pkl', 'rb') as f:
127+
samples, labels = pkl.load(f)
128+
129+
## Initialize ANN ##
130+
in_dim = samples[0].shape[1]
131+
model = ANN(in_dim, 2)
132+
optimizer = Adam(model.parameters())
133+
criterion = torch.nn.MSELoss()
134+
dataset = Mem_Dataset(samples, labels)
135+
train_size = int(0.8 * len(dataset))
136+
test_size = len(dataset) - train_size
137+
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])
138+
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
139+
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=True)
140+
141+
## Training ##
142+
loss_log = []
143+
accuracy_log = []
144+
for epoch in range(10):
145+
total_loss = 0
146+
correct = 0
147+
for memory_batch, positions in train_loader:
148+
# positions_ = torch.tensor([[positions_[0][i], positions_[1][i]] for i, _ in enumerate(positions_[0])], dtype=torch.float32)
149+
optimizer.zero_grad()
150+
outputs = model(memory_batch)
151+
loss = criterion(outputs, positions.to(torch.float32))
152+
loss.backward()
153+
optimizer.step()
154+
total_loss += loss.item()
155+
correct += torch.all(outputs.round() == positions.round(),
156+
dim=1).sum().item()
157+
accuracy_log.append(correct / len(train_set))
158+
loss_log.append(total_loss)
159+
160+
plt.xlabel('Epoch')
161+
plt.ylabel('Loss')
162+
plt.title('Training Loss')
163+
plt.plot(loss_log)
164+
plt.show()
165+
plt.xlabel('Epoch')
166+
plt.ylabel('Accuracy')
167+
plt.title('Training Accuracy')
168+
plt.plot(accuracy_log)
169+
plt.show()
170+
171+
## Testing ##
172+
total = 0
173+
correct = 0
174+
confusion_matrix = torch.zeros(25, 25)
175+
out_of_bounds = 0
176+
with torch.no_grad():
177+
for memories, labels in test_loader:
178+
outputs = model(memories)
179+
loss = criterion(outputs, labels)
180+
total += len(labels)
181+
correct += torch.all(outputs.round() == labels.round(),
182+
dim=1).sum().item() # Check if prediction for both x and y are correct
183+
for t, p in zip(labels, outputs):
184+
label_ind = int(t[0].round() * 5 + t[1].round())
185+
pred_ind = int(p[0].round() * 5 + p[1].round())
186+
if label_ind < 0 or label_ind >= 25 or pred_ind < 0 or pred_ind >= 25:
187+
out_of_bounds += 1
188+
else:
189+
confusion_matrix[label_ind, pred_ind] += 1
190+
191+
plt.imshow(confusion_matrix)
192+
plt.title('Confusion Matrix')
193+
plt.xlabel('Predicted')
194+
plt.ylabel('True Label')
195+
plt.colorbar()
196+
plt.show()
197+
198+
print(f'Accuracy: {round(correct / total, 3)*100}%')
199+

scripts/Chris/DQN/Environment.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from labyrinth.generate import DepthFirstSearchGenerator
2+
from labyrinth.grid import Cell, Direction
3+
from labyrinth.maze import Maze
4+
from labyrinth.solve import MazeSolver
5+
import pickle as pkl
6+
import matplotlib.pyplot as plt
7+
8+
class Maze_Environment(Maze):
9+
def __init__(self, width, height):
10+
11+
# Generate basic maze & solve
12+
super().__init__(width=width, height=height, generator=DepthFirstSearchGenerator())
13+
solver = MazeSolver()
14+
self.path = solver.solve(self)
15+
self.agent_cell = self.start_cell
16+
17+
def plot(self):
18+
# Box around maze
19+
plt.plot([-0.5, self.width-1+0.5], [-0.5, -0.5], color='black')
20+
plt.plot([-0.5, self.width-1+0.5], [self.height-1+0.5, self.height-1+0.5], color='black')
21+
plt.plot([-0.5, -0.5], [-0.5, self.height-1+0.5], color='black')
22+
plt.plot([self.width-1+0.5, self.width-1+0.5], [-0.5, self.height-1+0.5], color='black')
23+
24+
# Plot maze
25+
for row in range(self.height):
26+
for column in range(self.width):
27+
# Path
28+
cell = self[column, row] # Tranpose maze coordinates (just how the maze is stored)
29+
if cell == self.start_cell:
30+
plt.plot(row, column, 'go')
31+
elif cell == self.end_cell:
32+
plt.plot(row, column,'bo')
33+
elif cell in self.path:
34+
plt.plot(row, column, 'ro')
35+
36+
# Walls
37+
if Direction.S not in cell.open_walls:
38+
plt.plot([row-0.5, row+0.5], [column+0.5, column+0.5], color='black')
39+
if Direction.E not in cell.open_walls:
40+
plt.plot([row+0.5, row+0.5], [column-0.5, column+0.5], color='black')
41+
42+
def reset(self):
43+
pass
44+
45+
# Takes action, returns next state, reward, done, info
46+
def step(self, action):
47+
# Check if action runs into wall
48+
if action not in self.agent_cell.open_walls:
49+
return self.agent_cell, -1, False, {}
50+
51+
# Move agent
52+
else:
53+
self.agent_cell = self.agent_pos.neighbor(action)
54+
if self.agent_cell == self.end_cell:
55+
return self.agent_cell, 1, True, {}
56+
else:
57+
return self.agent_cell, 0, False, {}
58+
59+
def save(self, filename):
60+
with open(filename, 'wb') as f:
61+
pkl.dump(self, f)
62+
63+
64+
if __name__ == '__main__':
65+
maze = Maze_Environment(width=25, height=25)
66+
solver = MazeSolver()
67+
path = solver.solve(maze)
68+
maze.path = path
69+
print(maze)
70+
print(f'start: {maze.start_cell}')
71+
print(f'end: {maze.end_cell}')

scripts/Chris/DQN/Eval.ipynb

Lines changed: 252 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)