Skip to content

Commit 8a58920

Browse files
Neal006claude
andcommitted
feat: fit Ebbinghaus and exponential forgetting curves to recall@T data (#19)
Adds fit_forgetting_curve() to evaluation/stats.py — fits both exponential (log-linear OLS) and Ebbinghaus (grid-search over S) models to recall@T time-series and reports half-life, stability constant, and R² per backend. CLI: --fit-curves flag prints the analysis after any benchmark run, works with both single-seed and multi-seed (uses mean recall). Merge resolves conflict with PR #18 --scenario arg; both args preserved. Co-authored-by: Priyanshu-byte-coder Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 2a6fdad commit 8a58920

2 files changed

Lines changed: 227 additions & 3 deletions

File tree

evaluation/stats.py

Lines changed: 180 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
"""
2-
evaluation/stats.py — Statistical helpers for multi-seed aggregation.
2+
evaluation/stats.py — Statistical helpers for multi-seed aggregation and
3+
forgetting-curve analysis.
34
4-
Provides mean ± std + 95% confidence intervals over multiple benchmark runs.
5+
Provides:
6+
- mean ± std + 95% confidence intervals over multiple benchmark runs
7+
- fit_forgetting_curve() — fits Ebbinghaus / exponential decay to recall@T
8+
data and returns half-life, stability, and R² goodness-of-fit
9+
10+
Closes #6.
511
"""
612

713
import math
8-
from typing import List, Optional, Tuple
14+
from typing import Dict, List, Optional, Tuple
15+
916

17+
# ── Basic statistics ──────────────────────────────────────────────────────────
1018

1119
def _mean(values: List[float]) -> float:
1220
return sum(values) / len(values) if values else 0.0
@@ -72,3 +80,172 @@ def aggregate_checkpoint_series(
7280
aggregate_metric([run[i] for run in series])
7381
for i in range(n_checkpoints)
7482
]
83+
84+
85+
# ── Forgetting-curve fitting ──────────────────────────────────────────────────
86+
87+
def _r_squared(observed: List[float], predicted: List[float]) -> float:
88+
"""Coefficient of determination R² ∈ (-∞, 1]; 1 = perfect fit."""
89+
if len(observed) < 2:
90+
return float("nan")
91+
mean_obs = _mean(observed)
92+
ss_tot = sum((y - mean_obs) ** 2 for y in observed)
93+
ss_res = sum((y - y_hat) ** 2 for y, y_hat in zip(observed, predicted))
94+
if ss_tot == 0:
95+
return 1.0 if ss_res == 0 else float("-inf")
96+
return 1.0 - ss_res / ss_tot
97+
98+
99+
def _fit_exponential(
100+
turns: List[float],
101+
recalls: List[float],
102+
) -> Tuple[float, float, float]:
103+
"""
104+
Fit R(t) = a · exp(−k · t) via log-linear least squares.
105+
106+
Returns (a, k, r_squared).
107+
a — intercept (recall at t=0, ideally ≈ 1.0)
108+
k — decay rate (higher = faster forgetting)
109+
"""
110+
# Filter out zero/negative recalls to avoid log(0)
111+
valid = [(t, r) for t, r in zip(turns, recalls) if r > 0]
112+
if len(valid) < 2:
113+
return (float("nan"), float("nan"), float("nan"))
114+
115+
xs = [t for t, _ in valid]
116+
ys = [math.log(r) for _, r in valid]
117+
118+
# Linear regression on log(R) = log(a) - k*t
119+
n = len(xs)
120+
sx = sum(xs)
121+
sy = sum(ys)
122+
sxx = sum(x * x for x in xs)
123+
sxy = sum(x * y for x, y in zip(xs, ys))
124+
denom = n * sxx - sx * sx
125+
if denom == 0:
126+
return (float("nan"), float("nan"), float("nan"))
127+
128+
k = -(n * sxy - sx * sy) / denom
129+
log_a = (sy - (-k) * sx) / n # using -k because slope is -k
130+
a = math.exp(log_a)
131+
132+
predicted = [a * math.exp(-k * t) for t in turns]
133+
r2 = _r_squared(recalls, predicted)
134+
return (a, k, r2)
135+
136+
137+
def _fit_ebbinghaus(
138+
turns: List[float],
139+
recalls: List[float],
140+
t_max: float,
141+
) -> Tuple[float, float]:
142+
"""
143+
Fit R(t) = exp(−t_norm / (S · sqrt(1 + t_norm))) by grid-searching over S.
144+
t_norm = t / t_max so that t ∈ [0, 1].
145+
146+
Returns (S, r_squared).
147+
S — stability constant (higher = slower forgetting)
148+
"""
149+
if len(turns) < 2:
150+
return (float("nan"), float("nan"))
151+
152+
t_norm_list = [t / max(t_max, 1) for t in turns]
153+
154+
def _predict(s: float) -> List[float]:
155+
result = []
156+
for tn in t_norm_list:
157+
if tn <= 0:
158+
result.append(1.0)
159+
else:
160+
denom = s * math.sqrt(1.0 + tn)
161+
result.append(math.exp(-tn / denom))
162+
return result
163+
164+
best_s = 1.0
165+
best_r2 = float("-inf")
166+
167+
# Coarse + fine grid search over S ∈ [0.01, 20]
168+
for s in [i * 0.1 for i in range(1, 201)]:
169+
predicted = _predict(s)
170+
r2 = _r_squared(recalls, predicted)
171+
if not math.isnan(r2) and r2 > best_r2:
172+
best_r2 = r2
173+
best_s = s
174+
175+
return (best_s, best_r2)
176+
177+
178+
def fit_forgetting_curve(
179+
checkpoints: List[int],
180+
recalls: List[float],
181+
) -> Dict:
182+
"""
183+
Fit forgetting-curve models to a backend's recall@T time-series and return
184+
interpretable memory-stability statistics.
185+
186+
Models fitted
187+
-------------
188+
exponential : R(t) = a · exp(−k · t)
189+
Classic single-parameter decay (Jost 1897).
190+
ebbinghaus : R(t) = exp(−t_norm / (S · √(1 + t_norm)))
191+
Two-parameter Ebbinghaus (1885) forgetting curve.
192+
193+
Parameters
194+
----------
195+
checkpoints : list of turn numbers at which recall was measured
196+
recalls : list of recall values ∈ [0, 1] corresponding to each checkpoint
197+
198+
Returns
199+
-------
200+
dict with keys:
201+
exponential:
202+
a — initial recall estimate at t=0
203+
k — decay rate (nats per turn)
204+
half_life — turns until recall halves (ln(2)/k)
205+
r2 — R² goodness-of-fit
206+
ebbinghaus:
207+
stability — S parameter (higher = more stable memory)
208+
half_life — turns until recall drops to 0.5
209+
r2 — R² goodness-of-fit
210+
checkpoints : input turns (echoed for convenience)
211+
recalls : input recalls (echoed for convenience)
212+
"""
213+
if len(checkpoints) != len(recalls) or len(checkpoints) < 2:
214+
return {"error": "Need at least 2 (checkpoint, recall) pairs."}
215+
216+
turns = [float(t) for t in checkpoints]
217+
t_max = max(turns)
218+
219+
# ── Exponential fit ──────────────────────────────────────────────────────
220+
a, k, r2_exp = _fit_exponential(turns, recalls)
221+
half_life_exp = math.log(2) / k if (not math.isnan(k) and k > 0) else float("inf")
222+
223+
# ── Ebbinghaus fit ───────────────────────────────────────────────────────
224+
S, r2_ebb = _fit_ebbinghaus(turns, recalls, t_max)
225+
226+
# Half-life for Ebbinghaus: solve exp(-tn / (S * sqrt(1+tn))) = 0.5
227+
# Numerically: scan t_norm values
228+
half_life_ebb = float("inf")
229+
if not math.isnan(S):
230+
for step in range(1, 10001):
231+
tn = step / 100.0
232+
val = math.exp(-tn / (S * math.sqrt(1.0 + tn)))
233+
if val <= 0.5:
234+
half_life_ebb = round(tn * t_max, 2)
235+
break
236+
237+
return {
238+
"exponential": {
239+
"a": round(a, 4) if not math.isnan(a) else None,
240+
"k": round(k, 6) if not math.isnan(k) else None,
241+
"half_life": round(half_life_exp, 2) if not math.isinf(half_life_exp) else None,
242+
"r2": round(r2_exp, 4) if not math.isnan(r2_exp) else None,
243+
},
244+
"ebbinghaus": {
245+
"stability": round(S, 4) if not math.isnan(S) else None,
246+
"half_life": half_life_ebb if not math.isinf(half_life_ebb) else None,
247+
"r2": round(r2_ebb, 4) if not math.isnan(r2_ebb) else None,
248+
},
249+
"checkpoints": checkpoints,
250+
"recalls": [round(r, 4) for r in recalls],
251+
}

main.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
Realistic chunked RAG backend:
2828
python main.py --backends naive rag_chunked cascading
2929
30+
Forgetting-curve analysis (fit Ebbinghaus + exponential to recall@T data):
31+
python main.py --fit-curves
32+
python main.py --seeds 5 --fit-curves
33+
3034
Other options:
3135
python main.py --turns 50 --backends naive rag --log
3236
python main.py --list-providers
@@ -66,6 +70,9 @@ def main() -> None:
6670
parser.add_argument("--decay", type=str, default="ebbinghaus",
6771
choices=["ebbinghaus", "exponential", "linear", "default"],
6872
help="Temporal decay function for CascadingMemory warm tier")
73+
parser.add_argument("--fit-curves", action="store_true",
74+
help="After benchmarking, fit Ebbinghaus + exponential decay curves "
75+
"to recall@T data and report half-life / stability / R²")
6976
parser.add_argument("--scenario", type=str, default="default",
7077
choices=["default", "edtech"],
7178
help="Conversation scenario: default (tech Q&A) | edtech (student-tutor)")
@@ -182,11 +189,51 @@ def main() -> None:
182189
})
183190
print(f"Experiment logged -> {path}")
184191

192+
# ── Forgetting-curve analysis ─────────────────────────────────────────────
193+
if args.fit_curves:
194+
from evaluation.stats import fit_forgetting_curve
195+
checkpoints = sorted(args.checkpoints)
196+
print("\nFORGETTING CURVE FIT (Ebbinghaus + Exponential)")
197+
print("-" * 65)
198+
if multi_seed:
199+
for name in args.backends:
200+
if name not in aggregated:
201+
continue
202+
mean_recalls = [stat["mean"] for stat in aggregated[name]["recall"]]
203+
fit = fit_forgetting_curve(checkpoints, mean_recalls)
204+
_print_curve_fit(name, fit)
205+
else:
206+
for name in args.backends:
207+
if name not in display:
208+
continue
209+
fit = fit_forgetting_curve(checkpoints, display[name]["recall"])
210+
_print_curve_fit(name, fit)
211+
185212
print("Visualise: streamlit run dashboard.py")
186213

187214

188215
# ── Output helpers ────────────────────────────────────────────────────────────
189216

217+
218+
def _print_curve_fit(backend: str, fit: dict) -> None:
219+
if "error" in fit:
220+
print(f" {backend:<14} {fit['error']}")
221+
return
222+
exp = fit["exponential"]
223+
ebb = fit["ebbinghaus"]
224+
hl_exp = f"{exp['half_life']:.1f} turns" if exp["half_life"] is not None else "N/A"
225+
hl_ebb = f"{ebb['half_life']:.1f} turns" if ebb["half_life"] is not None else "N/A"
226+
r2_exp = f"{exp['r2']:.3f}" if exp["r2"] is not None else "N/A"
227+
r2_ebb = f"{ebb['r2']:.3f}" if ebb["r2"] is not None else "N/A"
228+
stab = f"{ebb['stability']:.4f}" if ebb["stability"] is not None else "N/A"
229+
k_val = f"{exp['k']:.6f}" if exp["k"] is not None else "N/A"
230+
print(f" {backend}")
231+
print(f" Exponential k={k_val} half-life={hl_exp} R²={r2_exp}")
232+
print(f" Ebbinghaus S={stab} half-life={hl_ebb} R²={r2_ebb}")
233+
234+
235+
236+
190237
def _print_single_seed_results(display: dict, backends: list) -> None:
191238
checkpoints = display["checkpoints"]
192239
col = " ".join(f"T={c:3d}" for c in checkpoints)

0 commit comments

Comments
 (0)