Skip to content

Commit dd19b7e

Browse files
Merge pull request #1 from SamuelSchlesinger/pmf-utilities
feat(Probability/PMF): joint and posterior distributions
2 parents 3ff9e18 + 480ff50 commit dd19b7e

2 files changed

Lines changed: 130 additions & 0 deletions

File tree

Mathlib.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6131,6 +6131,7 @@ public import Mathlib.Probability.ProbabilityMassFunction.Binomial
61316131
public import Mathlib.Probability.ProbabilityMassFunction.Constructions
61326132
public import Mathlib.Probability.ProbabilityMassFunction.Integrals
61336133
public import Mathlib.Probability.ProbabilityMassFunction.Monad
6134+
public import Mathlib.Probability.ProbabilityMassFunction.Posterior
61346135
public import Mathlib.Probability.Process.Adapted
61356136
public import Mathlib.Probability.Process.Filtration
61366137
public import Mathlib.Probability.Process.FiniteDimensionalLaws
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/-
2+
Copyright (c) 2026 Samuel Schlesinger. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Samuel Schlesinger
5+
-/
6+
module
7+
8+
public import Mathlib.Probability.ProbabilityMassFunction.Constructions
9+
10+
/-!
11+
# Joint and Posterior Distributions for Probability Mass Functions
12+
13+
Given a prior `p : PMF α` and a family of distributions `f : α → PMF β`,
14+
this file defines:
15+
16+
* The **joint** distribution on `α × β`, where `a` is sampled from `p` and then
17+
`b` is sampled from `f a`.
18+
* The **posterior** distribution on `α` conditioned on an observed value `b : β`.
19+
20+
This is an elementary, combinatorial treatment of discrete Bayesian inference.
21+
For the general measure-theoretic posterior defined via disintegration, see
22+
`ProbabilityTheory.posterior` in `Mathlib.Probability.Kernel.Posterior`.
23+
The two constructions compute the same thing at different levels of generality:
24+
`PMF.posterior` gives the explicit formula `P(A=a | B=b) = P(A=a) · P(B=b|A=a) / P(B=b)`,
25+
while `ProbabilityTheory.posterior` is defined as a conditional kernel and requires
26+
standard Borel spaces and the disintegration theorem.
27+
28+
## Main definitions
29+
30+
* `PMF.joint`: The joint distribution on `α × β` induced by a prior and a family of
31+
distributions.
32+
* `PMF.posterior`: The posterior distribution `Pr[A = a | B = b]`.
33+
34+
## Main results
35+
36+
* `PMF.joint_apply`: `(p.joint f) (a, b) = p a * f a b`.
37+
* `PMF.tsum_joint_fst`: Marginalizing over the first component gives `(p.bind f) b`.
38+
* `PMF.tsum_joint_snd`: Marginalizing over the second component gives `p a`.
39+
* `PMF.posterior_hasSum`: Posterior probabilities sum to 1.
40+
-/
41+
42+
@[expose] public section
43+
44+
noncomputable section
45+
46+
variable {α β : Type*}
47+
48+
open ENNReal
49+
50+
namespace PMF
51+
52+
section Joint
53+
54+
/-- The joint distribution on `α × β` induced by a prior `p : PMF α` and a family of
55+
distributions `f : α → PMF β`. Sampling from `p.joint f` is equivalent to first sampling
56+
`a ← p` and then `b ← f a`, returning `(a, b)`. -/
57+
def joint (p : PMF α) (f : α → PMF β) : PMF (α × β) :=
58+
p.bind fun a => (f a).map (Prod.mk a)
59+
60+
variable (p : PMF α) (f : α → PMF β)
61+
62+
@[simp]
63+
theorem joint_apply (a : α) (b : β) :
64+
(p.joint f) (a, b) = p a * f a b := by
65+
simp only [joint, bind_apply, map_apply, Prod.mk.injEq]
66+
rw [tsum_eq_single a (fun a' ha' => by simp [ha'.symm]),
67+
tsum_eq_single b (fun b' hb' => by simp [hb'.symm])]
68+
simp
69+
70+
@[simp]
71+
theorem support_joint :
72+
(p.joint f).support = {ab | ab.1 ∈ p.support ∧ ab.2 ∈ (f ab.1).support} := by
73+
ext ⟨a, b⟩
74+
simp [mem_support_iff, mul_eq_zero, not_or]
75+
76+
theorem mem_support_joint_iff (ab : α × β) :
77+
ab ∈ (p.joint f).support ↔ ab.1 ∈ p.support ∧ ab.2 ∈ (f ab.1).support := by
78+
simp
79+
80+
theorem tsum_joint_fst (b : β) :
81+
∑' a, (p.joint f) (a, b) = (p.bind f) b := by
82+
simp [bind_apply]
83+
84+
theorem tsum_joint_snd (a : α) :
85+
∑' b, (p.joint f) (a, b) = p a := by
86+
simp [ENNReal.tsum_mul_left]
87+
88+
end Joint
89+
90+
section Posterior
91+
92+
/-- Posterior probabilities `joint(a, b) / marginal(b)` sum to 1
93+
when `b` is in the support of the marginal. -/
94+
theorem posterior_hasSum (p : PMF α) (f : α → PMF β) (b : β)
95+
(hb : b ∈ (p.bind f).support) :
96+
HasSum (fun a => (p.joint f) (a, b) / (p.bind f) b) 1 :=
97+
ENNReal.summable.hasSum_iff.2 <| by
98+
simp_rw [div_eq_mul_inv, ENNReal.tsum_mul_right, tsum_joint_fst]
99+
exact ENNReal.mul_inv_cancel ((mem_support_iff _ _).mp hb) ((p.bind f).apply_ne_top b)
100+
101+
/-- The posterior distribution `Pr[A = a | B = b]` as a `PMF`,
102+
given a prior `p`, a family of distributions `f`, and that `b` has positive marginal
103+
probability under `p.bind f`. -/
104+
def posterior (p : PMF α) (f : α → PMF β) (b : β)
105+
(hb : b ∈ (p.bind f).support) : PMF α :=
106+
fun a => (p.joint f) (a, b) / (p.bind f) b, posterior_hasSum p f b hb⟩
107+
108+
variable (p : PMF α) (f : α → PMF β)
109+
110+
@[simp]
111+
theorem posterior_apply (b : β) (hb : b ∈ (p.bind f).support) (a : α) :
112+
(p.posterior f b hb) a = p a * f a b / (p.bind f) b := by
113+
change (p.joint f) (a, b) / (p.bind f) b = _; simp
114+
115+
@[simp]
116+
theorem support_posterior (b : β) (hb : b ∈ (p.bind f).support) :
117+
(p.posterior f b hb).support = {a | a ∈ p.support ∧ b ∈ (f a).support} := by
118+
ext a
119+
simp only [mem_support_iff, posterior_apply, Set.mem_setOf_eq, ne_eq,
120+
ENNReal.div_ne_zero, mul_eq_zero, not_or]
121+
exact ⟨fun ⟨h, _⟩ => h, fun h => ⟨h, (p.bind f).apply_ne_top b⟩⟩
122+
123+
theorem mem_support_posterior_iff (b : β) (hb : b ∈ (p.bind f).support) (a : α) :
124+
a ∈ (p.posterior f b hb).support ↔ a ∈ p.support ∧ b ∈ (f a).support := by
125+
simp
126+
127+
end Posterior
128+
129+
end PMF

0 commit comments

Comments
 (0)