Skip to content

Commit a92ea2a

Browse files
committed
use module-level loop fn instead of cached closure
1 parent 9539b25 commit a92ea2a

2 files changed

Lines changed: 15 additions & 33 deletions

File tree

numpyro/util.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,15 @@ def wrapper_progress_bar(i, vals):
309309
return progress_bar_fori_loop
310310

311311

312+
def _fori_collect_loop(_body_fn, upper, init_val, collection, start_idx, thinning):
313+
return fori_loop(
314+
0,
315+
upper,
316+
lambda i, vals: _body_fn(i, *vals),
317+
(init_val, collection, start_idx, thinning),
318+
)
319+
320+
312321
def fori_collect(
313322
lower: int,
314323
upper: int,
@@ -400,20 +409,11 @@ def map_fn(x):
400409
collection = jax.tree.map(map_fn, init_val_transformed)
401410

402411
if not progbar:
403-
# Cache loop_fn so jit() reuses the compiled trace across calls.
404-
# Without this, loop_fn is a fresh closure each call and jit recompiles.
405-
@cached_by(fori_collect, body_fun, transform, upper, start_idx, thinning)
406-
def loop_fn(init_val, collection):
407-
return fori_loop(
408-
0,
409-
upper,
410-
lambda i, vals: _body_fn(i, *vals),
411-
(init_val, collection, start_idx, thinning),
412-
)
413-
414-
last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=1)(
415-
init_val, collection
416-
)
412+
last_val, collection, _, _ = maybe_jit(
413+
_fori_collect_loop,
414+
static_argnums=(0, 1),
415+
donate_argnums=3,
416+
)(_body_fn, upper, init_val, collection, start_idx, thinning)
417417

418418
elif num_chains > 1:
419419
progress_bar_fori_loop = progress_bar_factory(upper, num_chains, progress_rate)

test/test_util.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,35 +69,17 @@ def f(x):
6969

7070

7171
def test_fori_collect_no_recompilation():
72-
"""Regression test: repeated fori_collect(..., progbar=False) must reuse
73-
the cached loop_fn and not trigger JIT recompilation.
74-
75-
Before the fix, loop_fn was a fresh closure each call, causing full
76-
XLA recompilation on every MCMC.run().
77-
"""
78-
7972
def f(x):
8073
return x + 1
8174

82-
init_val = jnp.array([0.0])
83-
84-
# First call
85-
result1 = fori_collect(0, 10, f, init_val, progbar=False)
86-
87-
# Second call with same config but different init_val
75+
result1 = fori_collect(0, 10, f, jnp.array([0.0]), progbar=False)
8876
result2 = fori_collect(0, 10, f, jnp.array([5.0]), progbar=False)
8977

90-
# Results must be correct (not stale from cached closure)
9178
assert_allclose(result1, np.arange(1, 11).reshape(-1, 1))
9279
assert_allclose(result2, np.arange(6, 16).reshape(-1, 1))
9380

94-
# Verify the cache exists and has entries
95-
assert hasattr(fori_collect, "_cache")
96-
assert len(fori_collect._cache) > 0
97-
9881

9982
def test_fori_collect_repeated_mcmc_no_recompilation():
100-
"""End-to-end regression coverage through repeated MCMC.run()."""
10183
from numpyro.infer import MCMC, NUTS
10284

10385
def model():

0 commit comments

Comments
 (0)