|
| 1 | +/- |
| 2 | +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Samuel Schlesinger |
| 5 | +-/ |
| 6 | +module |
| 7 | + |
| 8 | +public import Mathlib.Probability.ProbabilityMassFunction.Constructions |
| 9 | + |
| 10 | +/-! |
| 11 | +# Joint and Posterior Distributions for Probability Mass Functions |
| 12 | +
|
| 13 | +Given a prior `p : PMF α` and a family of distributions `f : α → PMF β`, |
| 14 | +this file defines: |
| 15 | +
|
| 16 | +* The **joint** distribution on `α × β`, where `a` is sampled from `p` and then |
| 17 | + `b` is sampled from `f a`. |
| 18 | +* The **posterior** distribution on `α` conditioned on an observed value `b : β`. |
| 19 | +
|
| 20 | +This is an elementary, combinatorial treatment of discrete Bayesian inference. |
| 21 | +For the general measure-theoretic posterior defined via disintegration, see |
| 22 | +`ProbabilityTheory.posterior` in `Mathlib.Probability.Kernel.Posterior`. |
| 23 | +The two constructions compute the same thing at different levels of generality: |
| 24 | +`PMF.posterior` gives the explicit formula `P(A=a | B=b) = P(A=a) · P(B=b|A=a) / P(B=b)`, |
| 25 | +while `ProbabilityTheory.posterior` is defined as a conditional kernel and requires |
| 26 | +standard Borel spaces and the disintegration theorem. |
| 27 | +
|
| 28 | +## Main definitions |
| 29 | +
|
| 30 | +* `PMF.joint`: The joint distribution on `α × β` induced by a prior and a family of |
| 31 | + distributions. |
| 32 | +* `PMF.posterior`: The posterior distribution `Pr[A = a | B = b]`. |
| 33 | +
|
| 34 | +## Main results |
| 35 | +
|
| 36 | +* `PMF.joint_apply`: `(p.joint f) (a, b) = p a * f a b`. |
| 37 | +* `PMF.tsum_joint_fst`: Marginalizing over the first component gives `(p.bind f) b`. |
| 38 | +* `PMF.tsum_joint_snd`: Marginalizing over the second component gives `p a`. |
| 39 | +* `PMF.posterior_hasSum`: Posterior probabilities sum to 1. |
| 40 | +-/ |
| 41 | + |
| 42 | +@[expose] public section |
| 43 | + |
| 44 | +noncomputable section |
| 45 | + |
| 46 | +variable {α β : Type*} |
| 47 | + |
| 48 | +open ENNReal |
| 49 | + |
| 50 | +namespace PMF |
| 51 | + |
| 52 | +section Joint |
| 53 | + |
| 54 | +/-- The joint distribution on `α × β` induced by a prior `p : PMF α` and a family of |
| 55 | +distributions `f : α → PMF β`. Sampling from `p.joint f` is equivalent to first sampling |
| 56 | +`a ← p` and then `b ← f a`, returning `(a, b)`. -/ |
| 57 | +def joint (p : PMF α) (f : α → PMF β) : PMF (α × β) := |
| 58 | + p.bind fun a => (f a).map (Prod.mk a) |
| 59 | + |
| 60 | +variable (p : PMF α) (f : α → PMF β) |
| 61 | + |
| 62 | +@[simp] |
| 63 | +theorem joint_apply (a : α) (b : β) : |
| 64 | + (p.joint f) (a, b) = p a * f a b := by |
| 65 | + simp only [joint, bind_apply, map_apply, Prod.mk.injEq] |
| 66 | + rw [tsum_eq_single a (fun a' ha' => by simp [ha'.symm]), |
| 67 | + tsum_eq_single b (fun b' hb' => by simp [hb'.symm])] |
| 68 | + simp |
| 69 | + |
| 70 | +@[simp] |
| 71 | +theorem support_joint : |
| 72 | + (p.joint f).support = {ab | ab.1 ∈ p.support ∧ ab.2 ∈ (f ab.1).support} := by |
| 73 | + ext ⟨a, b⟩ |
| 74 | + simp [mem_support_iff, mul_eq_zero, not_or] |
| 75 | + |
| 76 | +theorem mem_support_joint_iff (ab : α × β) : |
| 77 | + ab ∈ (p.joint f).support ↔ ab.1 ∈ p.support ∧ ab.2 ∈ (f ab.1).support := by |
| 78 | + simp |
| 79 | + |
| 80 | +theorem tsum_joint_fst (b : β) : |
| 81 | + ∑' a, (p.joint f) (a, b) = (p.bind f) b := by |
| 82 | + simp [bind_apply] |
| 83 | + |
| 84 | +theorem tsum_joint_snd (a : α) : |
| 85 | + ∑' b, (p.joint f) (a, b) = p a := by |
| 86 | + simp [ENNReal.tsum_mul_left] |
| 87 | + |
| 88 | +end Joint |
| 89 | + |
| 90 | +section Posterior |
| 91 | + |
| 92 | +/-- Posterior probabilities `joint(a, b) / marginal(b)` sum to 1 |
| 93 | +when `b` is in the support of the marginal. -/ |
| 94 | +theorem posterior_hasSum (p : PMF α) (f : α → PMF β) (b : β) |
| 95 | + (hb : b ∈ (p.bind f).support) : |
| 96 | + HasSum (fun a => (p.joint f) (a, b) / (p.bind f) b) 1 := |
| 97 | + ENNReal.summable.hasSum_iff.2 <| by |
| 98 | + simp_rw [div_eq_mul_inv, ENNReal.tsum_mul_right, tsum_joint_fst] |
| 99 | + exact ENNReal.mul_inv_cancel ((mem_support_iff _ _).mp hb) ((p.bind f).apply_ne_top b) |
| 100 | + |
| 101 | +/-- The posterior distribution `Pr[A = a | B = b]` as a `PMF`, |
| 102 | +given a prior `p`, a family of distributions `f`, and that `b` has positive marginal |
| 103 | +probability under `p.bind f`. -/ |
| 104 | +def posterior (p : PMF α) (f : α → PMF β) (b : β) |
| 105 | + (hb : b ∈ (p.bind f).support) : PMF α := |
| 106 | + ⟨fun a => (p.joint f) (a, b) / (p.bind f) b, posterior_hasSum p f b hb⟩ |
| 107 | + |
| 108 | +variable (p : PMF α) (f : α → PMF β) |
| 109 | + |
| 110 | +@[simp] |
| 111 | +theorem posterior_apply (b : β) (hb : b ∈ (p.bind f).support) (a : α) : |
| 112 | + (p.posterior f b hb) a = p a * f a b / (p.bind f) b := by |
| 113 | + change (p.joint f) (a, b) / (p.bind f) b = _; simp |
| 114 | + |
| 115 | +@[simp] |
| 116 | +theorem support_posterior (b : β) (hb : b ∈ (p.bind f).support) : |
| 117 | + (p.posterior f b hb).support = {a | a ∈ p.support ∧ b ∈ (f a).support} := by |
| 118 | + ext a |
| 119 | + simp only [mem_support_iff, posterior_apply, Set.mem_setOf_eq, ne_eq, |
| 120 | + ENNReal.div_ne_zero, mul_eq_zero, not_or] |
| 121 | + exact ⟨fun ⟨h, _⟩ => h, fun h => ⟨h, (p.bind f).apply_ne_top b⟩⟩ |
| 122 | + |
| 123 | +theorem mem_support_posterior_iff (b : β) (hb : b ∈ (p.bind f).support) (a : α) : |
| 124 | + a ∈ (p.posterior f b hb).support ↔ a ∈ p.support ∧ b ∈ (f a).support := by |
| 125 | + simp |
| 126 | + |
| 127 | +end Posterior |
| 128 | + |
| 129 | +end PMF |
0 commit comments