Skip to content

Commit 137a3f2

Browse files
Fix reverse rules: delegate via Enzyme.autodiff, return zeros for Active args
The reverse rules from #43 had two bugs exposed by new end-to-end tests: 1. Const-dret reverse returned `nothing` per arg, but Enzyme's rule protocol requires concrete scalar gradients for Active args (not nothing). Fixed to return `zero(T)` for Active args and `nothing` for Duplicated/Const args. 2. IIP reverse with Duplicated args (SciML pattern) returned nothing and never propagated gradients into the Duplicated shadow buffers. Fixed by delegating to `Enzyme.autodiff(Reverse, Const(f_orig), Const, args...)` when Duplicated args are present, so Enzyme accumulates the transposed derivative into the shadow buffers. 3. Enzyme passes `Type{<:Const}` (not an instance) for the dret slot in Const-return reverse rules. Updated dispatch signatures from `dret::EnzymeCore.Const` to `dret::Type{<:EnzymeCore.Const}`. New end-to-end reverse-mode tests that assert derivative correctness: - Const return + Active args: gradients are (0.0, 0.0) - IIP f!(du, u) with Duplicated args: u_shadow accumulates ∂du/∂u - Multi-component IIP cross-coupled Jacobian transpose - ReverseWithPrimal IIP variant Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
1 parent 9ea7132 commit 137a3f2

2 files changed

Lines changed: 134 additions & 17 deletions

File tree

ext/FunctionWrappersWrappersEnzymeExt.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,26 +210,37 @@ function EnzymeRules.reverse(
210210
end
211211
end
212212

213-
# Const return (no derivative to propagate from the return) — uniform Active args.
213+
# Const return — Enzyme passes the RT as a `Type{<:Const}` to `reverse`, not
214+
# as an instance. Delegate the reverse pass to
215+
# `Enzyme.autodiff(Reverse, Const(f_orig), Const, args...)` so gradients
216+
# accumulate into any `Duplicated` arg shadow buffers (the SciML IIP
217+
# pattern). Simply returning `nothing` left Duplicated shadows at zero.
218+
#
219+
# Per Enzyme's rule return-type protocol, `Active` args require a concrete
220+
# scalar gradient (not `nothing`). Under a `Const` return there is no
221+
# gradient source, so Active arg gradients are zero. `Duplicated` /
222+
# `BatchDuplicated` args return `nothing` because their gradients are
223+
# accumulated in-place by the `Enzyme.autodiff(Reverse, …)` call above.
214224
function EnzymeRules.reverse(
215225
config::EnzymeRules.RevConfig,
216226
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
217-
dret::EnzymeCore.Const,
218-
tape,
219-
args::Vararg{EnzymeCore.Active, N}
220-
) where {N}
221-
return ntuple(_ -> nothing, Val(N))
222-
end
223-
224-
# Const return — mixed Active/Const args.
225-
function EnzymeRules.reverse(
226-
config::EnzymeRules.RevConfig,
227-
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
228-
dret::EnzymeCore.Const,
227+
dret::Type{<:EnzymeCore.Const},
229228
tape,
230229
args::Vararg{EnzymeCore.Annotation, N}
231230
) where {N}
232-
return ntuple(_ -> nothing, Val(N))
231+
f_orig = unwrap(func.val)
232+
# Only worth invoking Enzyme.autodiff when at least one arg is
233+
# Duplicated/BatchDuplicated — otherwise there's nothing to accumulate.
234+
if any(a -> a isa EnzymeCore.Duplicated || a isa EnzymeCore.BatchDuplicated, args)
235+
Enzyme.autodiff(Reverse, Const(f_orig), Const, args...)
236+
end
237+
return ntuple(Val(N)) do i
238+
if args[i] isa EnzymeCore.Active
239+
zero(eltype(typeof(args[i])))
240+
else
241+
nothing
242+
end
243+
end
233244
end
234245

235246
end

test/enzyme_tests.jl

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,14 @@ end
132132
@test aug.shadow === nothing
133133
@test aug.tape === nothing
134134

135-
# Reverse step — dret is Const, no grads to accumulate.
135+
# Reverse step — dret is Const (passed as TYPE not instance in reverse
136+
# rules). Enzyme's rule protocol requires concrete gradients for Active
137+
# args; under a Const return they're zero (no gradient source).
136138
grads = EnzymeRules.reverse(
137-
rconfig, Const(fww), EnzymeCore.Const{Float64}(0.0),
139+
rconfig, Const(fww), EnzymeCore.Const{Float64},
138140
aug.tape, Active(3.0), Active(4.0)
139141
)
140-
@test grads == (nothing, nothing)
142+
@test grads == (0.0, 0.0)
141143
end
142144

143145
@testset "Enzyme reverse mode, Duplicated return — augmented_primal initializes shadow" begin
@@ -171,3 +173,107 @@ end
171173
@test aug.tape === nothing
172174
end
173175

176+
# =============================================================================
177+
# End-to-end reverse-mode derivative tests — exercise Enzyme.autodiff(Reverse,
178+
# …) through the FWW and assert the resulting gradients are numerically correct.
179+
# The prior reverse-mode testsets only checked dispatch / shape of
180+
# AugmentedReturn; they did NOT verify the gradients are right.
181+
# =============================================================================
182+
183+
@testset "Enzyme Reverse: Const return, Active args — no-flow gradients" begin
184+
# For a function whose return is annotated Const in Reverse mode, there is
185+
# no gradient source from the return, so Active arg gradients must be 0.
186+
# (Enzyme's rule-return protocol requires concrete gradients for Active
187+
# args — `nothing` is not allowed — so the rule returns zeros.)
188+
g(x, y) = x * y + x^2
189+
fww = FunctionWrappersWrapper(g, (Tuple{Float64, Float64},), (Float64,))
190+
191+
# Const return (instead of Active) → no gradient flows back
192+
result = Enzyme.autodiff(Reverse, Const(fww), Const, Active(3.0), Active(4.0))
193+
@test result[1] === (0.0, 0.0)
194+
end
195+
196+
@testset "Enzyme Reverse: IIP with Duplicated args, Const return" begin
197+
# SciML's standard pattern: IIP RHS `f!(du, u)` with Const return, both du
198+
# and u are Duplicated. Reverse mode should accumulate
199+
# u_shadow[i] += du_shadow[j] * ∂(du[j])/∂(u[i])
200+
# into u_shadow. For f!(du, u) = (du[1] = u[1]^2; nothing) with
201+
# du_shadow = [1.0] (incoming adjoint),
202+
# u[1] = 3.0,
203+
# ∂du[1]/∂u[1] = 2*u[1] = 6,
204+
# the expected result is u_shadow[1] = 6.0 after the call.
205+
f!(du, u) = (du[1] = u[1]^2; nothing)
206+
fww = FunctionWrappersWrapper(
207+
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
208+
)
209+
210+
du = [0.0]
211+
u = [3.0]
212+
du_shadow = [1.0]
213+
u_shadow = [0.0]
214+
215+
Enzyme.autodiff(
216+
Reverse, Const(fww), Const,
217+
Duplicated(du, du_shadow), Duplicated(u, u_shadow)
218+
)
219+
@test du[1] 9.0 # primal effect: du[1] = u[1]^2 = 9
220+
@test u_shadow[1] 6.0 # reverse accumulation: 2 * u[1] * du_shadow[1]
221+
end
222+
223+
@testset "Enzyme Reverse: IIP multi-component IIP with Duplicated args" begin
224+
# Cross-coupled IIP RHS: each output depends on multiple inputs.
225+
# du[1] = u[1] * u[2]
226+
# du[2] = u[1]^2 + u[2]^3
227+
# Jacobian at u = (x, y):
228+
# J = [ y x ;
229+
# 2x 3y^2 ]
230+
# In reverse mode with du_shadow = [a, b], transpose of J applied to
231+
# du_shadow gives the accumulation into u_shadow:
232+
# u_shadow[1] += a*y + b*2x
233+
# u_shadow[2] += a*x + b*3y^2
234+
f!(du, u) = (du[1] = u[1]*u[2]; du[2] = u[1]^2 + u[2]^3; nothing)
235+
fww = FunctionWrappersWrapper(
236+
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
237+
)
238+
239+
x, y = 2.0, 5.0
240+
a, b = 1.0, 0.5
241+
du = zeros(2)
242+
u = [x, y]
243+
du_shadow = [a, b]
244+
u_shadow = zeros(2)
245+
246+
Enzyme.autodiff(
247+
Reverse, Const(fww), Const,
248+
Duplicated(du, du_shadow), Duplicated(u, u_shadow)
249+
)
250+
@test du [x*y, x^2 + y^3]
251+
@test u_shadow[1] a*y + b*2*x # 5 + 2 = 7
252+
@test u_shadow[2] a*x + b*3*y^2 # 2 + 37.5 = 39.5
253+
end
254+
255+
@testset "Enzyme ReverseWithPrimal: IIP with Duplicated args" begin
256+
# Same IIP pattern but with ReverseWithPrimal so we also check the primal
257+
# is available when the rule is asked to include it.
258+
f!(du, u) = (du[1] = u[1]^3; nothing)
259+
fww = FunctionWrappersWrapper(
260+
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
261+
)
262+
263+
du = [0.0]
264+
u = [2.0]
265+
du_shadow = [1.0]
266+
u_shadow = [0.0]
267+
268+
# Capture the expected gradient BEFORE the call — Enzyme may zero
269+
# `du_shadow` after consuming it during the reverse pass.
270+
expected_u_grad = 3 * u[1]^2 * du_shadow[1] # = 12.0
271+
272+
Enzyme.autodiff(
273+
ReverseWithPrimal, Const(fww), Const,
274+
Duplicated(du, du_shadow), Duplicated(u, u_shadow)
275+
)
276+
@test du[1] 8.0
277+
@test u_shadow[1] expected_u_grad
278+
end
279+

0 commit comments

Comments
 (0)