|
1 | 1 | """ |
2 | | -evaluation/stats.py — Statistical helpers for multi-seed aggregation. |
| 2 | +evaluation/stats.py — Statistical helpers for multi-seed aggregation and |
| 3 | +forgetting-curve analysis. |
3 | 4 |
|
4 | | -Provides mean ± std + 95% confidence intervals over multiple benchmark runs. |
| 5 | +Provides: |
| 6 | + - mean ± std + 95% confidence intervals over multiple benchmark runs |
| 7 | + - fit_forgetting_curve() — fits Ebbinghaus / exponential decay to recall@T |
| 8 | + data and returns half-life, stability, and R² goodness-of-fit |
| 9 | +
|
| 10 | +Closes #6. |
5 | 11 | """ |
6 | 12 |
|
7 | 13 | import math |
8 | | -from typing import List, Optional, Tuple |
| 14 | +from typing import Dict, List, Optional, Tuple |
| 15 | + |
9 | 16 |
|
| 17 | +# ── Basic statistics ────────────────────────────────────────────────────────── |
10 | 18 |
|
11 | 19 | def _mean(values: List[float]) -> float: |
12 | 20 | return sum(values) / len(values) if values else 0.0 |
@@ -72,3 +80,172 @@ def aggregate_checkpoint_series( |
72 | 80 | aggregate_metric([run[i] for run in series]) |
73 | 81 | for i in range(n_checkpoints) |
74 | 82 | ] |
| 83 | + |
| 84 | + |
| 85 | +# ── Forgetting-curve fitting ────────────────────────────────────────────────── |
| 86 | + |
| 87 | +def _r_squared(observed: List[float], predicted: List[float]) -> float: |
| 88 | + """Coefficient of determination R² ∈ (-∞, 1]; 1 = perfect fit.""" |
| 89 | + if len(observed) < 2: |
| 90 | + return float("nan") |
| 91 | + mean_obs = _mean(observed) |
| 92 | + ss_tot = sum((y - mean_obs) ** 2 for y in observed) |
| 93 | + ss_res = sum((y - y_hat) ** 2 for y, y_hat in zip(observed, predicted)) |
| 94 | + if ss_tot == 0: |
| 95 | + return 1.0 if ss_res == 0 else float("-inf") |
| 96 | + return 1.0 - ss_res / ss_tot |
| 97 | + |
| 98 | + |
| 99 | +def _fit_exponential( |
| 100 | + turns: List[float], |
| 101 | + recalls: List[float], |
| 102 | +) -> Tuple[float, float, float]: |
| 103 | + """ |
| 104 | + Fit R(t) = a · exp(−k · t) via log-linear least squares. |
| 105 | +
|
| 106 | + Returns (a, k, r_squared). |
| 107 | + a — intercept (recall at t=0, ideally ≈ 1.0) |
| 108 | + k — decay rate (higher = faster forgetting) |
| 109 | + """ |
| 110 | + # Filter out zero/negative recalls to avoid log(0) |
| 111 | + valid = [(t, r) for t, r in zip(turns, recalls) if r > 0] |
| 112 | + if len(valid) < 2: |
| 113 | + return (float("nan"), float("nan"), float("nan")) |
| 114 | + |
| 115 | + xs = [t for t, _ in valid] |
| 116 | + ys = [math.log(r) for _, r in valid] |
| 117 | + |
| 118 | + # Linear regression on log(R) = log(a) - k*t |
| 119 | + n = len(xs) |
| 120 | + sx = sum(xs) |
| 121 | + sy = sum(ys) |
| 122 | + sxx = sum(x * x for x in xs) |
| 123 | + sxy = sum(x * y for x, y in zip(xs, ys)) |
| 124 | + denom = n * sxx - sx * sx |
| 125 | + if denom == 0: |
| 126 | + return (float("nan"), float("nan"), float("nan")) |
| 127 | + |
| 128 | + k = -(n * sxy - sx * sy) / denom |
| 129 | + log_a = (sy - (-k) * sx) / n # using -k because slope is -k |
| 130 | + a = math.exp(log_a) |
| 131 | + |
| 132 | + predicted = [a * math.exp(-k * t) for t in turns] |
| 133 | + r2 = _r_squared(recalls, predicted) |
| 134 | + return (a, k, r2) |
| 135 | + |
| 136 | + |
| 137 | +def _fit_ebbinghaus( |
| 138 | + turns: List[float], |
| 139 | + recalls: List[float], |
| 140 | + t_max: float, |
| 141 | +) -> Tuple[float, float]: |
| 142 | + """ |
| 143 | + Fit R(t) = exp(−t_norm / (S · sqrt(1 + t_norm))) by grid-searching over S. |
| 144 | + t_norm = t / t_max so that t ∈ [0, 1]. |
| 145 | +
|
| 146 | + Returns (S, r_squared). |
| 147 | + S — stability constant (higher = slower forgetting) |
| 148 | + """ |
| 149 | + if len(turns) < 2: |
| 150 | + return (float("nan"), float("nan")) |
| 151 | + |
| 152 | + t_norm_list = [t / max(t_max, 1) for t in turns] |
| 153 | + |
| 154 | + def _predict(s: float) -> List[float]: |
| 155 | + result = [] |
| 156 | + for tn in t_norm_list: |
| 157 | + if tn <= 0: |
| 158 | + result.append(1.0) |
| 159 | + else: |
| 160 | + denom = s * math.sqrt(1.0 + tn) |
| 161 | + result.append(math.exp(-tn / denom)) |
| 162 | + return result |
| 163 | + |
| 164 | + best_s = 1.0 |
| 165 | + best_r2 = float("-inf") |
| 166 | + |
| 167 | + # Coarse + fine grid search over S ∈ [0.01, 20] |
| 168 | + for s in [i * 0.1 for i in range(1, 201)]: |
| 169 | + predicted = _predict(s) |
| 170 | + r2 = _r_squared(recalls, predicted) |
| 171 | + if not math.isnan(r2) and r2 > best_r2: |
| 172 | + best_r2 = r2 |
| 173 | + best_s = s |
| 174 | + |
| 175 | + return (best_s, best_r2) |
| 176 | + |
| 177 | + |
| 178 | +def fit_forgetting_curve( |
| 179 | + checkpoints: List[int], |
| 180 | + recalls: List[float], |
| 181 | +) -> Dict: |
| 182 | + """ |
| 183 | + Fit forgetting-curve models to a backend's recall@T time-series and return |
| 184 | + interpretable memory-stability statistics. |
| 185 | +
|
| 186 | + Models fitted |
| 187 | + ------------- |
| 188 | + exponential : R(t) = a · exp(−k · t) |
| 189 | + Classic single-parameter decay (Jost 1897). |
| 190 | + ebbinghaus : R(t) = exp(−t_norm / (S · √(1 + t_norm))) |
| 191 | + Two-parameter Ebbinghaus (1885) forgetting curve. |
| 192 | +
|
| 193 | + Parameters |
| 194 | + ---------- |
| 195 | + checkpoints : list of turn numbers at which recall was measured |
| 196 | + recalls : list of recall values ∈ [0, 1] corresponding to each checkpoint |
| 197 | +
|
| 198 | + Returns |
| 199 | + ------- |
| 200 | + dict with keys: |
| 201 | + exponential: |
| 202 | + a — initial recall estimate at t=0 |
| 203 | + k — decay rate (nats per turn) |
| 204 | + half_life — turns until recall halves (ln(2)/k) |
| 205 | + r2 — R² goodness-of-fit |
| 206 | + ebbinghaus: |
| 207 | + stability — S parameter (higher = more stable memory) |
| 208 | + half_life — turns until recall drops to 0.5 |
| 209 | + r2 — R² goodness-of-fit |
| 210 | + checkpoints : input turns (echoed for convenience) |
| 211 | + recalls : input recalls (echoed for convenience) |
| 212 | + """ |
| 213 | + if len(checkpoints) != len(recalls) or len(checkpoints) < 2: |
| 214 | + return {"error": "Need at least 2 (checkpoint, recall) pairs."} |
| 215 | + |
| 216 | + turns = [float(t) for t in checkpoints] |
| 217 | + t_max = max(turns) |
| 218 | + |
| 219 | + # ── Exponential fit ────────────────────────────────────────────────────── |
| 220 | + a, k, r2_exp = _fit_exponential(turns, recalls) |
| 221 | + half_life_exp = math.log(2) / k if (not math.isnan(k) and k > 0) else float("inf") |
| 222 | + |
| 223 | + # ── Ebbinghaus fit ─────────────────────────────────────────────────────── |
| 224 | + S, r2_ebb = _fit_ebbinghaus(turns, recalls, t_max) |
| 225 | + |
| 226 | + # Half-life for Ebbinghaus: solve exp(-tn / (S * sqrt(1+tn))) = 0.5 |
| 227 | + # Numerically: scan t_norm values |
| 228 | + half_life_ebb = float("inf") |
| 229 | + if not math.isnan(S): |
| 230 | + for step in range(1, 10001): |
| 231 | + tn = step / 100.0 |
| 232 | + val = math.exp(-tn / (S * math.sqrt(1.0 + tn))) |
| 233 | + if val <= 0.5: |
| 234 | + half_life_ebb = round(tn * t_max, 2) |
| 235 | + break |
| 236 | + |
| 237 | + return { |
| 238 | + "exponential": { |
| 239 | + "a": round(a, 4) if not math.isnan(a) else None, |
| 240 | + "k": round(k, 6) if not math.isnan(k) else None, |
| 241 | + "half_life": round(half_life_exp, 2) if not math.isinf(half_life_exp) else None, |
| 242 | + "r2": round(r2_exp, 4) if not math.isnan(r2_exp) else None, |
| 243 | + }, |
| 244 | + "ebbinghaus": { |
| 245 | + "stability": round(S, 4) if not math.isnan(S) else None, |
| 246 | + "half_life": half_life_ebb if not math.isinf(half_life_ebb) else None, |
| 247 | + "r2": round(r2_ebb, 4) if not math.isnan(r2_ebb) else None, |
| 248 | + }, |
| 249 | + "checkpoints": checkpoints, |
| 250 | + "recalls": [round(r, 4) for r in recalls], |
| 251 | + } |
0 commit comments