Skip to content

Commit 77592ef

Browse files
Merge pull request #39 from ChrisRackauckas-Claude/mooncake-extension
Add Mooncake extension for FunctionWrappersWrapper
2 parents ccf9ed1 + b943a46 commit 77592ef

4 files changed

Lines changed: 118 additions & 2 deletions

File tree

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FunctionWrappersWrappers"
22
uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf"
33
authors = ["Chris Elrod <elrodc@gmail.com> and contributors"]
4-
version = "1.1.0"
4+
version = "1.2.0"
55

66
[deps]
77
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -11,23 +11,27 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
1111
[weakdeps]
1212
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1313
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
14+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1415

1516
[extensions]
1617
FunctionWrappersWrappersEnzymeExt = ["Enzyme", "EnzymeCore"]
18+
FunctionWrappersWrappersMooncakeExt = "Mooncake"
1719

1820
[compat]
1921
Enzyme = "0.13"
2022
EnzymeCore = "0.8"
2123
FunctionWrappers = "1"
24+
Mooncake = "0.5"
2225
PrecompileTools = "1"
2326
TruncatedStacktraces = "1"
2427
julia = "1.10"
2528

2629
[extras]
2730
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2831
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
32+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2933
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3034
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3135

3236
[targets]
33-
test = ["Pkg", "Test", "Enzyme", "EnzymeCore"]
37+
test = ["Pkg", "Test", "Enzyme", "EnzymeCore", "Mooncake"]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
module FunctionWrappersWrappersMooncakeExt
2+
3+
using FunctionWrappersWrappers
4+
import Mooncake
5+
using Mooncake: @is_primitive, MinimalCtx, CoDual, NoRData, zero_tangent, NoTangent
6+
7+
# Make calling a FunctionWrappersWrapper a Mooncake primitive.
8+
# Instead of differentiating through the FunctionWrapper dispatch machinery
9+
# (which fails because the tuple of differently-typed FunctionWrappers produces
10+
# incompatible FunctionWrapperTangent types), unwrap to the original function
11+
# and differentiate through that directly.
12+
13+
@is_primitive MinimalCtx Tuple{<:FunctionWrappersWrapper, Vararg}
14+
15+
function Mooncake.rrule!!(
16+
f::CoDual{<:FunctionWrappersWrapper}, args::Vararg{CoDual},
17+
)
18+
f_orig = unwrap(f.x)
19+
f_orig_codual = CoDual(f_orig, zero_tangent(f_orig))
20+
y, pb = Mooncake.rrule!!(f_orig_codual, args...)
21+
fww_pb(dy) = (NoRData(), Mooncake.Base.tail(pb(dy))...)
22+
return y, fww_pb
23+
end
24+
25+
# FunctionWrappersWrapper is not differentiable data itself — the wrapped function
26+
# is what carries the derivative information, and we handle that in the rrule above.
27+
Mooncake.tangent_type(::Type{<:FunctionWrappersWrapper}) = NoTangent
28+
29+
end

test/mooncake_tests.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using FunctionWrappersWrappers
2+
using Mooncake
3+
using Test
4+
5+
@testset "Mooncake reverse mode - single arg" begin
6+
f(x) = x^2
7+
fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,))
8+
9+
rule = Mooncake.build_rrule(fww, 3.0)
10+
val, (_, dx) = Mooncake.value_and_gradient!!(rule, fww, 3.0)
11+
@test val 9.0
12+
@test dx 6.0
13+
end
14+
15+
@testset "Mooncake reverse mode - multi arg" begin
16+
g(x, y) = x * y + x^2
17+
fww = FunctionWrappersWrapper(g, (Tuple{Float64, Float64},), (Float64,))
18+
19+
# g(x,y) = x*y + x^2 → ∂g/∂x = y + 2x, ∂g/∂y = x
20+
rule = Mooncake.build_rrule(fww, 3.0, 4.0)
21+
val, (_, dx, dy) = Mooncake.value_and_gradient!!(rule, fww, 3.0, 4.0)
22+
@test val 21.0
23+
@test dx 10.0 # ∂g/∂x at (3,4) = 4 + 6
24+
@test dy 3.0 # ∂g/∂y at (3,4) = 3
25+
end
26+
27+
@testset "Mooncake with trig functions" begin
28+
fww_sin = FunctionWrappersWrapper(sin, (Tuple{Float64},), (Float64,))
29+
30+
rule = Mooncake.build_rrule(fww_sin, 1.0)
31+
val, (_, dx) = Mooncake.value_and_gradient!!(rule, fww_sin, 1.0)
32+
@test val sin(1.0)
33+
@test dx cos(1.0)
34+
end
35+
36+
@testset "Mooncake through loss function" begin
37+
# Test that Mooncake can differentiate a loss function that calls FunctionWrappersWrapper
38+
f(x) = x[1]^2 + x[2]^2
39+
fww = FunctionWrappersWrapper(f, (Tuple{Vector{Float64}},), (Float64,))
40+
41+
loss(x) = fww(x)
42+
rule = Mooncake.build_rrule(loss, [3.0, 4.0])
43+
val, (_, dx) = Mooncake.value_and_gradient!!(rule, loss, [3.0, 4.0])
44+
@test val 25.0
45+
@test dx [6.0, 8.0]
46+
end
47+
48+
@testset "Mooncake in-place function" begin
49+
# In-place functions are common in SciML (f!(du, u, p, t))
50+
function f!(du, u, p)
51+
du[1] = p[1] * u[1] + p[2] * u[2]
52+
du[2] = p[3] * u[1] - u[2]
53+
return nothing
54+
end
55+
fww = FunctionWrappersWrapper(
56+
f!,
57+
(Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}},),
58+
(Nothing,),
59+
)
60+
61+
function loss(p)
62+
u = [1.0, 2.0]
63+
du = similar(u)
64+
fww(du, u, p)
65+
return sum(abs2, du)
66+
end
67+
68+
rule = Mooncake.build_rrule(loss, [1.0, 2.0, 3.0])
69+
val, (_, dp) = Mooncake.value_and_gradient!!(rule, loss, [1.0, 2.0, 3.0])
70+
# f!(du, [1,2], [1,2,3]) → du = [1*1+2*2, 3*1-2] = [5, 1]
71+
# loss = 25 + 1 = 26
72+
@test val 26.0
73+
# ∂loss/∂p1 = 2*du[1]*u[1] = 2*5*1 = 10
74+
# ∂loss/∂p2 = 2*du[1]*u[2] = 2*5*2 = 20
75+
# ∂loss/∂p3 = 2*du[2]*u[1] = 2*1*1 = 2
76+
@test dp [10.0, 20.0, 2.0]
77+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,9 @@ if GROUP == "All" || GROUP == "Enzyme"
5959
include("enzyme_tests.jl")
6060
end
6161
end
62+
63+
if GROUP == "All" || GROUP == "Mooncake"
64+
@testset "Mooncake extension" begin
65+
include("mooncake_tests.jl")
66+
end
67+
end

0 commit comments

Comments
 (0)