1414def test_is_conditional_dependent_static_shape ():
1515 """Test that we don't consider dependencies through "constant" shape Ops"""
1616 x1 = pt .matrix ("x1" , shape = (None , 5 ))
17- y1 = pt .random .normal (size = pt .shape (x1 ))
17+ _ , y1 = pt .random .normal (
18+ size = pt .shape (x1 ), rng = pt .random .shared_rng (seed = 0 ), return_next_rng = True
19+ )
1820 assert is_conditional_dependent (y1 , x1 , [x1 , y1 ])
1921
2022 x2 = pt .matrix ("x2" , shape = (9 , 5 ))
21- y2 = pt .random .normal (size = pt .shape (x2 ))
23+ _ , y2 = pt .random .normal (
24+ size = pt .shape (x2 ), rng = pt .random .shared_rng (seed = 0 ), return_next_rng = True
25+ )
2226 assert not is_conditional_dependent (y2 , x2 , [x2 , y2 ])
2327
2428
@@ -145,25 +149,36 @@ def test_blockwise(self):
145149 def test_random_variable (self ):
146150 inp = pt .tensor (shape = (5 , 4 , 3 ))
147151
148- out1 = pt .random .normal (loc = inp )
149- out2 = pt .random .categorical (p = inp [..., None ])
150- out3 = pt .random .multivariate_normal (mean = inp [..., None ], cov = pt .eye (1 ))
152+ _ , out1 = pt .random .normal (loc = inp , rng = pt .random .shared_rng (seed = 0 ), return_next_rng = True )
153+ _ , out2 = pt .random .categorical (
154+ p = inp [..., None ], rng = pt .random .shared_rng (seed = 0 ), return_next_rng = True
155+ )
156+ _ , out3 = pt .random .multivariate_normal (
157+ mean = inp [..., None ],
158+ cov = pt .eye (1 ),
159+ rng = pt .random .shared_rng (seed = 0 ),
160+ return_next_rng = True ,
161+ )
151162 [dims1 , dims2 , dims3 ] = subgraph_batch_dim_connection (inp , [out1 , out2 , out3 ])
152163 assert dims1 == (0 , 1 , 2 )
153164 assert dims2 == (0 , 1 , 2 )
154165 assert dims3 == (0 , 1 , 2 , None )
155166
156- invalid_out = pt .random .categorical (p = inp )
167+ _ , invalid_out = pt .random .categorical (
168+ p = inp , rng = pt .random .shared_rng (seed = 0 ), return_next_rng = True
169+ )
157170 with pytest .raises (ValueError , match = "Use of known dimensions" ):
158171 subgraph_batch_dim_connection (inp , [invalid_out ])
159172
160- invalid_out = pt .random .multivariate_normal (mean = inp , cov = pt .eye (3 ))
173+ _ , invalid_out = pt .random .multivariate_normal (
174+ mean = inp , cov = pt .eye (3 ), rng = pt .random .shared_rng (seed = 0 ), return_next_rng = True
175+ )
161176 with pytest .raises (ValueError , match = "Use of known dimensions" ):
162177 subgraph_batch_dim_connection (inp , [invalid_out ])
163178
164179 def test_minibatched_random_variable (self ):
165180 inp = pt .tensor (shape = (4 , 3 , 2 ))
166- out1 = pt .random .normal (loc = inp )
181+ _ , out1 = pt .random .normal (loc = inp , rng = pt . random . shared_rng ( seed = 0 ), return_next_rng = True )
167182 out2 = create_minibatch_rv (out1 , total_size = (10 , 10 , 10 ))
168183 [dims1 ] = subgraph_batch_dim_connection (inp , [out2 ])
169184 assert dims1 == (0 , 1 , 2 )
@@ -174,7 +189,9 @@ def test_symbolic_random_variable(self):
174189 # Test univariate
175190 out = CustomDist .dist (
176191 inp ,
177- dist = lambda mu , size : pt .random .normal (loc = mu , size = size ),
192+ dist = lambda mu , size : pt .random .normal (
193+ loc = mu , size = size , rng = pt .random .shared_rng (seed = 0 ), return_next_rng = True
194+ )[1 ],
178195 )
179196 [dims ] = subgraph_batch_dim_connection (inp , [out ])
180197 assert dims == (0 , 1 , 2 )
@@ -183,7 +200,13 @@ def test_symbolic_random_variable(self):
183200 def dist (mu , size ):
184201 if isinstance (size .type , NoneTypeT ):
185202 size = mu .shape
186- return pt .random .normal (loc = mu [..., None ], size = (* size , 2 ))
203+ _ , rv = pt .random .normal (
204+ loc = mu [..., None ],
205+ size = (* size , 2 ),
206+ rng = pt .random .shared_rng (seed = 0 ),
207+ return_next_rng = True ,
208+ )
209+ return rv
187210
188211 out = CustomDist .dist (inp , dist = dist , size = (4 , 3 , 2 ), signature = "()->(2)" )
189212 [dims ] = subgraph_batch_dim_connection (inp , [out ])
0 commit comments