From 9e264aa417c8058d812094a760afe5a3a0a6d433 Mon Sep 17 00:00:00 2001 From: Adam Dong Date: Tue, 17 Mar 2026 18:47:19 +0100 Subject: [PATCH] Implemented initial Galerkin Neural ODE --- src/neural_de.jl | 234 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 234 insertions(+) diff --git a/src/neural_de.jl b/src/neural_de.jl index 340c7c77c..a6ae5c968 100644 --- a/src/neural_de.jl +++ b/src/neural_de.jl @@ -410,3 +410,237 @@ function (dm::DimMover)(x, ps, st) return cat(eachslice(x; dims = from)...; dims = to), st end + + +# Introduce a parent type for all Galerkin basis families, for exmaple Fourier bassis, Chebyshev basis. +abstract type AbstractGalerkinBasis end + +# Interface function, to be implemented by concrete basis types, and return the number of basis functions. +function basisdim end + +# Interface function, given the current solver time `t`, return the vector of basis function values psi(t). +function basis_eval end +function lift_parameter_tree end +function reconstruct_parameter_tree end + +# Helper for the constant-only mode index, which is always the first basis function. +constant_mode(::AbstractGalerkinBasis) = 1 + +""" + FourierBasis{M}() + +A simple Fourier basis with `M` total modes laid out as + [1, sin(2pi tau), cos(2pi tau), sin(4pi tau), cos(4pi tau), ...] + +where `tau` is the normalized solver time on `[0, 1]`. + +Notes: +- `M == 1` gives a constant-only basis, which should reduce to an ordinary `NeuralODE`. +- Even `M` values are allowed, though the last mode will be an unmatched sine term. +""" +# Define a concrete basis type. M is the number of basis entries, known at compile time. +struct FourierBasis{M} <: AbstractGalerkinBasis end + +# The number of basis functions is just M, but we also check that M is at least 1 to avoid invalid zero-mode bases. +basisdim(::FourierBasis{M}) where {M} = M >= 1 ? M : throw(ArgumentError("FourierBasis{$M} requires M ≥ 1.")) + +# Evaluate the Fourier basis functions at time t, given the timespan for normalization. The output is a length-M vector of basis values psi(t). +function basis_eval(::FourierBasis{M}, t, tspan) where {M} + t0, t1 = tspan + tau = (t - t0) / (t1 - t0) + + return map(1:M) do i + if i == 1 + one(tau) + else + k = i ÷ 2 + iseven(i) ? sinpi(2k * tau) : cospi(2k * tau) + end + end +end + +""" + GalerkinNeuralODE(model, basis, tspan, alg = nothing, args...; kwargs...) + +Construct a Neural ODE-like layer whose parameters vary continuously with +depth/time via a Galerkin expansion: + Theta(t) = Sum_j alpha_j * psi_j(t) + +where `psi_j` comes from `basis` and the trainable object is the coefficient tree `alpha`. + +This implementation reuses the current DiffEqFlux solve-and-adjoint path used by `NeuralODE`, but +reconstructs ordinary model parameters inside the ODE RHS at each solver time. + +Arguments: + + - `model`: A `Flux.Chain` or `Lux.AbstractLuxLayer` neural network that defines the + ̇x. + - `tspan`: The timespan to be solved on. + - `args`: Solver positional arguments, such as Tsit5() + - `kwargs`: Solver keyword arguments, such as saveat, tolerances, callbacks +""" +@concrete struct GalerkinNeuralODE <: NeuralDELayer + model <: AbstractLuxLayer + basis <: AbstractGalerkinBasis + tspan + args + kwargs +end + +function GalerkinNeuralODE(model, basis::AbstractGalerkinBasis, tspan, args...; kwargs...) + # If the user passes a Flux model instead of a Lux model, adapt it to Lux first + !(model isa AbstractLuxLayer) && (model = FromFluxAdaptor()(model)) + return GalerkinNeuralODE(model, basis, tspan, args, kwargs) +end + +# ------------------------------------------------------------------------------ +# Lux interface +# ------------------------------------------------------------------------------ + +# Lux requires custom layers to implement `initialparameters`, `initialstates`, `parameterlength`, and `statelength` methods +function LuxCore.initialparameters(rng::AbstractRNG, g::GalerkinNeuralODE) + # Start with the ordinary parameter tree from the base model + base_ps = LuxCore.initialparameters(rng, g.model) + # Lift it into the coefficient tree for the Galerkin expansion, which will be trained instead of the base parameters + return lift_parameter_tree(base_ps, g.basis) +end + +# The state tree is unaffected by the Galerkin expansion, so we can just delegate to the base model. +LuxCore.initialstates(rng::AbstractRNG, g::GalerkinNeuralODE) = + LuxCore.initialstates(rng, g.model) + +# The number of trainable parameters is the number of basis functions times the number of parameters in the base model +LuxCore.parameterlength(g::GalerkinNeuralODE) = + basisdim(g.basis) * LuxCore.parameterlength(g.model) + +# The state length is the same as the base model, since the state tree is unchanged. +LuxCore.statelength(g::GalerkinNeuralODE) = + LuxCore.statelength(g.model) + +# ------------------------------------------------------------------------------ +# Parameter lifting helpers: ordinary Lux parameter tree -> Galerkin coefficient tree +# ------------------------------------------------------------------------------ + +# Recursively lift an ordinary parameter tree of NamedTuples into a coefficient tree for the Galerkin expansion. +lift_parameter_tree(ps::NamedTuple, basis::AbstractGalerkinBasis) = + map(v -> lift_parameter_tree(v, basis), ps) + +# Same for Tuple structures. +lift_parameter_tree(ps::Tuple, basis::AbstractGalerkinBasis) = + map(v -> lift_parameter_tree(v, basis), ps) + +# Base case: if the parameter leaf is `nothing`, we can just return `nothing` in the lifted tree, since it won't be trained or used in the RHS. +lift_parameter_tree(::Nothing, ::AbstractGalerkinBasis) = nothing + +# Base case: handles array-valued parameter leaves +function lift_parameter_tree(x::AbstractArray, basis::AbstractGalerkinBasis) + # Number of basis functions + m = basisdim(basis) + + # Allocate a new array with the same element type and backend as x, but with one extra leading dimension of size m. + alpha = similar(x, (m, size(x)...)) + + # Initialize all coefficient slices to zero. + fill!(alpha, zero(eltype(alpha))) + + # Pick the constant basis slice and copy the original parameter x into it. + # So the initial model is exactly the original static-parameter model, with all nonconstant modes turned off. + selectdim(alpha, 1, constant_mode(basis)) .= x + + # Return the lifted array leaf. + return alpha +end + +# Base case: if the parameter leaf is a scalar number +function lift_parameter_tree(x::Number, basis::AbstractGalerkinBasis) + # Create a length-m vector of scalar coefficients. + alpha = fill(zero(x), basisdim(basis)) + + # Put the original scalar into the constant mode. + alpha[constant_mode(basis)] = x + + # Return the lifted scalar leaf. + return alpha +end + +# Fallback method for unsupported leaf types, which throws an error if we encounter a parameter leaf type that we don't know how to lift. +lift_parameter_tree(x, ::AbstractGalerkinBasis) = + throw(ArgumentError("Unsupported parameter leaf type $(typeof(x)) in Galerkin lifting.")) + +# ------------------------------------------------------------------------------ +# Reconstruction helpers: coefficient tree alpha + basis values psi(t) -> ordinary parameter tree +# ------------------------------------------------------------------------------ + +# Recursively reconstruct an ordinary parameter tree by applying the Galerkin expansion at each leaf. +reconstruct_parameter_tree(alpha::NamedTuple, psi::AbstractVector) = + map(v -> reconstruct_parameter_tree(v, psi), alpha) + +# Same for Tuple structures. +reconstruct_parameter_tree(alpha::Tuple, psi::AbstractVector) = + map(v -> reconstruct_parameter_tree(v, psi), alpha) + +# Base case: if the parameter leaf is `nothing`, we can just return `nothing` in the reconstructed tree +reconstruct_parameter_tree(::Nothing, ::AbstractVector) = nothing + +# Base case: handles array-valued parameter leaves, which are reconstructed via a linear combination of the basis slices weighted by the basis values psi. +@inline function reconstruct_parameter_tree(alpha::AbstractArray, psi::AbstractVector) + # Check that the number of basis functions matches the leading dimension of the coefficient array. + length(psi) == size(alpha, 1) || throw(DimensionMismatch( + "basis length $(length(psi)) does not match lifted parameter size $(size(alpha, 1))." + )) + + # Scalar leaf case: a base scalar parameter becomes an m-vector of coefficients. + if ndims(alpha) == 1 + return LinearAlgebra.dot(alpha, psi) + end + + # Fast path for ordinary dense/strided arrays. + if alpha isa StridedArray + m = length(psi) + flat = reshape(alpha, m, :) + return reshape(flat' * psi, Base.tail(size(alpha))) + end + + # Generic fallback for array types without a convenient strided contraction. + y = psi[1] .* selectdim(alpha, 1, 1) + @inbounds for j in 2:lastindex(psi) + y = y .+ psi[j] .* selectdim(alpha, 1, j) + end + return y +end + +reconstruct_parameter_tree(alpha, ::AbstractVector) = + throw(ArgumentError("Unsupported coefficient leaf type $(typeof(alpha)) in Galerkin reconstruction.")) + +# ------------------------------------------------------------------------------ +# Layer call +# ------------------------------------------------------------------------------ +# The main call method, which constructs the ODE problem and solves it using the reconstructed parameters at each RHS evaluation. +# With x as the initial state, alpha as the coefficient tree of trainable parameters, and st as the initial state tree +function (g::GalerkinNeuralODE)(x, alpha, st) + # The model is a stateful wrapper around the base model + model = StatefulLuxLayer{fixed_state_type(g.model)}(g.model, nothing, st) + + # The ODE RHS function, which reconstructs the parameter tree at the current time and evaluates the model to get the state derivative. + function dudt(u, alpha, t) + psi = basis_eval(g.basis, t, g.tspan) + p_t = reconstruct_parameter_tree(alpha, psi) + return model(u, p_t) + end + + # IMPORTANT: Unlike current `NeuralODE`, we do not pass `tgrad = basic_tgrad` here, + # because the RHS depends on time/depth through the reconstructed parameter tree p(t). + ff = ODEFunction{false}(dudt) + prob = ODEProblem{false}(ff, x, g.tspan, alpha) + + # Solve the ODE problem using the InterpolatingAdjoint, which will allow gradients to flow through the solver and the time-dependent parameters. + return ( + solve( + prob, + g.args...; + # sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), + g.kwargs..., + ), + model.st, + ) +end