Skip to content

Commit 61a4cd4

Browse files
authored
Merge pull request #79 from ReactiveBayes/fix-invertible-nn-optim-v2
Fix Invertible Neural Network Tutorial for Optim.jl v2
2 parents 155aac5 + 4102932 commit 61a4cd4

2 files changed

Lines changed: 8 additions & 95 deletions

File tree

examples/Problem Specific/Invertible Neural Network Tutorial/Invertible Neural Network Tutorial.ipynb

Lines changed: 6 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,10 @@
4343
},
4444
{
4545
"cell_type": "code",
46-
"execution_count": 1,
46+
"execution_count": null,
4747
"metadata": {},
4848
"outputs": [],
49-
"source": [
50-
"using RxInfer\n",
51-
"using Random\n",
52-
"using StableRNGs\n",
53-
"\n",
54-
"using ReactiveMP # ReactiveMP is included in RxInfer, but we explicitly use some of its functionality\n",
55-
"using LinearAlgebra # only used for some matrix specifics\n",
56-
"using Plots # only used for visualisation\n",
57-
"using Distributions # only used for sampling from multivariate distributions\n",
58-
"using Optim # only used for parameter optimisation"
59-
]
49+
"source": "using RxInfer\nusing Random\nusing StableRNGs\n\nusing ReactiveMP # ReactiveMP is included in RxInfer, but we explicitly use some of its functionality\nusing LinearAlgebra # only used for some matrix specifics\nusing Plots # only used for visualisation\nusing Distributions # only used for sampling from multivariate distributions\nusing Optim # only used for parameter optimisation\nusing ADTypes # only used to specify the automatic differentiation backend for Optim\nusing ForwardDiff # only used for automatic differentiation in the parameter optimisation"
6050
},
6151
{
6252
"attachments": {},
@@ -9351,93 +9341,14 @@
93519341
"attachments": {},
93529342
"cell_type": "markdown",
93539343
"metadata": {},
9354-
"source": [
9355-
"Optimization can be performed using the `Optim` package. Alternatively, other (custom) optimizers can be implemented, such as:\n",
9356-
"\n",
9357-
"```julia\n",
9358-
"res = optimize(f, randn(StableRNG(42), nr_params(model)), GradientDescent(), Optim.Options(store_trace = true, show_trace = true, show_every = 50), autodiff=:forward)\n",
9359-
"``` \n",
9360-
"\n",
9361-
"- uses finitediff and is slower/less accurate.\n",
9362-
"\n",
9363-
"*or*\n",
9364-
"\n",
9365-
"```julia\n",
9366-
"# create gradient function\n",
9367-
"g = (x) -> ForwardDiff.gradient(f, x);\n",
9368-
"\n",
9369-
"# specify initial params\n",
9370-
"params = randn(nr_params(model))\n",
9371-
"\n",
9372-
"# create custom optimizer (here Adam)\n",
9373-
"optimizer = Adam(params; λ=1e-1)\n",
9374-
"\n",
9375-
"# allocate space for gradient\n",
9376-
"∇ = zeros(nr_params(model))\n",
9377-
"\n",
9378-
"# perform optimization\n",
9379-
"for it = 1:10000\n",
9380-
"\n",
9381-
" # backward pass\n",
9382-
" ∇ .= ForwardDiff.gradient(f, optimizer.x)\n",
9383-
"\n",
9384-
" # gradient update\n",
9385-
" ReactiveMP.update!(optimizer, ∇)\n",
9386-
"\n",
9387-
"end\n",
9388-
"\n",
9389-
"```"
9390-
]
9344+
"source": "Optimization can be performed using the `Optim` package. Since Optim v2 the automatic differentiation backend is specified with an `ADTypes` object (e.g. `AutoForwardDiff()`) instead of a symbol. Alternatively, other (custom) optimizers can be implemented, such as:\n\n```julia\nres = optimize(f, randn(StableRNG(42), nr_params(model)), GradientDescent(), Optim.Options(store_trace = true, show_trace = true, show_every = 50))\n``` \n\n- uses finitediff and is slower/less accurate.\n\n*or*\n\n```julia\n# create gradient function\ng = (x) -> ForwardDiff.gradient(f, x);\n\n# specify initial params\nparams = randn(nr_params(model))\n\n# create custom optimizer (here Adam)\noptimizer = Adam(params; λ=1e-1)\n\n# allocate space for gradient\n∇ = zeros(nr_params(model))\n\n# perform optimization\nfor it = 1:10000\n\n # backward pass\n ∇ .= ForwardDiff.gradient(f, optimizer.x)\n\n # gradient update\n ReactiveMP.update!(optimizer, ∇)\n\nend\n\n```"
93919345
},
93929346
{
93939347
"cell_type": "code",
9394-
"execution_count": 21,
9348+
"execution_count": null,
93959349
"metadata": {},
9396-
"outputs": [
9397-
{
9398-
"name": "stdout",
9399-
"output_type": "stream",
9400-
"text": [
9401-
"Iter Function value Gradient norm \n",
9402-
" 0 5.888958e+02 8.943663e+02\n",
9403-
" * time: 0.02565789222717285\n",
9404-
" 100 1.059823e+01 4.118858e+00\n",
9405-
" * time: 6.649883985519409\n"
9406-
]
9407-
},
9408-
{
9409-
"data": {
9410-
"text/plain": [
9411-
" * Status: success\n",
9412-
"\n",
9413-
" * Candidate solution\n",
9414-
" Final objective value: 9.904775e+00\n",
9415-
"\n",
9416-
" * Found with\n",
9417-
" Algorithm: Gradient Descent\n",
9418-
"\n",
9419-
" * Convergence measures\n",
9420-
" |x - x'| = 1.22e-03 ≰ 0.0e+00\n",
9421-
" |x - x'|/|x'| = 5.79e-04 ≰ 0.0e+00\n",
9422-
" |f(x) - f(x')| = 9.55e-03 ≰ 0.0e+00\n",
9423-
" |f(x) - f(x')|/|f(x')| = 9.65e-04 ≤ 1.0e-03\n",
9424-
" |g(x)| = 2.21e+00 ≰ 1.0e-08\n",
9425-
"\n",
9426-
" * Work counters\n",
9427-
" Seconds run: 8 (vs limit Inf)\n",
9428-
" Iterations: 116\n",
9429-
" f(x) calls: 312\n",
9430-
" ∇f(x) calls: 312\n"
9431-
]
9432-
},
9433-
"execution_count": 21,
9434-
"metadata": {},
9435-
"output_type": "execute_result"
9436-
}
9437-
],
9438-
"source": [
9439-
"res = optimize(f, randn(StableRNG(42), nr_params(model)), GradientDescent(), Optim.Options(f_tol = 1e-3, store_trace = true, show_trace = true, show_every = 100), autodiff=:forward)"
9440-
]
9350+
"outputs": [],
9351+
"source": "res = optimize(f, randn(StableRNG(42), nr_params(model)), GradientDescent(), Optim.Options(f_reltol = 1e-3, store_trace = true, show_trace = true, show_every = 100), autodiff=AutoForwardDiff())"
94419352
},
94429353
{
94439354
"attachments": {},

examples/Problem Specific/Invertible Neural Network Tutorial/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
35
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
46
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
57
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

0 commit comments

Comments
 (0)