Skip to content

Commit 100d576

Browse files
Merge pull request #594 from ChrisRackauckas-Claude/fix-downstream-tests-odeq7
Fix downstream test sources for OrdinaryDiffEq 7
2 parents af4895e + 39c3708 commit 100d576

5 files changed

Lines changed: 18 additions & 5 deletions

File tree

src/vector_of_array.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,15 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
797797
end
798798
symtype = symbolic_type(_arg)
799799
elsymtype = symbolic_type(eltype(_arg))
800+
# For symbolic indices, `A[sym, :]` is semantically equivalent to `A[sym]`
801+
# (the no-args symbolic getindex already returns the full timeseries).
802+
# Routing through the no-args path here also avoids a broadcast shape bug
803+
# in SymbolicIndexingInterface's `GetStateIndex` when the underlying index
804+
# is a `Vector{Int}` (array-symbolic) combined with a `Colon` time index.
805+
if (symtype != NotSymbolic() || elsymtype != NotSymbolic()) &&
806+
length(args) == 1 && args[1] === Colon()
807+
return A[_arg]
808+
end
800809

801810
return if symtype == NotSymbolic() && elsymtype == NotSymbolic()
802811
if _arg isa Union{Tuple, AbstractArray} &&

test/downstream/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
44
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
55
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
66
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
7+
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
78
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
9+
RecursiveArrayToolsShorthandConstructors = "39fb7555-b4ad-4efd-8abe-30331df017d3"
810
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
911
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1012
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
@@ -18,6 +20,7 @@ ModelingToolkit = "8.33, 9, 10, 11"
1820
MonteCarloMeasurements = "1.1"
1921
NLsolve = "4"
2022
OrdinaryDiffEq = "6.31, 7"
23+
OrdinaryDiffEqRosenbrock = "1, 2"
2124
StaticArrays = "1"
2225
SymbolicIndexingInterface = "0.3"
2326
Tables = "1"

test/downstream/downstream_events.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using OrdinaryDiffEq, StaticArrays, RecursiveArrayTools
1+
using OrdinaryDiffEq, StaticArrays, RecursiveArrayTools, RecursiveArrayToolsShorthandConstructors
22
u0 = AP[SVector{1}(50.0), SVector{1}(0.0)]
33
tspan = (0.0, 15.0)
44

test/downstream/odesolve.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface, StaticArrays
1+
using OrdinaryDiffEq, OrdinaryDiffEqRosenbrock, NLsolve, RecursiveArrayTools,
2+
RecursiveArrayToolsShorthandConstructors, Test, ArrayInterface, StaticArrays
23
function lorenz(du, u, p, t)
34
du[1] = 10.0 * (u[2] - u[1])
45
du[2] = u[1] * (28.0 - u[3]) - u[2]
@@ -9,7 +10,7 @@ u0 = AP[[1.0, 0.0], [0.0]]
910
tspan = (0.0, 100.0)
1011
prob = ODEProblem(lorenz, u0, tspan)
1112
sol = solve(prob, Tsit5())
12-
sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = false)))
13+
sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = AutoFiniteDiff())))
1314
sol = solve(prob, AutoTsit5(Rosenbrock23()))
1415

1516
@test all(Array(sol) .== sol)
@@ -72,4 +73,4 @@ end
7273
u = fill(SVector{2}(ones(2)), 2, 3)
7374
ode = ODEProblem(rhs!, VectorOfArray(u), (0.0, 1.0))
7475
sol = solve(ode, Tsit5())
75-
@test SciMLBase.successful_retcode(sol)
76+
@test successful_retcode(sol)

test/downstream/symbol_indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ ts = 0:0.5:10
7777
sol_ts = sol(ts)
7878
@assert sol_ts isa DiffEqArray
7979
test_tables_interface(
80-
sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")],
80+
sol_ts, [:timestamp; Symbol.(string.(unknowns(lv)))],
8181
hcat(ts, Array(sol_ts)')
8282
)
8383

0 commit comments

Comments
 (0)