diff --git a/ext/FunctionWrappersWrappersEnzymeExt.jl b/ext/FunctionWrappersWrappersEnzymeExt.jl index 44975e5..85a76df 100644 --- a/ext/FunctionWrappersWrappersEnzymeExt.jl +++ b/ext/FunctionWrappersWrappersEnzymeExt.jl @@ -59,13 +59,16 @@ function EnzymeRules.forward( return f_orig(pargs...) end -# Neither primal nor shadow requested — Enzyme asks for this combo with Const -# return-type annotations where the caller only needs the side effects of the -# primal invocation (e.g. mutating an IIP RHS in SciML's solver path). No rule -# previously matched this case, so dispatch fell through to Enzyme's default -# path which tried to differentiate through the raw FunctionWrappersWrapper -# and failed with `MethodError: no method matching forward(…)` when the wrapper -# only held plain-Float64 signatures. Just run the primal and return nothing. +# Neither primal nor shadow requested in the RETURN. Enzyme dispatches on +# this combo for IIP functions (Const return type) where the caller still +# needs primal and shadow propagation through `Duplicated` args — e.g. SciML +# solvers calling an IIP RHS via `AutoEnzyme(…, function_annotation = Const)`. +# The previous revision ran `f_orig(pargs...)` by hand; that mutated the +# primal IIP buffer but left `Duplicated` shadow buffers untouched, giving +# trivial Jacobians and blowing up Rodas4/5/Veldd4 error tolerances 4–9 +# orders of magnitude in OrdinaryDiffEq.jl v7. Delegate to `Enzyme.autodiff` +# on the unwrapped function with a Const return annotation so the Duplicated +# arg shadows are propagated correctly and no return is produced. function EnzymeRules.forward( ::EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero}, func::EnzymeCore.Const{<:FunctionWrappersWrapper}, @@ -73,8 +76,7 @@ function EnzymeRules.forward( args::Vararg{EnzymeCore.Annotation, N} ) where {W, N, RuntimeActivity, StrongZero} f_orig = unwrap(func.val) - pargs = ntuple(i -> args[i].val, Val(N)) - f_orig(pargs...) + Enzyme.autodiff(Forward, Const(f_orig), Const, args...) return nothing end @@ -208,26 +210,37 @@ function EnzymeRules.reverse( end end -# Const return (no derivative to propagate from the return) — uniform Active args. +# Const return — Enzyme passes the RT as a `Type{<:Const}` to `reverse`, not +# as an instance. Delegate the reverse pass to +# `Enzyme.autodiff(Reverse, Const(f_orig), Const, args...)` so gradients +# accumulate into any `Duplicated` arg shadow buffers (the SciML IIP +# pattern). Simply returning `nothing` left Duplicated shadows at zero. +# +# Per Enzyme's rule return-type protocol, `Active` args require a concrete +# scalar gradient (not `nothing`). Under a `Const` return there is no +# gradient source, so Active arg gradients are zero. `Duplicated` / +# `BatchDuplicated` args return `nothing` because their gradients are +# accumulated in-place by the `Enzyme.autodiff(Reverse, …)` call above. function EnzymeRules.reverse( config::EnzymeRules.RevConfig, func::EnzymeCore.Const{<:FunctionWrappersWrapper}, - dret::EnzymeCore.Const, - tape, - args::Vararg{EnzymeCore.Active, N} -) where {N} - return ntuple(_ -> nothing, Val(N)) -end - -# Const return — mixed Active/Const args. -function EnzymeRules.reverse( - config::EnzymeRules.RevConfig, - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, - dret::EnzymeCore.Const, + dret::Type{<:EnzymeCore.Const}, tape, args::Vararg{EnzymeCore.Annotation, N} ) where {N} - return ntuple(_ -> nothing, Val(N)) + f_orig = unwrap(func.val) + # Only worth invoking Enzyme.autodiff when at least one arg is + # Duplicated/BatchDuplicated — otherwise there's nothing to accumulate. + if any(a -> a isa EnzymeCore.Duplicated || a isa EnzymeCore.BatchDuplicated, args) + Enzyme.autodiff(Reverse, Const(f_orig), Const, args...) + end + return ntuple(Val(N)) do i + if args[i] isa EnzymeCore.Active + zero(eltype(typeof(args[i]))) + else + nothing + end + end end end diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index e9ecc4d..c5713dc 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -76,15 +76,17 @@ end @test result_wp[2] ≈ 9.0 # primal f(3) = 9 end -@testset "Enzyme forward mode, neither primal nor shadow requested" begin - # Covers EnzymeRules.FwdConfig{false, false, W, ...}: caller wants only the - # side-effects of the primal invocation, no return value and no derivative. - # Reproduces the SciML/OrdinaryDiffEq.jl v7 Downstream regression where - # Enzyme dispatched on this config combination with a FWW wrapping an IIP - # RHS and found no matching rule, throwing - # MethodError: no method matching forward( - # ::FwdConfigWidth{1, false, false, false, false}, - # ::Const{<:FunctionWrappersWrapper}, ::Type{Const{Nothing}}, …) +@testset "Enzyme forward mode, Const return (IIP, no return-shadow)" begin + # Covers EnzymeRules.FwdConfig{false, false, W, ...} — Enzyme dispatches on + # this combo for IIP functions with a Const return type where the caller + # needs primal + shadow propagation via Duplicated args only (no return + # value to shadow). Reproduces the SciML/OrdinaryDiffEq.jl v7 Downstream + # regression where this call previously produced: + # - without any rule: MethodError: no method matching forward(…) + # - with a primal-only rule: trivial (zero) arg shadows, wrong Jacobians + # (Rodas4/5/Veldd4 errors 4–9 orders of magnitude above tolerance). + # The rule must delegate to `Enzyme.autodiff` on the unwrapped function + # so Duplicated arg shadows propagate correctly. f!(du, u) = (du[1] = -u[1]^2; nothing) fww = FunctionWrappersWrapper( f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,) @@ -93,21 +95,18 @@ end du = [0.0] u = [3.0] du_shadow = [0.0] - u_shadow = [1.0] + u_shadow = [1.0] # seed: ∂/∂u[1] = 1 - # Call forward directly with {false, false}: Enzyme's public-facing - # autodiff front-end doesn't normally expose this config, so invoke the - # rule by hand. config = EnzymeCore.EnzymeRules.FwdConfig{false, false, 1, false, false}() ret = EnzymeCore.EnzymeRules.forward( config, Const(fww), EnzymeCore.Const{Nothing}, Duplicated(du, du_shadow), Duplicated(u, u_shadow) ) @test ret === nothing - # primal side-effect did happen: f!(du, u) sets du[1] = -u[1]^2 = -9 + # Primal side-effect: du[1] = -u[1]^2 = -9 @test du[1] ≈ -9.0 - # shadow buffer was not touched by this no-diff path - @test du_shadow[1] == 0.0 + # Shadow propagation: ∂du[1]/∂u[1] * u_shadow[1] = -2*u[1]*1 = -6 + @test du_shadow[1] ≈ -6.0 end @testset "Enzyme reverse mode, Const return — augmented_primal runs primal" begin @@ -133,12 +132,14 @@ end @test aug.shadow === nothing @test aug.tape === nothing - # Reverse step — dret is Const, no grads to accumulate. + # Reverse step — dret is Const (passed as TYPE not instance in reverse + # rules). Enzyme's rule protocol requires concrete gradients for Active + # args; under a Const return they're zero (no gradient source). grads = EnzymeRules.reverse( - rconfig, Const(fww), EnzymeCore.Const{Float64}(0.0), + rconfig, Const(fww), EnzymeCore.Const{Float64}, aug.tape, Active(3.0), Active(4.0) ) - @test grads == (nothing, nothing) + @test grads == (0.0, 0.0) end @testset "Enzyme reverse mode, Duplicated return — augmented_primal initializes shadow" begin @@ -172,3 +173,107 @@ end @test aug.tape === nothing end +# ============================================================================= +# End-to-end reverse-mode derivative tests — exercise Enzyme.autodiff(Reverse, +# …) through the FWW and assert the resulting gradients are numerically correct. +# The prior reverse-mode testsets only checked dispatch / shape of +# AugmentedReturn; they did NOT verify the gradients are right. +# ============================================================================= + +@testset "Enzyme Reverse: Const return, Active args — no-flow gradients" begin + # For a function whose return is annotated Const in Reverse mode, there is + # no gradient source from the return, so Active arg gradients must be 0. + # (Enzyme's rule-return protocol requires concrete gradients for Active + # args — `nothing` is not allowed — so the rule returns zeros.) + g(x, y) = x * y + x^2 + fww = FunctionWrappersWrapper(g, (Tuple{Float64, Float64},), (Float64,)) + + # Const return (instead of Active) → no gradient flows back + result = Enzyme.autodiff(Reverse, Const(fww), Const, Active(3.0), Active(4.0)) + @test result[1] === (0.0, 0.0) +end + +@testset "Enzyme Reverse: IIP with Duplicated args, Const return" begin + # SciML's standard pattern: IIP RHS `f!(du, u)` with Const return, both du + # and u are Duplicated. Reverse mode should accumulate + # u_shadow[i] += du_shadow[j] * ∂(du[j])/∂(u[i]) + # into u_shadow. For f!(du, u) = (du[1] = u[1]^2; nothing) with + # du_shadow = [1.0] (incoming adjoint), + # u[1] = 3.0, + # ∂du[1]/∂u[1] = 2*u[1] = 6, + # the expected result is u_shadow[1] = 6.0 after the call. + f!(du, u) = (du[1] = u[1]^2; nothing) + fww = FunctionWrappersWrapper( + f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,) + ) + + du = [0.0] + u = [3.0] + du_shadow = [1.0] + u_shadow = [0.0] + + Enzyme.autodiff( + Reverse, Const(fww), Const, + Duplicated(du, du_shadow), Duplicated(u, u_shadow) + ) + @test du[1] ≈ 9.0 # primal effect: du[1] = u[1]^2 = 9 + @test u_shadow[1] ≈ 6.0 # reverse accumulation: 2 * u[1] * du_shadow[1] +end + +@testset "Enzyme Reverse: IIP multi-component IIP with Duplicated args" begin + # Cross-coupled IIP RHS: each output depends on multiple inputs. + # du[1] = u[1] * u[2] + # du[2] = u[1]^2 + u[2]^3 + # Jacobian at u = (x, y): + # J = [ y x ; + # 2x 3y^2 ] + # In reverse mode with du_shadow = [a, b], transpose of J applied to + # du_shadow gives the accumulation into u_shadow: + # u_shadow[1] += a*y + b*2x + # u_shadow[2] += a*x + b*3y^2 + f!(du, u) = (du[1] = u[1]*u[2]; du[2] = u[1]^2 + u[2]^3; nothing) + fww = FunctionWrappersWrapper( + f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,) + ) + + x, y = 2.0, 5.0 + a, b = 1.0, 0.5 + du = zeros(2) + u = [x, y] + du_shadow = [a, b] + u_shadow = zeros(2) + + Enzyme.autodiff( + Reverse, Const(fww), Const, + Duplicated(du, du_shadow), Duplicated(u, u_shadow) + ) + @test du ≈ [x*y, x^2 + y^3] + @test u_shadow[1] ≈ a*y + b*2*x # 5 + 2 = 7 + @test u_shadow[2] ≈ a*x + b*3*y^2 # 2 + 37.5 = 39.5 +end + +@testset "Enzyme ReverseWithPrimal: IIP with Duplicated args" begin + # Same IIP pattern but with ReverseWithPrimal so we also check the primal + # is available when the rule is asked to include it. + f!(du, u) = (du[1] = u[1]^3; nothing) + fww = FunctionWrappersWrapper( + f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,) + ) + + du = [0.0] + u = [2.0] + du_shadow = [1.0] + u_shadow = [0.0] + + # Capture the expected gradient BEFORE the call — Enzyme may zero + # `du_shadow` after consuming it during the reverse pass. + expected_u_grad = 3 * u[1]^2 * du_shadow[1] # = 12.0 + + Enzyme.autodiff( + ReverseWithPrimal, Const(fww), Const, + Duplicated(du, du_shadow), Duplicated(u, u_shadow) + ) + @test du[1] ≈ 8.0 + @test u_shadow[1] ≈ expected_u_grad +end +