Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand Down Expand Up @@ -45,6 +46,7 @@ NonlinearSolve = "3.10.0, 4"
NonlinearSolveBase = "1.5"
OrdinaryDiffEq = "6.74.1"
Pkg = "1.10"
PrecompileTools = "1"
Random = "1.10"
ReTestItems = "1.23.1"
SciMLBase = "2"
Expand Down
1 change: 1 addition & 0 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const DEQs = DeepEquilibriumNetworks

include("layers.jl")
include("utils.jl")
include("precompilation.jl")

# Exports
export DEQs, DeepEquilibriumSolution, DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork,
Expand Down
45 changes: 45 additions & 0 deletions src/precompilation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using PrecompileTools: @compile_workload, @setup_workload

@setup_workload begin
@compile_workload begin
# Precompile core functionality for DeepEquilibriumNetwork
# These are the most common operations users perform

# Simple Dense-based DEQ model setup
rng = Random.Xoshiro(0)

# Create a small model for precompilation
# Using SSRootfind which is already imported from SteadyStateDiffEq
model = DEQ(
Parallel(+, Lux.Dense(2, 2; use_bias=false), Lux.Dense(2, 2; use_bias=false)),
SSRootfind();
verbose=false
)

# Initialize parameters and state (very common operation)
ps, st = LuxCore.setup(rng, model)

# Precompile DeepEquilibriumSolution constructor
_ = DeepEquilibriumSolution()

# Precompile utility functions
x = ones(Float32, 2, 1)

# Precompile check_unrolled_mode
_ = check_unrolled_mode(st)

# Precompile zeros_init
_ = zeros_init(nothing, x)

# Precompile flatten operations
_ = flatten(x)
_ = flatten_vcat((x, x))

# Precompile split_and_reshape with Nothing
_ = split_and_reshape(x, nothing, nothing)

# Precompile with fixed depth (unrolled mode)
st_unrolled = Lux.update_state(st, :fixed_depth, Val(2))
_ = check_unrolled_mode(st_unrolled)
end
end
Loading