Skip to content

Drop vec/flatten/unflatten roundtrips in MIRK/FIRK/MIRKN nlprob construction #486

@ChrisRackauckas-Claude

Description

@ChrisRackauckas-Claude

Background

In __perform_mirk_iteration / __perform_firk_iteration / __perform_mirkn_iteration, the structured initial guess cache.y₀::VectorOfArray{Float64, 2, Vector{Vector{Float64}}} is flattened to a 1D vector before being handed to NonlinearSolve:

nlprob = __construct_problem(cache, copy(vec(cache.y₀)), copy(cache.y₀))

(copy(vec(...)) rather than vec(...) is necessary because under RecursiveArrayTools v4, vec(::VectorOfArray) returns Base.ReshapedArray{T, 1, VectorOfArray{...}, …}, which causes NonlinearSolve's polyalg to fail to infer T, N, uType, R of the resulting NonlinearSolution — see #484 / #473 / #485.)

The copy is essentially papering over a deeper structural choice: the entire NL solve pipeline below this point is written to take a flat AbstractVector and round-trip it through the structured VectorOfArray representation:

@views function __mirk_loss!(resid, u, p, y, pt, bc!, residual, mesh, cache, eval_sol, trait, constraint) where {BC}
    y_ = recursive_unflatten!(y, u)            # u (flat) → y_ (VOA)
    resids = [get_tmp(r, u) for r in residual]
    Φ!(resids[2:end], cache, y_, u, trait, constraint)
    update_eval_sol!(eval_sol, y_, cache)
    eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh)
    recursive_flatten!(resid, resids)           # resids (VOA) → resid (flat)
    return nothing
end

Each NL iteration: VOA → flat → solve → flat → VOA via recursive_unflatten! / recursive_flatten!. With RAT v4 making VectorOfArray <: AbstractArray, NonlinearSolve in principle accepts cache.y₀ directly as u0, so the round-trip is no longer load-bearing — it's just historical from when VOA wasn't an AbstractArray.

Proposal

Skip the flatten/unflatten and pass cache.y₀ (the VOA) directly as nlprob.u0. Concretely:

  • __perform_*_iteration: __construct_*problem(cache, cache.y₀, copy(cache.y₀)) (drop copy(vec(...))).
  • __mirk_loss! / __mirk_loss_bc! / __mirk_loss_collocation! (and MIRKN/FIRK analogues): receive u::AbstractVectorOfArray directly, drop the recursive_unflatten!(y, u) / recursive_flatten!(resid, resids) calls. The structured form is the input.
  • safe_similar(y, ...) calls building residual buffers: keep returning a flat Vector{Float64} (residuals stay 1D — recursive_flatten_twopoint! etc. pack into the flat layout the NL solver expects).
  • DI prepare_jacobian(loss, resid, diffmode, y, Constant(cache.p)): confirm DI handles a 2D VectorOfArray input correctly — Jacobian shape is (length(resid_flat), length(y_voa)) under linear indexing, which should match what the existing vcat(J_bc, J_c) jac_prototype assembly expects. Some adapter work may be required.

Why bother

  1. Eliminates a per-iteration allocation+copy (copy(vec(cache.y₀)) allocates a fresh Vector each iteration; with the round-trip closures also allocating internally).
  2. Removes the RAT-v4 inference cliff at the source rather than per-call-site. Right now any new code path that hands a vec(VOA) to NonlinearSolve will silently lose T, N, uType, R inference — as MIRKN did until Fix MIRK/FIRK adaptive mesh refinement under RAT v4 (closes #484) #485.
  3. Removes a source of confusion: it's not obvious from reading __perform_mirk_iteration why copy(vec(...)) is required and not vec(...). Comments help, but not having the round-trip is cleaner.

Scope

Probably one well-scoped PR per sublibrary (MIRK, FIRK, MIRKN). Each touches the iteration entry point + the loss function family. Shooting/MultipleShooting separately if useful.

Out of scope here

This issue is about removing the round-trip on the post-cache-construction side. The choice to store cache.y₀ as a VectorOfArray of per-timestep vectors (vs. e.g. a contiguous Matrix{Float64} with column views) is independent and not changed by this proposal.

🤖 Generated with Claude Code

Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions