Skip to content

Add Mooncake rrule!! for tmap and responsible_map#1214

Merged
ChrisRackauckas merged 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:mooncake-tmap-responsible-map-rules
Apr 11, 2026
Merged

Add Mooncake rrule!! for tmap and responsible_map#1214
ChrisRackauckas merged 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:mooncake-tmap-responsible-map-rules

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

  • Implements reverse-mode AD rules (rrule!!) for SciMLBase.tmap and SciMLBase.responsible_map functions
  • Enables Mooncake to differentiate through ensemble solves
  • Uses Mooncake's fdata system for vector gradients (tangent field of CoDual)
  • Prepares pullback caches during forward pass for proper nested AD

Implementation Details

  • Forward pass computes primals and prepares pullback caches for each element
  • Reverse pass reads gradients from output fdata, computes input gradients via caches
  • responsible_map applies pullbacks in reverse order (for correctness with stateful functions)
  • Helper _accum_tangents function handles tangent accumulation for various types

Test Plan

  • Extension compiles successfully
  • Simple gradient computation with tmap: sum(map(x->x^2, xs)) produces correct gradients [2.0, 4.0, 6.0] for xs = [1.0, 2.0, 3.0]
  • responsible_map produces same correct gradients
  • Multi-argument case: sum(map((x,y)->x^2+y, xs, ys)) produces correct gradients for both inputs
  • SciMLBase tests pass
  • Downstream ensemble AD tests (to be validated after merge)

Related Issues

Closes SciML/OrdinaryDiffEq.jl#3229

🤖 Generated with Claude Code

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Update on Test Results

✅ tmap/responsible_map rules work correctly in isolation:

using SciMLBase, Mooncake

f(x) = x^2
xs = [1.0, 2.0, 3.0]

function loss(xs)
    ys = SciMLBase.tmap(f, xs)
    return sum(ys)
end

cache = Mooncake.prepare_gradient_cache(loss, xs)
val, grad = Mooncake.value_and_gradient!!(cache, loss, xs)
# Value: 14.0
# Gradient: [2.0, 4.0, 6.0] ✓

⚠️ Full ensemble solve still fails:
The downstream ensemble AD test still encounters a StackOverflowError during Mooncake's rule compilation for __solve with EnsembleProblem. The stack overflow occurs before reaching tmap/responsible_map - Mooncake gets stuck in infinite recursion while compiling rules for the ensemble solve machinery.

Root cause analysis:
The issue appears to be that Mooncake's rule compilation for __solve(::EnsembleProblem, ...) enters infinite recursion. The tmap/responsible_map rules here are necessary but not sufficient - additional work is needed to either:

  1. Mark __solve for EnsembleProblem as a primitive with a custom rule
  2. Or break the compilation cycle some other way

This PR provides the foundational tmap/responsible_map rules that will be needed once the higher-level issue is resolved.

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

MWE for the remaining issue

using OrdinaryDiffEq
using SciMLSensitivity
using DifferentiationInterface
using ADTypes: AutoMooncake
using Mooncake

function fiip(du, u, p, t)
    du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = -p[3] * u[2] + p[4] * u[1] * u[2]
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0; 1.0]
prob = ODEProblem(fiip, u0, (0.0, 10.0), p)

N = 3
eu0 = rand(N, 2)
ep = rand(N, 4)

function sum_of_e_solution(p)
    ensemble_prob = EnsembleProblem(
        prob,
        prob_func = (prob, i, repeat) -> remake(prob, u0 = eu0[i, :], p = p[i, :], saveat = 0.1)
    )
    sol = solve(ensemble_prob, Tsit5(), EnsembleSerial(), trajectories = N)
    return sum(Array(sol.u[1]))
end

# This works:
sum_of_e_solution(ep)

# This fails with StackOverflowError during rule compilation:
DifferentiationInterface.gradient(sum_of_e_solution, AutoMooncake(; config = nothing), ep)

Error:

Warning: detected a stack overflow; program state may be corrupted
Mooncake.MooncakeRuleCompilationError(...Tuple{SciMLBase.var"##__solve#799", ..., EnsembleProblem{...}, Tsit5{...}, EnsembleSerial}...)

The stack overflow occurs in Mooncake's rule compilation for __solve with EnsembleProblem, before it reaches tmap/responsible_map.

This implements reverse-mode AD rules for SciMLBase.tmap and
SciMLBase.responsible_map functions, enabling Mooncake to
differentiate through ensemble solves.

Key implementation details:
- Uses Mooncake's fdata system for vector gradients (tangent field of CoDual)
- Prepares pullback caches during forward pass for nested AD
- Applies pullbacks in reverse order for responsible_map (for stateful f)

Closes https://github.com/SciML/DiffEqBase.jl/issues/1256

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@ChrisRackauckas-Claude ChrisRackauckas-Claude force-pushed the mooncake-tmap-responsible-map-rules branch from 5d9a389 to 4caeba9 Compare April 11, 2026 11:32
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Rebased onto current master (971535e) to pick up ~172 commits of upstream changes since the original Feb 12 CI run. The rebase was clean — the OverrideInitData/ODENLStepData tangent_type overrides added in be70f40 are preserved ahead of the new tmap/responsible_map rules.

Triage of the 6 failures from the old Feb 12 run — none look related to the tmap/responsible_map rules:

  • DiffEqBase.jl/DownstreamExtraVariablesSystemException: The system is unbalanced in MTK's tearing during the "Null DE Handling" test. MTK/DiffEqBase API drift.
  • Optimization.jl/AllERROR: LoadError: Dev path .../downstream/lib/All does not exist. Test harness pathing issue in the Optimization repo layout, nothing to do with SciMLBase.
  • ModelingToolkit.jl/DownstreamUndefVarError: __concrete_solve_algorithm not defined + UndefVarError: defaults not defined in the linearization test. MTK internal symbol drift.
  • ModelingToolkit.jl/FMI — Same __concrete_solve_algorithm undef + a standalone pendulum FMI error. Same MTK drift.
  • ModelingToolkitStandardLibrary.jl/CoreExtraVariablesSystemException in SISO check + UndefVarError: _clamp not defined in Blocks math. MTKStdLib internal drift.
  • SciMLSensitivity.jl/Core8MethodError: no method matching ModelingToolkitBase.System(...) keyword signature + type SCCNonlinearProblem has no field u0. MTK/SciMLBase API mismatch.

All six read as stale-CI drift from pre-v3 refactors that have since landed on master, not regressions from tmap/responsible_map. Fresh CI on the rebased branch should clarify.

Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Local verification against the rebased branch (SciMLBase 3.1.0 + Mooncake 0.5.24 + DifferentiationInterface):

```
Loaded. Mooncake version: 0.5.24
SciMLBase version: 3.1.0
tmap gradient: [2.0, 4.0, 6.0] (expect [2.0, 4.0, 6.0])
responsible_map gradient: [2.0, 4.0, 6.0] (expect [2.0, 4.0, 6.0])
```

Both single-arg tmap and responsible_map rules produce correct gradients via DifferentiationInterface.gradient(..., AutoMooncake(; config=nothing), xs) on sum(map(x -> x^2, xs)) for xs = [1.0, 2.0, 3.0]. Extension loads cleanly, compiles without errors.

Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com

@ChrisRackauckas ChrisRackauckas merged commit 8569c5c into SciML:master Apr 11, 2026
50 of 61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mooncake AD backend limitations with MTK and Ensemble problems

3 participants