Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from arviz import InferenceData, dict_to_dataset
from arviz.data.base import make_attrs
from pytensor.compile.mode import get_mode
from pytensor.graph.basic import Variable
from rich.theme import Theme
from threadpoolctl import threadpool_limits
Expand Down Expand Up @@ -83,6 +84,17 @@
except ImportError:
MemoryStore = type("MemoryStore", (), {})

try:
from pytensor.link.jax.dispatch import JAXLinker
except ImportError:
JAXLinker = type("JAXLinker", (), {})

try:
from pytensor.link.numba.dispatch import NumbaLinker
except ImportError:
NumbaLinker = type("NumbaLinker", (), {})


sys.setrecursionlimit(10000)

__all__ = [
Expand Down Expand Up @@ -336,10 +348,27 @@ def _sample_external_nuts(
idata_kwargs: dict | None,
compute_convergence_checks: bool,
nuts_sampler_kwargs: dict | None,
compile_kwargs: dict | None = None,
**kwargs,
):
compile_kwargs = compile_kwargs or {}
nutpie_compile_kwargs = {}

# Propagate backend based on compile_kwargs mode
if "mode" in compile_kwargs:
mode = get_mode(compile_kwargs["mode"])
if isinstance(mode.linker, JAXLinker):
nutpie_compile_kwargs["backend"] = "jax"
elif isinstance(mode.linker, NumbaLinker):
nutpie_compile_kwargs["backend"] = "numba"

if nuts_sampler_kwargs is None:
nuts_sampler_kwargs = {}
nuts_sampler_kwargs = nuts_sampler_kwargs.copy()

for kwarg in ("backend", "gradient_backend"):
if kwarg in nuts_sampler_kwargs:
nutpie_compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)

if sampler == "nutpie":
try:
Expand All @@ -362,15 +391,10 @@ def _sample_external_nuts(
UserWarning,
)

compile_kwargs = {}
nuts_sampler_kwargs = nuts_sampler_kwargs.copy()
for kwarg in ("backend", "gradient_backend"):
if kwarg in nuts_sampler_kwargs:
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
compiled_model = nutpie.compile_pymc_model(
model,
var_names=var_names,
**compile_kwargs,
**nutpie_compile_kwargs,
)
t_start = time.time()
idata = nutpie.sample(
Expand Down Expand Up @@ -457,7 +481,7 @@ def sample(
quiet: bool = False,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -490,7 +514,7 @@ def sample(
quiet: bool = False,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -523,7 +547,7 @@ def sample(
quiet: bool = False,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -592,10 +616,13 @@ def sample(
method will be used, if appropriate to the model.
var_names : list of str, optional
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
nuts_sampler : str
nuts_sampler : str, optional
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
This requires the chosen sampler to be installed.
All samplers, except "pymc", require the full model to be continuous.
If ``None`` (default), "nutpie" is used if installed and the model is suitable
(all continuous variables, no incompatible compile_kwargs).
Otherwise "pymc" is used.
blas_cores: int or "auto" or None, default = "auto"
The total number of threads blas and openmp functions should use during sampling.
Setting it to "auto" will ensure that the total number of active blas threads is the
Expand Down Expand Up @@ -838,6 +865,15 @@ def joined_blas_limiter():
)
)

if nuts_sampler is None:
if exclusive_nuts and (
compile_kwargs is None
or isinstance(get_mode(compile_kwargs.get("mode")).linker, JAXLinker | NumbaLinker)
):
nuts_sampler = "nutpie"
else:
nuts_sampler = "pymc"

if nuts_sampler != "pymc":
if not exclusive_nuts:
raise ValueError(
Expand All @@ -850,7 +886,7 @@ def joined_blas_limiter():
draws=draws,
tune=tune,
chains=chains,
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
target_accept=nuts_sampler_kwargs.get("target_accept", 0.8),
random_seed=random_seed,
initvals=initvals,
model=model,
Expand All @@ -860,6 +896,7 @@ def joined_blas_limiter():
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
nuts_sampler_kwargs=nuts_sampler_kwargs,
compile_kwargs=compile_kwargs,
**kwargs,
)

Expand Down
124 changes: 124 additions & 0 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,3 +965,127 @@ def test_quiet_false_shows_logs(self, caplog):

pymc_logs = [r for r in caplog.records if r.name.startswith("pymc")]
assert len(pymc_logs) > 0


class TestNutpieSelection:
@pytest.fixture
def continuous_model(self):
with pm.Model() as model:
pm.Normal("x")
return model

def test_auto_selection_numba(self, continuous_model):
with (
mock.patch("pymc.sampling.mcmc.get_mode") as mock_get_mode,
mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_sample_external,
):
# Create a mock linker and mode
MockNumbaLinker = type("MockNumbaLinker", (), {})
with mock.patch("pymc.sampling.mcmc.NumbaLinker", MockNumbaLinker):
mock_mode = mock.Mock()
mock_mode.linker = MockNumbaLinker()
mock_get_mode.return_value = mock_mode

pm.sample(
model=continuous_model,
compile_kwargs={"mode": "NUMBA"},
tune=10,
draws=10,
chains=1,
progressbar=False,
)

mock_sample_external.assert_called_once()
assert mock_sample_external.call_args[1]["sampler"] == "nutpie"
assert mock_sample_external.call_args[1].get("compile_kwargs") == {"mode": "NUMBA"}

def test_auto_selection_jax(self, continuous_model):
with (
mock.patch("pymc.sampling.mcmc.get_mode") as mock_get_mode,
mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_sample_external,
):
MockJAXLinker = type("MockJAXLinker", (), {})
with mock.patch("pymc.sampling.mcmc.JAXLinker", MockJAXLinker):
mock_mode = mock.Mock()
mock_mode.linker = MockJAXLinker()
mock_get_mode.return_value = mock_mode

pm.sample(
model=continuous_model,
compile_kwargs={"mode": "JAX"},
tune=10,
draws=10,
chains=1,
progressbar=False,
)

mock_sample_external.assert_called_once()
assert mock_sample_external.call_args[1]["sampler"] == "nutpie"
# Backend should be propagated correctly in _sample_external_nuts, but here we check kwargs passed TO it
assert mock_sample_external.call_args[1].get("compile_kwargs") == {"mode": "JAX"}

def test_fallback_cvm(self, continuous_model):
with (
mock.patch("pymc.sampling.mcmc.get_mode") as mock_get_mode,
mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_sample_external,
mock.patch("pymc.sampling.mcmc._iter_sample"),
mock.patch("pymc.sampling.mcmc._mp_sample"),
):
# Use real NumbaLinker/JAXLinker classes if possible, or mocks that won't match CVM

mock_mode = mock.Mock()
mock_mode.linker = mock.Mock() # Generic mock, not JAX or Numba
mock_get_mode.return_value = mock_mode

pm.sample(
model=continuous_model,
compile_kwargs={"mode": "FAST_RUN"},
tune=10,
draws=10,
chains=1,
progressbar=False,
)

mock_sample_external.assert_not_called()

def test_explicit_selection(self, continuous_model):
with mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_sample_external:
pm.sample(
model=continuous_model,
nuts_sampler="nutpie",
tune=10,
draws=10,
chains=1,
progressbar=False,
)
mock_sample_external.assert_called_once()
assert mock_sample_external.call_args[1]["sampler"] == "nutpie"

def test_backend_propagation_internal(self, continuous_model):
with mock.patch.dict("sys.modules", {"nutpie": mock.Mock()}):
import nutpie

nutpie.compile_pymc_model = mock.Mock()
nutpie.sample = mock.Mock(return_value=mock.Mock())

with mock.patch("pymc.sampling.mcmc.get_mode") as mock_get_mode:
MockNumbaLinker = type("MockNumbaLinker", (), {})
with mock.patch("pymc.sampling.mcmc.NumbaLinker", MockNumbaLinker):
mock_mode = mock.Mock()
mock_mode.linker = MockNumbaLinker()
mock_get_mode.return_value = mock_mode

# We can call pm.sample with nuts_sampler="nutpie" and compile_kwargs
pm.sample(
model=continuous_model,
nuts_sampler="nutpie",
compile_kwargs={"mode": "NUMBA"},
tune=10,
draws=10,
chains=1,
progressbar=False,
)

nutpie.compile_pymc_model.assert_called()
_, kwargs = nutpie.compile_pymc_model.call_args
assert kwargs.get("backend") == "numba"