|
43 | 43 | }, |
44 | 44 | { |
45 | 45 | "cell_type": "code", |
46 | | - "execution_count": 1, |
| 46 | + "execution_count": null, |
47 | 47 | "metadata": {}, |
48 | 48 | "outputs": [], |
49 | | - "source": [ |
50 | | - "using RxInfer\n", |
51 | | - "using Random\n", |
52 | | - "using StableRNGs\n", |
53 | | - "\n", |
54 | | - "using ReactiveMP # ReactiveMP is included in RxInfer, but we explicitly use some of its functionality\n", |
55 | | - "using LinearAlgebra # only used for some matrix specifics\n", |
56 | | - "using Plots # only used for visualisation\n", |
57 | | - "using Distributions # only used for sampling from multivariate distributions\n", |
58 | | - "using Optim # only used for parameter optimisation" |
59 | | - ] |
| 49 | + "source": "using RxInfer\nusing Random\nusing StableRNGs\n\nusing ReactiveMP # ReactiveMP is included in RxInfer, but we explicitly use some of its functionality\nusing LinearAlgebra # only used for some matrix specifics\nusing Plots # only used for visualisation\nusing Distributions # only used for sampling from multivariate distributions\nusing Optim # only used for parameter optimisation\nusing ADTypes # only used to specify the automatic differentiation backend for Optim\nusing ForwardDiff # only used for automatic differentiation in the parameter optimisation" |
60 | 50 | }, |
61 | 51 | { |
62 | 52 | "attachments": {}, |
|
9351 | 9341 | "attachments": {}, |
9352 | 9342 | "cell_type": "markdown", |
9353 | 9343 | "metadata": {}, |
9354 | | - "source": [ |
9355 | | - "Optimization can be performed using the `Optim` package. Alternatively, other (custom) optimizers can be implemented, such as:\n", |
9356 | | - "\n", |
9357 | | - "```julia\n", |
9358 | | - "res = optimize(f, randn(StableRNG(42), nr_params(model)), GradientDescent(), Optim.Options(store_trace = true, show_trace = true, show_every = 50), autodiff=:forward)\n", |
9359 | | - "``` \n", |
9360 | | - "\n", |
9361 | | - "- uses finitediff and is slower/less accurate.\n", |
9362 | | - "\n", |
9363 | | - "*or*\n", |
9364 | | - "\n", |
9365 | | - "```julia\n", |
9366 | | - "# create gradient function\n", |
9367 | | - "g = (x) -> ForwardDiff.gradient(f, x);\n", |
9368 | | - "\n", |
9369 | | - "# specify initial params\n", |
9370 | | - "params = randn(nr_params(model))\n", |
9371 | | - "\n", |
9372 | | - "# create custom optimizer (here Adam)\n", |
9373 | | - "optimizer = Adam(params; λ=1e-1)\n", |
9374 | | - "\n", |
9375 | | - "# allocate space for gradient\n", |
9376 | | - "∇ = zeros(nr_params(model))\n", |
9377 | | - "\n", |
9378 | | - "# perform optimization\n", |
9379 | | - "for it = 1:10000\n", |
9380 | | - "\n", |
9381 | | - " # backward pass\n", |
9382 | | - " ∇ .= ForwardDiff.gradient(f, optimizer.x)\n", |
9383 | | - "\n", |
9384 | | - " # gradient update\n", |
9385 | | - " ReactiveMP.update!(optimizer, ∇)\n", |
9386 | | - "\n", |
9387 | | - "end\n", |
9388 | | - "\n", |
9389 | | - "```" |
9390 | | - ] |
| 9344 | + "source": "Optimization can be performed using the `Optim` package. Since Optim v2 the automatic differentiation backend is specified with an `ADTypes` object (e.g. `AutoForwardDiff()`) instead of a symbol. Alternatively, other (custom) optimizers can be implemented, such as:\n\n```julia\nres = optimize(f, randn(StableRNG(42), nr_params(model)), GradientDescent(), Optim.Options(store_trace = true, show_trace = true, show_every = 50))\n``` \n\n- uses finitediff and is slower/less accurate.\n\n*or*\n\n```julia\n# create gradient function\ng = (x) -> ForwardDiff.gradient(f, x);\n\n# specify initial params\nparams = randn(nr_params(model))\n\n# create custom optimizer (here Adam)\noptimizer = Adam(params; λ=1e-1)\n\n# allocate space for gradient\n∇ = zeros(nr_params(model))\n\n# perform optimization\nfor it = 1:10000\n\n # backward pass\n ∇ .= ForwardDiff.gradient(f, optimizer.x)\n\n # gradient update\n ReactiveMP.update!(optimizer, ∇)\n\nend\n\n```" |
9391 | 9345 | }, |
9392 | 9346 | { |
9393 | 9347 | "cell_type": "code", |
9394 | | - "execution_count": 21, |
| 9348 | + "execution_count": null, |
9395 | 9349 | "metadata": {}, |
9396 | | - "outputs": [ |
9397 | | - { |
9398 | | - "name": "stdout", |
9399 | | - "output_type": "stream", |
9400 | | - "text": [ |
9401 | | - "Iter Function value Gradient norm \n", |
9402 | | - " 0 5.888958e+02 8.943663e+02\n", |
9403 | | - " * time: 0.02565789222717285\n", |
9404 | | - " 100 1.059823e+01 4.118858e+00\n", |
9405 | | - " * time: 6.649883985519409\n" |
9406 | | - ] |
9407 | | - }, |
9408 | | - { |
9409 | | - "data": { |
9410 | | - "text/plain": [ |
9411 | | - " * Status: success\n", |
9412 | | - "\n", |
9413 | | - " * Candidate solution\n", |
9414 | | - " Final objective value: 9.904775e+00\n", |
9415 | | - "\n", |
9416 | | - " * Found with\n", |
9417 | | - " Algorithm: Gradient Descent\n", |
9418 | | - "\n", |
9419 | | - " * Convergence measures\n", |
9420 | | - " |x - x'| = 1.22e-03 ≰ 0.0e+00\n", |
9421 | | - " |x - x'|/|x'| = 5.79e-04 ≰ 0.0e+00\n", |
9422 | | - " |f(x) - f(x')| = 9.55e-03 ≰ 0.0e+00\n", |
9423 | | - " |f(x) - f(x')|/|f(x')| = 9.65e-04 ≤ 1.0e-03\n", |
9424 | | - " |g(x)| = 2.21e+00 ≰ 1.0e-08\n", |
9425 | | - "\n", |
9426 | | - " * Work counters\n", |
9427 | | - " Seconds run: 8 (vs limit Inf)\n", |
9428 | | - " Iterations: 116\n", |
9429 | | - " f(x) calls: 312\n", |
9430 | | - " ∇f(x) calls: 312\n" |
9431 | | - ] |
9432 | | - }, |
9433 | | - "execution_count": 21, |
9434 | | - "metadata": {}, |
9435 | | - "output_type": "execute_result" |
9436 | | - } |
9437 | | - ], |
9438 | | - "source": [ |
9439 | | - "res = optimize(f, randn(StableRNG(42), nr_params(model)), GradientDescent(), Optim.Options(f_tol = 1e-3, store_trace = true, show_trace = true, show_every = 100), autodiff=:forward)" |
9440 | | - ] |
| 9350 | + "outputs": [], |
| 9351 | + "source": "res = optimize(f, randn(StableRNG(42), nr_params(model)), GradientDescent(), Optim.Options(f_reltol = 1e-3, store_trace = true, show_trace = true, show_every = 100), autodiff=AutoForwardDiff())" |
9441 | 9352 | }, |
9442 | 9353 | { |
9443 | 9354 | "attachments": {}, |
|
0 commit comments