|
| 1 | +# /// script |
| 2 | +# dependencies = ["numpy", "pymc", "pytensor"] |
| 3 | +# /// |
| 4 | + |
| 5 | +"""Verify the Tweedie dist function against theoretical values. |
| 6 | +
|
| 7 | +Compares the symbolic PyMC tweedie_dist (used by sample_prior_predictive |
| 8 | +and sample_posterior_predictive) against: |
| 9 | + - theoretical values (E[Y]=mu, Var[Y]=phi*mu^p, P(Y=0)=exp(-lambda)) |
| 10 | + - the numpy tweedie_random reference (known correct) |
| 11 | +
|
| 12 | +Tests two versions: the blog post version (suspected bug) and the corrected |
| 13 | +version (beta as rate = 1/(phi*(p-1)*mu^(p-1))). |
| 14 | +""" |
| 15 | + |
| 16 | +import sys |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import pymc as pm |
| 20 | +import pymc.dims as pmd |
| 21 | +import pytensor |
| 22 | +import pytensor.tensor as pt |
| 23 | +import pytensor.xtensor as px |
| 24 | + |
| 25 | +# ---- Import the known-correct numpy reference ---- |
| 26 | +sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent)) |
| 27 | +from tweedie_utils import tweedie_random |
| 28 | + |
| 29 | + |
| 30 | +def tweedie_dist_buggy(mu, phi, p): |
| 31 | + """Original blog post version (suspected wrong).""" |
| 32 | + lam = mu ** (2 - p) / (phi * (2 - p)) |
| 33 | + alpha_term = (2 - p) / (p - 1) |
| 34 | + beta = phi * (p - 1) * mu ** (p - 1) |
| 35 | + N = pmd.Poisson.dist(mu=lam) |
| 36 | + Y = pmd.Gamma.dist(alpha=px.math.maximum(N * alpha_term, 1e-10), beta=beta) |
| 37 | + return px.math.where(N > 0, Y, 0.0) |
| 38 | + |
| 39 | + |
| 40 | +def tweedie_dist_correct(mu, phi, p): |
| 41 | + """Corrected version: beta is rate = 1/scale.""" |
| 42 | + lam = mu ** (2 - p) / (phi * (2 - p)) |
| 43 | + alpha_term = (2 - p) / (p - 1) |
| 44 | + beta = 1.0 / (phi * (p - 1) * mu ** (p - 1)) |
| 45 | + N = pmd.Poisson.dist(mu=lam) |
| 46 | + Y = pmd.Gamma.dist(alpha=px.math.maximum(N * alpha_term, 1e-10), beta=beta) |
| 47 | + return px.math.where(N > 0, Y, 0.0) |
| 48 | + |
| 49 | + |
| 50 | +def theoretical_values(mu, phi, p): |
| 51 | + """Compute theoretical Tweedie moments.""" |
| 52 | + lam = mu ** (2 - p) / (phi * (2 - p)) |
| 53 | + zero_rate = np.exp(-lam) |
| 54 | + return { |
| 55 | + "mean": float(mu), |
| 56 | + "std": float(np.sqrt(phi * mu**p)), |
| 57 | + "zero_rate": float(zero_rate), |
| 58 | + } |
| 59 | + |
| 60 | + |
| 61 | +def draw_and_stats(dist, draws=50_000, rng=None): |
| 62 | + """Draw from a symbolic dist and return mean, std, zero_rate.""" |
| 63 | + samples = pm.draw(dist, draws=draws, random_seed=rng) |
| 64 | + samples = np.asarray(samples).ravel() |
| 65 | + return { |
| 66 | + "mean": float(np.mean(samples)), |
| 67 | + "std": float(np.std(samples)), |
| 68 | + "zero_rate": float(np.mean(samples == 0)), |
| 69 | + } |
| 70 | + |
| 71 | + |
| 72 | +def compare(name, stats, theo, tol_mean=0.05, tol_zero=0.01): |
| 73 | + """Compare sampled stats against theoretical values.""" |
| 74 | + results = [] |
| 75 | + passed = True |
| 76 | + |
| 77 | + for key in ("mean", "std", "zero_rate"): |
| 78 | + s, t = stats[key], theo[key] |
| 79 | + rel_err = abs(s - t) / max(abs(t), 1e-10) |
| 80 | + tol = tol_zero if key == "zero_rate" else tol_mean |
| 81 | + ok = rel_err < tol |
| 82 | + if not ok: |
| 83 | + passed = False |
| 84 | + status = "✓" if ok else "✗" |
| 85 | + results.append(f" {status} {key:<10s}: {s:>14.4f} (theory={t:>14.4f}, rel_err={rel_err:>8.4%})") |
| 86 | + |
| 87 | + print(f"\n{'='*75}") |
| 88 | + print(f" {name}") |
| 89 | + print(f"{'='*75}") |
| 90 | + for r in results: |
| 91 | + print(r) |
| 92 | + print(f" {'ALL PASS' if passed else 'FAILURES DETECTED'}") |
| 93 | + return passed |
| 94 | + |
| 95 | + |
| 96 | +def compare_vs_numpy(name, pymc_stats, numpy_stats): |
| 97 | + """Compare PyMC dist stats against numpy reference.""" |
| 98 | + print(f"\n{'='*75}") |
| 99 | + print(f" {name} vs numpy reference") |
| 100 | + print(f"{'='*75}") |
| 101 | + all_ok = True |
| 102 | + for key in ("mean", "std", "zero_rate"): |
| 103 | + s_pymc, s_np = pymc_stats[key], numpy_stats[key] |
| 104 | + rel_diff = abs(s_pymc - s_np) / max(abs(s_np), 1e-10) |
| 105 | + ok = rel_diff < 0.05 |
| 106 | + if not ok: |
| 107 | + all_ok = False |
| 108 | + status = "✓" if ok else "✗" |
| 109 | + print(f" {status} {key:<10s}: pymc={s_pymc:>14.4f} numpy={s_np:>14.4f} diff={rel_diff:>8.4%}") |
| 110 | + print(f" {'MATCHES NUMPY' if all_ok else 'DIVERGES FROM NUMPY'}") |
| 111 | + return all_ok |
| 112 | + |
| 113 | + |
| 114 | +def main(): |
| 115 | + test_cases = [ |
| 116 | + {"mu": 10.0, "phi": 2.0, "p": 1.5, "label": "μ=10, φ=2.0, p=1.5"}, |
| 117 | + {"mu": 50.0, "phi": 1.5, "p": 1.3, "label": "μ=50, φ=1.5, p=1.3"}, |
| 118 | + {"mu": 293.0, "phi": 174.0, "p": 1.574, "label": "μ=293, φ=174, p=1.574 (dataCar-like)"}, |
| 119 | + ] |
| 120 | + |
| 121 | + draws = 50_000 |
| 122 | + all_passed = True |
| 123 | + all_match_numpy = True |
| 124 | + |
| 125 | + for tc in test_cases: |
| 126 | + mu = tc["mu"] |
| 127 | + phi = tc["phi"] |
| 128 | + p = tc["p"] |
| 129 | + label = tc["label"] |
| 130 | + |
| 131 | + print(f"\n{'#'*75}") |
| 132 | + print(f"# {label}") |
| 133 | + print(f"{'#'*75}") |
| 134 | + |
| 135 | + # Seed for reproducibility |
| 136 | + rng = np.random.default_rng(42) |
| 137 | + seed = 42 |
| 138 | + |
| 139 | + theo = theoretical_values(mu, phi, p) |
| 140 | + print(f"\n Theoretical: mean={theo['mean']:.4f}, std={theo['std']:.4f}, zero_rate={theo['zero_rate']:.4%}") |
| 141 | + |
| 142 | + # ---- 1. Numpy reference ---- |
| 143 | + np_samples = tweedie_random(mu, phi, p, size=draws, rng=rng) |
| 144 | + np_stats = { |
| 145 | + "mean": float(np.mean(np_samples)), |
| 146 | + "std": float(np.std(np_samples)), |
| 147 | + "zero_rate": float(np.mean(np_samples == 0)), |
| 148 | + } |
| 149 | + compare("numpy reference (tweedie_random)", np_stats, theo) |
| 150 | + |
| 151 | + # ---- 2. Buggy PyMC dist ---- |
| 152 | + dist_buggy = tweedie_dist_buggy(mu, phi, p) |
| 153 | + buggy_stats = draw_and_stats(dist_buggy, draws=draws, rng=seed) |
| 154 | + passed_buggy = compare("BUGGY tweedie_dist (blog post)", buggy_stats, theo) |
| 155 | + match_buggy = compare_vs_numpy("BUGGY tweedie_dist", buggy_stats, np_stats) |
| 156 | + if not passed_buggy: |
| 157 | + all_passed = False |
| 158 | + if not match_buggy: |
| 159 | + all_match_numpy = False |
| 160 | + |
| 161 | + # ---- 3. Corrected PyMC dist ---- |
| 162 | + dist_correct = tweedie_dist_correct(mu, phi, p) |
| 163 | + correct_stats = draw_and_stats(dist_correct, draws=draws, rng=seed) |
| 164 | + passed_correct = compare("CORRECT tweedie_dist", correct_stats, theo) |
| 165 | + match_correct = compare_vs_numpy("CORRECT tweedie_dist", correct_stats, np_stats) |
| 166 | + if not passed_correct: |
| 167 | + all_passed = False |
| 168 | + if not match_correct: |
| 169 | + all_match_numpy = False |
| 170 | + |
| 171 | + # ---- Final verdict ---- |
| 172 | + print(f"\n{'='*75}") |
| 173 | + print(f" SUMMARY") |
| 174 | + print(f"{'='*75}") |
| 175 | + if all_passed: |
| 176 | + print(f" Buggy version FAILED theory check — CONFIRMED BUG") |
| 177 | + else: |
| 178 | + print(f" Buggy version MAY have passed — UNEXPECTED, investigate") |
| 179 | + |
| 180 | + print(f" Corrected version matches numpy reference: {'YES' if all_match_numpy else 'NO'}") |
| 181 | + |
| 182 | + if all_passed: |
| 183 | + print(f"\n >>> Bug confirmed. Proceed with fix in 3 locations. <<<") |
| 184 | + else: |
| 185 | + print(f"\n >>> Unexpected results. Investigate before proceeding. <<<") |
| 186 | + |
| 187 | + |
| 188 | +if __name__ == "__main__": |
| 189 | + main() |
0 commit comments