Skip to content

Commit 509d7fb

Browse files
Convert AnonymousStruct shadow to NTuple in batch forward rules
The forward rules for batch width > 1 asserted `shadow_result[1]::NTuple{W, T}` but `Enzyme.autodiff(Forward, …, BatchDuplicated{T,W}, …)` returns the batch shadow wrapped in `Enzyme.Compiler.AnonymousStruct` — a `NamedTuple{(:1, :2, …), NTuple{W, T}}` (see `Enzyme/src/compiler/utils.jl:480`). The mismatch tripped the existing `Enzyme batch forward mode (width > 1)` testset on `main` with: ``` TypeError: in typeassert, expected Tuple{Float64, Float64}, got a value of type @NamedTuple{1::Float64, 2::Float64} ``` Wrap the shadow in `Tuple(...)` before the type-assert so the rule's return matches the `BatchDuplicated` shadow contract that Enzyme expects from a forward rule. Applies to both `{false, true, W>1, …}` (shadow-only) and `{true, true, W>1, …}` (ForwardWithPrimal) rules. The existing testset (which was failing on `main`) now passes. Full local test summary on Julia 1.11 + Enzyme v0.13.147: ``` FunctionWrappersWrappers.jl | 48 48 BigFloat support | 5 5 UnionAll return types | 2 2 Enzyme extension | 44 44 Mooncake extension | 13 13 ``` Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent a03f55f commit 509d7fb

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

ext/FunctionWrappersWrappersEnzymeExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ function EnzymeRules.forward(
4949
return shadow_result[1]::T
5050
else
5151
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
52-
return shadow_result[1]::NTuple{W, T}
52+
# Enzyme returns the batch shadow as an `AnonymousStruct` — a
53+
# `NamedTuple{(:1, :2, …), NTuple{W, T}}` (see
54+
# `Enzyme.Compiler.AnonymousStruct` in `Enzyme/src/compiler/utils.jl`).
55+
# Convert to a plain tuple so the rule's return matches the
56+
# `BatchDuplicated` shadow contract Enzyme expects from a forward rule.
57+
return Tuple(shadow_result[1])::NTuple{W, T}
5358
end
5459
end
5560

@@ -73,7 +78,9 @@ function EnzymeRules.forward(
7378
return Duplicated(primal, shadow)
7479
else
7580
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
76-
shadows = shadow_result[1]::NTuple{W, T}
81+
# See the comment on the {false, true} rule — `shadow_result[1]` is a
82+
# NamedTuple, not an NTuple.
83+
shadows = Tuple(shadow_result[1])::NTuple{W, T}
7784
return BatchDuplicated(primal, shadows)
7885
end
7986
end

0 commit comments

Comments
 (0)