-
Notifications
You must be signed in to change notification settings - Fork 736
Expand file tree
/
Copy path1-policy_iteration.py
More file actions
128 lines (103 loc) · 4.5 KB
/
1-policy_iteration.py
File metadata and controls
128 lines (103 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import random
from env import GraphicDisplay, PolicyEnv as Env
class PolicyIteration:
def __init__(self, env):
self.env = env
# 2-d list for the value function
self.value_table = [[0.0] * env.width for _ in range(env.height)]
# list of random policy (same probability of up, down, left, right)
self.policy_table = [[[0.25, 0.25, 0.25, 0.25]] * env.width
for _ in range(env.height)]
# setting terminal state
self.policy_table[2][2] = []
self.discount_factor = 0.9
def policy_evaluation(self):
next_value_table = [[0.00] * self.env.width
for _ in range(self.env.height)]
# Bellman Expectation Equation for the every states
for state in self.env.get_all_states():
value = 0.0
# keep the value function of terminal states as 0
if state == [2, 2]:
next_value_table[state[0]][state[1]] = value
continue
for action in self.env.possible_actions:
next_state = self.env.state_after_action(state, action)
reward = self.env.get_reward(state, action)
next_value = self.get_value(next_state)
value += (self.get_policy(state)[action] *
(reward + self.discount_factor * next_value))
next_value_table[state[0]][state[1]] = round(value, 2)
self.value_table = next_value_table
def policy_improvement(self):
next_policy = self.policy_table
for state in self.env.get_all_states():
if state == [2, 2]:
continue
value = -99999
max_index = []
result = [0.0, 0.0, 0.0, 0.0] # initialize the policy
# for every actions, calculate
# [reward + (discount factor) * (next state value function)]
for index, action in enumerate(self.env.possible_actions):
next_state = self.env.state_after_action(state, action)
reward = self.env.get_reward(state, action)
next_value = self.get_value(next_state)
temp = reward + self.discount_factor * next_value
# We normally can't pick multiple actions in greedy policy.
# but here we allow multiple actions with same max values
if temp == value:
max_index.append(index)
elif temp > value:
value = temp
max_index.clear()
max_index.append(index)
# probability of action
prob = 1 / len(max_index)
for index in max_index:
result[index] = prob
next_policy[state[0]][state[1]] = result
self.policy_table = next_policy
# get action according to the current policy
def get_action(self, state):
random_pick = random.randrange(100) / 100
policy = self.get_policy(state)
policy_sum = 0.0
# return the action in the index
for index, value in enumerate(policy):
policy_sum += value
if random_pick < policy_sum:
return index
# get policy of specific state
def get_policy(self, state):
if state == [2, 2]:
return 0.0
return self.policy_table[state[0]][state[1]]
def get_value(self, state):
return round(self.value_table[state[0]][state[1]], 2)
if __name__ == "__main__":
env = Env()
policy_iteration = PolicyIteration(env)
display = GraphicDisplay(policy_iteration, title="Policy Iteration")
def on_evaluate():
policy_iteration.policy_evaluation()
display.show_values(policy_iteration.value_table)
def on_improve():
policy_iteration.policy_improvement()
display.show_arrows(policy_iteration.policy_table)
def on_move():
display.move_along_policy(policy_iteration.get_action)
def on_reset():
policy_iteration.__init__(env)
display.clear()
display.agent_pos = [0, 0]
display.clicks.clear()
# Workflow: (Evaluate x several -> Improve) x several -> Move.
# Improve unlocks after Evaluate; Move unlocks after Improve.
display.buttons = [
("Evaluate", on_evaluate),
("Improve", on_improve, lambda: display.click_count("Evaluate") > 0),
("Move", on_move, lambda: display.click_count("Improve") > 0),
("Reset", on_reset),
]
display.mainloop()