@@ -195,18 +195,18 @@ def _get_one_posterior_sample(
195195 and (
196196 (return_sites is None ) or (name in return_sites )
197197 ) # selected in return_sites list
198- and (
199- (
200- (not site .get ("is_observed" , True )) or return_observed
201- ) # don't save observed unless requested
202- or (site .get ("infer" , False ).get ("_deterministic" , False ))
203- ) # unless it is deterministic
204198 and not isinstance (
205199 site .get ("fn" , None ), poutine .subsample_messenger ._Subsample
206200 ) # don't save plates
207201 )
208202 }
209203
204+ if not return_observed :
205+ observed_not_deterministic = self ._get_observed_sites (* args , ** kwargs )
206+ sample = {
207+ k : v for k , v in sample .items () if k not in observed_not_deterministic
208+ }
209+
210210 sample = {name : site .cpu ().numpy () for name , site in sample .items ()}
211211
212212 return sample
@@ -309,21 +309,58 @@ def _get_obs_plate_sites(
309309 for name , site in trace .nodes .items ()
310310 if (
311311 (site ["type" ] == "sample" ) # sample statement
312- and (
313- (
314- (not site .get ("is_observed" , True )) or return_observed
315- ) # don't save observed unless requested
316- or (site .get ("infer" , False ).get ("_deterministic" , False ))
317- ) # unless it is deterministic
318312 and not isinstance (
319313 site .get ("fn" , None ), poutine .subsample_messenger ._Subsample
320314 ) # don't save plates
321315 )
322316 if any (f .name == plate_name for f in site ["cond_indep_stack" ])
323317 }
318+ if not return_observed :
319+ observed_not_deterministic = self ._get_observed_sites (* args , ** kwargs )
320+ obs_plate = {
321+ k : v
322+ for k , v in obs_plate .items ()
323+ if k not in observed_not_deterministic
324+ }
324325
325326 return obs_plate
326327
328+ def _get_observed_sites (
329+ self ,
330+ args : list ,
331+ kwargs : dict ,
332+ ):
333+ """
334+ Automatically guess which model sites correspond to observed variables
335+
336+ This excludes pyro.deterministic variables.
337+
338+ Parameters
339+ ----------
340+ args
341+ Arguments to the model.
342+ kwargs
343+ Keyword arguments to the model.
344+
345+ Returns
346+ -------
347+ List with site names.
348+ """
349+ trace = poutine .trace (self .module .model ).get_trace (* args , ** kwargs )
350+ observed_sites = [
351+ name
352+ for name , site in trace .nodes .items ()
353+ if (
354+ (site ["type" ] == "sample" ) # sample statement
355+ and (
356+ site .get ("is_observed" , True )
357+ and site .get ("infer" , False ).get ("_deterministic" , False )
358+ ) # exclude deterministic sites
359+ )
360+ ]
361+
362+ return observed_sites
363+
327364 def _posterior_samples_minibatch (
328365 self , use_gpu : bool = None , batch_size : Optional [int ] = None , ** sample_kwargs
329366 ):
0 commit comments