Skip to content

Commit 8487bf1

Browse files
committed
Renamed agent-level "MLP" references to "DQN" in code comments and messages.
- run_experiments.py, record_gameplay.py, record_all_gifs.py: print/warning messages and CLI usage examples - save_load.py: module docstring + save/load print messages - train.py: docstring renamed and corrected (was claiming all agents use semi-gradient SARSA — the DQN agent actually uses off-policy Q-learning with replay) - test_double_dqn.py: print label
1 parent efc2aaf commit 8487bf1

6 files changed

Lines changed: 18 additions & 16 deletions

File tree

record_all_gifs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
HAS_TORCH = True
3737
except ImportError:
3838
HAS_TORCH = False
39-
print("WARNING: PyTorch not installed. MLP GIFs will be skipped.")
39+
print("WARNING: PyTorch not installed. DQN GIFs will be skipped.")
4040

4141
GRID_SIZE = 20
4242
CELL_SIZE = 30

record_gameplay.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def make_agent(algo, rep_name, weights_dir=WEIGHTS_DIR, name=None):
9898
"""
9999
Create a fresh agent for the given algo/rep.
100100
101-
For MLP agents, reads the architecture (hidden_dims) directly from the
102-
checkpoint file so v1 (hidden_dim=128) and v2 (hidden_dims=(256,128))
103-
weights both load correctly without any flags.
101+
For the DQN agent, reads the architecture (hidden_dims) directly from
102+
the checkpoint file so any saved architecture loads correctly without
103+
any flags.
104104
"""
105105
rep = REPRESENTATIONS[rep_name]()
106106

@@ -121,7 +121,7 @@ def make_agent(algo, rep_name, weights_dir=WEIGHTS_DIR, name=None):
121121

122122
elif algo == "mlp":
123123
if not HAS_TORCH:
124-
raise ImportError("PyTorch required for MLP")
124+
raise ImportError("PyTorch required for DQN")
125125
mlp_name = name or weight_name(algo, rep_name)
126126
pt_path = os.path.join(weights_dir, f"{mlp_name}.pt")
127127
if os.path.exists(pt_path):
@@ -392,7 +392,7 @@ def main():
392392
formatter_class=argparse.RawDescriptionHelpFormatter,
393393
epilog="""
394394
examples:
395-
# Watch MLP on compact representation
395+
# Watch DQN on compact representation
396396
python record_gameplay.py --watch mlp compact
397397
398398
# Watch a specific seed

run_experiments.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
3. max_steps_factor=3 (1,200 steps max)
77
4. Larger tile hash table (262,144)
88
5. Epsilon decay over 80% of training
9-
6. Default 20,000 episodes (MLP needs more steps to converge with replay)
9+
6. Default 20,000 episodes (DQN needs more steps to converge with replay)
1010
1111
Usage:
1212
python run_experiments.py # all configs
@@ -42,7 +42,7 @@
4242
HAS_TORCH = True
4343
except ImportError:
4444
HAS_TORCH = False
45-
print("WARNING: PyTorch not installed. MLP experiments will be skipped.")
45+
print("WARNING: PyTorch not installed. DQN experiments will be skipped.")
4646

4747

4848
# ============================================================
@@ -189,7 +189,7 @@ def make_agent(algo: str, rep_instance, config: ExperimentConfig, seed: int, env
189189
return agent
190190
elif algo == "mlp":
191191
if not HAS_TORCH:
192-
raise ImportError("PyTorch required for MLP agent")
192+
raise ImportError("PyTorch required for DQN agent")
193193
p = config.algo_params
194194
return DoubleDQNAgent(
195195
representation=rep_instance,

snake_rl/agents/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
- Logging (per-episode metrics via RunLogger)
99
- Progress reporting
1010
11-
The same loop is used for linear FA, tile coding, and MLP agents,
12-
since they all use the semi-gradient SARSA update — only the
13-
function approximation differs.
11+
The same loop is used for linear FA, tile coding, and DQN agents.
12+
Linear FA and tile coding agents do a semi-gradient SARSA update on
13+
each step; the DQN agent stores the transition and runs a mini-batch
14+
Double-Q update internally. Only the agent's `.update()` method differs
15+
— the surrounding loop interface is identical.
1416
"""
1517

1618
import time

snake_rl/utils/save_load.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Save and load trained agent weights.
33
4-
Supports all three agent types: Linear FA, Tile Coding, MLP.
4+
Supports all three agent types: Linear FA, Tile Coding, DQN.
55
"""
66

77
import os
@@ -45,7 +45,7 @@ def save_agent(agent, name: str, directory: str = WEIGHTS_DIR):
4545
}, filepath)
4646
print(f"Saved {agent_type} weights to {filepath}")
4747
except ImportError:
48-
print("PyTorch required to save MLP weights")
48+
print("PyTorch required to save DQN weights")
4949

5050
else:
5151
raise ValueError(f"Unknown agent type: {agent_type}")
@@ -79,7 +79,7 @@ def load_agent_weights(agent, name: str, directory: str = WEIGHTS_DIR):
7979
agent.q_net.load_state_dict(checkpoint["q_net"])
8080
print(f"Loaded {agent_type} weights from {filepath}")
8181
except ImportError:
82-
print("PyTorch required to load MLP weights")
82+
print("PyTorch required to load DQN weights")
8383

8484
else:
8585
raise ValueError(f"Unknown agent type: {agent_type}")

tests/test_double_dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_training_improves(self):
249249
early = np.mean(logger.scores[:300])
250250
late = np.mean(logger.scores[-300:])
251251
assert late > early, f"Scores should improve: early={early:.2f}, late={late:.2f}"
252-
print(f" [MLP] Early: {early:.2f} → Late: {late:.2f}")
252+
print(f" [DQN] Early: {early:.2f} → Late: {late:.2f}")
253253

254254

255255
# ============================================================

0 commit comments

Comments
 (0)