diff --git a/Project.toml b/Project.toml index f9f92bc35..b37e16e4e 100644 --- a/Project.toml +++ b/Project.toml @@ -44,6 +44,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" [extensions] TuringDynamicHMCExt = "DynamicHMC" +TuringMCMCChainsExt = "MCMCChains" [compat] ADTypes = "1.9" diff --git a/ext/TuringMCMCChainsExt.jl b/ext/TuringMCMCChainsExt.jl index 446b70f6f..18416c433 100644 --- a/ext/TuringMCMCChainsExt.jl +++ b/ext/TuringMCMCChainsExt.jl @@ -1,10 +1,12 @@ module TuringMCMCChainsExt using Turing -using Turing: AbstractMCMC -using Turing.Inference: HMC, NUTS, HMCDA, Emcee, EmceeState +using Turing: AbstractMCMC, DynamicPPL +using Turing.Inference: HMC, NUTS, HMCDA, Emcee, EmceeState, _get_n_walkers using MCMCChains: MCMCChains +import Turing.Inference: post_sample_hook + """ loadstate(chain::MCMCChains.Chains) diff --git a/test/Project.toml b/test/Project.toml index d5dd07966..8a7cbb82e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -65,6 +65,7 @@ Libtask = "0.9.14" LinearAlgebra = "1" LogDensityProblems = "2" LogDensityProblemsAD = "1.4" +MCMCChains = "7" Mooncake = "0.4.182, 0.5" Optimization = "3, 4, 5" OptimizationBBO = "0.1, 0.2, 0.3, 0.4" diff --git a/test/mcmc/chains.jl b/test/mcmc/chains.jl index e9fb0bdf9..7189448d4 100644 --- a/test/mcmc/chains.jl +++ b/test/mcmc/chains.jl @@ -36,6 +36,10 @@ function AbstractMCMC.step( return DynamicPPL.ParamsWithStats(vnt, (;)), vnt end +function mcmc_values(chain::MCMCChains.Chains, name::Symbol) + return reshape(Array(chain[name]), size(chain, 1), size(chain, 3)) +end + @testset verbose = true "chains.jl" begin @testset "basic sampling" begin @model function demomodel(x) @@ -363,6 +367,67 @@ end end end end + + @testset "MCMCChains metadata" begin + @testset "save_state and initial_state" begin + @model f() = x ~ Normal() + model = f() + + @testset "single chain" begin + chn1 = sample( + model, + StaticSampler(), + 10; + chain_type=MCMCChains.Chains, + verbose=false, + save_state=true, + ) + state = loadstate(chn1) + @test state isa DynamicPPL.VarNamedTuple + + chn2 = sample( + model, + StaticSampler(), + 10; + chain_type=MCMCChains.Chains, + verbose=false, + initial_state=state, + ) + xval = mcmc_values(chn1, :x)[end, 1] + @test all(==(xval), mcmc_values(chn2, :x)) + end + + @testset "multiple chain" begin + chn1 = sample( + model, + StaticSampler(), + MCMCThreads(), + 10, + 3; + chain_type=MCMCChains.Chains, + verbose=false, + save_state=true, + ) + states = loadstate(chn1) + @test states isa AbstractVector{<:DynamicPPL.VarNamedTuple} + @test length(states) == 3 + + chn2 = sample( + model, + StaticSampler(), + MCMCThreads(), + 10, + 3; + chain_type=MCMCChains.Chains, + verbose=false, + initial_state=states, + ) + xval = mcmc_values(chn1, :x)[end, :] + samples = mcmc_values(chn2, :x) + @test all(i -> samples[i, :] == xval, axes(samples, 1)) + end + end + end end @testset "underlying data is same regardless of backend" begin