@@ -17,6 +17,7 @@ using SimpleChains
1717using StatsFuns
1818using MLUtils
1919using DistributionFits
20+ using UnPack
2021```
2122
2223Next, specify many moving parts of the Hybrid variational inference (HVI)
3334``` julia
3435function f_doubleMM (θc:: CA.ComponentVector{ET} , x) where ET
3536 # extract parameters not depending on order, i.e whether they are in θP or θM
36- (r0, r1, K1, K2) = map ((:r0 , :r1 , :K1 , :K2 )) do par
37- CA. getdata (θc[par]):: ET
38- end
37+ @unpack r0, r1, K1, K2 = θc
3938 r0 .+ r1 .* x. S1 ./ (K1 .+ x. S1) .* x. S2 ./ (K2 .+ x. S2)
4039end
4140```
@@ -157,15 +156,16 @@ the problem below.
157156
158157### Providing data in batches
159158
160- HVI uses ` MLUtils.DataLoader ` to provide baches of the data during each
159+ HVI uses ` MLUtils.DataLoader ` to provide batches of the data during each
161160iteration of the solver. In addition to the data, it provides an
162161index to the sites inside a tuple.
163162
164163``` julia
165164n_site = size (y_o,2 )
166165n_batch = 20
167166train_dataloader = MLUtils. DataLoader (
168- (xM, xP, y_o, y_unc, 1 : n_site), batchsize= n_batch, partial= false )
167+ (CA. getdata (xM), CA. getdata (xP), y_o, y_unc, 1 : n_site),
168+ batchsize= n_batch, partial= false )
169169```
170170
171171## The Machine-Learning model
@@ -211,7 +211,7 @@ However, for simplicity, a [`NormalScalingModelApplicator`](@ref)
211211is fitted to the transformed 5% and 95% quantiles of the original prior.
212212
213213``` julia
214- priorsM = [ priors_dict[k] for k in keys (θM)]
214+ priorsM = Tuple ( priors_dict[k] for k in keys (θM))
215215lowers, uppers = get_quantile_transformed (priorsM, transM)
216216g_chain_scaled = NormalScalingModelApplicator (g_chain_app, lowers, uppers, FT)
217217```
@@ -231,8 +231,9 @@ invocation of the process based model (PBM), defined at the beginning.
231231
232232``` julia
233233f_batch = PBMSiteApplicator (f_doubleMM; θP, θM, θFix, xPvec= xP[:,1 ])
234+ ϕq0 = init_hybrid_ϕq (MeanHVIApproximation (), θP, θM, transP)
234235
235- prob = HybridProblem (θP, θM , g_chain_scaled, ϕg0,
236+ prob = HybridProblem (θM, ϕq0 , g_chain_scaled, ϕg0,
236237 f_batch, priors_dict, py,
237238 transM, transP, train_dataloader, n_covar, n_site, n_batch)
238239```
@@ -267,7 +268,7 @@ Then the solver is applied to the problem using [`solve`](@ref)
267268for a given number of iterations or epochs.
268269For this tutorial, we additionally specify that the function to transfer structures to
269270the GPU is the identity function, so that all stays on the CPU, and this tutorial
270- hence does not require ad GPU or GPU livraries .
271+ hence does not require ad GPU or GPU libraries .
271272
272273Among the return values are
273274- ` probo ` : A copy of the HybridProblem, with updated optimized parameters
@@ -276,7 +277,7 @@ will help analyzing the results.
276277
277278## Using a population-level process-based model
278279
279- So far, the process-based model ram for each single site.
280+ So far, the process-based model ran for each single site.
280281For this simple model, some performance grains result from matrix-computations
281282when running the model for all sites within one batch simultaneously.
282283
@@ -289,29 +290,25 @@ one site. For the drivers and predictions, one column corresponds to one site.
289290``` julia
290291function f_doubleMM_sites (θc:: CA.ComponentMatrix , xPc:: CA.ComponentMatrix )
291292 # extract several covariates from xP
292- ST = typeof (CA. getdata (xPc)[1 : 1 ,:]) # workaround for non-type-stable Symbol-indexing
293- S1 = (CA. getdata (xPc[:S1 ,:]):: ST )
294- S2 = (CA. getdata (xPc[:S2 ,:]):: ST )
293+ S1 = view (xPc, Val (:S1 ), :)
294+ S2 = view (xPc, Val (:S2 ), :)
295295 #
296296 # extract the parameters as row-repeated vectors
297- n_obs = size (S1, 1 )
298- VT = typeof (CA. getdata (θc)[:,1 ]) # workaround for non-type-stable Symbol-indexing
299- (r0, r1, K1, K2) = map ((:r0 , :r1 , :K1 , :K2 )) do par
300- p1 = CA. getdata (θc[:, par]) :: VT
301- repeat (p1' , n_obs) # matrix: same for each concentration row in S1
302- end
303- #
297+ # θc[:,:r0] is parameter r0 for each site in batch
298+ # dot-multiplication of full matrix times row-vector repeats for each observation row
299+ # also introduces zero for missing observations, leading to zero gradient there
300+ is_valid = isfinite .(S1) .&& isfinite .(S2)
301+ r0 = is_valid .* CA. getdata (θc[:, Val (:r0 )])'
302+ r1 = is_valid .* CA. getdata (θc[:, Val (:r1 )])'
303+ K1 = is_valid .* CA. getdata (θc[:, Val (:K1 )])'
304+ K2 = is_valid .* CA. getdata (θc[:, Val (:K2 )])'
304305 # each variable is a matrix (n_obs x n_site)
305306 r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
306307end
307308```
308309
309310Again, the function should not rely on the order of parameters but use symbolic indexing
310- to extract the parameter vectors. For type stability of this symbolic indexing,
311- it uses a workaround to get the type of a single row.
312- Similarly, it uses type hints to index into the drivers, ` xPc ` , to extract
313- sub-matrices by symbols. Alternatively, here it could rely on the structure and
314- ordering of the columns in ` xPc ` .
311+ to extract the parameter vectors.
315312
316313A corresponding [ ` PBMPopulationApplicator ` ] ( @ref ) transforms calls with
317314partitioned global and site parameters to calls of this matrix version of the PBM.
@@ -323,11 +320,9 @@ probo_sites = HybridProblem(probo; f_batch)
323320```
324321
325322For numerical efficiency, the number of sites within one batch is part of the
326- ` PBMPopulationApplicator ` . Hence, we have two different functions, one applied
327- to a batch of site, and another applied to all sites.
328-
329- As a test of the new applicator, the results are refined by running a few more
330- epochs of the optimization.
323+ ` PBMPopulationApplicator ` . The problem stores an applicator for ` n_batch ` sites,
324+ however, an applicator for ` n_site_pred ` sites can be obtained by
325+ ` create_nsite_applicator(f_batch, n_site_pred) ` .
331326
332327``` julia
333328(; probo) = solve (probo_sites, solver; rng,
@@ -344,7 +339,7 @@ in the following [Inspect results of fitted problem](@ref) tutorial.
344339In order to use the results from this tutorial in other tutorials,
345340the updated ` probo ` ` HybridProblem ` and the interpreters are saved to a JLD2 file.
346341
347- Before the problem is updated to use the redefinition [ ` DoubleMM.f_doubleMM_sites ` ] ( @ref )
342+ Before the problem is updated, so that it uses the redefinition [ ` DoubleMM.f_doubleMM_sites ` ] ( @ref )
348343of the PBM in module ` DoubleMM ` rather than
349344module ` Main ` to allow for easier reloading with JLD2.
350345
0 commit comments