|
76 | 76 | @test result_wp[2] ≈ 9.0 # primal f(3) = 9 |
77 | 77 | end |
78 | 78 |
|
| 79 | +@testset "Enzyme batch forward rule return type is NTuple, not NamedTuple" begin |
| 80 | + # Regression test for the typeassert bug: the inner |
| 81 | + # `Enzyme.autodiff(Forward, …, BatchDuplicated{T,W}, …)` returns the |
| 82 | + # batch shadow wrapped in `Enzyme.Compiler.AnonymousStruct` — a |
| 83 | + # `NamedTuple{(:1, :2, …), NTuple{W, T}}`. The rule must convert it |
| 84 | + # to a plain `NTuple{W, T}` before returning, otherwise the |
| 85 | + # `::NTuple{W, T}` typeassert fires and surfaces as: |
| 86 | + # TypeError: in typeassert, expected Tuple{Float64, Float64}, |
| 87 | + # got a value of type @NamedTuple{1::Float64, 2::Float64} |
| 88 | + # |
| 89 | + # The outer `Enzyme.autodiff` testset above doesn't catch this on its |
| 90 | + # own because the outer call ALSO wraps the result in |
| 91 | + # `AnonymousStruct`, and `shadow[1] / shadow[2]` indexing works on |
| 92 | + # both `NamedTuple` and `Tuple`. Call `EnzymeRules.forward` |
| 93 | + # directly so we observe the rule's actual return value and can |
| 94 | + # assert its concrete type. |
| 95 | + f(x) = x^2 |
| 96 | + fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,)) |
| 97 | + |
| 98 | + # {NeedsPrimal=false, NeedsShadow=true, W=2, RuntimeActivity=false, |
| 99 | + # StrongZero=false} — the shadow-only batch branch. |
| 100 | + config_shadow = EnzymeCore.EnzymeRules.FwdConfig{false, true, 2, false, false}() |
| 101 | + shadow = EnzymeCore.EnzymeRules.forward( |
| 102 | + config_shadow, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2}, |
| 103 | + BatchDuplicated(3.0, (1.0, 2.0)) |
| 104 | + ) |
| 105 | + @test shadow isa NTuple{2, Float64} |
| 106 | + @test !(shadow isa NamedTuple) |
| 107 | + @test shadow == (6.0, 12.0) |
| 108 | + |
| 109 | + # {NeedsPrimal=true, NeedsShadow=true, W=2, …} — ForwardWithPrimal |
| 110 | + # batch branch. Same conversion bug existed on this path. |
| 111 | + config_primal = EnzymeCore.EnzymeRules.FwdConfig{true, true, 2, false, false}() |
| 112 | + result = EnzymeCore.EnzymeRules.forward( |
| 113 | + config_primal, Const(fww), EnzymeCore.BatchDuplicated{Float64, 2}, |
| 114 | + BatchDuplicated(3.0, (1.0, 2.0)) |
| 115 | + ) |
| 116 | + @test result isa BatchDuplicated |
| 117 | + @test result.val ≈ 9.0 |
| 118 | + @test result.dval isa NTuple{2, Float64} |
| 119 | + @test !(result.dval isa NamedTuple) |
| 120 | + @test result.dval == (6.0, 12.0) |
| 121 | + |
| 122 | + # Confirm the conversion generalises to W > 2. |
| 123 | + config_w3 = EnzymeCore.EnzymeRules.FwdConfig{false, true, 3, false, false}() |
| 124 | + shadow3 = EnzymeCore.EnzymeRules.forward( |
| 125 | + config_w3, Const(fww), EnzymeCore.BatchDuplicated{Float64, 3}, |
| 126 | + BatchDuplicated(3.0, (1.0, 2.0, 4.0)) |
| 127 | + ) |
| 128 | + @test shadow3 isa NTuple{3, Float64} |
| 129 | + @test shadow3 == (6.0, 12.0, 24.0) |
| 130 | +end |
| 131 | + |
79 | 132 | @testset "Enzyme forward mode, Const return (IIP, no return-shadow)" begin |
80 | 133 | # Covers EnzymeRules.FwdConfig{false, false, W, ...} — Enzyme dispatches on |
81 | 134 | # this combo for IIP functions with a Const return type where the caller |
|
0 commit comments