From f81c63775801a7304b7bab62812101abeb5d4275 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Thu, 8 Jul 2021 07:34:10 -0400 Subject: [PATCH 1/4] support infer_discrete for Predictive --- examples/annotation.py | 23 +++--------- numpyro/contrib/funsor/discrete.py | 15 +++++--- numpyro/infer/util.py | 59 ++++++++++++++++++++++++++---- 3 files changed, 67 insertions(+), 30 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 6b7ada33e..863354946 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -42,10 +42,10 @@ import numpyro from numpyro import handlers -from numpyro.contrib.funsor import config_enumerate, infer_discrete +from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.indexing import Vindex import numpyro.distributions as dist -from numpyro.infer import MCMC, NUTS +from numpyro.infer import MCMC, NUTS, Predictive from numpyro.infer.reparam import LocScaleReparam @@ -313,24 +313,11 @@ def main(args): mcmc.run(random.PRNGKey(0), *data) mcmc.print_summary() - def infer_discrete_model(rng_key, samples): - conditioned_model = handlers.condition(model, data=samples) - infer_discrete_model = infer_discrete( - config_enumerate(conditioned_model), rng_key=rng_key - ) - with handlers.trace() as tr: - infer_discrete_model(*data) - - return { - name: site["value"] - for name, site in tr.items() - if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel" - } - posterior_samples = mcmc.get_samples() - discrete_samples = vmap(infer_discrete_model)( - random.split(random.PRNGKey(1), args.num_samples), posterior_samples + predictive = Predictive( + config_enumerate(model), posterior_samples, infer_discrete=True ) + discrete_samples = predictive(random.PRNGKey(1), *data) item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)( discrete_samples["c"].squeeze(-1) diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index 72767ff84..36462e28a 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -118,8 +118,7 @@ def _sample_posterior( values = [v.reshape((-1,) + prototype_shape[1:]) for v in values] data[root_name] = jnp.concatenate(values) - with substitute(data=data): - return model(*args, **kwargs) + return data def infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None): @@ -169,6 +168,12 @@ def viterbi_decoder(data, hidden_dim=10): temperature=temperature, rng_key=rng_key, ) - return functools.partial( - _sample_posterior, fn, first_available_dim, temperature, rng_key - ) + + def wrap_fn(*args, **kwargs): + samples = _sample_posterior( + fn, first_available_dim, temperature, rng_key, *args, **kwargs + ) + with substitute(data=samples): + return fn(*args, **kwargs) + + return wrap_fn diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index b4aee6ad5..144d19f22 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -11,6 +11,7 @@ from jax import device_get, jacfwd, lax, random, value_and_grad from jax.flatten_util import ravel_pytree import jax.numpy as jnp +from jax.tree_util import tree_map import numpyro from numpyro.distributions import constraints @@ -673,17 +674,45 @@ def _predictive( posterior_samples, batch_shape, return_sites=None, + infer_discrete=False, + temperature=1, parallel=True, model_args=(), model_kwargs={}, ): - model = numpyro.handlers.mask(model, mask=False) + masked_model = numpyro.handlers.mask(model, mask=False) + if infer_discrete: + # inspect the model to get some structure + rng_key, subkey = random.split(rng_key) + batch_ndim = len(batch_shape) + prototype_sample = tree_map( + lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0], + posterior_samples, + ) + model_trace = trace( + seed(substitute(masked_model, prototype_sample), subkey) + ).get_trace(*model_args, **model_kwargs) + first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 def single_prediction(val): rng_key, samples = val - model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace( - *model_args, **model_kwargs - ) + if infer_discrete: + from numpyro.contrib.funsor.discrete import _sample_posterior + + pred_samples = _sample_posterior( + substitute(model, samples), + first_available_dim, + temperature, + subkey, + *model_args, + **model_kwargs, + ) + else: + model_trace = trace( + seed(substitute(masked_model, samples), rng_key) + ).get_trace(*model_args, **model_kwargs) + pred_samples = {name: site["value"] for name, site in model_trace.items()} + if return_sites is not None: if return_sites == "": sites = { @@ -698,9 +727,7 @@ def single_prediction(val): if (site["type"] == "sample" and k not in samples) or (site["type"] == "deterministic") } - return { - name: site["value"] for name, site in model_trace.items() if name in sites - } + return {name: value for name, value in pred_samples.items() if name in sites} num_samples = int(np.prod(batch_shape)) if num_samples > 1: @@ -729,6 +756,17 @@ class Predictive(object): :param int num_samples: number of samples :param list return_sites: sites to return; by default only sample sites not present in `posterior_samples` are returned. + :param bool infer_discrete: whether or not to sample discrete sites marked with + ``site["infer"]["enumerate"] = "parallel"`` from the posterior, + conditioned on observations. Default to False. + :param int temperature: This argument controls the behavior of + sampling discrete latent sites marked with + ``site["infer"]["enumerate"] = "parallel"`` from the posterior, + conditioned on observations. If not None, this can be set to either 1 + (sample via forward-filter backward-sample) or 0 (optimize via Viterbi-like + MAP inference). By default, this is None, which means the samples will + be drawn by running the model (conditioned on posterior samples) forward. + :type infer_discrete_temperature: int or None :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`. Defaults to False. :param batch_ndims: the number of batch dimensions in posterior samples. Some usages: @@ -749,10 +787,13 @@ def __init__( self, model, posterior_samples=None, + *, guide=None, params=None, num_samples=None, return_sites=None, + infer_discrete=False, + temperature=1, parallel=False, batch_ndims=1, ): @@ -801,6 +842,8 @@ def __init__( self.num_samples = num_samples self.guide = guide self.params = {} if params is None else params + self.infer_discrete = infer_discrete + self.temperature = temperature self.return_sites = return_sites self.parallel = parallel self.batch_ndims = batch_ndims @@ -838,6 +881,8 @@ def __call__(self, rng_key, *args, **kwargs): posterior_samples, self._batch_shape, return_sites=self.return_sites, + infer_discrete=self.infer_discrete, + temperature=self.temperature, parallel=self.parallel, model_args=args, model_kwargs=kwargs, From 4521e94ea32b203fdb56445ece6d03d26ae4afce Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Thu, 8 Jul 2021 08:00:56 -0400 Subject: [PATCH 2/4] revise docs --- examples/annotation.py | 5 +---- numpyro/infer/util.py | 32 ++++++++++++++++---------------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 863354946..264d14ac7 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -42,7 +42,6 @@ import numpyro from numpyro import handlers -from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.indexing import Vindex import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS, Predictive @@ -314,9 +313,7 @@ def main(args): mcmc.print_summary() posterior_samples = mcmc.get_samples() - predictive = Predictive( - config_enumerate(model), posterior_samples, infer_discrete=True - ) + predictive = Predictive(model, posterior_samples, infer_discrete=True) discrete_samples = predictive(random.PRNGKey(1), *data) item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)( diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 144d19f22..815eb53de 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -17,7 +17,7 @@ from numpyro.distributions import constraints from numpyro.distributions.transforms import biject_to from numpyro.distributions.util import is_identically_one, sum_rightmost -from numpyro.handlers import replay, seed, substitute, trace +from numpyro.handlers import condition, replay, seed, substitute, trace from numpyro.infer.initialization import init_to_uniform, init_to_value from numpyro.util import not_jax_tracer, soft_vmap, while_loop @@ -689,21 +689,23 @@ def _predictive( lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0], posterior_samples, ) - model_trace = trace( + prototype_trace = trace( seed(substitute(masked_model, prototype_sample), subkey) ).get_trace(*model_args, **model_kwargs) - first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 + first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1 def single_prediction(val): rng_key, samples = val if infer_discrete: + from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.funsor.discrete import _sample_posterior + model_trace = prototype_trace pred_samples = _sample_posterior( - substitute(model, samples), + config_enumerate(condition(model, samples)), first_available_dim, temperature, - subkey, + rng_key, *model_args, **model_kwargs, ) @@ -756,17 +758,15 @@ class Predictive(object): :param int num_samples: number of samples :param list return_sites: sites to return; by default only sample sites not present in `posterior_samples` are returned. - :param bool infer_discrete: whether or not to sample discrete sites marked with - ``site["infer"]["enumerate"] = "parallel"`` from the posterior, - conditioned on observations. Default to False. - :param int temperature: This argument controls the behavior of - sampling discrete latent sites marked with - ``site["infer"]["enumerate"] = "parallel"`` from the posterior, - conditioned on observations. If not None, this can be set to either 1 - (sample via forward-filter backward-sample) or 0 (optimize via Viterbi-like - MAP inference). By default, this is None, which means the samples will - be drawn by running the model (conditioned on posterior samples) forward. - :type infer_discrete_temperature: int or None + :param bool infer_discrete: whether or not to sample discrete sites from the + posterior, conditioned on observations and other latent values in + ``posterior_samples``. Under the hood, those sites will be marked with + ``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at + the `Pyro enumeration tutorial `_. + This feature requires ``funsor`` installation. + :param int temperature: Either 1 (sample via forward-filter backward-sample) + or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample). + This argument only takes effect when ``infer_discrete=True``. :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`. Defaults to False. :param batch_ndims: the number of batch dimensions in posterior samples. Some usages: From 0d9e42ed6205fd3cf5658b581454f8b958262a7d Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Thu, 8 Jul 2021 22:28:31 -0400 Subject: [PATCH 3/4] use infer_discrete_temperature --- examples/annotation.py | 2 +- numpyro/infer/util.py | 28 ++++++++++++---------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 264d14ac7..77eb2c750 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -313,7 +313,7 @@ def main(args): mcmc.print_summary() posterior_samples = mcmc.get_samples() - predictive = Predictive(model, posterior_samples, infer_discrete=True) + predictive = Predictive(model, posterior_samples, infer_discrete_temperature=1) discrete_samples = predictive(random.PRNGKey(1), *data) item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)( diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 815eb53de..793e6edd1 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -674,14 +674,13 @@ def _predictive( posterior_samples, batch_shape, return_sites=None, - infer_discrete=False, - temperature=1, + infer_discrete_temperature=None, parallel=True, model_args=(), model_kwargs={}, ): masked_model = numpyro.handlers.mask(model, mask=False) - if infer_discrete: + if infer_discrete_temperature is not None: # inspect the model to get some structure rng_key, subkey = random.split(rng_key) batch_ndim = len(batch_shape) @@ -696,7 +695,7 @@ def _predictive( def single_prediction(val): rng_key, samples = val - if infer_discrete: + if infer_discrete_temperature is not None: from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.funsor.discrete import _sample_posterior @@ -704,7 +703,7 @@ def single_prediction(val): pred_samples = _sample_posterior( config_enumerate(condition(model, samples)), first_available_dim, - temperature, + infer_discrete_temperature, rng_key, *model_args, **model_kwargs, @@ -758,15 +757,15 @@ class Predictive(object): :param int num_samples: number of samples :param list return_sites: sites to return; by default only sample sites not present in `posterior_samples` are returned. - :param bool infer_discrete: whether or not to sample discrete sites from the + :param infer_discrete_temperature: if not None, we'll sample discrete sites from the posterior, conditioned on observations and other latent values in ``posterior_samples``. Under the hood, those sites will be marked with ``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at the `Pyro enumeration tutorial `_. - This feature requires ``funsor`` installation. - :param int temperature: Either 1 (sample via forward-filter backward-sample) - or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample). - This argument only takes effect when ``infer_discrete=True``. + The temperature value is either 1 (sample via forward-filter backward-sample) + or 0 (optimize via Viterbi-like MAP inference). + Note that this requires ``funsor`` installation. + :type infer_discrete_temperature: None or int :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`. Defaults to False. :param batch_ndims: the number of batch dimensions in posterior samples. Some usages: @@ -792,8 +791,7 @@ def __init__( params=None, num_samples=None, return_sites=None, - infer_discrete=False, - temperature=1, + infer_discrete_temperature=None, parallel=False, batch_ndims=1, ): @@ -842,8 +840,7 @@ def __init__( self.num_samples = num_samples self.guide = guide self.params = {} if params is None else params - self.infer_discrete = infer_discrete - self.temperature = temperature + self.infer_discrete_temperature = infer_discrete_temperature self.return_sites = return_sites self.parallel = parallel self.batch_ndims = batch_ndims @@ -881,8 +878,7 @@ def __call__(self, rng_key, *args, **kwargs): posterior_samples, self._batch_shape, return_sites=self.return_sites, - infer_discrete=self.infer_discrete, - temperature=self.temperature, + infer_discrete_temperature=self.infer_discrete_temperature, parallel=self.parallel, model_args=args, model_kwargs=kwargs, From 93414d6110b4f9d4eed94c2f796e631b0f424495 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Fri, 9 Jul 2021 17:16:36 -0400 Subject: [PATCH 4/4] use temperature=1 by default --- examples/annotation.py | 2 +- numpyro/infer/util.py | 20 +++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 77eb2c750..264d14ac7 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -313,7 +313,7 @@ def main(args): mcmc.print_summary() posterior_samples = mcmc.get_samples() - predictive = Predictive(model, posterior_samples, infer_discrete_temperature=1) + predictive = Predictive(model, posterior_samples, infer_discrete=True) discrete_samples = predictive(random.PRNGKey(1), *data) item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)( diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 793e6edd1..d3c41187a 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -674,13 +674,13 @@ def _predictive( posterior_samples, batch_shape, return_sites=None, - infer_discrete_temperature=None, + infer_discrete=False, parallel=True, model_args=(), model_kwargs={}, ): masked_model = numpyro.handlers.mask(model, mask=False) - if infer_discrete_temperature is not None: + if infer_discrete: # inspect the model to get some structure rng_key, subkey = random.split(rng_key) batch_ndim = len(batch_shape) @@ -695,15 +695,16 @@ def _predictive( def single_prediction(val): rng_key, samples = val - if infer_discrete_temperature is not None: + if infer_discrete: from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.funsor.discrete import _sample_posterior model_trace = prototype_trace + temperature = 1 pred_samples = _sample_posterior( config_enumerate(condition(model, samples)), first_available_dim, - infer_discrete_temperature, + temperature, rng_key, *model_args, **model_kwargs, @@ -757,15 +758,12 @@ class Predictive(object): :param int num_samples: number of samples :param list return_sites: sites to return; by default only sample sites not present in `posterior_samples` are returned. - :param infer_discrete_temperature: if not None, we'll sample discrete sites from the + :param bool infer_discrete: whether or not to sample discrete sites from the posterior, conditioned on observations and other latent values in ``posterior_samples``. Under the hood, those sites will be marked with ``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at the `Pyro enumeration tutorial `_. - The temperature value is either 1 (sample via forward-filter backward-sample) - or 0 (optimize via Viterbi-like MAP inference). Note that this requires ``funsor`` installation. - :type infer_discrete_temperature: None or int :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`. Defaults to False. :param batch_ndims: the number of batch dimensions in posterior samples. Some usages: @@ -791,7 +789,7 @@ def __init__( params=None, num_samples=None, return_sites=None, - infer_discrete_temperature=None, + infer_discrete=False, parallel=False, batch_ndims=1, ): @@ -840,7 +838,7 @@ def __init__( self.num_samples = num_samples self.guide = guide self.params = {} if params is None else params - self.infer_discrete_temperature = infer_discrete_temperature + self.infer_discrete = infer_discrete self.return_sites = return_sites self.parallel = parallel self.batch_ndims = batch_ndims @@ -878,7 +876,7 @@ def __call__(self, rng_key, *args, **kwargs): posterior_samples, self._batch_shape, return_sites=self.return_sites, - infer_discrete_temperature=self.infer_discrete_temperature, + infer_discrete=self.infer_discrete, parallel=self.parallel, model_args=args, model_kwargs=kwargs,