|
36 | 36 | # ============================================================================= |
37 | 37 |
|
38 | 38 | # Shadow only (Forward mode, no primal) |
| 39 | +# |
| 40 | +# `func` is `Annotation{<:FunctionWrappersWrapper}` rather than |
| 41 | +# `Const{<:FunctionWrappersWrapper}` so that callers passing |
| 42 | +# `Duplicated{<:FunctionWrappersWrapper}` also dispatch here. Enzyme drives |
| 43 | +# the rule that way when the outer `autodiff` call is differentiating through |
| 44 | +# a closure that carries an FWW (e.g. NonlinearSolve + SciMLSensitivity, see |
| 45 | +# SciML/FunctionWrappersWrappers.jl#48). The FWW struct itself only carries |
| 46 | +# `FunctionWrapper`s plus cache storage — none of those fields have a |
| 47 | +# meaningful tangent — so the function shadow is ignored and the inner |
| 48 | +# `Enzyme.autodiff` call uses `Const(f_orig)`. |
39 | 49 | function EnzymeRules.forward( |
40 | 50 | ::EnzymeRules.FwdConfig{false, true, W, RuntimeActivity, StrongZero}, |
41 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 51 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
42 | 52 | RT::Type{<:EnzymeCore.Annotation{T}}, |
43 | 53 | args::Vararg{EnzymeCore.Annotation, N} |
44 | 54 | ) where {T, W, N, RuntimeActivity, StrongZero} |
|
56 | 66 | # Both primal and shadow (ForwardWithPrimal mode) |
57 | 67 | function EnzymeRules.forward( |
58 | 68 | ::EnzymeRules.FwdConfig{true, true, W, RuntimeActivity, StrongZero}, |
59 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 69 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
60 | 70 | RT::Type{<:EnzymeCore.Annotation{T}}, |
61 | 71 | args::Vararg{EnzymeCore.Annotation, N} |
62 | 72 | ) where {T, W, N, RuntimeActivity, StrongZero} |
|
81 | 91 | # Primal only (Const return type) — width-independent |
82 | 92 | function EnzymeRules.forward( |
83 | 93 | ::EnzymeRules.FwdConfig{true, false, W, RuntimeActivity, StrongZero}, |
84 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 94 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
85 | 95 | RT::Type{<:EnzymeCore.Annotation}, |
86 | 96 | args::Vararg{EnzymeCore.Annotation, N} |
87 | 97 | ) where {W, N, RuntimeActivity, StrongZero} |
|
107 | 117 | # `set_runtime_activity(Forward)` on the way down into `f_orig`. |
108 | 118 | function EnzymeRules.forward( |
109 | 119 | ::EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero}, |
110 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 120 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
111 | 121 | RT::Type{<:EnzymeCore.Annotation}, |
112 | 122 | args::Vararg{EnzymeCore.Annotation, N} |
113 | 123 | ) where {W, N, RuntimeActivity, StrongZero} |
|
123 | 133 |
|
124 | 134 | function EnzymeRules.augmented_primal( |
125 | 135 | config::EnzymeRules.RevConfig, |
126 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 136 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
127 | 137 | RT::Type{<:EnzymeCore.Active{T}}, |
128 | 138 | args::Vararg{EnzymeCore.Annotation, N} |
129 | 139 | ) where {T, N} |
|
143 | 153 | # the reverse pass has nothing to propagate back from the return. |
144 | 154 | function EnzymeRules.augmented_primal( |
145 | 155 | config::EnzymeRules.RevConfig, |
146 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 156 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
147 | 157 | RT::Type{<:EnzymeCore.Const}, |
148 | 158 | args::Vararg{EnzymeCore.Annotation, N} |
149 | 159 | ) where {N} |
|
157 | 167 | # it available when propagating dret through the arguments. |
158 | 168 | function EnzymeRules.augmented_primal( |
159 | 169 | config::EnzymeRules.RevConfig, |
160 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 170 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
161 | 171 | RT::Type{<:EnzymeCore.Duplicated{T}}, |
162 | 172 | args::Vararg{EnzymeCore.Annotation, N} |
163 | 173 | ) where {T, N} |
|
173 | 183 |
|
174 | 184 | function EnzymeRules.augmented_primal( |
175 | 185 | config::EnzymeRules.RevConfig, |
176 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 186 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
177 | 187 | RT::Type{<:EnzymeCore.BatchDuplicated{T, W}}, |
178 | 188 | args::Vararg{EnzymeCore.Annotation, N} |
179 | 189 | ) where {T, W, N} |
|
220 | 230 |
|
221 | 231 | function EnzymeRules.reverse( |
222 | 232 | config::EnzymeRules.RevConfig, |
223 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 233 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
224 | 234 | dret::EnzymeCore.Active{T}, |
225 | 235 | tape, |
226 | 236 | args::Vararg{EnzymeCore.Active, N} |
|
232 | 242 | # Handle mixed Active/Const args: return nothing for Const, gradient for Active |
233 | 243 | function EnzymeRules.reverse( |
234 | 244 | config::EnzymeRules.RevConfig, |
235 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 245 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
236 | 246 | dret::EnzymeCore.Active, |
237 | 247 | tape, |
238 | 248 | args::Vararg{EnzymeCore.Annotation, N} |
|
271 | 281 | # accumulated in-place by the `Enzyme.autodiff(Reverse, …)` call above. |
272 | 282 | function EnzymeRules.reverse( |
273 | 283 | config::EnzymeRules.RevConfig, |
274 | | - func::EnzymeCore.Const{<:FunctionWrappersWrapper}, |
| 284 | + func::EnzymeCore.Annotation{<:FunctionWrappersWrapper}, |
275 | 285 | dret::Type{<:EnzymeCore.Const}, |
276 | 286 | tape, |
277 | 287 | args::Vararg{EnzymeCore.Annotation, N} |
|
0 commit comments