Skip to content

Commit b15abbb

Browse files
Add direct-rule regression test for batch-forward NTuple conversion
Calls `EnzymeRules.forward` directly so the rule's actual return value is observable. The pre-existing `Enzyme.autodiff(Forward, …)`-driven testset doesn't catch a regression on its own because the outer `Enzyme.autodiff` ALSO wraps in `AnonymousStruct`, and `shadow[1]` indexing works on both `NamedTuple` and `Tuple`. The new testset asserts: - the shadow-only rule (`{false, true, W=2, …}`) returns `NTuple{2, Float64}`, not `NamedTuple`; - the ForwardWithPrimal rule (`{true, true, W=2, …}`) puts an `NTuple{2, Float64}` (not a `NamedTuple`) into `result.dval`; - the conversion generalises to W = 3. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 509d7fb commit b15abbb

1 file changed

Lines changed: 53 additions & 0 deletions

File tree

test/enzyme_tests.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,59 @@ end
7676
@test result_wp[2] 9.0 # primal f(3) = 9
7777
end
7878

79+
@testset "Enzyme batch forward rule return type is NTuple, not NamedTuple" begin
80+
# Regression test for the typeassert bug: the inner
81+
# `Enzyme.autodiff(Forward, …, BatchDuplicated{T,W}, …)` returns the
82+
# batch shadow wrapped in `Enzyme.Compiler.AnonymousStruct` — a
83+
# `NamedTuple{(:1, :2, …), NTuple{W, T}}`. The rule must convert it
84+
# to a plain `NTuple{W, T}` before returning, otherwise the
85+
# `::NTuple{W, T}` typeassert fires and surfaces as:
86+
# TypeError: in typeassert, expected Tuple{Float64, Float64},
87+
# got a value of type @NamedTuple{1::Float64, 2::Float64}
88+
#
89+
# The outer `Enzyme.autodiff` testset above doesn't catch this on its
90+
# own because the outer call ALSO wraps the result in
91+
# `AnonymousStruct`, and `shadow[1] / shadow[2]` indexing works on
92+
# both `NamedTuple` and `Tuple`. Call `EnzymeRules.forward`
93+
# directly so we observe the rule's actual return value and can
94+
# assert its concrete type.
95+
f(x) = x^2
96+
fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,))
97+
98+
# {NeedsPrimal=false, NeedsShadow=true, W=2, RuntimeActivity=false,
99+
# StrongZero=false} — the shadow-only batch branch.
100+
config_shadow = EnzymeCore.EnzymeRules.FwdConfig{false, true, 2, false, false}()
101+
shadow = EnzymeCore.EnzymeRules.forward(
102+
config_shadow, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2},
103+
BatchDuplicated(3.0, (1.0, 2.0))
104+
)
105+
@test shadow isa NTuple{2, Float64}
106+
@test !(shadow isa NamedTuple)
107+
@test shadow == (6.0, 12.0)
108+
109+
# {NeedsPrimal=true, NeedsShadow=true, W=2, …} — ForwardWithPrimal
110+
# batch branch. Same conversion bug existed on this path.
111+
config_primal = EnzymeCore.EnzymeRules.FwdConfig{true, true, 2, false, false}()
112+
result = EnzymeCore.EnzymeRules.forward(
113+
config_primal, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2},
114+
BatchDuplicated(3.0, (1.0, 2.0))
115+
)
116+
@test result isa BatchDuplicated
117+
@test result.val 9.0
118+
@test result.dval isa NTuple{2, Float64}
119+
@test !(result.dval isa NamedTuple)
120+
@test result.dval == (6.0, 12.0)
121+
122+
# Confirm the conversion generalises to W > 2.
123+
config_w3 = EnzymeCore.EnzymeRules.FwdConfig{false, true, 3, false, false}()
124+
shadow3 = EnzymeCore.EnzymeRules.forward(
125+
config_w3, Const(fww), EnzymeCore.BatchDuplicated{Float64, 3},
126+
BatchDuplicated(3.0, (1.0, 2.0, 4.0))
127+
)
128+
@test shadow3 isa NTuple{3, Float64}
129+
@test shadow3 == (6.0, 12.0, 24.0)
130+
end
131+
79132
@testset "Enzyme forward mode, Const return (IIP, no return-shadow)" begin
80133
# Covers EnzymeRules.FwdConfig{false, false, W, ...} — Enzyme dispatches on
81134
# this combo for IIP functions with a Const return type where the caller

0 commit comments

Comments
 (0)