Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,237 @@ function (dm::DimMover)(x, ps, st)

return cat(eachslice(x; dims = from)...; dims = to), st
end


# Introduce a parent type for all Galerkin basis families, for exmaple Fourier bassis, Chebyshev basis.
abstract type AbstractGalerkinBasis end

# Interface function, to be implemented by concrete basis types, and return the number of basis functions.
function basisdim end

# Interface function, given the current solver time `t`, return the vector of basis function values psi(t).
function basis_eval end
function lift_parameter_tree end
function reconstruct_parameter_tree end

# Helper for the constant-only mode index, which is always the first basis function.
constant_mode(::AbstractGalerkinBasis) = 1

"""
FourierBasis{M}()

A simple Fourier basis with `M` total modes laid out as
[1, sin(2pi tau), cos(2pi tau), sin(4pi tau), cos(4pi tau), ...]

where `tau` is the normalized solver time on `[0, 1]`.

Notes:
- `M == 1` gives a constant-only basis, which should reduce to an ordinary `NeuralODE`.
- Even `M` values are allowed, though the last mode will be an unmatched sine term.
"""
# Define a concrete basis type. M is the number of basis entries, known at compile time.
struct FourierBasis{M} <: AbstractGalerkinBasis end

# The number of basis functions is just M, but we also check that M is at least 1 to avoid invalid zero-mode bases.
basisdim(::FourierBasis{M}) where {M} = M >= 1 ? M : throw(ArgumentError("FourierBasis{$M} requires M ≥ 1."))

# Evaluate the Fourier basis functions at time t, given the timespan for normalization. The output is a length-M vector of basis values psi(t).
function basis_eval(::FourierBasis{M}, t, tspan) where {M}
t0, t1 = tspan
tau = (t - t0) / (t1 - t0)

return map(1:M) do i
if i == 1
one(tau)
else
k = i ÷ 2
iseven(i) ? sinpi(2k * tau) : cospi(2k * tau)
end
end
end

"""
GalerkinNeuralODE(model, basis, tspan, alg = nothing, args...; kwargs...)

Construct a Neural ODE-like layer whose parameters vary continuously with
depth/time via a Galerkin expansion:
Theta(t) = Sum_j alpha_j * psi_j(t)

where `psi_j` comes from `basis` and the trainable object is the coefficient tree `alpha`.

This implementation reuses the current DiffEqFlux solve-and-adjoint path used by `NeuralODE`, but
reconstructs ordinary model parameters inside the ODE RHS at each solver time.

Arguments:

- `model`: A `Flux.Chain` or `Lux.AbstractLuxLayer` neural network that defines the
̇x.
- `tspan`: The timespan to be solved on.
- `args`: Solver positional arguments, such as Tsit5()
- `kwargs`: Solver keyword arguments, such as saveat, tolerances, callbacks
"""
@concrete struct GalerkinNeuralODE <: NeuralDELayer
model <: AbstractLuxLayer
basis <: AbstractGalerkinBasis
tspan
args
kwargs
end

function GalerkinNeuralODE(model, basis::AbstractGalerkinBasis, tspan, args...; kwargs...)
# If the user passes a Flux model instead of a Lux model, adapt it to Lux first
!(model isa AbstractLuxLayer) && (model = FromFluxAdaptor()(model))
return GalerkinNeuralODE(model, basis, tspan, args, kwargs)
end

# ------------------------------------------------------------------------------
# Lux interface
# ------------------------------------------------------------------------------

# Lux requires custom layers to implement `initialparameters`, `initialstates`, `parameterlength`, and `statelength` methods
function LuxCore.initialparameters(rng::AbstractRNG, g::GalerkinNeuralODE)
# Start with the ordinary parameter tree from the base model
base_ps = LuxCore.initialparameters(rng, g.model)
# Lift it into the coefficient tree for the Galerkin expansion, which will be trained instead of the base parameters
return lift_parameter_tree(base_ps, g.basis)
end

# The state tree is unaffected by the Galerkin expansion, so we can just delegate to the base model.
LuxCore.initialstates(rng::AbstractRNG, g::GalerkinNeuralODE) =
LuxCore.initialstates(rng, g.model)

# The number of trainable parameters is the number of basis functions times the number of parameters in the base model
LuxCore.parameterlength(g::GalerkinNeuralODE) =
basisdim(g.basis) * LuxCore.parameterlength(g.model)

# The state length is the same as the base model, since the state tree is unchanged.
LuxCore.statelength(g::GalerkinNeuralODE) =
LuxCore.statelength(g.model)

# ------------------------------------------------------------------------------
# Parameter lifting helpers: ordinary Lux parameter tree -> Galerkin coefficient tree
# ------------------------------------------------------------------------------

# Recursively lift an ordinary parameter tree of NamedTuples into a coefficient tree for the Galerkin expansion.
lift_parameter_tree(ps::NamedTuple, basis::AbstractGalerkinBasis) =
map(v -> lift_parameter_tree(v, basis), ps)

# Same for Tuple structures.
lift_parameter_tree(ps::Tuple, basis::AbstractGalerkinBasis) =
map(v -> lift_parameter_tree(v, basis), ps)

# Base case: if the parameter leaf is `nothing`, we can just return `nothing` in the lifted tree, since it won't be trained or used in the RHS.
lift_parameter_tree(::Nothing, ::AbstractGalerkinBasis) = nothing

# Base case: handles array-valued parameter leaves
function lift_parameter_tree(x::AbstractArray, basis::AbstractGalerkinBasis)
# Number of basis functions
m = basisdim(basis)

# Allocate a new array with the same element type and backend as x, but with one extra leading dimension of size m.
alpha = similar(x, (m, size(x)...))

# Initialize all coefficient slices to zero.
fill!(alpha, zero(eltype(alpha)))

# Pick the constant basis slice and copy the original parameter x into it.
# So the initial model is exactly the original static-parameter model, with all nonconstant modes turned off.
selectdim(alpha, 1, constant_mode(basis)) .= x

# Return the lifted array leaf.
return alpha
end

# Base case: if the parameter leaf is a scalar number
function lift_parameter_tree(x::Number, basis::AbstractGalerkinBasis)
# Create a length-m vector of scalar coefficients.
alpha = fill(zero(x), basisdim(basis))

# Put the original scalar into the constant mode.
alpha[constant_mode(basis)] = x

# Return the lifted scalar leaf.
return alpha
end

# Fallback method for unsupported leaf types, which throws an error if we encounter a parameter leaf type that we don't know how to lift.
lift_parameter_tree(x, ::AbstractGalerkinBasis) =
throw(ArgumentError("Unsupported parameter leaf type $(typeof(x)) in Galerkin lifting."))

# ------------------------------------------------------------------------------
# Reconstruction helpers: coefficient tree alpha + basis values psi(t) -> ordinary parameter tree
# ------------------------------------------------------------------------------

# Recursively reconstruct an ordinary parameter tree by applying the Galerkin expansion at each leaf.
reconstruct_parameter_tree(alpha::NamedTuple, psi::AbstractVector) =
map(v -> reconstruct_parameter_tree(v, psi), alpha)

# Same for Tuple structures.
reconstruct_parameter_tree(alpha::Tuple, psi::AbstractVector) =
map(v -> reconstruct_parameter_tree(v, psi), alpha)

# Base case: if the parameter leaf is `nothing`, we can just return `nothing` in the reconstructed tree
reconstruct_parameter_tree(::Nothing, ::AbstractVector) = nothing

# Base case: handles array-valued parameter leaves, which are reconstructed via a linear combination of the basis slices weighted by the basis values psi.
@inline function reconstruct_parameter_tree(alpha::AbstractArray, psi::AbstractVector)
# Check that the number of basis functions matches the leading dimension of the coefficient array.
length(psi) == size(alpha, 1) || throw(DimensionMismatch(
"basis length $(length(psi)) does not match lifted parameter size $(size(alpha, 1))."
))

# Scalar leaf case: a base scalar parameter becomes an m-vector of coefficients.
if ndims(alpha) == 1
return LinearAlgebra.dot(alpha, psi)
end

# Fast path for ordinary dense/strided arrays.
if alpha isa StridedArray
m = length(psi)
flat = reshape(alpha, m, :)
return reshape(flat' * psi, Base.tail(size(alpha)))
end

# Generic fallback for array types without a convenient strided contraction.
y = psi[1] .* selectdim(alpha, 1, 1)
@inbounds for j in 2:lastindex(psi)
y = y .+ psi[j] .* selectdim(alpha, 1, j)
end
return y
end

reconstruct_parameter_tree(alpha, ::AbstractVector) =
throw(ArgumentError("Unsupported coefficient leaf type $(typeof(alpha)) in Galerkin reconstruction."))

# ------------------------------------------------------------------------------
# Layer call
# ------------------------------------------------------------------------------
# The main call method, which constructs the ODE problem and solves it using the reconstructed parameters at each RHS evaluation.
# With x as the initial state, alpha as the coefficient tree of trainable parameters, and st as the initial state tree
function (g::GalerkinNeuralODE)(x, alpha, st)
# The model is a stateful wrapper around the base model
model = StatefulLuxLayer{fixed_state_type(g.model)}(g.model, nothing, st)

# The ODE RHS function, which reconstructs the parameter tree at the current time and evaluates the model to get the state derivative.
function dudt(u, alpha, t)
psi = basis_eval(g.basis, t, g.tspan)
p_t = reconstruct_parameter_tree(alpha, psi)
return model(u, p_t)
end

# IMPORTANT: Unlike current `NeuralODE`, we do not pass `tgrad = basic_tgrad` here,
# because the RHS depends on time/depth through the reconstructed parameter tree p(t).
ff = ODEFunction{false}(dudt)
prob = ODEProblem{false}(ff, x, g.tspan, alpha)

# Solve the ODE problem using the InterpolatingAdjoint, which will allow gradients to flow through the solver and the time-dependent parameters.
return (
solve(
prob,
g.args...;
# sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()),
g.kwargs...,
),
model.st,
)
end