|
15 | 15 | from numpyro.distributions.transforms import AffineTransform, ExpTransform |
16 | 16 | import numpyro.handlers as handlers |
17 | 17 | from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO |
18 | | -from numpyro.infer.autoguide import AutoIAFNormal |
| 18 | +from numpyro.infer.autoguide import AutoDiagonalNormal, AutoIAFNormal |
19 | 19 | from numpyro.infer.reparam import ( |
20 | 20 | CircularReparam, |
21 | 21 | ExplicitReparam, |
@@ -228,6 +228,22 @@ def test_neutra_reparam_unobserved_model(): |
228 | 228 | reparam_model(data=None) |
229 | 229 |
|
230 | 230 |
|
| 231 | +def test_neutra_reparam_with_plate(): |
| 232 | + def model(): |
| 233 | + with numpyro.plate("N", 3, dim=-1): |
| 234 | + x = numpyro.sample("x", dist.Normal(0, 1)) |
| 235 | + assert x.shape == (3,) |
| 236 | + |
| 237 | + guide = AutoDiagonalNormal(model) |
| 238 | + svi = SVI(model, guide, Adam(1e-3), Trace_ELBO()) |
| 239 | + svi_state = svi.init(random.PRNGKey(0)) |
| 240 | + params = svi.get_params(svi_state) |
| 241 | + neutra = NeuTraReparam(guide, params) |
| 242 | + reparam_model = neutra.reparam(model) |
| 243 | + with handlers.seed(rng_seed=0): |
| 244 | + reparam_model() |
| 245 | + |
| 246 | + |
231 | 247 | @pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str) |
232 | 248 | @pytest.mark.parametrize("centered", [0.0, 0.6, 1.0, None]) |
233 | 249 | @pytest.mark.parametrize("dist_type", ["Normal", "StudentT"]) |
|
0 commit comments