|
| 1 | +using FunctionWrappersWrappers |
| 2 | +using Mooncake |
| 3 | +using Test |
| 4 | + |
| 5 | +@testset "Mooncake reverse mode - single arg" begin |
| 6 | + f(x) = x^2 |
| 7 | + fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,)) |
| 8 | + |
| 9 | + rule = Mooncake.build_rrule(fww, 3.0) |
| 10 | + val, (_, dx) = Mooncake.value_and_gradient!!(rule, fww, 3.0) |
| 11 | + @test val ≈ 9.0 |
| 12 | + @test dx ≈ 6.0 |
| 13 | +end |
| 14 | + |
| 15 | +@testset "Mooncake reverse mode - multi arg" begin |
| 16 | + g(x, y) = x * y + x^2 |
| 17 | + fww = FunctionWrappersWrapper(g, (Tuple{Float64, Float64},), (Float64,)) |
| 18 | + |
| 19 | + # g(x,y) = x*y + x^2 → ∂g/∂x = y + 2x, ∂g/∂y = x |
| 20 | + rule = Mooncake.build_rrule(fww, 3.0, 4.0) |
| 21 | + val, (_, dx, dy) = Mooncake.value_and_gradient!!(rule, fww, 3.0, 4.0) |
| 22 | + @test val ≈ 21.0 |
| 23 | + @test dx ≈ 10.0 # ∂g/∂x at (3,4) = 4 + 6 |
| 24 | + @test dy ≈ 3.0 # ∂g/∂y at (3,4) = 3 |
| 25 | +end |
| 26 | + |
| 27 | +@testset "Mooncake with trig functions" begin |
| 28 | + fww_sin = FunctionWrappersWrapper(sin, (Tuple{Float64},), (Float64,)) |
| 29 | + |
| 30 | + rule = Mooncake.build_rrule(fww_sin, 1.0) |
| 31 | + val, (_, dx) = Mooncake.value_and_gradient!!(rule, fww_sin, 1.0) |
| 32 | + @test val ≈ sin(1.0) |
| 33 | + @test dx ≈ cos(1.0) |
| 34 | +end |
| 35 | + |
| 36 | +@testset "Mooncake through loss function" begin |
| 37 | + # Test that Mooncake can differentiate a loss function that calls FunctionWrappersWrapper |
| 38 | + f(x) = x[1]^2 + x[2]^2 |
| 39 | + fww = FunctionWrappersWrapper(f, (Tuple{Vector{Float64}},), (Float64,)) |
| 40 | + |
| 41 | + loss(x) = fww(x) |
| 42 | + rule = Mooncake.build_rrule(loss, [3.0, 4.0]) |
| 43 | + val, (_, dx) = Mooncake.value_and_gradient!!(rule, loss, [3.0, 4.0]) |
| 44 | + @test val ≈ 25.0 |
| 45 | + @test dx ≈ [6.0, 8.0] |
| 46 | +end |
| 47 | + |
| 48 | +@testset "Mooncake in-place function" begin |
| 49 | + # In-place functions are common in SciML (f!(du, u, p, t)) |
| 50 | + function f!(du, u, p) |
| 51 | + du[1] = p[1] * u[1] + p[2] * u[2] |
| 52 | + du[2] = p[3] * u[1] - u[2] |
| 53 | + return nothing |
| 54 | + end |
| 55 | + fww = FunctionWrappersWrapper( |
| 56 | + f!, |
| 57 | + (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}},), |
| 58 | + (Nothing,), |
| 59 | + ) |
| 60 | + |
| 61 | + function loss(p) |
| 62 | + u = [1.0, 2.0] |
| 63 | + du = similar(u) |
| 64 | + fww(du, u, p) |
| 65 | + return sum(abs2, du) |
| 66 | + end |
| 67 | + |
| 68 | + rule = Mooncake.build_rrule(loss, [1.0, 2.0, 3.0]) |
| 69 | + val, (_, dp) = Mooncake.value_and_gradient!!(rule, loss, [1.0, 2.0, 3.0]) |
| 70 | + # f!(du, [1,2], [1,2,3]) → du = [1*1+2*2, 3*1-2] = [5, 1] |
| 71 | + # loss = 25 + 1 = 26 |
| 72 | + @test val ≈ 26.0 |
| 73 | + # ∂loss/∂p1 = 2*du[1]*u[1] = 2*5*1 = 10 |
| 74 | + # ∂loss/∂p2 = 2*du[1]*u[2] = 2*5*2 = 20 |
| 75 | + # ∂loss/∂p3 = 2*du[2]*u[1] = 2*1*1 = 2 |
| 76 | + @test dp ≈ [10.0, 20.0, 2.0] |
| 77 | +end |
0 commit comments