Skip to content

Commit 04b2e88

Browse files
authored
Add a simple mcmc test (#22)
1 parent 1984c5d commit 04b2e88

3 files changed

Lines changed: 44 additions & 0 deletions

File tree

pyroapi/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from .test_mcmc import * # noqa F401
45
from .test_svi import * # noqa F401

pyroapi/tests/test_mcmc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from pyroapi.dispatch import distributions as dist
5+
from pyroapi.dispatch import infer, pyro
6+
7+
# Note that the backend arg to these tests must be provided as a
8+
# user-defined fixture that sets the pyro_backend. For demonstration,
9+
# see test/conftest.py.
10+
11+
12+
def assert_ok(model, *args, **kwargs):
13+
"""
14+
Assert that inference works without warnings or errors.
15+
"""
16+
pyro.get_param_store().clear()
17+
kernel = infer.NUTS(model)
18+
mcmc = infer.MCMC(kernel, num_samples=2, warmup_steps=2)
19+
mcmc.run(*args, **kwargs)
20+
21+
22+
def test_mcmc_run_ok(backend):
23+
if backend not in ["pyro", "numpy"]:
24+
return
25+
26+
def model():
27+
pyro.sample("x", dist.Normal(0, 1))
28+
29+
assert_ok(model)

test/test_tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,17 @@ def backend(request):
2828
pytest.importorskip(PACKAGE_NAME[request.param])
2929
with pyro_backend(request.param):
3030
yield
31+
32+
33+
# TODO(fehiepsi): Remove the following when the test passes in numpyro.
34+
_test_mcmc_run_ok = test_mcmc_run_ok # noqa F405
35+
36+
37+
@pytest.mark.parametrize("backend", [
38+
"pyro",
39+
pytest.param("numpy", marks=[
40+
pytest.mark.xfail(reason="numpyro signature for MCMC is not consistent.")])])
41+
def test_mcmc_run_ok(backend):
42+
pytest.importorskip(PACKAGE_NAME[backend])
43+
with pyro_backend(backend):
44+
_test_mcmc_run_ok(backend)

0 commit comments

Comments
 (0)