Skip to content

Commit b8a449d

Browse files
jessegrabowskiricardoV94
authored andcommitted
Update distributions to handle new rng semantics
1 parent 9f89e6a commit b8a449d

8 files changed

Lines changed: 97 additions & 33 deletions

File tree

pymc_extras/distributions/continuous.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ class Chi:
270270

271271
@staticmethod
272272
def chi_dist(nu: TensorVariable, size: TensorVariable) -> TensorVariable:
273-
return pt.math.sqrt(ChiSquared.dist(nu=nu, size=size))
273+
_, rv = ChiSquared.dist(
274+
nu=nu, size=size, rng=pt.random.shared_rng(seed=0), return_next_rng=True
275+
)
276+
return pt.math.sqrt(rv)
274277

275278
def __new__(cls, name, nu, **kwargs):
276279
if "observed" not in kwargs:
@@ -331,7 +334,8 @@ def maxwell_dist(a: TensorVariable, size: TensorVariable) -> TensorVariable:
331334

332335
a = CheckParameterValue("a > 0")(a, pt.all(pt.gt(a, 0)))
333336

334-
return Chi.dist(nu=3, size=size) * a
337+
_, chi = Chi.dist(nu=3, size=size, rng=pt.random.shared_rng(seed=0), return_next_rng=True)
338+
return chi * a
335339

336340
def __new__(cls, name, a, **kwargs):
337341
return CustomDist(

pymc_extras/distributions/discrete.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,13 @@ def beta_negative_binomial_dist(alpha, beta, r, size):
252252
if rv_size_is_none(size):
253253
alpha, beta, r = pt.broadcast_arrays(alpha, beta, r)
254254

255-
p = pm.Beta.dist(alpha, beta, size=size)
256-
return pm.NegativeBinomial.dist(p, r, size=size)
255+
_, p = pm.Beta.dist(
256+
alpha, beta, size=size, rng=pt.random.shared_rng(seed=0), return_next_rng=True
257+
)
258+
_, rv = pm.NegativeBinomial.dist(
259+
p, r, size=size, rng=pt.random.shared_rng(seed=0), return_next_rng=True
260+
)
261+
return rv
257262

258263
@staticmethod
259264
def beta_negative_binomial_logp(value, alpha, beta, r):
@@ -361,7 +366,13 @@ def skellam_dist(mu1, mu2, size):
361366
if rv_size_is_none(size):
362367
mu1, mu2 = pt.broadcast_arrays(mu1, mu2)
363368

364-
return pm.Poisson.dist(mu=mu1, size=size) - pm.Poisson.dist(mu=mu2, size=size)
369+
_, p1 = pm.Poisson.dist(
370+
mu=mu1, size=size, rng=pt.random.shared_rng(seed=0), return_next_rng=True
371+
)
372+
_, p2 = pm.Poisson.dist(
373+
mu=mu2, size=size, rng=pt.random.shared_rng(seed=0), return_next_rng=True
374+
)
375+
return p1 - p2
365376

366377
@staticmethod
367378
def skellam_logp(value, mu1, mu2):

pymc_extras/distributions/timeseries.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwar
175175
UserWarning,
176176
)
177177
k = P.shape[-1]
178-
init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))
178+
_, init_dist = pm.Categorical.dist(
179+
p=pt.full((k,), 1 / k), rng=pt.random.shared_rng(seed=0), return_next_rng=True
180+
)
179181

180182
return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)
181183

@@ -198,7 +200,7 @@ def rv_op(cls, P, steps, init_dist, n_lags, size=None):
198200
def transition(*args):
199201
old_rng, *states, transition_probs = args
200202
p = transition_probs[tuple(states)]
201-
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
203+
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng, return_next_rng=True)
202204
return next_rng, next_state
203205

204206
state_next_rng, markov_chain = pytensor.scan(

pymc_extras/inference/advi/autoguide.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,14 @@ def AutoDiagonalNormal(model: Model) -> AutoGuideModel:
8787
loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape)
8888
scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape)
8989
# TODO: Make these customizable
90-
params_init_values[loc] = pt.random.uniform(-1, 1, size=free_rv_shapes[rv]).eval()
90+
_, loc_init = pt.random.uniform(
91+
-1,
92+
1,
93+
size=free_rv_shapes[rv],
94+
rng=pt.random.shared_rng(seed=0),
95+
return_next_rng=True,
96+
)
97+
params_init_values[loc] = loc_init.eval()
9198
params_init_values[scale] = pt.full(free_rv_shapes[rv], 0.1).eval()
9299

93100
z = Normal(

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,23 @@ def draws_from_laplace_approx(
232232
size = (draws,) if vectorize_draws else ()
233233
if covariance is not None:
234234
sigma_pt = pt.matrix("cov", shape=(n, n), dtype=covariance.dtype)
235-
laplace_approximation = pm.MvNormal.dist(mu=mu_pt, cov=sigma_pt, size=size, method="svd")
235+
_, laplace_approximation = pm.MvNormal.dist(
236+
mu=mu_pt,
237+
cov=sigma_pt,
238+
size=size,
239+
method="svd",
240+
rng=pt.random.shared_rng(seed=0),
241+
return_next_rng=True,
242+
)
236243
else:
237244
sigma_pt = pt.vector("sigma", shape=(n,), dtype=standard_deviation.dtype)
238-
laplace_approximation = pm.Normal.dist(mu=mu_pt, sigma=sigma_pt, size=(*size, n))
245+
_, laplace_approximation = pm.Normal.dist(
246+
mu=mu_pt,
247+
sigma=sigma_pt,
248+
size=(*size, n),
249+
rng=pt.random.shared_rng(seed=0),
250+
return_next_rng=True,
251+
)
239252

240253
constrained_vars = unpack_last_axis(
241254
laplace_approximation,

pymc_extras/model/transforms/autoreparam.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,12 @@ def _(
231231
rng, size, loc, scale = node.inputs
232232
if transform is not None:
233233
raise NotImplementedError("Reparametrization of Normal with Transform is not implemented")
234-
vip_rv_ = pm.Normal.dist(
234+
_, vip_rv_ = pm.Normal.dist(
235235
lam * loc,
236236
scale**lam,
237237
size=size,
238238
rng=rng,
239+
return_next_rng=True,
239240
)
240241
vip_rv_.name = f"{name}::tau_"
241242

@@ -266,10 +267,11 @@ def _(
266267
rng, size, scale = node.inputs
267268
scale_centered = scale**lam
268269
scale_noncentered = scale ** (1 - lam)
269-
vip_rv_ = pm.Exponential.dist(
270+
_, vip_rv_ = pm.Exponential.dist(
270271
scale=scale_centered,
271272
size=size,
272273
rng=rng,
274+
return_next_rng=True,
273275
)
274276
vip_rv_value_ = vip_rv_.clone()
275277
vip_rv_.name = f"{name}::tau_"

pymc_extras/statespace/filters/distributions.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,6 @@ def rv_op(
180180
]
181181
non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences]
182182

183-
rng = normalize_rng_param(rng)
184-
185183
def sort_args(args):
186184
sorted_args = []
187185

@@ -206,11 +204,11 @@ def step_fn(*args):
206204
a = state[:k]
207205

208206
middle_rng, a_innovation = pm.MvNormal.dist(
209-
mu=0, cov=Q, rng=rng, method=method
210-
).owner.outputs
207+
mu=0, cov=Q, rng=rng, method=method, return_next_rng=True
208+
)
211209
next_rng, y_innovation = pm.MvNormal.dist(
212-
mu=0, cov=H, rng=middle_rng, method=method
213-
).owner.outputs
210+
mu=0, cov=H, rng=middle_rng, method=method, return_next_rng=True
211+
)
214212

215213
a_mu = c + T @ a
216214
a_next = a_mu + R @ a_innovation
@@ -225,14 +223,18 @@ def step_fn(*args):
225223
Z_init = Z_ if Z_ in non_sequences else Z_[0]
226224
H_init = H_ if H_ in non_sequences else H_[0]
227225

228-
init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method=method)
229-
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method=method)
226+
rng = normalize_rng_param(rng)
227+
228+
next_rng, init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method=method, return_next_rng=True)
229+
next_rng, init_y_ = pm.MvNormal.dist(
230+
Z_init @ init_x_, H_init, rng=next_rng, method=method, return_next_rng=True
231+
)
230232

231233
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
232234

233235
ss_rng, statespace = pytensor.scan(
234236
step_fn,
235-
outputs_info=[rng, init_dist_],
237+
outputs_info=[next_rng, init_dist_],
236238
sequences=None if len(sequences) == 0 else sequences,
237239
non_sequences=[*non_sequences],
238240
n_steps=steps,
@@ -379,8 +381,8 @@ def rv_op(cls, mus, covs, logp, method="svd", size=None, rng=None):
379381

380382
mus_, covs_ = mus.type(), covs.type()
381383
seq_mvn_rng, mvn_seq = multivariate_normal(
382-
mean=mus_, cov=covs_, rng=rng, method=method
383-
).owner.outputs
384+
mean=mus_, cov=covs_, rng=rng, method=method, return_next_rng=True
385+
)
384386

385387
mvn_seq_op = KalmanFilterRV(
386388
inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2

tests/model/marginal/test_graph_analysis.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
def test_is_conditional_dependent_static_shape():
1515
"""Test that we don't consider dependencies through "constant" shape Ops"""
1616
x1 = pt.matrix("x1", shape=(None, 5))
17-
y1 = pt.random.normal(size=pt.shape(x1))
17+
_, y1 = pt.random.normal(
18+
size=pt.shape(x1), rng=pt.random.shared_rng(seed=0), return_next_rng=True
19+
)
1820
assert is_conditional_dependent(y1, x1, [x1, y1])
1921

2022
x2 = pt.matrix("x2", shape=(9, 5))
21-
y2 = pt.random.normal(size=pt.shape(x2))
23+
_, y2 = pt.random.normal(
24+
size=pt.shape(x2), rng=pt.random.shared_rng(seed=0), return_next_rng=True
25+
)
2226
assert not is_conditional_dependent(y2, x2, [x2, y2])
2327

2428

@@ -145,25 +149,36 @@ def test_blockwise(self):
145149
def test_random_variable(self):
146150
inp = pt.tensor(shape=(5, 4, 3))
147151

148-
out1 = pt.random.normal(loc=inp)
149-
out2 = pt.random.categorical(p=inp[..., None])
150-
out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1))
152+
_, out1 = pt.random.normal(loc=inp, rng=pt.random.shared_rng(seed=0), return_next_rng=True)
153+
_, out2 = pt.random.categorical(
154+
p=inp[..., None], rng=pt.random.shared_rng(seed=0), return_next_rng=True
155+
)
156+
_, out3 = pt.random.multivariate_normal(
157+
mean=inp[..., None],
158+
cov=pt.eye(1),
159+
rng=pt.random.shared_rng(seed=0),
160+
return_next_rng=True,
161+
)
151162
[dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
152163
assert dims1 == (0, 1, 2)
153164
assert dims2 == (0, 1, 2)
154165
assert dims3 == (0, 1, 2, None)
155166

156-
invalid_out = pt.random.categorical(p=inp)
167+
_, invalid_out = pt.random.categorical(
168+
p=inp, rng=pt.random.shared_rng(seed=0), return_next_rng=True
169+
)
157170
with pytest.raises(ValueError, match="Use of known dimensions"):
158171
subgraph_batch_dim_connection(inp, [invalid_out])
159172

160-
invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3))
173+
_, invalid_out = pt.random.multivariate_normal(
174+
mean=inp, cov=pt.eye(3), rng=pt.random.shared_rng(seed=0), return_next_rng=True
175+
)
161176
with pytest.raises(ValueError, match="Use of known dimensions"):
162177
subgraph_batch_dim_connection(inp, [invalid_out])
163178

164179
def test_minibatched_random_variable(self):
165180
inp = pt.tensor(shape=(4, 3, 2))
166-
out1 = pt.random.normal(loc=inp)
181+
_, out1 = pt.random.normal(loc=inp, rng=pt.random.shared_rng(seed=0), return_next_rng=True)
167182
out2 = create_minibatch_rv(out1, total_size=(10, 10, 10))
168183
[dims1] = subgraph_batch_dim_connection(inp, [out2])
169184
assert dims1 == (0, 1, 2)
@@ -174,7 +189,9 @@ def test_symbolic_random_variable(self):
174189
# Test univariate
175190
out = CustomDist.dist(
176191
inp,
177-
dist=lambda mu, size: pt.random.normal(loc=mu, size=size),
192+
dist=lambda mu, size: pt.random.normal(
193+
loc=mu, size=size, rng=pt.random.shared_rng(seed=0), return_next_rng=True
194+
)[1],
178195
)
179196
[dims] = subgraph_batch_dim_connection(inp, [out])
180197
assert dims == (0, 1, 2)
@@ -183,7 +200,13 @@ def test_symbolic_random_variable(self):
183200
def dist(mu, size):
184201
if isinstance(size.type, NoneTypeT):
185202
size = mu.shape
186-
return pt.random.normal(loc=mu[..., None], size=(*size, 2))
203+
_, rv = pt.random.normal(
204+
loc=mu[..., None],
205+
size=(*size, 2),
206+
rng=pt.random.shared_rng(seed=0),
207+
return_next_rng=True,
208+
)
209+
return rv
187210

188211
out = CustomDist.dist(inp, dist=dist, size=(4, 3, 2), signature="()->(2)")
189212
[dims] = subgraph_batch_dim_connection(inp, [out])

0 commit comments

Comments
 (0)