Skip to content

Commit e10f664

Browse files
authored
fix repeated recompilation in no-progress-bar fori_collect path (#2171)
* fix repeated recompilation in no-progress-bar fori_collect path * fix ruff formatting * use module-level loop fn instead of cached closure
1 parent 0152376 commit e10f664

2 files changed

Lines changed: 46 additions & 10 deletions

File tree

numpyro/util.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,15 @@ def wrapper_progress_bar(i, vals):
309309
return progress_bar_fori_loop
310310

311311

312+
def _fori_collect_loop(_body_fn, upper, init_val, collection, start_idx, thinning):
313+
return fori_loop(
314+
0,
315+
upper,
316+
lambda i, vals: _body_fn(i, *vals),
317+
(init_val, collection, start_idx, thinning),
318+
)
319+
320+
312321
def fori_collect(
313322
lower: int,
314323
upper: int,
@@ -400,16 +409,11 @@ def map_fn(x):
400409
collection = jax.tree.map(map_fn, init_val_transformed)
401410

402411
if not progbar:
403-
404-
def loop_fn(collection):
405-
return fori_loop(
406-
0,
407-
upper,
408-
lambda i, vals: _body_fn(i, *vals),
409-
(init_val, collection, start_idx, thinning),
410-
)
411-
412-
last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)
412+
last_val, collection, _, _ = maybe_jit(
413+
_fori_collect_loop,
414+
static_argnums=(0, 1),
415+
donate_argnums=3,
416+
)(_body_fn, upper, init_val, collection, start_idx, thinning)
413417

414418
elif num_chains > 1:
415419
progress_bar_fori_loop = progress_bar_factory(upper, num_chains, progress_rate)

test/test_util.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,38 @@ def f(x):
6868
jax.tree.all(jax.tree.map(assert_allclose, tree, expected_tree))
6969

7070

71+
def test_fori_collect_no_recompilation():
72+
def f(x):
73+
return x + 1
74+
75+
result1 = fori_collect(0, 10, f, jnp.array([0.0]), progbar=False)
76+
result2 = fori_collect(0, 10, f, jnp.array([5.0]), progbar=False)
77+
78+
assert_allclose(result1, np.arange(1, 11).reshape(-1, 1))
79+
assert_allclose(result2, np.arange(6, 16).reshape(-1, 1))
80+
81+
82+
def test_fori_collect_repeated_mcmc_no_recompilation():
83+
from numpyro.infer import MCMC, NUTS
84+
85+
def model():
86+
numpyro.sample("x", dist.Normal(0, 1))
87+
88+
mcmc = MCMC(
89+
NUTS(model), num_warmup=5, num_samples=10, num_chains=1, progress_bar=False
90+
)
91+
92+
mcmc.run(random.PRNGKey(0))
93+
samples1 = mcmc.get_samples()["x"]
94+
95+
mcmc.run(random.PRNGKey(1))
96+
samples2 = mcmc.get_samples()["x"]
97+
98+
assert samples1.shape == (10,)
99+
assert samples2.shape == (10,)
100+
assert not np.allclose(samples1, samples2)
101+
102+
71103
@pytest.mark.parametrize(
72104
"pytree",
73105
[

0 commit comments

Comments
 (0)