Skip to content

Commit 3a915cc

Browse files
committed
Explicit rng in expand_dist_dims
1 parent a4d9833 commit 3a915cc

4 files changed

Lines changed: 9 additions & 1 deletion

File tree

pymc/dims/distributions/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ def expand_dist_dims(dist: XTensorVariable, extra_dims: dict[str, Any]) -> XTens
371371
dist_props["extra_dims"] = (*(extra_dims.keys()), *dist_props["extra_dims"])
372372
new_dist_op = type(dist.owner.op)(**dist_props)
373373
_old_rng, *params_and_dim_lengths = dist.owner.inputs
374-
new_rng = None # We don't propagate the old RNG, because we don't want the new and old dists to be correlated
374+
# We don't propagate the old RNG, because we don't want the new and old dists to be correlated
375+
new_rng = pt.random.shared_rng(seed=None)
375376
return new_dist_op(new_rng, *extra_dims.values(), *params_and_dim_lengths)
376377
case Transpose():
377378
return expand_dist_dims(dist.owner.inputs[0], extra_dims=extra_dims).transpose(

tests/dims/distributions/test_censored.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from pymc.model.core import Model
2525
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph
2626

27+
pytestmark = pytest.mark.filterwarnings("error")
28+
2729

2830
@pytest.mark.parametrize("lower", [None, -1])
2931
@pytest.mark.parametrize("upper", [None, 1])

tests/dims/distributions/test_scalar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
)
4040
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph
4141

42+
pytestmark = pytest.mark.filterwarnings("error")
43+
4244

4345
def test_flat():
4446
coords = {"a": range(3)}

tests/dims/distributions/test_vector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import numpy as np
1515
import pytensor.tensor as pt
16+
import pytest
1617

1718
from pytensor.xtensor import as_xtensor
1819

@@ -22,6 +23,8 @@
2223
from pymc.dims import Categorical, Dirichlet, MvNormal, ZeroSumNormal
2324
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph
2425

26+
pytestmark = pytest.mark.filterwarnings("error")
27+
2528

2629
def test_categorical():
2730
coords = {"a": range(3), "b": range(4)}

0 commit comments

Comments
 (0)