Add Mooncake rrule!! for tmap and responsible_map#1214
Conversation
Update on Test Results✅ tmap/responsible_map rules work correctly in isolation: using SciMLBase, Mooncake
f(x) = x^2
xs = [1.0, 2.0, 3.0]
function loss(xs)
ys = SciMLBase.tmap(f, xs)
return sum(ys)
end
cache = Mooncake.prepare_gradient_cache(loss, xs)
val, grad = Mooncake.value_and_gradient!!(cache, loss, xs)
# Value: 14.0
# Gradient: [2.0, 4.0, 6.0] ✓
Root cause analysis:
This PR provides the foundational tmap/responsible_map rules that will be needed once the higher-level issue is resolved. |
MWE for the remaining issueusing OrdinaryDiffEq
using SciMLSensitivity
using DifferentiationInterface
using ADTypes: AutoMooncake
using Mooncake
function fiip(du, u, p, t)
du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0; 1.0]
prob = ODEProblem(fiip, u0, (0.0, 10.0), p)
N = 3
eu0 = rand(N, 2)
ep = rand(N, 4)
function sum_of_e_solution(p)
ensemble_prob = EnsembleProblem(
prob,
prob_func = (prob, i, repeat) -> remake(prob, u0 = eu0[i, :], p = p[i, :], saveat = 0.1)
)
sol = solve(ensemble_prob, Tsit5(), EnsembleSerial(), trajectories = N)
return sum(Array(sol.u[1]))
end
# This works:
sum_of_e_solution(ep)
# This fails with StackOverflowError during rule compilation:
DifferentiationInterface.gradient(sum_of_e_solution, AutoMooncake(; config = nothing), ep)Error: The stack overflow occurs in Mooncake's rule compilation for |
e590f8e to
5d9a389
Compare
This implements reverse-mode AD rules for SciMLBase.tmap and SciMLBase.responsible_map functions, enabling Mooncake to differentiate through ensemble solves. Key implementation details: - Uses Mooncake's fdata system for vector gradients (tangent field of CoDual) - Prepares pullback caches during forward pass for nested AD - Applies pullbacks in reverse order for responsible_map (for stateful f) Closes https://github.com/SciML/DiffEqBase.jl/issues/1256 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
5d9a389 to
4caeba9
Compare
|
Rebased onto current master (971535e) to pick up ~172 commits of upstream changes since the original Feb 12 CI run. The rebase was clean — the Triage of the 6 failures from the old Feb 12 run — none look related to the tmap/responsible_map rules:
All six read as stale-CI drift from pre-v3 refactors that have since landed on master, not regressions from Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com |
|
Local verification against the rebased branch (SciMLBase 3.1.0 + Mooncake 0.5.24 + DifferentiationInterface): ``` Both single-arg Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com |
Summary
rrule!!) forSciMLBase.tmapandSciMLBase.responsible_mapfunctionsImplementation Details
responsible_mapapplies pullbacks in reverse order (for correctness with stateful functions)_accum_tangentsfunction handles tangent accumulation for various typesTest Plan
tmap:sum(map(x->x^2, xs))produces correct gradients[2.0, 4.0, 6.0]forxs = [1.0, 2.0, 3.0]responsible_mapproduces same correct gradientssum(map((x,y)->x^2+y, xs, ys))produces correct gradients for both inputsRelated Issues
Closes SciML/OrdinaryDiffEq.jl#3229
🤖 Generated with Claude Code