Skip to content

Use native gradient API for ForwardDiff, Enzyme, Mooncake#458

Open
yebai wants to merge 11 commits intomainfrom
native-ad-extensions
Open

Use native gradient API for ForwardDiff, Enzyme, Mooncake#458
yebai wants to merge 11 commits intomainfrom
native-ad-extensions

Conversation

@yebai
Copy link
Copy Markdown
Member

@yebai yebai commented Apr 12, 2026

- Remove DifferentiationInterface from [deps]; add ADTypes
- Move Enzyme to [weakdeps]; add BijectorsEnzymeExt extension
- Add src/ad_utils.jl defining _value_and_gradient/_value_and_jacobian generic functions
- Implement native backends in each pkg ext: ForwardDiff, ReverseDiff (compiled + non-compiled), Mooncake (reverse + forward JVP), Enzyme (reverse + forward)
- Update src/vector/test_utils.jl to use ADTypes backend types and B._value_and_* API

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

Bijectors.jl documentation for PR #458 is available at:
https://TuringLang.github.io/Bijectors.jl/previews/PR458/

yebai and others added 7 commits April 12, 2026 21:21
- Avoid double f(x) evaluation in gradient/jacobian for ForwardDiff and ReverseDiff
  by using DiffResults (GradientResult, JacobianResult) with the in-place ! variants
- For Enzyme reverse mode, use autodiff(ReverseWithPrimal, ...) to get value and
  gradient in one pass instead of calling f(x) separately
- Fix _enzyme_mode to guard against mode=nothing (AutoEnzyme() default) which
  previously threw a MethodError from set_runtime_activity(::Nothing)
- Pre-allocate dy/dx tangent buffers outside loops in Mooncake implementations and
  use fill! to zero them, avoiding one heap allocation per iteration
- Add fallback _value_and_gradient/_value_and_jacobian methods with a clear error
  message for backends without a loaded extension

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add AutoEnzyme{Nothing} to the ReverseWithPrimal dispatch union so the
  default (mode=nothing) backend also avoids double-evaluating f
- Remove redundant `return` before `error(...)` in ad_utils.jl fallback
  methods; error() returns Union{} so return is a no-op

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add return before error() in ad_utils.jl (JuliaFormatter)
- Use ReverseDiff.DiffResults instead of ForwardDiff.DiffResults so the
  extension triggers on ReverseDiff alone
- Keep Bijectors in test/Project.toml for B._value_and_jacobian calls

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@yebai yebai force-pushed the native-ad-extensions branch from b3577f0 to 669d4a8 Compare April 12, 2026 23:15
Copy link
Copy Markdown

@gdalle gdalle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a lot of duplication of efforts compared to DI, along with some forgotten aspects. All in all, I'm not sure what we gain here

Comment thread ext/BijectorsEnzymeExt.jl
}

function _annotate_function(f, backend::AutoEnzyme, mode)
annotation = typeof(backend).parameters[2]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing type parameters this way is not recommended since the field is internal (AFAICT)

Comment thread ext/BijectorsEnzymeExt.jl
EnzymeCore.Duplicated,EnzymeCore.DuplicatedNoNeed,EnzymeCore.MixedDuplicated
}

function _annotate_function(f, backend::AutoEnzyme, mode)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread ext/BijectorsEnzymeExt.jl
backend::Union{AutoEnzyme{Nothing},AutoEnzyme{<:EnzymeCore.ReverseMode}},
x::AbstractVector,
)
mode = if backend isa AutoEnzyme{Nothing}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread ext/BijectorsEnzymeExt.jl
Comment on lines +54 to +62
for i in eachindex(x)
dx = zero(x)
dx[i] = one(eltype(x))
directional, primal = Enzyme.autodiff(mode, annotated_f, Enzyme.Duplicated(x, dx))
grad[i] = directional
if i == firstindex(x)
value = primal
end
end
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enzyme has a built-in forward-mode gradient function, which DI already uses in such cases. Any reason not to use it here too?
Ping @wsmoses

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

Comment thread ext/BijectorsEnzymeExt.jl
Comment on lines +89 to +98
for i in eachindex(x)
dx = zero(x)
dx[i] = one(eltype(x))
directional, primal = Enzyme.autodiff(mode, annotated_f, Enzyme.Duplicated(x, dx))
if i == firstindex(x)
value = primal isa AbstractArray ? copy(primal) : primal
J = Matrix{eltype(directional)}(undef, length(directional), length(x))
end
J[:, i] .= directional
end
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enzyme has a built-in forward Jacobian function, which DI already uses in such cases. Any reason not to use it here too?

Comment on lines +17 to +20
if T === Nothing
ForwardDiff.checktag(config, f, x)
end
ForwardDiff.gradient!(result, f, x, config, Val(false))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function _mooncake_zero_tangent_or_primal(
x, backend::Union{AutoMooncake,AutoMooncakeForward}
)
if _mooncake_config(backend).friendly_tangents
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is type-unstable

return f(x), similar(x, 0)
end
tape = ReverseDiff.GradientTape(f, x)
compiled = ReverseDiff.compile(tape)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really worth compiling a tape you will only use once? I predict this slows things down significantly


function _value_and_jacobian(f, ::AutoReverseDiff{true}, x::AbstractVector)
tape = ReverseDiff.JacobianTape(f, x)
compiled = ReverseDiff.compile(tape)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment thread src/ad_utils.jl

Implementations are provided by package extensions for each AD backend.
"""
function _value_and_gradient(f, backend::ADTypes.AbstractADType, x::AbstractVector)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is breaking

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants