Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions ext/FunctionWrappersWrappersEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,24 @@ function EnzymeRules.forward(
return f_orig(pargs...)
end

# Neither primal nor shadow requested — Enzyme asks for this combo with Const
# return-type annotations where the caller only needs the side effects of the
# primal invocation (e.g. mutating an IIP RHS in SciML's solver path). No rule
# previously matched this case, so dispatch fell through to Enzyme's default
# path which tried to differentiate through the raw FunctionWrappersWrapper
# and failed with `MethodError: no method matching forward(…)` when the wrapper
# only held plain-Float64 signatures. Just run the primal and return nothing.
# Neither primal nor shadow requested in the RETURN. Enzyme dispatches on
# this combo for IIP functions (Const return type) where the caller still
# needs primal and shadow propagation through `Duplicated` args — e.g. SciML
# solvers calling an IIP RHS via `AutoEnzyme(…, function_annotation = Const)`.
# The previous revision ran `f_orig(pargs...)` by hand; that mutated the
# primal IIP buffer but left `Duplicated` shadow buffers untouched, giving
# trivial Jacobians and blowing up Rodas4/5/Veldd4 error tolerances 4–9
# orders of magnitude in OrdinaryDiffEq.jl v7. Delegate to `Enzyme.autodiff`
# on the unwrapped function with a Const return annotation so the Duplicated
# arg shadows are propagated correctly and no return is produced.
function EnzymeRules.forward(
::EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation},
args::Vararg{EnzymeCore.Annotation, N}
) where {W, N, RuntimeActivity, StrongZero}
f_orig = unwrap(func.val)
pargs = ntuple(i -> args[i].val, Val(N))
f_orig(pargs...)
Enzyme.autodiff(Forward, Const(f_orig), Const, args...)
return nothing
end

Expand Down Expand Up @@ -208,26 +210,37 @@ function EnzymeRules.reverse(
end
end

# Const return (no derivative to propagate from the return) — uniform Active args.
# Const return — Enzyme passes the RT as a `Type{<:Const}` to `reverse`, not
# as an instance. Delegate the reverse pass to
# `Enzyme.autodiff(Reverse, Const(f_orig), Const, args...)` so gradients
# accumulate into any `Duplicated` arg shadow buffers (the SciML IIP
# pattern). Simply returning `nothing` left Duplicated shadows at zero.
#
# Per Enzyme's rule return-type protocol, `Active` args require a concrete
# scalar gradient (not `nothing`). Under a `Const` return there is no
# gradient source, so Active arg gradients are zero. `Duplicated` /
# `BatchDuplicated` args return `nothing` because their gradients are
# accumulated in-place by the `Enzyme.autodiff(Reverse, …)` call above.
function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
dret::EnzymeCore.Const,
tape,
args::Vararg{EnzymeCore.Active, N}
) where {N}
return ntuple(_ -> nothing, Val(N))
end

# Const return — mixed Active/Const args.
function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
dret::EnzymeCore.Const,
dret::Type{<:EnzymeCore.Const},
tape,
args::Vararg{EnzymeCore.Annotation, N}
) where {N}
return ntuple(_ -> nothing, Val(N))
f_orig = unwrap(func.val)
# Only worth invoking Enzyme.autodiff when at least one arg is
# Duplicated/BatchDuplicated — otherwise there's nothing to accumulate.
if any(a -> a isa EnzymeCore.Duplicated || a isa EnzymeCore.BatchDuplicated, args)
Enzyme.autodiff(Reverse, Const(f_orig), Const, args...)
end
return ntuple(Val(N)) do i
if args[i] isa EnzymeCore.Active
zero(eltype(typeof(args[i])))
else
nothing
end
end
end

end
143 changes: 124 additions & 19 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,17 @@ end
@test result_wp[2] ≈ 9.0 # primal f(3) = 9
end

@testset "Enzyme forward mode, neither primal nor shadow requested" begin
# Covers EnzymeRules.FwdConfig{false, false, W, ...}: caller wants only the
# side-effects of the primal invocation, no return value and no derivative.
# Reproduces the SciML/OrdinaryDiffEq.jl v7 Downstream regression where
# Enzyme dispatched on this config combination with a FWW wrapping an IIP
# RHS and found no matching rule, throwing
# MethodError: no method matching forward(
# ::FwdConfigWidth{1, false, false, false, false},
# ::Const{<:FunctionWrappersWrapper}, ::Type{Const{Nothing}}, …)
@testset "Enzyme forward mode, Const return (IIP, no return-shadow)" begin
# Covers EnzymeRules.FwdConfig{false, false, W, ...} — Enzyme dispatches on
# this combo for IIP functions with a Const return type where the caller
# needs primal + shadow propagation via Duplicated args only (no return
# value to shadow). Reproduces the SciML/OrdinaryDiffEq.jl v7 Downstream
# regression where this call previously produced:
# - without any rule: MethodError: no method matching forward(…)
# - with a primal-only rule: trivial (zero) arg shadows, wrong Jacobians
# (Rodas4/5/Veldd4 errors 4–9 orders of magnitude above tolerance).
# The rule must delegate to `Enzyme.autodiff` on the unwrapped function
# so Duplicated arg shadows propagate correctly.
f!(du, u) = (du[1] = -u[1]^2; nothing)
fww = FunctionWrappersWrapper(
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
Expand All @@ -93,21 +95,18 @@ end
du = [0.0]
u = [3.0]
du_shadow = [0.0]
u_shadow = [1.0]
u_shadow = [1.0] # seed: ∂/∂u[1] = 1

# Call forward directly with {false, false}: Enzyme's public-facing
# autodiff front-end doesn't normally expose this config, so invoke the
# rule by hand.
config = EnzymeCore.EnzymeRules.FwdConfig{false, false, 1, false, false}()
ret = EnzymeCore.EnzymeRules.forward(
config, Const(fww), EnzymeCore.Const{Nothing},
Duplicated(du, du_shadow), Duplicated(u, u_shadow)
)
@test ret === nothing
# primal side-effect did happen: f!(du, u) sets du[1] = -u[1]^2 = -9
# Primal side-effect: du[1] = -u[1]^2 = -9
@test du[1] ≈ -9.0
# shadow buffer was not touched by this no-diff path
@test du_shadow[1] == 0.0
# Shadow propagation: ∂du[1]/∂u[1] * u_shadow[1] = -2*u[1]*1 = -6
@test du_shadow[1] ≈ -6.0
end

@testset "Enzyme reverse mode, Const return — augmented_primal runs primal" begin
Expand All @@ -133,12 +132,14 @@ end
@test aug.shadow === nothing
@test aug.tape === nothing

# Reverse step — dret is Const, no grads to accumulate.
# Reverse step — dret is Const (passed as TYPE not instance in reverse
# rules). Enzyme's rule protocol requires concrete gradients for Active
# args; under a Const return they're zero (no gradient source).
grads = EnzymeRules.reverse(
rconfig, Const(fww), EnzymeCore.Const{Float64}(0.0),
rconfig, Const(fww), EnzymeCore.Const{Float64},
aug.tape, Active(3.0), Active(4.0)
)
@test grads == (nothing, nothing)
@test grads == (0.0, 0.0)
end

@testset "Enzyme reverse mode, Duplicated return — augmented_primal initializes shadow" begin
Expand Down Expand Up @@ -172,3 +173,107 @@ end
@test aug.tape === nothing
end

# =============================================================================
# End-to-end reverse-mode derivative tests — exercise Enzyme.autodiff(Reverse,
# …) through the FWW and assert the resulting gradients are numerically correct.
# The prior reverse-mode testsets only checked dispatch / shape of
# AugmentedReturn; they did NOT verify the gradients are right.
# =============================================================================

@testset "Enzyme Reverse: Const return, Active args — no-flow gradients" begin
# For a function whose return is annotated Const in Reverse mode, there is
# no gradient source from the return, so Active arg gradients must be 0.
# (Enzyme's rule-return protocol requires concrete gradients for Active
# args — `nothing` is not allowed — so the rule returns zeros.)
g(x, y) = x * y + x^2
fww = FunctionWrappersWrapper(g, (Tuple{Float64, Float64},), (Float64,))

# Const return (instead of Active) → no gradient flows back
result = Enzyme.autodiff(Reverse, Const(fww), Const, Active(3.0), Active(4.0))
@test result[1] === (0.0, 0.0)
end

@testset "Enzyme Reverse: IIP with Duplicated args, Const return" begin
# SciML's standard pattern: IIP RHS `f!(du, u)` with Const return, both du
# and u are Duplicated. Reverse mode should accumulate
# u_shadow[i] += du_shadow[j] * ∂(du[j])/∂(u[i])
# into u_shadow. For f!(du, u) = (du[1] = u[1]^2; nothing) with
# du_shadow = [1.0] (incoming adjoint),
# u[1] = 3.0,
# ∂du[1]/∂u[1] = 2*u[1] = 6,
# the expected result is u_shadow[1] = 6.0 after the call.
f!(du, u) = (du[1] = u[1]^2; nothing)
fww = FunctionWrappersWrapper(
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
)

du = [0.0]
u = [3.0]
du_shadow = [1.0]
u_shadow = [0.0]

Enzyme.autodiff(
Reverse, Const(fww), Const,
Duplicated(du, du_shadow), Duplicated(u, u_shadow)
)
@test du[1] ≈ 9.0 # primal effect: du[1] = u[1]^2 = 9
@test u_shadow[1] ≈ 6.0 # reverse accumulation: 2 * u[1] * du_shadow[1]
end

@testset "Enzyme Reverse: IIP multi-component IIP with Duplicated args" begin
# Cross-coupled IIP RHS: each output depends on multiple inputs.
# du[1] = u[1] * u[2]
# du[2] = u[1]^2 + u[2]^3
# Jacobian at u = (x, y):
# J = [ y x ;
# 2x 3y^2 ]
# In reverse mode with du_shadow = [a, b], transpose of J applied to
# du_shadow gives the accumulation into u_shadow:
# u_shadow[1] += a*y + b*2x
# u_shadow[2] += a*x + b*3y^2
f!(du, u) = (du[1] = u[1]*u[2]; du[2] = u[1]^2 + u[2]^3; nothing)
fww = FunctionWrappersWrapper(
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
)

x, y = 2.0, 5.0
a, b = 1.0, 0.5
du = zeros(2)
u = [x, y]
du_shadow = [a, b]
u_shadow = zeros(2)

Enzyme.autodiff(
Reverse, Const(fww), Const,
Duplicated(du, du_shadow), Duplicated(u, u_shadow)
)
@test du ≈ [x*y, x^2 + y^3]
@test u_shadow[1] ≈ a*y + b*2*x # 5 + 2 = 7
@test u_shadow[2] ≈ a*x + b*3*y^2 # 2 + 37.5 = 39.5
end

@testset "Enzyme ReverseWithPrimal: IIP with Duplicated args" begin
# Same IIP pattern but with ReverseWithPrimal so we also check the primal
# is available when the rule is asked to include it.
f!(du, u) = (du[1] = u[1]^3; nothing)
fww = FunctionWrappersWrapper(
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
)

du = [0.0]
u = [2.0]
du_shadow = [1.0]
u_shadow = [0.0]

# Capture the expected gradient BEFORE the call — Enzyme may zero
# `du_shadow` after consuming it during the reverse pass.
expected_u_grad = 3 * u[1]^2 * du_shadow[1] # = 12.0

Enzyme.autodiff(
ReverseWithPrimal, Const(fww), Const,
Duplicated(du, du_shadow), Duplicated(u, u_shadow)
)
@test du[1] ≈ 8.0
@test u_shadow[1] ≈ expected_u_grad
end

Loading