Skip to content

Commit a4d1931

Browse files
mcwanzaricardoV94
authored andcommitted
Speed up test_weakref_leak and test_step_args
Profiling-backed changes targeting the two tests where iteration reduction yields meaningful savings. test_weakref_leak (89s on CI → estimated ~30s): Object count profiling shows memory state stabilizes at iteration 2. The original test used 16 warmup iterations before checking — reduced to 3 warmup + 3 check = 6 total iterations (from 20). Each conditional_logp call costs ~4.5s on CI, so removing 14 iterations saves ~63s. test_step_args (62s on CI → estimated ~25s): This test verifies target_accept argument plumbing, checking acceptance_rate.mean() ≈ 0.5 with decimal=1 precision. The default pm.sample() uses 1000 draws/tune, but 200 is sufficient for this loose tolerance. Verified stable over 10 runs with different seeds. Changes NOT made (profiling showed negligible impact): - test_default_value_transform_logprob range(10)→range(3): compile time is 99.98% of cost, loop saves 0.6ms total - test_interpolated x_points 100k→20k: compilation is 91.8% of cost, array reduction saves ~296ms across 56 iterations Addresses #7686
1 parent a65e8b6 commit a4d1931

2 files changed

Lines changed: 14 additions & 9 deletions

File tree

tests/logprob/test_transform_value.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -634,9 +634,8 @@ def _growth(limit=10, peak_stats={}):
634634
rvs_to_values = {pt.random.beta(1, 1, name=f"p_{i}"): pt.scalar(f"p_{i}") for i in range(30)}
635635
tr = TransformValuesRewrite(dict.fromkeys(rvs_to_values.values(), logodds))
636636

637-
for i in range(20):
637+
for i in range(6):
638638
conditional_logp(rvs_to_values, extra_rewrites=tr)
639639
res = _growth()
640-
# Only start checking after warmup
641-
if i > 15:
640+
if i > 2:
642641
assert not res, "Object counts are still growing"

tests/sampling/test_mcmc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -691,10 +691,12 @@ def accept(idata):
691691

692692
with pm.Model() as model:
693693
a = pm.Normal("a")
694-
idata_default = pm.sample(random_seed=1410)
695-
idata0 = pm.sample(target_accept=0.5, random_seed=1410)
696-
idata1 = pm.sample(nuts={"target_accept": 0.5}, random_seed=1410 * 2)
697-
idata2 = pm.sample(target_accept=0.5, nuts={"max_treedepth": 10}, random_seed=1410)
694+
idata_default = pm.sample(draws=200, tune=200, random_seed=1410)
695+
idata0 = pm.sample(draws=200, tune=200, target_accept=0.5, random_seed=1410)
696+
idata1 = pm.sample(draws=200, tune=200, nuts={"target_accept": 0.5}, random_seed=1410 * 2)
697+
idata2 = pm.sample(
698+
draws=200, tune=200, target_accept=0.5, nuts={"max_treedepth": 10}, random_seed=1410
699+
)
698700

699701
with pytest.raises(ValueError, match="`target_accept` was defined twice."):
700702
pm.sample(target_accept=0.5, nuts={"target_accept": 0.95}, random_seed=1410)
@@ -709,13 +711,17 @@ def accept(idata):
709711
with pm.Model() as model:
710712
a = pm.Normal("a")
711713
b = pm.Poisson("b", 1)
712-
idata0 = pm.sample(target_accept=0.5, random_seed=1418)
714+
idata0 = pm.sample(draws=200, tune=200, target_accept=0.5, random_seed=1418)
713715
with warnings.catch_warnings():
714716
warnings.filterwarnings(
715717
"ignore", "invalid value encountered in double_scalars", RuntimeWarning
716718
)
717719
idata1 = pm.sample(
718-
nuts={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2
720+
draws=200,
721+
tune=200,
722+
nuts={"target_accept": 0.5},
723+
metropolis={"scaling": 0},
724+
random_seed=1418 * 2,
719725
)
720726

721727
npt.assert_almost_equal(accept(idata0).mean(), 0.5, decimal=1)

0 commit comments

Comments
 (0)