Skip to content

Commit df97bec

Browse files
committed
Install nutpie from main
1 parent ddc17e3 commit df97bec

3 files changed

Lines changed: 6 additions & 4 deletions

File tree

.github/workflows/tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ jobs:
328328
run: |
329329
pip install "pytensor @ git+https://github.com/pymc-devs/pytensor.git@main" --no-deps
330330
pip install -e . --no-deps
331+
# Track nutpie main so we exercise unreleased changes. Expected to be
332+
# red until nutpie main drops its `arviz.InferenceData` references.
333+
pip install "nutpie @ git+https://github.com/pymc-devs/nutpie@main"
331334
python --version
332335
micromamba list
333336
- name: Run tests

conda-envs/environment-alternative-backends.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ dependencies:
1111
- cloudpickle
1212
- zarr>=2.5.0,<3
1313
- numba
14-
- nutpie >= 0.15.1
14+
# nutpie is installed from git in the CI workflow (see tests.yml) so we can track
15+
# the arviz 1.0 -compatible branch until an upstream release ships.
1516
# Jaxlib version must not be greater than jax version!
1617
- jax>=0.4.28
1718
- jaxlib>=0.4.28

tests/sampling/test_mcmc_external.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ def test_step_args():
100100
npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
101101

102102

103-
# temporarily skip nutpie
104-
@pytest.mark.parametrize("nuts_sampler", ["pymc", "blackjax", "numpyro"])
105-
# @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
103+
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
106104
def test_sample_var_names(nuts_sampler):
107105
if nuts_sampler != "pymc":
108106
pytest.importorskip(nuts_sampler)

0 commit comments

Comments
 (0)