-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathusage.jl
More file actions
129 lines (118 loc) · 3.68 KB
/
usage.jl
File metadata and controls
129 lines (118 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
## Enable Logging
using Logging, TerminalLoggers
global_logger(TerminalLogger())
## Data
using Distributions
ndata = 1024
ndimensions = 1
data_dist = Beta(2.0, 4.0)
r = rand(data_dist, ndimensions, ndata)
## Parameters
nvariables = size(r, 1)
naugments = nvariables + 1
n_in = nvariables + naugments + 1 # add time concatenation
n_out = nvariables + naugments
n_hidden = n_in * 4
## Model
using ContinuousNormalizingFlows,
Lux,
OrdinaryDiffEqAdamsBashforthMoulton,
SciMLLogging,
SciMLSensitivity,
ADTypes,
Zygote,
# ForwardDiff, # to use JVP
# LuxCUDA, # To use gpu
MLDataDevices
icnf = ICNF(;
nn = Chain(
Dense(n_in => n_hidden, softplus),
Dense(n_hidden => n_hidden, softplus),
Dense(n_hidden => n_out),
),
nvariables = nvariables, # number of variables
naugments = naugments, # number of augmented dimensions
nconditions = 0, # number of conditioning inputs
λ₁ = 0.01, # regulate flow
λ₂ = 0.01, # regulate volume change
λ₃ = 0.01, # regulate augmented dimensions
steer_rate = 0.1, # add random noise to end of the time span
tspan = (0.0, 1.0), # time span
device = cpu_device(), # process data by CPU
# device = gpu_device(), # process data by GPU
autonomous = false, # using non-autonomous flow
inplace = false, # not using the inplace version of functions
compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use VJP via Zygote
# compute_mode = LuxJacVecMatrixMode(AutoForwardDiff()), # process data in batches and use JVP via ForwardDiff
sol_kwargs = (;
save_everystep = false,
maxiters = typemax(Int),
reltol = 1.0e-4,
abstol = 1.0e-8,
alg = VCABM(),
sensealg = QuadratureAdjoint(;
autodiff = true,
autojacvec = ZygoteVJP(),
reltol = 1.0e-4,
abstol = 1.0e-8,
),
progress = false,
verbose = Detailed(),
), # pass to the solver
)
## Fit It
using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers
function opt_callback(state::Any, l::Any)
if isone(state.iter % 64) # log the loss at each 64 iterations
println("Iteration: $(state.iter) | Loss: $l")
end
return false
end
icnf_mach_fn = "icnf-machine.jls"
if !isfile(icnf_mach_fn)
df = DataFrame(permutedims(r), :auto)
model = ICNFModel(;
icnf,
optimizers = (
OptimiserChain(
WeightDecay(; lambda = 1.0e-4),
ClipNorm(10.0, 2.0; throw = true),
Adam(; eta = 0.001, beta = (0.9, 0.999), epsilon = 1.0e-8),
),
),
batchsize = 1024,
adtype = AutoZygote(),
sol_kwargs = (;
epochs = 300,
callback = opt_callback,
progress = true,
verbose = Detailed(),
), # pass to the solver
)
mach = machine(model, df)
fit!(mach)
# CUDA.@allowscalar fit!(mach) # needed for gpu
MLJBase.save(icnf_mach_fn, mach) # save it
end
mach = machine(icnf_mach_fn) # load it
## Use It
d = ICNFDist(mach, TestMode())
actual_pdf = pdf.(data_dist, vec(r))
estimated_pdf = pdf(d, r)
new_data = rand(d, ndata)
## Evaluate It
using Distances
mad_ = meanad(estimated_pdf, actual_pdf)
msd_ = msd(estimated_pdf, actual_pdf)
tv_dis = totalvariation(estimated_pdf, actual_pdf) / ndata
res_df = DataFrame(; mad_, msd_, tv_dis)
display(res_df)
## Plot It
using CairoMakie
f = Figure()
ax = Axis(f[1, 1]; title = "Result")
lines!(ax, 0.0 .. 1.0, x -> pdf(data_dist, x); label = "Actual")
lines!(ax, 0.0 .. 1.0, x -> pdf(d, vcat(x)); label = "Estimated")
axislegend(ax)
save("result-figure.svg", f)
save("result-figure.png", f)