-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFspaceEnv.jl
More file actions
113 lines (99 loc) · 3.1 KB
/
FspaceEnv.jl
File metadata and controls
113 lines (99 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
export FspaceEnv
struct FspaceEnvParams{T}
accuracy::T
statesize:T
embeddingsize:Int
max_steps::Int
end
Base.show(io::IO, params::FspaceEnvParams) = print(
io,
join(["$p=$(getfield(params, p))" for p in fieldnames(FspaceEnvParams)], ","),
)
function FspaceEnvParams{T}(;
accuracy=0.7,
statesize=1,
embeddingsize=10,
max_steps=1e5,
) where {T}
FspaceEnvParams{T}(
accuracy,
statesize,
embeddingsize,
max_steps,
)
end
mutable struct FspaceEnv{T,ACT} <: AbstractEnv
params::FspaceEnvParams{T}
state::Vector{T}
action::ACT
done::Bool
t::Int
rng::AbstractRNG
end
"""
FspaceEnv(;kwargs...)
# Keyword arguments
- `T = Float64`
- `continuous = false`
- `rng = Random.GLOBAL_RNG`
- `accuracy = T(0.7)`
"""
function FspaceEnv(; T=Float64, continuous=false, rng=Random.GLOBAL_RNG, kwargs...)
params = FspaceEnvParams{T}(; kwargs...)
env = FspaceEnv(params, zeros(T, 4), continuous ? zero(T) : zero(Int), false, 0, rng)
reset!(env)
env
end
FspaceEnv{T}(; kwargs...) where {T} = FspaceEnv(T=T, kwargs...)
Random.seed!(env::FspaceEnv, seed) = Random.seed!(env.rng, seed)
RLBase.reward(env::FspaceEnv{T}) where {T} = env.done ? zero(T) : one(T)
RLBase.is_terminated(env::FspaceEnv) = env.done
RLBase.state(env::FspaceEnv) = env.state
function RLBase.state_space(env::FspaceEnv{T}) where {T}
((-2 * env.params.spacesize) .. (2 * env.params.spacesize)) ×
(typemin(T) .. typemax(T)) ×
((-2 * env.params.embeddingsize) .. (2 * env.params.embeddingsize)) ×
(typemin(T) .. typemax(T))
end
RLBase.action_space(env::FspaceEnv{<:AbstractFloat,Int}, player) = Base.OneTo(2)
RLBase.action_space(env::FspaceEnv{<:AbstractFloat,<:AbstractFloat}, player) = -1.0 .. 1.0
function RLBase.reset!(env::CartPoleEnv{T}) where {T}
env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05)
env.t = 0
env.action = rand(env.rng, action_space(env))
env.done = false
nothing
end
function (env::CartPoleEnv)(a::AbstractFloat)
@assert a in action_space(env)
env.action = a
_step!(env, a)
end
function (env::CartPoleEnv)(a::Int)
@assert a in action_space(env)
env.action = a
_step!(env, a == 2 ? 1 : -1)
end
function _step!(env::CartPoleEnv, a)
env.t += 1
force = a * env.params.forcemag
x, xdot, theta, thetadot = env.state
costheta = cos(theta)
sintheta = sin(theta)
tmp = (force + env.params.polemasslength * thetadot^2 * sintheta) / env.params.totalmass
thetaacc =
(env.params.gravity * sintheta - costheta * tmp) / (
env.params.halflength *
(4 / 3 - env.params.masspole * costheta^2 / env.params.totalmass)
)
xacc = tmp - env.params.polemasslength * thetaacc * costheta / env.params.totalmass
env.state[1] += env.params.dt * xdot
env.state[2] += env.params.dt * xacc
env.state[3] += env.params.dt * thetadot
env.state[4] += env.params.dt * thetaacc
env.done =
abs(env.state[1]) > env.params.xthreshold ||
abs(env.state[3]) > env.params.thetathreshold ||
env.t > env.params.max_steps
nothing
end