1+ import random
2+ import numpy as np
13from labyrinth .generate import DepthFirstSearchGenerator
24from labyrinth .grid import Cell , Direction
35from labyrinth .maze import Maze
46from labyrinth .solve import MazeSolver
7+ from matplotlib .pyplot import plot as plt
8+ from matplotlib .animation import FuncAnimation
9+
510import pickle as pkl
611import matplotlib .pyplot as plt
12+ from torch import optim
713
814class Maze_Environment ():
915 def __init__ (self , width , height ):
@@ -17,6 +23,8 @@ def __init__(self, width, height):
1723 self .maze .path = self .path # No idea why this is necessary
1824 self .agent_cell = self .maze .start_cell
1925 self .num_actions = 4
26+ self .history = [(self .agent_cell .coordinates , 0 , False , {})] # (state, reward, done, info)
27+ self .pos_history = []
2028
2129 def plot (self ):
2230 # Box around maze
@@ -29,12 +37,12 @@ def plot(self):
2937 for row in range (self .height ):
3038 for column in range (self .width ):
3139 # Path
32- cell = self [column , row ] # Tranpose maze coordinates (just how the maze is stored)
33- if cell == self .start_cell :
40+ cell = self . maze [column , row ] # Tranpose maze coordinates (just how the maze is stored)
41+ if cell == self .maze . start_cell :
3442 plt .plot (row , column , 'go' )
35- elif cell == self .end_cell :
43+ elif cell == self .maze . end_cell :
3644 plt .plot (row , column ,'bo' )
37- elif cell in self .path :
45+ elif cell in self .maze . path :
3846 plt .plot (row , column , 'ro' )
3947
4048 # Walls
@@ -44,16 +52,11 @@ def plot(self):
4452 plt .plot ([row + 0.5 , row + 0.5 ], [column - 0.5 , column + 0.5 ], color = 'black' )
4553
4654 def reset (self ):
47- # self.maze = Maze(width=self.width, height=self.height, generator=DepthFirstSearchGenerator())
48- # self.solver = MazeSolver()
49- # self.path = self.solver.solve(self.maze)
50- # self.maze.path = self.path # No idea why this is necessary
51- # self.agent_cell = self.maze.start_cell
52- # return self.agent_cell, {}
5355 self .agent_cell = self .maze .start_cell
56+ self .step_history = []
57+ self .pos_history = []
5458 return self .agent_cell , {}
5559
56-
5760 # Takes action
5861 # Returns next state, reward, done, info
5962 def step (self , action ):
@@ -69,29 +72,68 @@ def step(self, action):
6972
7073 # Check if action runs into wall
7174 if action not in self .agent_cell .open_walls :
75+ self .history .append ((self .agent_cell .coordinates , - 0.5 , False , {}))
7276 return self .agent_cell , - .5 , False , {}
7377
7478 # Move agent
7579 else :
7680 self .agent_cell = self .maze .neighbor (self .agent_cell , action )
7781 if self .agent_cell == self .maze .end_cell : # Check if agent has reached the end
82+ self .history .append (self .agent_cell .coordinates , 1 , True , {})
7883 return self .agent_cell , 1 , True , {}
7984 else :
85+ self .history .append ((self .agent_cell .coordinates , 0 , False , {}))
8086 return self .agent_cell , 0 , False , {}
8187
8288 def save (self , filename ):
8389 with open (filename , 'wb' ) as f :
8490 pkl .dump (self , f )
8591
92+ def animate_history (self ):
93+ def update (i ):
94+ plt .clf ()
95+ self .plot ()
96+ plt .plot (self .history [i ][0 ][1 ], self .history [i ][0 ][0 ], 'yo' )
97+ plt .title (f'Step { i } , Reward: { self .history [i ][1 ]} ' )
98+ ani = FuncAnimation (plt .gcf (), update , frames = len (self .history ), repeat = False )
99+ ani .save ('maze.gif' , writer = 'ffmpeg' , fps = 10 )
100+
101+ class Grid_Cell_Maze_Environment (Maze_Environment ):
102+ def __init__ (self , width , height ):
103+ super ().__init__ (width , height )
86104
105+ # Load spike train samples
106+ # {position: [spike_trains]}
107+ with open ('Data/preprocessed_recalls_sorted.pkl' , 'rb' ) as f :
108+ self .samples = pkl .load (f )
109+
110+ def reset (self ):
111+ cell , info = super ().reset ()
112+ return self .state_to_grid_cell_spikes (cell ), info
113+
114+ def step (self , action ):
115+ obs , reward , done , info = super ().step (action )
116+ obs = self .state_to_grid_cell_spikes (obs )
117+ return obs , reward , done , info
118+
119+ def state_to_grid_cell_spikes (self , cell ):
120+ return random .choice (self .samples [cell .coordinates ])
87121
88122
89123if __name__ == '__main__' :
90- maze_env = Maze_Environment (width = 25 , height = 25 )
91- print (maze_env .maze )
92- print (f'start: { maze_env .maze .start_cell } ' )
93- print (f'end: { maze_env .maze .end_cell } ' )
94- maze_env .reset ()
95- print (maze_env .maze )
96- print (f'start: { maze_env .maze .start_cell } ' )
97- print (f'end: { maze_env .maze .end_cell } ' )
124+ from train_DQN import DQN , ReplayMemory
125+ from scripts .Chris .DQN .train_DQN import run_episode
126+
127+ device = 'cpu'
128+ n_actions = 4
129+ input_size = 300
130+ lr = 0.01
131+ policy_net = DQN (input_size , n_actions ).to (device )
132+ target_net = DQN (input_size , n_actions ).to (device )
133+ target_net .load_state_dict (policy_net .state_dict ())
134+ optimizer = optim .AdamW (policy_net .parameters (), lr = lr , amsgrad = True )
135+ memory = ReplayMemory (10000 )
136+ env = Grid_Cell_Maze_Environment (width = 5 , height = 5 )
137+
138+ run_episode (env , policy_net , 'cpu' , 100 , eps = 0.9 )
139+ env .animate_history ()
0 commit comments