Skip to content

Commit 9999e3d

Browse files
committed
New sample_posterior_predictive API
1 parent 22fb388 commit 9999e3d

4 files changed

Lines changed: 330 additions & 21 deletions

File tree

pymc/sampling/forward.py

Lines changed: 135 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def compile_forward_sampling_function(
113113
givens_dict: dict[Variable, Any] | None = None,
114114
constant_data: dict[str, np.ndarray] | None = None,
115115
constant_coords: set[str] | None = None,
116+
freeze_vars: set[Variable] | None = None,
117+
volatile_outputs: set[Variable] | None = None,
116118
**kwargs,
117119
) -> tuple[Callable[..., np.ndarray | list[np.ndarray]], set[Variable]]:
118120
"""Compile a function to draw samples, conditioned on the values of some variables.
@@ -131,6 +133,10 @@ def compile_forward_sampling_function(
131133
- Variables that are keys in the ``givens_dict``
132134
- Variables that have volatile inputs
133135
136+
Variables in ``freeze_vars`` are never considered volatile, regardless of the above
137+
rules. They act as volatility barriers, stopping the propagation of volatility to
138+
their dependents. Frozen variables are always treated as trace inputs.
139+
134140
Concretely, this function can be used to compile a function to sample from the
135141
posterior predictive distribution of a model that has variables that are conditioned
136142
on ``Data`` instances. The variables that depend on the mutable data that have changed
@@ -183,6 +189,17 @@ def compile_forward_sampling_function(
183189
which case, it is considered volatile. If a ``SharedVariable`` is not found
184190
in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
185191
Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set.
192+
freeze_vars : Optional[Set[pytensor.graph.basic.Variable]]
193+
A set of variables that should never be considered volatile, even if they would
194+
otherwise be due to having volatile inputs, being in the outputs list, or depending
195+
on changed data. Frozen variables act as volatility barriers: they stop the
196+
propagation of volatility to their dependents and are always treated as inputs
197+
that pull values from the trace.
198+
volatile_outputs : Optional[Set[pytensor.graph.basic.Variable]]
199+
A subset of ``outputs`` that should be considered volatile. If provided, only these
200+
outputs (rather than all outputs) are marked as volatile. Outputs not in this set
201+
will still be computed but won't trigger volatility propagation. If ``None`` (default),
202+
all outputs are considered volatile (preserving backward compatibility).
186203
187204
Returns
188205
-------
@@ -202,6 +219,11 @@ def compile_forward_sampling_function(
202219
constant_data = {}
203220
if constant_coords is None:
204221
constant_coords = set()
222+
if freeze_vars is None:
223+
freeze_vars = set()
224+
# If volatile_outputs is not specified, all outputs are volatile (backward compatible)
225+
# If specified, only those outputs are marked volatile
226+
_volatile_outputs: set[Variable] | None = volatile_outputs
205227

206228
# We define a helper function to check if shared values match to an array
207229
def shared_value_matches(var):
@@ -221,8 +243,10 @@ def shared_value_matches(var):
221243
) # type: ignore[call-overload]
222244
volatile_nodes: set[Any] = set()
223245
for node in nodes:
246+
if node in freeze_vars:
247+
continue # Frozen variables are never volatile, and block propagation
224248
if (
225-
node in fg.outputs
249+
(node in fg.outputs if _volatile_outputs is None else node in _volatile_outputs)
226250
or node in givens_dict
227251
or ( # SharedVariables, except RandomState/Generators
228252
isinstance(node, SharedVariable)
@@ -504,6 +528,8 @@ def sample_posterior_predictive(
504528
predictions: bool = False,
505529
idata_kwargs: dict | None = None,
506530
compile_kwargs: dict | None = None,
531+
sample_vars: list[str] | None = None,
532+
freeze_vars: list[str] | None = None,
507533
) -> InferenceData: ...
508534
@overload
509535
def sample_posterior_predictive(
@@ -519,6 +545,8 @@ def sample_posterior_predictive(
519545
predictions: bool = False,
520546
idata_kwargs: dict | None = None,
521547
compile_kwargs: dict | None = None,
548+
sample_vars: list[str] | None = None,
549+
freeze_vars: list[str] | None = None,
522550
) -> dict[str, np.ndarray]: ...
523551
def sample_posterior_predictive(
524552
trace,
@@ -533,6 +561,8 @@ def sample_posterior_predictive(
533561
predictions: bool = False,
534562
idata_kwargs: dict | None = None,
535563
compile_kwargs: dict | None = None,
564+
sample_vars: list[str] | None = None,
565+
freeze_vars: list[str] | None = None,
536566
) -> InferenceData | dict[str, np.ndarray]:
537567
"""Generate forward samples for `var_names`, conditioned on the posterior samples of variables found in the `trace`.
538568
@@ -553,9 +583,10 @@ def sample_posterior_predictive(
553583
Model to be used to generate the posterior predictive samples. It will
554584
generally be the model used to generate the `trace`, but it doesn't need to be.
555585
var_names : Iterable[str], optional
556-
Names of variables for which to compute the posterior predictive samples.
557-
By default, only observed variables are sampled.
558-
See the example below for what happens when this argument is customized.
586+
Names of variables to include in the returned dataset. This only controls which
587+
variables appear in the output, not which variables are resampled. Use ``sample_vars``
588+
to control resampling. By default, observed variables and their dependent
589+
deterministics are included.
559590
sample_dims : list of str, optional
560591
Dimensions over which to loop and generate posterior predictive samples.
561592
When ``sample_dims`` is ``None`` (default) both "chain" and "draw" are considered sample
@@ -581,6 +612,16 @@ def sample_posterior_predictive(
581612
:func:`pymc.predictions_to_inference_data` otherwise.
582613
compile_kwargs: dict, optional
583614
Keyword arguments for :func:`pymc.pytensorf.compile`.
615+
sample_vars : list of str, optional
616+
Names of unobserved variables that should be explicitly resampled. Observed variables
617+
are always resampled. Use this to request resampling of specific unobserved variables
618+
(e.g., for out-of-model predictions or forecasting). Variables not in ``sample_vars``
619+
will have their values taken from the trace if available.
620+
freeze_vars : list of str, optional
621+
Names of variables that should always be reused from the trace, even if they would
622+
otherwise be resampled due to depending on changed data or other resampled variables.
623+
Frozen variables act as barriers that stop the propagation of "volatility" in the
624+
computational graph. Must be present in the trace. Cannot overlap with ``sample_vars``.
584625
585626
Returns
586627
-------
@@ -873,17 +914,39 @@ def sample_posterior_predictive(
873914

874915
constant_coords = get_constant_coords(trace_coords, model)
875916

917+
# Resolve output variables (what to return in the dataset)
876918
if var_names is not None:
877-
vars_ = [model[x] for x in var_names]
919+
output_vars = [model[x] for x in var_names]
878920
else:
879921
observed_vars = model.observed_RVs
880922
if observed_data is not None:
881923
observed_vars += [
882924
model[x] for x in observed_data if x in model and x not in observed_vars
883925
]
884-
vars_ = observed_vars + observed_dependent_deterministics(model, observed_vars)
885-
886-
vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))
926+
output_vars = observed_vars + observed_dependent_deterministics(model, observed_vars)
927+
928+
# Resolve variables to resample
929+
# Observed variables are always resampled, plus any explicit sample_vars
930+
resample_vars: list[Variable] = list(model.observed_RVs)
931+
if sample_vars is not None:
932+
resample_vars += [model[x] for x in sample_vars]
933+
934+
# Compiled function outputs = resample_vars + dependent deterministics +
935+
# deterministics from var_names (they need recomputation, not "resampling")
936+
basic_rv_set = set(model.basic_RVs)
937+
compiled_outputs = list(resample_vars)
938+
compiled_outputs += observed_dependent_deterministics(model, resample_vars)
939+
# Add deterministics from var_names that need recomputation
940+
for var in output_vars:
941+
if var not in basic_rv_set and var not in compiled_outputs:
942+
compiled_outputs.append(var)
943+
944+
vars_to_sample = list(
945+
get_default_varnames(
946+
list({v.name: v for v in compiled_outputs}.values()),
947+
include_transformed=False,
948+
)
949+
)
887950

888951
if not vars_to_sample:
889952
if return_inferencedata and not extend_inferencedata:
@@ -894,6 +957,27 @@ def sample_posterior_predictive(
894957

895958
vars_in_trace = get_vars_in_point_list(_trace, model)
896959

960+
# Resolve freeze vars
961+
frozen: set[Variable] = set()
962+
if freeze_vars is not None:
963+
frozen = {model[x] for x in freeze_vars}
964+
# Validate: freeze_vars must be in trace
965+
vars_in_trace_names = {v.name for v in vars_in_trace}
966+
missing = {x for x in freeze_vars if x not in vars_in_trace_names}
967+
if missing:
968+
raise ValueError(
969+
f"freeze_vars {sorted(missing)} are not present in the trace. "
970+
f"Cannot freeze variables without stored values."
971+
)
972+
# Validate: freeze_vars and sample_vars must be disjoint
973+
if sample_vars is not None:
974+
overlap = set(freeze_vars) & set(sample_vars)
975+
if overlap:
976+
raise ValueError(
977+
f"Variables {sorted(overlap)} are in both sample_vars and freeze_vars. "
978+
f"A variable cannot be both resampled and frozen."
979+
)
980+
897981
if random_seed is not None:
898982
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
899983

@@ -902,6 +986,10 @@ def sample_posterior_predictive(
902986
compile_kwargs.setdefault("allow_input_downcast", True)
903987
compile_kwargs.setdefault("accept_inplace", True)
904988

989+
# Only resample_vars should be volatile outputs. Other outputs (deterministics from
990+
# var_names) are just computed from the graph without triggering volatility cascades.
991+
_volatile_outputs = set(resample_vars) & set(vars_to_sample)
992+
905993
_sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
906994
outputs=vars_to_sample,
907995
vars_in_trace=vars_in_trace,
@@ -910,11 +998,38 @@ def sample_posterior_predictive(
910998
random_seed=random_seed,
911999
constant_data=constant_data,
9121000
constant_coords=constant_coords,
1001+
freeze_vars=frozen,
1002+
volatile_outputs=_volatile_outputs,
9131003
**compile_kwargs,
9141004
)
9151005
sampler_fn = point_wrapper(_sampler_fn)
1006+
1007+
# Warn about implicitly volatile trace variables
1008+
vars_in_trace_set = set(vars_in_trace)
1009+
sample_var_names = set(sample_vars) if sample_vars is not None else set()
1010+
implicit_volatile = {
1011+
rv
1012+
for rv in volatile_basic_rvs
1013+
if rv in vars_in_trace_set and rv.name not in sample_var_names
1014+
}
1015+
if implicit_volatile:
1016+
implicit_names = sorted(rv.name for rv in implicit_volatile)
1017+
warnings.warn(
1018+
f"Variables {implicit_names} are in the trace but will be resampled because "
1019+
f"they depend on data/coords that changed. To silence this warning, add them "
1020+
f"to `sample_vars` explicitly, or add them to `freeze_vars` to reuse their "
1021+
f"trace values.",
1022+
UserWarning,
1023+
stacklevel=2,
1024+
)
1025+
9161026
# All model variables have a name, but mypy does not know this
9171027
_log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value]
1028+
1029+
# Determine output-only variables that should be copied from trace, not sampled
1030+
sampled_names = {v.name for v in vars_to_sample}
1031+
copy_from_trace_names = [var.name for var in output_vars if var.name not in sampled_names]
1032+
9181033
ppc_trace_t = _DefaultTrace(samples)
9191034

9201035
progress = create_simple_progress(
@@ -942,9 +1057,14 @@ def sample_posterior_predictive(
9421057

9431058
values = sampler_fn(**param)
9441059

945-
for k, v in zip(vars_, values):
1060+
for k, v in zip(vars_to_sample, values):
9461061
ppc_trace_t.insert(k.name, v, idx)
9471062

1063+
# Copy output-only variables from trace
1064+
for name in copy_from_trace_names:
1065+
if name in param:
1066+
ppc_trace_t.insert(name, param[name], idx)
1067+
9481068
progress.advance(task)
9491069
progress.update(task, refresh=True, completed=samples)
9501070

@@ -953,6 +1073,12 @@ def sample_posterior_predictive(
9531073

9541074
ppc_trace = ppc_trace_t.trace_dict
9551075

1076+
# Filter to only include requested output variables and sample_vars
1077+
output_var_names = {v.name for v in output_vars}
1078+
if sample_vars is not None:
1079+
output_var_names |= set(sample_vars)
1080+
ppc_trace = {k: v for k, v in ppc_trace.items() if k in output_var_names}
1081+
9561082
for k, ary in ppc_trace.items():
9571083
if stacked_dims is not None:
9581084
ppc_trace[k] = ary.reshape(

tests/model/transform/test_conditioning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_do_posterior_predictive():
187187
# Replace `y` by a constant `100.0`
188188
m_do = do(m, {y: 100.0})
189189
with m_do:
190-
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
190+
idata_do = pm.sample_posterior_predictive(idata_m, var_names=["z"], sample_vars=["z"])
191191

192192
assert 120 < idata_do.posterior_predictive["z"].mean() < 130
193193

@@ -315,7 +315,9 @@ def test_do_sample_posterior_predictive(make_interventions_shared):
315315
idata = az.from_dict({"a": [[1.0]], "b": [[2.0]], "c": [[1.0]]})
316316

317317
with do(model, {a: 1000}, make_interventions_shared=make_interventions_shared):
318-
pp = sample_posterior_predictive(idata, var_names=["c"], predictions=True).predictions
318+
pp = sample_posterior_predictive(
319+
idata, var_names=["c"], sample_vars=["c"], predictions=True
320+
).predictions
319321
assert (pp["c"] > 500).all()
320322

321323

0 commit comments

Comments
 (0)