@@ -22,13 +22,17 @@ struct DeepEquilibriumSolution # This is intentionally left untyped to allow up
2222 original
2323end
2424
25- function CRC. rrule (:: Type{<:DeepEquilibriumSolution} , z_star,
26- u0, residual, jacobian_loss, nfe, original)
25+ function CRC. rrule (
26+ :: Type{<:DeepEquilibriumSolution} , z_star,
27+ u0, residual, jacobian_loss, nfe, original
28+ )
2729 sol = DeepEquilibriumSolution (z_star, u0, residual, jacobian_loss, nfe, original)
2830 ∇DeepEquilibriumSolution (:: CRC.NoTangent ) = ntuple (_ -> CRC. NoTangent (), 7 )
2931 function ∇DeepEquilibriumSolution (∂sol)
30- return (CRC. NoTangent (), ∂sol. z_star, ∂sol. u0, ∂sol. residual,
31- ∂sol. jacobian_loss, ∂sol. nfe, CRC. NoTangent ())
32+ return (
33+ CRC. NoTangent (), ∂sol. z_star, ∂sol. u0, ∂sol. residual,
34+ ∂sol. jacobian_loss, ∂sol. nfe, CRC. NoTangent (),
35+ )
3236 end
3337 return sol, ∇DeepEquilibriumSolution
3438end
3943
4044function Base. show (io:: IO , sol:: DeepEquilibriumSolution )
4145 println (io, " DeepEquilibriumSolution" )
42- println (io, " * Initial Guess: " , sprint (print, sol. u0; context= (
43- :compact => true , :limit => true )))
44- println (io, " * Steady State: " , sprint (print, sol. z_star; context= (
45- :compact => true , :limit => true )))
46- println (io, " * Residual: " , sprint (print, sol. residual; context= (
47- :compact => true , :limit => true )))
48- println (io, " * Jacobian Loss: " ,
49- sprint (print, sol. jacobian_loss; context= (:compact => true , :limit => true )))
50- print (io, " * NFE: " , sol. nfe)
46+ println (
47+ io, " * Initial Guess: " , sprint (
48+ print, sol. u0; context = (
49+ :compact => true , :limit => true ,
50+ )
51+ )
52+ )
53+ println (
54+ io, " * Steady State: " , sprint (
55+ print, sol. z_star; context = (
56+ :compact => true , :limit => true ,
57+ )
58+ )
59+ )
60+ println (
61+ io, " * Residual: " , sprint (
62+ print, sol. residual; context = (
63+ :compact => true , :limit => true ,
64+ )
65+ )
66+ )
67+ println (
68+ io, " * Jacobian Loss: " ,
69+ sprint (print, sol. jacobian_loss; context = (:compact => true , :limit => true ))
70+ )
71+ return print (io, " * NFE: " , sol. nfe)
5172end
5273
5374# Core Model
@@ -65,31 +86,37 @@ const DEQ = DeepEquilibriumNetwork
6586function LuxCore. initialstates (rng:: AbstractRNG , deq:: DEQ )
6687 rng = LuxCore. replicate (rng)
6788 randn (rng, 1 )
68- return (; model= LuxCore. initialstates (rng, deq. model), fixed_depth= Val (0 ),
69- init= LuxCore. initialstates (rng, deq. init), solution= DeepEquilibriumSolution (), rng)
89+ return (;
90+ model = LuxCore. initialstates (rng, deq. model), fixed_depth = Val (0 ),
91+ init = LuxCore. initialstates (rng, deq. init), solution = DeepEquilibriumSolution (), rng,
92+ )
7093end
7194
7295(deq:: DEQ )(x, ps, st:: NamedTuple ) = deq (x, ps, st, check_unrolled_mode (st))
7396
7497# # Pretraining
7598function (deq:: DEQ )(x, ps, st:: NamedTuple , :: Val{true} )
7699 z, st = get_initial_condition (deq, x, ps, st)
77- repeated_model = RepeatedLayer (deq. model; repeats= st. fixed_depth)
100+ repeated_model = RepeatedLayer (deq. model; repeats = st. fixed_depth)
78101
79102 z_star, st_ = repeated_model ((z, x), ps. model, st. model)
80103 model = StatefulLuxLayer {true} (deq. model, ps. model, st_)
81104 resid = CRC. ignore_derivatives (z_star .- model ((z_star, x)))
82105
83106 rng = LuxCore. replicate (st. rng)
84107 jac_loss = estimate_jacobian_trace (
85- LuxOps. getproperty (deq, Val (:jacobian_regularization )), model, z_star, x, rng)
108+ LuxOps. getproperty (deq, Val (:jacobian_regularization )), model, z_star, x, rng
109+ )
86110
87111 solution = DeepEquilibriumSolution (
88- z_star, z, resid, zero (eltype (x)), _unwrap_val (st. fixed_depth), jac_loss)
89- res = split_and_reshape (z_star, LuxOps. getproperty (deq. model, Val (:split_idxs )),
90- LuxOps. getproperty (deq. model, Val (:scales )))
91-
92- return res, (; st... , model= model. st, solution, rng)
112+ z_star, z, resid, zero (eltype (x)), _unwrap_val (st. fixed_depth), jac_loss
113+ )
114+ res = split_and_reshape (
115+ z_star, LuxOps. getproperty (deq. model, Val (:split_idxs )),
116+ LuxOps. getproperty (deq. model, Val (:scales ))
117+ )
118+
119+ return res, (; st... , model = model. st, solution, rng)
93120end
94121
95122function (deq:: DEQ )(x, ps, st:: NamedTuple , :: Val{false} )
@@ -104,23 +131,29 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{false})
104131 return y .- u
105132 end
106133
107- prob = construct_prob (deq. kind, ODEFunction {false} (dudt), z, (; ps= ps. model, x))
134+ prob = construct_prob (deq. kind, ODEFunction {false} (dudt), z, (; ps = ps. model, x))
108135 alg = normalize_alg (deq)
109136 termination_condition = AbsNormTerminationMode (Base. Fix1 (maximum, abs))
110- sol = solve (prob, alg; sensealg= default_sensealg (prob), abstol= 1e-3 ,
111- reltol= 1e-3 , termination_condition, maxiters= 32 , deq. kwargs... )
137+ sol = solve (
138+ prob, alg; sensealg = default_sensealg (prob), abstol = 1.0e-3 ,
139+ reltol = 1.0e-3 , termination_condition, maxiters = 32 , deq. kwargs...
140+ )
112141 z_star = get_steady_state (sol)
113142
114143 rng = LuxCore. replicate (st. rng)
115144 jac_loss = estimate_jacobian_trace (
116- LuxOps. getproperty (deq, Val (:jacobian_regularization )), model, z_star, x, rng)
145+ LuxOps. getproperty (deq, Val (:jacobian_regularization )), model, z_star, x, rng
146+ )
117147
118148 solution = DeepEquilibriumSolution (
119- z_star, z, LuxOps. getproperty (sol, Val (:resid )), jac_loss, get_nfe (sol), sol)
120- res = split_and_reshape (z_star, LuxOps. getproperty (deq. model, Val (:split_idxs )),
121- LuxOps. getproperty (deq. model, Val (:scales )))
122-
123- return res, (; st... , model= model. st, solution, rng)
149+ z_star, z, LuxOps. getproperty (sol, Val (:resid )), jac_loss, get_nfe (sol), sol
150+ )
151+ res = split_and_reshape (
152+ z_star, LuxOps. getproperty (deq. model, Val (:split_idxs )),
153+ LuxOps. getproperty (deq. model, Val (:scales ))
154+ )
155+
156+ return res, (; st... , model = model. st, solution, rng)
124157end
125158
126159# # Constructors
@@ -168,17 +201,20 @@ See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwo
168201[`MultiScaleSkipDeepEquilibriumNetwork`](@ref).
169202"""
170203function DeepEquilibriumNetwork (
171- model, solver; init= missing , jacobian_regularization= nothing ,
172- problem_type:: Type = SteadyStateProblem{false }, kwargs... )
204+ model, solver; init = missing , jacobian_regularization = nothing ,
205+ problem_type:: Type = SteadyStateProblem{false }, kwargs...
206+ )
173207 if init === missing # Regular DEQ
174208 init = WrappedFunction (Base. Fix1 (zeros_init, LuxOps. getproperty (model, Val (:scales ))))
175209 elseif init === nothing # SkipRegDEQ
176210 init = NoOpLayer ()
177211 elseif ! (init isa AbstractLuxLayer)
178212 error (" init::$(typeof (init)) is not a valid input for DeepEquilibriumNetwork." )
179213 end
180- return DeepEquilibriumNetwork (init, model, solver, jacobian_regularization,
181- kwargs, problem_type_to_symbol (problem_type))
214+ return DeepEquilibriumNetwork (
215+ init, model, solver, jacobian_regularization,
216+ kwargs, problem_type_to_symbol (problem_type)
217+ )
182218end
183219
184220"""
@@ -192,7 +228,7 @@ function SkipDeepEquilibriumNetwork(model, init, solver; kwargs...)
192228end
193229
194230function SkipDeepEquilibriumNetwork (model, solver; kwargs... )
195- return DeepEquilibriumNetwork (model, solver; init= nothing , kwargs... )
231+ return DeepEquilibriumNetwork (model, solver; init = nothing , kwargs... )
196232end
197233
198234# # MultiScale DEQ
@@ -242,8 +278,10 @@ julia> size.(first(model(x, ps, st)))
242278((4, 12), (3, 12), (2, 12), (1, 12))
243279```
244280"""
245- function MultiScaleDeepEquilibriumNetwork (main_layers:: Tuple , mapping_layers:: Matrix ,
246- post_fuse_layer:: Union{Nothing, Tuple} , solver, scales; kwargs... )
281+ function MultiScaleDeepEquilibriumNetwork (
282+ main_layers:: Tuple , mapping_layers:: Matrix ,
283+ post_fuse_layer:: Union{Nothing, Tuple} , solver, scales; kwargs...
284+ )
247285 l1 = Parallel (nothing , main_layers... )
248286 l2 = BranchLayer (Parallel .(+ , map (x -> tuple (x... ), eachrow (mapping_layers))... )... )
249287
@@ -269,17 +307,23 @@ creates a [`MultiScaleDeepEquilibriumNetwork`](@ref) with `init` kwarg set to pa
269307
270308If `init` is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.
271309"""
272- function MultiScaleSkipDeepEquilibriumNetwork (main_layers:: Tuple , mapping_layers:: Matrix ,
273- post_fuse_layer:: Union{Nothing, Tuple} , init:: Tuple , solver, scales; kwargs... )
310+ function MultiScaleSkipDeepEquilibriumNetwork (
311+ main_layers:: Tuple , mapping_layers:: Matrix ,
312+ post_fuse_layer:: Union{Nothing, Tuple} , init:: Tuple , solver, scales; kwargs...
313+ )
274314 init = Chain (Parallel (nothing , init... ), flatten_vcat)
275315 return MultiScaleDeepEquilibriumNetwork (
276- main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs... )
316+ main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs...
317+ )
277318end
278319
279- function MultiScaleSkipDeepEquilibriumNetwork (main_layers:: Tuple , mapping_layers:: Matrix ,
280- post_fuse_layer:: Union{Nothing, Tuple} , args... ; kwargs... )
320+ function MultiScaleSkipDeepEquilibriumNetwork (
321+ main_layers:: Tuple , mapping_layers:: Matrix ,
322+ post_fuse_layer:: Union{Nothing, Tuple} , args... ; kwargs...
323+ )
281324 return MultiScaleDeepEquilibriumNetwork (
282- main_layers, mapping_layers, post_fuse_layer, args... ; init= nothing , kwargs... )
325+ main_layers, mapping_layers, post_fuse_layer, args... ; init = nothing , kwargs...
326+ )
283327end
284328
285329"""
@@ -289,19 +333,19 @@ Same arguments as [`MultiScaleDeepEquilibriumNetwork`](@ref) but sets `problem_t
289333`ODEProblem{false}`.
290334"""
291335function MultiScaleNeuralODE (args... ; kwargs... )
292- return MultiScaleDeepEquilibriumNetwork (args... ; kwargs... , problem_type= ODEProblem{false })
336+ return MultiScaleDeepEquilibriumNetwork (args... ; kwargs... , problem_type = ODEProblem{false })
293337end
294338
295339# # Generate Initial Condition
296340function get_initial_condition (deq:: DEQ{NoOpLayer} , x, ps, st)
297341 zₓ = zeros_init (LuxOps. getproperty (deq. model, Val (:scales )), x)
298342 z, st_ = deq. model ((zₓ, x), ps. model, st. model)
299- return z, (; st... , model= st_)
343+ return z, (; st... , model = st_)
300344end
301345
302346function get_initial_condition (deq:: DEQ , x, ps, st)
303347 z, st_ = deq. init (x, ps. init, st. init)
304- return z, (; st... , init= st_)
348+ return z, (; st... , init = st_)
305349end
306350
307351# Other Layers
0 commit comments