File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11# Copyright Contributors to the Pyro project.
22# SPDX-License-Identifier: Apache-2.0
33
4+ from .test_mcmc import * # noqa F401
45from .test_svi import * # noqa F401
Original file line number Diff line number Diff line change 1+ # Copyright Contributors to the Pyro project.
2+ # SPDX-License-Identifier: Apache-2.0
3+
4+ from pyroapi .dispatch import distributions as dist
5+ from pyroapi .dispatch import infer , pyro
6+
7+ # Note that the backend arg to these tests must be provided as a
8+ # user-defined fixture that sets the pyro_backend. For demonstration,
9+ # see test/conftest.py.
10+
11+
12+ def assert_ok (model , * args , ** kwargs ):
13+ """
14+ Assert that inference works without warnings or errors.
15+ """
16+ pyro .get_param_store ().clear ()
17+ kernel = infer .NUTS (model )
18+ mcmc = infer .MCMC (kernel , num_samples = 2 , warmup_steps = 2 )
19+ mcmc .run (* args , ** kwargs )
20+
21+
22+ def test_mcmc_run_ok (backend ):
23+ if backend not in ["pyro" , "numpy" ]:
24+ return
25+
26+ def model ():
27+ pyro .sample ("x" , dist .Normal (0 , 1 ))
28+
29+ assert_ok (model )
Original file line number Diff line number Diff line change @@ -28,3 +28,17 @@ def backend(request):
2828 pytest .importorskip (PACKAGE_NAME [request .param ])
2929 with pyro_backend (request .param ):
3030 yield
31+
32+
33+ # TODO(fehiepsi): Remove the following when the test passes in numpyro.
34+ _test_mcmc_run_ok = test_mcmc_run_ok # noqa F405
35+
36+
37+ @pytest .mark .parametrize ("backend" , [
38+ "pyro" ,
39+ pytest .param ("numpy" , marks = [
40+ pytest .mark .xfail (reason = "numpyro signature for MCMC is not consistent." )])])
41+ def test_mcmc_run_ok (backend ):
42+ pytest .importorskip (PACKAGE_NAME [backend ])
43+ with pyro_backend (backend ):
44+ _test_mcmc_run_ok (backend )
You can’t perform that action at this time.
0 commit comments