@@ -4,7 +4,8 @@ using SciMLBase, Mooncake
44using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator
55import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive ,
66 @from_rrule , @zero_adjoint , @mooncake_overlay , MinimalCtx,
7- NoPullback
7+ NoPullback, NoTangent, NoRData, primal, tangent, prepare_pullback_cache,
8+ value_and_pullback!!
89
910# OverrideInitData and ODENLStepData are solver/initialization infrastructure
1011# embedded in ODEFunction type parameters. They are not differentiable, but their
@@ -29,5 +30,216 @@ function rrule!!(
2930 return zero_fcodual (SciMLBase. MooncakeOriginator ()), NoPullback (f, X)
3031end
3132
33+ # ============================================================================
34+ # tmap and responsible_map rules for Ensemble AD
35+ # These enable Mooncake to differentiate through ensemble solves by providing
36+ # proper AD nesting for the mapped function.
37+ # See: https://github.com/SciML/DiffEqBase.jl/issues/1256
38+ # ============================================================================
39+
40+ # Mark tmap and responsible_map as primitives
41+ @is_primitive MinimalCtx Tuple{typeof (SciMLBase. tmap), Any, Vararg}
42+ @is_primitive MinimalCtx Tuple{typeof (SciMLBase. responsible_map), Any, Vararg}
43+
44+ # Helper to accumulate tangents
45+ function _accum_tangents (a:: NoTangent , b:: NoTangent )
46+ return NoTangent ()
47+ end
48+ function _accum_tangents (a:: NoTangent , b)
49+ return b
50+ end
51+ function _accum_tangents (a, b:: NoTangent )
52+ return a
53+ end
54+ function _accum_tangents (a:: T , b:: T ) where {T <: Number }
55+ return a + b
56+ end
57+ function _accum_tangents (a:: Tuple , b:: Tuple )
58+ return map (_accum_tangents, a, b)
59+ end
60+ function _accum_tangents (a:: NamedTuple{N} , b:: NamedTuple{N} ) where {N}
61+ return NamedTuple {N} (map (_accum_tangents, values (a), values (b)))
62+ end
63+ function _accum_tangents (a, b)
64+ # Fallback: try addition
65+ return a + b
66+ end
67+
68+ """
69+ rrule!! for SciMLBase.tmap
70+
71+ Implements reverse-mode AD for tmap by:
72+ 1. Forward pass: Compute primals and prepare pullback caches for each element
73+ 2. Reverse pass: Read gradients from output fdata, compute input gradients via caches
74+
75+ Note: For vectors, Mooncake uses fdata (tangent field of CoDual) for gradients,
76+ not rdata. The pullback receives NoRData() and must read the gradient from
77+ the output's fdata which was modified by the downstream operation.
78+ """
79+ function rrule!! (
80+ :: CoDual{typeof(SciMLBase.tmap)} ,
81+ f_dual:: CoDual{F} ,
82+ args_dual:: CoDual...
83+ ) where {F}
84+ # Extract primals and tangents (fdata)
85+ f = primal (f_dual)
86+ args = map (primal, args_dual)
87+ args_tangents = map (tangent, args_dual)
88+
89+ n = length (args[1 ])
90+
91+ # Compute first element to determine output type
92+ if n == 0
93+ # Empty case - infer type from function signature if possible
94+ T = Core. Compiler. return_type (f, Tuple{map (eltype, args)... })
95+ ys = Vector {T} (undef, 0 )
96+ return zero_fcodual (ys), NoPullback (zero_fcodual (SciMLBase. tmap), f_dual, args_dual... )
97+ end
98+
99+ # Forward pass: compute values and prepare caches
100+ caches = Vector {Any} (undef, n)
101+
102+ # Compute first element to get the type
103+ arg_1 = ntuple (j -> args[j][1 ], length (args))
104+ caches[1 ] = prepare_pullback_cache (f, arg_1... )
105+ y1 = f (arg_1... )
106+
107+ # Create properly typed output vector and its tangent (fdata)
108+ ys = Vector {typeof(y1)} (undef, n)
109+ ys[1 ] = y1
110+ ys_tangent = zeros (typeof (y1), n)
111+
112+ # Compute remaining elements
113+ for i in 2 : n
114+ arg_i = ntuple (j -> args[j][i], length (args))
115+ # Prepare cache for this call
116+ caches[i] = prepare_pullback_cache (f, arg_i... )
117+ # Compute primal value
118+ ys[i] = f (arg_i... )
119+ end
120+
121+ # Create output CoDual - the tangent will be modified by downstream pullbacks
122+ ys_codual = CoDual (ys, ys_tangent)
123+
124+ function tmap_pullback!! (:: NoRData )
125+ # For vectors, gradient comes from fdata (ys_tangent), not rdata
126+ # The downstream operation (e.g., sum) will have modified ys_tangent
127+
128+ # Compute gradients for each element and accumulate into input tangents
129+ Δf = NoTangent ()
130+
131+ for i in 1 : n
132+ arg_i = ntuple (j -> args[j][i], length (args))
133+ # Get cotangent for this element from output fdata
134+ δ_i = ys_tangent[i]
135+
136+ # Use cache to compute pullback
137+ _, tangents_i = value_and_pullback!! (caches[i], δ_i, f, arg_i... )
138+ # tangents_i is (df, darg1, darg2, ...)
139+
140+ Δf = _accum_tangents (Δf, tangents_i[1 ])
141+
142+ # Accumulate into input tangents (fdata)
143+ for j in 1 : length (args)
144+ args_tangents[j][i] += tangents_i[j + 1 ]
145+ end
146+ end
147+
148+ # Return NoRData for all args since gradients are in fdata
149+ Δargs = ntuple (_ -> NoRData (), length (args))
150+ return NoTangent (), Δf, Δargs...
151+ end
152+
153+ return ys_codual, tmap_pullback!!
154+ end
155+
156+ """
157+ rrule!! for SciMLBase.responsible_map
158+
159+ Implements reverse-mode AD for responsible_map by:
160+ 1. Forward pass: Compute primals and prepare pullback caches for each element
161+ 2. Reverse pass: Read gradients from output fdata, compute input gradients in reverse order (for stateful f)
162+
163+ Note: For vectors, Mooncake uses fdata (tangent field of CoDual) for gradients,
164+ not rdata. The pullback receives NoRData() and must read the gradient from
165+ the output's fdata which was modified by the downstream operation.
166+ """
167+ function rrule!! (
168+ :: CoDual{typeof(SciMLBase.responsible_map)} ,
169+ f_dual:: CoDual{F} ,
170+ args_dual:: CoDual...
171+ ) where {F}
172+ # Extract primals and tangents (fdata)
173+ f = primal (f_dual)
174+ args = map (primal, args_dual)
175+ args_tangents = map (tangent, args_dual)
176+
177+ n = length (args[1 ])
178+
179+ # Compute first element to determine output type
180+ if n == 0
181+ # Empty case - infer type from function signature if possible
182+ T = Core. Compiler. return_type (f, Tuple{map (eltype, args)... })
183+ ys = Vector {T} (undef, 0 )
184+ return zero_fcodual (ys), NoPullback (zero_fcodual (SciMLBase. responsible_map), f_dual, args_dual... )
185+ end
186+
187+ # Forward pass: compute values and prepare caches
188+ caches = Vector {Any} (undef, n)
189+
190+ # Compute first element to get the type
191+ arg_1 = ntuple (j -> args[j][1 ], length (args))
192+ caches[1 ] = prepare_pullback_cache (f, arg_1... )
193+ y1 = f (arg_1... )
194+
195+ # Create properly typed output vector and its tangent (fdata)
196+ ys = Vector {typeof(y1)} (undef, n)
197+ ys[1 ] = y1
198+ ys_tangent = zeros (typeof (y1), n)
199+
200+ # Compute remaining elements
201+ for i in 2 : n
202+ arg_i = ntuple (j -> args[j][i], length (args))
203+ # Prepare cache for this call
204+ caches[i] = prepare_pullback_cache (f, arg_i... )
205+ # Compute primal value
206+ ys[i] = f (arg_i... )
207+ end
208+
209+ # Create output CoDual - the tangent will be modified by downstream pullbacks
210+ ys_codual = CoDual (ys, ys_tangent)
211+
212+ function responsible_map_pullback!! (:: NoRData )
213+ # For vectors, gradient comes from fdata (ys_tangent), not rdata
214+ # The downstream operation (e.g., sum) will have modified ys_tangent
215+
216+ # Compute gradients for each element and accumulate into input tangents
217+ # Apply pullbacks in reverse order for correctness with stateful f
218+ Δf = NoTangent ()
219+
220+ for i in n: - 1 : 1
221+ arg_i = ntuple (j -> args[j][i], length (args))
222+ # Get cotangent for this element from output fdata
223+ δ_i = ys_tangent[i]
224+
225+ # Use cache to compute pullback
226+ _, tangents_i = value_and_pullback!! (caches[i], δ_i, f, arg_i... )
227+ # tangents_i is (df, darg1, darg2, ...)
228+
229+ Δf = _accum_tangents (Δf, tangents_i[1 ])
230+
231+ # Accumulate into input tangents (fdata)
232+ for j in 1 : length (args)
233+ args_tangents[j][i] += tangents_i[j + 1 ]
234+ end
235+ end
236+
237+ # Return NoRData for all args since gradients are in fdata
238+ Δargs = ntuple (_ -> NoRData (), length (args))
239+ return NoTangent (), Δf, Δargs...
240+ end
241+
242+ return ys_codual, responsible_map_pullback!!
243+ end
32244
33245end
0 commit comments