Skip to content

Implement linearly constrained samplers#814

Draft
ripaul wants to merge 8 commits into
blackjax-devs:mainfrom
ripaul:linearly-constrained-samplers
Draft

Implement linearly constrained samplers#814
ripaul wants to merge 8 commits into
blackjax-devs:mainfrom
ripaul:linearly-constrained-samplers

Conversation

@ripaul
Copy link
Copy Markdown

@ripaul ripaul commented Mar 24, 2026

A few important guidelines and requirements before we can merge your PR:

  • If I add a new sampler, there is an issue discussing it already; cf. Linearly Constrained Samplers #813
  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

Overview

This PR adds a the Dikin walk, Vaidya walk, MAPLA and Hit-&-Run-based samplers specialized for sampling on linearly constrained densities. Since Dikin, Vaidya and MAPLA are specialized versions of a Gaussian random walk with position-dependant covariance and the simplified Manifold MALA algorithm, respectively, this PR additionally implements these two algorithms as well.

API Design

The top-level API closely follows that of the other MCMC samplers. For the linearly constrained samplers, passing the left-hand side matrix Aand right-hand side bounds b of the linear inequality system Ax <= b is required.

import blackjax

dikin = blackjax.dikin(logp, A, b, step_size=1.0)
vaidya = blackjax.vaidya(logp, A, b, step_size=1.0)
mapla = blackjax.mapla(logp, A, b, step_size=1.0)

ehr = blackjax.ehr(logp, A, b, grad_fn, mass_matrix_fn, step_size=1.0)

posdep_rwmh = blackjax.posdep_rwmh(logp, mass_matrix_fn, step_size=1.0)
smmala = blackjax.smmala(logp, mass_matrix_fn, step_size=1.0)

Implementation

Implementations of each sampler are found in blackjax/mcmc/{sampler}.py. Additionally, blackjax/mcmc/step_distributions.py contains univariate distributions which can be passed to blackjax.ehr(..., step_dist=...) in order to control the univariate distribution from which the magnitude component of the Hit-&-Run sample is drawn. Moreover, overdamped simplified manifold diffusion was added to blackjax/mcmc/diffusions.py and some required infrastructure to deal with the metrics was added to blackjax/mcmc/metrics.py

Tests

An additional test case for sampling a normal distribution truncated to a simplex in 2 dimensions was added. All of the above mentioned samplers are tested on this new test case.

Questions

I am unsure whether some of the algorithms could be implemented in a more blackjax-ish way by using more of the existing infrastructure. For example, the position dependant RWMH could potentially be implemented as a special case of the existing blackjax.rmh? Also, the simplified overdamped manifold Langevin diffusion could be possibly implemented by providing in terms of a special integrator, I guess? But I am not knowledgeable enough to do that, which is why I opted for the way I implemented it here. I'd be happy to get some feedback on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant