Skip to content

Commit 9cdf52c

Browse files
Move SciMLBase to downstream test env, move adjoints test
SciMLBase can't resolve against RAT v4 until the companion SciMLBase PR (SciML/SciMLBase.jl#1297) is merged. Move it out of the main test deps into the downstream test Project.toml. - Remove SciMLBase from [extras] and [targets] in Project.toml - Add SciMLBase and ForwardDiff to test/downstream/Project.toml - Move adjoints.jl to test/downstream/ (uses ODEProblem from SciMLBase) - Core tests now pass without SciMLBase dependency Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 90a71e2 commit 9cdf52c

File tree

4 files changed

+105
-3
lines changed

4 files changed

+105
-3
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
8585
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8686
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8787
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
88-
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
8988
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
9089
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
9190
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -96,4 +95,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
9695
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9796

9897
[targets]
99-
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
98+
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]

test/downstream/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
[deps]
22
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
34
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
45
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
56
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
67
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
78
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
9+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
810
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
911
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1012
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
@@ -14,6 +16,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1416

1517
[compat]
1618
ArrayInterface = "7"
19+
ForwardDiff = "0.10, 1"
1720
ModelingToolkit = "8.33, 9"
1821
MonteCarloMeasurements = "1.1"
1922
NLsolve = "4"

test/downstream/adjoints.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
using RecursiveArrayTools, Zygote, ForwardDiff, Test
2+
using SciMLBase
3+
4+
function loss(x)
5+
return sum(abs2, Array(VectorOfArray([x .* i for i in 1:5])))
6+
end
7+
8+
function loss2(x)
9+
return sum(abs2, Array(DiffEqArray([x .* i for i in 1:5], 1:5)))
10+
end
11+
12+
function loss3(x)
13+
y = VectorOfArray([x .* i for i in 1:5])
14+
tmp = 0.0
15+
for i in 1:5, j in 1:5
16+
17+
tmp += y[i, j]
18+
end
19+
return tmp
20+
end
21+
22+
function loss4(x)
23+
y = DiffEqArray([x .* i for i in 1:5], 1:5)
24+
tmp = 0.0
25+
for i in 1:5, j in 1:5
26+
27+
tmp += y[i, j]
28+
end
29+
return tmp
30+
end
31+
32+
function loss5(x)
33+
return sum(abs2, Array(ArrayPartition([x .* i for i in 1:5]...)))
34+
end
35+
36+
function loss6(x)
37+
_x = ArrayPartition([x .* i for i in 1:5]...)
38+
_prob = ODEProblem((u, p, t) -> u, _x, (0, 1))
39+
return sum(abs2, Array(_prob.u0))
40+
end
41+
42+
function loss7(x)
43+
_x = VectorOfArray([x .* i for i in 1:5])
44+
return sum(abs2, _x .- 1)
45+
end
46+
47+
# use a bunch of broadcasts to test all the adjoints
48+
function loss8(x)
49+
_x = VectorOfArray([x .* i for i in 1:5])
50+
res = copy(_x)
51+
res = res .+ _x
52+
res = res .+ 1
53+
res = res .* _x
54+
res = res .* 2.0
55+
res = res .* res
56+
res = res ./ 2.0
57+
res = res ./ _x
58+
res = 3.0 .- res
59+
res = .-res
60+
res = identity.(Base.literal_pow.(^, res, Val(2)))
61+
res = tanh.(res)
62+
res = res .+ im .* res
63+
res = conj.(res) .+ real.(res) .+ imag.(res) .+ abs2.(res)
64+
return sum(abs2, res)
65+
end
66+
67+
function loss9(x)
68+
return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5])
69+
end
70+
71+
function loss10(x)
72+
voa = VectorOfArray([i * x for i in 1:5])
73+
return sum(view(voa, 2:4, 3:5))
74+
end
75+
76+
function loss11(x)
77+
voa = VectorOfArray([i * x for i in 1:5])
78+
return sum(view(voa, :, :))
79+
end
80+
81+
x = float.(6:10)
82+
loss(x)
83+
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
84+
@test Zygote.gradient(loss2, x)[1] == ForwardDiff.gradient(loss2, x)
85+
@test Zygote.gradient(loss3, x)[1] == ForwardDiff.gradient(loss3, x)
86+
@test Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x)
87+
@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x)
88+
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
89+
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
90+
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
91+
@test ForwardDiff.derivative(loss9, 0.0) ==
92+
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
93+
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
94+
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)
95+
96+
voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3))
97+
voa_gs, = Zygote.gradient(voa) do x
98+
sum(sum.(x.u))
99+
end
100+
@test voa_gs isa RecursiveArrayTools.VectorOfArray

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ end
3737
@time @safetestset "Table traits" include("tabletraits.jl")
3838
@time @safetestset "StaticArrays Tests" include("copy_static_array_test.jl")
3939
@time @safetestset "Linear Algebra Tests" include("linalg.jl")
40-
@time @safetestset "Adjoint Tests" include("adjoints.jl")
4140
@time @safetestset "Measurement Tests" include("measurements.jl")
4241
end
4342

@@ -51,6 +50,7 @@ end
5150
@time @safetestset "Event Tests with ArrayPartition" include("downstream/downstream_events.jl")
5251
@time @safetestset "Measurements and Units" include("downstream/measurements_and_units.jl")
5352
@time @safetestset "TrackerExt" include("downstream/TrackerExt.jl")
53+
@time @safetestset "Adjoint Tests" include("downstream/adjoints.jl")
5454
end
5555

5656
if GROUP == "SymbolicIndexingInterface" || GROUP == "Downstream"

0 commit comments

Comments
 (0)