Skip to content

Commit 01bcfe7

Browse files
Merge pull request #50 from ChrisRackauckas-Claude/fix-batch-forward-typeassert
Fix batch-forward typeassert: Enzyme returns AnonymousStruct, not NTuple
2 parents 6575926 + b15abbb commit 01bcfe7

2 files changed

Lines changed: 62 additions & 2 deletions

File tree

ext/FunctionWrappersWrappersEnzymeExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ function EnzymeRules.forward(
5959
return shadow_result[1]::T
6060
else
6161
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
62-
return shadow_result[1]::NTuple{W, T}
62+
# Enzyme returns the batch shadow as an `AnonymousStruct` — a
63+
# `NamedTuple{(:1, :2, …), NTuple{W, T}}` (see
64+
# `Enzyme.Compiler.AnonymousStruct` in `Enzyme/src/compiler/utils.jl`).
65+
# Convert to a plain tuple so the rule's return matches the
66+
# `BatchDuplicated` shadow contract Enzyme expects from a forward rule.
67+
return Tuple(shadow_result[1])::NTuple{W, T}
6368
end
6469
end
6570

@@ -83,7 +88,9 @@ function EnzymeRules.forward(
8388
return Duplicated(primal, shadow)
8489
else
8590
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
86-
shadows = shadow_result[1]::NTuple{W, T}
91+
# See the comment on the {false, true} rule — `shadow_result[1]` is a
92+
# NamedTuple, not an NTuple.
93+
shadows = Tuple(shadow_result[1])::NTuple{W, T}
8794
return BatchDuplicated(primal, shadows)
8895
end
8996
end

test/enzyme_tests.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,59 @@ end
7676
@test result_wp[2] 9.0 # primal f(3) = 9
7777
end
7878

79+
@testset "Enzyme batch forward rule return type is NTuple, not NamedTuple" begin
80+
# Regression test for the typeassert bug: the inner
81+
# `Enzyme.autodiff(Forward, …, BatchDuplicated{T,W}, …)` returns the
82+
# batch shadow wrapped in `Enzyme.Compiler.AnonymousStruct` — a
83+
# `NamedTuple{(:1, :2, …), NTuple{W, T}}`. The rule must convert it
84+
# to a plain `NTuple{W, T}` before returning, otherwise the
85+
# `::NTuple{W, T}` typeassert fires and surfaces as:
86+
# TypeError: in typeassert, expected Tuple{Float64, Float64},
87+
# got a value of type @NamedTuple{1::Float64, 2::Float64}
88+
#
89+
# The outer `Enzyme.autodiff` testset above doesn't catch this on its
90+
# own because the outer call ALSO wraps the result in
91+
# `AnonymousStruct`, and `shadow[1] / shadow[2]` indexing works on
92+
# both `NamedTuple` and `Tuple`. Call `EnzymeRules.forward`
93+
# directly so we observe the rule's actual return value and can
94+
# assert its concrete type.
95+
f(x) = x^2
96+
fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,))
97+
98+
# {NeedsPrimal=false, NeedsShadow=true, W=2, RuntimeActivity=false,
99+
# StrongZero=false} — the shadow-only batch branch.
100+
config_shadow = EnzymeCore.EnzymeRules.FwdConfig{false, true, 2, false, false}()
101+
shadow = EnzymeCore.EnzymeRules.forward(
102+
config_shadow, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2},
103+
BatchDuplicated(3.0, (1.0, 2.0))
104+
)
105+
@test shadow isa NTuple{2, Float64}
106+
@test !(shadow isa NamedTuple)
107+
@test shadow == (6.0, 12.0)
108+
109+
# {NeedsPrimal=true, NeedsShadow=true, W=2, …} — ForwardWithPrimal
110+
# batch branch. Same conversion bug existed on this path.
111+
config_primal = EnzymeCore.EnzymeRules.FwdConfig{true, true, 2, false, false}()
112+
result = EnzymeCore.EnzymeRules.forward(
113+
config_primal, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2},
114+
BatchDuplicated(3.0, (1.0, 2.0))
115+
)
116+
@test result isa BatchDuplicated
117+
@test result.val 9.0
118+
@test result.dval isa NTuple{2, Float64}
119+
@test !(result.dval isa NamedTuple)
120+
@test result.dval == (6.0, 12.0)
121+
122+
# Confirm the conversion generalises to W > 2.
123+
config_w3 = EnzymeCore.EnzymeRules.FwdConfig{false, true, 3, false, false}()
124+
shadow3 = EnzymeCore.EnzymeRules.forward(
125+
config_w3, Const(fww), EnzymeCore.BatchDuplicated{Float64, 3},
126+
BatchDuplicated(3.0, (1.0, 2.0, 4.0))
127+
)
128+
@test shadow3 isa NTuple{3, Float64}
129+
@test shadow3 == (6.0, 12.0, 24.0)
130+
end
131+
79132
@testset "Enzyme forward mode, Const return (IIP, no return-shadow)" begin
80133
# Covers EnzymeRules.FwdConfig{false, false, W, ...} — Enzyme dispatches on
81134
# this combo for IIP functions with a Const return type where the caller

0 commit comments

Comments
 (0)