diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index f3826939e..a1f7cb5d0 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -15,6 +15,7 @@ from numpyro.diagnostics import print_summary from numpyro.util import ( + _is_under_jax_transform, cached_by, find_stack_level, fori_collect, @@ -447,7 +448,7 @@ def _postprocess_fn(state, args, kwargs): ) fns = sample_fn, postprocess_fn - if key is not None: + if key is not None and not _is_under_jax_transform(): self._cache[key] = fns return fns @@ -539,7 +540,8 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): kwargs = jax.tree.map(lambda x: _hashable(x), tuple(sorted(kwargs.items()))) key = rng_key + args + kwargs try: - self._init_state_cache[key] = self._last_state + if not _is_under_jax_transform(): + self._init_state_cache[key] = self._last_state # If unhashable arguments are provided, return None except TypeError: pass diff --git a/numpyro/util.py b/numpyro/util.py index f05f541db..1722914a3 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -193,12 +193,28 @@ def identity(x, *args, **kwargs): return x +def _is_under_jax_transform(): + """Check if we are currently under a JAX transform (e.g. jit, vmap). + + When under a transform, caching functions that close over traced values + can cause tracer leaks (see https://github.com/pyro-ppl/numpyro/issues/2000). + """ + from jax._src.core import trace_state_clean + + return not trace_state_clean() + + def cached_by(outer_fn, *keys): # Restrict cache size to prevent ref cycles. max_size = 8 outer_fn._cache = getattr(outer_fn, "_cache", OrderedDict()) def _wrapped(fn): + # Skip caching when inside a JAX tracing context to avoid + # tracer leaks (https://github.com/pyro-ppl/numpyro/issues/2000). + if _is_under_jax_transform(): + return fn + fn_cache = outer_fn._cache hashkeys = (*keys, fn.__name__) if hashkeys in fn_cache: diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 1dd46b67e..39637e23d 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -605,10 +605,6 @@ def model(labels): @pytest.mark.skipif( "CI" in os.environ, reason="Compiling time the whole sampling process is slow." ) -@pytest.mark.xfail( - os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", - reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000", -) def test_chain_inside_jit(kernel_cls, chain_method): # NB: this feature is useful for consensus MC. # Caution: compiling time will be slow (~ 90s) @@ -665,10 +661,6 @@ def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob) @pytest.mark.skipif( "CI" in os.environ, reason="Compiling time the whole sampling process is slow." ) -@pytest.mark.xfail( - os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", - reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000", -) def test_chain_jit_args_smoke(chain_method, compile_args): def model(data): concentration = jnp.array([1.0, 1.0, 1.0]) @@ -782,10 +774,6 @@ def potential_fn(z): @pytest.mark.parametrize("jit_args", [False, True]) @pytest.mark.parametrize("shape", [50, 100]) -@pytest.mark.xfail( - os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", - reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000", -) def test_reuse_mcmc_run(jit_args, shape): y1 = np.random.normal(3, 0.1, (100,)) y2 = np.random.normal(-3, 0.1, (shape,)) @@ -806,10 +794,6 @@ def model(y_obs): @pytest.mark.parametrize("jit_args", [False, True]) -@pytest.mark.xfail( - os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", - reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000", -) def test_model_with_multiple_exec_paths(jit_args): def model(a=None, b=None, z=None): int_term = numpyro.sample("a", dist.Normal(0.0, 0.2)) @@ -839,6 +823,46 @@ def model(a=None, b=None, z=None): assert set(mcmc.get_samples()) == {"a", "x", "y", "sigma"} +def test_mcmc_inside_jit_no_tracer_leak(): + """Regression test for https://github.com/pyro-ppl/numpyro/issues/2000""" + from numpyro.infer.mcmc import _collect_and_postprocess + from numpyro.util import fori_collect + + def model(data): + concentration = jnp.array([1.0, 1.0, 1.0]) + p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration)) + numpyro.sample("obs", dist.Categorical(p_latent), obs=data) + + @jit + def get_samples(rng_key, data): + kernel = HMC( + model, step_size=1.0, trajectory_length=1.0, target_accept_prob=0.8 + ) + mcmc = MCMC( + kernel, + num_warmup=5, + num_samples=10, + num_chains=1, + chain_method="sequential", + progress_bar=False, + ) + mcmc.run(rng_key, data) + return mcmc.get_samples() + + data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.key(1), (100,)) + samples = get_samples(random.key(2), data) + assert "p_latent" in samples + + # Verify no traced values leaked into module-level caches + for cached_fn in [_collect_and_postprocess, fori_collect]: + cache = getattr(cached_fn, "_cache", {}) + leaves = jax.tree.leaves(list(cache.keys()) + list(cache.values())) + for leaf in leaves: + assert not isinstance(leaf, jax.core.Tracer), ( + f"Tracer leaked into {cached_fn.__name__}._cache" + ) + + @pytest.mark.parametrize("num_chains", [1, 2]) @pytest.mark.parametrize("chain_method", ["parallel", "sequential", "vectorized"]) @pytest.mark.parametrize("progress_bar", [True, False])