Skip to content

Commit 5b6f3d2

Browse files
committed
Final version of RL model
1 parent 701a18e commit 5b6f3d2

6 files changed

Lines changed: 78 additions & 113 deletions

File tree

scripts/Chris/DQN/Environment.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def __init__(self, width, height):
2424
self.agent_cell = self.maze.start_cell
2525
self.num_actions = 4
2626
self.history = [(self.agent_cell.coordinates, 0, False, {})] # (state, reward, done, info)
27-
self.pos_history = []
2827

2928
def plot(self):
3029
# Box around maze
@@ -53,8 +52,7 @@ def plot(self):
5352

5453
def reset(self):
5554
self.agent_cell = self.maze.start_cell
56-
self.step_history = []
57-
self.pos_history = []
55+
self.history = [(self.agent_cell.coordinates, 0, False, {})]
5856
return self.agent_cell, {}
5957

6058
# Takes action
@@ -72,14 +70,14 @@ def step(self, action):
7270

7371
# Check if action runs into wall
7472
if action not in self.agent_cell.open_walls:
75-
self.history.append((self.agent_cell.coordinates, -0.5, False, {}))
76-
return self.agent_cell, -.5, False, {}
73+
self.history.append((self.agent_cell.coordinates, -0.1, False, {}))
74+
return self.agent_cell, -0.01, False, {}
7775

7876
# Move agent
7977
else:
8078
self.agent_cell = self.maze.neighbor(self.agent_cell, action)
8179
if self.agent_cell == self.maze.end_cell: # Check if agent has reached the end
82-
self.history.append(self.agent_cell.coordinates, 1, True, {})
80+
self.history.append((self.agent_cell.coordinates, 1, True, {}))
8381
return self.agent_cell, 1, True, {}
8482
else:
8583
self.history.append((self.agent_cell.coordinates, 0, False, {}))
@@ -89,14 +87,14 @@ def save(self, filename):
8987
with open(filename, 'wb') as f:
9088
pkl.dump(self, f)
9189

92-
def animate_history(self):
90+
def animate_history(self, file_name='maze.gif'):
9391
def update(i):
9492
plt.clf()
9593
self.plot()
9694
plt.plot(self.history[i][0][1], self.history[i][0][0], 'yo')
9795
plt.title(f'Step {i}, Reward: {self.history[i][1]}')
9896
ani = FuncAnimation(plt.gcf(), update, frames=len(self.history), repeat=False)
99-
ani.save('maze.gif', writer='ffmpeg', fps=10)
97+
ani.save(file_name, writer='ffmpeg', fps=5)
10098

10199
class Grid_Cell_Maze_Environment(Maze_Environment):
102100
def __init__(self, width, height):
@@ -132,7 +130,7 @@ def state_to_grid_cell_spikes(self, cell):
132130
target_net = DQN(input_size, n_actions).to(device)
133131
target_net.load_state_dict(policy_net.state_dict())
134132
optimizer = optim.AdamW(policy_net.parameters(), lr=lr, amsgrad=True)
135-
memory = ReplayMemory(10000)
133+
memory = ReplayMemory(1000)
136134
env = Grid_Cell_Maze_Environment(width=5, height=5)
137135

138136
run_episode(env, policy_net, 'cpu', 100, eps=0.9)

scripts/Chris/DQN/Eval.ipynb

Lines changed: 9 additions & 48 deletions
Large diffs are not rendered by default.

scripts/Chris/DQN/pipeline_executor.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212
## Constants ##
1313
WIDTH = 5
1414
HEIGHT = 5
15-
SAMPLES_PER_POS = 5000
15+
SAMPLES_PER_POS = 10
1616
NOISE = 0.1 # Noise in sampling
1717
WINDOW_FREQ = 10
1818
WINDOW_SIZE = 10
1919
NUM_CELLS = 20
20-
X_RANGE = (0, 5)
21-
Y_RANGE = (0, 5)
20+
X_RANGE = (0, WIDTH)
21+
Y_RANGE = (0, HEIGHT)
2222
SIM_TIME = 50
2323
MAX_SPIKE_FREQ = 0.8
2424
GC_MULTIPLES = 1
2525
EXC_SIZE = 250
2626
INH_SIZE = 50
27-
STORE_SAMPLES = 100
27+
STORE_SAMPLES = 0
2828
WINDOW_FREQ = 10
2929
WINDOW_SIZE = 10
3030
OUT_DIM = 2
@@ -66,28 +66,30 @@
6666
# # Spike Train Generation ##
6767
# spike_trains, labels, sorted_spike_trains = spike_train_generator(samples, labels, SIM_TIME, GC_MULTIPLES, MAX_SPIKE_FREQ)
6868
#
69-
# # ## Association (Store) ##
69+
# ## Association (Store) ##
7070
# store_reservoir(EXC_SIZE, INH_SIZE, STORE_SAMPLES, NUM_CELLS, GC_MULTIPLES, SIM_TIME, hyper_params, PLOT)
7171
#
7272
# # ## Association (Recall) ##
7373
# recall_reservoir(EXC_SIZE, INH_SIZE, SIM_TIME, PLOT)
7474
#
7575
# # Preprocess Recalls ##
76-
# recalled_mem_preprocessing(WINDOW_FREQ, WINDOW_SIZE, PLOT)
76+
# recalled_mem_preprocessing(WIDTH, HEIGHT, PLOT)
7777

7878
## Train DQN ##
7979
LR = 0.01
8080
EPS_START = 0.9
8181
EPS_END = 0.05
82-
EPS_DECAY = 1000
82+
DECAY_INTENSITY = 3 # higher
8383
TAU = 0.005
8484
GAMMA = 0.99
85-
MAX_STEPS_PER_EP = 10
86-
MAX_TOTAL_STEPS = 10
87-
MAX_EPS = 3000
88-
BATCH_SIZE = 128
85+
MAX_STEPS_PER_EP = 100
86+
MAX_TOTAL_STEPS = 15000
87+
MAX_EPS = 500
88+
BATCH_SIZE = 256
8989
INPUT_SIZE = EXC_SIZE + INH_SIZE
90-
train_DQN(INPUT_SIZE, LR, BATCH_SIZE, EPS_START, EPS_END, EPS_DECAY, TAU, GAMMA, MAX_STEPS_PER_EP, MAX_TOTAL_STEPS, MAX_EPS)
90+
train_DQN(INPUT_SIZE, WIDTH, HEIGHT, LR, BATCH_SIZE, EPS_START,
91+
EPS_END, DECAY_INTENSITY, TAU, GAMMA, MAX_STEPS_PER_EP,
92+
MAX_TOTAL_STEPS, MAX_EPS, PLOT)
9193

9294
## Train ANN ##
9395
# classify_recalls(OUT_DIM, TRAIN_RATIO, BATCH_SIZE, TRAIN_EPOCHS)

scripts/Chris/DQN/recall_reservoir.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,25 @@ def recall_reservoir(exc_size, inh_size, sim_time, plot=False):
3232
pkl.dump(recalled_memories_sorted, f)
3333

3434
# Plot recalls
35-
if plot:
36-
positions = np.array([key for key in recalled_memories_sorted.keys()])
37-
rand_inds = np.random.choice(range(len(positions)), 5)
38-
for pos in positions[rand_inds]:
39-
fig = plt.figure(figsize=(10, 3))
40-
gs = fig.add_gridspec(1, 6)
41-
ax1 = fig.add_subplot(gs[0, 0])
42-
ax1.set_title(f"Position: {pos}")
43-
avg_mem = np.mean(recalled_memories_sorted[tuple(pos)], axis=0)
44-
ax1.imshow(avg_mem.T)
45-
random_inds = np.random.choice(range(len(recalled_memories_sorted[tuple(pos)])), 5)
46-
random_samples = np.array(recalled_memories_sorted[tuple(pos)])[random_inds]
47-
vmin = np.min(random_samples)
48-
vmax = np.max(random_samples)
49-
for i in range(1, 5):
50-
ax = fig.add_subplot(gs[0, i])
51-
rand_sample = recalled_memories_sorted[tuple(pos)][random_inds[i]]
52-
im = ax.imshow(np.expand_dims(rand_sample.T, axis=1).squeeze(), vmin=vmin, vmax=vmax)
53-
ax.set_title(f"S{i}")
54-
ax.set(xticklabels=[])
55-
ax.set(yticklabels=[])
56-
plt.show()
35+
# if plot:
36+
# positions = np.array([key for key in recalled_memories_sorted.keys()])
37+
# rand_inds = np.random.choice(range(len(positions)), 5)
38+
# for pos in positions[rand_inds]:
39+
# fig = plt.figure(figsize=(10, 3))
40+
# gs = fig.add_gridspec(1, 6)
41+
# ax1 = fig.add_subplot(gs[0, 0])
42+
# ax1.set_title(f"Position: {pos}")
43+
# avg_mem = np.mean(recalled_memories_sorted[tuple(pos)], axis=0)
44+
# ax1.imshow(avg_mem.T)
45+
# random_inds = np.random.choice(range(len(recalled_memories_sorted[tuple(pos)])), 5)
46+
# random_samples = np.array(recalled_memories_sorted[tuple(pos)])[random_inds]
47+
# vmin = np.min(random_samples)
48+
# vmax = np.max(random_samples)
49+
# for i in range(1, 5):
50+
# ax = fig.add_subplot(gs[0, i])
51+
# rand_sample = recalled_memories_sorted[tuple(pos)][random_inds[i]]
52+
# im = ax.imshow(np.expand_dims(rand_sample.T, axis=1).squeeze(), vmin=vmin, vmax=vmax)
53+
# ax.set_title(f"S{i}")
54+
# ax.set(xticklabels=[])
55+
# ax.set(yticklabels=[])
56+
# plt.show()

scripts/Chris/DQN/recalled_mem_preprocessing.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55

6-
def recalled_mem_preprocessing(window_freq, window_size, plot):
6+
def recalled_mem_preprocessing(width, height, plot):
77
print('Preprocessing recalled memories...')
88

99
## Load recalled memory spike-trains ##
@@ -46,21 +46,9 @@ def recalled_mem_preprocessing(window_freq, window_size, plot):
4646
pkl.dump(new_samples_sorted, f)
4747

4848
if plot:
49-
# positions = np.array([key for key in new_samples_sorted.keys()])
50-
# fig = plt.figure(figsize=(10, 10))
51-
# gs = fig.add_gridspec(nrows=5, ncols=5)
52-
# for i, pos in enumerate(positions):
53-
# ax = fig.add_subplot(gs[int(pos[0]), int(pos[1])])
54-
# avg_mem = np.mean(new_samples_sorted[tuple(pos)], axis=0)
55-
# ax.set_title(f"Conf-Mat: {pos[0] * 5 + pos[1]}")
56-
# im = ax.imshow(np.expand_dims(avg_mem, axis=0))
57-
# ax.set_aspect('auto')
58-
# plt.tight_layout()
59-
# plt.show()
60-
6149
positions = np.array([key for key in new_samples_sorted.keys()])
62-
fig = plt.figure(figsize=(10, 10))
63-
gs = fig.add_gridspec(nrows=5, ncols=5)
50+
fig = plt.figure(figsize=(50, 50))
51+
gs = fig.add_gridspec(nrows=width, ncols=height)
6452
for i, pos in enumerate(positions):
6553
ax = fig.add_subplot(gs[int(pos[0]), int(pos[1])])
6654
avg_mem = np.mean(recalled_memories_sorted[tuple(pos)], axis=0)

scripts/Chris/DQN/train_DQN.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,26 +135,32 @@ def run_episode(env, policy_net, device, max_steps, eps=0):
135135

136136

137137

138-
def train_DQN(input_size, lr, batch_size, eps_start, eps_end, eps_decay, tau, gamma, max_steps_per_ep, max_total_steps, max_eps):
138+
def train_DQN(input_size, env_width, env_height, lr, batch_size, eps_start,
139+
eps_end, decay_intensity, tau, gamma, max_steps_per_ep, max_total_steps, max_eps, plot):
139140
device = 'cpu'
140141
n_actions = 4
141142
policy_net = DQN(input_size, n_actions).to(device)
142143
target_net = DQN(input_size, n_actions).to(device)
143144
target_net.load_state_dict(policy_net.state_dict())
144145
optimizer = optim.AdamW(policy_net.parameters(), lr=lr, amsgrad=True)
145-
memory = ReplayMemory(10000)
146-
env = Grid_Cell_Maze_Environment(width=5, height=5)
146+
memory = ReplayMemory(1000)
147+
env = Grid_Cell_Maze_Environment(width=env_width, height=env_height)
148+
149+
## Pre-training recording ##
150+
if plot:
151+
run_episode(env, policy_net, device, 100, eps=0.9)
152+
env.animate_history("pre_training.gif")
147153

148154
episode_durations = []
149155
episodes = 0
150156
total_steps = 0
151157
print(env.maze)
152-
while total_steps < max_total_steps and episodes < max_eps:
158+
while total_steps < max_total_steps: # and episodes < max_eps:
153159
# Initialize the environment and get its state
154160
state, info = env.reset()
155161
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
156162
for t in count():
157-
eps = eps_end + (eps_start - eps_end) * math.exp(-1. * total_steps / eps_decay)
163+
eps = eps_end + (eps_start - eps_end) * math.exp(-decay_intensity * total_steps / (max_total_steps))
158164
action = select_action(state, t, eps, policy_net, env)
159165
observation, reward, terminated, _ = env.step(action.item())
160166
reward = torch.tensor([reward], device=device)
@@ -188,5 +194,15 @@ def train_DQN(input_size, lr, batch_size, eps_start, eps_end, eps_decay, tau, ga
188194
print(f"Episode {episodes} lasted {t + 1} steps, eps = {round(eps, 2)} total steps = {total_steps}")
189195
episodes += 1
190196

191-
plt.plot(episode_durations)
192-
plt.show()
197+
## Post-training recording ##
198+
if plot:
199+
env.reset()
200+
run_episode(env, policy_net, device, 100, eps=0) # eps = 0 -> no exploration
201+
env.animate_history("post_training.gif")
202+
plt.clf()
203+
204+
plt.plot(episode_durations)
205+
plt.title("Episode durations")
206+
plt.ylabel("Duration")
207+
plt.xlabel("Episode")
208+
plt.show()

0 commit comments

Comments
 (0)