Implement linearly constrained samplers#814
Draft
ripaul wants to merge 8 commits into
Draft
Conversation
…ckjax into linearly-constrained-samplers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
A few important guidelines and requirements before we can merge your PR:
maincommit;pre-commitis installed and configured on your machine, and you ran it before opening the PR;Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.
Overview
This PR adds a the Dikin walk, Vaidya walk, MAPLA and Hit-&-Run-based samplers specialized for sampling on linearly constrained densities. Since Dikin, Vaidya and MAPLA are specialized versions of a Gaussian random walk with position-dependant covariance and the simplified Manifold MALA algorithm, respectively, this PR additionally implements these two algorithms as well.
API Design
The top-level API closely follows that of the other MCMC samplers. For the linearly constrained samplers, passing the left-hand side matrix
Aand right-hand side boundsbof the linear inequality systemAx <= bis required.Implementation
Implementations of each sampler are found in
blackjax/mcmc/{sampler}.py. Additionally,blackjax/mcmc/step_distributions.pycontains univariate distributions which can be passed toblackjax.ehr(..., step_dist=...)in order to control the univariate distribution from which the magnitude component of the Hit-&-Run sample is drawn. Moreover, overdamped simplified manifold diffusion was added toblackjax/mcmc/diffusions.pyand some required infrastructure to deal with the metrics was added toblackjax/mcmc/metrics.pyTests
An additional test case for sampling a normal distribution truncated to a simplex in 2 dimensions was added. All of the above mentioned samplers are tested on this new test case.
Questions
I am unsure whether some of the algorithms could be implemented in a more
blackjax-ish way by using more of the existing infrastructure. For example, the position dependant RWMH could potentially be implemented as a special case of the existingblackjax.rmh? Also, the simplified overdamped manifold Langevin diffusion could be possibly implemented by providing in terms of a special integrator, I guess? But I am not knowledgeable enough to do that, which is why I opted for the way I implemented it here. I'd be happy to get some feedback on this.