@@ -24,7 +24,6 @@ def __init__(self, width, height):
2424 self .agent_cell = self .maze .start_cell
2525 self .num_actions = 4
2626 self .history = [(self .agent_cell .coordinates , 0 , False , {})] # (state, reward, done, info)
27- self .pos_history = []
2827
2928 def plot (self ):
3029 # Box around maze
@@ -53,8 +52,7 @@ def plot(self):
5352
5453 def reset (self ):
5554 self .agent_cell = self .maze .start_cell
56- self .step_history = []
57- self .pos_history = []
55+ self .history = [(self .agent_cell .coordinates , 0 , False , {})]
5856 return self .agent_cell , {}
5957
6058 # Takes action
@@ -72,14 +70,14 @@ def step(self, action):
7270
7371 # Check if action runs into wall
7472 if action not in self .agent_cell .open_walls :
75- self .history .append ((self .agent_cell .coordinates , - 0.5 , False , {}))
76- return self .agent_cell , - .5 , False , {}
73+ self .history .append ((self .agent_cell .coordinates , - 0.1 , False , {}))
74+ return self .agent_cell , - 0.01 , False , {}
7775
7876 # Move agent
7977 else :
8078 self .agent_cell = self .maze .neighbor (self .agent_cell , action )
8179 if self .agent_cell == self .maze .end_cell : # Check if agent has reached the end
82- self .history .append (self .agent_cell .coordinates , 1 , True , {})
80+ self .history .append (( self .agent_cell .coordinates , 1 , True , {}) )
8381 return self .agent_cell , 1 , True , {}
8482 else :
8583 self .history .append ((self .agent_cell .coordinates , 0 , False , {}))
@@ -89,14 +87,14 @@ def save(self, filename):
8987 with open (filename , 'wb' ) as f :
9088 pkl .dump (self , f )
9189
92- def animate_history (self ):
90+ def animate_history (self , file_name = 'maze.gif' ):
9391 def update (i ):
9492 plt .clf ()
9593 self .plot ()
9694 plt .plot (self .history [i ][0 ][1 ], self .history [i ][0 ][0 ], 'yo' )
9795 plt .title (f'Step { i } , Reward: { self .history [i ][1 ]} ' )
9896 ani = FuncAnimation (plt .gcf (), update , frames = len (self .history ), repeat = False )
99- ani .save ('maze.gif' , writer = 'ffmpeg' , fps = 10 )
97+ ani .save (file_name , writer = 'ffmpeg' , fps = 5 )
10098
10199class Grid_Cell_Maze_Environment (Maze_Environment ):
102100 def __init__ (self , width , height ):
@@ -132,7 +130,7 @@ def state_to_grid_cell_spikes(self, cell):
132130 target_net = DQN (input_size , n_actions ).to (device )
133131 target_net .load_state_dict (policy_net .state_dict ())
134132 optimizer = optim .AdamW (policy_net .parameters (), lr = lr , amsgrad = True )
135- memory = ReplayMemory (10000 )
133+ memory = ReplayMemory (1000 )
136134 env = Grid_Cell_Maze_Environment (width = 5 , height = 5 )
137135
138136 run_episode (env , policy_net , 'cpu' , 100 , eps = 0.9 )
0 commit comments