Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,15 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
end
symtype = symbolic_type(_arg)
elsymtype = symbolic_type(eltype(_arg))
# For symbolic indices, `A[sym, :]` is semantically equivalent to `A[sym]`
# (the no-args symbolic getindex already returns the full timeseries).
# Routing through the no-args path here also avoids a broadcast shape bug
# in SymbolicIndexingInterface's `GetStateIndex` when the underlying index
# is a `Vector{Int}` (array-symbolic) combined with a `Colon` time index.
if (symtype != NotSymbolic() || elsymtype != NotSymbolic()) &&
length(args) == 1 && args[1] === Colon()
return A[_arg]
end

return if symtype == NotSymbolic() && elsymtype == NotSymbolic()
if _arg isa Union{Tuple, AbstractArray} &&
Expand Down
3 changes: 3 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
RecursiveArrayToolsShorthandConstructors = "39fb7555-b4ad-4efd-8abe-30331df017d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand All @@ -18,6 +20,7 @@ ModelingToolkit = "8.33, 9, 10, 11"
MonteCarloMeasurements = "1.1"
NLsolve = "4"
OrdinaryDiffEq = "6.31, 7"
OrdinaryDiffEqRosenbrock = "1, 2"
StaticArrays = "1"
SymbolicIndexingInterface = "0.3"
Tables = "1"
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/downstream_events.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq, StaticArrays, RecursiveArrayTools
using OrdinaryDiffEq, StaticArrays, RecursiveArrayTools, RecursiveArrayToolsShorthandConstructors
u0 = AP[SVector{1}(50.0), SVector{1}(0.0)]
tspan = (0.0, 15.0)

Expand Down
7 changes: 4 additions & 3 deletions test/downstream/odesolve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface, StaticArrays
using OrdinaryDiffEq, OrdinaryDiffEqRosenbrock, NLsolve, RecursiveArrayTools,
RecursiveArrayToolsShorthandConstructors, Test, ArrayInterface, StaticArrays
function lorenz(du, u, p, t)
du[1] = 10.0 * (u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
Expand All @@ -9,7 +10,7 @@ u0 = AP[[1.0, 0.0], [0.0]]
tspan = (0.0, 100.0)
prob = ODEProblem(lorenz, u0, tspan)
sol = solve(prob, Tsit5())
sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = false)))
sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = AutoFiniteDiff())))
sol = solve(prob, AutoTsit5(Rosenbrock23()))

@test all(Array(sol) .== sol)
Expand Down Expand Up @@ -72,4 +73,4 @@ end
u = fill(SVector{2}(ones(2)), 2, 3)
ode = ODEProblem(rhs!, VectorOfArray(u), (0.0, 1.0))
sol = solve(ode, Tsit5())
@test SciMLBase.successful_retcode(sol)
@test successful_retcode(sol)
2 changes: 1 addition & 1 deletion test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ts = 0:0.5:10
sol_ts = sol(ts)
@assert sol_ts isa DiffEqArray
test_tables_interface(
sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")],
sol_ts, [:timestamp; Symbol.(string.(unknowns(lv)))],
hcat(ts, Array(sol_ts)')
)

Expand Down
Loading