1- import numpy as np
21import pymc as pm
32import pytensor
43import pytensor .tensor as pt
87from pymc .distributions .distribution import Continuous , SymbolicRandomVariable
98from pymc .distributions .shape_utils import get_support_shape_1d
109from pymc .logprob .abstract import _logprob
10+ from pymc .pytensorf import normalize_rng_param
1111from pytensor .graph .basic import Node
12+ from pytensor .tensor .random import multivariate_normal
1213
1314floatX = pytensor .config .floatX
1415COV_ZERO_TOL = 0
@@ -152,6 +153,7 @@ def rv_op(
152153 Q ,
153154 steps ,
154155 size = None ,
156+ rng = None ,
155157 sequence_names = None ,
156158 append_x0 = True ,
157159 method = "svd" ,
@@ -178,7 +180,7 @@ def rv_op(
178180 ]
179181 non_sequences = [x for x in [c_ , d_ , T_ , Z_ , R_ , H_ , Q_ ] if x not in sequences ]
180182
181- rng = pytensor . shared ( np . random . default_rng () )
183+ rng = normalize_rng_param ( rng )
182184
183185 def sort_args (args ):
184186 sorted_args = []
@@ -367,44 +369,25 @@ def __new__(cls, *args, **kwargs):
367369
368370 @classmethod
369371 def dist (cls , mus , covs , logp , method = "svd" , ** kwargs ):
372+ mus , covs , logp = map (pt .as_tensor_variable , (mus , covs , logp ))
370373 return super ().dist ([mus , covs , logp ], method = method , ** kwargs )
371374
372375 @classmethod
373- def rv_op (cls , mus , covs , logp , method = "svd" , size = None ):
374- # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
375- if mus .ndim > 2 :
376- mus = pt .moveaxis (mus , - 2 , 0 )
377- if covs .ndim > 3 :
378- covs = pt .moveaxis (covs , - 3 , 0 )
379-
380- mus_ , covs_ = mus .type (), covs .type ()
381-
376+ def rv_op (cls , mus , covs , logp , method = "svd" , size = None , rng = None ):
377+ rng = normalize_rng_param (rng )
382378 logp_ = logp .type ()
383- rng = pytensor .shared (np .random .default_rng ())
384-
385- def step (mu , cov , rng ):
386- new_rng , mvn = pm .MvNormal .dist (mu = mu , cov = cov , rng = rng , method = method ).owner .outputs
387- return new_rng , mvn
388-
389- seq_mvn_rng , mvn_seq = pytensor .scan (
390- step ,
391- sequences = [mus_ , covs_ ],
392- outputs_info = [rng , None ],
393- strict = True ,
394- n_steps = mus_ .shape [0 ],
395- return_updates = False ,
396- )
397- mvn_seq = pt .specify_shape (mvn_seq , mus .type .shape )
398379
399- # Move time axis back to position -2 so batches are on the left
400- if mvn_seq .ndim > 2 :
401- mvn_seq = pt .moveaxis (mvn_seq , 0 , - 2 )
380+ mus_ , covs_ = mus .type (), covs .type ()
381+ seq_mvn_rng , mvn_seq = multivariate_normal (
382+ mean = mus_ , cov = covs_ , rng = rng , method = method
383+ ).owner .outputs
402384
403385 mvn_seq_op = KalmanFilterRV (
404386 inputs = [mus_ , covs_ , logp_ , rng ], outputs = [seq_mvn_rng , mvn_seq ], ndim_supp = 2
405387 )
406388
407389 mvn_seq = mvn_seq_op (mus , covs , logp , rng )
390+
408391 return mvn_seq
409392
410393
0 commit comments