Skip to content

Commit 18eedab

Browse files
authored
Make FlexiChains the default chain type (#2743)
Mostly interested in seeing what breaks, I ran some tests offline then got bored. As far as source code changes go, I only needed to change the definition of `DEFAULT_CHAIN_TYPE` and twiddle with some imports. So in principle this is pretty much done. Most changes were really in the test suite, using `chn[@varname(x)]` instead of `chn[:x]`. In fact, that [isn't even necessary](https://pysm.dev/FlexiChains.jl/stable/turing/#Indexing-by-Symbol:-a-shortcut), but I just consider it good practice. The only changes that I found to be mandatory were changing things like `chn[Symbol("m[1]")]` to `chn[@varname(m[1])]`, which is a Positive Change, and also other things that relied too heavily on the exact internal structure of MCMCChains, which is also a Positive Change, or a Neutral Change in the cases where I had to rely on the internal structure of FlexiChains. This PR should definitely target breaking, but that would cause version conflicts with FlexiChains, since FlexiChains has a compat entry that pins Turing to its current minor version. That can be fixed by moving FlexiChainsTuringExt into Turing proper, so that that compat entry can be removed.
1 parent 6df73f8 commit 18eedab

33 files changed

Lines changed: 796 additions & 317 deletions

HISTORY.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
# 0.45.0
2+
3+
## Breaking changes
4+
5+
Make FlexiChains the default chain type for MCMC sampling.
6+
7+
MCMCChains is still fully supported: you can specify `chain_type=MCMCChains.Chains` in the `sample` function to use it instead.
8+
However, it is no longer loaded as a dependency of Turing and re-exported (it is now an extension).
9+
That means that if you were previously importing MCMCChains via Turing, you will now have to import it directly.
10+
111
# 0.44.5
212

313
Allow users to disable the post-sample hook by passing `verbose=false` keyword argument to `sample`.

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.44.5"
3+
version = "0.45.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -20,11 +20,11 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2020
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
2121
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
2222
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
23+
FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7"
2324
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2425
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2526
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2627
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
27-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2828
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
2929
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
3030
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
@@ -40,6 +40,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
4040

4141
[weakdeps]
4242
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
43+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
4344

4445
[extensions]
4546
TuringDynamicHMCExt = "DynamicHMC"
@@ -63,6 +64,7 @@ DocStringExtensions = "0.8, 0.9"
6364
DynamicHMC = "3.4"
6465
DynamicPPL = "0.41.3"
6566
EllipticalSliceSampling = "0.5, 1, 2"
67+
FlexiChains = "0.6"
6668
ForwardDiff = "0.10.3, 1"
6769
Libtask = "0.9.14"
6870
LinearAlgebra = "1"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ links = InterLinks(
1111
"AbstractMCMC" => "https://turinglang.org/AbstractMCMC.jl/stable/",
1212
"ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/",
1313
"AdvancedVI" => "https://turinglang.org/AdvancedVI.jl/stable/",
14+
"FlexiChains" => "https://pysm.dev/FlexiChains.jl/stable/",
1415
"OrderedCollections" => "https://juliacollections.github.io/OrderedCollections.jl/stable/",
1516
"Distributions" => "https://juliastats.org/Distributions.jl/stable/",
1617
)

docs/src/api.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22

33
## Module-wide re-exports
44

5-
Turing.jl directly re-exports the entire public API of the following packages:
6-
7-
- [Distributions.jl](https://juliastats.org/Distributions.jl)
8-
- [MCMCChains.jl](https://turinglang.org/MCMCChains.jl)
9-
10-
Please see the individual packages for their documentation.
5+
Turing.jl directly re-exports the entire public API of [Distributions.jl](https://juliastats.org/Distributions.jl).
6+
Please see its documentation for more details.
117

128
## Individual exports and re-exports
139

@@ -50,13 +46,14 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
5046

5147
### Inference
5248

53-
| Exported symbol | Documentation | Description |
54-
|:----------------- |:------------------------------------------------------------------------- |:----------------------------------------- |
55-
| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model |
56-
| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads |
57-
| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes |
58-
| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism |
59-
| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from `MCMCChains.Chains` |
49+
| Exported symbol | Documentation | Description |
50+
|:----------------- |:------------------------------------------------------------------------- |:----------------------------------- |
51+
| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model |
52+
| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads |
53+
| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes |
54+
| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism |
55+
| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from an MCMC chain |
56+
| `VNChain` | n/a | Alias for `FlexiChain{VarName}` |
6057

6158
### Samplers
6259

ext/TuringMCMCChainsExt.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module TuringMCMCChainsExt
2+
3+
using Turing
4+
using Turing: AbstractMCMC
5+
using Turing.Inference: HMC, NUTS, HMCDA, Emcee, EmceeState
6+
using MCMCChains: MCMCChains
7+
8+
"""
9+
loadstate(chain::MCMCChains.Chains)
10+
11+
Load the final state of the sampler from a `MCMCChains.Chains` object.
12+
13+
To save the final state of the sampler, you must use `sample(...; save_state=true)`. If this
14+
argument was not used during sampling, calling `loadstate` will throw an error.
15+
"""
16+
function Turing.Inference.loadstate(chain::MCMCChains.Chains)
17+
if !haskey(chain.info, :samplerstate)
18+
throw(
19+
ArgumentError(
20+
"the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`",
21+
),
22+
)
23+
end
24+
return chain.info[:samplerstate]
25+
end
26+
27+
function AbstractMCMC.bundle_samples(
28+
samples::Vector{<:Vector},
29+
model::DynamicPPL.Model,
30+
spl::Emcee,
31+
state::EmceeState,
32+
::Type{MCMCChains.Chains},
33+
kwargs...,
34+
)
35+
n_walkers = _get_n_walkers(spl)
36+
chains = map(1:n_walkers) do i
37+
this_walker_samples = [s[i] for s in samples]
38+
AbstractMCMC.bundle_samples(
39+
this_walker_samples, model, spl, state, MCMCChains.Chains; kwargs...
40+
)
41+
end
42+
return AbstractMCMC.chainscat(chains...)
43+
end
44+
45+
"""
46+
post_sample_hook(chain::MCMCChains.Chains, sampler::Union{HMC,NUTS,HMCDA}; kwargs...)
47+
48+
Emit a warning message if there are divergent transitions in the chain.
49+
"""
50+
function post_sample_hook(
51+
chain::MCMCChains.Chains, ::Union{HMC,NUTS,HMCDA}; verbose::Bool=true, kwargs...
52+
)
53+
n_divergent = round(Int, sum(skipmissing(vec(chain[:numerical_error]))))
54+
if verbose && n_divergent > 0
55+
@warn "There were $n_divergent divergent transitions. Consider reparameterising your model or using a smaller step size. For adaptive samplers such as NUTS and HMCDA, consider increasing `target_accept`."
56+
end
57+
return nothing
58+
end
59+
60+
end

src/Turing.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Reexport, ForwardDiff
44
using Bijectors, StatsFuns, SpecialFunctions
55
using Statistics, LinearAlgebra
66
using Libtask
7-
@reexport using Distributions, MCMCChains
7+
@reexport using Distributions
88
using Compat: pkgversion
99

1010
using AdvancedVI: AdvancedVI
@@ -14,6 +14,7 @@ using LogDensityProblems: LogDensityProblems
1414
using StatsAPI: StatsAPI
1515
using StatsBase: StatsBase
1616
using AbstractMCMC
17+
using FlexiChains
1718

1819
using Printf: Printf
1920
using Random: Random
@@ -189,6 +190,8 @@ export
189190
loadstate,
190191
# kwargs in SMC
191192
might_produce,
192-
@might_produce
193+
@might_produce,
194+
# FlexiChains re-export
195+
VNChain
193196

194197
end

src/mcmc/Inference.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using DynamicPPL:
1616
Model,
1717
DefaultContext
1818
using Distributions, Libtask, Bijectors
19+
using FlexiChains: FlexiChains, VNChain
1920
using LinearAlgebra
2021
using ..Turing: PROGRESS, Turing
2122
using StatsFuns: logsumexp
@@ -35,7 +36,6 @@ import AdvancedPS
3536
import EllipticalSliceSampling
3637
import LogDensityProblems
3738
import Random
38-
import MCMCChains
3939
import StatsBase: predict
4040

4141
export Hamiltonian,
@@ -62,7 +62,19 @@ export Hamiltonian,
6262
init_strategy,
6363
loadstate
6464

65-
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
65+
const DEFAULT_CHAIN_TYPE = VNChain
66+
67+
"""
68+
Turing.loadstate(chain::FlexiChain{<:VarName})
69+
70+
Extracts the last sampler state from a `FlexiChain`. This is the same function as
71+
[`FlexiChains.last_sampler_state`](@ref).
72+
73+
$(FlexiChains._INITIAL_STATE_DOCSTRING)
74+
"""
75+
function loadstate(chain::VNChain)
76+
return FlexiChains.last_sampler_state(chain)
77+
end
6678

6779
include("abstractmcmc.jl")
6880
include("repeat_sampler.jl")

src/mcmc/abstractmcmc.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -165,25 +165,6 @@ function AbstractMCMC.sample(
165165
return chain
166166
end
167167

168-
"""
169-
loadstate(chain::MCMCChains.Chains)
170-
171-
Load the final state of the sampler from a `MCMCChains.Chains` object.
172-
173-
To save the final state of the sampler, you must use `sample(...; save_state=true)`. If this
174-
argument was not used during sampling, calling `loadstate` will throw an error.
175-
"""
176-
function loadstate(chain::MCMCChains.Chains)
177-
if !haskey(chain.info, :samplerstate)
178-
throw(
179-
ArgumentError(
180-
"the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`",
181-
),
182-
)
183-
end
184-
return chain.info[:samplerstate]
185-
end
186-
187168
# TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures
188169
function initialstep end
189170

src/mcmc/emcee.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ function AbstractMCMC.step(
122122
end
123123

124124
function AbstractMCMC.bundle_samples(
125-
samples::Vector{<:Vector},
126-
model::AbstractModel,
125+
samples::Vector{<:AbstractVector},
126+
model::DynamicPPL.Model,
127127
spl::Emcee,
128128
state::EmceeState,
129-
chain_type::Type{MCMCChains.Chains};
129+
chain_type::Type{VNChain};
130130
kwargs...,
131131
)
132132
n_walkers = _get_n_walkers(spl)
133133
chains = map(1:n_walkers) do i
134134
this_walker_samples = [s[i] for s in samples]
135135
AbstractMCMC.bundle_samples(
136-
this_walker_samples, model, spl, state, chain_type; kwargs...
136+
this_walker_samples, model, spl, state, VNChain; kwargs...
137137
)
138138
end
139139
return AbstractMCMC.chainscat(chains...)

src/mcmc/hmc.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,14 +469,16 @@ function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric, nadapts::Int; kwargs.
469469
end
470470

471471
"""
472-
post_sample_hook(chain::MCMCChains.Chains, sampler::Union{HMC,NUTS,HMCDA}; kwargs...)
472+
post_sample_hook(chain::FlexiChains.VNChain, sampler::Union{HMC,NUTS,HMCDA}; kwargs...)
473473
474474
Emit a warning message if there are divergent transitions in the chain.
475475
"""
476476
function post_sample_hook(
477-
chain::MCMCChains.Chains, ::Union{HMC,NUTS,HMCDA}; verbose::Bool=true, kwargs...
477+
chain::FlexiChains.VNChain, ::Union{HMC,NUTS,HMCDA}; verbose::Bool=true, kwargs...
478478
)
479-
n_divergent = round(Int, sum(skipmissing(vec(chain[:numerical_error]))))
479+
n_divergent = round(
480+
Int, sum(skipmissing(vec(chain[FlexiChains.Extra(:numerical_error)])))
481+
)
480482
if verbose && n_divergent > 0
481483
@warn "There were $n_divergent divergent transitions. Consider reparameterising your model or using a smaller step size. For adaptive samplers such as NUTS and HMCDA, consider increasing `target_accept`."
482484
end

0 commit comments

Comments
 (0)