@@ -253,7 +253,8 @@ def record_gif(algo, rep_name, n_episodes=3, max_steps_per_ep=200, fps=10,
253253 panel_h = 60
254254 surface = pygame .Surface ((grid_w , grid_w + panel_h ))
255255
256- name = f"{ algo } _sarsa + { rep_name } "
256+ algo_label = "double_dqn" if algo == "mlp" else f"{ algo } _sarsa"
257+ name = f"{ algo_label } + { rep_name } "
257258 frames = []
258259
259260 pygame .init ()
@@ -279,7 +280,6 @@ def record_gif(algo, rep_name, n_episodes=3, max_steps_per_ep=200, fps=10,
279280 pygame .quit ()
280281
281282 if frames :
282- algo_label = "double_dqn" if algo == "mlp" else f"{ algo } _sarsa"
283283 filepath = os .path .join (RECORDINGS_DIR , f"{ algo_label } __{ rep_name } .gif" )
284284 frames [0 ].save (
285285 filepath ,
@@ -305,12 +305,13 @@ def watch_live(algo, rep_name, fps=10, weights_dir=WEIGHTS_DIR, seed=None):
305305 grid_w = GRID_SIZE * CELL_SIZE
306306 panel_h = 60
307307 screen = pygame .display .set_mode ((grid_w , grid_w + panel_h ))
308+ algo_label = "double_dqn" if algo == "mlp" else f"{ algo } _sarsa"
308309 seed_label = f" seed={ seed } " if seed is not None else ""
309- pygame .display .set_caption (f"Snake RL — { algo } _sarsa + { rep_name } { seed_label } " )
310+ pygame .display .set_caption (f"Snake RL — { algo_label } + { rep_name } { seed_label } " )
310311 clock = pygame .time .Clock ()
311312
312313 wdir_label = f" [{ weights_dir } ]" if weights_dir != WEIGHTS_DIR else ""
313- name = f"{ algo } _sarsa + { rep_name } { seed_label } { wdir_label } "
314+ name = f"{ algo_label } + { rep_name } { seed_label } { wdir_label } "
314315 obs , _ = env .reset ()
315316 running = True
316317 total_episodes = 0
0 commit comments