Skip to content

Predictive.exclude_deterministic does not filter deterministic sites #2086

Description

@brendancooley

Bug Description

The Predictive helper class attribute exclude_deterministic does not successfully exclude deterministic sites from resulting posterior predictive dictionary.

Steps to Reproduce

uv run test_exclude_deterministic.py

where exclude_deterministic.py contains

#!/usr/bin/env -S uv run
# /// script
# requires-python = ">=3.9"
# dependencies = [
#     "jax",
#     "jaxlib",
#     "numpyro",
# ]
# ///
"""
Test to verify if exclude_deterministic works in numpyro.infer.util.Predictive
"""

import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive


def model(x=None, y=None):
    """Simple model with both stochastic and deterministic sites"""
    # Stochastic sites
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.Normal(0, 1))

    # Deterministic site
    linear_combination = numpyro.deterministic("linear_combination", a + 2 * b)

    # Another deterministic site
    _ = numpyro.deterministic("squared", a**2)

    # Likelihood
    if y is not None:
        sigma = numpyro.sample("sigma", dist.Exponential(1.0))
        numpyro.sample("obs", dist.Normal(linear_combination, sigma), obs=y)


# Generate some synthetic data
key = jax.random.PRNGKey(0)
n_samples = 100
x_data = jnp.linspace(0, 1, n_samples)
true_a = 1.5
true_b = 2.0
y_data = true_a + 2 * true_b + jax.random.normal(key, (n_samples,)) * 0.5

# Run MCMC
mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=1000)
mcmc.run(key, y=y_data)
samples = mcmc.get_samples()

# Test Predictive with exclude_deterministic=True
print("\nTest 1: Predictive with exclude_deterministic=True")
print("-" * 70)
predictive_exclude = Predictive(model, samples, exclude_deterministic=True)
predictions_exclude = predictive_exclude(key)
print("Keys with exclude_deterministic=True:", predictions_exclude.keys())
print("Expected: Should NOT include 'linear_combination' or 'squared'")

# Test Predictive with exclude_deterministic=False (default)
print("\nTest 2: Predictive with exclude_deterministic=False")
print("-" * 70)
predictive_include = Predictive(model, samples, exclude_deterministic=False)
predictions_include = predictive_include(key)
print("Keys with exclude_deterministic=False:", predictions_include.keys())
print("Expected: SHOULD include 'linear_combination' and 'squared'")

Expected Behavior

The "squared" and "linear_combination" sites should only be present in the returned value of predictive_include().

Notes

Obviously the fix here will be correct but breaking. Unless I'm missing something. Happy to work on a fix, but perhaps there's some context I'm missing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions