Skip to content

Commit efc2aaf

Browse files
committed
Updated README with updated embedded GIFs
1 parent 4719cac commit efc2aaf

5 files changed

Lines changed: 7 additions & 6 deletions

File tree

record_all_gifs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def record_one_gif(algo, rep_name, n_episodes=3, max_steps_per_ep=300, fps=10, s
7272
pygame.font.init()
7373
surface = pygame.Surface((grid_w, grid_w + panel_h))
7474

75-
label = f"{algo}_sarsa + {rep_name}"
75+
algo_label = "double_dqn" if algo == "mlp" else f"{algo}_sarsa"
76+
label = f"{algo_label} + {rep_name}"
7677
frames = []
7778

7879
for ep in range(n_episodes):
@@ -92,7 +93,6 @@ def record_one_gif(algo, rep_name, n_episodes=3, max_steps_per_ep=300, fps=10, s
9293
pygame.quit()
9394

9495
if frames:
95-
algo_label = "double_dqn" if algo == "mlp" else f"{algo}_sarsa"
9696
filepath = os.path.join(RECORDINGS_DIR, f"{algo_label}__{rep_name}.gif")
9797
frames[0].save(
9898
filepath,

record_gameplay.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ def record_gif(algo, rep_name, n_episodes=3, max_steps_per_ep=200, fps=10,
253253
panel_h = 60
254254
surface = pygame.Surface((grid_w, grid_w + panel_h))
255255

256-
name = f"{algo}_sarsa + {rep_name}"
256+
algo_label = "double_dqn" if algo == "mlp" else f"{algo}_sarsa"
257+
name = f"{algo_label} + {rep_name}"
257258
frames = []
258259

259260
pygame.init()
@@ -279,7 +280,6 @@ def record_gif(algo, rep_name, n_episodes=3, max_steps_per_ep=200, fps=10,
279280
pygame.quit()
280281

281282
if frames:
282-
algo_label = "double_dqn" if algo == "mlp" else f"{algo}_sarsa"
283283
filepath = os.path.join(RECORDINGS_DIR, f"{algo_label}__{rep_name}.gif")
284284
frames[0].save(
285285
filepath,
@@ -305,12 +305,13 @@ def watch_live(algo, rep_name, fps=10, weights_dir=WEIGHTS_DIR, seed=None):
305305
grid_w = GRID_SIZE * CELL_SIZE
306306
panel_h = 60
307307
screen = pygame.display.set_mode((grid_w, grid_w + panel_h))
308+
algo_label = "double_dqn" if algo == "mlp" else f"{algo}_sarsa"
308309
seed_label = f" seed={seed}" if seed is not None else ""
309-
pygame.display.set_caption(f"Snake RL — {algo}_sarsa + {rep_name}{seed_label}")
310+
pygame.display.set_caption(f"Snake RL — {algo_label} + {rep_name}{seed_label}")
310311
clock = pygame.time.Clock()
311312

312313
wdir_label = f" [{weights_dir}]" if weights_dir != WEIGHTS_DIR else ""
313-
name = f"{algo}_sarsa + {rep_name}{seed_label}{wdir_label}"
314+
name = f"{algo_label} + {rep_name}{seed_label}{wdir_label}"
314315
obs, _ = env.reset()
315316
running = True
316317
total_episodes = 0

recordings/double_dqn__compact.gif

166 Bytes
Loading
-4 Bytes
Loading

recordings/double_dqn__local.gif

107 Bytes
Loading

0 commit comments

Comments
 (0)