Skip to content

Commit 5bf8b41

Browse files
Add Mooncake extension for FunctionWrappersWrapper
When Mooncake differentiates through code that calls a FunctionWrappersWrapper, it tries to create tangent types for each FunctionWrapper variant in the internal tuple. These variants have different type parameters (for different ForwardDiff Dual combinations), producing incompatible FunctionWrapperTangent types that can't be stored in a typed tuple — causing a convert error. The fix: make FunctionWrappersWrapper calls a Mooncake primitive that unwraps to the original function (via `unwrap`) and differentiates through that directly. This mirrors the existing Enzyme extension pattern. The FunctionWrappersWrapper itself gets NoTangent since it's runtime dispatch infrastructure, not differentiable data — the original function's derivatives are handled in the rrule. This enables Mooncake to differentiate through NonlinearProblem solves that use AutoSpecialize (FunctionWrappers), which is needed for SCCNonlinearProblem AD support in SciMLSensitivity.jl (#1358). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ccf9ed1 commit 5bf8b41

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FunctionWrappersWrappers"
22
uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf"
33
authors = ["Chris Elrod <elrodc@gmail.com> and contributors"]
4-
version = "1.1.0"
4+
version = "1.2.0"
55

66
[deps]
77
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -11,23 +11,27 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
1111
[weakdeps]
1212
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1313
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
14+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1415

1516
[extensions]
1617
FunctionWrappersWrappersEnzymeExt = ["Enzyme", "EnzymeCore"]
18+
FunctionWrappersWrappersMooncakeExt = "Mooncake"
1719

1820
[compat]
1921
Enzyme = "0.13"
2022
EnzymeCore = "0.8"
2123
FunctionWrappers = "1"
24+
Mooncake = "0.5"
2225
PrecompileTools = "1"
2326
TruncatedStacktraces = "1"
2427
julia = "1.10"
2528

2629
[extras]
2730
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2831
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
32+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2933
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3034
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3135

3236
[targets]
33-
test = ["Pkg", "Test", "Enzyme", "EnzymeCore"]
37+
test = ["Pkg", "Test", "Enzyme", "EnzymeCore", "Mooncake"]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
module FunctionWrappersWrappersMooncakeExt
2+
3+
using FunctionWrappersWrappers
4+
import Mooncake
5+
using Mooncake: @is_primitive, MinimalCtx, CoDual, NoRData, zero_tangent, NoTangent
6+
7+
# Make calling a FunctionWrappersWrapper a Mooncake primitive.
8+
# Instead of differentiating through the FunctionWrapper dispatch machinery
9+
# (which fails because the tuple of differently-typed FunctionWrappers produces
10+
# incompatible FunctionWrapperTangent types), unwrap to the original function
11+
# and differentiate through that directly.
12+
13+
@is_primitive MinimalCtx Tuple{<:FunctionWrappersWrapper, Vararg}
14+
15+
function Mooncake.rrule!!(
16+
f::CoDual{<:FunctionWrappersWrapper}, args::Vararg{CoDual},
17+
)
18+
f_orig = unwrap(f.x)
19+
f_orig_codual = CoDual(f_orig, zero_tangent(f_orig))
20+
y, pb = Mooncake.rrule!!(f_orig_codual, args...)
21+
fww_pb(dy) = (NoRData(), Mooncake.Base.tail(pb(dy))...)
22+
return y, fww_pb
23+
end
24+
25+
# FunctionWrappersWrapper is not differentiable data itself — the wrapped function
26+
# is what carries the derivative information, and we handle that in the rrule above.
27+
Mooncake.tangent_type(::Type{<:FunctionWrappersWrapper}) = NoTangent
28+
29+
end

0 commit comments

Comments
 (0)