Skip to content

Commit 4caeba9

Browse files
claudeChrisRackauckas
authored andcommitted
Add Mooncake rrule!! for tmap and responsible_map
This implements reverse-mode AD rules for SciMLBase.tmap and SciMLBase.responsible_map functions, enabling Mooncake to differentiate through ensemble solves. Key implementation details: - Uses Mooncake's fdata system for vector gradients (tangent field of CoDual) - Prepares pullback caches during forward pass for nested AD - Applies pullbacks in reverse order for responsible_map (for stateful f) Closes https://github.com/SciML/DiffEqBase.jl/issues/1256 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 971535e commit 4caeba9

1 file changed

Lines changed: 213 additions & 1 deletion

File tree

ext/SciMLBaseMooncakeExt.jl

Lines changed: 213 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using SciMLBase, Mooncake
44
using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator
55
import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
66
@from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx,
7-
NoPullback
7+
NoPullback, NoTangent, NoRData, primal, tangent, prepare_pullback_cache,
8+
value_and_pullback!!
89

910
# OverrideInitData and ODENLStepData are solver/initialization infrastructure
1011
# embedded in ODEFunction type parameters. They are not differentiable, but their
@@ -29,5 +30,216 @@ function rrule!!(
2930
return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X)
3031
end
3132

33+
# ============================================================================
34+
# tmap and responsible_map rules for Ensemble AD
35+
# These enable Mooncake to differentiate through ensemble solves by providing
36+
# proper AD nesting for the mapped function.
37+
# See: https://github.com/SciML/DiffEqBase.jl/issues/1256
38+
# ============================================================================
39+
40+
# Mark tmap and responsible_map as primitives
41+
@is_primitive MinimalCtx Tuple{typeof(SciMLBase.tmap), Any, Vararg}
42+
@is_primitive MinimalCtx Tuple{typeof(SciMLBase.responsible_map), Any, Vararg}
43+
44+
# Helper to accumulate tangents
45+
function _accum_tangents(a::NoTangent, b::NoTangent)
46+
return NoTangent()
47+
end
48+
function _accum_tangents(a::NoTangent, b)
49+
return b
50+
end
51+
function _accum_tangents(a, b::NoTangent)
52+
return a
53+
end
54+
function _accum_tangents(a::T, b::T) where {T <: Number}
55+
return a + b
56+
end
57+
function _accum_tangents(a::Tuple, b::Tuple)
58+
return map(_accum_tangents, a, b)
59+
end
60+
function _accum_tangents(a::NamedTuple{N}, b::NamedTuple{N}) where {N}
61+
return NamedTuple{N}(map(_accum_tangents, values(a), values(b)))
62+
end
63+
function _accum_tangents(a, b)
64+
# Fallback: try addition
65+
return a + b
66+
end
67+
68+
"""
69+
rrule!! for SciMLBase.tmap
70+
71+
Implements reverse-mode AD for tmap by:
72+
1. Forward pass: Compute primals and prepare pullback caches for each element
73+
2. Reverse pass: Read gradients from output fdata, compute input gradients via caches
74+
75+
Note: For vectors, Mooncake uses fdata (tangent field of CoDual) for gradients,
76+
not rdata. The pullback receives NoRData() and must read the gradient from
77+
the output's fdata which was modified by the downstream operation.
78+
"""
79+
function rrule!!(
80+
::CoDual{typeof(SciMLBase.tmap)},
81+
f_dual::CoDual{F},
82+
args_dual::CoDual...
83+
) where {F}
84+
# Extract primals and tangents (fdata)
85+
f = primal(f_dual)
86+
args = map(primal, args_dual)
87+
args_tangents = map(tangent, args_dual)
88+
89+
n = length(args[1])
90+
91+
# Compute first element to determine output type
92+
if n == 0
93+
# Empty case - infer type from function signature if possible
94+
T = Core.Compiler.return_type(f, Tuple{map(eltype, args)...})
95+
ys = Vector{T}(undef, 0)
96+
return zero_fcodual(ys), NoPullback(zero_fcodual(SciMLBase.tmap), f_dual, args_dual...)
97+
end
98+
99+
# Forward pass: compute values and prepare caches
100+
caches = Vector{Any}(undef, n)
101+
102+
# Compute first element to get the type
103+
arg_1 = ntuple(j -> args[j][1], length(args))
104+
caches[1] = prepare_pullback_cache(f, arg_1...)
105+
y1 = f(arg_1...)
106+
107+
# Create properly typed output vector and its tangent (fdata)
108+
ys = Vector{typeof(y1)}(undef, n)
109+
ys[1] = y1
110+
ys_tangent = zeros(typeof(y1), n)
111+
112+
# Compute remaining elements
113+
for i in 2:n
114+
arg_i = ntuple(j -> args[j][i], length(args))
115+
# Prepare cache for this call
116+
caches[i] = prepare_pullback_cache(f, arg_i...)
117+
# Compute primal value
118+
ys[i] = f(arg_i...)
119+
end
120+
121+
# Create output CoDual - the tangent will be modified by downstream pullbacks
122+
ys_codual = CoDual(ys, ys_tangent)
123+
124+
function tmap_pullback!!(::NoRData)
125+
# For vectors, gradient comes from fdata (ys_tangent), not rdata
126+
# The downstream operation (e.g., sum) will have modified ys_tangent
127+
128+
# Compute gradients for each element and accumulate into input tangents
129+
Δf = NoTangent()
130+
131+
for i in 1:n
132+
arg_i = ntuple(j -> args[j][i], length(args))
133+
# Get cotangent for this element from output fdata
134+
δ_i = ys_tangent[i]
135+
136+
# Use cache to compute pullback
137+
_, tangents_i = value_and_pullback!!(caches[i], δ_i, f, arg_i...)
138+
# tangents_i is (df, darg1, darg2, ...)
139+
140+
Δf = _accum_tangents(Δf, tangents_i[1])
141+
142+
# Accumulate into input tangents (fdata)
143+
for j in 1:length(args)
144+
args_tangents[j][i] += tangents_i[j + 1]
145+
end
146+
end
147+
148+
# Return NoRData for all args since gradients are in fdata
149+
Δargs = ntuple(_ -> NoRData(), length(args))
150+
return NoTangent(), Δf, Δargs...
151+
end
152+
153+
return ys_codual, tmap_pullback!!
154+
end
155+
156+
"""
157+
rrule!! for SciMLBase.responsible_map
158+
159+
Implements reverse-mode AD for responsible_map by:
160+
1. Forward pass: Compute primals and prepare pullback caches for each element
161+
2. Reverse pass: Read gradients from output fdata, compute input gradients in reverse order (for stateful f)
162+
163+
Note: For vectors, Mooncake uses fdata (tangent field of CoDual) for gradients,
164+
not rdata. The pullback receives NoRData() and must read the gradient from
165+
the output's fdata which was modified by the downstream operation.
166+
"""
167+
function rrule!!(
168+
::CoDual{typeof(SciMLBase.responsible_map)},
169+
f_dual::CoDual{F},
170+
args_dual::CoDual...
171+
) where {F}
172+
# Extract primals and tangents (fdata)
173+
f = primal(f_dual)
174+
args = map(primal, args_dual)
175+
args_tangents = map(tangent, args_dual)
176+
177+
n = length(args[1])
178+
179+
# Compute first element to determine output type
180+
if n == 0
181+
# Empty case - infer type from function signature if possible
182+
T = Core.Compiler.return_type(f, Tuple{map(eltype, args)...})
183+
ys = Vector{T}(undef, 0)
184+
return zero_fcodual(ys), NoPullback(zero_fcodual(SciMLBase.responsible_map), f_dual, args_dual...)
185+
end
186+
187+
# Forward pass: compute values and prepare caches
188+
caches = Vector{Any}(undef, n)
189+
190+
# Compute first element to get the type
191+
arg_1 = ntuple(j -> args[j][1], length(args))
192+
caches[1] = prepare_pullback_cache(f, arg_1...)
193+
y1 = f(arg_1...)
194+
195+
# Create properly typed output vector and its tangent (fdata)
196+
ys = Vector{typeof(y1)}(undef, n)
197+
ys[1] = y1
198+
ys_tangent = zeros(typeof(y1), n)
199+
200+
# Compute remaining elements
201+
for i in 2:n
202+
arg_i = ntuple(j -> args[j][i], length(args))
203+
# Prepare cache for this call
204+
caches[i] = prepare_pullback_cache(f, arg_i...)
205+
# Compute primal value
206+
ys[i] = f(arg_i...)
207+
end
208+
209+
# Create output CoDual - the tangent will be modified by downstream pullbacks
210+
ys_codual = CoDual(ys, ys_tangent)
211+
212+
function responsible_map_pullback!!(::NoRData)
213+
# For vectors, gradient comes from fdata (ys_tangent), not rdata
214+
# The downstream operation (e.g., sum) will have modified ys_tangent
215+
216+
# Compute gradients for each element and accumulate into input tangents
217+
# Apply pullbacks in reverse order for correctness with stateful f
218+
Δf = NoTangent()
219+
220+
for i in n:-1:1
221+
arg_i = ntuple(j -> args[j][i], length(args))
222+
# Get cotangent for this element from output fdata
223+
δ_i = ys_tangent[i]
224+
225+
# Use cache to compute pullback
226+
_, tangents_i = value_and_pullback!!(caches[i], δ_i, f, arg_i...)
227+
# tangents_i is (df, darg1, darg2, ...)
228+
229+
Δf = _accum_tangents(Δf, tangents_i[1])
230+
231+
# Accumulate into input tangents (fdata)
232+
for j in 1:length(args)
233+
args_tangents[j][i] += tangents_i[j + 1]
234+
end
235+
end
236+
237+
# Return NoRData for all args since gradients are in fdata
238+
Δargs = ntuple(_ -> NoRData(), length(args))
239+
return NoTangent(), Δf, Δargs...
240+
end
241+
242+
return ys_codual, responsible_map_pullback!!
243+
end
32244

33245
end

0 commit comments

Comments
 (0)