-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathOregonTrailMini.py
More file actions
136 lines (112 loc) · 4.23 KB
/
OregonTrailMini.py
File metadata and controls
136 lines (112 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import random
import json
import csv
import os
import matplotlib.pyplot as plt
from RandomAgent import QLearningAgent
class OregonTrailMini:
def __init__(self):
self.reset()
def reset(self):
self.state = {
"day": 0,
"distance": 0,
"goal": 500,
"food": 10,
"health": 5,
"alive": True,
"events": [],
}
def step(self, action):
s = self.state
s["day"] += 1
s["events"] = []
reward = 0
if s["food"] <= 0: # Reduce health from hunger,no food.
s["health"] -= 1
else:
s["food"] -= 5 # Daily food consumption
if action == "rest":
if s["health"] < 5:
if random.random() < 0.5:
s["health"] += 1
elif action == "hunt":
s["food"] += random.randint(5, 15) # Food found while hunting
if random.random() < 0.3: # Possibility of injury
s["health"] -= 1
elif action == "travel":
dist = random.randint(5, 15)
s["distance"] += dist
s["events"].append(f"Traveled {dist} miles.")
if random.random() < 0.1:
s["health"] -= 1
else:
s["events"].append("Invalid action.")
if s["health"] <= 0:
s["alive"] = False
reward -= 1000
if s["distance"] >= s["goal"]:
s["alive"] = False # Changing state to dead is how the game ends
s["events"].append("You reached your goal!")
reward += 5000 - (s["day"] * 7)
return s.copy(), reward, not s["alive"], {}
def get_state(self):
return self.state.copy()
def run_agent(self, agent, episodes=1):
results = []
for ep in range(episodes):
self.reset()
state = self.get_state()
done = False
while not done:
action = agent.choose_action(state)
next_state, reward, done, _ = self.step(action)
agent.learn(state, action, reward, next_state)
state = next_state
results.append({
"episode": ep,
"survived_days": self.state["day"],
"distance": self.state["distance"],
"result": "win" if self.state["distance"] >= self.state["goal"] else "death"
})
# Write results to CSV of the run
os.makedirs("trial", exist_ok=True)
existing_files = os.listdir("trial")
i = 0
while f"training_results_{i:02}.csv" in existing_files:
i += 1
filename = f"training_results_{i:02}.csv"
filepath = os.path.join("trial", filename)
with open(filepath, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
self.generate_graphs(results, i)
def generate_graphs(self, results, run_number):
episodes = [r["episode"] for r in results]
distances = [r["distance"] for r in results]
days = [r["survived_days"] for r in results]
# Rolling average to smooth the curves
window = 500
def rolling_avg(data):
return [sum(data[max(0,i-window):i+1])/min(i+1,window) for i in range(len(data))]
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
ax1.plot(episodes, rolling_avg(distances), color='blue')
ax1.axhline(y=500, color='green', linestyle='--', label='Goal (500 miles)')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Distance')
ax1.set_title('Distance Traveled Over Training')
ax1.legend()
ax2.plot(episodes, rolling_avg(days), color='orange')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Days Survived')
ax2.set_title('Survival Days Over Training')
plt.tight_layout()
graph_path = f"trial/graph_{run_number:02}.png"
plt.savefig(graph_path)
plt.close()
print(f"Saved: {graph_path}")
if __name__ == "__main__":
game = OregonTrailMini()
agent = QLearningAgent(actions=["rest", "hunt", "travel"])
game.run_agent(agent, episodes=50000)