Skip to content

Commit 12a18c2

Browse files
Use vectorized distributions for post-estimation tasks where possible (#629)
1 parent 81c6f0e commit 12a18c2

2 files changed

Lines changed: 22 additions & 32 deletions

File tree

pymc_extras/statespace/core/statespace.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,6 @@ def build_statespace_graph(
10211021
.. deprecated:: 0.2.5
10221022
The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
10231023
model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
1024-
10251024
"""
10261025
if mode is not None:
10271026
warnings.warn(
@@ -1559,7 +1558,11 @@ def sample_conditional_prior(
15591558
"""
15601559

15611560
return self._sample_conditional(
1562-
idata=idata, group="prior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1561+
idata=idata,
1562+
group="prior",
1563+
random_seed=random_seed,
1564+
mvn_method=mvn_method,
1565+
**kwargs,
15631566
)
15641567

15651568
def sample_conditional_posterior(
@@ -1602,7 +1605,11 @@ def sample_conditional_posterior(
16021605
"""
16031606

16041607
return self._sample_conditional(
1605-
idata=idata, group="posterior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1608+
idata=idata,
1609+
group="posterior",
1610+
random_seed=random_seed,
1611+
mvn_method=mvn_method,
1612+
**kwargs,
16061613
)
16071614

16081615
def sample_unconditional_prior(

pymc_extras/statespace/filters/distributions.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import pymc as pm
32
import pytensor
43
import pytensor.tensor as pt
@@ -8,7 +7,9 @@
87
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
98
from pymc.distributions.shape_utils import get_support_shape_1d
109
from pymc.logprob.abstract import _logprob
10+
from pymc.pytensorf import normalize_rng_param
1111
from pytensor.graph.basic import Node
12+
from pytensor.tensor.random import multivariate_normal
1213

1314
floatX = pytensor.config.floatX
1415
COV_ZERO_TOL = 0
@@ -152,6 +153,7 @@ def rv_op(
152153
Q,
153154
steps,
154155
size=None,
156+
rng=None,
155157
sequence_names=None,
156158
append_x0=True,
157159
method="svd",
@@ -178,7 +180,7 @@ def rv_op(
178180
]
179181
non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences]
180182

181-
rng = pytensor.shared(np.random.default_rng())
183+
rng = normalize_rng_param(rng)
182184

183185
def sort_args(args):
184186
sorted_args = []
@@ -367,44 +369,25 @@ def __new__(cls, *args, **kwargs):
367369

368370
@classmethod
369371
def dist(cls, mus, covs, logp, method="svd", **kwargs):
372+
mus, covs, logp = map(pt.as_tensor_variable, (mus, covs, logp))
370373
return super().dist([mus, covs, logp], method=method, **kwargs)
371374

372375
@classmethod
373-
def rv_op(cls, mus, covs, logp, method="svd", size=None):
374-
# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
375-
if mus.ndim > 2:
376-
mus = pt.moveaxis(mus, -2, 0)
377-
if covs.ndim > 3:
378-
covs = pt.moveaxis(covs, -3, 0)
379-
380-
mus_, covs_ = mus.type(), covs.type()
381-
376+
def rv_op(cls, mus, covs, logp, method="svd", size=None, rng=None):
377+
rng = normalize_rng_param(rng)
382378
logp_ = logp.type()
383-
rng = pytensor.shared(np.random.default_rng())
384-
385-
def step(mu, cov, rng):
386-
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
387-
return new_rng, mvn
388-
389-
seq_mvn_rng, mvn_seq = pytensor.scan(
390-
step,
391-
sequences=[mus_, covs_],
392-
outputs_info=[rng, None],
393-
strict=True,
394-
n_steps=mus_.shape[0],
395-
return_updates=False,
396-
)
397-
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
398379

399-
# Move time axis back to position -2 so batches are on the left
400-
if mvn_seq.ndim > 2:
401-
mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
380+
mus_, covs_ = mus.type(), covs.type()
381+
seq_mvn_rng, mvn_seq = multivariate_normal(
382+
mean=mus_, cov=covs_, rng=rng, method=method
383+
).owner.outputs
402384

403385
mvn_seq_op = KalmanFilterRV(
404386
inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
405387
)
406388

407389
mvn_seq = mvn_seq_op(mus, covs, logp, rng)
390+
408391
return mvn_seq
409392

410393

0 commit comments

Comments
 (0)