@@ -5,6 +5,32 @@ using Enzyme
55using EnzymeCore
66using EnzymeCore. EnzymeRules
77
8+ # =============================================================================
9+ # Helper: build a Forward mode from FwdConfig flags
10+ # =============================================================================
11+ # The outer caller may invoke `Enzyme.autodiff(set_runtime_activity(Forward), …)`
12+ # or `set_strong_zero(Forward)` or `ForwardWithPrimal`. Those settings flow
13+ # into the `EnzymeRules.FwdConfig{NeedsPrimal, NeedsShadow, Width,
14+ # RuntimeActivity, StrongZero}` type parameters of the rule's first argument.
15+ # Before this fix the rules hard-coded plain `Forward` in their inner
16+ # `Enzyme.autodiff` delegation, which dropped both `RuntimeActivity` and
17+ # `StrongZero` — breaking users who need `set_runtime_activity(Forward)` to
18+ # avoid `EnzymeRuntimeActivityError` inside the wrapped function (the SciML
19+ # `Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Forward)))` path
20+ # on an in-place time-dependent RHS; see OrdinaryDiffEq.jl PR #3518).
21+ #
22+ # `_fwd_mode(needs_primal, RuntimeActivity, StrongZero)` returns the
23+ # `ForwardMode` matching the outer config so the delegated call inherits
24+ # those flags.
25+ @inline function _fwd_mode (
26+ :: Val{NeedsPrimal} , :: Val{RuntimeActivity} , :: Val{StrongZero}
27+ ) where {NeedsPrimal, RuntimeActivity, StrongZero}
28+ mode = NeedsPrimal ? ForwardWithPrimal : Forward
29+ RuntimeActivity && (mode = Enzyme. set_runtime_activity (mode))
30+ StrongZero && (mode = Enzyme. set_strong_zero (mode))
31+ return mode
32+ end
33+
834# =============================================================================
935# Forward mode rules — generalized to arbitrary batch width W
1036# =============================================================================
@@ -17,11 +43,12 @@ function EnzymeRules.forward(
1743 args:: Vararg{EnzymeCore.Annotation, N}
1844) where {T, W, N, RuntimeActivity, StrongZero}
1945 f_orig = unwrap (func. val)
46+ mode = _fwd_mode (Val (false ), Val (RuntimeActivity), Val (StrongZero))
2047 if W == 1
21- shadow_result = Enzyme. autodiff (Forward , Const (f_orig), Duplicated{T}, args... )
48+ shadow_result = Enzyme. autodiff (mode , Const (f_orig), Duplicated{T}, args... )
2249 return shadow_result[1 ]:: T
2350 else
24- shadow_result = Enzyme. autodiff (Forward , Const (f_orig), BatchDuplicated{T, W}, args... )
51+ shadow_result = Enzyme. autodiff (mode , Const (f_orig), BatchDuplicated{T, W}, args... )
2552 return shadow_result[1 ]:: NTuple{W, T}
2653 end
2754end
@@ -36,12 +63,16 @@ function EnzymeRules.forward(
3663 f_orig = unwrap (func. val)
3764 pargs = ntuple (i -> args[i]. val, Val (N))
3865 primal = f_orig (pargs... ):: T
66+ # Use plain Forward (not ForwardWithPrimal) here — we already have the
67+ # primal above, and `Duplicated{T}` / `BatchDuplicated{T,W}` as the RT
68+ # annotation asks only for the shadow.
69+ mode = _fwd_mode (Val (false ), Val (RuntimeActivity), Val (StrongZero))
3970 if W == 1
40- shadow_result = Enzyme. autodiff (Forward , Const (f_orig), Duplicated{T}, args... )
71+ shadow_result = Enzyme. autodiff (mode , Const (f_orig), Duplicated{T}, args... )
4172 shadow = shadow_result[1 ]:: T
4273 return Duplicated (primal, shadow)
4374 else
44- shadow_result = Enzyme. autodiff (Forward , Const (f_orig), BatchDuplicated{T, W}, args... )
75+ shadow_result = Enzyme. autodiff (mode , Const (f_orig), BatchDuplicated{T, W}, args... )
4576 shadows = shadow_result[1 ]:: NTuple{W, T}
4677 return BatchDuplicated (primal, shadows)
4778 end
69100# orders of magnitude in OrdinaryDiffEq.jl v7. Delegate to `Enzyme.autodiff`
70101# on the unwrapped function with a Const return annotation so the Duplicated
71102# arg shadows are propagated correctly and no return is produced.
103+ #
104+ # IMPORTANT: forward the `RuntimeActivity` and `StrongZero` flags from the
105+ # outer config into the delegated `Enzyme.autodiff` call. Prior to this
106+ # fix the rule hard-coded `Forward`, silently dropping
107+ # `set_runtime_activity(Forward)` on the way down into `f_orig`. That
108+ # broke `Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Forward)))`
109+ # on any time-dependent in-place RHS:
110+ # EnzymeRuntimeActivityError at `broadcast_unalias` → `mightalias`
111+ # despite the user explicitly setting runtime activity.
72112function EnzymeRules. forward (
73113 :: EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero} ,
74114 func:: EnzymeCore.Const{<:FunctionWrappersWrapper} ,
75115 RT:: Type{<:EnzymeCore.Annotation} ,
76116 args:: Vararg{EnzymeCore.Annotation, N}
77117) where {W, N, RuntimeActivity, StrongZero}
78118 f_orig = unwrap (func. val)
79- Enzyme. autodiff (Forward, Const (f_orig), Const, args... )
119+ mode = _fwd_mode (Val (false ), Val (RuntimeActivity), Val (StrongZero))
120+ Enzyme. autodiff (mode, Const (f_orig), Const, args... )
80121 return nothing
81122end
82123
@@ -151,11 +192,21 @@ function EnzymeRules.augmented_primal(
151192 end
152193end
153194
195+ # Helper: build a Forward mode reflecting a RevConfig's runtime_activity /
196+ # strong_zero flags so the internal forward-mode delegation inside reverse
197+ # rules inherits the user's outer config.
198+ @inline function _fwd_mode_from_rev (config:: EnzymeRules.RevConfig )
199+ mode = Forward
200+ EnzymeRules. runtime_activity (config) && (mode = Enzyme. set_runtime_activity (mode))
201+ EnzymeRules. strong_zero (config) && (mode = Enzyme. set_strong_zero (mode))
202+ return mode
203+ end
204+
154205# Varargs reverse: compute each partial via forward-mode AD on the unwrapped
155206# function, then scale by dret. This avoids type-inference issues that arise
156207# from calling autodiff(Reverse, Const{Any}(...), ...).
157208@generated function _fww_reverse_grads (
158- f_orig, dret_val:: T , args:: Vararg{EnzymeCore.Active, N}
209+ mode, f_orig, dret_val:: T , args:: Vararg{EnzymeCore.Active, N}
159210) where {T, N}
160211 # Build forward-mode calls for each partial derivative
161212 exprs = []
164215 dups = [:(Duplicated (args[$ j]. val, $ (seeds[j]))) for j in 1 : N]
165216 Ti = :(eltype (typeof (args[$ i])))
166217 push! (exprs, quote
167- fwd = Enzyme. autodiff (Forward , Const (f_orig), Duplicated{$ T}, $ (dups... ))
218+ fwd = Enzyme. autodiff (mode , Const (f_orig), Duplicated{$ T}, $ (dups... ))
168219 $ Ti (fwd[1 ] * dret_val):: $Ti
169220 end )
170221 end
@@ -179,7 +230,7 @@ function EnzymeRules.reverse(
179230 args:: Vararg{EnzymeCore.Active, N}
180231) where {T, N}
181232 f_orig = unwrap (func. val)
182- return _fww_reverse_grads (f_orig, dret. val, args... )
233+ return _fww_reverse_grads (_fwd_mode_from_rev (config), f_orig, dret. val, args... )
183234end
184235
185236# Handle mixed Active/Const args: return nothing for Const, gradient for Active
@@ -192,6 +243,7 @@ function EnzymeRules.reverse(
192243) where {N}
193244 f_orig = unwrap (func. val)
194245 dret_val = dret. val
246+ mode = _fwd_mode_from_rev (config)
195247 return ntuple (Val (N)) do i
196248 if args[i] isa EnzymeCore. Const
197249 nothing
@@ -204,7 +256,7 @@ function EnzymeRules.reverse(
204256 Duplicated (args[j]. val, zero (eltype (typeof (args[j]))))
205257 end
206258 end
207- fwd = Enzyme. autodiff (Forward , Const (f_orig), Duplicated, dup_args... )
259+ fwd = Enzyme. autodiff (mode , Const (f_orig), Duplicated, dup_args... )
208260 fwd[1 ] * dret_val
209261 end
210262 end
0 commit comments