diff --git a/ext/FunctionWrappersWrappersEnzymeExt.jl b/ext/FunctionWrappersWrappersEnzymeExt.jl index 9e0a3ae..8797ea7 100644 --- a/ext/FunctionWrappersWrappersEnzymeExt.jl +++ b/ext/FunctionWrappersWrappersEnzymeExt.jl @@ -49,7 +49,12 @@ function EnzymeRules.forward( return shadow_result[1]::T else shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...) - return shadow_result[1]::NTuple{W, T} + # Enzyme returns the batch shadow as an `AnonymousStruct` — a + # `NamedTuple{(:1, :2, …), NTuple{W, T}}` (see + # `Enzyme.Compiler.AnonymousStruct` in `Enzyme/src/compiler/utils.jl`). + # Convert to a plain tuple so the rule's return matches the + # `BatchDuplicated` shadow contract Enzyme expects from a forward rule. + return Tuple(shadow_result[1])::NTuple{W, T} end end @@ -73,7 +78,9 @@ function EnzymeRules.forward( return Duplicated(primal, shadow) else shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...) - shadows = shadow_result[1]::NTuple{W, T} + # See the comment on the {false, true} rule — `shadow_result[1]` is a + # NamedTuple, not an NTuple. + shadows = Tuple(shadow_result[1])::NTuple{W, T} return BatchDuplicated(primal, shadows) end end diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index c427094..8bbf2ea 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -76,6 +76,59 @@ end @test result_wp[2] ≈ 9.0 # primal f(3) = 9 end +@testset "Enzyme batch forward rule return type is NTuple, not NamedTuple" begin + # Regression test for the typeassert bug: the inner + # `Enzyme.autodiff(Forward, …, BatchDuplicated{T,W}, …)` returns the + # batch shadow wrapped in `Enzyme.Compiler.AnonymousStruct` — a + # `NamedTuple{(:1, :2, …), NTuple{W, T}}`. The rule must convert it + # to a plain `NTuple{W, T}` before returning, otherwise the + # `::NTuple{W, T}` typeassert fires and surfaces as: + # TypeError: in typeassert, expected Tuple{Float64, Float64}, + # got a value of type @NamedTuple{1::Float64, 2::Float64} + # + # The outer `Enzyme.autodiff` testset above doesn't catch this on its + # own because the outer call ALSO wraps the result in + # `AnonymousStruct`, and `shadow[1] / shadow[2]` indexing works on + # both `NamedTuple` and `Tuple`. Call `EnzymeRules.forward` + # directly so we observe the rule's actual return value and can + # assert its concrete type. + f(x) = x^2 + fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,)) + + # {NeedsPrimal=false, NeedsShadow=true, W=2, RuntimeActivity=false, + # StrongZero=false} — the shadow-only batch branch. + config_shadow = EnzymeCore.EnzymeRules.FwdConfig{false, true, 2, false, false}() + shadow = EnzymeCore.EnzymeRules.forward( + config_shadow, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2}, + BatchDuplicated(3.0, (1.0, 2.0)) + ) + @test shadow isa NTuple{2, Float64} + @test !(shadow isa NamedTuple) + @test shadow == (6.0, 12.0) + + # {NeedsPrimal=true, NeedsShadow=true, W=2, …} — ForwardWithPrimal + # batch branch. Same conversion bug existed on this path. + config_primal = EnzymeCore.EnzymeRules.FwdConfig{true, true, 2, false, false}() + result = EnzymeCore.EnzymeRules.forward( + config_primal, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2}, + BatchDuplicated(3.0, (1.0, 2.0)) + ) + @test result isa BatchDuplicated + @test result.val ≈ 9.0 + @test result.dval isa NTuple{2, Float64} + @test !(result.dval isa NamedTuple) + @test result.dval == (6.0, 12.0) + + # Confirm the conversion generalises to W > 2. + config_w3 = EnzymeCore.EnzymeRules.FwdConfig{false, true, 3, false, false}() + shadow3 = EnzymeCore.EnzymeRules.forward( + config_w3, Const(fww), EnzymeCore.BatchDuplicated{Float64, 3}, + BatchDuplicated(3.0, (1.0, 2.0, 4.0)) + ) + @test shadow3 isa NTuple{3, Float64} + @test shadow3 == (6.0, 12.0, 24.0) +end + @testset "Enzyme forward mode, Const return (IIP, no return-shadow)" begin # Covers EnzymeRules.FwdConfig{false, false, W, ...} — Enzyme dispatches on # this combo for IIP functions with a Const return type where the caller