Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions .JuliaFormatter.toml

This file was deleted.

14 changes: 10 additions & 4 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
name: "Format Check"
name: format-check

on:
push:
branches:
- 'master'
- 'main'
- 'release-'
tags: '*'
pull_request:

jobs:
format-check:
name: "Format Check"
uses: "SciML/.github/.github/workflows/format-suggestions-on-pr.yml@v1"
runic:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: fredrikekre/runic-action@v1
with:
version: '1'
32 changes: 18 additions & 14 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
using Documenter, DocumenterCitations, DeepEquilibriumNetworks

cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force=true)
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force=true)
cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force = true)
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force = true)

bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style=:authoryear)
bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style = :authoryear)

include("pages.jl")

makedocs(; sitename="Deep Equilibrium Networks",
authors="Avik Pal et al.",
modules=[DeepEquilibriumNetworks],
clean=true,
doctest=false, # Tested in CI
linkcheck=true,
format=Documenter.HTML(; assets=["assets/favicon.ico"],
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
plugins=[bib],
pages)
makedocs(;
sitename = "Deep Equilibrium Networks",
authors = "Avik Pal et al.",
modules = [DeepEquilibriumNetworks],
clean = true,
doctest = false, # Tested in CI
linkcheck = true,
format = Documenter.HTML(;
assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"
),
plugins = [bib],
pages
)

deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true)
deploydocs(; repo = "github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview = true)
6 changes: 4 additions & 2 deletions docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pages = ["Home" => "index.md",
pages = [
"Home" => "index.md",
"Tutorials" => ["tutorials/basic_mnist_deq.md", "tutorials/reduced_dim_deq.md"],
"API References" => "api.md", "References" => "references.md"]
"API References" => "api.md", "References" => "references.md",
]
10 changes: 5 additions & 5 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ using NonlinearSolveBase: AbsNormTerminationMode
using FastClosures: @closure
using Random: Random, AbstractRNG, randn!
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm,
NonlinearSolution, ODESolution, ODEFunction, ODEProblem,
SteadyStateProblem, _unwrap_val
NonlinearSolution, ODESolution, ODEFunction, ODEProblem,
SteadyStateProblem, _unwrap_val
using SciMLSensitivity: SteadyStateAdjoint, GaussAdjoint, ZygoteVJP
using Static: StaticSymbol, StaticInt, known, static

using Lux: Lux, LuxOps, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer,
StatefulLuxLayer, WrappedFunction
StatefulLuxLayer, WrappedFunction
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer
using NNlib: ⊠
using SteadyStateDiffEq: DynamicSS, SSRootfind
Expand All @@ -30,7 +30,7 @@ include("precompilation.jl")

# Exports
export DEQs, DeepEquilibriumSolution, DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork,
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
MultiScaleNeuralODE
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
MultiScaleNeuralODE

end
138 changes: 91 additions & 47 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ struct DeepEquilibriumSolution # This is intentionally left untyped to allow up
original
end

function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star,
u0, residual, jacobian_loss, nfe, original)
function CRC.rrule(
::Type{<:DeepEquilibriumSolution}, z_star,
u0, residual, jacobian_loss, nfe, original
)
sol = DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe, original)
∇DeepEquilibriumSolution(::CRC.NoTangent) = ntuple(_ -> CRC.NoTangent(), 7)
function ∇DeepEquilibriumSolution(∂sol)
return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual,
∂sol.jacobian_loss, ∂sol.nfe, CRC.NoTangent())
return (
CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual,
∂sol.jacobian_loss, ∂sol.nfe, CRC.NoTangent(),
)
end
return sol, ∇DeepEquilibriumSolution
end
Expand All @@ -39,15 +43,32 @@ end

function Base.show(io::IO, sol::DeepEquilibriumSolution)
println(io, "DeepEquilibriumSolution")
println(io, " * Initial Guess: ", sprint(print, sol.u0; context=(
:compact => true, :limit => true)))
println(io, " * Steady State: ", sprint(print, sol.z_star; context=(
:compact => true, :limit => true)))
println(io, " * Residual: ", sprint(print, sol.residual; context=(
:compact => true, :limit => true)))
println(io, " * Jacobian Loss: ",
sprint(print, sol.jacobian_loss; context=(:compact => true, :limit => true)))
print(io, " * NFE: ", sol.nfe)
println(
io, " * Initial Guess: ", sprint(
print, sol.u0; context = (
:compact => true, :limit => true,
)
)
)
println(
io, " * Steady State: ", sprint(
print, sol.z_star; context = (
:compact => true, :limit => true,
)
)
)
println(
io, " * Residual: ", sprint(
print, sol.residual; context = (
:compact => true, :limit => true,
)
)
)
println(
io, " * Jacobian Loss: ",
sprint(print, sol.jacobian_loss; context = (:compact => true, :limit => true))
)
return print(io, " * NFE: ", sol.nfe)
end

# Core Model
Expand All @@ -65,31 +86,37 @@ const DEQ = DeepEquilibriumNetwork
function LuxCore.initialstates(rng::AbstractRNG, deq::DEQ)
rng = LuxCore.replicate(rng)
randn(rng, 1)
return (; model=LuxCore.initialstates(rng, deq.model), fixed_depth=Val(0),
init=LuxCore.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng)
return (;
model = LuxCore.initialstates(rng, deq.model), fixed_depth = Val(0),
init = LuxCore.initialstates(rng, deq.init), solution = DeepEquilibriumSolution(), rng,
)
end

(deq::DEQ)(x, ps, st::NamedTuple) = deq(x, ps, st, check_unrolled_mode(st))

## Pretraining
function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true})
z, st = get_initial_condition(deq, x, ps, st)
repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth)
repeated_model = RepeatedLayer(deq.model; repeats = st.fixed_depth)

z_star, st_ = repeated_model((z, x), ps.model, st.model)
model = StatefulLuxLayer{true}(deq.model, ps.model, st_)
resid = CRC.ignore_derivatives(z_star .- model((z_star, x)))

rng = LuxCore.replicate(st.rng)
jac_loss = estimate_jacobian_trace(
LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng)
LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng
)

solution = DeepEquilibriumSolution(
z_star, z, resid, zero(eltype(x)), _unwrap_val(st.fixed_depth), jac_loss)
res = split_and_reshape(z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)),
LuxOps.getproperty(deq.model, Val(:scales)))

return res, (; st..., model=model.st, solution, rng)
z_star, z, resid, zero(eltype(x)), _unwrap_val(st.fixed_depth), jac_loss
)
res = split_and_reshape(
z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)),
LuxOps.getproperty(deq.model, Val(:scales))
)

return res, (; st..., model = model.st, solution, rng)
end

function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{false})
Expand All @@ -104,23 +131,29 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{false})
return y .- u
end

prob = construct_prob(deq.kind, ODEFunction{false}(dudt), z, (; ps=ps.model, x))
prob = construct_prob(deq.kind, ODEFunction{false}(dudt), z, (; ps = ps.model, x))
alg = normalize_alg(deq)
termination_condition = AbsNormTerminationMode(Base.Fix1(maximum, abs))
sol = solve(prob, alg; sensealg=default_sensealg(prob), abstol=1e-3,
reltol=1e-3, termination_condition, maxiters=32, deq.kwargs...)
sol = solve(
prob, alg; sensealg = default_sensealg(prob), abstol = 1.0e-3,
reltol = 1.0e-3, termination_condition, maxiters = 32, deq.kwargs...
)
z_star = get_steady_state(sol)

rng = LuxCore.replicate(st.rng)
jac_loss = estimate_jacobian_trace(
LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng)
LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng
)

solution = DeepEquilibriumSolution(
z_star, z, LuxOps.getproperty(sol, Val(:resid)), jac_loss, get_nfe(sol), sol)
res = split_and_reshape(z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)),
LuxOps.getproperty(deq.model, Val(:scales)))

return res, (; st..., model=model.st, solution, rng)
z_star, z, LuxOps.getproperty(sol, Val(:resid)), jac_loss, get_nfe(sol), sol
)
res = split_and_reshape(
z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)),
LuxOps.getproperty(deq.model, Val(:scales))
)

return res, (; st..., model = model.st, solution, rng)
end

## Constructors
Expand Down Expand Up @@ -168,17 +201,20 @@ See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwo
[`MultiScaleSkipDeepEquilibriumNetwork`](@ref).
"""
function DeepEquilibriumNetwork(
model, solver; init=missing, jacobian_regularization=nothing,
problem_type::Type=SteadyStateProblem{false}, kwargs...)
model, solver; init = missing, jacobian_regularization = nothing,
problem_type::Type = SteadyStateProblem{false}, kwargs...
)
if init === missing # Regular DEQ
init = WrappedFunction(Base.Fix1(zeros_init, LuxOps.getproperty(model, Val(:scales))))
elseif init === nothing # SkipRegDEQ
init = NoOpLayer()
elseif !(init isa AbstractLuxLayer)
error("init::$(typeof(init)) is not a valid input for DeepEquilibriumNetwork.")
end
return DeepEquilibriumNetwork(init, model, solver, jacobian_regularization,
kwargs, problem_type_to_symbol(problem_type))
return DeepEquilibriumNetwork(
init, model, solver, jacobian_regularization,
kwargs, problem_type_to_symbol(problem_type)
)
end

"""
Expand All @@ -192,7 +228,7 @@ function SkipDeepEquilibriumNetwork(model, init, solver; kwargs...)
end

function SkipDeepEquilibriumNetwork(model, solver; kwargs...)
return DeepEquilibriumNetwork(model, solver; init=nothing, kwargs...)
return DeepEquilibriumNetwork(model, solver; init = nothing, kwargs...)
end

## MultiScale DEQ
Expand Down Expand Up @@ -242,8 +278,10 @@ julia> size.(first(model(x, ps, st)))
((4, 12), (3, 12), (2, 12), (1, 12))
```
"""
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...)
function MultiScaleDeepEquilibriumNetwork(
main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...
)
l1 = Parallel(nothing, main_layers...)
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)

Expand All @@ -269,17 +307,23 @@ creates a [`MultiScaleDeepEquilibriumNetwork`](@ref) with `init` kwarg set to pa

If `init` is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.
"""
function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...)
function MultiScaleSkipDeepEquilibriumNetwork(
main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...
)
init = Chain(Parallel(nothing, init...), flatten_vcat)
return MultiScaleDeepEquilibriumNetwork(
main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs...)
main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs...
)
end

function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, args...; kwargs...)
function MultiScaleSkipDeepEquilibriumNetwork(
main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, args...; kwargs...
)
return MultiScaleDeepEquilibriumNetwork(
main_layers, mapping_layers, post_fuse_layer, args...; init=nothing, kwargs...)
main_layers, mapping_layers, post_fuse_layer, args...; init = nothing, kwargs...
)
end

"""
Expand All @@ -289,19 +333,19 @@ Same arguments as [`MultiScaleDeepEquilibriumNetwork`](@ref) but sets `problem_t
`ODEProblem{false}`.
"""
function MultiScaleNeuralODE(args...; kwargs...)
return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., problem_type=ODEProblem{false})
return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., problem_type = ODEProblem{false})
end

## Generate Initial Condition
function get_initial_condition(deq::DEQ{NoOpLayer}, x, ps, st)
zₓ = zeros_init(LuxOps.getproperty(deq.model, Val(:scales)), x)
z, st_ = deq.model((zₓ, x), ps.model, st.model)
return z, (; st..., model=st_)
return z, (; st..., model = st_)
end

function get_initial_condition(deq::DEQ, x, ps, st)
z, st_ = deq.init(x, ps.init, st.init)
return z, (; st..., init=st_)
return z, (; st..., init = st_)
end

# Other Layers
Expand Down
4 changes: 2 additions & 2 deletions src/precompilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ using PrecompileTools: @compile_workload, @setup_workload
# Create a small model for precompilation
# Using SSRootfind which is already imported from SteadyStateDiffEq
model = DEQ(
Parallel(+, Lux.Dense(2, 2; use_bias=false), Lux.Dense(2, 2; use_bias=false)),
Parallel(+, Lux.Dense(2, 2; use_bias = false), Lux.Dense(2, 2; use_bias = false)),
SSRootfind();
verbose=false
verbose = false
)

# Initialize parameters and state (very common operation)
Expand Down
Loading
Loading