Skip to content

Commit ddc17e3

Browse files
committed
Use nutpie by default
A few tests previously checked `sample_stats.lp`; nutpie exposes it as `logp`, so switch those assertions to `diverging`, which both samplers emit identically.
1 parent 7fe04d5 commit ddc17e3

9 files changed

Lines changed: 395 additions & 68 deletions

File tree

pymc/sampling/mcmc.py

Lines changed: 136 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import contextlib
16+
import importlib.util
1617
import logging
1718
import multiprocessing
1819
import pickle
@@ -32,7 +33,10 @@
3233
import numpy as np
3334
import pytensor.gradient as tg
3435

36+
from pytensor.compile.mode import get_mode
3537
from pytensor.graph.basic import Variable
38+
from pytensor.link.jax.linker import JAXLinker
39+
from pytensor.link.numba.linker import NumbaLinker
3640
from rich.theme import Theme
3741
from threadpoolctl import threadpool_limits
3842
from typing_extensions import Protocol
@@ -81,6 +85,9 @@
8185
except ImportError:
8286
MemoryStore = type("MemoryStore", (), {})
8387

88+
NUTPIE_INSTALLED = importlib.util.find_spec("nutpie") is not None
89+
90+
8491
sys.setrecursionlimit(10000)
8592

8693
__all__ = [
@@ -324,38 +331,64 @@ def _sample_external_nuts(
324331
draws: int,
325332
tune: int,
326333
chains: int,
327-
target_accept: float,
334+
cores: int | None,
328335
random_seed: RandomState | None,
329336
initvals: StartDict | Sequence[StartDict | None] | None,
330337
model: Model,
331338
var_names: Sequence[str] | None,
332339
progressbar: bool | ProgressBarOptions,
333340
progressbar_theme: Theme | None,
334341
quiet: bool,
335-
idata_kwargs: dict | None,
336342
compute_convergence_checks: bool,
337-
nuts_sampler_kwargs: dict | None,
343+
discard_tuned_samples: bool,
344+
nuts_kwargs: dict,
345+
compile_kwargs: dict,
346+
idata_kwargs: dict | None,
338347
**kwargs,
339348
):
340-
if nuts_sampler_kwargs is None:
341-
nuts_sampler_kwargs = {}
349+
# Shallow copy dicts so we can safely matute them below
350+
nuts_kwargs = nuts_kwargs.copy()
351+
compile_kwargs = compile_kwargs.copy()
352+
idata_kwargs = {} if idata_kwargs is None else idata_kwargs.copy()
353+
354+
if "backend" in nuts_kwargs:
355+
warnings.warn(
356+
"`backend` should be passed as a top-level argument to `pm.sample`, "
357+
"not nested in the NUTS step kwargs.",
358+
FutureWarning,
359+
)
360+
compile_kwargs["mode"] = get_mode(nuts_kwargs.pop("backend"))
361+
362+
if "gradient_backend" in nuts_kwargs:
363+
warnings.warn(
364+
"`gradient_backend` should be passed via `compile_kwargs` to `pm.sample`, "
365+
"not nested in the NUTS step kwargs.",
366+
FutureWarning,
367+
)
368+
compile_kwargs["gradient_backend"] = nuts_kwargs.pop("gradient_backend")
342369

343370
if sampler == "nutpie":
344-
try:
345-
import nutpie
346-
except ImportError as err:
371+
if not NUTPIE_INSTALLED:
347372
raise ImportError(
348373
"nutpie not found. Install it with conda install -c conda-forge nutpie"
349-
) from err
350-
351-
if initvals is not None:
352-
warnings.warn(
353-
"`initvals` are currently not passed to nutpie sampler. "
354-
"Use `init_mean` kwarg following nutpie specification instead.",
355-
UserWarning,
374+
)
375+
import nutpie
376+
377+
if isinstance(initvals, dict):
378+
compile_kwargs.setdefault("initial_points", initvals)
379+
elif initvals is not None:
380+
raise NotImplementedError(
381+
"nutpie does not support per-chain `initvals`. "
382+
"Pass a single dict, or use `nuts_sampler='pymc'`."
356383
)
357384

358-
idata_kwargs = {} if idata_kwargs is None else {**idata_kwargs}
385+
# nuts-rs asserts `early_end < num_tune`, which panics when `tune == 0`.
386+
if tune == 0:
387+
tune = 1
388+
389+
if "max_treedepth" in nuts_kwargs:
390+
nuts_kwargs["maxdepth"] = nuts_kwargs.pop("max_treedepth")
391+
359392
include_transformed = idata_kwargs.pop("include_transformed", False)
360393
log_likelihood = idata_kwargs.pop("log_likelihood", False)
361394
if idata_kwargs:
@@ -364,11 +397,8 @@ def _sample_external_nuts(
364397
UserWarning,
365398
)
366399

367-
compile_kwargs = {}
368-
nuts_sampler_kwargs = nuts_sampler_kwargs.copy()
369-
for kwarg in ("backend", "gradient_backend"):
370-
if kwarg in nuts_sampler_kwargs:
371-
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
400+
linker = get_mode(compile_kwargs.pop("mode", None)).linker
401+
compile_kwargs.setdefault("backend", "jax" if isinstance(linker, JAXLinker) else "numba")
372402
compiled_model = nutpie.compile_pymc_model(
373403
model,
374404
var_names=var_names,
@@ -390,11 +420,12 @@ def _sample_external_nuts(
390420
draws=draws,
391421
tune=tune,
392422
chains=chains,
393-
target_accept=target_accept,
394-
seed=_get_seeds_per_chain(random_seed, 1)[0],
423+
cores=cores,
424+
seed=int(random_seed[0]),
425+
save_warmup=not discard_tuned_samples,
395426
progress_bar=False,
396427
progress_callback=pb_manager.update,
397-
**nuts_sampler_kwargs,
428+
**nuts_kwargs,
398429
)
399430
t_sample = time.time() - t_start
400431
patch_nutpie_idata(
@@ -426,21 +457,25 @@ def _sample_external_nuts(
426457
elif sampler in ("numpyro", "blackjax"):
427458
import pymc.sampling.jax as pymc_jax
428459

460+
jax_nuts_kwargs = dict(nuts_kwargs)
461+
jax_target_accept = jax_nuts_kwargs.pop("target_accept", 0.8)
429462
idata = pymc_jax.sample_jax_nuts(
430463
draws=draws,
431464
tune=tune,
432465
chains=chains,
433-
target_accept=target_accept,
434-
random_seed=random_seed,
466+
target_accept=jax_target_accept,
467+
# jax samplers take a single master seed; `random_seed` here is the
468+
# per-chain list produced by `pm.sample`, so use the first entry.
469+
random_seed=int(random_seed[0]),
435470
initvals=initvals,
436471
model=model,
437472
var_names=var_names,
438473
progressbar=bool(progressbar),
439474
quiet=quiet,
440475
nuts_sampler=sampler,
476+
nuts_kwargs=jax_nuts_kwargs,
441477
idata_kwargs=idata_kwargs,
442478
compute_convergence_checks=compute_convergence_checks,
443-
**nuts_sampler_kwargs,
444479
)
445480
return idata
446481

@@ -463,7 +498,7 @@ def sample(
463498
quiet: bool = False,
464499
step=None,
465500
var_names: Sequence[str] | None = None,
466-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
501+
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
467502
initvals: StartDict | Sequence[StartDict | None] | None = None,
468503
init: str = "auto",
469504
jitter_max_retries: int = 10,
@@ -496,7 +531,7 @@ def sample(
496531
quiet: bool = False,
497532
step=None,
498533
var_names: Sequence[str] | None = None,
499-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
534+
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
500535
initvals: StartDict | Sequence[StartDict | None] | None = None,
501536
init: str = "auto",
502537
jitter_max_retries: int = 10,
@@ -529,7 +564,7 @@ def sample(
529564
quiet: bool = False,
530565
step=None,
531566
var_names: Sequence[str] | None = None,
532-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
567+
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] | None = None,
533568
initvals: StartDict | Sequence[StartDict | None] | None = None,
534569
init: str = "auto",
535570
jitter_max_retries: int = 10,
@@ -599,10 +634,11 @@ def sample(
599634
method will be used, if appropriate to the model.
600635
var_names : list of str, optional
601636
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
602-
nuts_sampler : str
637+
nuts_sampler : str, optional
603638
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
604639
This requires the chosen sampler to be installed.
605640
All samplers, except "pymc", require the full model to be continuous.
641+
If ``None`` (default), "nutpie" is used if installed and can be compiled to the desired backend.
606642
blas_cores: int or "auto" or None, default = "auto"
607643
The total number of threads blas and openmp functions should use during sampling.
608644
Setting it to "auto" will ensure that the total number of active blas threads is the
@@ -651,8 +687,8 @@ def sample(
651687
idata_kwargs : dict, optional
652688
Keyword arguments for :func:`pymc.to_inference_data`
653689
nuts_sampler_kwargs : dict, optional
654-
Keyword arguments for the sampling library that implements nuts.
655-
Only used when an external sampler is specified via the `nuts_sampler` kwarg.
690+
Deprecated. Pass NUTS keyword arguments via ``nuts={...}`` instead
691+
(e.g. ``pm.sample(..., nuts={"target_accept": 0.9})``).
656692
callback : function, default=None
657693
A function which gets called for every sample from the trace of a chain. The function is
658694
called with the trace and the current draw and will contain all samples for a single trace.
@@ -747,8 +783,18 @@ def sample(
747783
FutureWarning,
748784
stacklevel=2,
749785
)
750-
if nuts_sampler_kwargs is None:
751-
nuts_sampler_kwargs = {}
786+
if nuts_sampler_kwargs is not None:
787+
warnings.warn(
788+
"`nuts_sampler_kwargs` is deprecated. Pass NUTS keyword arguments via the "
789+
"`nuts={...}` argument to `pm.sample`.",
790+
FutureWarning,
791+
stacklevel=2,
792+
)
793+
if "nuts" in kwargs:
794+
raise ValueError(
795+
"Cannot pass both `nuts_sampler_kwargs` and `nuts=`. Use `nuts=` only."
796+
)
797+
kwargs["nuts"] = nuts_sampler_kwargs
752798
if "target_accept" in kwargs:
753799
if "nuts" in kwargs and "target_accept" in kwargs["nuts"]:
754800
raise ValueError(
@@ -826,20 +872,67 @@ def sample(
826872

827873
compile_kwargs = resolve_backend_compile_kwargs(backend, compile_kwargs)
828874

875+
if nuts_sampler is None:
876+
# Try to use nutpie by default if no setting is clearly at odds.
877+
# Requires all model variables, numba or jax preference,
878+
# and must not conflict with pymc sample-only arguments.
879+
can_use_nutpie = (
880+
exclusive_nuts
881+
and not provided_steps
882+
and NUTPIE_INSTALLED
883+
and init == "auto"
884+
and return_inferencedata
885+
and trace is None
886+
and callback is None
887+
and (initvals is None or isinstance(initvals, dict))
888+
and isinstance(get_mode(compile_kwargs.get("mode")).linker, NumbaLinker | JAXLinker)
889+
)
890+
nuts_sampler = "nutpie" if can_use_nutpie else "pymc"
891+
elif nuts_sampler != "pymc" and not exclusive_nuts:
892+
raise ValueError(
893+
f"`nuts_sampler={nuts_sampler!r}` requires all variables to be differentiable "
894+
"and not assigned to another step sampler."
895+
)
896+
829897
if nuts_sampler != "pymc":
830-
if not exclusive_nuts:
898+
if provided_steps:
899+
warnings.warn(
900+
f"The provided NUTS `step` is ignored by `nuts_sampler={nuts_sampler!r}`; "
901+
"pass `nuts_sampler='pymc'` to use it.",
902+
UserWarning,
903+
stacklevel=2,
904+
)
905+
if not return_inferencedata:
831906
raise ValueError(
832-
"Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
907+
f"`return_inferencedata=False` is not supported with `nuts_sampler={nuts_sampler!r}`. "
908+
"External NUTS samplers can only return `InferenceData`."
833909
)
834-
910+
if trace is not None:
911+
raise ValueError(
912+
f"A custom `trace` backend is not supported with `nuts_sampler={nuts_sampler!r}`. "
913+
"Trace backends (e.g. `ZarrTrace`) only work with `nuts_sampler='pymc'`."
914+
)
915+
if callback is not None:
916+
raise ValueError(
917+
f"`callback` is not supported with `nuts_sampler={nuts_sampler!r}`. "
918+
"External NUTS samplers don't invoke per-draw callbacks."
919+
)
920+
if init != "auto":
921+
warnings.warn(
922+
f"`init={init!r}` is ignored by `nuts_sampler={nuts_sampler!r}`; "
923+
"the external sampler uses its own initialization.",
924+
UserWarning,
925+
stacklevel=2,
926+
)
927+
nuts_kwargs = kwargs.pop("nuts", {})
835928
with joined_blas_limiter():
836929
return _sample_external_nuts(
837930
sampler=nuts_sampler,
838931
draws=draws,
839932
tune=tune,
840933
chains=chains,
841-
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
842-
random_seed=random_seed,
934+
cores=cores,
935+
random_seed=random_seed_list,
843936
initvals=initvals,
844937
model=model,
845938
var_names=var_names,
@@ -848,7 +941,9 @@ def sample(
848941
quiet=quiet,
849942
idata_kwargs=idata_kwargs,
850943
compute_convergence_checks=compute_convergence_checks,
851-
nuts_sampler_kwargs=nuts_sampler_kwargs,
944+
discard_tuned_samples=discard_tuned_samples,
945+
nuts_kwargs=nuts_kwargs,
946+
compile_kwargs=compile_kwargs,
852947
**kwargs,
853948
)
854949

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@
5353

5454
test_reqs = ["pytest", "pytest-cov"]
5555

56+
extras_require = {
57+
"nutpie": ["nutpie>=0.16.8"],
58+
}
59+
5660
if __name__ == "__main__":
5761
setup(
5862
name="pymc",
@@ -72,5 +76,6 @@
7276
classifiers=classifiers,
7377
python_requires=">=3.12",
7478
install_requires=install_reqs,
79+
extras_require=extras_require,
7580
tests_require=test_reqs,
7681
)

tests/backends/test_arviz.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
# Related to https://github.com/arviz-devs/arviz/issues/2327
4040
"ignore:datetime.datetime.utcnow():DeprecationWarning",
4141
r"ignore::numba.NumbaPerformanceWarning",
42+
r"ignore:This process .* is multi-threaded:DeprecationWarning",
4243
)
4344

4445

@@ -240,7 +241,7 @@ def test_posterior_predictive_thinned(self, data):
240241
idata.update(pm.sample_posterior_predictive(thinned_idata))
241242
test_dict = {
242243
"posterior": ["mu", "tau", "eta", "theta"],
243-
"sample_stats": ["diverging", "lp", "~log_likelihood"],
244+
"sample_stats": ["diverging", "~log_likelihood"],
244245
"posterior_predictive": ["obs"],
245246
"observed_data": ["obs"],
246247
}
@@ -412,7 +413,7 @@ def test_multiple_observed_rv(self, log_likelihood):
412413
"posterior": ["x"],
413414
"observed_data": ["y1", "y2"],
414415
"log_likelihood": ["y1", "y2"],
415-
"sample_stats": ["diverging", "lp", "~log_likelihood"],
416+
"sample_stats": ["diverging", "~log_likelihood"],
416417
}
417418
if not log_likelihood:
418419
test_dict.pop("log_likelihood")
@@ -630,7 +631,7 @@ def test_multivariate_observations(self):
630631
)
631632
test_dict = {
632633
"posterior": ["p"],
633-
"sample_stats": ["lp"],
634+
"sample_stats": ["diverging"],
634635
"log_likelihood": ["y"],
635636
"observed_data": ["y"],
636637
}

tests/distributions/test_custom.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def random(rng, size):
151151
assert isinstance(y_dist.owner.op, CustomDistRV)
152152
with warnings.catch_warnings():
153153
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
154-
sample(draws=5, tune=1, mp_ctx="spawn")
154+
# nutpie can't handle RNG in deterministics
155+
# https://github.com/pymc-devs/nutpie/issues/4
156+
sample(draws=5, tune=1, mp_ctx="spawn", nuts_sampler="pymc")
155157

156158
cloudpickle.loads(cloudpickle.dumps(y))
157159
cloudpickle.loads(cloudpickle.dumps(y_dist))

0 commit comments

Comments
 (0)