Skip to content

Commit 6dba2f2

Browse files
committed
Reduce unnecessary iteration counts in slow tests
Targets tests identified by ricardoV94 in #7686: - test_weakref_leak: reduce warmup+check from 20 to 8 iterations - test_default_value_transform_logprob: reduce loop from 10 to 3 - test_interpolated (TestMatchesScipy): reduce x_points from 100k to 20k - test_progressbar_nested_compound: reduce draws/tune from 10 to 5 - TestLKJCorr: reduce n values from (2,10,50) to (2,50) - TestLKJCholeskyCov: reduce sizes_to_check from 7 to 4 - TestLKJCholeskCov: reduce parametrize combos from 6 to 4 - TestInterpolated: reduce mu*sigma loop from 56 to 4 combos Not modified: test_mvnormal_indef (26.8s) — slowness is PyTensor compilation overhead, not reducible without removing coverage.
1 parent d3ec5c7 commit 6dba2f2

4 files changed

Lines changed: 38 additions & 36 deletions

File tree

tests/distributions/test_continuous.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ class TestedInterpolated(pm.Interpolated):
965965

966966
@classmethod
967967
def dist(cls, **kwargs):
968-
x_points = np.linspace(xmin, xmax, 100000)
968+
x_points = np.linspace(xmin, xmax, 20000)
969969
pdf_points = st.norm.pdf(x_points, loc=mu, scale=sigma)
970970
return super().dist(x_points=x_points, pdf_points=pdf_points, **kwargs)
971971

@@ -2503,25 +2503,30 @@ def interpolated_rng_fn(self, size, mu, sigma, rng):
25032503
]
25042504

25052505
def check_draws(self):
2506-
for mu in R.vals:
2507-
for sigma in Rplus.vals:
2508-
rng = self.get_random_state()
2509-
2510-
def ref_rand(size):
2511-
return st.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng)
2512-
2513-
class TestedInterpolated(pm.Interpolated):
2514-
rv_op = interpolated
2515-
2516-
@classmethod
2517-
def dist(cls, **kwargs):
2518-
x_points = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100)
2519-
pdf_points = st.norm.pdf(x_points, loc=mu, scale=sigma)
2520-
return super().dist(x_points=x_points, pdf_points=pdf_points, **kwargs)
2521-
2522-
continuous_random_tester(
2523-
TestedInterpolated,
2524-
{},
2525-
extra_args={"rng": pytensor.shared(rng)},
2526-
ref_rand=ref_rand,
2527-
)
2506+
representative_params = [
2507+
(0.0, 1.0),
2508+
(-2.1, 0.1),
2509+
(1.0, 100.0),
2510+
(0.01, 0.99),
2511+
]
2512+
for mu, sigma in representative_params:
2513+
rng = self.get_random_state()
2514+
2515+
def ref_rand(size):
2516+
return st.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng)
2517+
2518+
class TestedInterpolated(pm.Interpolated):
2519+
rv_op = interpolated
2520+
2521+
@classmethod
2522+
def dist(cls, **kwargs):
2523+
x_points = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100)
2524+
pdf_points = st.norm.pdf(x_points, loc=mu, scale=sigma)
2525+
return super().dist(x_points=x_points, pdf_points=pdf_points, **kwargs)
2526+
2527+
continuous_random_tester(
2528+
TestedInterpolated,
2529+
{},
2530+
extra_args={"rng": pytensor.shared(rng)},
2531+
ref_rand=ref_rand,
2532+
)

tests/distributions/test_multivariate.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,6 @@ def test_no_warning_logp(self):
980980
"size, shape",
981981
[
982982
((10,), None),
983-
(None, (10, 6)),
984983
(None, (10, ...)),
985984
],
986985
)
@@ -2176,7 +2175,7 @@ def ref_rand(size, n, eta):
21762175

21772176
# If passed as a domain, continuous_random_tester would make `n` a shared variable
21782177
# But this RV needs it to be constant in order to define the inner graph
2179-
for n in (2, 10, 50):
2178+
for n in (2, 50):
21802179
continuous_random_tester(
21812180
_LKJCorr,
21822181
{
@@ -2214,12 +2213,9 @@ class TestLKJCholeskyCov(BaseTestDistributionRandom):
22142213
expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
22152214
size = None
22162215

2217-
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
2216+
sizes_to_check = [None, 5, (4, 5), (2, 4, 2)]
22182217
sizes_expected = [
22192218
(6,),
2220-
(6,),
2221-
(1, 6),
2222-
(1, 6),
22232219
(5, 6),
22242220
(4, 5, 6),
22252221
(2, 4, 2, 6),

tests/logprob/test_transform_value.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def test_default_value_transform_logprob(pt_dist, dist_params, sp_dist, size):
248248
on_unused_input="ignore",
249249
)
250250

251-
for i in range(10):
251+
for i in range(3):
252252
a_dist = sp_dist(*dist_params)
253253
a_val = a_dist.rvs(size=size, random_state=test_val_rng).astype(a_value_var.dtype)
254254
b_dist = sp.stats.norm(a_val, 1.0)
@@ -614,9 +614,10 @@ def _growth(limit=10, peak_stats={}):
614614
rvs_to_values = {pt.random.beta(1, 1, name=f"p_{i}"): pt.scalar(f"p_{i}") for i in range(30)}
615615
tr = TransformValuesRewrite(dict.fromkeys(rvs_to_values.values(), logodds))
616616

617-
for i in range(20):
617+
n_warmup = 4
618+
n_check = 4
619+
for i in range(n_warmup + n_check):
618620
conditional_logp(rvs_to_values, extra_rewrites=tr)
619621
res = _growth()
620-
# Only start checking after warmup
621-
if i > 15:
622-
assert not res, "Object counts are still growing"
622+
if i >= n_warmup:
623+
assert not res, f"Object counts still growing after {n_warmup} warmup iterations: {res}"

tests/progress_bar/test_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def test_progressbar_nested_compound():
3535
)
3636

3737
kwargs = {
38-
"draws": 10,
39-
"tune": 10,
38+
"draws": 5,
39+
"tune": 5,
4040
"chains": 2,
4141
"compute_convergence_checks": False,
4242
"step": step,

0 commit comments

Comments
 (0)