@@ -373,6 +373,7 @@ def _sample_external_nuts(
373373
374374 idata_kwargs = {} if idata_kwargs is None else {** idata_kwargs }
375375 include_transformed = idata_kwargs .pop ("include_transformed" , False )
376+ log_likelihood = idata_kwargs .pop ("log_likelihood" , False )
376377 if idata_kwargs :
377378 warnings .warn (
378379 f"`idata_kwargs` keys { sorted (idata_kwargs )} are currently ignored by the nutpie sampler" ,
@@ -404,6 +405,23 @@ def _sample_external_nuts(
404405 sampling_time = t_sample ,
405406 include_transformed = include_transformed ,
406407 )
408+ if log_likelihood :
409+ warnings .warn (
410+ "Passing `log_likelihood` via `idata_kwargs` is deprecated and will be removed "
411+ "in future versions. Call `pm.compute_log_likelihood(idata)` instead." ,
412+ FutureWarning ,
413+ stacklevel = 2 ,
414+ )
415+ from pymc .stats .log_density import compute_log_likelihood
416+
417+ idata = compute_log_likelihood (
418+ idata ,
419+ var_names = None if log_likelihood is True else log_likelihood ,
420+ extend_inferencedata = True ,
421+ model = model ,
422+ sample_dims = ["chain" , "draw" ],
423+ progressbar = False ,
424+ )
407425 return idata
408426
409427 elif sampler in ("numpyro" , "blackjax" ):
@@ -1115,8 +1133,8 @@ def _sample_return(
11151133 log_likelihood = idata_kwargs .pop ("log_likelihood" , False )
11161134 if log_likelihood :
11171135 warnings .warn (
1118- "`passing log_likelihood` is deprecated and will be removed in future versions. Use "
1119- ":func:`pymc .compute_log_likelihood` instead." ,
1136+ "Passing ` log_likelihood` via `idata_kwargs` is deprecated and will be removed "
1137+ "in future versions. Call `pm .compute_log_likelihood(idata) ` instead." ,
11201138 FutureWarning ,
11211139 stacklevel = 2 ,
11221140 )
0 commit comments