|
| 1 | +module TuringMCMCChainsExt |
| 2 | + |
| 3 | +using Turing |
| 4 | +using Turing: AbstractMCMC |
| 5 | +using Turing.Inference: HMC, NUTS, HMCDA, Emcee, EmceeState |
| 6 | +using MCMCChains: MCMCChains |
| 7 | + |
| 8 | +""" |
| 9 | + loadstate(chain::MCMCChains.Chains) |
| 10 | +
|
| 11 | +Load the final state of the sampler from a `MCMCChains.Chains` object. |
| 12 | +
|
| 13 | +To save the final state of the sampler, you must use `sample(...; save_state=true)`. If this |
| 14 | +argument was not used during sampling, calling `loadstate` will throw an error. |
| 15 | +""" |
| 16 | +function Turing.Inference.loadstate(chain::MCMCChains.Chains) |
| 17 | + if !haskey(chain.info, :samplerstate) |
| 18 | + throw( |
| 19 | + ArgumentError( |
| 20 | + "the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`", |
| 21 | + ), |
| 22 | + ) |
| 23 | + end |
| 24 | + return chain.info[:samplerstate] |
| 25 | +end |
| 26 | + |
| 27 | +function AbstractMCMC.bundle_samples( |
| 28 | + samples::Vector{<:Vector}, |
| 29 | + model::DynamicPPL.Model, |
| 30 | + spl::Emcee, |
| 31 | + state::EmceeState, |
| 32 | + ::Type{MCMCChains.Chains}, |
| 33 | + kwargs..., |
| 34 | +) |
| 35 | + n_walkers = _get_n_walkers(spl) |
| 36 | + chains = map(1:n_walkers) do i |
| 37 | + this_walker_samples = [s[i] for s in samples] |
| 38 | + AbstractMCMC.bundle_samples( |
| 39 | + this_walker_samples, model, spl, state, MCMCChains.Chains; kwargs... |
| 40 | + ) |
| 41 | + end |
| 42 | + return AbstractMCMC.chainscat(chains...) |
| 43 | +end |
| 44 | + |
| 45 | +""" |
| 46 | + post_sample_hook(chain::MCMCChains.Chains, sampler::Union{HMC,NUTS,HMCDA}; kwargs...) |
| 47 | +
|
| 48 | +Emit a warning message if there are divergent transitions in the chain. |
| 49 | +""" |
| 50 | +function post_sample_hook( |
| 51 | + chain::MCMCChains.Chains, ::Union{HMC,NUTS,HMCDA}; verbose::Bool=true, kwargs... |
| 52 | +) |
| 53 | + n_divergent = round(Int, sum(skipmissing(vec(chain[:numerical_error])))) |
| 54 | + if verbose && n_divergent > 0 |
| 55 | + @warn "There were $n_divergent divergent transitions. Consider reparameterising your model or using a smaller step size. For adaptive samplers such as NUTS and HMCDA, consider increasing `target_accept`." |
| 56 | + end |
| 57 | + return nothing |
| 58 | +end |
| 59 | + |
| 60 | +end |
0 commit comments