Skip to content

Commit ed33d37

Browse files
committed
Handle deprecated log_likelihood path in nutpie
And make message more clear
1 parent 5305585 commit ed33d37

2 files changed

Lines changed: 22 additions & 4 deletions

File tree

pymc/sampling/jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -651,8 +651,8 @@ def sample_jax_nuts(
651651

652652
if idata_kwargs.pop("log_likelihood", False):
653653
warnings.warn(
654-
"`passing log_likelihood` is deprecated and will be removed in future versions. Use "
655-
":func:`pymc.compute_log_likelihood` instead.",
654+
"Passing `log_likelihood` via `idata_kwargs` is deprecated and will be removed "
655+
"in future versions. Call `pm.compute_log_likelihood(idata)` instead.",
656656
FutureWarning,
657657
stacklevel=3,
658658
)

pymc/sampling/mcmc.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def _sample_external_nuts(
373373

374374
idata_kwargs = {} if idata_kwargs is None else {**idata_kwargs}
375375
include_transformed = idata_kwargs.pop("include_transformed", False)
376+
log_likelihood = idata_kwargs.pop("log_likelihood", False)
376377
if idata_kwargs:
377378
warnings.warn(
378379
f"`idata_kwargs` keys {sorted(idata_kwargs)} are currently ignored by the nutpie sampler",
@@ -404,6 +405,23 @@ def _sample_external_nuts(
404405
sampling_time=t_sample,
405406
include_transformed=include_transformed,
406407
)
408+
if log_likelihood:
409+
warnings.warn(
410+
"Passing `log_likelihood` via `idata_kwargs` is deprecated and will be removed "
411+
"in future versions. Call `pm.compute_log_likelihood(idata)` instead.",
412+
FutureWarning,
413+
stacklevel=2,
414+
)
415+
from pymc.stats.log_density import compute_log_likelihood
416+
417+
idata = compute_log_likelihood(
418+
idata,
419+
var_names=None if log_likelihood is True else log_likelihood,
420+
extend_inferencedata=True,
421+
model=model,
422+
sample_dims=["chain", "draw"],
423+
progressbar=False,
424+
)
407425
return idata
408426

409427
elif sampler in ("numpyro", "blackjax"):
@@ -1115,8 +1133,8 @@ def _sample_return(
11151133
log_likelihood = idata_kwargs.pop("log_likelihood", False)
11161134
if log_likelihood:
11171135
warnings.warn(
1118-
"`passing log_likelihood` is deprecated and will be removed in future versions. Use "
1119-
":func:`pymc.compute_log_likelihood` instead.",
1136+
"Passing `log_likelihood` via `idata_kwargs` is deprecated and will be removed "
1137+
"in future versions. Call `pm.compute_log_likelihood(idata)` instead.",
11201138
FutureWarning,
11211139
stacklevel=2,
11221140
)

0 commit comments

Comments
 (0)