@@ -5,6 +5,32 @@ using Enzyme
55using EnzymeCore
66using EnzymeCore. EnzymeRules
77
8+ # =============================================================================
9+ # Helper: build a Forward mode from FwdConfig flags
10+ # =============================================================================
11+ # The outer caller may invoke `Enzyme.autodiff(set_runtime_activity(Forward), …)`
12+ # or `set_strong_zero(Forward)` or `ForwardWithPrimal`. Those settings flow
13+ # into the `EnzymeRules.FwdConfig{NeedsPrimal, NeedsShadow, Width,
14+ # RuntimeActivity, StrongZero}` type parameters of the rule's first argument.
15+ # Before this fix the rules hard-coded plain `Forward` in their inner
16+ # `Enzyme.autodiff` delegation, which dropped both `RuntimeActivity` and
17+ # `StrongZero` — breaking users who need `set_runtime_activity(Forward)` to
18+ # avoid `EnzymeRuntimeActivityError` inside the wrapped function (the SciML
19+ # `Rosenbrock23(autodiff = AutoEnzyme(set_runtime_activity(Forward)))` path
20+ # on an in-place time-dependent RHS; see OrdinaryDiffEq.jl PR #3518).
21+ #
22+ # `_fwd_mode(needs_primal, RuntimeActivity, StrongZero)` returns the
23+ # `ForwardMode` matching the outer config so the delegated call inherits
24+ # those flags.
25+ @inline function _fwd_mode (
26+ :: Val{NeedsPrimal} , :: Val{RuntimeActivity} , :: Val{StrongZero}
27+ ) where {NeedsPrimal, RuntimeActivity, StrongZero}
28+ mode = NeedsPrimal ? ForwardWithPrimal : Forward
29+ RuntimeActivity && (mode = Enzyme. set_runtime_activity (mode))
30+ StrongZero && (mode = Enzyme. set_strong_zero (mode))
31+ return mode
32+ end
33+
834# =============================================================================
935# Forward mode rules — generalized to arbitrary batch width W
1036# =============================================================================
@@ -17,11 +43,12 @@ function EnzymeRules.forward(
1743 args:: Vararg{EnzymeCore.Annotation, N}
1844) where {T, W, N, RuntimeActivity, StrongZero}
1945 f_orig = unwrap (func. val)
46+ mode = _fwd_mode (Val (false ), Val (RuntimeActivity), Val (StrongZero))
2047 if W == 1
21- shadow_result = Enzyme. autodiff (Forward , Const (f_orig), Duplicated{T}, args... )
48+ shadow_result = Enzyme. autodiff (mode , Const (f_orig), Duplicated{T}, args... )
2249 return shadow_result[1 ]:: T
2350 else
24- shadow_result = Enzyme. autodiff (Forward , Const (f_orig), BatchDuplicated{T, W}, args... )
51+ shadow_result = Enzyme. autodiff (mode , Const (f_orig), BatchDuplicated{T, W}, args... )
2552 return shadow_result[1 ]:: NTuple{W, T}
2653 end
2754end
@@ -36,12 +63,16 @@ function EnzymeRules.forward(
3663 f_orig = unwrap (func. val)
3764 pargs = ntuple (i -> args[i]. val, Val (N))
3865 primal = f_orig (pargs... ):: T
66+ # Use plain Forward (not ForwardWithPrimal) here — we already have the
67+ # primal above, and `Duplicated{T}` / `BatchDuplicated{T,W}` as the RT
68+ # annotation asks only for the shadow.
69+ mode = _fwd_mode (Val (false ), Val (RuntimeActivity), Val (StrongZero))
3970 if W == 1
40- shadow_result = Enzyme. autodiff (Forward , Const (f_orig), Duplicated{T}, args... )
71+ shadow_result = Enzyme. autodiff (mode , Const (f_orig), Duplicated{T}, args... )
4172 shadow = shadow_result[1 ]:: T
4273 return Duplicated (primal, shadow)
4374 else
44- shadow_result = Enzyme. autodiff (Forward , Const (f_orig), BatchDuplicated{T, W}, args... )
75+ shadow_result = Enzyme. autodiff (mode , Const (f_orig), BatchDuplicated{T, W}, args... )
4576 shadows = shadow_result[1 ]:: NTuple{W, T}
4677 return BatchDuplicated (primal, shadows)
4778 end
69100# orders of magnitude in OrdinaryDiffEq.jl v7. Delegate to `Enzyme.autodiff`
70101# on the unwrapped function with a Const return annotation so the Duplicated
71102# arg shadows are propagated correctly and no return is produced.
103+ #
104+ # IMPORTANT: forward the `RuntimeActivity` and `StrongZero` flags from the
105+ # outer config into the delegated `Enzyme.autodiff` call. Prior to this
106+ # fix the rule hard-coded `Forward`, silently dropping
107+ # `set_runtime_activity(Forward)` on the way down into `f_orig`.
72108function EnzymeRules. forward (
73109 :: EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero} ,
74110 func:: EnzymeCore.Const{<:FunctionWrappersWrapper} ,
75111 RT:: Type{<:EnzymeCore.Annotation} ,
76112 args:: Vararg{EnzymeCore.Annotation, N}
77113) where {W, N, RuntimeActivity, StrongZero}
78114 f_orig = unwrap (func. val)
79- Enzyme. autodiff (Forward, Const (f_orig), Const, args... )
115+ mode = _fwd_mode (Val (false ), Val (RuntimeActivity), Val (StrongZero))
116+ Enzyme. autodiff (mode, Const (f_orig), Const, args... )
80117 return nothing
81118end
82119
@@ -151,11 +188,21 @@ function EnzymeRules.augmented_primal(
151188 end
152189end
153190
191+ # Helper: build a Forward mode reflecting a RevConfig's runtime_activity /
192+ # strong_zero flags so the internal forward-mode delegation inside reverse
193+ # rules inherits the user's outer config.
194+ @inline function _fwd_mode_from_rev (config:: EnzymeRules.RevConfig )
195+ mode = Forward
196+ EnzymeRules. runtime_activity (config) && (mode = Enzyme. set_runtime_activity (mode))
197+ EnzymeRules. strong_zero (config) && (mode = Enzyme. set_strong_zero (mode))
198+ return mode
199+ end
200+
154201# Varargs reverse: compute each partial via forward-mode AD on the unwrapped
155202# function, then scale by dret. This avoids type-inference issues that arise
156203# from calling autodiff(Reverse, Const{Any}(...), ...).
157204@generated function _fww_reverse_grads (
158- f_orig, dret_val:: T , args:: Vararg{EnzymeCore.Active, N}
205+ mode, f_orig, dret_val:: T , args:: Vararg{EnzymeCore.Active, N}
159206) where {T, N}
160207 # Build forward-mode calls for each partial derivative
161208 exprs = []
164211 dups = [:(Duplicated (args[$ j]. val, $ (seeds[j]))) for j in 1 : N]
165212 Ti = :(eltype (typeof (args[$ i])))
166213 push! (exprs, quote
167- fwd = Enzyme. autodiff (Forward , Const (f_orig), Duplicated{$ T}, $ (dups... ))
214+ fwd = Enzyme. autodiff (mode , Const (f_orig), Duplicated{$ T}, $ (dups... ))
168215 $ Ti (fwd[1 ] * dret_val):: $Ti
169216 end )
170217 end
@@ -179,7 +226,7 @@ function EnzymeRules.reverse(
179226 args:: Vararg{EnzymeCore.Active, N}
180227) where {T, N}
181228 f_orig = unwrap (func. val)
182- return _fww_reverse_grads (f_orig, dret. val, args... )
229+ return _fww_reverse_grads (_fwd_mode_from_rev (config), f_orig, dret. val, args... )
183230end
184231
185232# Handle mixed Active/Const args: return nothing for Const, gradient for Active
@@ -192,6 +239,7 @@ function EnzymeRules.reverse(
192239) where {N}
193240 f_orig = unwrap (func. val)
194241 dret_val = dret. val
242+ mode = _fwd_mode_from_rev (config)
195243 return ntuple (Val (N)) do i
196244 if args[i] isa EnzymeCore. Const
197245 nothing
@@ -204,7 +252,7 @@ function EnzymeRules.reverse(
204252 Duplicated (args[j]. val, zero (eltype (typeof (args[j]))))
205253 end
206254 end
207- fwd = Enzyme. autodiff (Forward , Const (f_orig), Duplicated, dup_args... )
255+ fwd = Enzyme. autodiff (mode , Const (f_orig), Duplicated, dup_args... )
208256 fwd[1 ] * dret_val
209257 end
210258 end
0 commit comments