Skip to content

Commit 3ea3778

Browse files
Merge pull request #45 from ChrisRackauckas-Claude/enzyme-forward-rule
Propagate set_runtime_activity through FWW Enzyme forward rules
2 parents 11cbd62 + 4eeba5d commit 3ea3778

2 files changed

Lines changed: 137 additions & 9 deletions

File tree

ext/FunctionWrappersWrappersEnzymeExt.jl

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,32 @@ using Enzyme
55
using EnzymeCore
66
using EnzymeCore.EnzymeRules
77

8+
# =============================================================================
9+
# Helper: build a Forward mode from FwdConfig flags
10+
# =============================================================================
11+
# The outer caller may invoke `Enzyme.autodiff(set_runtime_activity(Forward), …)`
12+
# or `set_strong_zero(Forward)` or `ForwardWithPrimal`. Those settings flow
13+
# into the `EnzymeRules.FwdConfig{NeedsPrimal, NeedsShadow, Width,
14+
# RuntimeActivity, StrongZero}` type parameters of the rule's first argument.
15+
# Before this fix the rules hard-coded plain `Forward` in their inner
16+
# `Enzyme.autodiff` delegation, which dropped both `RuntimeActivity` and
17+
# `StrongZero` — breaking users who need `set_runtime_activity(Forward)` to
18+
# avoid `EnzymeRuntimeActivityError` inside the wrapped function (the SciML
19+
# `Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Forward)))` path
20+
# on an in-place time-dependent RHS; see OrdinaryDiffEq.jl PR #3518).
21+
#
22+
# `_fwd_mode(needs_primal, RuntimeActivity, StrongZero)` returns the
23+
# `ForwardMode` matching the outer config so the delegated call inherits
24+
# those flags.
25+
@inline function _fwd_mode(
26+
::Val{NeedsPrimal}, ::Val{RuntimeActivity}, ::Val{StrongZero}
27+
) where {NeedsPrimal, RuntimeActivity, StrongZero}
28+
mode = NeedsPrimal ? ForwardWithPrimal : Forward
29+
RuntimeActivity && (mode = Enzyme.set_runtime_activity(mode))
30+
StrongZero && (mode = Enzyme.set_strong_zero(mode))
31+
return mode
32+
end
33+
834
# =============================================================================
935
# Forward mode rules — generalized to arbitrary batch width W
1036
# =============================================================================
@@ -17,11 +43,12 @@ function EnzymeRules.forward(
1743
args::Vararg{EnzymeCore.Annotation, N}
1844
) where {T, W, N, RuntimeActivity, StrongZero}
1945
f_orig = unwrap(func.val)
46+
mode = _fwd_mode(Val(false), Val(RuntimeActivity), Val(StrongZero))
2047
if W == 1
21-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{T}, args...)
48+
shadow_result = Enzyme.autodiff(mode, Const(f_orig), Duplicated{T}, args...)
2249
return shadow_result[1]::T
2350
else
24-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), BatchDuplicated{T, W}, args...)
51+
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
2552
return shadow_result[1]::NTuple{W, T}
2653
end
2754
end
@@ -36,12 +63,16 @@ function EnzymeRules.forward(
3663
f_orig = unwrap(func.val)
3764
pargs = ntuple(i -> args[i].val, Val(N))
3865
primal = f_orig(pargs...)::T
66+
# Use plain Forward (not ForwardWithPrimal) here — we already have the
67+
# primal above, and `Duplicated{T}` / `BatchDuplicated{T,W}` as the RT
68+
# annotation asks only for the shadow.
69+
mode = _fwd_mode(Val(false), Val(RuntimeActivity), Val(StrongZero))
3970
if W == 1
40-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{T}, args...)
71+
shadow_result = Enzyme.autodiff(mode, Const(f_orig), Duplicated{T}, args...)
4172
shadow = shadow_result[1]::T
4273
return Duplicated(primal, shadow)
4374
else
44-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), BatchDuplicated{T, W}, args...)
75+
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
4576
shadows = shadow_result[1]::NTuple{W, T}
4677
return BatchDuplicated(primal, shadows)
4778
end
@@ -69,14 +100,20 @@ end
69100
# orders of magnitude in OrdinaryDiffEq.jl v7. Delegate to `Enzyme.autodiff`
70101
# on the unwrapped function with a Const return annotation so the Duplicated
71102
# arg shadows are propagated correctly and no return is produced.
103+
#
104+
# IMPORTANT: forward the `RuntimeActivity` and `StrongZero` flags from the
105+
# outer config into the delegated `Enzyme.autodiff` call. Prior to this
106+
# fix the rule hard-coded `Forward`, silently dropping
107+
# `set_runtime_activity(Forward)` on the way down into `f_orig`.
72108
function EnzymeRules.forward(
73109
::EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero},
74110
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
75111
RT::Type{<:EnzymeCore.Annotation},
76112
args::Vararg{EnzymeCore.Annotation, N}
77113
) where {W, N, RuntimeActivity, StrongZero}
78114
f_orig = unwrap(func.val)
79-
Enzyme.autodiff(Forward, Const(f_orig), Const, args...)
115+
mode = _fwd_mode(Val(false), Val(RuntimeActivity), Val(StrongZero))
116+
Enzyme.autodiff(mode, Const(f_orig), Const, args...)
80117
return nothing
81118
end
82119

@@ -151,11 +188,21 @@ function EnzymeRules.augmented_primal(
151188
end
152189
end
153190

191+
# Helper: build a Forward mode reflecting a RevConfig's runtime_activity /
192+
# strong_zero flags so the internal forward-mode delegation inside reverse
193+
# rules inherits the user's outer config.
194+
@inline function _fwd_mode_from_rev(config::EnzymeRules.RevConfig)
195+
mode = Forward
196+
EnzymeRules.runtime_activity(config) && (mode = Enzyme.set_runtime_activity(mode))
197+
EnzymeRules.strong_zero(config) && (mode = Enzyme.set_strong_zero(mode))
198+
return mode
199+
end
200+
154201
# Varargs reverse: compute each partial via forward-mode AD on the unwrapped
155202
# function, then scale by dret. This avoids type-inference issues that arise
156203
# from calling autodiff(Reverse, Const{Any}(...), ...).
157204
@generated function _fww_reverse_grads(
158-
f_orig, dret_val::T, args::Vararg{EnzymeCore.Active, N}
205+
mode, f_orig, dret_val::T, args::Vararg{EnzymeCore.Active, N}
159206
) where {T, N}
160207
# Build forward-mode calls for each partial derivative
161208
exprs = []
@@ -164,7 +211,7 @@ end
164211
dups = [:(Duplicated(args[$j].val, $(seeds[j]))) for j in 1:N]
165212
Ti = :(eltype(typeof(args[$i])))
166213
push!(exprs, quote
167-
fwd = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{$T}, $(dups...))
214+
fwd = Enzyme.autodiff(mode, Const(f_orig), Duplicated{$T}, $(dups...))
168215
$Ti(fwd[1] * dret_val)::$Ti
169216
end)
170217
end
@@ -179,7 +226,7 @@ function EnzymeRules.reverse(
179226
args::Vararg{EnzymeCore.Active, N}
180227
) where {T, N}
181228
f_orig = unwrap(func.val)
182-
return _fww_reverse_grads(f_orig, dret.val, args...)
229+
return _fww_reverse_grads(_fwd_mode_from_rev(config), f_orig, dret.val, args...)
183230
end
184231

185232
# Handle mixed Active/Const args: return nothing for Const, gradient for Active
@@ -192,6 +239,7 @@ function EnzymeRules.reverse(
192239
) where {N}
193240
f_orig = unwrap(func.val)
194241
dret_val = dret.val
242+
mode = _fwd_mode_from_rev(config)
195243
return ntuple(Val(N)) do i
196244
if args[i] isa EnzymeCore.Const
197245
nothing
@@ -204,7 +252,7 @@ function EnzymeRules.reverse(
204252
Duplicated(args[j].val, zero(eltype(typeof(args[j]))))
205253
end
206254
end
207-
fwd = Enzyme.autodiff(Forward, Const(f_orig), Duplicated, dup_args...)
255+
fwd = Enzyme.autodiff(mode, Const(f_orig), Duplicated, dup_args...)
208256
fwd[1] * dret_val
209257
end
210258
end

test/enzyme_tests.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,83 @@ end
277277
@test u_shadow[1] expected_u_grad
278278
end
279279

280+
# =============================================================================
281+
# Runtime-activity propagation through the FWW forward rules.
282+
#
283+
# Prior to this fix the rules hard-coded plain `Forward` when delegating to
284+
# `Enzyme.autodiff`, silently dropping the caller's
285+
# `set_runtime_activity(Forward)` flag. Enzyme's static IR-level activity
286+
# analysis can't see through `FunctionWrappersWrapper`'s opaque cfunction
287+
# indirection, so the inner call raised `EnzymeRuntimeActivityError` inside
288+
# `@.` broadcast's `broadcast_unalias` → `mightalias` — despite
289+
# `set_runtime_activity` being set on the outer `autodiff` call.
290+
#
291+
# Upstream motivation: OrdinaryDiffEq.jl PR #3518 —
292+
# Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Enzyme.Forward)))
293+
# on any time-dependent in-place RHS routed through DiffEqBase's
294+
# `wrapfun_iip`. Here we reproduce the shape (`f!(du, u, p, t) = @. du = …`)
295+
# in a 4-arg `FunctionWrappersWrapper` matching DiffEqBase's
296+
# `wrapfun_iip` output, and assert both that (a) the call completes without
297+
# an `EnzymeRuntimeActivityError` and (b) the resulting tangent is
298+
# numerically correct.
299+
# =============================================================================
300+
301+
@testset "Enzyme Forward: set_runtime_activity propagates through FWW (IIP, time-dependent)" begin
302+
# DiffEqBase's `wrapfun_iip(ff, (u, u, p, t))` shape.
303+
const_INPUTS = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}
304+
305+
# 1) Time-independent RHS — ∂du/∂t = 0.
306+
f!(du, u, p, t) = (@. du = p * u; nothing)
307+
fww = FunctionWrappersWrapper(f!, (const_INPUTS,), (Nothing,))
308+
309+
u = [1.0, 2.0, 3.0]
310+
p = [0.5, 0.5, 0.5]
311+
t = 1.7
312+
du = zero(u); ddu = zero(u); dt = 1.0
313+
314+
Enzyme.autodiff(
315+
Enzyme.set_runtime_activity(Forward),
316+
Const(fww), Const,
317+
Duplicated(du, ddu),
318+
Const(u), Const(p),
319+
Duplicated(t, dt),
320+
)
321+
@test du p .* u
322+
@test all(iszero, ddu)
323+
324+
# 2) Non-trivial time dependence: g!(du,u,p,t) = @. sin(t)*u.
325+
# Expected ∂du/∂t = cos(t) .* u.
326+
g!(du, u, p, t) = (@. du = sin(t) * u; nothing)
327+
gww = FunctionWrappersWrapper(g!, (const_INPUTS,), (Nothing,))
328+
329+
du2 = zero(u); ddu2 = zero(u)
330+
Enzyme.autodiff(
331+
Enzyme.set_runtime_activity(Forward),
332+
Const(gww), Const,
333+
Duplicated(du2, ddu2),
334+
Const(u), Const(p),
335+
Duplicated(t, 1.0),
336+
)
337+
@test du2 sin(t) .* u
338+
@test ddu2 cos(t) .* u
339+
340+
# 3) Confirm the rule also propagates `set_strong_zero(Forward)` (the
341+
# other ForwardMode flag carried in FwdConfig) — another RHS that
342+
# doesn't need runtime activity but exercises a distinct flag.
343+
h!(du, u, p, t) = (du[1] = u[1] * t; nothing)
344+
hww = FunctionWrappersWrapper(
345+
h!, (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64},),
346+
(Nothing,)
347+
)
348+
du_h = [0.0]; ddu_h = [0.0]
349+
Enzyme.autodiff(
350+
Enzyme.set_strong_zero(Forward),
351+
Const(hww), Const,
352+
Duplicated(du_h, ddu_h),
353+
Const([2.0]), Const([0.0]),
354+
Duplicated(3.5, 1.0),
355+
)
356+
@test du_h[1] 2.0 * 3.5 # primal: u[1] * t = 7.0
357+
@test ddu_h[1] 2.0 # ∂(u[1]*t)/∂t = u[1]
358+
end
359+

0 commit comments

Comments
 (0)