1313# limitations under the License.
1414
1515import contextlib
16+ import importlib .util
1617import logging
1718import multiprocessing
1819import pickle
3233import numpy as np
3334import pytensor .gradient as tg
3435
36+ from pytensor .compile .mode import get_mode
3537from pytensor .graph .basic import Variable
38+ from pytensor .link .jax .linker import JAXLinker
39+ from pytensor .link .numba .linker import NumbaLinker
3640from rich .theme import Theme
3741from threadpoolctl import threadpool_limits
3842from typing_extensions import Protocol
8185except ImportError :
8286 MemoryStore = type ("MemoryStore" , (), {})
8387
88+ NUTPIE_INSTALLED = importlib .util .find_spec ("nutpie" ) is not None
89+
90+
8491sys .setrecursionlimit (10000 )
8592
8693__all__ = [
@@ -324,38 +331,64 @@ def _sample_external_nuts(
324331 draws : int ,
325332 tune : int ,
326333 chains : int ,
327- target_accept : float ,
334+ cores : int | None ,
328335 random_seed : RandomState | None ,
329336 initvals : StartDict | Sequence [StartDict | None ] | None ,
330337 model : Model ,
331338 var_names : Sequence [str ] | None ,
332339 progressbar : bool | ProgressBarOptions ,
333340 progressbar_theme : Theme | None ,
334341 quiet : bool ,
335- idata_kwargs : dict | None ,
336342 compute_convergence_checks : bool ,
337- nuts_sampler_kwargs : dict | None ,
343+ discard_tuned_samples : bool ,
344+ nuts_kwargs : dict ,
345+ compile_kwargs : dict ,
346+ idata_kwargs : dict | None ,
338347 ** kwargs ,
339348):
340- if nuts_sampler_kwargs is None :
341- nuts_sampler_kwargs = {}
349+ # Shallow copy dicts so we can safely matute them below
350+ nuts_kwargs = nuts_kwargs .copy ()
351+ compile_kwargs = compile_kwargs .copy ()
352+ idata_kwargs = {} if idata_kwargs is None else idata_kwargs .copy ()
353+
354+ if "backend" in nuts_kwargs :
355+ warnings .warn (
356+ "`backend` should be passed as a top-level argument to `pm.sample`, "
357+ "not nested in the NUTS step kwargs." ,
358+ FutureWarning ,
359+ )
360+ compile_kwargs ["mode" ] = get_mode (nuts_kwargs .pop ("backend" ))
361+
362+ if "gradient_backend" in nuts_kwargs :
363+ warnings .warn (
364+ "`gradient_backend` should be passed via `compile_kwargs` to `pm.sample`, "
365+ "not nested in the NUTS step kwargs." ,
366+ FutureWarning ,
367+ )
368+ compile_kwargs ["gradient_backend" ] = nuts_kwargs .pop ("gradient_backend" )
342369
343370 if sampler == "nutpie" :
344- try :
345- import nutpie
346- except ImportError as err :
371+ if not NUTPIE_INSTALLED :
347372 raise ImportError (
348373 "nutpie not found. Install it with conda install -c conda-forge nutpie"
349- ) from err
350-
351- if initvals is not None :
352- warnings .warn (
353- "`initvals` are currently not passed to nutpie sampler. "
354- "Use `init_mean` kwarg following nutpie specification instead." ,
355- UserWarning ,
374+ )
375+ import nutpie
376+
377+ if isinstance (initvals , dict ):
378+ compile_kwargs .setdefault ("initial_points" , initvals )
379+ elif initvals is not None :
380+ raise NotImplementedError (
381+ "nutpie does not support per-chain `initvals`. "
382+ "Pass a single dict, or use `nuts_sampler='pymc'`."
356383 )
357384
358- idata_kwargs = {} if idata_kwargs is None else {** idata_kwargs }
385+ # nuts-rs asserts `early_end < num_tune`, which panics when `tune == 0`.
386+ if tune == 0 :
387+ tune = 1
388+
389+ if "max_treedepth" in nuts_kwargs :
390+ nuts_kwargs ["maxdepth" ] = nuts_kwargs .pop ("max_treedepth" )
391+
359392 include_transformed = idata_kwargs .pop ("include_transformed" , False )
360393 log_likelihood = idata_kwargs .pop ("log_likelihood" , False )
361394 if idata_kwargs :
@@ -364,11 +397,8 @@ def _sample_external_nuts(
364397 UserWarning ,
365398 )
366399
367- compile_kwargs = {}
368- nuts_sampler_kwargs = nuts_sampler_kwargs .copy ()
369- for kwarg in ("backend" , "gradient_backend" ):
370- if kwarg in nuts_sampler_kwargs :
371- compile_kwargs [kwarg ] = nuts_sampler_kwargs .pop (kwarg )
400+ linker = get_mode (compile_kwargs .pop ("mode" , None )).linker
401+ compile_kwargs .setdefault ("backend" , "jax" if isinstance (linker , JAXLinker ) else "numba" )
372402 compiled_model = nutpie .compile_pymc_model (
373403 model ,
374404 var_names = var_names ,
@@ -390,11 +420,12 @@ def _sample_external_nuts(
390420 draws = draws ,
391421 tune = tune ,
392422 chains = chains ,
393- target_accept = target_accept ,
394- seed = _get_seeds_per_chain (random_seed , 1 )[0 ],
423+ cores = cores ,
424+ seed = int (random_seed [0 ]),
425+ save_warmup = not discard_tuned_samples ,
395426 progress_bar = False ,
396427 progress_callback = pb_manager .update ,
397- ** nuts_sampler_kwargs ,
428+ ** nuts_kwargs ,
398429 )
399430 t_sample = time .time () - t_start
400431 patch_nutpie_idata (
@@ -426,21 +457,25 @@ def _sample_external_nuts(
426457 elif sampler in ("numpyro" , "blackjax" ):
427458 import pymc .sampling .jax as pymc_jax
428459
460+ jax_nuts_kwargs = dict (nuts_kwargs )
461+ jax_target_accept = jax_nuts_kwargs .pop ("target_accept" , 0.8 )
429462 idata = pymc_jax .sample_jax_nuts (
430463 draws = draws ,
431464 tune = tune ,
432465 chains = chains ,
433- target_accept = target_accept ,
434- random_seed = random_seed ,
466+ target_accept = jax_target_accept ,
467+ # jax samplers take a single master seed; `random_seed` here is the
468+ # per-chain list produced by `pm.sample`, so use the first entry.
469+ random_seed = int (random_seed [0 ]),
435470 initvals = initvals ,
436471 model = model ,
437472 var_names = var_names ,
438473 progressbar = bool (progressbar ),
439474 quiet = quiet ,
440475 nuts_sampler = sampler ,
476+ nuts_kwargs = jax_nuts_kwargs ,
441477 idata_kwargs = idata_kwargs ,
442478 compute_convergence_checks = compute_convergence_checks ,
443- ** nuts_sampler_kwargs ,
444479 )
445480 return idata
446481
@@ -463,7 +498,7 @@ def sample(
463498 quiet : bool = False ,
464499 step = None ,
465500 var_names : Sequence [str ] | None = None ,
466- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
501+ nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] | None = None ,
467502 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
468503 init : str = "auto" ,
469504 jitter_max_retries : int = 10 ,
@@ -496,7 +531,7 @@ def sample(
496531 quiet : bool = False ,
497532 step = None ,
498533 var_names : Sequence [str ] | None = None ,
499- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
534+ nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] | None = None ,
500535 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
501536 init : str = "auto" ,
502537 jitter_max_retries : int = 10 ,
@@ -529,7 +564,7 @@ def sample(
529564 quiet : bool = False ,
530565 step = None ,
531566 var_names : Sequence [str ] | None = None ,
532- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
567+ nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] | None = None ,
533568 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
534569 init : str = "auto" ,
535570 jitter_max_retries : int = 10 ,
@@ -599,10 +634,11 @@ def sample(
599634 method will be used, if appropriate to the model.
600635 var_names : list of str, optional
601636 Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
602- nuts_sampler : str
637+ nuts_sampler : str, optional
603638 Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
604639 This requires the chosen sampler to be installed.
605640 All samplers, except "pymc", require the full model to be continuous.
641+ If ``None`` (default), "nutpie" is used if installed and can be compiled to the desired backend.
606642 blas_cores: int or "auto" or None, default = "auto"
607643 The total number of threads blas and openmp functions should use during sampling.
608644 Setting it to "auto" will ensure that the total number of active blas threads is the
@@ -651,8 +687,8 @@ def sample(
651687 idata_kwargs : dict, optional
652688 Keyword arguments for :func:`pymc.to_inference_data`
653689 nuts_sampler_kwargs : dict, optional
654- Keyword arguments for the sampling library that implements nuts.
655- Only used when an external sampler is specified via the `nuts_sampler` kwarg .
690+ Deprecated. Pass NUTS keyword arguments via `` nuts={...}`` instead
691+ (e.g. ``pm.sample(..., nuts={"target_accept": 0.9})``) .
656692 callback : function, default=None
657693 A function which gets called for every sample from the trace of a chain. The function is
658694 called with the trace and the current draw and will contain all samples for a single trace.
@@ -747,8 +783,18 @@ def sample(
747783 FutureWarning ,
748784 stacklevel = 2 ,
749785 )
750- if nuts_sampler_kwargs is None :
751- nuts_sampler_kwargs = {}
786+ if nuts_sampler_kwargs is not None :
787+ warnings .warn (
788+ "`nuts_sampler_kwargs` is deprecated. Pass NUTS keyword arguments via the "
789+ "`nuts={...}` argument to `pm.sample`." ,
790+ FutureWarning ,
791+ stacklevel = 2 ,
792+ )
793+ if "nuts" in kwargs :
794+ raise ValueError (
795+ "Cannot pass both `nuts_sampler_kwargs` and `nuts=`. Use `nuts=` only."
796+ )
797+ kwargs ["nuts" ] = nuts_sampler_kwargs
752798 if "target_accept" in kwargs :
753799 if "nuts" in kwargs and "target_accept" in kwargs ["nuts" ]:
754800 raise ValueError (
@@ -826,20 +872,67 @@ def sample(
826872
827873 compile_kwargs = resolve_backend_compile_kwargs (backend , compile_kwargs )
828874
875+ if nuts_sampler is None :
876+ # Try to use nutpie by default if no setting is clearly at odds.
877+ # Requires all model variables, numba or jax preference,
878+ # and must not conflict with pymc sample-only arguments.
879+ can_use_nutpie = (
880+ exclusive_nuts
881+ and not provided_steps
882+ and NUTPIE_INSTALLED
883+ and init == "auto"
884+ and return_inferencedata
885+ and trace is None
886+ and callback is None
887+ and (initvals is None or isinstance (initvals , dict ))
888+ and isinstance (get_mode (compile_kwargs .get ("mode" )).linker , NumbaLinker | JAXLinker )
889+ )
890+ nuts_sampler = "nutpie" if can_use_nutpie else "pymc"
891+ elif nuts_sampler != "pymc" and not exclusive_nuts :
892+ raise ValueError (
893+ f"`nuts_sampler={ nuts_sampler !r} ` requires all variables to be differentiable "
894+ "and not assigned to another step sampler."
895+ )
896+
829897 if nuts_sampler != "pymc" :
830- if not exclusive_nuts :
898+ if provided_steps :
899+ warnings .warn (
900+ f"The provided NUTS `step` is ignored by `nuts_sampler={ nuts_sampler !r} `; "
901+ "pass `nuts_sampler='pymc'` to use it." ,
902+ UserWarning ,
903+ stacklevel = 2 ,
904+ )
905+ if not return_inferencedata :
831906 raise ValueError (
832- "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
907+ f"`return_inferencedata=False` is not supported with `nuts_sampler={ nuts_sampler !r} `. "
908+ "External NUTS samplers can only return `InferenceData`."
833909 )
834-
910+ if trace is not None :
911+ raise ValueError (
912+ f"A custom `trace` backend is not supported with `nuts_sampler={ nuts_sampler !r} `. "
913+ "Trace backends (e.g. `ZarrTrace`) only work with `nuts_sampler='pymc'`."
914+ )
915+ if callback is not None :
916+ raise ValueError (
917+ f"`callback` is not supported with `nuts_sampler={ nuts_sampler !r} `. "
918+ "External NUTS samplers don't invoke per-draw callbacks."
919+ )
920+ if init != "auto" :
921+ warnings .warn (
922+ f"`init={ init !r} ` is ignored by `nuts_sampler={ nuts_sampler !r} `; "
923+ "the external sampler uses its own initialization." ,
924+ UserWarning ,
925+ stacklevel = 2 ,
926+ )
927+ nuts_kwargs = kwargs .pop ("nuts" , {})
835928 with joined_blas_limiter ():
836929 return _sample_external_nuts (
837930 sampler = nuts_sampler ,
838931 draws = draws ,
839932 tune = tune ,
840933 chains = chains ,
841- target_accept = kwargs . pop ( "nuts" , {}). get ( "target_accept" , 0.8 ) ,
842- random_seed = random_seed ,
934+ cores = cores ,
935+ random_seed = random_seed_list ,
843936 initvals = initvals ,
844937 model = model ,
845938 var_names = var_names ,
@@ -848,7 +941,9 @@ def sample(
848941 quiet = quiet ,
849942 idata_kwargs = idata_kwargs ,
850943 compute_convergence_checks = compute_convergence_checks ,
851- nuts_sampler_kwargs = nuts_sampler_kwargs ,
944+ discard_tuned_samples = discard_tuned_samples ,
945+ nuts_kwargs = nuts_kwargs ,
946+ compile_kwargs = compile_kwargs ,
852947 ** kwargs ,
853948 )
854949
0 commit comments