Skip to content

Commit 0a2a4a5

Browse files
committed
DO NOT MERGE: pip-install nutpie from main in all CI jobs
Needed until conda-forge ships a nutpie release that supports arviz>=1.0. Until then the conda env can't co-install nutpie and pymc's arviz pin, so we get nutpie via pip-from-git-main in every CI job. Once a compatible nutpie is on conda-forge, drop this commit and put `nutpie>=<that version>` back into `environment-test.yml` and `windows-environment-test.yml`.
1 parent c925cf0 commit 0a2a4a5

3 files changed

Lines changed: 85 additions & 2 deletions

File tree

.github/workflows/tests.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ jobs:
168168
run: |
169169
pip install "pytensor @ git+https://github.com/pymc-devs/pytensor.git@v3" --no-deps
170170
pip install -e . --no-deps
171+
# Track nutpie main to catch any incompatibility with pymc early.
172+
pip install "nutpie @ git+https://github.com/pymc-devs/nutpie@main"
171173
python --version
172174
micromamba list
173175
- name: Run tests
@@ -218,6 +220,8 @@ jobs:
218220
run: |
219221
pip install "pytensor @ git+https://github.com/pymc-devs/pytensor.git@v3" --no-deps
220222
pip install -e . --no-deps
223+
# Track nutpie main to catch any incompatibility with pymc early.
224+
pip install "nutpie @ git+https://github.com/pymc-devs/nutpie@main"
221225
python --version
222226
micromamba list
223227
- name: Run tests
@@ -276,6 +280,8 @@ jobs:
276280
run: |
277281
pip install "pytensor @ git+https://github.com/pymc-devs/pytensor.git@v3" --no-deps
278282
pip install -e . --no-deps
283+
# Track nutpie main to catch any incompatibility with pymc early.
284+
pip install "nutpie @ git+https://github.com/pymc-devs/nutpie@main"
279285
python --version
280286
micromamba list
281287
- name: Run tests
@@ -378,6 +384,8 @@ jobs:
378384
run: |
379385
pip install "pytensor @ git+https://github.com/pymc-devs/pytensor.git@v3" --no-deps
380386
pip install -e . --no-deps
387+
# Track nutpie main to catch any incompatibility with pymc early.
388+
pip install "nutpie @ git+https://github.com/pymc-devs/nutpie@main"
381389
python --version
382390
micromamba list
383391
- name: Run tests

pymc/sampling/mcmc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -908,11 +908,12 @@ def sample(
908908

909909
if nuts_sampler is None:
910910
# Nutpie must take all the variables and can only compile to Numba or JAX.
911-
# When the user asks for a MultiTrace or provides a pymc-backend trace object
912-
# (e.g. `ZarrTrace`), route to the pymc sampler.
911+
# When the user asks for a MultiTrace, a pymc-only `init`, or provides a
912+
# pymc-backend trace object (e.g. `ZarrTrace`), route to the pymc sampler.
913913
can_use_nutpie = (
914914
exclusive_nuts
915915
and NUTPIE_INSTALLED
916+
and init == "auto"
916917
and return_inferencedata
917918
and trace is None
918919
and callback is None

tests/sampling/test_mcmc_external.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import unittest.mock as mock
16+
import warnings
1617

1718
from types import SimpleNamespace
1819

@@ -214,3 +215,76 @@ def test_nutpie_end_to_end():
214215
assert {"posterior", "sample_stats", "observed_data"} <= set(idata.children)
215216
assert set(idata.posterior.data_vars) == {"mu", "sigma"}
216217
assert idata.posterior.sizes == {"chain": 2, "draw": 20}
218+
219+
220+
class TestExternalSamplerKwargCompat:
221+
"""Validate how `pm.sample` handles kwargs that external samplers don't fully honor."""
222+
223+
@pytest.fixture
224+
def model(self):
225+
with Model() as m:
226+
Normal("x")
227+
return m
228+
229+
@pytest.fixture
230+
def patched_sampler(self):
231+
with (
232+
mock.patch("pymc.sampling.mcmc.NUTPIE_INSTALLED", True),
233+
mock.patch("pymc.sampling.mcmc._sample_external_nuts") as mock_ext,
234+
):
235+
yield mock_ext
236+
237+
_BASE_KWARGS = {
238+
"tune": 2,
239+
"draws": 2,
240+
"chains": 1,
241+
"progressbar": False,
242+
"compile_kwargs": {"mode": "NUMBA"},
243+
}
244+
245+
@pytest.mark.parametrize(
246+
"extra",
247+
[
248+
{"return_inferencedata": False},
249+
{"trace": object()},
250+
{"callback": lambda **kw: None},
251+
],
252+
ids=["return_inferencedata_false", "custom_trace", "callback"],
253+
)
254+
def test_explicit_nutpie_raises_on_incompatible(self, model, patched_sampler, extra):
255+
# Filter the separate FutureWarning for return_inferencedata=False; we only
256+
# care that pm.sample raises the external-sampler ValueError.
257+
with warnings.catch_warnings():
258+
warnings.filterwarnings("ignore", ".*return_inferencedata=False.*", FutureWarning)
259+
with model, pytest.raises(ValueError, match="`nuts_sampler='nutpie'`"):
260+
sample(nuts_sampler="nutpie", **self._BASE_KWARGS, **extra)
261+
patched_sampler.assert_not_called()
262+
263+
def test_explicit_nutpie_warns_on_non_default_init(self, model, patched_sampler):
264+
with model, pytest.warns(UserWarning, match="`init='advi'` is ignored"):
265+
sample(nuts_sampler="nutpie", init="advi", **self._BASE_KWARGS)
266+
patched_sampler.assert_called_once()
267+
268+
@pytest.mark.parametrize(
269+
"extra",
270+
[
271+
{"return_inferencedata": False},
272+
{"trace": object()},
273+
{"callback": lambda **kw: None},
274+
{"init": "advi"},
275+
],
276+
ids=["return_inferencedata_false", "custom_trace", "callback", "non_default_init"],
277+
)
278+
def test_auto_selection_falls_back_to_pymc(self, model, patched_sampler, extra):
279+
# Each of these kwargs disqualifies nutpie auto-selection; pm.sample should
280+
# route to the pymc sampler (external stub never called) instead of raising.
281+
with warnings.catch_warnings():
282+
warnings.filterwarnings("ignore", ".*return_inferencedata=False.*", FutureWarning)
283+
with model:
284+
try:
285+
sample(**self._BASE_KWARGS, **extra)
286+
except Exception:
287+
# The pymc path may error on the stand-in objects (dummy trace, etc.);
288+
# we only care that nutpie wasn't dispatched.
289+
pass
290+
patched_sampler.assert_not_called()

0 commit comments

Comments
 (0)