I am trying to integrate RL.jl with Dojo.jl in dojo-sim/Dojo.jl#9. However current MultiThreadEnv wrapper fails to work with the BoxSpace defined there, even though we have Base.in, Base.length and Random.rand are defined. Issue seems to be in the use of selectdim in the current implementation.
Wrapper states are a Vector of Vector
N_ENV = 2
env_vec = [Dojo.DojoRLEnv("cartpole") for i in 1:N_ENV]
env = MultiThreadEnv(env_vec)
env.states
results in:
julia> env.states
2-element Vector{Vector{Float64}}:
[0.8223080697554248, 1.6284598982949312, 0.464146354303141, -0.2796568159199919]
[0.0, 0.0, 3.141592653589793, 0.0]
However selectdim(env.states, 1, 1) gives a 0-dim view:
julia> selectdim(env.states, 1, 1)
0-dimensional view(::Vector{Vector{Float64}}, 1) with eltype Vector{Float64}:
[0.8223080697554248, 1.6284598982949312, 0.464146354303141, -0.2796568159199919]
which cannot take the vector observations as currently implemented:
julia> selectdim(env.states, 1, 1) .= randn(4)
ERROR: DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions")
Stacktrace:
[1] check_broadcast_shape
@ ./broadcast.jl:535 [inlined]
[2] check_broadcast_axes
@ ./broadcast.jl:543 [inlined]
[3] instantiate
@ ./broadcast.jl:284 [inlined]
[4] materialize!
@ ./broadcast.jl:871 [inlined]
[5] materialize!(dest::SubArray{Vector{Float64}, 0, Vector{Vector{Float64}}, Tuple{Int64}, true}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(identity), Tuple{Vector{Float64}}})
@ Base.Broadcast ./broadcast.jl:868
[6] top-level scope
@ REPL[28]:1
[7] top-level scope
@ ~/.julia/packages/CUDA/Axzxe/src/initialization.jl:52
I am trying to integrate RL.jl with Dojo.jl in dojo-sim/Dojo.jl#9. However current
MultiThreadEnvwrapper fails to work with theBoxSpacedefined there, even though we haveBase.in,Base.lengthandRandom.randare defined. Issue seems to be in the use ofselectdimin the current implementation.Wrapper states are a Vector of Vector
results in:
However
selectdim(env.states, 1, 1)gives a 0-dim view:which cannot take the vector observations as currently implemented: