FwdConfig{false,false} rule: delegate to Enzyme.autodiff for IIP shadow propagation#44
Conversation
…rapped f Follow-up to SciML#43. The original revision ran `f_orig(pargs...)` by hand to cover the IIP-void-return Enzyme-forward path that was throwing `MethodError: no method matching forward(::FwdConfigWidth{1, false, false, false, false}, …)`. That version fixed the dispatch but left the `Duplicated` arg shadow buffers untouched (the inner call only exercised the primal-valued function wrapper), so downstream callers that rely on shadow propagation through args got a trivially zero Jacobian. Observed concretely in SciML/OrdinaryDiffEq.jl v7 `Downstream` `time_derivative_test.jl` with `AutoEnzyme(mode = Enzyme.Forward, function_annotation = Const)`: Rosenbrock23 error: 5.55e-17 < 1e-10 PASS Rodas4 error: 1.11e-6 > 1e-10 FAIL Rodas5 error: 0.022 > 1e-10 FAIL Veldd4 error: 5.56e-7 > 1e-10 FAIL After delegating to `Enzyme.autodiff(Forward, Const(f_orig), Const, args...)`, all four pass at machine epsilon — matching master. Update the regression test to assert that the Duplicated shadow buffer is correctly updated (`∂du[1]/∂u[1] * u_shadow[1] = -2*u[1]*1 = -6`) rather than left at zero. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
|
Audited the other dispatches per review feedback — none of them need additional changes here, though the pattern is worth explaining. Key constraint: That shapes the pattern of every rule:
The reverse-side The reverse rules that can delegate via forward-mode trick (Active dret) already do — that's So this PR stays narrowly on the |
…ive args The reverse rules from SciML#43 had two bugs exposed by new end-to-end tests: 1. Const-dret reverse returned `nothing` per arg, but Enzyme's rule protocol requires concrete scalar gradients for Active args (not nothing). Fixed to return `zero(T)` for Active args and `nothing` for Duplicated/Const args. 2. IIP reverse with Duplicated args (SciML pattern) returned nothing and never propagated gradients into the Duplicated shadow buffers. Fixed by delegating to `Enzyme.autodiff(Reverse, Const(f_orig), Const, args...)` when Duplicated args are present, so Enzyme accumulates the transposed derivative into the shadow buffers. 3. Enzyme passes `Type{<:Const}` (not an instance) for the dret slot in Const-return reverse rules. Updated dispatch signatures from `dret::EnzymeCore.Const` to `dret::Type{<:EnzymeCore.Const}`. New end-to-end reverse-mode tests that assert derivative correctness: - Const return + Active args: gradients are (0.0, 0.0) - IIP f!(du, u) with Duplicated args: u_shadow accumulates ∂du/∂u - Multi-component IIP cross-coupled Jacobian transpose - ReverseWithPrimal IIP variant Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Follow-up to #43. That PR's `FwdConfig{false, false, …}` rule invoked
`f_orig(pargs...)` directly, which fixed the original `MethodError` but
evaluated only the primal — the Duplicated arg shadows stayed at zero.
SciML solvers that rely on shadow propagation through IIP `Duplicated`
args under `AutoEnzyme(Forward, function_annotation = Const)` therefore
saw a trivially zero Jacobian.
Motivating failure (SciML/OrdinaryDiffEq.jl v7, `Downstream 1`)
`time_derivative_test.jl` solving with `AutoEnzyme(mode = Enzyme.Forward, function_annotation = Enzyme.Const)` over an IIP RHS:
All four match master at machine epsilon after this fix.
Fix
Replace the hand-rolled primal call with `Enzyme.autodiff(Forward, Const(f_orig), Const, args...)`. The `Const` return annotation tells Enzyme there's no return-shadow to produce, while the Duplicated args get proper shadow propagation through the unwrapped function.
Test
Updated the regression test to assert `du_shadow` is populated to the analytic derivative value (`∂du[1]/∂u[1] * u_shadow[1] = -2u[1] = -6` at `u=3.0, u_shadow=1.0`) rather than left at zero.
Ran the full `Pkg.test()` locally: 25 passed, 0 failed, 1 errored (the errored test is the pre-existing `Enzyme batch forward mode (width > 1)` `TypeError: expected Tuple{Float64, Float64}, got @NamedTuple` — unrelated, looks like an EnzymeCore shape change).
Other dispatch audit
Checked every rule in the extension for similar `run-primal-when-you-shouldn't` / `don't-run-primal-when-you-should` bugs:
Only this one forward rule needed a semantic correction.
🤖 Generated with Claude Code