Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions ext/FunctionWrappersWrappersEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ function EnzymeRules.forward(
return shadow_result[1]::T
else
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
return shadow_result[1]::NTuple{W, T}
# Enzyme returns the batch shadow as an `AnonymousStruct` — a
# `NamedTuple{(:1, :2, …), NTuple{W, T}}` (see
# `Enzyme.Compiler.AnonymousStruct` in `Enzyme/src/compiler/utils.jl`).
# Convert to a plain tuple so the rule's return matches the
# `BatchDuplicated` shadow contract Enzyme expects from a forward rule.
return Tuple(shadow_result[1])::NTuple{W, T}
end
end

Expand All @@ -73,7 +78,9 @@ function EnzymeRules.forward(
return Duplicated(primal, shadow)
else
shadow_result = Enzyme.autodiff(mode, Const(f_orig), BatchDuplicated{T, W}, args...)
shadows = shadow_result[1]::NTuple{W, T}
# See the comment on the {false, true} rule — `shadow_result[1]` is a
# NamedTuple, not an NTuple.
shadows = Tuple(shadow_result[1])::NTuple{W, T}
return BatchDuplicated(primal, shadows)
end
end
Expand Down
53 changes: 53 additions & 0 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,59 @@ end
@test result_wp[2] ≈ 9.0 # primal f(3) = 9
end

@testset "Enzyme batch forward rule return type is NTuple, not NamedTuple" begin
# Regression test for the typeassert bug: the inner
# `Enzyme.autodiff(Forward, …, BatchDuplicated{T,W}, …)` returns the
# batch shadow wrapped in `Enzyme.Compiler.AnonymousStruct` — a
# `NamedTuple{(:1, :2, …), NTuple{W, T}}`. The rule must convert it
# to a plain `NTuple{W, T}` before returning, otherwise the
# `::NTuple{W, T}` typeassert fires and surfaces as:
# TypeError: in typeassert, expected Tuple{Float64, Float64},
# got a value of type @NamedTuple{1::Float64, 2::Float64}
#
# The outer `Enzyme.autodiff` testset above doesn't catch this on its
# own because the outer call ALSO wraps the result in
# `AnonymousStruct`, and `shadow[1] / shadow[2]` indexing works on
# both `NamedTuple` and `Tuple`. Call `EnzymeRules.forward`
# directly so we observe the rule's actual return value and can
# assert its concrete type.
f(x) = x^2
fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,))

# {NeedsPrimal=false, NeedsShadow=true, W=2, RuntimeActivity=false,
# StrongZero=false} — the shadow-only batch branch.
config_shadow = EnzymeCore.EnzymeRules.FwdConfig{false, true, 2, false, false}()
shadow = EnzymeCore.EnzymeRules.forward(
config_shadow, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2},
BatchDuplicated(3.0, (1.0, 2.0))
)
@test shadow isa NTuple{2, Float64}
@test !(shadow isa NamedTuple)
@test shadow == (6.0, 12.0)

# {NeedsPrimal=true, NeedsShadow=true, W=2, …} — ForwardWithPrimal
# batch branch. Same conversion bug existed on this path.
config_primal = EnzymeCore.EnzymeRules.FwdConfig{true, true, 2, false, false}()
result = EnzymeCore.EnzymeRules.forward(
config_primal, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2},
BatchDuplicated(3.0, (1.0, 2.0))
)
@test result isa BatchDuplicated
@test result.val ≈ 9.0
@test result.dval isa NTuple{2, Float64}
@test !(result.dval isa NamedTuple)
@test result.dval == (6.0, 12.0)

# Confirm the conversion generalises to W > 2.
config_w3 = EnzymeCore.EnzymeRules.FwdConfig{false, true, 3, false, false}()
shadow3 = EnzymeCore.EnzymeRules.forward(
config_w3, Const(fww), EnzymeCore.BatchDuplicated{Float64, 3},
BatchDuplicated(3.0, (1.0, 2.0, 4.0))
)
@test shadow3 isa NTuple{3, Float64}
@test shadow3 == (6.0, 12.0, 24.0)
end

@testset "Enzyme forward mode, Const return (IIP, no return-shadow)" begin
# Covers EnzymeRules.FwdConfig{false, false, W, ...} — Enzyme dispatches on
# this combo for IIP functions with a Const return type where the caller
Expand Down
Loading