diff --git a/Project.toml b/Project.toml index 542a1db680..8f54d9d68a 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -31,6 +32,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] EnzymeBFloat16sExt = "BFloat16s" EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeFunctionWrappersExt = "FunctionWrappers" EnzymeGPUArraysCoreExt = "GPUArraysCore" EnzymeLogExpFunctionsExt = "LogExpFunctions" EnzymeSpecialFunctionsExt = "SpecialFunctions" @@ -42,6 +44,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.16" +FunctionWrappers = "1.1" Enzyme_jll = "0.0.249" GPUArraysCore = "0.1.6, 0.2" GPUCompiler = "1.6.2" @@ -59,6 +62,7 @@ julia = "1.10" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/ext/EnzymeFunctionWrappersExt.jl b/ext/EnzymeFunctionWrappersExt.jl new file mode 100644 index 0000000000..04abc32efe --- /dev/null +++ b/ext/EnzymeFunctionWrappersExt.jl @@ -0,0 +1,209 @@ +module EnzymeFunctionWrappersExt + +using FunctionWrappers: FunctionWrapper +using EnzymeCore +using EnzymeCore.EnzymeRules +using Enzyme + +# Helper to extract the raw function from a FunctionWrapper +@inline unwrap_fw(fw::FunctionWrapper) = fw.obj[] + +# Helper to reconstruct an annotation with a cached primal value +@inline _reconstruct_arg(arg::Const, cached, overwritten::Bool) = arg +@inline function _reconstruct_arg(arg::Duplicated, cached, overwritten::Bool) + overwritten && cached !== nothing ? Duplicated(cached, arg.dval) : arg +end +@inline function _reconstruct_arg(arg::BatchDuplicated, cached, overwritten::Bool) + overwritten && cached !== nothing ? BatchDuplicated(cached, arg.dval) : arg +end +@inline _reconstruct_arg(arg::Active, cached, overwritten::Bool) = arg + +# Helper for type-stable reverse return values +@inline _reverse_val(::Active{T}, grad, dret_val) where {T} = (grad * dret_val)::T +@inline _reverse_val(::Const, grad, dret_val) = nothing +@inline _reverse_val(::Duplicated, grad, dret_val) = nothing +@inline _reverse_val(::BatchDuplicated, grad, dret_val) = nothing + +# --------------------------------------------------------------------------- +# Forward mode rule +# --------------------------------------------------------------------------- +# Single rule for both IIP (Nothing return) and OOP FunctionWrappers. +# Extracts the wrapped function and delegates to autodiff_deferred. +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{<:FunctionWrapper}, + RT::Type{<:Annotation}, + args::Annotation..., +) + raw_f = unwrap_fw(func.val) + + # For IIP functions (Const{Nothing} return), needs_shadow is false but we + # still must propagate tangents into argument shadow arrays via AD. + if RT <: Const + # IIP or inactive return — run AD for tangent propagation into arg shadows + Enzyme.autodiff_deferred(Forward, Const(raw_f), Const{eltype(RT)}, args...) + if EnzymeRules.needs_primal(config) + return raw_f(map(x -> x.val, args)...) + else + return nothing + end + end + + # OOP: shadow is needed. Always use Duplicated for autodiff_deferred + # (it rejects DuplicatedNoNeed). + RealRt = eltype(RT) + if EnzymeRules.needs_primal(config) + res = Enzyme.autodiff_deferred(ForwardWithPrimal, Const(raw_f), Duplicated, args...) + # autodiff ForwardWithPrimal returns (derivs, primal) + if EnzymeRules.width(config) == 1 + return Duplicated(res[2]::RealRt, res[1]::RealRt) + else + return BatchDuplicated(res[2]::RealRt, res[1]::NTuple{EnzymeRules.width(config),RealRt}) + end + else + res = Enzyme.autodiff_deferred(Forward, Const(raw_f), Duplicated, args...) + # autodiff Forward returns (derivs,) + if EnzymeRules.width(config) == 1 + return res[1]::RealRt + else + return res[1]::NTuple{EnzymeRules.width(config),RealRt} + end + end +end + +# --------------------------------------------------------------------------- +# Reverse mode rules +# --------------------------------------------------------------------------- + +# augmented_primal: execute the forward pass, cache data for reverse +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{<:FunctionWrapper{Ret}}, + RT::Type{<:Annotation}, + args::Annotation..., +) where {Ret} + raw_f = unwrap_fw(func.val) + ow = EnzymeRules.overwritten(config) + nargs = length(args) + + # Cache copies of overwritten mutable args (needed for reverse pass) + cached_args = ntuple(Val(nargs)) do i + Base.@_inline_meta + # ow[1] is the function itself, ow[i+1] is the i-th argument + if ow[i + 1] && !(args[i] isa Const) + deepcopy(args[i].val) + else + nothing + end + end + + # Execute the primal + primal_result = raw_f(map(x -> x.val, args)...) + + primal = if EnzymeRules.needs_primal(config) + primal_result + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + if Ret === Nothing + nothing + else + if EnzymeRules.width(config) == 1 + Enzyme.make_zero(primal_result) + else + ntuple(Val(EnzymeRules.width(config))) do j + Base.@_inline_meta + Enzyme.make_zero(primal_result) + end + end + end + else + nothing + end + + tape = (raw_f, cached_args) + return EnzymeRules.AugmentedReturn(primal, shadow, tape) +end + +# reverse for IIP (Nothing return): accumulate gradients into dval arrays +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{<:FunctionWrapper{Nothing}}, + ::Type{<:Const{Nothing}}, + tape, + args::Annotation..., +) + raw_f, cached_args = tape + ow = EnzymeRules.overwritten(config) + nargs = length(args) + + new_args = ntuple(Val(nargs)) do i + Base.@_inline_meta + _reconstruct_arg(args[i], cached_args[i], ow[i + 1]) + end + + Enzyme.autodiff_deferred(Reverse, Const(raw_f), Const{Nothing}, new_args...) + + return ntuple(Val(nargs)) do i + Base.@_inline_meta + nothing + end +end + +# reverse for OOP with Active return: return scaled per-arg gradients +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{<:FunctionWrapper{Ret}}, + dret::Active, + tape, + args::Annotation..., +) where {Ret} + raw_f, cached_args = tape + ow = EnzymeRules.overwritten(config) + nargs = length(args) + + new_args = ntuple(Val(nargs)) do i + Base.@_inline_meta + _reconstruct_arg(args[i], cached_args[i], ow[i + 1]) + end + + # autodiff_deferred(Reverse, ..., Active, args...) returns ((grad1, grad2, ...),) + res = Enzyme.autodiff_deferred(Reverse, Const(raw_f), Active, new_args...) + grads = res[1] + + return ntuple(Val(nargs)) do i + Base.@_inline_meta + _reverse_val(args[i], grads[i], dret.val) + end +end + +# reverse for OOP with Duplicated/Const return type (non-Active) +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{<:FunctionWrapper{Ret}}, + dret::Type{<:Annotation}, + tape, + args::Annotation..., +) where {Ret} + if !(dret <: Const) + raw_f, cached_args = tape + ow = EnzymeRules.overwritten(config) + nargs = length(args) + + new_args = ntuple(Val(nargs)) do i + Base.@_inline_meta + _reconstruct_arg(args[i], cached_args[i], ow[i + 1]) + end + + Enzyme.autodiff_deferred(Reverse, Const(raw_f), dret, new_args...) + end + + return ntuple(Val(length(args))) do i + Base.@_inline_meta + nothing + end +end + +end # module diff --git a/test/Project.toml b/test/Project.toml index 87bbd781dd..c7fe12b0d0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" diff --git a/test/ext/functionwrappers.jl b/test/ext/functionwrappers.jl new file mode 100644 index 0000000000..94b4e9c83c --- /dev/null +++ b/test/ext/functionwrappers.jl @@ -0,0 +1,87 @@ +using Enzyme, Test +using FunctionWrappers: FunctionWrapper + +@testset "FunctionWrappers Extension" begin + + # In-place (IIP) test function: du[1] = p[1] * u[1]^2 + f!(du, u, p) = (du[1] = p[1] * u[1]^2; nothing) + + # Out-of-place (OOP) test function: returns p[1] * x^2 + f_oop(x, p) = p[1] * x^2 + + @testset "IIP Forward Mode" begin + fw = FunctionWrapper{Nothing,Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}}(f!) + + u = [2.0]; du = zeros(1); p = [3.0] + ddu = zeros(1); du_u = [1.0] + + # Differentiate through FunctionWrapper + Enzyme.autodiff(Forward, fw, Const{Nothing}, + Duplicated(du, ddu), Duplicated(u, du_u), Const(p)) + + # Compare with raw function + u2 = [2.0]; du2 = zeros(1); ddu2 = zeros(1); du_u2 = [1.0] + Enzyme.autodiff(Forward, f!, Const{Nothing}, + Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p)) + + @test ddu ≈ ddu2 + # ddu[1] should be d/du(p*u^2) * du_u = 3.0 * 2 * 2.0 * 1.0 = 12.0 + @test ddu[1] ≈ 12.0 + end + + @testset "IIP Reverse Mode" begin + fw = FunctionWrapper{Nothing,Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}}(f!) + + u = [2.0]; du = zeros(1); p = [3.0] + ddu = [1.0]; du_u = zeros(1) + + Enzyme.autodiff(Reverse, fw, Const{Nothing}, + Duplicated(du, ddu), Duplicated(u, du_u), Const(p)) + + # Compare with raw function + u2 = [2.0]; du2 = zeros(1); ddu2 = [1.0]; du_u2 = zeros(1) + Enzyme.autodiff(Reverse, f!, Const{Nothing}, + Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p)) + + @test du_u ≈ du_u2 + # du/du[1] of (du[1] = p[1]*u[1]^2) with seed ddu[1]=1.0: + # = p[1] * 2 * u[1] = 3.0 * 2 * 2.0 = 12.0 + @test du_u[1] ≈ 12.0 + end + + @testset "OOP Forward Mode" begin + fw_oop = FunctionWrapper{Float64,Tuple{Float64,Vector{Float64}}}(f_oop) + + x = 3.0; p = [2.0] + dx = 1.0 + + res = Enzyme.autodiff(Forward, fw_oop, Duplicated, + Duplicated(x, dx), Const(p)) + + # Compare with raw function + res2 = Enzyme.autodiff(Forward, f_oop, Duplicated, + Duplicated(x, dx), Const(p)) + + @test res[1] ≈ res2[1] + # d/dx(p*x^2) = 2*p*x = 2*2.0*3.0 = 12.0 + @test res[1] ≈ 12.0 + end + + @testset "OOP Reverse Mode" begin + fw_oop = FunctionWrapper{Float64,Tuple{Float64,Vector{Float64}}}(f_oop) + + x = 3.0; p = [2.0] + + res = Enzyme.autodiff(Reverse, fw_oop, Active, + Active(x), Const(p)) + + # Compare with raw function + res2 = Enzyme.autodiff(Reverse, f_oop, Active, + Active(x), Const(p)) + + @test res[1][1] ≈ res2[1][1] + # d/dx(p*x^2) = 2*p*x = 2*2.0*3.0 = 12.0 + @test res[1][1] ≈ 12.0 + end + +end