@@ -113,6 +113,8 @@ def compile_forward_sampling_function(
113113 givens_dict : dict [Variable , Any ] | None = None ,
114114 constant_data : dict [str , np .ndarray ] | None = None ,
115115 constant_coords : set [str ] | None = None ,
116+ freeze_vars : set [Variable ] | None = None ,
117+ volatile_outputs : set [Variable ] | None = None ,
116118 ** kwargs ,
117119) -> tuple [Callable [..., np .ndarray | list [np .ndarray ]], set [Variable ]]:
118120 """Compile a function to draw samples, conditioned on the values of some variables.
@@ -131,6 +133,10 @@ def compile_forward_sampling_function(
131133 - Variables that are keys in the ``givens_dict``
132134 - Variables that have volatile inputs
133135
136+ Variables in ``freeze_vars`` are never considered volatile, regardless of the above
137+ rules. They act as volatility barriers, stopping the propagation of volatility to
138+ their dependents. Frozen variables are always treated as trace inputs.
139+
134140 Concretely, this function can be used to compile a function to sample from the
135141 posterior predictive distribution of a model that has variables that are conditioned
136142 on ``Data`` instances. The variables that depend on the mutable data that have changed
@@ -183,6 +189,17 @@ def compile_forward_sampling_function(
183189 which case, it is considered volatile. If a ``SharedVariable`` is not found
184190 in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
185191 Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set.
192+ freeze_vars : Optional[Set[pytensor.graph.basic.Variable]]
193+ A set of variables that should never be considered volatile, even if they would
194+ otherwise be due to having volatile inputs, being in the outputs list, or depending
195+ on changed data. Frozen variables act as volatility barriers: they stop the
196+ propagation of volatility to their dependents and are always treated as inputs
197+ that pull values from the trace.
198+ volatile_outputs : Optional[Set[pytensor.graph.basic.Variable]]
199+ A subset of ``outputs`` that should be considered volatile. If provided, only these
200+ outputs (rather than all outputs) are marked as volatile. Outputs not in this set
201+ will still be computed but won't trigger volatility propagation. If ``None`` (default),
202+ all outputs are considered volatile (preserving backward compatibility).
186203
187204 Returns
188205 -------
@@ -202,6 +219,11 @@ def compile_forward_sampling_function(
202219 constant_data = {}
203220 if constant_coords is None :
204221 constant_coords = set ()
222+ if freeze_vars is None :
223+ freeze_vars = set ()
224+ # If volatile_outputs is not specified, all outputs are volatile (backward compatible)
225+ # If specified, only those outputs are marked volatile
226+ _volatile_outputs : set [Variable ] | None = volatile_outputs
205227
206228 # We define a helper function to check if shared values match to an array
207229 def shared_value_matches (var ):
@@ -221,8 +243,10 @@ def shared_value_matches(var):
221243 ) # type: ignore[call-overload]
222244 volatile_nodes : set [Any ] = set ()
223245 for node in nodes :
246+ if node in freeze_vars :
247+ continue # Frozen variables are never volatile, and block propagation
224248 if (
225- node in fg .outputs
249+ ( node in fg .outputs if _volatile_outputs is None else node in _volatile_outputs )
226250 or node in givens_dict
227251 or ( # SharedVariables, except RandomState/Generators
228252 isinstance (node , SharedVariable )
@@ -504,6 +528,8 @@ def sample_posterior_predictive(
504528 predictions : bool = False ,
505529 idata_kwargs : dict | None = None ,
506530 compile_kwargs : dict | None = None ,
531+ sample_vars : list [str ] | None = None ,
532+ freeze_vars : list [str ] | None = None ,
507533) -> InferenceData : ...
508534@overload
509535def sample_posterior_predictive (
@@ -519,6 +545,8 @@ def sample_posterior_predictive(
519545 predictions : bool = False ,
520546 idata_kwargs : dict | None = None ,
521547 compile_kwargs : dict | None = None ,
548+ sample_vars : list [str ] | None = None ,
549+ freeze_vars : list [str ] | None = None ,
522550) -> dict [str , np .ndarray ]: ...
523551def sample_posterior_predictive (
524552 trace ,
@@ -533,6 +561,8 @@ def sample_posterior_predictive(
533561 predictions : bool = False ,
534562 idata_kwargs : dict | None = None ,
535563 compile_kwargs : dict | None = None ,
564+ sample_vars : list [str ] | None = None ,
565+ freeze_vars : list [str ] | None = None ,
536566) -> InferenceData | dict [str , np .ndarray ]:
537567 """Generate forward samples for `var_names`, conditioned on the posterior samples of variables found in the `trace`.
538568
@@ -553,9 +583,10 @@ def sample_posterior_predictive(
553583 Model to be used to generate the posterior predictive samples. It will
554584 generally be the model used to generate the `trace`, but it doesn't need to be.
555585 var_names : Iterable[str], optional
556- Names of variables for which to compute the posterior predictive samples.
557- By default, only observed variables are sampled.
558- See the example below for what happens when this argument is customized.
586+ Names of variables to include in the returned dataset. This only controls which
587+ variables appear in the output, not which variables are resampled. Use ``sample_vars``
588+ to control resampling. By default, observed variables and their dependent
589+ deterministics are included.
559590 sample_dims : list of str, optional
560591 Dimensions over which to loop and generate posterior predictive samples.
561592 When ``sample_dims`` is ``None`` (default) both "chain" and "draw" are considered sample
@@ -581,6 +612,16 @@ def sample_posterior_predictive(
581612 :func:`pymc.predictions_to_inference_data` otherwise.
582613 compile_kwargs: dict, optional
583614 Keyword arguments for :func:`pymc.pytensorf.compile`.
615+ sample_vars : list of str, optional
616+ Names of unobserved variables that should be explicitly resampled. Observed variables
617+ are always resampled. Use this to request resampling of specific unobserved variables
618+ (e.g., for out-of-model predictions or forecasting). Variables not in ``sample_vars``
619+ will have their values taken from the trace if available.
620+ freeze_vars : list of str, optional
621+ Names of variables that should always be reused from the trace, even if they would
622+ otherwise be resampled due to depending on changed data or other resampled variables.
623+ Frozen variables act as barriers that stop the propagation of "volatility" in the
624+ computational graph. Must be present in the trace. Cannot overlap with ``sample_vars``.
584625
585626 Returns
586627 -------
@@ -873,17 +914,39 @@ def sample_posterior_predictive(
873914
874915 constant_coords = get_constant_coords (trace_coords , model )
875916
917+ # Resolve output variables (what to return in the dataset)
876918 if var_names is not None :
877- vars_ = [model [x ] for x in var_names ]
919+ output_vars = [model [x ] for x in var_names ]
878920 else :
879921 observed_vars = model .observed_RVs
880922 if observed_data is not None :
881923 observed_vars += [
882924 model [x ] for x in observed_data if x in model and x not in observed_vars
883925 ]
884- vars_ = observed_vars + observed_dependent_deterministics (model , observed_vars )
885-
886- vars_to_sample = list (get_default_varnames (vars_ , include_transformed = False ))
926+ output_vars = observed_vars + observed_dependent_deterministics (model , observed_vars )
927+
928+ # Resolve variables to resample
929+ # Observed variables are always resampled, plus any explicit sample_vars
930+ resample_vars : list [Variable ] = list (model .observed_RVs )
931+ if sample_vars is not None :
932+ resample_vars += [model [x ] for x in sample_vars ]
933+
934+ # Compiled function outputs = resample_vars + dependent deterministics +
935+ # deterministics from var_names (they need recomputation, not "resampling")
936+ basic_rv_set = set (model .basic_RVs )
937+ compiled_outputs = list (resample_vars )
938+ compiled_outputs += observed_dependent_deterministics (model , resample_vars )
939+ # Add deterministics from var_names that need recomputation
940+ for var in output_vars :
941+ if var not in basic_rv_set and var not in compiled_outputs :
942+ compiled_outputs .append (var )
943+
944+ vars_to_sample = list (
945+ get_default_varnames (
946+ list ({v .name : v for v in compiled_outputs }.values ()),
947+ include_transformed = False ,
948+ )
949+ )
887950
888951 if not vars_to_sample :
889952 if return_inferencedata and not extend_inferencedata :
@@ -894,6 +957,27 @@ def sample_posterior_predictive(
894957
895958 vars_in_trace = get_vars_in_point_list (_trace , model )
896959
960+ # Resolve freeze vars
961+ frozen : set [Variable ] = set ()
962+ if freeze_vars is not None :
963+ frozen = {model [x ] for x in freeze_vars }
964+ # Validate: freeze_vars must be in trace
965+ vars_in_trace_names = {v .name for v in vars_in_trace }
966+ missing = {x for x in freeze_vars if x not in vars_in_trace_names }
967+ if missing :
968+ raise ValueError (
969+ f"freeze_vars { sorted (missing )} are not present in the trace. "
970+ f"Cannot freeze variables without stored values."
971+ )
972+ # Validate: freeze_vars and sample_vars must be disjoint
973+ if sample_vars is not None :
974+ overlap = set (freeze_vars ) & set (sample_vars )
975+ if overlap :
976+ raise ValueError (
977+ f"Variables { sorted (overlap )} are in both sample_vars and freeze_vars. "
978+ f"A variable cannot be both resampled and frozen."
979+ )
980+
897981 if random_seed is not None :
898982 (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
899983
@@ -902,6 +986,10 @@ def sample_posterior_predictive(
902986 compile_kwargs .setdefault ("allow_input_downcast" , True )
903987 compile_kwargs .setdefault ("accept_inplace" , True )
904988
989+ # Only resample_vars should be volatile outputs. Other outputs (deterministics from
990+ # var_names) are just computed from the graph without triggering volatility cascades.
991+ _volatile_outputs = set (resample_vars ) & set (vars_to_sample )
992+
905993 _sampler_fn , volatile_basic_rvs = compile_forward_sampling_function (
906994 outputs = vars_to_sample ,
907995 vars_in_trace = vars_in_trace ,
@@ -910,11 +998,38 @@ def sample_posterior_predictive(
910998 random_seed = random_seed ,
911999 constant_data = constant_data ,
9121000 constant_coords = constant_coords ,
1001+ freeze_vars = frozen ,
1002+ volatile_outputs = _volatile_outputs ,
9131003 ** compile_kwargs ,
9141004 )
9151005 sampler_fn = point_wrapper (_sampler_fn )
1006+
1007+ # Warn about implicitly volatile trace variables
1008+ vars_in_trace_set = set (vars_in_trace )
1009+ sample_var_names = set (sample_vars ) if sample_vars is not None else set ()
1010+ implicit_volatile = {
1011+ rv
1012+ for rv in volatile_basic_rvs
1013+ if rv in vars_in_trace_set and rv .name not in sample_var_names
1014+ }
1015+ if implicit_volatile :
1016+ implicit_names = sorted (rv .name for rv in implicit_volatile )
1017+ warnings .warn (
1018+ f"Variables { implicit_names } are in the trace but will be resampled because "
1019+ f"they depend on data/coords that changed. To silence this warning, add them "
1020+ f"to `sample_vars` explicitly, or add them to `freeze_vars` to reuse their "
1021+ f"trace values." ,
1022+ UserWarning ,
1023+ stacklevel = 2 ,
1024+ )
1025+
9161026 # All model variables have a name, but mypy does not know this
9171027 _log .info (f"Sampling: { sorted (volatile_basic_rvs , key = lambda var : var .name )} " ) # type: ignore[arg-type, return-value]
1028+
1029+ # Determine output-only variables that should be copied from trace, not sampled
1030+ sampled_names = {v .name for v in vars_to_sample }
1031+ copy_from_trace_names = [var .name for var in output_vars if var .name not in sampled_names ]
1032+
9181033 ppc_trace_t = _DefaultTrace (samples )
9191034
9201035 progress = create_simple_progress (
@@ -942,9 +1057,14 @@ def sample_posterior_predictive(
9421057
9431058 values = sampler_fn (** param )
9441059
945- for k , v in zip (vars_ , values ):
1060+ for k , v in zip (vars_to_sample , values ):
9461061 ppc_trace_t .insert (k .name , v , idx )
9471062
1063+ # Copy output-only variables from trace
1064+ for name in copy_from_trace_names :
1065+ if name in param :
1066+ ppc_trace_t .insert (name , param [name ], idx )
1067+
9481068 progress .advance (task )
9491069 progress .update (task , refresh = True , completed = samples )
9501070
@@ -953,6 +1073,12 @@ def sample_posterior_predictive(
9531073
9541074 ppc_trace = ppc_trace_t .trace_dict
9551075
1076+ # Filter to only include requested output variables and sample_vars
1077+ output_var_names = {v .name for v in output_vars }
1078+ if sample_vars is not None :
1079+ output_var_names |= set (sample_vars )
1080+ ppc_trace = {k : v for k , v in ppc_trace .items () if k in output_var_names }
1081+
9561082 for k , ary in ppc_trace .items ():
9571083 if stacked_dims is not None :
9581084 ppc_trace [k ] = ary .reshape (
0 commit comments