Skip to content

Commit ed3a3f3

Browse files
committed
add everything as one commit
1 parent 91193e4 commit ed3a3f3

1 file changed

Lines changed: 49 additions & 12 deletions

File tree

scvi/model/base/_pyromixin.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,18 +195,18 @@ def _get_one_posterior_sample(
195195
and (
196196
(return_sites is None) or (name in return_sites)
197197
) # selected in return_sites list
198-
and (
199-
(
200-
(not site.get("is_observed", True)) or return_observed
201-
) # don't save observed unless requested
202-
or (site.get("infer", False).get("_deterministic", False))
203-
) # unless it is deterministic
204198
and not isinstance(
205199
site.get("fn", None), poutine.subsample_messenger._Subsample
206200
) # don't save plates
207201
)
208202
}
209203

204+
if not return_observed:
205+
observed_not_deterministic = self._get_observed_sites(*args, **kwargs)
206+
sample = {
207+
k: v for k, v in sample.items() if k not in observed_not_deterministic
208+
}
209+
210210
sample = {name: site.cpu().numpy() for name, site in sample.items()}
211211

212212
return sample
@@ -309,21 +309,58 @@ def _get_obs_plate_sites(
309309
for name, site in trace.nodes.items()
310310
if (
311311
(site["type"] == "sample") # sample statement
312-
and (
313-
(
314-
(not site.get("is_observed", True)) or return_observed
315-
) # don't save observed unless requested
316-
or (site.get("infer", False).get("_deterministic", False))
317-
) # unless it is deterministic
318312
and not isinstance(
319313
site.get("fn", None), poutine.subsample_messenger._Subsample
320314
) # don't save plates
321315
)
322316
if any(f.name == plate_name for f in site["cond_indep_stack"])
323317
}
318+
if not return_observed:
319+
observed_not_deterministic = self._get_observed_sites(*args, **kwargs)
320+
obs_plate = {
321+
k: v
322+
for k, v in obs_plate.items()
323+
if k not in observed_not_deterministic
324+
}
324325

325326
return obs_plate
326327

328+
def _get_observed_sites(
329+
self,
330+
args: list,
331+
kwargs: dict,
332+
):
333+
"""
334+
Automatically guess which model sites correspond to observed variables
335+
336+
This excludes pyro.deterministic variables.
337+
338+
Parameters
339+
----------
340+
args
341+
Arguments to the model.
342+
kwargs
343+
Keyword arguments to the model.
344+
345+
Returns
346+
-------
347+
List with site names.
348+
"""
349+
trace = poutine.trace(self.module.model).get_trace(*args, **kwargs)
350+
observed_sites = [
351+
name
352+
for name, site in trace.nodes.items()
353+
if (
354+
(site["type"] == "sample") # sample statement
355+
and (
356+
site.get("is_observed", True)
357+
and site.get("infer", False).get("_deterministic", False)
358+
) # exclude deterministic sites
359+
)
360+
]
361+
362+
return observed_sites
363+
327364
def _posterior_samples_minibatch(
328365
self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs
329366
):

0 commit comments

Comments
 (0)