Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pymc_extras/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import arviz as az
import numpy as np
import xarray as xr

from numpy.typing import NDArray
from scipy.special import logsumexp
Expand Down Expand Up @@ -99,12 +100,11 @@ def importance_sampling(
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
)
match method:
case "psis":
replace = False
logiw, pareto_k = az.psislw(logiw)
case "psir":
replace = True
logiw, pareto_k = az.psislw(logiw)
case "psis" | "psir":
replace = method == "psir"
logiw_da, pareto_k_da = az.psislw(xr.DataArray(logiw, dims="__sample__"))
logiw = logiw_da.values
pareto_k = float(pareto_k_da)
case "identity":
replace = False
pareto_k = None
Expand Down
4 changes: 4 additions & 0 deletions tests/pathfinder/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def test_pathfinder_importance_sampling(importance_sampling):
assert idata.posterior["tau"].shape == (1, num_draws)
assert idata.posterior["theta"].shape == (1, num_draws, 8)

if importance_sampling in ("psis", "psir"):
pareto_k = idata.sample_stats["pareto_k"].item()
assert isinstance(pareto_k, float) and np.isfinite(pareto_k)


def test_pathfinder_initvals():
# Run a model with an ordered transform that will fail unless initvals are in place
Expand Down
Loading