3232from pytensor import tensor as pt
3333from pytensor .graph import vectorize_graph
3434from pytensor .graph .basic import (
35- Apply ,
3635 Constant ,
3736 Variable ,
3837)
@@ -110,9 +109,10 @@ def compile_forward_sampling_function(
110109 outputs : list [Variable ],
111110 vars_in_trace : list [Variable ],
112111 basic_rvs : list [Variable ] | None = None ,
113- givens_dict : dict [Variable , Any ] | None = None ,
114112 constant_data : dict [str , np .ndarray ] | None = None ,
115113 constant_coords : set [str ] | None = None ,
114+ volatile_vars : set [Variable ] | None = None ,
115+ freeze_vars : set [Variable ] | None = None ,
116116 ** kwargs ,
117117) -> tuple [Callable [..., np .ndarray | list [np .ndarray ]], set [Variable ]]:
118118 """Compile a function to draw samples, conditioned on the values of some variables.
@@ -125,12 +125,15 @@ def compile_forward_sampling_function(
125125 Volatile variables are variables whose values could change between runs of the
126126 compiled function or after inference has been run. These variables are:
127127
128- - Variables in the outputs list
128+ - Variables in ``volatile_vars``
129129 - ``SharedVariable`` instances that are not ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time
130130 - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
131- - Variables that are keys in the ``givens_dict``
132131 - Variables that have volatile inputs
133132
133+ Variables in ``freeze_vars`` are never considered volatile, regardless of the above
134+ rules. They act as volatility barriers, stopping the propagation of volatility to
135+ their dependents. Frozen variables are always treated as trace inputs.
136+
134137 Concretely, this function can be used to compile a function to sample from the
135138 posterior predictive distribution of a model that has variables that are conditioned
136139 on ``Data`` instances. The variables that depend on the mutable data that have changed
@@ -139,19 +142,11 @@ def compile_forward_sampling_function(
139142 ignored and new values will be computed (in the case of deterministics and potentials) or
140143 sampled (in the case of random variables).
141144
142- This function also enables a way to impute values for any variable in the computational
143- graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
144- to set the ``givens`` argument of the pytensor function compilation. This will essentially
145- replace a node in the computational graph with any other expression that has the same
146- type as the desired node. Passing variables in the givens_dict is considered an intervention
147- that might lead to different variable values from those that could have been seen during
148- inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
149- volatile**.
150-
151145 Parameters
152146 ----------
153147 outputs : List[pytensor.graph.basic.Variable]
154- The list of variables that will be returned by the compiled function
148+ The list of variables that will be returned by the compiled function.
149+ Outputs are not inherently volatile.
155150 vars_in_trace : List[pytensor.graph.basic.Variable]
156151 The list of variables that are assumed to have values stored in the trace
157152 basic_rvs : Optional[List[pytensor.graph.basic.Variable]]
@@ -160,10 +155,6 @@ def compile_forward_sampling_function(
160155 be considered as random variable instances. This includes variables that have
161156 a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
162157 Censored distributions.
163- givens_dict : Optional[Dict[pytensor.graph.basic.Variable, Any]]
164- A dictionary that maps tensor variables to the values that should be used to replace them
165- in the compiled function. The types of the key and value should match or an error will be
166- raised during compilation.
167158 constant_data : Optional[Dict[str, numpy.ndarray]]
168159 A dictionary that maps the names of ``Data`` instances to their
169160 corresponding values at inference time. If a model was created with ``Data``, these
@@ -183,6 +174,14 @@ def compile_forward_sampling_function(
183174 which case, it is considered volatile. If a ``SharedVariable`` is not found
184175 in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
185176 Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set.
177+ volatile_vars : Optional[Set[pytensor.graph.basic.Variable]]
178+ Variables that are unconditionally volatile. Volatility propagates from these
179+ to their dependents in the graph.
180+ freeze_vars : Optional[Set[pytensor.graph.basic.Variable]]
181+ A set of variables that should never be considered volatile, even if they would
182+ otherwise be due to having volatile inputs or depending on changed data. Frozen
183+ variables act as volatility barriers: they stop the propagation of volatility to
184+ their dependents and are always treated as inputs that pull values from the trace.
186185
187186 Returns
188187 -------
@@ -192,16 +191,17 @@ def compile_forward_sampling_function(
192191 Set of all basic_rvs that were considered volatile and will be resampled when
193192 the function is evaluated
194193 """
195- if givens_dict is None :
196- givens_dict = {}
197-
198194 if basic_rvs is None :
199195 basic_rvs = []
200196
201197 if constant_data is None :
202198 constant_data = {}
203199 if constant_coords is None :
204200 constant_coords = set ()
201+ if volatile_vars is None :
202+ volatile_vars = set ()
203+ if freeze_vars is None :
204+ freeze_vars = set ()
205205
206206 # We define a helper function to check if shared values match to an array
207207 def shared_value_matches (var ):
@@ -221,9 +221,10 @@ def shared_value_matches(var):
221221 ) # type: ignore[call-overload]
222222 volatile_nodes : set [Any ] = set ()
223223 for node in nodes :
224+ if node in freeze_vars :
225+ continue # Frozen variables are never volatile, and block propagation
224226 if (
225- node in fg .outputs
226- or node in givens_dict
227+ node in volatile_vars
227228 or ( # SharedVariables, except RandomState/Generators
228229 isinstance (node , SharedVariable )
229230 and not isinstance (node , RandomGeneratorSharedVariable )
@@ -260,20 +261,9 @@ def expand(node):
260261 # the entire graph
261262 list (walk (fg .outputs , expand ))
262263
263- # Populate the givens list
264- givens = [
265- (
266- node ,
267- value
268- if isinstance (value , Variable | Apply )
269- else pt .constant (value , dtype = getattr (node , "dtype" , None ), name = node .name ),
270- )
271- for node , value in givens_dict .items ()
272- ]
273-
274264 return (
275- compile (inputs , fg .outputs , givens = givens , on_unused_input = "ignore" , ** kwargs ),
276- set (basic_rvs ) & ( volatile_nodes - set ( givens_dict )) , # Basic RVs that will be resampled
265+ compile (inputs , fg .outputs , on_unused_input = "ignore" , ** kwargs ),
266+ set (basic_rvs ) & volatile_nodes , # Basic RVs that will be resampled
277267 )
278268
279269
@@ -450,7 +440,6 @@ def sample_prior_predictive(
450440 vars_to_sample ,
451441 vars_in_trace = [],
452442 basic_rvs = model .basic_RVs ,
453- givens_dict = None ,
454443 random_seed = random_seed ,
455444 ** compile_kwargs ,
456445 )
@@ -490,6 +479,8 @@ def sample_posterior_predictive(
490479 predictions : bool = False ,
491480 idata_kwargs : dict | None = None ,
492481 compile_kwargs : dict | None = None ,
482+ sample_vars : list [str ] | None = None ,
483+ freeze_vars : list [str ] | None = None ,
493484) -> InferenceData : ...
494485@overload
495486def sample_posterior_predictive (
@@ -505,6 +496,8 @@ def sample_posterior_predictive(
505496 predictions : bool = False ,
506497 idata_kwargs : dict | None = None ,
507498 compile_kwargs : dict | None = None ,
499+ sample_vars : list [str ] | None = None ,
500+ freeze_vars : list [str ] | None = None ,
508501) -> dict [str , np .ndarray ]: ...
509502def sample_posterior_predictive (
510503 trace ,
@@ -519,6 +512,8 @@ def sample_posterior_predictive(
519512 predictions : bool = False ,
520513 idata_kwargs : dict | None = None ,
521514 compile_kwargs : dict | None = None ,
515+ sample_vars : list [str ] | None = None ,
516+ freeze_vars : list [str ] | None = None ,
522517) -> InferenceData | dict [str , np .ndarray ]:
523518 """Generate forward samples for `var_names`, conditioned on the posterior samples of variables found in the `trace`.
524519
@@ -539,9 +534,10 @@ def sample_posterior_predictive(
539534 Model to be used to generate the posterior predictive samples. It will
540535 generally be the model used to generate the `trace`, but it doesn't need to be.
541536 var_names : Iterable[str], optional
542- Names of variables for which to compute the posterior predictive samples.
543- By default, only observed variables are sampled.
544- See the example below for what happens when this argument is customized.
537+ Names of variables to include in the returned dataset. This only controls which
538+ variables appear in the output, not which variables are resampled. Use ``sample_vars``
539+ to control resampling. By default, observed variables and their dependent
540+ deterministics are included.
545541 sample_dims : list of str, optional
546542 Dimensions over which to loop and generate posterior predictive samples.
547543 When ``sample_dims`` is ``None`` (default) both "chain" and "draw" are considered sample
@@ -567,6 +563,16 @@ def sample_posterior_predictive(
567563 :func:`pymc.predictions_to_inference_data` otherwise.
568564 compile_kwargs: dict, optional
569565 Keyword arguments for :func:`pymc.pytensorf.compile`.
566+ sample_vars : list of str, optional
567+ Names of unobserved variables that should be explicitly resampled. Observed variables
568+ are always resampled. Use this to request resampling of specific unobserved variables
569+ (e.g., for out-of-model predictions or forecasting). Variables not in ``sample_vars``
570+ will have their values taken from the trace if available.
571+ freeze_vars : list of str, optional
572+ Names of variables that should always be reused from the trace, even if they would
573+ otherwise be resampled due to depending on changed data or other resampled variables.
574+ Frozen variables act as barriers that stop the propagation of "volatility" in the
575+ computational graph. Must be present in the trace. Cannot overlap with ``sample_vars``.
570576
571577 Returns
572578 -------
@@ -859,17 +865,39 @@ def sample_posterior_predictive(
859865
860866 constant_coords = get_constant_coords (trace_coords , model )
861867
868+ # Resolve output variables (what to return in the dataset)
862869 if var_names is not None :
863- vars_ = [model [x ] for x in var_names ]
870+ output_vars = [model [x ] for x in var_names ]
864871 else :
865872 observed_vars = model .observed_RVs
866873 if observed_data is not None :
867874 observed_vars += [
868875 model [x ] for x in observed_data if x in model and x not in observed_vars
869876 ]
870- vars_ = observed_vars + observed_dependent_deterministics (model , observed_vars )
871-
872- vars_to_sample = list (get_default_varnames (vars_ , include_transformed = False ))
877+ output_vars = observed_vars + observed_dependent_deterministics (model , observed_vars )
878+
879+ # Resolve variables to resample
880+ # Observed variables are always resampled, plus any explicit sample_vars
881+ resample_vars : list [Variable ] = list (model .observed_RVs )
882+ if sample_vars is not None :
883+ resample_vars += [model [x ] for x in sample_vars ]
884+
885+ # Compiled function outputs = resample_vars + dependent deterministics +
886+ # deterministics from var_names (they need recomputation, not "resampling")
887+ basic_rv_set = set (model .basic_RVs )
888+ compiled_outputs = list (resample_vars )
889+ compiled_outputs += observed_dependent_deterministics (model , resample_vars )
890+ # Add deterministics from var_names that need recomputation
891+ for var in output_vars :
892+ if var not in basic_rv_set and var not in compiled_outputs :
893+ compiled_outputs .append (var )
894+
895+ vars_to_sample = list (
896+ get_default_varnames (
897+ list ({v .name : v for v in compiled_outputs }.values ()),
898+ include_transformed = False ,
899+ )
900+ )
873901
874902 if not vars_to_sample :
875903 if return_inferencedata and not extend_inferencedata :
@@ -880,6 +908,27 @@ def sample_posterior_predictive(
880908
881909 vars_in_trace = get_vars_in_point_list (_trace , model )
882910
911+ # Resolve freeze vars
912+ frozen : set [Variable ] = set ()
913+ if freeze_vars is not None :
914+ frozen = {model [x ] for x in freeze_vars }
915+ # Validate: freeze_vars must be in trace
916+ vars_in_trace_names = {v .name for v in vars_in_trace }
917+ missing = {x for x in freeze_vars if x not in vars_in_trace_names }
918+ if missing :
919+ raise ValueError (
920+ f"freeze_vars { sorted (missing )} are not present in the trace. "
921+ f"Cannot freeze variables without stored values."
922+ )
923+ # Validate: freeze_vars and sample_vars must be disjoint
924+ if sample_vars is not None :
925+ overlap = set (freeze_vars ) & set (sample_vars )
926+ if overlap :
927+ raise ValueError (
928+ f"Variables { sorted (overlap )} are in both sample_vars and freeze_vars. "
929+ f"A variable cannot be both resampled and frozen."
930+ )
931+
883932 if random_seed is not None :
884933 (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
885934
@@ -892,15 +941,41 @@ def sample_posterior_predictive(
892941 outputs = vars_to_sample ,
893942 vars_in_trace = vars_in_trace ,
894943 basic_rvs = model .basic_RVs ,
895- givens_dict = None ,
896944 random_seed = random_seed ,
897945 constant_data = constant_data ,
898946 constant_coords = constant_coords ,
947+ volatile_vars = set (resample_vars ),
948+ freeze_vars = frozen ,
899949 ** compile_kwargs ,
900950 )
901951 sampler_fn = point_wrapper (_sampler_fn )
952+
953+ # Warn about implicitly volatile trace variables
954+ vars_in_trace_set = set (vars_in_trace )
955+ sample_var_names = set (sample_vars ) if sample_vars is not None else set ()
956+ implicit_volatile = {
957+ rv
958+ for rv in volatile_basic_rvs
959+ if rv in vars_in_trace_set and rv .name not in sample_var_names
960+ }
961+ if implicit_volatile :
962+ implicit_names = sorted (rv .name for rv in implicit_volatile ) # type: ignore[type-var]
963+ warnings .warn (
964+ f"Variables { implicit_names } are in the trace but will be resampled because "
965+ f"they depend on data/coords that changed. To silence this warning, add them "
966+ f"to `sample_vars` explicitly, or add them to `freeze_vars` to reuse their "
967+ f"trace values." ,
968+ UserWarning ,
969+ stacklevel = 2 ,
970+ )
971+
902972 # All model variables have a name, but mypy does not know this
903973 _log .info (f"Sampling: { sorted (volatile_basic_rvs , key = lambda var : var .name )} " ) # type: ignore[arg-type, return-value]
974+
975+ # Determine output-only variables that should be copied from trace, not sampled
976+ sampled_names = {v .name for v in vars_to_sample }
977+ copy_from_trace_names = [var .name for var in output_vars if var .name not in sampled_names ]
978+
904979 ppc_trace_t = _DefaultTrace (samples )
905980
906981 progress = create_simple_progress (
@@ -928,9 +1003,14 @@ def sample_posterior_predictive(
9281003
9291004 values = sampler_fn (** param )
9301005
931- for k , v in zip (vars_ , values ):
1006+ for k , v in zip (vars_to_sample , values ):
9321007 ppc_trace_t .insert (k .name , v , idx )
9331008
1009+ # Copy output-only variables from trace
1010+ for name in copy_from_trace_names :
1011+ if name in param :
1012+ ppc_trace_t .insert (name , param [name ], idx )
1013+
9341014 progress .advance (task )
9351015 progress .update (task , refresh = True , completed = samples )
9361016
@@ -939,6 +1019,12 @@ def sample_posterior_predictive(
9391019
9401020 ppc_trace = ppc_trace_t .trace_dict
9411021
1022+ # Filter to only include requested output variables and sample_vars
1023+ output_var_names = {v .name for v in output_vars }
1024+ if sample_vars is not None :
1025+ output_var_names |= set (sample_vars )
1026+ ppc_trace = {k : v for k , v in ppc_trace .items () if k in output_var_names }
1027+
9421028 for k , ary in ppc_trace .items ():
9431029 if stacked_dims is not None :
9441030 ppc_trace [k ] = ary .reshape (
0 commit comments