|
1 | | -/// Kullback-Leibler divergence and related measures |
2 | | -/// Core information-theoretic distance metrics |
| 1 | +/// Kullback-Leibler divergence and related measures. |
| 2 | +/// Core information-theoretic distance metrics for comparing probability distributions. |
| 3 | +/// |
| 4 | +/// LP usage: measure evolutionary distance between species, compare creature behaviour |
| 5 | +/// distributions, quantify trait drift between populations. In multiplayer, these |
| 6 | +/// metrics let the simulation decide when two players' lineages count as separate species. |
3 | 7 | pub struct KLDivergence; |
4 | 8 |
|
5 | 9 | impl KLDivergence { |
6 | | - /// Calculate KL divergence D(P||Q) = Σ P(i) log₂(P(i)/Q(i)) |
7 | | - /// Measures how distribution P differs from reference Q |
8 | | - /// Returns divergence in bits - NOT symmetric |
| 10 | + /// D(P||Q) = Σ P(i) log₂(P(i)/Q(i)) |
| 11 | + /// |
| 12 | + /// How many extra bits are needed to encode P-distributed events using a code |
| 13 | + /// optimised for Q. NOT symmetric. Returns +∞ when P has mass where Q has none. |
9 | 14 | pub fn divergence(p_probs: &[f64], q_probs: &[f64]) -> f64 { |
10 | 15 | assert_eq!( |
11 | 16 | p_probs.len(), |
12 | 17 | q_probs.len(), |
13 | 18 | "P and Q must have same length" |
14 | 19 | ); |
15 | 20 |
|
16 | | - let mut kl_div = 0.0; |
| 21 | + let mut kl = 0.0; |
17 | 22 | for (&p, &q) in p_probs.iter().zip(q_probs) { |
18 | 23 | if p > 0.0 && q > 0.0 { |
19 | | - kl_div += p * (p / q).log2(); |
20 | | - } else if p > 0.0 && q == 0.0 { |
21 | | - // P has probability where Q doesn't - infinite divergence |
| 24 | + kl += p * (p / q).log2(); |
| 25 | + } else if p > 0.0 { |
| 26 | + // P has support where Q does not — infinite divergence by definition |
22 | 27 | return f64::INFINITY; |
23 | 28 | } |
24 | | - // p == 0.0 contributes nothing to KL divergence |
| 29 | + // p == 0 contributes 0 regardless of q (0 · log(0/q) := 0) |
25 | 30 | } |
26 | | - |
27 | | - kl_div |
| 31 | + kl |
28 | 32 | } |
29 | 33 |
|
30 | | - /// Calculate Jensen-Shannon divergence - symmetric version of KL |
31 | | - /// JS(P,Q) = 0.5 * [D(P||M) + D(Q||M)] where M = 0.5*(P+Q) |
32 | | - /// Always finite and bounded [0, 1] bits |
| 34 | + /// JS(P,Q) = 0.5 · [D(P||M) + D(Q||M)] where M = 0.5·(P+Q) |
| 35 | + /// |
| 36 | + /// Symmetric, always finite, bounded [0, 1] bits (log₂ base). |
| 37 | + /// Reaches 1 bit only when P and Q have fully disjoint support. |
| 38 | + /// Preferred over raw KL for comparing creature trait distributions |
| 39 | + /// because it never blows up and is a proper metric (square root is a distance). |
33 | 40 | pub fn jensen_shannon(p_probs: &[f64], q_probs: &[f64]) -> f64 { |
34 | 41 | assert_eq!( |
35 | 42 | p_probs.len(), |
36 | 43 | q_probs.len(), |
37 | 44 | "P and Q must have same length" |
38 | 45 | ); |
39 | 46 |
|
40 | | - // Calculate mixture distribution M = 0.5*(P+Q) |
41 | | - let m_probs: Vec<f64> = p_probs |
| 47 | + let m: Vec<f64> = p_probs |
42 | 48 | .iter() |
43 | 49 | .zip(q_probs) |
44 | 50 | .map(|(&p, &q)| 0.5 * (p + q)) |
45 | 51 | .collect(); |
46 | 52 |
|
47 | | - let kl_pm = Self::divergence(p_probs, &m_probs); |
48 | | - let kl_qm = Self::divergence(q_probs, &m_probs); |
49 | | - |
50 | | - 0.5 * (kl_pm + kl_qm) |
| 53 | + // M[i] >= 0.5 · max(P[i], Q[i]), so KL(P||M) and KL(Q||M) are always finite. |
| 54 | + 0.5 * (Self::divergence(p_probs, &m) + Self::divergence(q_probs, &m)) |
51 | 55 | } |
52 | 56 |
|
53 | | - /// Calculate cross-entropy H(P,Q) = -Σ P(i) log₂(Q(i)) |
54 | | - /// Useful for ML loss functions and distribution comparison |
| 57 | + /// H(P,Q) = -Σ P(i) log₂(Q(i)) |
| 58 | + /// |
| 59 | + /// Cross-entropy: bits to encode P-distributed events with a Q-optimal code. |
| 60 | + /// H(P,Q) = H(P) + D(P||Q). Returns +∞ when Q has no mass where P has mass. |
55 | 61 | pub fn cross_entropy(p_probs: &[f64], q_probs: &[f64]) -> f64 { |
56 | 62 | assert_eq!( |
57 | 63 | p_probs.len(), |
58 | 64 | q_probs.len(), |
59 | 65 | "P and Q must have same length" |
60 | 66 | ); |
61 | 67 |
|
62 | | - let mut cross_ent = 0.0; |
| 68 | + let mut ce = 0.0; |
63 | 69 | for (&p, &q) in p_probs.iter().zip(q_probs) { |
64 | 70 | if p > 0.0 && q > 0.0 { |
65 | | - cross_ent -= p * q.log2(); |
66 | | - } else if p > 0.0 && q == 0.0 { |
| 71 | + ce -= p * q.log2(); |
| 72 | + } else if p > 0.0 { |
67 | 73 | return f64::INFINITY; |
68 | 74 | } |
69 | 75 | } |
| 76 | + ce |
| 77 | + } |
| 78 | + |
| 79 | + /// TV(P,Q) = 0.5 · Σ|P(i) - Q(i)| |
| 80 | + /// |
| 81 | + /// Total variation distance: the maximum probability gap any single event |
| 82 | + /// can have between P and Q. Symmetric, bounded [0, 1], no log required. |
| 83 | + /// Fastest way to ask "how different are these two distributions?" when you |
| 84 | + /// do not need the information-theoretic interpretation of KL/JS. |
| 85 | + pub fn total_variation(p_probs: &[f64], q_probs: &[f64]) -> f64 { |
| 86 | + assert_eq!( |
| 87 | + p_probs.len(), |
| 88 | + q_probs.len(), |
| 89 | + "P and Q must have same length" |
| 90 | + ); |
| 91 | + 0.5 * p_probs |
| 92 | + .iter() |
| 93 | + .zip(q_probs) |
| 94 | + .map(|(&p, &q)| (p - q).abs()) |
| 95 | + .sum::<f64>() |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +#[cfg(test)] |
| 100 | +mod tests { |
| 101 | + use super::*; |
| 102 | + |
| 103 | + #[test] |
| 104 | + fn kl_identical_is_zero() { |
| 105 | + let p = [0.5, 0.5]; |
| 106 | + assert!(KLDivergence::divergence(&p, &p).abs() < 1e-10); |
| 107 | + } |
| 108 | + |
| 109 | + #[test] |
| 110 | + fn kl_disjoint_support_is_infinity() { |
| 111 | + let p = [1.0, 0.0]; |
| 112 | + let q = [0.0, 1.0]; |
| 113 | + assert_eq!(KLDivergence::divergence(&p, &q), f64::INFINITY); |
| 114 | + } |
| 115 | + |
| 116 | + #[test] |
| 117 | + fn kl_known_value() { |
| 118 | + // D([0.5, 0.5] || [0.25, 0.75]) = 0.5*log2(2) + 0.5*log2(2/3) ≈ 0.2075 |
| 119 | + let p = [0.5, 0.5]; |
| 120 | + let q = [0.25, 0.75]; |
| 121 | + let kl = KLDivergence::divergence(&p, &q); |
| 122 | + let expected = 0.5 * (0.5_f64 / 0.25).log2() + 0.5 * (0.5_f64 / 0.75).log2(); |
| 123 | + assert!((kl - expected).abs() < 1e-10, "got {}", kl); |
| 124 | + } |
| 125 | + |
| 126 | + #[test] |
| 127 | + fn js_identical_is_zero() { |
| 128 | + let p = [0.25, 0.25, 0.25, 0.25]; |
| 129 | + assert!(KLDivergence::jensen_shannon(&p, &p).abs() < 1e-10); |
| 130 | + } |
| 131 | + |
| 132 | + #[test] |
| 133 | + fn js_disjoint_support_is_one_bit() { |
| 134 | + let p = [1.0, 0.0]; |
| 135 | + let q = [0.0, 1.0]; |
| 136 | + let js = KLDivergence::jensen_shannon(&p, &q); |
| 137 | + assert!((js - 1.0).abs() < 1e-10, "expected 1 bit, got {}", js); |
| 138 | + } |
| 139 | + |
| 140 | + #[test] |
| 141 | + fn js_is_symmetric() { |
| 142 | + let p = [0.7, 0.2, 0.1]; |
| 143 | + let q = [0.1, 0.5, 0.4]; |
| 144 | + let diff = |
| 145 | + (KLDivergence::jensen_shannon(&p, &q) - KLDivergence::jensen_shannon(&q, &p)).abs(); |
| 146 | + assert!(diff < 1e-12, "JS must be symmetric, diff = {}", diff); |
| 147 | + } |
| 148 | + |
| 149 | + #[test] |
| 150 | + fn js_bounded_between_zero_and_one() { |
| 151 | + let p = [0.6, 0.3, 0.1]; |
| 152 | + let q = [0.1, 0.2, 0.7]; |
| 153 | + let js = KLDivergence::jensen_shannon(&p, &q); |
| 154 | + assert!(js >= 0.0 && js <= 1.0, "JS out of [0,1]: {}", js); |
| 155 | + } |
| 156 | + |
| 157 | + #[test] |
| 158 | + fn cross_entropy_of_self_equals_entropy() { |
| 159 | + // H(P, P) = H(P). For fair coin P=[0.5,0.5], H = 1 bit. |
| 160 | + let p = [0.5, 0.5]; |
| 161 | + let h = KLDivergence::cross_entropy(&p, &p); |
| 162 | + assert!((h - 1.0).abs() < 1e-10, "expected 1 bit, got {}", h); |
| 163 | + } |
| 164 | + |
| 165 | + #[test] |
| 166 | + fn total_variation_identical_is_zero() { |
| 167 | + let p = [0.3, 0.3, 0.4]; |
| 168 | + assert!(KLDivergence::total_variation(&p, &p).abs() < 1e-12); |
| 169 | + } |
| 170 | + |
| 171 | + #[test] |
| 172 | + fn total_variation_disjoint_is_one() { |
| 173 | + let p = [1.0, 0.0]; |
| 174 | + let q = [0.0, 1.0]; |
| 175 | + assert!((KLDivergence::total_variation(&p, &q) - 1.0).abs() < 1e-12); |
| 176 | + } |
70 | 177 |
|
71 | | - cross_ent |
| 178 | + #[test] |
| 179 | + fn total_variation_bounded() { |
| 180 | + let p = [0.6, 0.4]; |
| 181 | + let q = [0.2, 0.8]; |
| 182 | + let tv = KLDivergence::total_variation(&p, &q); |
| 183 | + assert!(tv >= 0.0 && tv <= 1.0, "TV must be in [0, 1], got {}", tv); |
72 | 184 | } |
73 | 185 | } |
0 commit comments