Skip to content

Add Reactant extension for branchless interpolation#516

Open
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:reactant-ext
Open

Add Reactant extension for branchless interpolation#516
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:reactant-ext

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

  • Adds DataInterpolationsReactantExt package extension that provides branchless _interpolate methods for TracedRNumber inputs
  • Supports LinearInterpolation, ConstantInterpolation, and QuadraticInterpolation
  • Uses cascading ifelse chains (lowered to stablehlo.select) instead of if branching, which fails with TracedRNumber{Bool}
  • Handles all ExtrapolationType variants (Constant, Linear, Extension) in the traced path

Motivation

Fixes SciML/SciMLSensitivity.jl#1430 — when using ReactantVJP as the autojacvec in GaussAdjoint, DataInterpolations calls inside ODE functions fail because _interpolate uses if t < first(A.t) where t is a Reactant.TracedRNumber. The comparison produces TracedRNumber{Bool} which can't be used in Julia's boolean context.

Approach

For each supported interpolation type, the extension evaluates all segments and selects the correct one using ifelse. This is O(n) in knot count but fully traceable by Reactant. For typical ODE interpolation tables (tens to hundreds of knots), this is efficient enough.

Test plan

  • 80 tests covering LinearInterpolation, ConstantInterpolation, QuadraticInterpolation with Reactant tracing
  • Multiple knot counts, extrapolation modes (Constant, Extension), boundary values
  • Verified the original SciMLSensitivity.jl#1430 MRE works end-to-end (forward solve + Zygote gradient with ReactantVJP)
  • Gradient accuracy verified against finite differences (~0.2% relative error)
  • Existing interpolation behavior unchanged (no regressions)
  • CI tests

🤖 Generated with Claude Code

DataInterpolations methods use `if t < first(A.t)` branching that fails when
`t` is a Reactant.TracedRNumber (produces TracedRNumber{Bool} which can't be
used in boolean context). This adds a package extension that provides branchless
`_interpolate` methods using `ifelse` chains for LinearInterpolation,
ConstantInterpolation, and QuadraticInterpolation.

Fixes SciML/SciMLSensitivity.jl#1430

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Investigation: alternative approaches

Explored several alternatives to the branchless ifelse approach:

@trace function / @trace if (ReactantCore)

  • @trace if converts if statements to ReactantCore.traced_if, which uses Reactant's if_condition MLIR op. However, if_condition traces both branches, and the ExtrapolationType.None branch calls throw() which fails during tracing.
  • @trace function wraps the function to call traced_call when traced arguments are detected. This routes through Reactant.Ops.callmake_mlir_fn, but make_mlir_fn still uses call_with_reactant to trace the function body, hitting the same if branch error.

@reactant_overlay with traced_call

  • Registering an overlay via Base.Experimental.@overlay REACTANT_METHOD_TABLE successfully intercepts the call in Reactant's interpreter. However, routing to traced_call(_interpolate, interp, t) still fails because Ops.callmake_mlir_fncall_with_reactant traces through _interpolate and hits the if branch.

Benchmark: branchless vs branching

Knots Branching (ns) Branchless (ns) Ratio
5 54 19 0.35x
10 53 32 0.61x
20 53 58 ~1.1x
50 53 143 2.7x
100 54 270 5x
1000 54 2827 52x

The branching version is O(log n) via correlated binary search. The branchless O(n) scan wins for small tables but degrades linearly. For typical ODE interpolation tables (<100 knots), the overhead is acceptable.

Conclusion

The branchless TracedRNumber dispatch approach is the only viable option from an extension. The fundamental issue is that Reactant's call_with_reactant generated function traces through function bodies directly, and no mechanism (@trace, overlay, traced_call) can prevent it from encountering the if statement in _interpolate. The only way to avoid this is to dispatch to an entirely different method body that never uses if on traced values.

A deeper fix would require either:

  1. DataInterpolations adding ReactantCore as a dependency and sprinkling @trace on if statements in source (but @trace if needs both branches to be traceable, ruling out throw branches)
  2. Reactant learning to automatically convert if with TracedRNumber{Bool} conditions to traced_if during its overlay pass (a Reactant.jl enhancement)

@ChrisRackauckas
Copy link
Copy Markdown
Member

@avik-pal I'm surprised an overlay doesn't end up good here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ReactantVJP issue with DataInterpolations

2 participants