Skip to content

Commit d63f6b4

Browse files
Merge pull request #186 from ChrisRackauckas-Claude/runic-formatting
Switch from JuliaFormatter to Runic.jl for code formatting
2 parents 03da7a9 + a7dd057 commit d63f6b4

13 files changed

Lines changed: 238 additions & 149 deletions

.JuliaFormatter.toml

Lines changed: 0 additions & 7 deletions
This file was deleted.

.github/workflows/FormatCheck.yml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1-
name: "Format Check"
1+
name: format-check
22

33
on:
44
push:
55
branches:
6+
- 'master'
67
- 'main'
8+
- 'release-'
79
tags: '*'
810
pull_request:
911

1012
jobs:
11-
format-check:
12-
name: "Format Check"
13-
uses: "SciML/.github/.github/workflows/format-suggestions-on-pr.yml@v1"
13+
runic:
14+
runs-on: ubuntu-latest
15+
steps:
16+
- uses: actions/checkout@v4
17+
- uses: fredrikekre/runic-action@v1
18+
with:
19+
version: '1'

docs/make.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
using Documenter, DocumenterCitations, DeepEquilibriumNetworks
22

3-
cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force=true)
4-
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force=true)
3+
cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force = true)
4+
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force = true)
55

6-
bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style=:authoryear)
6+
bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style = :authoryear)
77

88
include("pages.jl")
99

10-
makedocs(; sitename="Deep Equilibrium Networks",
11-
authors="Avik Pal et al.",
12-
modules=[DeepEquilibriumNetworks],
13-
clean=true,
14-
doctest=false, # Tested in CI
15-
linkcheck=true,
16-
format=Documenter.HTML(; assets=["assets/favicon.ico"],
17-
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
18-
plugins=[bib],
19-
pages)
10+
makedocs(;
11+
sitename = "Deep Equilibrium Networks",
12+
authors = "Avik Pal et al.",
13+
modules = [DeepEquilibriumNetworks],
14+
clean = true,
15+
doctest = false, # Tested in CI
16+
linkcheck = true,
17+
format = Documenter.HTML(;
18+
assets = ["assets/favicon.ico"],
19+
canonical = "https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"
20+
),
21+
plugins = [bib],
22+
pages
23+
)
2024

21-
deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true)
25+
deploydocs(; repo = "github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview = true)

docs/pages.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
pages = ["Home" => "index.md",
1+
pages = [
2+
"Home" => "index.md",
23
"Tutorials" => ["tutorials/basic_mnist_deq.md", "tutorials/reduced_dim_deq.md"],
3-
"API References" => "api.md", "References" => "references.md"]
4+
"API References" => "api.md", "References" => "references.md",
5+
]

src/DeepEquilibriumNetworks.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ using NonlinearSolveBase: AbsNormTerminationMode
99
using FastClosures: @closure
1010
using Random: Random, AbstractRNG, randn!
1111
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm,
12-
NonlinearSolution, ODESolution, ODEFunction, ODEProblem,
13-
SteadyStateProblem, _unwrap_val
12+
NonlinearSolution, ODESolution, ODEFunction, ODEProblem,
13+
SteadyStateProblem, _unwrap_val
1414
using SciMLSensitivity: SteadyStateAdjoint, GaussAdjoint, ZygoteVJP
1515
using Static: StaticSymbol, StaticInt, known, static
1616

1717
using Lux: Lux, LuxOps, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer,
18-
StatefulLuxLayer, WrappedFunction
18+
StatefulLuxLayer, WrappedFunction
1919
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer
2020
using NNlib:
2121
using SteadyStateDiffEq: DynamicSS, SSRootfind
@@ -30,7 +30,7 @@ include("precompilation.jl")
3030

3131
# Exports
3232
export DEQs, DeepEquilibriumSolution, DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork,
33-
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
34-
MultiScaleNeuralODE
33+
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
34+
MultiScaleNeuralODE
3535

3636
end

src/layers.jl

Lines changed: 91 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@ struct DeepEquilibriumSolution # This is intentionally left untyped to allow up
2222
original
2323
end
2424

25-
function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star,
26-
u0, residual, jacobian_loss, nfe, original)
25+
function CRC.rrule(
26+
::Type{<:DeepEquilibriumSolution}, z_star,
27+
u0, residual, jacobian_loss, nfe, original
28+
)
2729
sol = DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe, original)
2830
∇DeepEquilibriumSolution(::CRC.NoTangent) = ntuple(_ -> CRC.NoTangent(), 7)
2931
function ∇DeepEquilibriumSolution(∂sol)
30-
return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual,
31-
∂sol.jacobian_loss, ∂sol.nfe, CRC.NoTangent())
32+
return (
33+
CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual,
34+
∂sol.jacobian_loss, ∂sol.nfe, CRC.NoTangent(),
35+
)
3236
end
3337
return sol, ∇DeepEquilibriumSolution
3438
end
@@ -39,15 +43,32 @@ end
3943

4044
function Base.show(io::IO, sol::DeepEquilibriumSolution)
4145
println(io, "DeepEquilibriumSolution")
42-
println(io, " * Initial Guess: ", sprint(print, sol.u0; context=(
43-
:compact => true, :limit => true)))
44-
println(io, " * Steady State: ", sprint(print, sol.z_star; context=(
45-
:compact => true, :limit => true)))
46-
println(io, " * Residual: ", sprint(print, sol.residual; context=(
47-
:compact => true, :limit => true)))
48-
println(io, " * Jacobian Loss: ",
49-
sprint(print, sol.jacobian_loss; context=(:compact => true, :limit => true)))
50-
print(io, " * NFE: ", sol.nfe)
46+
println(
47+
io, " * Initial Guess: ", sprint(
48+
print, sol.u0; context = (
49+
:compact => true, :limit => true,
50+
)
51+
)
52+
)
53+
println(
54+
io, " * Steady State: ", sprint(
55+
print, sol.z_star; context = (
56+
:compact => true, :limit => true,
57+
)
58+
)
59+
)
60+
println(
61+
io, " * Residual: ", sprint(
62+
print, sol.residual; context = (
63+
:compact => true, :limit => true,
64+
)
65+
)
66+
)
67+
println(
68+
io, " * Jacobian Loss: ",
69+
sprint(print, sol.jacobian_loss; context = (:compact => true, :limit => true))
70+
)
71+
return print(io, " * NFE: ", sol.nfe)
5172
end
5273

5374
# Core Model
@@ -65,31 +86,37 @@ const DEQ = DeepEquilibriumNetwork
6586
function LuxCore.initialstates(rng::AbstractRNG, deq::DEQ)
6687
rng = LuxCore.replicate(rng)
6788
randn(rng, 1)
68-
return (; model=LuxCore.initialstates(rng, deq.model), fixed_depth=Val(0),
69-
init=LuxCore.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng)
89+
return (;
90+
model = LuxCore.initialstates(rng, deq.model), fixed_depth = Val(0),
91+
init = LuxCore.initialstates(rng, deq.init), solution = DeepEquilibriumSolution(), rng,
92+
)
7093
end
7194

7295
(deq::DEQ)(x, ps, st::NamedTuple) = deq(x, ps, st, check_unrolled_mode(st))
7396

7497
## Pretraining
7598
function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true})
7699
z, st = get_initial_condition(deq, x, ps, st)
77-
repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth)
100+
repeated_model = RepeatedLayer(deq.model; repeats = st.fixed_depth)
78101

79102
z_star, st_ = repeated_model((z, x), ps.model, st.model)
80103
model = StatefulLuxLayer{true}(deq.model, ps.model, st_)
81104
resid = CRC.ignore_derivatives(z_star .- model((z_star, x)))
82105

83106
rng = LuxCore.replicate(st.rng)
84107
jac_loss = estimate_jacobian_trace(
85-
LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng)
108+
LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng
109+
)
86110

87111
solution = DeepEquilibriumSolution(
88-
z_star, z, resid, zero(eltype(x)), _unwrap_val(st.fixed_depth), jac_loss)
89-
res = split_and_reshape(z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)),
90-
LuxOps.getproperty(deq.model, Val(:scales)))
91-
92-
return res, (; st..., model=model.st, solution, rng)
112+
z_star, z, resid, zero(eltype(x)), _unwrap_val(st.fixed_depth), jac_loss
113+
)
114+
res = split_and_reshape(
115+
z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)),
116+
LuxOps.getproperty(deq.model, Val(:scales))
117+
)
118+
119+
return res, (; st..., model = model.st, solution, rng)
93120
end
94121

95122
function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{false})
@@ -104,23 +131,29 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{false})
104131
return y .- u
105132
end
106133

107-
prob = construct_prob(deq.kind, ODEFunction{false}(dudt), z, (; ps=ps.model, x))
134+
prob = construct_prob(deq.kind, ODEFunction{false}(dudt), z, (; ps = ps.model, x))
108135
alg = normalize_alg(deq)
109136
termination_condition = AbsNormTerminationMode(Base.Fix1(maximum, abs))
110-
sol = solve(prob, alg; sensealg=default_sensealg(prob), abstol=1e-3,
111-
reltol=1e-3, termination_condition, maxiters=32, deq.kwargs...)
137+
sol = solve(
138+
prob, alg; sensealg = default_sensealg(prob), abstol = 1.0e-3,
139+
reltol = 1.0e-3, termination_condition, maxiters = 32, deq.kwargs...
140+
)
112141
z_star = get_steady_state(sol)
113142

114143
rng = LuxCore.replicate(st.rng)
115144
jac_loss = estimate_jacobian_trace(
116-
LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng)
145+
LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng
146+
)
117147

118148
solution = DeepEquilibriumSolution(
119-
z_star, z, LuxOps.getproperty(sol, Val(:resid)), jac_loss, get_nfe(sol), sol)
120-
res = split_and_reshape(z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)),
121-
LuxOps.getproperty(deq.model, Val(:scales)))
122-
123-
return res, (; st..., model=model.st, solution, rng)
149+
z_star, z, LuxOps.getproperty(sol, Val(:resid)), jac_loss, get_nfe(sol), sol
150+
)
151+
res = split_and_reshape(
152+
z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)),
153+
LuxOps.getproperty(deq.model, Val(:scales))
154+
)
155+
156+
return res, (; st..., model = model.st, solution, rng)
124157
end
125158

126159
## Constructors
@@ -168,17 +201,20 @@ See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwo
168201
[`MultiScaleSkipDeepEquilibriumNetwork`](@ref).
169202
"""
170203
function DeepEquilibriumNetwork(
171-
model, solver; init=missing, jacobian_regularization=nothing,
172-
problem_type::Type=SteadyStateProblem{false}, kwargs...)
204+
model, solver; init = missing, jacobian_regularization = nothing,
205+
problem_type::Type = SteadyStateProblem{false}, kwargs...
206+
)
173207
if init === missing # Regular DEQ
174208
init = WrappedFunction(Base.Fix1(zeros_init, LuxOps.getproperty(model, Val(:scales))))
175209
elseif init === nothing # SkipRegDEQ
176210
init = NoOpLayer()
177211
elseif !(init isa AbstractLuxLayer)
178212
error("init::$(typeof(init)) is not a valid input for DeepEquilibriumNetwork.")
179213
end
180-
return DeepEquilibriumNetwork(init, model, solver, jacobian_regularization,
181-
kwargs, problem_type_to_symbol(problem_type))
214+
return DeepEquilibriumNetwork(
215+
init, model, solver, jacobian_regularization,
216+
kwargs, problem_type_to_symbol(problem_type)
217+
)
182218
end
183219

184220
"""
@@ -192,7 +228,7 @@ function SkipDeepEquilibriumNetwork(model, init, solver; kwargs...)
192228
end
193229

194230
function SkipDeepEquilibriumNetwork(model, solver; kwargs...)
195-
return DeepEquilibriumNetwork(model, solver; init=nothing, kwargs...)
231+
return DeepEquilibriumNetwork(model, solver; init = nothing, kwargs...)
196232
end
197233

198234
## MultiScale DEQ
@@ -242,8 +278,10 @@ julia> size.(first(model(x, ps, st)))
242278
((4, 12), (3, 12), (2, 12), (1, 12))
243279
```
244280
"""
245-
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
246-
post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...)
281+
function MultiScaleDeepEquilibriumNetwork(
282+
main_layers::Tuple, mapping_layers::Matrix,
283+
post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...
284+
)
247285
l1 = Parallel(nothing, main_layers...)
248286
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
249287

@@ -269,17 +307,23 @@ creates a [`MultiScaleDeepEquilibriumNetwork`](@ref) with `init` kwarg set to pa
269307
270308
If `init` is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.
271309
"""
272-
function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
273-
post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...)
310+
function MultiScaleSkipDeepEquilibriumNetwork(
311+
main_layers::Tuple, mapping_layers::Matrix,
312+
post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...
313+
)
274314
init = Chain(Parallel(nothing, init...), flatten_vcat)
275315
return MultiScaleDeepEquilibriumNetwork(
276-
main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs...)
316+
main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs...
317+
)
277318
end
278319

279-
function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
280-
post_fuse_layer::Union{Nothing, Tuple}, args...; kwargs...)
320+
function MultiScaleSkipDeepEquilibriumNetwork(
321+
main_layers::Tuple, mapping_layers::Matrix,
322+
post_fuse_layer::Union{Nothing, Tuple}, args...; kwargs...
323+
)
281324
return MultiScaleDeepEquilibriumNetwork(
282-
main_layers, mapping_layers, post_fuse_layer, args...; init=nothing, kwargs...)
325+
main_layers, mapping_layers, post_fuse_layer, args...; init = nothing, kwargs...
326+
)
283327
end
284328

285329
"""
@@ -289,19 +333,19 @@ Same arguments as [`MultiScaleDeepEquilibriumNetwork`](@ref) but sets `problem_t
289333
`ODEProblem{false}`.
290334
"""
291335
function MultiScaleNeuralODE(args...; kwargs...)
292-
return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., problem_type=ODEProblem{false})
336+
return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., problem_type = ODEProblem{false})
293337
end
294338

295339
## Generate Initial Condition
296340
function get_initial_condition(deq::DEQ{NoOpLayer}, x, ps, st)
297341
zₓ = zeros_init(LuxOps.getproperty(deq.model, Val(:scales)), x)
298342
z, st_ = deq.model((zₓ, x), ps.model, st.model)
299-
return z, (; st..., model=st_)
343+
return z, (; st..., model = st_)
300344
end
301345

302346
function get_initial_condition(deq::DEQ, x, ps, st)
303347
z, st_ = deq.init(x, ps.init, st.init)
304-
return z, (; st..., init=st_)
348+
return z, (; st..., init = st_)
305349
end
306350

307351
# Other Layers

src/precompilation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ using PrecompileTools: @compile_workload, @setup_workload
1111
# Create a small model for precompilation
1212
# Using SSRootfind which is already imported from SteadyStateDiffEq
1313
model = DEQ(
14-
Parallel(+, Lux.Dense(2, 2; use_bias=false), Lux.Dense(2, 2; use_bias=false)),
14+
Parallel(+, Lux.Dense(2, 2; use_bias = false), Lux.Dense(2, 2; use_bias = false)),
1515
SSRootfind();
16-
verbose=false
16+
verbose = false
1717
)
1818

1919
# Initialize parameters and state (very common operation)

0 commit comments

Comments
 (0)