Skip to content

Add ODERNN layer for ODE-RNN/ODE-LSTM architectures#995

Open
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:claude-fix-issue-422
Open

Add ODERNN layer for ODE-RNN/ODE-LSTM architectures#995
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:claude-fix-issue-422

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

Implements the ODERNN layer that combines a Neural ODE for continuous hidden state dynamics with a recurrent neural network cell (LSTM/GRU/RNN) for processing sequential observations.

This implements a variant of the ODE-RNN/ODE-LSTM architecture from:

  • Rubanova et al. "Latent ODEs for Irregularly-Sampled Time Series" (2019)
  • Lechner & Hasani "Learning Long-Term Dependencies in Irregularly-Sampled Time Series" (2020)

Key implementation details:

  • Follows the efficient approach suggested by @ChrisRackauckas: solve the ODE once from first to last time point with saveat=ts, then use the ODE solutions at each time point as continuous hidden states
  • The RNN cell then processes inputs using these ODE-evolved states
  • Uses foldl for a fully differentiable loop implementation (avoids mutation issues with Zygote)
  • Works with any Lux recurrent cell (LSTMCell, GRUCell, RNNCell)
  • Supports return_sequence=true (all outputs) and return_sequence=false (final only)
  • Uses InterpolatingAdjoint with ZygoteVJP for efficient gradient computation

Example usage:

using DiffEqFlux, Lux, Random, OrdinaryDiffEq

rng = Random.default_rng()
input_dim, hidden_dim, seq_len, batch_size = 2, 4, 10, 3

# ODE dynamics model (hidden_dim -> hidden_dim)
ode_model = Chain(Dense(hidden_dim => hidden_dim, tanh))

# RNN cell (processes input_dim -> hidden_dim)
cell = LSTMCell(input_dim => hidden_dim)

# Create ODE-RNN layer
odernn = ODERNN(ode_model, cell, Tsit5(); return_sequence=true)

# Setup
ps, st = Lux.setup(rng, odernn)

# Input data and time points (irregularly-sampled supported!)
x = randn(Float32, input_dim, seq_len, batch_size)
ts = collect(Float32, range(0, 1, length=seq_len))

# Forward pass
outputs, st = odernn((x, ts), ps, st)

Closes #422

cc @ChrisRackauckas @MartinuzziFrancesco

Test plan

  • Added tests for ODERNN with all three cell types (LSTM, GRU, RNN)
  • Tests cover both return_sequence=true and return_sequence=false
  • Tests verify forward pass output shapes
  • Tests verify gradients are non-zero for both model and cell parameters
  • All 27 tests pass locally

🤖 Generated with Claude Code

Implements the ODERNN layer that combines a Neural ODE for continuous hidden state
dynamics with a recurrent neural network cell (LSTM/GRU/RNN) for processing
sequential observations.

This implements a variant of the ODE-RNN/ODE-LSTM architecture from:
- Rubanova et al. "Latent ODEs for Irregularly-Sampled Time Series" (2019)
- Lechner & Hasani "Learning Long-Term Dependencies in Irregularly-Sampled Time Series" (2020)

The implementation follows the efficient approach suggested by @ChrisRackauckas:
- Solve the ODE once from first to last time point with saveat=ts
- Use the ODE solutions at each time point as continuous hidden states
- The RNN cell then processes inputs using these ODE-evolved states

Features:
- Works with any Lux recurrent cell (LSTMCell, GRUCell, RNNCell)
- Supports both return_sequence=true (all outputs) and return_sequence=false (final only)
- Uses InterpolatingAdjoint with ZygoteVJP for efficient gradient computation
- Fully differentiable loop implementation using foldl

Closes SciML#422

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ODE-LSTM layer

2 participants