Skip to content

Commit c3eb5fd

Browse files
committed
New sample_posterior_predictive API
1 parent bca2b1e commit c3eb5fd

4 files changed

Lines changed: 343 additions & 80 deletions

File tree

pymc/sampling/forward.py

Lines changed: 132 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from pytensor import tensor as pt
3333
from pytensor.graph import vectorize_graph
3434
from pytensor.graph.basic import (
35-
Apply,
3635
Constant,
3736
Variable,
3837
)
@@ -110,9 +109,10 @@ def compile_forward_sampling_function(
110109
outputs: list[Variable],
111110
vars_in_trace: list[Variable],
112111
basic_rvs: list[Variable] | None = None,
113-
givens_dict: dict[Variable, Any] | None = None,
114112
constant_data: dict[str, np.ndarray] | None = None,
115113
constant_coords: set[str] | None = None,
114+
volatile_vars: set[Variable] | None = None,
115+
freeze_vars: set[Variable] | None = None,
116116
**kwargs,
117117
) -> tuple[Callable[..., np.ndarray | list[np.ndarray]], set[Variable]]:
118118
"""Compile a function to draw samples, conditioned on the values of some variables.
@@ -125,12 +125,15 @@ def compile_forward_sampling_function(
125125
Volatile variables are variables whose values could change between runs of the
126126
compiled function or after inference has been run. These variables are:
127127
128-
- Variables in the outputs list
128+
- Variables in ``volatile_vars``
129129
- ``SharedVariable`` instances that are not ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time
130130
- Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
131-
- Variables that are keys in the ``givens_dict``
132131
- Variables that have volatile inputs
133132
133+
Variables in ``freeze_vars`` are never considered volatile, regardless of the above
134+
rules. They act as volatility barriers, stopping the propagation of volatility to
135+
their dependents. Frozen variables are always treated as trace inputs.
136+
134137
Concretely, this function can be used to compile a function to sample from the
135138
posterior predictive distribution of a model that has variables that are conditioned
136139
on ``Data`` instances. The variables that depend on the mutable data that have changed
@@ -139,19 +142,11 @@ def compile_forward_sampling_function(
139142
ignored and new values will be computed (in the case of deterministics and potentials) or
140143
sampled (in the case of random variables).
141144
142-
This function also enables a way to impute values for any variable in the computational
143-
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
144-
to set the ``givens`` argument of the pytensor function compilation. This will essentially
145-
replace a node in the computational graph with any other expression that has the same
146-
type as the desired node. Passing variables in the givens_dict is considered an intervention
147-
that might lead to different variable values from those that could have been seen during
148-
inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
149-
volatile**.
150-
151145
Parameters
152146
----------
153147
outputs : List[pytensor.graph.basic.Variable]
154-
The list of variables that will be returned by the compiled function
148+
The list of variables that will be returned by the compiled function.
149+
Outputs are not inherently volatile.
155150
vars_in_trace : List[pytensor.graph.basic.Variable]
156151
The list of variables that are assumed to have values stored in the trace
157152
basic_rvs : Optional[List[pytensor.graph.basic.Variable]]
@@ -160,10 +155,6 @@ def compile_forward_sampling_function(
160155
be considered as random variable instances. This includes variables that have
161156
a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
162157
Censored distributions.
163-
givens_dict : Optional[Dict[pytensor.graph.basic.Variable, Any]]
164-
A dictionary that maps tensor variables to the values that should be used to replace them
165-
in the compiled function. The types of the key and value should match or an error will be
166-
raised during compilation.
167158
constant_data : Optional[Dict[str, numpy.ndarray]]
168159
A dictionary that maps the names of ``Data`` instances to their
169160
corresponding values at inference time. If a model was created with ``Data``, these
@@ -183,6 +174,14 @@ def compile_forward_sampling_function(
183174
which case, it is considered volatile. If a ``SharedVariable`` is not found
184175
in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
185176
Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set.
177+
volatile_vars : Optional[Set[pytensor.graph.basic.Variable]]
178+
Variables that are unconditionally volatile. Volatility propagates from these
179+
to their dependents in the graph.
180+
freeze_vars : Optional[Set[pytensor.graph.basic.Variable]]
181+
A set of variables that should never be considered volatile, even if they would
182+
otherwise be due to having volatile inputs or depending on changed data. Frozen
183+
variables act as volatility barriers: they stop the propagation of volatility to
184+
their dependents and are always treated as inputs that pull values from the trace.
186185
187186
Returns
188187
-------
@@ -192,16 +191,17 @@ def compile_forward_sampling_function(
192191
Set of all basic_rvs that were considered volatile and will be resampled when
193192
the function is evaluated
194193
"""
195-
if givens_dict is None:
196-
givens_dict = {}
197-
198194
if basic_rvs is None:
199195
basic_rvs = []
200196

201197
if constant_data is None:
202198
constant_data = {}
203199
if constant_coords is None:
204200
constant_coords = set()
201+
if volatile_vars is None:
202+
volatile_vars = set()
203+
if freeze_vars is None:
204+
freeze_vars = set()
205205

206206
# We define a helper function to check if shared values match to an array
207207
def shared_value_matches(var):
@@ -221,9 +221,10 @@ def shared_value_matches(var):
221221
) # type: ignore[call-overload]
222222
volatile_nodes: set[Any] = set()
223223
for node in nodes:
224+
if node in freeze_vars:
225+
continue # Frozen variables are never volatile, and block propagation
224226
if (
225-
node in fg.outputs
226-
or node in givens_dict
227+
node in volatile_vars
227228
or ( # SharedVariables, except RandomState/Generators
228229
isinstance(node, SharedVariable)
229230
and not isinstance(node, RandomGeneratorSharedVariable)
@@ -260,20 +261,9 @@ def expand(node):
260261
# the entire graph
261262
list(walk(fg.outputs, expand))
262263

263-
# Populate the givens list
264-
givens = [
265-
(
266-
node,
267-
value
268-
if isinstance(value, Variable | Apply)
269-
else pt.constant(value, dtype=getattr(node, "dtype", None), name=node.name),
270-
)
271-
for node, value in givens_dict.items()
272-
]
273-
274264
return (
275-
compile(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
276-
set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
265+
compile(inputs, fg.outputs, on_unused_input="ignore", **kwargs),
266+
set(basic_rvs) & volatile_nodes, # Basic RVs that will be resampled
277267
)
278268

279269

@@ -450,7 +440,6 @@ def sample_prior_predictive(
450440
vars_to_sample,
451441
vars_in_trace=[],
452442
basic_rvs=model.basic_RVs,
453-
givens_dict=None,
454443
random_seed=random_seed,
455444
**compile_kwargs,
456445
)
@@ -490,6 +479,8 @@ def sample_posterior_predictive(
490479
predictions: bool = False,
491480
idata_kwargs: dict | None = None,
492481
compile_kwargs: dict | None = None,
482+
sample_vars: list[str] | None = None,
483+
freeze_vars: list[str] | None = None,
493484
) -> InferenceData: ...
494485
@overload
495486
def sample_posterior_predictive(
@@ -505,6 +496,8 @@ def sample_posterior_predictive(
505496
predictions: bool = False,
506497
idata_kwargs: dict | None = None,
507498
compile_kwargs: dict | None = None,
499+
sample_vars: list[str] | None = None,
500+
freeze_vars: list[str] | None = None,
508501
) -> dict[str, np.ndarray]: ...
509502
def sample_posterior_predictive(
510503
trace,
@@ -519,6 +512,8 @@ def sample_posterior_predictive(
519512
predictions: bool = False,
520513
idata_kwargs: dict | None = None,
521514
compile_kwargs: dict | None = None,
515+
sample_vars: list[str] | None = None,
516+
freeze_vars: list[str] | None = None,
522517
) -> InferenceData | dict[str, np.ndarray]:
523518
"""Generate forward samples for `var_names`, conditioned on the posterior samples of variables found in the `trace`.
524519
@@ -539,9 +534,10 @@ def sample_posterior_predictive(
539534
Model to be used to generate the posterior predictive samples. It will
540535
generally be the model used to generate the `trace`, but it doesn't need to be.
541536
var_names : Iterable[str], optional
542-
Names of variables for which to compute the posterior predictive samples.
543-
By default, only observed variables are sampled.
544-
See the example below for what happens when this argument is customized.
537+
Names of variables to include in the returned dataset. This only controls which
538+
variables appear in the output, not which variables are resampled. Use ``sample_vars``
539+
to control resampling. By default, observed variables and their dependent
540+
deterministics are included.
545541
sample_dims : list of str, optional
546542
Dimensions over which to loop and generate posterior predictive samples.
547543
When ``sample_dims`` is ``None`` (default) both "chain" and "draw" are considered sample
@@ -567,6 +563,16 @@ def sample_posterior_predictive(
567563
:func:`pymc.predictions_to_inference_data` otherwise.
568564
compile_kwargs: dict, optional
569565
Keyword arguments for :func:`pymc.pytensorf.compile`.
566+
sample_vars : list of str, optional
567+
Names of unobserved variables that should be explicitly resampled. Observed variables
568+
are always resampled. Use this to request resampling of specific unobserved variables
569+
(e.g., for out-of-model predictions or forecasting). Variables not in ``sample_vars``
570+
will have their values taken from the trace if available.
571+
freeze_vars : list of str, optional
572+
Names of variables that should always be reused from the trace, even if they would
573+
otherwise be resampled due to depending on changed data or other resampled variables.
574+
Frozen variables act as barriers that stop the propagation of "volatility" in the
575+
computational graph. Must be present in the trace. Cannot overlap with ``sample_vars``.
570576
571577
Returns
572578
-------
@@ -859,17 +865,39 @@ def sample_posterior_predictive(
859865

860866
constant_coords = get_constant_coords(trace_coords, model)
861867

868+
# Resolve output variables (what to return in the dataset)
862869
if var_names is not None:
863-
vars_ = [model[x] for x in var_names]
870+
output_vars = [model[x] for x in var_names]
864871
else:
865872
observed_vars = model.observed_RVs
866873
if observed_data is not None:
867874
observed_vars += [
868875
model[x] for x in observed_data if x in model and x not in observed_vars
869876
]
870-
vars_ = observed_vars + observed_dependent_deterministics(model, observed_vars)
871-
872-
vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))
877+
output_vars = observed_vars + observed_dependent_deterministics(model, observed_vars)
878+
879+
# Resolve variables to resample
880+
# Observed variables are always resampled, plus any explicit sample_vars
881+
resample_vars: list[Variable] = list(model.observed_RVs)
882+
if sample_vars is not None:
883+
resample_vars += [model[x] for x in sample_vars]
884+
885+
# Compiled function outputs = resample_vars + dependent deterministics +
886+
# deterministics from var_names (they need recomputation, not "resampling")
887+
basic_rv_set = set(model.basic_RVs)
888+
compiled_outputs = list(resample_vars)
889+
compiled_outputs += observed_dependent_deterministics(model, resample_vars)
890+
# Add deterministics from var_names that need recomputation
891+
for var in output_vars:
892+
if var not in basic_rv_set and var not in compiled_outputs:
893+
compiled_outputs.append(var)
894+
895+
vars_to_sample = list(
896+
get_default_varnames(
897+
list({v.name: v for v in compiled_outputs}.values()),
898+
include_transformed=False,
899+
)
900+
)
873901

874902
if not vars_to_sample:
875903
if return_inferencedata and not extend_inferencedata:
@@ -880,6 +908,27 @@ def sample_posterior_predictive(
880908

881909
vars_in_trace = get_vars_in_point_list(_trace, model)
882910

911+
# Resolve freeze vars
912+
frozen: set[Variable] = set()
913+
if freeze_vars is not None:
914+
frozen = {model[x] for x in freeze_vars}
915+
# Validate: freeze_vars must be in trace
916+
vars_in_trace_names = {v.name for v in vars_in_trace}
917+
missing = {x for x in freeze_vars if x not in vars_in_trace_names}
918+
if missing:
919+
raise ValueError(
920+
f"freeze_vars {sorted(missing)} are not present in the trace. "
921+
f"Cannot freeze variables without stored values."
922+
)
923+
# Validate: freeze_vars and sample_vars must be disjoint
924+
if sample_vars is not None:
925+
overlap = set(freeze_vars) & set(sample_vars)
926+
if overlap:
927+
raise ValueError(
928+
f"Variables {sorted(overlap)} are in both sample_vars and freeze_vars. "
929+
f"A variable cannot be both resampled and frozen."
930+
)
931+
883932
if random_seed is not None:
884933
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
885934

@@ -892,15 +941,41 @@ def sample_posterior_predictive(
892941
outputs=vars_to_sample,
893942
vars_in_trace=vars_in_trace,
894943
basic_rvs=model.basic_RVs,
895-
givens_dict=None,
896944
random_seed=random_seed,
897945
constant_data=constant_data,
898946
constant_coords=constant_coords,
947+
volatile_vars=set(resample_vars),
948+
freeze_vars=frozen,
899949
**compile_kwargs,
900950
)
901951
sampler_fn = point_wrapper(_sampler_fn)
952+
953+
# Warn about implicitly volatile trace variables
954+
vars_in_trace_set = set(vars_in_trace)
955+
sample_var_names = set(sample_vars) if sample_vars is not None else set()
956+
implicit_volatile = {
957+
rv
958+
for rv in volatile_basic_rvs
959+
if rv in vars_in_trace_set and rv.name not in sample_var_names
960+
}
961+
if implicit_volatile:
962+
implicit_names = sorted(rv.name for rv in implicit_volatile) # type: ignore[type-var]
963+
warnings.warn(
964+
f"Variables {implicit_names} are in the trace but will be resampled because "
965+
f"they depend on data/coords that changed. To silence this warning, add them "
966+
f"to `sample_vars` explicitly, or add them to `freeze_vars` to reuse their "
967+
f"trace values.",
968+
UserWarning,
969+
stacklevel=2,
970+
)
971+
902972
# All model variables have a name, but mypy does not know this
903973
_log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value]
974+
975+
# Determine output-only variables that should be copied from trace, not sampled
976+
sampled_names = {v.name for v in vars_to_sample}
977+
copy_from_trace_names = [var.name for var in output_vars if var.name not in sampled_names]
978+
904979
ppc_trace_t = _DefaultTrace(samples)
905980

906981
progress = create_simple_progress(
@@ -928,9 +1003,14 @@ def sample_posterior_predictive(
9281003

9291004
values = sampler_fn(**param)
9301005

931-
for k, v in zip(vars_, values):
1006+
for k, v in zip(vars_to_sample, values):
9321007
ppc_trace_t.insert(k.name, v, idx)
9331008

1009+
# Copy output-only variables from trace
1010+
for name in copy_from_trace_names:
1011+
if name in param:
1012+
ppc_trace_t.insert(name, param[name], idx)
1013+
9341014
progress.advance(task)
9351015
progress.update(task, refresh=True, completed=samples)
9361016

@@ -939,6 +1019,12 @@ def sample_posterior_predictive(
9391019

9401020
ppc_trace = ppc_trace_t.trace_dict
9411021

1022+
# Filter to only include requested output variables and sample_vars
1023+
output_var_names = {v.name for v in output_vars}
1024+
if sample_vars is not None:
1025+
output_var_names |= set(sample_vars)
1026+
ppc_trace = {k: v for k, v in ppc_trace.items() if k in output_var_names}
1027+
9421028
for k, ary in ppc_trace.items():
9431029
if stacked_dims is not None:
9441030
ppc_trace[k] = ary.reshape(

tests/model/transform/test_conditioning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_do_posterior_predictive():
187187
# Replace `y` by a constant `100.0`
188188
m_do = do(m, {y: 100.0})
189189
with m_do:
190-
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
190+
idata_do = pm.sample_posterior_predictive(idata_m, var_names=["z"], sample_vars=["z"])
191191

192192
assert 120 < idata_do.posterior_predictive["z"].mean() < 130
193193

@@ -315,7 +315,9 @@ def test_do_sample_posterior_predictive(make_interventions_shared):
315315
idata = az.from_dict({"a": [[1.0]], "b": [[2.0]], "c": [[1.0]]})
316316

317317
with do(model, {a: 1000}, make_interventions_shared=make_interventions_shared):
318-
pp = sample_posterior_predictive(idata, var_names=["c"], predictions=True).predictions
318+
pp = sample_posterior_predictive(
319+
idata, var_names=["c"], sample_vars=["c"], predictions=True
320+
).predictions
319321
assert (pp["c"] > 500).all()
320322

321323

0 commit comments

Comments
 (0)