|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Plot grounded evolution convergence curves from experiment data. |
| 3 | +
|
| 4 | +Usage: |
| 5 | + python analysis/plot_convergence.py # Use main run_log.jsonl |
| 6 | + python analysis/plot_convergence.py --ablation # Use per-condition files |
| 7 | + python analysis/plot_convergence.py --ablation --rolling 5 # Rolling average |
| 8 | +
|
| 9 | +Output: PNG files in analysis/charts/ |
| 10 | +""" |
| 11 | + |
| 12 | +import json |
| 13 | +import sys |
| 14 | +from collections import defaultdict |
| 15 | +from pathlib import Path |
| 16 | +from typing import Any |
| 17 | + |
| 18 | + |
| 19 | +CHARTS_DIR: Path = Path("analysis/charts") |
| 20 | +ROLLING_WINDOW: int = 10 # default rolling average window |
| 21 | + |
| 22 | + |
| 23 | +def load_main_log() -> list[dict[str, Any]]: |
| 24 | + """Load all cycles from the main experiment log.""" |
| 25 | + log_path: Path = Path("experiments/run_log.jsonl") |
| 26 | + if not log_path.exists(): |
| 27 | + print("No experiment log found at experiments/run_log.jsonl") |
| 28 | + sys.exit(1) |
| 29 | + return [json.loads(line) for line in log_path.read_text().strip().splitlines() if line] |
| 30 | + |
| 31 | + |
| 32 | +def load_ablation_runs() -> dict[str, list[dict[str, Any]]]: |
| 33 | + """Load per-condition results from experiments/ablation_runs/*.jsonl.""" |
| 34 | + runs_dir: Path = Path("experiments/ablation_runs") |
| 35 | + if not runs_dir.exists(): |
| 36 | + print("No ablation runs found at experiments/ablation_runs/") |
| 37 | + sys.exit(1) |
| 38 | + |
| 39 | + results: dict[str, list[dict[str, Any]]] = {} |
| 40 | + for fpath in sorted(runs_dir.glob("*.jsonl")): |
| 41 | + condition: str = fpath.stem |
| 42 | + results[condition] = [ |
| 43 | + json.loads(line) for line in fpath.read_text().strip().splitlines() if line |
| 44 | + ] |
| 45 | + return results |
| 46 | + |
| 47 | + |
| 48 | +def rolling_average(values: list[float], window: int) -> list[float]: |
| 49 | + """Compute rolling average with the given window size.""" |
| 50 | + if not values or window <= 1: |
| 51 | + return list(values) |
| 52 | + smoothed: list[float] = [] |
| 53 | + for i in range(len(values)): |
| 54 | + start: int = max(0, i - window + 1) |
| 55 | + chunk: list[float] = values[start:i + 1] |
| 56 | + smoothed.append(sum(chunk) / len(chunk)) |
| 57 | + return smoothed |
| 58 | + |
| 59 | + |
| 60 | +def plot_main_convergence(records: list[dict[str, Any]]) -> None: |
| 61 | + """Plot overall score vs cycles from the main log.""" |
| 62 | + try: |
| 63 | + import matplotlib |
| 64 | + matplotlib.use("Agg") |
| 65 | + import matplotlib.pyplot as plt |
| 66 | + except ImportError: |
| 67 | + print("matplotlib not installed. Install it with: pip install matplotlib") |
| 68 | + return |
| 69 | + |
| 70 | + CHARTS_DIR.mkdir(parents=True, exist_ok=True) |
| 71 | + |
| 72 | + scores: list[float] = [r.get("score", 0) for r in records] |
| 73 | + best: list[float] = [] |
| 74 | + best_sofar: float = 0 |
| 75 | + for s in scores: |
| 76 | + best_sofar = max(best_sofar, s) |
| 77 | + best.append(best_sofar) |
| 78 | + |
| 79 | + fig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True) |
| 80 | + |
| 81 | + ax1, ax2 = axes |
| 82 | + |
| 83 | + # Top: per-cycle score |
| 84 | + ax1.plot(scores, alpha=0.4, color="blue", linewidth=0.8, label="Per-cycle score") |
| 85 | + smoothed = rolling_average(scores, ROLLING_WINDOW) |
| 86 | + ax1.plot(smoothed, color="blue", linewidth=2, label=f"Rolling avg (w={ROLLING_WINDOW})") |
| 87 | + ax1.set_ylabel("Execution Score") |
| 88 | + ax1.set_title("Grounded Evolution: Per-Cycle Scores") |
| 89 | + ax1.legend() |
| 90 | + ax1.grid(True, alpha=0.3) |
| 91 | + |
| 92 | + # Bottom: best-so-far |
| 93 | + ax2.plot(best, color="green", linewidth=2, label="Best so far") |
| 94 | + ax2.set_xlabel("Cycle") |
| 95 | + ax2.set_ylabel("Best Score") |
| 96 | + ax2.set_title("Grounded Evolution: Best Score Convergence") |
| 97 | + ax2.legend() |
| 98 | + ax2.grid(True, alpha=0.3) |
| 99 | + |
| 100 | + fig.tight_layout() |
| 101 | + out: Path = CHARTS_DIR / "convergence_main.png" |
| 102 | + fig.savefig(out, dpi=150) |
| 103 | + plt.close(fig) |
| 104 | + print(f"Saved {out}") |
| 105 | + |
| 106 | + |
| 107 | +def plot_ablation_convergence(conditions: dict[str, list[dict[str, Any]]]) -> None: |
| 108 | + """Plot ablation study comparison: one line per condition.""" |
| 109 | + try: |
| 110 | + import matplotlib |
| 111 | + matplotlib.use("Agg") |
| 112 | + import matplotlib.pyplot as plt |
| 113 | + except ImportError: |
| 114 | + print("matplotlib not installed. Install it with: pip install matplotlib") |
| 115 | + return |
| 116 | + |
| 117 | + CHARTS_DIR.mkdir(parents=True, exist_ok=True) |
| 118 | + |
| 119 | + fig, axes = plt.subplots(2, 1, figsize=(14, 12)) |
| 120 | + |
| 121 | + ax1, ax2 = axes |
| 122 | + |
| 123 | + colors: dict[str, str] = { |
| 124 | + "full": "blue", |
| 125 | + "mutation_only": "orange", |
| 126 | + "crossover_only": "green", |
| 127 | + "random_walk": "red", |
| 128 | + } |
| 129 | + markers: dict[str, str] = { |
| 130 | + "full": "o", |
| 131 | + "mutation_only": "s", |
| 132 | + "crossover_only": "^", |
| 133 | + "random_walk": "v", |
| 134 | + } |
| 135 | + |
| 136 | + # Top: per-condition best-so-far |
| 137 | + for cid, records in sorted(conditions.items()): |
| 138 | + scores: list[float] = [r.get("score", 0) for r in records] |
| 139 | + best: list[float] = [] |
| 140 | + best_sofar: float = 0 |
| 141 | + for s in scores: |
| 142 | + best_sofar = max(best_sofar, s) |
| 143 | + best.append(best_sofar) |
| 144 | + |
| 145 | + base_cid: str = cid.rsplit("_", 1)[0] if "_" in cid else cid |
| 146 | + color: str = colors.get(base_cid, "gray") |
| 147 | + marker: str = markers.get(base_cid, ".") |
| 148 | + label: str = cid |
| 149 | + ax1.plot(best, color=color, linewidth=1.5, label=label, marker=marker, markevery=max(1, len(best) // 10)) |
| 150 | + |
| 151 | + ax1.set_ylabel("Best Score") |
| 152 | + ax1.set_title("Ablation Study: Best Score Convergence by Condition") |
| 153 | + ax1.legend(fontsize=8, ncol=2) |
| 154 | + ax1.grid(True, alpha=0.3) |
| 155 | + |
| 156 | + # Bottom: aggregated per-condition (group by condition, average across benchmarks) |
| 157 | + condition_scores: dict[str, list[list[float]]] = defaultdict(list) |
| 158 | + for cid, records in sorted(conditions.items()): |
| 159 | + base_cid = cid.rsplit("_", 1)[0] if "_" in cid else cid |
| 160 | + condition_scores[base_cid].append([r.get("score", 0) for r in records]) |
| 161 | + |
| 162 | + for cond, all_scores in sorted(condition_scores.items()): |
| 163 | + # Average across benchmarks at each cycle |
| 164 | + min_len: int = min(len(s) for s in all_scores) |
| 165 | + avg_scores: list[float] = [sum(s[i] for s in all_scores) / len(all_scores) for i in range(min_len)] |
| 166 | + best_avg: list[float] = [] |
| 167 | + best_sofar = 0 |
| 168 | + for s in avg_scores: |
| 169 | + best_sofar = max(best_sofar, s) |
| 170 | + best_avg.append(best_sofar) |
| 171 | + |
| 172 | + color: str = colors.get(cond, "gray") |
| 173 | + marker: str = markers.get(cond, ".") |
| 174 | + ax2.plot(best_avg, color=color, linewidth=2.5, label=cond, marker=marker, markevery=max(1, min_len // 8)) |
| 175 | + |
| 176 | + ax2.set_xlabel("Cycle") |
| 177 | + ax2.set_ylabel("Best Score (avg across benchmarks)") |
| 178 | + ax2.set_title("Ablation Study: Aggregate Convergence (averaged across benchmarks)") |
| 179 | + ax2.legend(fontsize=10) |
| 180 | + ax2.grid(True, alpha=0.3) |
| 181 | + |
| 182 | + fig.tight_layout() |
| 183 | + out: Path = CHARTS_DIR / "convergence_ablation.png" |
| 184 | + fig.savefig(out, dpi=150) |
| 185 | + plt.close(fig) |
| 186 | + print(f"Saved {out}") |
| 187 | + |
| 188 | + |
| 189 | +def main() -> None: |
| 190 | + """Main entry point.""" |
| 191 | + use_ablation: bool = "--ablation" in sys.argv |
| 192 | + rolling_window: int = ROLLING_WINDOW |
| 193 | + for arg in sys.argv: |
| 194 | + if arg.startswith("--rolling="): |
| 195 | + rolling_window = int(arg.split("=")[1]) |
| 196 | + |
| 197 | + global ROLLING_WINDOW |
| 198 | + ROLLING_WINDOW = rolling_window |
| 199 | + |
| 200 | + if use_ablation: |
| 201 | + conditions = load_ablation_runs() |
| 202 | + print(f"Loaded {len(conditions)} condition files from experiments/ablation_runs/") |
| 203 | + print(f"Conditions: {', '.join(sorted(conditions.keys()))}") |
| 204 | + plot_ablation_convergence(conditions) |
| 205 | + else: |
| 206 | + records = load_main_log() |
| 207 | + n_benchmarks = len(set(r.get("benchmark", "?") for r in records)) |
| 208 | + print(f"Loaded {len(records)} cycles across {n_benchmarks} benchmarks") |
| 209 | + plot_main_convergence(records) |
| 210 | + |
| 211 | + print(f"Charts saved to {CHARTS_DIR}/") |
| 212 | + |
| 213 | + |
| 214 | +if __name__ == "__main__": |
| 215 | + main() |
0 commit comments