Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from numpyro.diagnostics import print_summary
from numpyro.util import (
_is_under_jax_transform,
cached_by,
find_stack_level,
fori_collect,
Expand Down Expand Up @@ -447,7 +448,7 @@ def _postprocess_fn(state, args, kwargs):
)

fns = sample_fn, postprocess_fn
if key is not None:
if key is not None and not _is_under_jax_transform():
self._cache[key] = fns
return fns

Expand Down Expand Up @@ -539,7 +540,8 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
kwargs = jax.tree.map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
key = rng_key + args + kwargs
try:
self._init_state_cache[key] = self._last_state
if not _is_under_jax_transform():
self._init_state_cache[key] = self._last_state
# If unhashable arguments are provided, return None
except TypeError:
pass
Expand Down
16 changes: 16 additions & 0 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,28 @@ def identity(x, *args, **kwargs):
return x


def _is_under_jax_transform():
"""Check if we are currently under a JAX transform (e.g. jit, vmap).

When under a transform, caching functions that close over traced values
can cause tracer leaks (see https://github.com/pyro-ppl/numpyro/issues/2000).
"""
from jax._src.core import trace_state_clean

return not trace_state_clean()


def cached_by(outer_fn, *keys):
# Restrict cache size to prevent ref cycles.
max_size = 8
outer_fn._cache = getattr(outer_fn, "_cache", OrderedDict())

def _wrapped(fn):
# Skip caching when inside a JAX tracing context to avoid
# tracer leaks (https://github.com/pyro-ppl/numpyro/issues/2000).
if _is_under_jax_transform():
return fn

fn_cache = outer_fn._cache
hashkeys = (*keys, fn.__name__)
if hashkeys in fn_cache:
Expand Down
56 changes: 40 additions & 16 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,6 @@ def model(labels):
@pytest.mark.skipif(
"CI" in os.environ, reason="Compiling time the whole sampling process is slow."
)
@pytest.mark.xfail(
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000",
)
def test_chain_inside_jit(kernel_cls, chain_method):
# NB: this feature is useful for consensus MC.
# Caution: compiling time will be slow (~ 90s)
Expand Down Expand Up @@ -665,10 +661,6 @@ def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob)
@pytest.mark.skipif(
"CI" in os.environ, reason="Compiling time the whole sampling process is slow."
)
@pytest.mark.xfail(
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000",
)
def test_chain_jit_args_smoke(chain_method, compile_args):
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
Expand Down Expand Up @@ -782,10 +774,6 @@ def potential_fn(z):

@pytest.mark.parametrize("jit_args", [False, True])
@pytest.mark.parametrize("shape", [50, 100])
@pytest.mark.xfail(
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000",
)
def test_reuse_mcmc_run(jit_args, shape):
y1 = np.random.normal(3, 0.1, (100,))
y2 = np.random.normal(-3, 0.1, (shape,))
Expand All @@ -806,10 +794,6 @@ def model(y_obs):


@pytest.mark.parametrize("jit_args", [False, True])
@pytest.mark.xfail(
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000",
)
def test_model_with_multiple_exec_paths(jit_args):
def model(a=None, b=None, z=None):
int_term = numpyro.sample("a", dist.Normal(0.0, 0.2))
Expand Down Expand Up @@ -839,6 +823,46 @@ def model(a=None, b=None, z=None):
assert set(mcmc.get_samples()) == {"a", "x", "y", "sigma"}


def test_mcmc_inside_jit_no_tracer_leak():
"""Regression test for https://github.com/pyro-ppl/numpyro/issues/2000"""
from numpyro.infer.mcmc import _collect_and_postprocess
from numpyro.util import fori_collect

def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
numpyro.sample("obs", dist.Categorical(p_latent), obs=data)

@jit
def get_samples(rng_key, data):
kernel = HMC(
model, step_size=1.0, trajectory_length=1.0, target_accept_prob=0.8
)
mcmc = MCMC(
kernel,
num_warmup=5,
num_samples=10,
num_chains=1,
chain_method="sequential",
progress_bar=False,
)
mcmc.run(rng_key, data)
return mcmc.get_samples()

data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.key(1), (100,))
samples = get_samples(random.key(2), data)
assert "p_latent" in samples

# Verify no traced values leaked into module-level caches
for cached_fn in [_collect_and_postprocess, fori_collect]:
cache = getattr(cached_fn, "_cache", {})
leaves = jax.tree.leaves(list(cache.keys()) + list(cache.values()))
for leaf in leaves:
assert not isinstance(leaf, jax.core.Tracer), (
f"Tracer leaked into {cached_fn.__name__}._cache"
)


@pytest.mark.parametrize("num_chains", [1, 2])
@pytest.mark.parametrize("chain_method", ["parallel", "sequential", "vectorized"])
@pytest.mark.parametrize("progress_bar", [True, False])
Expand Down
Loading