Skip to content

Commit 7c6770e

Browse files
Return tunable vector for Enzyme tangent instead of NamedTuple (#1358)
For EnzymeOriginator with SciMLStructure parameters, return the tunable gradient vector directly from steadystatebackpass instead of the Zygote-repacked NamedTuple. The NonlinearSolveBaseEnzymeExt reverse rule uses SciMLStructures.replace! to accumulate it into the parameter shadow, going through the proper SciMLStructures interface. This avoids the NamedTuple broadcasting error and ensures all tangent accumulation uses SciMLStructures.canonicalize/replace! rather than making assumptions about the NamedTuple field structure. Companion PR: NonlinearSolve.jl#879 Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1a9e18e commit 7c6770e

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

src/concrete_solve.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,16 +2418,26 @@ function SciMLBase._concrete_solve_adjoint(
24182418
Δtunables
24192419
)
24202420

2421+
# For Enzyme with SciMLStructure parameters, return the tunable gradient
2422+
# vector directly instead of the Zygote-repacked NamedTuple. The Enzyme
2423+
# reverse rule in NonlinearSolveBaseEnzymeExt uses
2424+
# SciMLStructures.replace! to accumulate it into the parameter shadow.
2425+
dp_tangent = if originator isa SciMLBase.EnzymeOriginator && isscimlstructure(p)
2426+
dp
2427+
else
2428+
repack_adjoint(dp)[1]
2429+
end
2430+
24212431
return if originator isa SciMLBase.TrackerOriginator ||
24222432
originator isa SciMLBase.ReverseDiffOriginator
24232433
(
2424-
NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
2434+
NoTangent(), NoTangent(), NoTangent(), dp_tangent, NoTangent(),
24252435
ntuple(_ -> NoTangent(), length(args))...,
24262436
)
24272437
else
24282438
(
24292439
NoTangent(), NoTangent(), NoTangent(),
2430-
NoTangent(), repack_adjoint(dp)[1], NoTangent(),
2440+
NoTangent(), dp_tangent, NoTangent(),
24312441
ntuple(_ -> NoTangent(), length(args))...,
24322442
)
24332443
end

0 commit comments

Comments
 (0)