Skip to content

Commit 448a6a1

Browse files
Propagate set_runtime_activity/set_strong_zero through FWW Enzyme rules
The forward-mode rules hard-coded plain `Forward` in their delegated `Enzyme.autodiff` calls, silently dropping the outer caller's `set_runtime_activity(Forward)` / `set_strong_zero(Forward)` flags. Enzyme's IR-level activity analysis can't see through FunctionWrappersWrapper's cfunction indirection, so the inner call raised `EnzymeRuntimeActivityError` at `broadcast_unalias` / `mightalias` inside `@. du = …` — even when the user had explicitly set runtime activity on the outer `autodiff` call. Reproducer (confirmed): `Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Enzyme.Forward)))` on any time-dependent in-place ODE RHS that DiffEqBase's `AutoSpecialize` wraps with `wrapfun_iip`. See OrdinaryDiffEq.jl PR #3518. Fix: extract `RuntimeActivity` and `StrongZero` from `FwdConfig` / `RevConfig` type parameters and rebuild the `ForwardMode` used in the delegated `Enzyme.autodiff` call with those flags set. The reverse rules' internal forward-mode helpers are updated analogously. Also adds a test exercising the exact DiffEqBase `wrapfun_iip` shape — a 4-arg `(du, u, p, t)` IIP RHS with `@.` broadcast — under `set_runtime_activity(Forward)`, and a `set_strong_zero(Forward)` case. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 11cbd62 commit 448a6a1

2 files changed

Lines changed: 141 additions & 9 deletions

File tree

ext/FunctionWrappersWrappersEnzymeExt.jl

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,32 @@ using Enzyme
55
using EnzymeCore
66
using EnzymeCore.EnzymeRules
77

8+
# =============================================================================
9+
# Helper: build a Forward mode from FwdConfig flags
10+
# =============================================================================
11+
# The outer caller may invoke `Enzyme.autodiff(set_runtime_activity(Forward), …)`
12+
# or `set_strong_zero(Forward)` or `ForwardWithPrimal`. Those settings flow
13+
# into the `EnzymeRules.FwdConfig{NeedsPrimal, NeedsShadow, Width,
14+
# RuntimeActivity, StrongZero}` type parameters of the rule's first argument.
15+
# Before this fix the rules hard-coded plain `Forward` in their inner
16+
# `Enzyme.autodiff` delegation, which dropped both `RuntimeActivity` and
17+
# `StrongZero` — breaking users who need `set_runtime_activity(Forward)` to
18+
# avoid `EnzymeRuntimeActivityError` inside the wrapped function (the SciML
19+
# `Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Forward)))` path
20+
# on an in-place time-dependent RHS; see OrdinaryDiffEq.jl PR #3518).
21+
#
22+
# `_fwd_mode(needs_primal, RuntimeActivity, StrongZero)` returns the
23+
# `ForwardMode` matching the outer config so the delegated call inherits
24+
# those flags.
25+
@inline function _fwd_mode(
26+
::Val{NeedsPrimal}, ::Val{RuntimeActivity}, ::Val{StrongZero}
27+
) where {NeedsPrimal, RuntimeActivity, StrongZero}
28+
mode = NeedsPrimal ? ForwardWithPrimal : Forward
29+
RuntimeActivity && (mode = Enzyme.set_runtime_activity(mode))
30+
StrongZero && (mode = Enzyme.set_strong_zero(mode))
31+
return mode
32+
end
33+
834
# =============================================================================
935
# Forward mode rules — generalized to arbitrary batch width W
1036
# =============================================================================
@@ -17,11 +43,12 @@ function EnzymeRules.forward(
1743
args::Vararg{EnzymeCore.Annotation, N}
1844
) where {T, W, N, RuntimeActivity, StrongZero}
1945
f_orig = unwrap(func.val)
46+
mode = _fwd_mode(Val(false), Val(RuntimeActivity), Val(StrongZero))
2047
if W == 1
21-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{T}, args...)
48+
shadow_result = Enzyme.autodiff(mode, Const(f_orig), Duplicated{T}, args...)
2249
return shadow_result[1]::T
2350
else
24-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), BatchDuplicated{T, W}, args...)
51+
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
2552
return shadow_result[1]::NTuple{W, T}
2653
end
2754
end
@@ -36,12 +63,16 @@ function EnzymeRules.forward(
3663
f_orig = unwrap(func.val)
3764
pargs = ntuple(i -> args[i].val, Val(N))
3865
primal = f_orig(pargs...)::T
66+
# Use plain Forward (not ForwardWithPrimal) here — we already have the
67+
# primal above, and `Duplicated{T}` / `BatchDuplicated{T,W}` as the RT
68+
# annotation asks only for the shadow.
69+
mode = _fwd_mode(Val(false), Val(RuntimeActivity), Val(StrongZero))
3970
if W == 1
40-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{T}, args...)
71+
shadow_result = Enzyme.autodiff(mode, Const(f_orig), Duplicated{T}, args...)
4172
shadow = shadow_result[1]::T
4273
return Duplicated(primal, shadow)
4374
else
44-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), BatchDuplicated{T, W}, args...)
75+
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
4576
shadows = shadow_result[1]::NTuple{W, T}
4677
return BatchDuplicated(primal, shadows)
4778
end
@@ -69,14 +100,24 @@ end
69100
# orders of magnitude in OrdinaryDiffEq.jl v7. Delegate to `Enzyme.autodiff`
70101
# on the unwrapped function with a Const return annotation so the Duplicated
71102
# arg shadows are propagated correctly and no return is produced.
103+
#
104+
# IMPORTANT: forward the `RuntimeActivity` and `StrongZero` flags from the
105+
# outer config into the delegated `Enzyme.autodiff` call. Prior to this
106+
# fix the rule hard-coded `Forward`, silently dropping
107+
# `set_runtime_activity(Forward)` on the way down into `f_orig`. That
108+
# broke `Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Forward)))`
109+
# on any time-dependent in-place RHS:
110+
# EnzymeRuntimeActivityError at `broadcast_unalias` → `mightalias`
111+
# despite the user explicitly setting runtime activity.
72112
function EnzymeRules.forward(
73113
::EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero},
74114
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
75115
RT::Type{<:EnzymeCore.Annotation},
76116
args::Vararg{EnzymeCore.Annotation, N}
77117
) where {W, N, RuntimeActivity, StrongZero}
78118
f_orig = unwrap(func.val)
79-
Enzyme.autodiff(Forward, Const(f_orig), Const, args...)
119+
mode = _fwd_mode(Val(false), Val(RuntimeActivity), Val(StrongZero))
120+
Enzyme.autodiff(mode, Const(f_orig), Const, args...)
80121
return nothing
81122
end
82123

@@ -151,11 +192,21 @@ function EnzymeRules.augmented_primal(
151192
end
152193
end
153194

195+
# Helper: build a Forward mode reflecting a RevConfig's runtime_activity /
196+
# strong_zero flags so the internal forward-mode delegation inside reverse
197+
# rules inherits the user's outer config.
198+
@inline function _fwd_mode_from_rev(config::EnzymeRules.RevConfig)
199+
mode = Forward
200+
EnzymeRules.runtime_activity(config) && (mode = Enzyme.set_runtime_activity(mode))
201+
EnzymeRules.strong_zero(config) && (mode = Enzyme.set_strong_zero(mode))
202+
return mode
203+
end
204+
154205
# Varargs reverse: compute each partial via forward-mode AD on the unwrapped
155206
# function, then scale by dret. This avoids type-inference issues that arise
156207
# from calling autodiff(Reverse, Const{Any}(...), ...).
157208
@generated function _fww_reverse_grads(
158-
f_orig, dret_val::T, args::Vararg{EnzymeCore.Active, N}
209+
mode, f_orig, dret_val::T, args::Vararg{EnzymeCore.Active, N}
159210
) where {T, N}
160211
# Build forward-mode calls for each partial derivative
161212
exprs = []
@@ -164,7 +215,7 @@ end
164215
dups = [:(Duplicated(args[$j].val, $(seeds[j]))) for j in 1:N]
165216
Ti = :(eltype(typeof(args[$i])))
166217
push!(exprs, quote
167-
fwd = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{$T}, $(dups...))
218+
fwd = Enzyme.autodiff(mode, Const(f_orig), Duplicated{$T}, $(dups...))
168219
$Ti(fwd[1] * dret_val)::$Ti
169220
end)
170221
end
@@ -179,7 +230,7 @@ function EnzymeRules.reverse(
179230
args::Vararg{EnzymeCore.Active, N}
180231
) where {T, N}
181232
f_orig = unwrap(func.val)
182-
return _fww_reverse_grads(f_orig, dret.val, args...)
233+
return _fww_reverse_grads(_fwd_mode_from_rev(config), f_orig, dret.val, args...)
183234
end
184235

185236
# Handle mixed Active/Const args: return nothing for Const, gradient for Active
@@ -192,6 +243,7 @@ function EnzymeRules.reverse(
192243
) where {N}
193244
f_orig = unwrap(func.val)
194245
dret_val = dret.val
246+
mode = _fwd_mode_from_rev(config)
195247
return ntuple(Val(N)) do i
196248
if args[i] isa EnzymeCore.Const
197249
nothing
@@ -204,7 +256,7 @@ function EnzymeRules.reverse(
204256
Duplicated(args[j].val, zero(eltype(typeof(args[j]))))
205257
end
206258
end
207-
fwd = Enzyme.autodiff(Forward, Const(f_orig), Duplicated, dup_args...)
259+
fwd = Enzyme.autodiff(mode, Const(f_orig), Duplicated, dup_args...)
208260
fwd[1] * dret_val
209261
end
210262
end

test/enzyme_tests.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,83 @@ end
277277
@test u_shadow[1] expected_u_grad
278278
end
279279

280+
# =============================================================================
281+
# Runtime-activity propagation through the FWW forward rules.
282+
#
283+
# Prior to this fix the rules hard-coded plain `Forward` when delegating to
284+
# `Enzyme.autodiff`, silently dropping the caller's
285+
# `set_runtime_activity(Forward)` flag. Enzyme's static IR-level activity
286+
# analysis can't see through `FunctionWrappersWrapper`'s opaque cfunction
287+
# indirection, so the inner call raised `EnzymeRuntimeActivityError` inside
288+
# `@.` broadcast's `broadcast_unalias` → `mightalias` — despite
289+
# `set_runtime_activity` being set on the outer `autodiff` call.
290+
#
291+
# Upstream motivation: OrdinaryDiffEq.jl PR #3518 —
292+
# Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Enzyme.Forward)))
293+
# on any time-dependent in-place RHS routed through DiffEqBase's
294+
# `wrapfun_iip`. Here we reproduce the shape (`f!(du, u, p, t) = @. du = …`)
295+
# in a 4-arg `FunctionWrappersWrapper` matching DiffEqBase's
296+
# `wrapfun_iip` output, and assert both that (a) the call completes without
297+
# an `EnzymeRuntimeActivityError` and (b) the resulting tangent is
298+
# numerically correct.
299+
# =============================================================================
300+
301+
@testset "Enzyme Forward: set_runtime_activity propagates through FWW (IIP, time-dependent)" begin
302+
# DiffEqBase's `wrapfun_iip(ff, (u, u, p, t))` shape.
303+
const_INPUTS = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}
304+
305+
# 1) Time-independent RHS — ∂du/∂t = 0.
306+
f!(du, u, p, t) = (@. du = p * u; nothing)
307+
fww = FunctionWrappersWrapper(f!, (const_INPUTS,), (Nothing,))
308+
309+
u = [1.0, 2.0, 3.0]
310+
p = [0.5, 0.5, 0.5]
311+
t = 1.7
312+
du = zero(u); ddu = zero(u); dt = 1.0
313+
314+
Enzyme.autodiff(
315+
Enzyme.set_runtime_activity(Forward),
316+
Const(fww), Const,
317+
Duplicated(du, ddu),
318+
Const(u), Const(p),
319+
Duplicated(t, dt),
320+
)
321+
@test du p .* u
322+
@test all(iszero, ddu)
323+
324+
# 2) Non-trivial time dependence: g!(du,u,p,t) = @. sin(t)*u.
325+
# Expected ∂du/∂t = cos(t) .* u.
326+
g!(du, u, p, t) = (@. du = sin(t) * u; nothing)
327+
gww = FunctionWrappersWrapper(g!, (const_INPUTS,), (Nothing,))
328+
329+
du2 = zero(u); ddu2 = zero(u)
330+
Enzyme.autodiff(
331+
Enzyme.set_runtime_activity(Forward),
332+
Const(gww), Const,
333+
Duplicated(du2, ddu2),
334+
Const(u), Const(p),
335+
Duplicated(t, 1.0),
336+
)
337+
@test du2 sin(t) .* u
338+
@test ddu2 cos(t) .* u
339+
340+
# 3) Confirm the rule also propagates `set_strong_zero(Forward)` (the
341+
# other ForwardMode flag carried in FwdConfig) — another RHS that
342+
# doesn't need runtime activity but exercises a distinct flag.
343+
h!(du, u, p, t) = (du[1] = u[1] * t; nothing)
344+
hww = FunctionWrappersWrapper(
345+
h!, (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64},),
346+
(Nothing,)
347+
)
348+
du_h = [0.0]; ddu_h = [0.0]
349+
Enzyme.autodiff(
350+
Enzyme.set_strong_zero(Forward),
351+
Const(hww), Const,
352+
Duplicated(du_h, ddu_h),
353+
Const([2.0]), Const([0.0]),
354+
Duplicated(3.5, 1.0),
355+
)
356+
@test du_h[1] 2.0 * 3.5 # primal: u[1] * t = 7.0
357+
@test ddu_h[1] 2.0 # ∂(u[1]*t)/∂t = u[1]
358+
end
359+

0 commit comments

Comments
 (0)