|
1 | 1 | using RecursiveArrayTools, Zygote, ForwardDiff, Test |
2 | 2 | using SciMLBase |
3 | 3 |
|
4 | | -function loss(x) |
5 | | - return sum(abs2, Array(VectorOfArray([x .* i for i in 1:5]))) |
6 | | -end |
7 | | - |
8 | | -function loss2(x) |
9 | | - return sum(abs2, Array(DiffEqArray([x .* i for i in 1:5], 1:5))) |
10 | | -end |
11 | | - |
12 | | -function loss3(x) |
13 | | - y = VectorOfArray([x .* i for i in 1:5]) |
14 | | - tmp = 0.0 |
15 | | - for i in 1:5, j in 1:5 |
16 | | - |
17 | | - tmp += y[i, j] |
18 | | - end |
19 | | - return tmp |
20 | | -end |
21 | | - |
22 | | -function loss4(x) |
23 | | - y = DiffEqArray([x .* i for i in 1:5], 1:5) |
24 | | - tmp = 0.0 |
25 | | - for i in 1:5, j in 1:5 |
26 | | - |
27 | | - tmp += y[i, j] |
28 | | - end |
29 | | - return tmp |
30 | | -end |
31 | | - |
32 | | -function loss5(x) |
33 | | - return sum(abs2, Array(ArrayPartition([x .* i for i in 1:5]...))) |
34 | | -end |
35 | | - |
36 | | -function loss6(x) |
| 4 | +# Test that ArrayPartition works through ODEProblem construction |
| 5 | +# (requires SciMLBase, so this is a downstream test) |
| 6 | +function loss_odeproblem(x) |
37 | 7 | _x = ArrayPartition([x .* i for i in 1:5]...) |
38 | 8 | _prob = ODEProblem((u, p, t) -> u, _x, (0, 1)) |
39 | 9 | return sum(abs2, Array(_prob.u0)) |
40 | 10 | end |
41 | 11 |
|
42 | | -function loss7(x) |
43 | | - _x = VectorOfArray([x .* i for i in 1:5]) |
44 | | - return sum(abs2, _x .- 1) |
45 | | -end |
46 | | - |
47 | | -# use a bunch of broadcasts to test all the adjoints |
48 | | -function loss8(x) |
49 | | - _x = VectorOfArray([x .* i for i in 1:5]) |
50 | | - res = copy(_x) |
51 | | - res = res .+ _x |
52 | | - res = res .+ 1 |
53 | | - res = res .* _x |
54 | | - res = res .* 2.0 |
55 | | - res = res .* res |
56 | | - res = res ./ 2.0 |
57 | | - res = res ./ _x |
58 | | - res = 3.0 .- res |
59 | | - res = .-res |
60 | | - res = identity.(Base.literal_pow.(^, res, Val(2))) |
61 | | - res = tanh.(res) |
62 | | - res = res .+ im .* res |
63 | | - res = conj.(res) .+ real.(res) .+ imag.(res) .+ abs2.(res) |
64 | | - return sum(abs2, res) |
65 | | -end |
66 | | - |
67 | | -function loss9(x) |
68 | | - return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5]) |
69 | | -end |
70 | | - |
71 | | -function loss10(x) |
72 | | - voa = VectorOfArray([i * x for i in 1:5]) |
73 | | - return sum(view(voa, 2:4, 3:5)) |
74 | | -end |
75 | | - |
76 | | -function loss11(x) |
77 | | - voa = VectorOfArray([i * x for i in 1:5]) |
78 | | - return sum(view(voa, :, :)) |
79 | | -end |
80 | | - |
81 | 12 | x = float.(6:10) |
82 | | -loss(x) |
83 | | -@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x) |
84 | | -@test Zygote.gradient(loss2, x)[1] == ForwardDiff.gradient(loss2, x) |
85 | | -@test Zygote.gradient(loss3, x)[1] == ForwardDiff.gradient(loss3, x) |
86 | | -@test Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x) |
87 | | -@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x) |
88 | | -@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x) |
89 | | -@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x) |
90 | | -@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x) |
91 | | -@test ForwardDiff.derivative(loss9, 0.0) == |
92 | | - VectorOfArray([collect((3i):(3i + 3)) for i in 1:5]) |
93 | | -@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x) |
94 | | -@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x) |
95 | | - |
96 | | -voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3)) |
97 | | -voa_gs, = Zygote.gradient(voa) do x |
98 | | - sum(sum.(x.u)) |
99 | | -end |
100 | | -@test voa_gs isa RecursiveArrayTools.VectorOfArray |
| 13 | +@test Zygote.gradient(loss_odeproblem, x)[1] == ForwardDiff.gradient(loss_odeproblem, x) |
0 commit comments