Skip to content

Commit b1d8401

Browse files
Merge pull request #49 from ChrisRackauckas-Claude/fix-issue-48-duplicated-fww
Accept Annotation{<:FunctionWrappersWrapper} in Enzyme rules (closes #48)
2 parents a03f55f + b102863 commit b1d8401

2 files changed

Lines changed: 127 additions & 11 deletions

File tree

ext/FunctionWrappersWrappersEnzymeExt.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,19 @@ end
3636
# =============================================================================
3737

3838
# Shadow only (Forward mode, no primal)
39+
#
40+
# `func` is `Annotation{<:FunctionWrappersWrapper}` rather than
41+
# `Const{<:FunctionWrappersWrapper}` so that callers passing
42+
# `Duplicated{<:FunctionWrappersWrapper}` also dispatch here. Enzyme drives
43+
# the rule that way when the outer `autodiff` call is differentiating through
44+
# a closure that carries an FWW (e.g. NonlinearSolve + SciMLSensitivity, see
45+
# SciML/FunctionWrappersWrappers.jl#48). The FWW struct itself only carries
46+
# `FunctionWrapper`s plus cache storage — none of those fields have a
47+
# meaningful tangent — so the function shadow is ignored and the inner
48+
# `Enzyme.autodiff` call uses `Const(f_orig)`.
3949
function EnzymeRules.forward(
4050
::EnzymeRules.FwdConfig{false, true, W, RuntimeActivity, StrongZero},
41-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
51+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
4252
RT::Type{<:EnzymeCore.Annotation{T}},
4353
args::Vararg{EnzymeCore.Annotation, N}
4454
) where {T, W, N, RuntimeActivity, StrongZero}
@@ -56,7 +66,7 @@ end
5666
# Both primal and shadow (ForwardWithPrimal mode)
5767
function EnzymeRules.forward(
5868
::EnzymeRules.FwdConfig{true, true, W, RuntimeActivity, StrongZero},
59-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
69+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
6070
RT::Type{<:EnzymeCore.Annotation{T}},
6171
args::Vararg{EnzymeCore.Annotation, N}
6272
) where {T, W, N, RuntimeActivity, StrongZero}
@@ -81,7 +91,7 @@ end
8191
# Primal only (Const return type) — width-independent
8292
function EnzymeRules.forward(
8393
::EnzymeRules.FwdConfig{true, false, W, RuntimeActivity, StrongZero},
84-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
94+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
8595
RT::Type{<:EnzymeCore.Annotation},
8696
args::Vararg{EnzymeCore.Annotation, N}
8797
) where {W, N, RuntimeActivity, StrongZero}
@@ -107,7 +117,7 @@ end
107117
# `set_runtime_activity(Forward)` on the way down into `f_orig`.
108118
function EnzymeRules.forward(
109119
::EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero},
110-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
120+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
111121
RT::Type{<:EnzymeCore.Annotation},
112122
args::Vararg{EnzymeCore.Annotation, N}
113123
) where {W, N, RuntimeActivity, StrongZero}
@@ -123,7 +133,7 @@ end
123133

124134
function EnzymeRules.augmented_primal(
125135
config::EnzymeRules.RevConfig,
126-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
136+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
127137
RT::Type{<:EnzymeCore.Active{T}},
128138
args::Vararg{EnzymeCore.Annotation, N}
129139
) where {T, N}
@@ -143,7 +153,7 @@ end
143153
# the reverse pass has nothing to propagate back from the return.
144154
function EnzymeRules.augmented_primal(
145155
config::EnzymeRules.RevConfig,
146-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
156+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
147157
RT::Type{<:EnzymeCore.Const},
148158
args::Vararg{EnzymeCore.Annotation, N}
149159
) where {N}
@@ -157,7 +167,7 @@ end
157167
# it available when propagating dret through the arguments.
158168
function EnzymeRules.augmented_primal(
159169
config::EnzymeRules.RevConfig,
160-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
170+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
161171
RT::Type{<:EnzymeCore.Duplicated{T}},
162172
args::Vararg{EnzymeCore.Annotation, N}
163173
) where {T, N}
@@ -173,7 +183,7 @@ end
173183

174184
function EnzymeRules.augmented_primal(
175185
config::EnzymeRules.RevConfig,
176-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
186+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
177187
RT::Type{<:EnzymeCore.BatchDuplicated{T, W}},
178188
args::Vararg{EnzymeCore.Annotation, N}
179189
) where {T, W, N}
@@ -220,7 +230,7 @@ end
220230

221231
function EnzymeRules.reverse(
222232
config::EnzymeRules.RevConfig,
223-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
233+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
224234
dret::EnzymeCore.Active{T},
225235
tape,
226236
args::Vararg{EnzymeCore.Active, N}
@@ -232,7 +242,7 @@ end
232242
# Handle mixed Active/Const args: return nothing for Const, gradient for Active
233243
function EnzymeRules.reverse(
234244
config::EnzymeRules.RevConfig,
235-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
245+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
236246
dret::EnzymeCore.Active,
237247
tape,
238248
args::Vararg{EnzymeCore.Annotation, N}
@@ -271,7 +281,7 @@ end
271281
# accumulated in-place by the `Enzyme.autodiff(Reverse, …)` call above.
272282
function EnzymeRules.reverse(
273283
config::EnzymeRules.RevConfig,
274-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
284+
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
275285
dret::Type{<:EnzymeCore.Const},
276286
tape,
277287
args::Vararg{EnzymeCore.Annotation, N}

test/enzyme_tests.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,112 @@ end
298298
# numerically correct.
299299
# =============================================================================
300300

301+
# =============================================================================
302+
# Duplicated function annotation on the FWW itself.
303+
#
304+
# Reproduces SciML/FunctionWrappersWrappers.jl#48: when Enzyme differentiates
305+
# through a closure that captures an FWW (e.g. NonlinearSolve +
306+
# SciMLSensitivity), the rule is invoked with
307+
# `Duplicated{<:FunctionWrappersWrapper}` for the function argument, not
308+
# `Const{<:FunctionWrappersWrapper}`. The FWW struct itself only carries
309+
# `FunctionWrapper`s + cache storage, so its "shadow" is ignored — we route
310+
# through `unwrap(func.val)` exactly as with `Const`.
311+
# =============================================================================
312+
313+
@testset "Enzyme forward, Duplicated FWW annotation — IIP Const return" begin
314+
f!(residual, u, p) = (residual[1] = u[1]^2 - p[1]; nothing)
315+
fww = FunctionWrappersWrapper(
316+
f!,
317+
(Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}},),
318+
(Nothing,)
319+
)
320+
321+
residual = [0.0]; dresidual = [0.0]
322+
u = [2.0]; du = [1.0]
323+
p = [1.0]; dp = [0.0]
324+
325+
config = EnzymeCore.EnzymeRules.FwdConfig{false, false, 1, false, false}()
326+
ret = EnzymeCore.EnzymeRules.forward(
327+
config,
328+
Duplicated(fww, fww), # <-- the failing dispatch in #48
329+
EnzymeCore.Const{Nothing},
330+
Duplicated(residual, dresidual),
331+
Duplicated(u, du),
332+
Duplicated(p, dp),
333+
)
334+
@test ret === nothing
335+
@test residual[1] 3.0 # u[1]^2 - p[1] = 4 - 1
336+
@test dresidual[1] 4.0 # 2*u[1]*du[1] - 1*dp[1] = 4
337+
end
338+
339+
@testset "Enzyme forward, Duplicated FWW annotation — shadow-only return" begin
340+
# Drive the {false, true, W, …} rule (shadow only, no primal) with a
341+
# Duplicated FWW.
342+
f(x) = x^2
343+
fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,))
344+
345+
config = EnzymeCore.EnzymeRules.FwdConfig{false, true, 1, false, false}()
346+
shadow = EnzymeCore.EnzymeRules.forward(
347+
config,
348+
Duplicated(fww, fww),
349+
EnzymeCore.Duplicated{Float64},
350+
Duplicated(3.0, 1.0),
351+
)
352+
@test shadow 6.0 # f'(3) = 2*3 = 6
353+
end
354+
355+
@testset "Enzyme forward, Duplicated FWW annotation — primal + shadow return" begin
356+
# Drive the {true, true, W, …} rule (ForwardWithPrimal) with a Duplicated
357+
# FWW.
358+
f(x) = x^2
359+
fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,))
360+
361+
config = EnzymeCore.EnzymeRules.FwdConfig{true, true, 1, false, false}()
362+
result = EnzymeCore.EnzymeRules.forward(
363+
config,
364+
Duplicated(fww, fww),
365+
EnzymeCore.Duplicated{Float64},
366+
Duplicated(3.0, 1.0),
367+
)
368+
@test result isa Duplicated
369+
@test result.val 9.0 # primal
370+
@test result.dval 6.0 # shadow
371+
end
372+
373+
@testset "Enzyme reverse, Duplicated FWW annotation — Const return IIP" begin
374+
# Mirror the forward IIP case on the reverse side. Duplicated FWW must
375+
# still drive the rule, gradients must accumulate into u_shadow.
376+
f!(du, u) = (du[1] = u[1]^2; nothing)
377+
fww = FunctionWrappersWrapper(
378+
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
379+
)
380+
381+
du = [0.0]; du_shadow = [1.0]
382+
u = [3.0]; u_shadow = [0.0]
383+
384+
rconfig = EnzymeRules.RevConfig{false, false, 1, (false, false), false, false}()
385+
aug = EnzymeRules.augmented_primal(
386+
rconfig,
387+
Duplicated(fww, fww), # <-- Duplicated FWW
388+
EnzymeCore.Const{Nothing},
389+
Duplicated(du, du_shadow),
390+
Duplicated(u, u_shadow),
391+
)
392+
@test aug.primal === nothing
393+
@test aug.shadow === nothing
394+
395+
EnzymeRules.reverse(
396+
rconfig,
397+
Duplicated(fww, fww),
398+
EnzymeCore.Const{Nothing},
399+
aug.tape,
400+
Duplicated(du, du_shadow),
401+
Duplicated(u, u_shadow),
402+
)
403+
@test du[1] 9.0 # primal effect from augmented_primal
404+
@test u_shadow[1] 6.0 # reverse accumulation: 2*u[1]*du_shadow[1]
405+
end
406+
301407
@testset "Enzyme Forward: set_runtime_activity propagates through FWW (IIP, time-dependent)" begin
302408
# DiffEqBase's `wrapfun_iip(ff, (u, u, p, t))` shape.
303409
const_INPUTS = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}

0 commit comments

Comments
 (0)