Skip to content

Commit d50dad1

Browse files
committed
Install nutpie from main
1 parent 074b867 commit d50dad1

3 files changed

Lines changed: 7 additions & 7 deletions

File tree

.github/workflows/tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ jobs:
328328
run: |
329329
pip install "pytensor @ git+https://github.com/pymc-devs/pytensor.git@v3" --no-deps
330330
pip install -e . --no-deps
331+
# Track nutpie main so we exercise unreleased changes. Expected to be
332+
# red until nutpie main drops its `arviz.InferenceData` references.
333+
pip install "nutpie @ git+https://github.com/pymc-devs/nutpie@main"
331334
python --version
332335
micromamba list
333336
- name: Run tests

conda-envs/environment-alternative-backends.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ dependencies:
1111
- cloudpickle
1212
- zarr>=2.5.0,<3
1313
- numba
14-
- nutpie >= 0.15.1
14+
# nutpie is installed from git in the CI workflow (see tests.yml) so we can track
15+
# the arviz 1.0 -compatible branch until an upstream release ships.
1516
# Jaxlib version must not be greater than jax version!
1617
- jax>=0.4.28
1718
- jaxlib>=0.4.28

tests/sampling/test_mcmc_external.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
from pymc.progress_bar import NutpieProgressBarManager
2626

2727

28-
# temporarily skip nutpie
29-
@pytest.mark.parametrize("nuts_sampler", ["pymc", "blackjax", "numpyro"])
30-
# @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
28+
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
3129
def test_external_nuts_sampler(recwarn, nuts_sampler):
3230
if nuts_sampler != "pymc":
3331
pytest.importorskip(nuts_sampler)
@@ -97,9 +95,7 @@ def test_step_args():
9795
npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
9896

9997

100-
# temporarily skip nutpie
101-
@pytest.mark.parametrize("nuts_sampler", ["pymc", "blackjax", "numpyro"])
102-
# @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
98+
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
10399
def test_sample_var_names(nuts_sampler):
104100
if nuts_sampler != "pymc":
105101
pytest.importorskip(nuts_sampler)

0 commit comments

Comments
 (0)