This repository uses BlackJAX (pinned to version 1.2.5) for MCMC sampling. The implementation focuses on clean, idiomatic JAX loops using jax.lax.scan within llc/sampling.py.
- Implementation: Uses
blackjax.hmcalong withblackjax.window_adaptationfor automatic tuning of step size and the inverse mass matrix during the warmup phase. - Precision: Runs in
float64. - Parallelism: Chains are run in parallel using
jax.vmap.
- Implementation: A custom SGLD kernel is implemented directly in
run_sgldto handle minibatching efficiently. - Precision: Runs in
float32. - Parallelism: Chains are run in parallel using
jax.vmap. - Note: Preconditioning (RMSProp/Adam) is implemented in the custom SGLD loop.
- Implementation: Uses
blackjax.mclmc. - Precision: Runs in
float64. - Parallelism: Chains are run in parallel using
jax.vmap. - Note: The current implementation uses fixed values for the trajectory length (
L) andstep_sizeprovided in the configuration. Automatic adaptation usingblackjax.mclmc_find_L_and_step_sizeis not implemented.
The online BlackJAX documentation defaults to the main branch, which may differ from the pinned version 1.2.5 used here. When consulting BlackJAX documentation, ensure you are viewing the 1.2.5 tag source.