From eb11fd8f04b6bedf46244e5b52d16235608c168d Mon Sep 17 00:00:00 2001 From: Samuel Schlesinger Date: Tue, 7 Apr 2026 14:33:29 -0400 Subject: [PATCH 1/5] feat(Probability/PMF): define joint distribution on product type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Define `PMF.joint p f` as the joint distribution on `α × β` induced by a prior `p : PMF α` and a family of distributions `f : α → PMF β`. Prove `PMF.joint_apply`: evaluating the joint at `(a, b)` gives `p a * f a b`. --- Mathlib.lean | 1 + .../ProbabilityMassFunction/Posterior.lean | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 Mathlib/Probability/ProbabilityMassFunction/Posterior.lean diff --git a/Mathlib.lean b/Mathlib.lean index 1374639b55c706..8ec65c698e8378 100644 --- a/Mathlib.lean +++ b/Mathlib.lean @@ -6146,6 +6146,7 @@ public import Mathlib.Probability.ProbabilityMassFunction.Binomial public import Mathlib.Probability.ProbabilityMassFunction.Constructions public import Mathlib.Probability.ProbabilityMassFunction.Integrals public import Mathlib.Probability.ProbabilityMassFunction.Monad +public import Mathlib.Probability.ProbabilityMassFunction.Posterior public import Mathlib.Probability.Process.Adapted public import Mathlib.Probability.Process.Filtration public import Mathlib.Probability.Process.FiniteDimensionalLaws diff --git a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean new file mode 100644 index 00000000000000..611f479d793ea4 --- /dev/null +++ b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean @@ -0,0 +1,62 @@ +/- +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Samuel Schlesinger +-/ +module + +public import Mathlib.Probability.ProbabilityMassFunction.Monad + +/-! +# Joint and Posterior Distributions for Probability Mass Functions + +Given a prior `p : PMF α` and a family of distributions `f : α → PMF β`, +this file defines: + +* The **joint** distribution on `α × β`, where `a` is sampled from `p` and then + `b` is sampled from `f a`. +* The **posterior** distribution on `α` conditioned on an observed value `b : β`. + +## Main definitions + +* `PMF.joint`: The joint distribution on `α × β` induced by a prior and a Markov kernel. +* `PMF.posterior`: The posterior distribution `Pr[A = a | B = b]`. + +## Main results + +* `PMF.joint_apply`: `(p.joint f) (a, b) = p a * f a b`. +* `PMF.tsum_joint_fst`: Marginalizing over the first component gives `(p.bind f) b`. +* `PMF.posterior_hasSum`: Posterior probabilities sum to 1. +-/ + +@[expose] public section + +noncomputable section + +variable {α β : Type*} + +open ENNReal + +namespace PMF + +section Joint + +/-- The joint distribution on `α × β` induced by a prior `p : PMF α` and a family of +distributions `f : α → PMF β`. Sampling from `p.joint f` is equivalent to first sampling +`a ← p` and then `b ← f a`, returning `(a, b)`. -/ +def joint (p : PMF α) (f : α → PMF β) : PMF (α × β) := + p.bind fun a => (f a).bind fun b => pure (a, b) + +@[simp] +theorem joint_apply (p : PMF α) (f : α → PMF β) (a : α) (b : β) : + (p.joint f) (a, b) = p a * f a b := by + simp only [joint, bind_apply, pure_apply, Prod.mk.injEq] + rw [tsum_eq_single a] + · congr 1; rw [tsum_eq_single b] + · simp + · intro b' hb'; simp [hb'.symm] + · intro a' ha'; simp [ha'.symm] + +end Joint + +end PMF From adc11f3f571b7d6ff97e5c9e7d20af3723c09035 Mon Sep 17 00:00:00 2001 From: Samuel Schlesinger Date: Tue, 7 Apr 2026 14:33:38 -0400 Subject: [PATCH 2/5] feat(Probability/PMF): marginalization of joint distribution Prove `PMF.tsum_joint_fst`: summing the joint distribution over the first component recovers the marginal `(p.bind f) b`. --- Mathlib/Probability/ProbabilityMassFunction/Posterior.lean | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean index 611f479d793ea4..79759a91879530 100644 --- a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean +++ b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean @@ -57,6 +57,10 @@ theorem joint_apply (p : PMF α) (f : α → PMF β) (a : α) (b : β) : · intro b' hb'; simp [hb'.symm] · intro a' ha'; simp [ha'.symm] +theorem tsum_joint_fst (p : PMF α) (f : α → PMF β) (b : β) : + ∑' a, (p.joint f) (a, b) = (p.bind f) b := by + simp [bind_apply] + end Joint end PMF From 1527dbc4b3b6a4dfe4c4d8f5ba6d15377257d574 Mon Sep 17 00:00:00 2001 From: Samuel Schlesinger Date: Tue, 7 Apr 2026 14:33:57 -0400 Subject: [PATCH 3/5] feat(Probability/PMF): define posterior distribution Define `PMF.posterior p f b hb` as the posterior distribution `Pr[A = a | B = b]` given a prior `p`, a family of distributions `f`, and that `b` is in the support of the marginal `p.bind f`. Prove `PMF.posterior_hasSum` (posterior probabilities sum to 1) and `PMF.posterior_apply` (simp lemma for evaluation). --- .../ProbabilityMassFunction/Posterior.lean | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean index 79759a91879530..b907f45253cead 100644 --- a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean +++ b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean @@ -63,4 +63,34 @@ theorem tsum_joint_fst (p : PMF α) (f : α → PMF β) (b : β) : end Joint +section Posterior + +/-- Posterior probabilities `joint(a, b) / marginal(b)` sum to 1 +when `b` is in the support of the marginal. -/ +theorem posterior_hasSum (p : PMF α) (f : α → PMF β) (b : β) + (hb : b ∈ (p.bind f).support) : + HasSum (fun a => (p.joint f) (a, b) / (p.bind f) b) 1 := by + have hne := (mem_support_iff _ _).mp hb + have hne_top := (p.bind f).apply_ne_top b + have h : ∑' a, (p.joint f) (a, b) / (p.bind f) b = 1 := by + simp only [div_eq_mul_inv] + rw [ENNReal.tsum_mul_right, tsum_joint_fst] + exact ENNReal.mul_inv_cancel hne hne_top + exact h ▸ ENNReal.summable.hasSum + +/-- The posterior distribution `Pr[A = a | B = b]` as a `PMF`, +given a prior `p`, a family of distributions `f`, and that `b` has positive marginal +probability under `p.bind f`. -/ +def posterior (p : PMF α) (f : α → PMF β) (b : β) + (hb : b ∈ (p.bind f).support) : PMF α := + ⟨fun a => (p.joint f) (a, b) / (p.bind f) b, posterior_hasSum p f b hb⟩ + +@[simp] +theorem posterior_apply (p : PMF α) (f : α → PMF β) (b : β) + (hb : b ∈ (p.bind f).support) (a : α) : + (p.posterior f b hb) a = (p.joint f) (a, b) / (p.bind f) b := + rfl + +end Posterior + end PMF From 3cd3b4f7943add661eed8523dad71c35c6ed6352 Mon Sep 17 00:00:00 2001 From: Samuel Schlesinger Date: Tue, 7 Apr 2026 14:40:05 -0400 Subject: [PATCH 4/5] refactor(Probability/PMF): golf joint_apply and posterior_hasSum proofs Shorten `joint_apply` by inlining `tsum_eq_single` side conditions. Shorten `posterior_hasSum` by using `ENNReal.summable.hasSum_iff.2` instead of intermediate `have` bindings. --- .../ProbabilityMassFunction/Posterior.lean | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean index b907f45253cead..6225d1a06d1bad 100644 --- a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean +++ b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean @@ -51,11 +51,9 @@ def joint (p : PMF α) (f : α → PMF β) : PMF (α × β) := theorem joint_apply (p : PMF α) (f : α → PMF β) (a : α) (b : β) : (p.joint f) (a, b) = p a * f a b := by simp only [joint, bind_apply, pure_apply, Prod.mk.injEq] - rw [tsum_eq_single a] - · congr 1; rw [tsum_eq_single b] - · simp - · intro b' hb'; simp [hb'.symm] - · intro a' ha'; simp [ha'.symm] + rw [tsum_eq_single a (fun a' ha' => by simp [ha'.symm]), + tsum_eq_single b (fun b' hb' => by simp [hb'.symm])] + simp theorem tsum_joint_fst (p : PMF α) (f : α → PMF β) (b : β) : ∑' a, (p.joint f) (a, b) = (p.bind f) b := by @@ -69,14 +67,11 @@ section Posterior when `b` is in the support of the marginal. -/ theorem posterior_hasSum (p : PMF α) (f : α → PMF β) (b : β) (hb : b ∈ (p.bind f).support) : - HasSum (fun a => (p.joint f) (a, b) / (p.bind f) b) 1 := by - have hne := (mem_support_iff _ _).mp hb - have hne_top := (p.bind f).apply_ne_top b - have h : ∑' a, (p.joint f) (a, b) / (p.bind f) b = 1 := by + HasSum (fun a => (p.joint f) (a, b) / (p.bind f) b) 1 := + ENNReal.summable.hasSum_iff.2 <| by simp only [div_eq_mul_inv] rw [ENNReal.tsum_mul_right, tsum_joint_fst] - exact ENNReal.mul_inv_cancel hne hne_top - exact h ▸ ENNReal.summable.hasSum + exact ENNReal.mul_inv_cancel ((mem_support_iff _ _).mp hb) ((p.bind f).apply_ne_top b) /-- The posterior distribution `Pr[A = a | B = b]` as a `PMF`, given a prior `p`, a family of distributions `f`, and that `b` has positive marginal From 3e19ca2eae0a3db21f8d23ed64da9c22e04b35ab Mon Sep 17 00:00:00 2001 From: Samuel Schlesinger Date: Tue, 7 Apr 2026 15:08:12 -0400 Subject: [PATCH 5/5] feat(Probability/PMF): expand joint/posterior API and improve docs Add support lemmas, second marginalization, use `map` in `joint` definition, expand `posterior_apply` to show the Bayes formula directly, and document relationship to `Kernel.Posterior`. --- .../ProbabilityMassFunction/Posterior.lean | 62 +++++++++++++++---- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean index 6225d1a06d1bad..a16be0064a6f2d 100644 --- a/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean +++ b/Mathlib/Probability/ProbabilityMassFunction/Posterior.lean @@ -5,7 +5,7 @@ Authors: Samuel Schlesinger -/ module -public import Mathlib.Probability.ProbabilityMassFunction.Monad +public import Mathlib.Probability.ProbabilityMassFunction.Constructions /-! # Joint and Posterior Distributions for Probability Mass Functions @@ -17,15 +17,25 @@ this file defines: `b` is sampled from `f a`. * The **posterior** distribution on `α` conditioned on an observed value `b : β`. +This is an elementary, combinatorial treatment of discrete Bayesian inference. +For the general measure-theoretic posterior defined via disintegration, see +`ProbabilityTheory.posterior` in `Mathlib.Probability.Kernel.Posterior`. +The two constructions compute the same thing at different levels of generality: +`PMF.posterior` gives the explicit formula `P(A=a | B=b) = P(A=a) · P(B=b|A=a) / P(B=b)`, +while `ProbabilityTheory.posterior` is defined as a conditional kernel and requires +standard Borel spaces and the disintegration theorem. + ## Main definitions -* `PMF.joint`: The joint distribution on `α × β` induced by a prior and a Markov kernel. +* `PMF.joint`: The joint distribution on `α × β` induced by a prior and a family of + distributions. * `PMF.posterior`: The posterior distribution `Pr[A = a | B = b]`. ## Main results * `PMF.joint_apply`: `(p.joint f) (a, b) = p a * f a b`. * `PMF.tsum_joint_fst`: Marginalizing over the first component gives `(p.bind f) b`. +* `PMF.tsum_joint_snd`: Marginalizing over the second component gives `p a`. * `PMF.posterior_hasSum`: Posterior probabilities sum to 1. -/ @@ -45,20 +55,36 @@ section Joint distributions `f : α → PMF β`. Sampling from `p.joint f` is equivalent to first sampling `a ← p` and then `b ← f a`, returning `(a, b)`. -/ def joint (p : PMF α) (f : α → PMF β) : PMF (α × β) := - p.bind fun a => (f a).bind fun b => pure (a, b) + p.bind fun a => (f a).map (Prod.mk a) + +variable (p : PMF α) (f : α → PMF β) @[simp] -theorem joint_apply (p : PMF α) (f : α → PMF β) (a : α) (b : β) : +theorem joint_apply (a : α) (b : β) : (p.joint f) (a, b) = p a * f a b := by - simp only [joint, bind_apply, pure_apply, Prod.mk.injEq] + simp only [joint, bind_apply, map_apply, Prod.mk.injEq] rw [tsum_eq_single a (fun a' ha' => by simp [ha'.symm]), tsum_eq_single b (fun b' hb' => by simp [hb'.symm])] simp -theorem tsum_joint_fst (p : PMF α) (f : α → PMF β) (b : β) : +@[simp] +theorem support_joint : + (p.joint f).support = {ab | ab.1 ∈ p.support ∧ ab.2 ∈ (f ab.1).support} := by + ext ⟨a, b⟩ + simp [mem_support_iff, mul_eq_zero, not_or] + +theorem mem_support_joint_iff (ab : α × β) : + ab ∈ (p.joint f).support ↔ ab.1 ∈ p.support ∧ ab.2 ∈ (f ab.1).support := by + simp + +theorem tsum_joint_fst (b : β) : ∑' a, (p.joint f) (a, b) = (p.bind f) b := by simp [bind_apply] +theorem tsum_joint_snd (a : α) : + ∑' b, (p.joint f) (a, b) = p a := by + simp [ENNReal.tsum_mul_left] + end Joint section Posterior @@ -69,8 +95,7 @@ theorem posterior_hasSum (p : PMF α) (f : α → PMF β) (b : β) (hb : b ∈ (p.bind f).support) : HasSum (fun a => (p.joint f) (a, b) / (p.bind f) b) 1 := ENNReal.summable.hasSum_iff.2 <| by - simp only [div_eq_mul_inv] - rw [ENNReal.tsum_mul_right, tsum_joint_fst] + simp_rw [div_eq_mul_inv, ENNReal.tsum_mul_right, tsum_joint_fst] exact ENNReal.mul_inv_cancel ((mem_support_iff _ _).mp hb) ((p.bind f).apply_ne_top b) /-- The posterior distribution `Pr[A = a | B = b]` as a `PMF`, @@ -80,11 +105,24 @@ def posterior (p : PMF α) (f : α → PMF β) (b : β) (hb : b ∈ (p.bind f).support) : PMF α := ⟨fun a => (p.joint f) (a, b) / (p.bind f) b, posterior_hasSum p f b hb⟩ +variable (p : PMF α) (f : α → PMF β) + @[simp] -theorem posterior_apply (p : PMF α) (f : α → PMF β) (b : β) - (hb : b ∈ (p.bind f).support) (a : α) : - (p.posterior f b hb) a = (p.joint f) (a, b) / (p.bind f) b := - rfl +theorem posterior_apply (b : β) (hb : b ∈ (p.bind f).support) (a : α) : + (p.posterior f b hb) a = p a * f a b / (p.bind f) b := by + change (p.joint f) (a, b) / (p.bind f) b = _; simp + +@[simp] +theorem support_posterior (b : β) (hb : b ∈ (p.bind f).support) : + (p.posterior f b hb).support = {a | a ∈ p.support ∧ b ∈ (f a).support} := by + ext a + simp only [mem_support_iff, posterior_apply, Set.mem_setOf_eq, ne_eq, + ENNReal.div_ne_zero, mul_eq_zero, not_or] + exact ⟨fun ⟨h, _⟩ => h, fun h => ⟨h, (p.bind f).apply_ne_top b⟩⟩ + +theorem mem_support_posterior_iff (b : β) (hb : b ∈ (p.bind f).support) (a : α) : + a ∈ (p.posterior f b hb).support ↔ a ∈ p.support ∧ b ∈ (f a).support := by + simp end Posterior