Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,15 @@ def wrapper_progress_bar(i, vals):
return progress_bar_fori_loop


def _fori_collect_loop(_body_fn, upper, init_val, collection, start_idx, thinning):
return fori_loop(
0,
upper,
lambda i, vals: _body_fn(i, *vals),
(init_val, collection, start_idx, thinning),
)


def fori_collect(
lower: int,
upper: int,
Expand Down Expand Up @@ -400,16 +409,11 @@ def map_fn(x):
collection = jax.tree.map(map_fn, init_val_transformed)

if not progbar:

def loop_fn(collection):
return fori_loop(
0,
upper,
lambda i, vals: _body_fn(i, *vals),
(init_val, collection, start_idx, thinning),
)

last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)
last_val, collection, _, _ = maybe_jit(
_fori_collect_loop,
static_argnums=(0, 1),
donate_argnums=3,
)(_body_fn, upper, init_val, collection, start_idx, thinning)

elif num_chains > 1:
progress_bar_fori_loop = progress_bar_factory(upper, num_chains, progress_rate)
Expand Down
32 changes: 32 additions & 0 deletions test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,38 @@ def f(x):
jax.tree.all(jax.tree.map(assert_allclose, tree, expected_tree))


def test_fori_collect_no_recompilation():
def f(x):
return x + 1

result1 = fori_collect(0, 10, f, jnp.array([0.0]), progbar=False)
result2 = fori_collect(0, 10, f, jnp.array([5.0]), progbar=False)

assert_allclose(result1, np.arange(1, 11).reshape(-1, 1))
assert_allclose(result2, np.arange(6, 16).reshape(-1, 1))


def test_fori_collect_repeated_mcmc_no_recompilation():
from numpyro.infer import MCMC, NUTS

def model():
numpyro.sample("x", dist.Normal(0, 1))

mcmc = MCMC(
NUTS(model), num_warmup=5, num_samples=10, num_chains=1, progress_bar=False
)

mcmc.run(random.PRNGKey(0))
samples1 = mcmc.get_samples()["x"]

mcmc.run(random.PRNGKey(1))
samples2 = mcmc.get_samples()["x"]

assert samples1.shape == (10,)
assert samples2.shape == (10,)
assert not np.allclose(samples1, samples2)


@pytest.mark.parametrize(
"pytree",
[
Expand Down
Loading