1+ using ReinforcementLearning
2+ using Flux
3+ using Flux. Losses
4+
5+ using Random
6+ using Dojo
7+
8+ function RL. Experiment (
9+ :: Val{:JuliaRL} ,
10+ :: Val{:DDPG} ,
11+ :: Val{:DojoCartpole} ,
12+ :: Nothing ,
13+ save_dir = nothing ,
14+ seed = 42
15+ )
16+
17+ rng = MersenneTwister (seed)
18+ inner_env = Dojo. DojoRLEnv (" cartpole" )
19+ Random. seed! (inner_env, seed)
20+ # TODO
21+ low = - 5.0
22+ high = 5.0
23+ ns, na = length (state (inner_env)), length (action_space (inner_env))
24+ @show na
25+ A = Dojo. BoxSpace (na)
26+ env = ActionTransformedEnv (
27+ inner_env;
28+ action_mapping = x -> low .+ (x .+ 1 ) .* 0.5 .* (high .- low),
29+ action_space_mapping = _ -> A
30+ )
31+
32+ init = glorot_uniform (rng)
33+
34+ create_actor () = Chain (
35+ Dense (ns, 30 , relu; init = init),
36+ Dense (30 , 30 , relu; init = init),
37+ Dense (30 , na, tanh; init = init),
38+ )
39+ create_critic () = Chain (
40+ Dense (ns + na, 30 , relu; init = init),
41+ Dense (30 , 30 , relu; init = init),
42+ Dense (30 , 1 ; init = init),
43+ )
44+
45+ agent = Agent (
46+ policy = DDPGPolicy (
47+ behavior_actor = NeuralNetworkApproximator (
48+ model = create_actor (),
49+ optimizer = ADAM (),
50+ ),
51+ behavior_critic = NeuralNetworkApproximator (
52+ model = create_critic (),
53+ optimizer = ADAM (),
54+ ),
55+ target_actor = NeuralNetworkApproximator (
56+ model = create_actor (),
57+ optimizer = ADAM (),
58+ ),
59+ target_critic = NeuralNetworkApproximator (
60+ model = create_critic (),
61+ optimizer = ADAM (),
62+ ),
63+ γ = 0.99f0 ,
64+ ρ = 0.995f0 ,
65+ na = na,
66+ batch_size = 64 ,
67+ start_steps = 1000 ,
68+ start_policy = RandomPolicy (A; rng = rng),
69+ update_after = 1000 ,
70+ update_freq = 1 ,
71+ act_limit = 1.0 ,
72+ act_noise = 0.1 ,
73+ rng = rng,
74+ ),
75+ trajectory = CircularArraySARTTrajectory (
76+ capacity = 10000 ,
77+ state = Vector{Float32} => (ns,),
78+ action = Float32 => (na, ),
79+ ),
80+ )
81+
82+ stop_condition = StopAfterStep (10_000 , is_show_progress= ! haskey (ENV , " CI" ))
83+ hook = TotalRewardPerEpisode ()
84+ Experiment (agent, env, stop_condition, hook, " # Dojo Cartpole with DDPG" )
85+ end
86+
87+ ex = E ` JuliaRL_DDPG_DojoCartpole`
88+ run (ex)
0 commit comments