Skip to content

Commit 7b42e55

Browse files
authored
Merge pull request #33 from EarthyScience/dev
Dev
2 parents 67e06e0 + 9ca7846 commit 7b42e55

93 files changed

Lines changed: 4552 additions & 943 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HybridVariationalInference"
22
uuid = "a108c475-a4e2-4021-9a84-cfa7df242f64"
33
authors = ["Thomas Wutzler <twutz@bgc-jena.mpg.de> and contributors"]
4-
version = "0.2"
4+
version = "0.2.0"
55

66
[deps]
77
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
@@ -31,6 +31,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3131
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3232
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3333
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
34+
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3435
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3536

3637
[weakdeps]
@@ -69,14 +70,15 @@ MLUtils = "0.4.5"
6970
Missings = "1.2.0"
7071
NaNMath = "1.1.3"
7172
Optimisers = "0.4.6"
72-
Optimization = "3.19.3, 4"
73+
Optimization = "3.11, 4"
7374
Random = "1.10.0"
7475
SimpleChains = "0.4"
7576
StableRNGs = "1.0.2"
7677
StaticArrays = "1.9.13"
7778
StatsBase = "0.34.4"
7879
StatsFuns = "1.3.2"
7980
Test = "1.10"
81+
UnPack = "1.0.2"
8082
Zygote = "0.7.10"
8183
julia = "1.10"
8284

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ of the posterior. It returns a NamedTuple of
5555
- the machine learning model parameters (usually weights), $\phi_g$
5656
- means of the global parameters, $\phi_P = \mu_{\zeta_P}$ at transformed
5757
unconstrained scale
58-
- additional parameters, $\phi_{unc}$ of the posterior, $q(\zeta)$, such as
59-
coefficients that describe the scaling of variance with magnitude
60-
and coefficients that parameterize the choleski-factor or the correlation matrix.
58+
- additional parameters, $\phi_{ϕq}$ of the posterior, $q(\zeta)$, such as
59+
- coefficients that describe the scaling of variance with magnitude
60+
- coefficients that parameterize the choleski-factor or the correlation matrix
61+
- mean of global parameters at unconstrained scale
6162
- `θP`: predicted means of the global parameters, $\theta_P$
6263
- `resopt`: the original result object of the optimizer (useful for debugging)
6364

docs/src/tutorials/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
1515
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1616
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
1717
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
18+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1819
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
20+
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1921
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2022
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

docs/src/tutorials/basic_cpu.md

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using SimpleChains
1717
using StatsFuns
1818
using MLUtils
1919
using DistributionFits
20+
using UnPack
2021
```
2122

2223
Next, specify many moving parts of the Hybrid variational inference (HVI)
@@ -33,9 +34,7 @@ $$
3334
``` julia
3435
function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
3536
# extract parameters not depending on order, i.e whether they are in θP or θM
36-
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
37-
CA.getdata(θc[par])::ET
38-
end
37+
@unpack r0, r1, K1, K2 = θc
3938
r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
4039
end
4140
```
@@ -157,15 +156,16 @@ the problem below.
157156

158157
### Providing data in batches
159158

160-
HVI uses `MLUtils.DataLoader` to provide baches of the data during each
159+
HVI uses `MLUtils.DataLoader` to provide batches of the data during each
161160
iteration of the solver. In addition to the data, it provides an
162161
index to the sites inside a tuple.
163162

164163
``` julia
165164
n_site = size(y_o,2)
166165
n_batch = 20
167166
train_dataloader = MLUtils.DataLoader(
168-
(xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false)
167+
(CA.getdata(xM), CA.getdata(xP), y_o, y_unc, 1:n_site),
168+
batchsize=n_batch, partial=false)
169169
```
170170

171171
## The Machine-Learning model
@@ -211,7 +211,7 @@ However, for simplicity, a [`NormalScalingModelApplicator`](@ref)
211211
is fitted to the transformed 5% and 95% quantiles of the original prior.
212212

213213
``` julia
214-
priorsM = [priors_dict[k] for k in keys(θM)]
214+
priorsM = Tuple(priors_dict[k] for k in keys(θM))
215215
lowers, uppers = get_quantile_transformed(priorsM, transM)
216216
g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT)
217217
```
@@ -231,8 +231,9 @@ invocation of the process based model (PBM), defined at the beginning.
231231

232232
``` julia
233233
f_batch = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1])
234+
ϕq0 = init_hybrid_ϕq(MeanHVIApproximation(), θP, θM, transP)
234235

235-
prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0,
236+
prob = HybridProblem(θM, ϕq0, g_chain_scaled, ϕg0,
236237
f_batch, priors_dict, py,
237238
transM, transP, train_dataloader, n_covar, n_site, n_batch)
238239
```
@@ -267,7 +268,7 @@ Then the solver is applied to the problem using [`solve`](@ref)
267268
for a given number of iterations or epochs.
268269
For this tutorial, we additionally specify that the function to transfer structures to
269270
the GPU is the identity function, so that all stays on the CPU, and this tutorial
270-
hence does not require ad GPU or GPU livraries.
271+
hence does not require ad GPU or GPU libraries.
271272

272273
Among the return values are
273274
- `probo`: A copy of the HybridProblem, with updated optimized parameters
@@ -276,7 +277,7 @@ will help analyzing the results.
276277

277278
## Using a population-level process-based model
278279

279-
So far, the process-based model ram for each single site.
280+
So far, the process-based model ran for each single site.
280281
For this simple model, some performance grains result from matrix-computations
281282
when running the model for all sites within one batch simultaneously.
282283

@@ -289,29 +290,25 @@ one site. For the drivers and predictions, one column corresponds to one site.
289290
``` julia
290291
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
291292
# extract several covariates from xP
292-
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
293-
S1 = (CA.getdata(xPc[:S1,:])::ST)
294-
S2 = (CA.getdata(xPc[:S2,:])::ST)
293+
S1 = view(xPc, Val(:S1), :)
294+
S2 = view(xPc, Val(:S2), :)
295295
#
296296
# extract the parameters as row-repeated vectors
297-
n_obs = size(S1, 1)
298-
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
299-
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
300-
p1 = CA.getdata(θc[:, par]) ::VT
301-
repeat(p1', n_obs) # matrix: same for each concentration row in S1
302-
end
303-
#
297+
# θc[:,:r0] is parameter r0 for each site in batch
298+
# dot-multiplication of full matrix times row-vector repeats for each observation row
299+
# also introduces zero for missing observations, leading to zero gradient there
300+
is_valid = isfinite.(S1) .&& isfinite.(S2)
301+
r0 = is_valid .* CA.getdata(θc[:, Val(:r0)])'
302+
r1 = is_valid .* CA.getdata(θc[:, Val(:r1)])'
303+
K1 = is_valid .* CA.getdata(θc[:, Val(:K1)])'
304+
K2 = is_valid .* CA.getdata(θc[:, Val(:K2)])'
304305
# each variable is a matrix (n_obs x n_site)
305306
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
306307
end
307308
```
308309

309310
Again, the function should not rely on the order of parameters but use symbolic indexing
310-
to extract the parameter vectors. For type stability of this symbolic indexing,
311-
it uses a workaround to get the type of a single row.
312-
Similarly, it uses type hints to index into the drivers, `xPc`, to extract
313-
sub-matrices by symbols. Alternatively, here it could rely on the structure and
314-
ordering of the columns in `xPc`.
311+
to extract the parameter vectors.
315312

316313
A corresponding [`PBMPopulationApplicator`](@ref) transforms calls with
317314
partitioned global and site parameters to calls of this matrix version of the PBM.
@@ -323,11 +320,9 @@ probo_sites = HybridProblem(probo; f_batch)
323320
```
324321

325322
For numerical efficiency, the number of sites within one batch is part of the
326-
`PBMPopulationApplicator`. Hence, we have two different functions, one applied
327-
to a batch of site, and another applied to all sites.
328-
329-
As a test of the new applicator, the results are refined by running a few more
330-
epochs of the optimization.
323+
`PBMPopulationApplicator`. The problem stores an applicator for `n_batch` sites,
324+
however, an applicator for `n_site_pred` sites can be obtained by
325+
`create_nsite_applicator(f_batch, n_site_pred)`.
331326

332327
``` julia
333328
(; probo) = solve(probo_sites, solver; rng,
@@ -344,7 +339,7 @@ in the following [Inspect results of fitted problem](@ref) tutorial.
344339
In order to use the results from this tutorial in other tutorials,
345340
the updated `probo` `HybridProblem` and the interpreters are saved to a JLD2 file.
346341

347-
Before the problem is updated to use the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
342+
Before the problem is updated, so that it uses the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
348343
of the PBM in module `DoubleMM` rather than
349344
module `Main` to allow for easier reloading with JLD2.
350345

docs/src/tutorials/basic_cpu.qmd

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ using SimpleChains
2727
using StatsFuns
2828
using MLUtils
2929
using DistributionFits
30+
using UnPack
3031
```
3132

3233
Next, specify many moving parts of the Hybrid variational inference (HVI)
@@ -42,9 +43,7 @@ $$
4243
```{julia}
4344
function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
4445
# extract parameters not depending on order, i.e whether they are in θP or θM
45-
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
46-
CA.getdata(θc[par])::ET
47-
end
46+
@unpack r0, r1, K1, K2 = θc
4847
r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
4948
end
5049
```
@@ -166,15 +165,16 @@ the problem below.
166165

167166
### Providing data in batches
168167

169-
HVI uses `MLUtils.DataLoader` to provide baches of the data during each
168+
HVI uses `MLUtils.DataLoader` to provide batches of the data during each
170169
iteration of the solver. In addition to the data, it provides an
171170
index to the sites inside a tuple.
172171

173172
```{julia}
174173
n_site = size(y_o,2)
175174
n_batch = 20
176175
train_dataloader = MLUtils.DataLoader(
177-
(xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false)
176+
(CA.getdata(xM), CA.getdata(xP), y_o, y_unc, 1:n_site),
177+
batchsize=n_batch, partial=false)
178178
```
179179

180180
## The Machine-Learning model
@@ -220,7 +220,7 @@ However, for simplicity, a [`NormalScalingModelApplicator`](@ref)
220220
is fitted to the transformed 5% and 95% quantiles of the original prior.
221221

222222
```{julia}
223-
priorsM = [priors_dict[k] for k in keys(θM)]
223+
priorsM = Tuple(priors_dict[k] for k in keys(θM))
224224
lowers, uppers = get_quantile_transformed(priorsM, transM)
225225
g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT)
226226
```
@@ -241,8 +241,9 @@ invocation of the process based model (PBM), defined at the beginning.
241241

242242
```{julia}
243243
f_batch = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1])
244+
ϕq0 = init_hybrid_ϕq(MeanHVIApproximation(), θP, θM, transP)
244245
245-
prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0,
246+
prob = HybridProblem(θM, ϕq0, g_chain_scaled, ϕg0,
246247
f_batch, priors_dict, py,
247248
transM, transP, train_dataloader, n_covar, n_site, n_batch)
248249
```
@@ -302,7 +303,7 @@ Then the solver is applied to the problem using [`solve`](@ref)
302303
for a given number of iterations or epochs.
303304
For this tutorial, we additionally specify that the function to transfer structures to
304305
the GPU is the identity function, so that all stays on the CPU, and this tutorial
305-
hence does not require ad GPU or GPU livraries.
306+
hence does not require ad GPU or GPU libraries.
306307

307308
Among the return values are
308309
- `probo`: A copy of the HybridProblem, with updated optimized parameters
@@ -311,7 +312,7 @@ Among the return values are
311312

312313
## Using a population-level process-based model
313314

314-
So far, the process-based model ram for each single site.
315+
So far, the process-based model ran for each single site.
315316
For this simple model, some performance grains result from matrix-computations
316317
when running the model for all sites within one batch simultaneously.
317318

@@ -323,31 +324,28 @@ one site. For the drivers and predictions, one column corresponds to one site.
323324

324325

325326
```{julia}
327+
using StaticArrays
326328
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
327329
# extract several covariates from xP
328-
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
329-
S1 = (CA.getdata(xPc[:S1,:])::ST)
330-
S2 = (CA.getdata(xPc[:S2,:])::ST)
330+
S1 = view(xPc, Val(:S1), :)
331+
S2 = view(xPc, Val(:S2), :)
331332
#
332333
# extract the parameters as row-repeated vectors
333-
n_obs = size(S1, 1)
334-
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
335-
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
336-
p1 = CA.getdata(θc[:, par]) ::VT
337-
repeat(p1', n_obs) # matrix: same for each concentration row in S1
338-
end
339-
#
334+
# θc[:,:r0] is parameter r0 for each site in batch
335+
# dot-multiplication of full matrix times row-vector repeats for each observation row
336+
# also introduces zero for missing observations, leading to zero gradient there
337+
is_valid = isfinite.(S1) .&& isfinite.(S2)
338+
r0 = is_valid .* CA.getdata(θc[:, Val(:r0)])'
339+
r1 = is_valid .* CA.getdata(θc[:, Val(:r1)])'
340+
K1 = is_valid .* CA.getdata(θc[:, Val(:K1)])'
341+
K2 = is_valid .* CA.getdata(θc[:, Val(:K2)])'
340342
# each variable is a matrix (n_obs x n_site)
341343
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
342344
end
343345
```
344346

345347
Again, the function should not rely on the order of parameters but use symbolic indexing
346-
to extract the parameter vectors. For type stability of this symbolic indexing,
347-
it uses a workaround to get the type of a single row.
348-
Similarly, it uses type hints to index into the drivers, `xPc`, to extract
349-
sub-matrices by symbols. Alternatively, here it could rely on the structure and
350-
ordering of the columns in `xPc`.
348+
to extract the parameter vectors.
351349

352350
A corresponding [`PBMPopulationApplicator`](@ref) transforms calls with
353351
partitioned global and site parameters to calls of this matrix version of the PBM.
@@ -359,11 +357,9 @@ probo_sites = HybridProblem(probo; f_batch)
359357
```
360358

361359
For numerical efficiency, the number of sites within one batch is part of the
362-
`PBMPopulationApplicator`. Hence, we have two different functions, one applied
363-
to a batch of site, and another applied to all sites.
364-
365-
As a test of the new applicator, the results are refined by running a few more
366-
epochs of the optimization.
360+
`PBMPopulationApplicator`. The problem stores an applicator for `n_batch` sites,
361+
however, an applicator for `n_site_pred` sites can be obtained by
362+
`create_nsite_applicator(f_batch, n_site_pred)`.
367363

368364
```{julia}
369365
(; probo) = solve(probo_sites, solver; rng,
@@ -379,7 +375,7 @@ in the following [Inspect results of fitted problem](@ref) tutorial.
379375
In order to use the results from this tutorial in other tutorials,
380376
the updated `probo` `HybridProblem` and the interpreters are saved to a JLD2 file.
381377

382-
Before the problem is updated to use the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
378+
Before the problem is updated, so that it uses the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
383379
of the PBM in module `DoubleMM` rather than
384380
module `Main` to allow for easier reloading with JLD2.
385381

0 commit comments

Comments
 (0)