Skip to content

Commit 6965e98

Browse files
committed
feat: add --script mode to codeflash compare
Allows running arbitrary benchmark scripts on both git refs and rendering a styled comparison table. Supports optional --memory via memray wrapping. No codeflash config required for script mode.
1 parent ca198ce commit 6965e98

4 files changed

Lines changed: 491 additions & 30 deletions

File tree

codeflash/benchmarking/compare.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,63 @@ def sort_key(fn: str, _bm_key: BenchmarkKey = bm_key) -> float:
145145
return "\n\n".join(sections)
146146

147147

148+
@dataclass
149+
class ScriptCompareResult:
150+
base_ref: str
151+
head_ref: str
152+
base_results: dict[str, float] = field(default_factory=dict)
153+
head_results: dict[str, float] = field(default_factory=dict)
154+
base_memory: Optional[MemoryStats] = None
155+
head_memory: Optional[MemoryStats] = None
156+
157+
def format_markdown(self) -> str:
158+
if not self.base_results and not self.head_results and not self.base_memory and not self.head_memory:
159+
return "_No benchmark results to compare._"
160+
161+
base_short = self.base_ref[:12]
162+
head_short = self.head_ref[:12]
163+
lines: list[str] = [f"## Benchmark: `{base_short}` vs `{head_short}`"]
164+
165+
all_keys = sorted((set(self.base_results) | set(self.head_results)) - {"__total__"})
166+
has_total = "__total__" in self.base_results or "__total__" in self.head_results
167+
168+
lines.extend(["", "| Key | Base | Head | Delta | Speedup |", "|:---|---:|---:|:---|---:|"])
169+
for key in all_keys:
170+
b = self.base_results.get(key)
171+
h = self.head_results.get(key)
172+
lines.append(
173+
f"| `{key}` | {_fmt_seconds(b)} | {_fmt_seconds(h)} | {_md_delta_s(b, h)} | {md_speedup(b, h)} |"
174+
)
175+
176+
if has_total:
177+
b = self.base_results.get("__total__")
178+
h = self.head_results.get("__total__")
179+
lines.append(
180+
f"| **TOTAL** | **{_fmt_seconds(b)}** | **{_fmt_seconds(h)}** | {_md_delta_s(b, h)} | {md_speedup(b, h)} |"
181+
)
182+
183+
if self.base_memory or self.head_memory:
184+
lines.extend(
185+
["", "#### Memory", "", "| Ref | Peak Memory | Allocations | Delta |", "|:---|---:|---:|:---|"]
186+
)
187+
if self.base_memory:
188+
lines.append(
189+
f"| `{base_short}` (base) | {md_bytes(self.base_memory.peak_memory_bytes)}"
190+
f" | {self.base_memory.total_allocations:,} | |"
191+
)
192+
if self.head_memory:
193+
delta = md_memory_delta(
194+
self.base_memory.peak_memory_bytes if self.base_memory else None, self.head_memory.peak_memory_bytes
195+
)
196+
lines.append(
197+
f"| `{head_short}` (head) | {md_bytes(self.head_memory.peak_memory_bytes)}"
198+
f" | {self.head_memory.total_allocations:,} | {delta} |"
199+
)
200+
201+
lines.extend(["", "---", "*Generated by codeflash optimization agent*"])
202+
return "\n".join(lines)
203+
204+
148205
def compare_branches(
149206
base_ref: str,
150207
head_ref: str,
@@ -837,3 +894,289 @@ def has_meaningful_memory_change(
837894
if alloc_pct > threshold_pct:
838895
return True
839896
return False
897+
898+
899+
# --- Script-mode comparison ---
900+
901+
902+
def _fmt_seconds(s: Optional[float]) -> str:
903+
if s is None:
904+
return "-"
905+
if s >= 60:
906+
return f"{s / 60:,.1f}m"
907+
return f"{s:,.2f}s"
908+
909+
910+
def _fmt_delta_s(before: Optional[float], after: Optional[float]) -> str:
911+
if before is None or after is None:
912+
return "-"
913+
pct = ((after - before) / before) * 100 if before != 0 else 0
914+
if pct < 0:
915+
return _GREEN_TPL % pct
916+
return _RED_TPL % pct
917+
918+
919+
def _md_delta_s(before: Optional[float], after: Optional[float]) -> str:
920+
if before is None or after is None or before == 0:
921+
return "-"
922+
pct = ((after - before) / before) * 100
923+
emoji = "\U0001f7e2" if pct <= 0 else "\U0001f534"
924+
return f"{emoji} {pct:+.1f}%"
925+
926+
927+
def _speedup_s(before: Optional[float], after: Optional[float]) -> str:
928+
if before is None or after is None or after == 0:
929+
return "-"
930+
ratio = before / after
931+
if ratio >= 1:
932+
return f"[green]{ratio:.2f}x[/green]"
933+
return f"[red]{ratio:.2f}x[/red]"
934+
935+
936+
def compare_with_script(
937+
base_ref: str,
938+
head_ref: str,
939+
project_root: Path,
940+
script_cmd: str,
941+
script_output: str,
942+
timeout: int = 600,
943+
memory: bool = False,
944+
) -> ScriptCompareResult:
945+
"""Compare benchmark performance between two git refs using a custom script.
946+
947+
The script is run in each worktree with CWD set to the worktree root.
948+
It must produce a JSON file at script_output (relative to worktree root)
949+
mapping keys to seconds, e.g. {"test1": 1.23, "__total__": 4.56}.
950+
"""
951+
import sys
952+
953+
if memory and sys.platform == "win32":
954+
logger.error("--memory requires memray which is not available on Windows")
955+
return ScriptCompareResult(base_ref=base_ref, head_ref=head_ref)
956+
957+
repo = git.Repo(project_root, search_parent_directories=True)
958+
959+
from codeflash.code_utils.git_worktree_utils import worktree_dirs
960+
961+
worktree_dirs.mkdir(parents=True, exist_ok=True)
962+
timestamp = time.strftime("%Y%m%d-%H%M%S")
963+
964+
base_worktree = worktree_dirs / f"compare-base-{timestamp}"
965+
head_worktree = worktree_dirs / f"compare-head-{timestamp}"
966+
base_memray_bin = worktree_dirs / f"script-memray-base-{timestamp}.bin"
967+
head_memray_bin = worktree_dirs / f"script-memray-head-{timestamp}.bin"
968+
969+
result = ScriptCompareResult(base_ref=base_ref, head_ref=head_ref)
970+
971+
from rich.console import Group
972+
from rich.live import Live
973+
from rich.panel import Panel
974+
from rich.text import Text
975+
976+
base_short = base_ref[:12]
977+
head_short = head_ref[:12]
978+
979+
step_labels = [
980+
"Creating worktrees",
981+
f"Running benchmark on base ({base_short})",
982+
f"Running benchmark on head ({head_short})",
983+
]
984+
985+
def build_steps(current_step: int) -> Group:
986+
lines: list[Text] = []
987+
for i, label in enumerate(step_labels):
988+
if i < current_step:
989+
lines.append(Text.from_markup(f"[green]\u2714[/green] {label}"))
990+
elif i == current_step:
991+
lines.append(Text.from_markup(f"[cyan]\u25cb[/cyan] {label}..."))
992+
else:
993+
lines.append(Text.from_markup(f"[dim]\u2500 {label}[/dim]"))
994+
return Group(*lines)
995+
996+
def build_panel(current_step: int) -> Panel:
997+
return Panel(
998+
Group(
999+
Text.from_markup(
1000+
f"[bold cyan]{base_short}[/bold cyan] (base) vs [bold cyan]{head_short}[/bold cyan] (head)"
1001+
),
1002+
"",
1003+
Text.from_markup(f"[dim]Script:[/dim] {script_cmd}"),
1004+
"",
1005+
build_steps(current_step),
1006+
),
1007+
title="[bold]Script Benchmark Compare[/bold]",
1008+
border_style="cyan",
1009+
expand=True,
1010+
padding=(1, 2),
1011+
)
1012+
1013+
try:
1014+
step = 0
1015+
with Live(build_panel(step), console=console, refresh_per_second=1) as live:
1016+
base_sha = repo.commit(base_ref).hexsha
1017+
head_sha = repo.commit(head_ref).hexsha
1018+
repo.git.worktree("add", str(base_worktree), base_sha)
1019+
repo.git.worktree("add", str(head_worktree), head_sha)
1020+
step += 1
1021+
live.update(build_panel(step))
1022+
1023+
# Run script on base
1024+
result.base_results = _run_script_in_worktree(
1025+
script_cmd, base_worktree, script_output, timeout, base_memray_bin if memory else None
1026+
)
1027+
step += 1
1028+
live.update(build_panel(step))
1029+
1030+
# Run script on head
1031+
result.head_results = _run_script_in_worktree(
1032+
script_cmd, head_worktree, script_output, timeout, head_memray_bin if memory else None
1033+
)
1034+
1035+
# Parse memory results
1036+
if memory:
1037+
result.base_memory = _parse_memray_bin(base_memray_bin)
1038+
result.head_memory = _parse_memray_bin(head_memray_bin)
1039+
1040+
render_script_comparison(result)
1041+
1042+
except KeyboardInterrupt:
1043+
console.print("\n[yellow]Interrupted — cleaning up...[/yellow]")
1044+
1045+
finally:
1046+
from codeflash.code_utils.git_worktree_utils import remove_worktree
1047+
1048+
remove_worktree(base_worktree)
1049+
remove_worktree(head_worktree)
1050+
repo.git.worktree("prune")
1051+
for f in [base_memray_bin, head_memray_bin]:
1052+
if f.exists():
1053+
f.unlink()
1054+
1055+
return result
1056+
1057+
1058+
def _run_script_in_worktree(
1059+
script_cmd: str, worktree_dir: Path, script_output: str, timeout: int, memray_bin: Optional[Path]
1060+
) -> dict[str, float]:
1061+
import json
1062+
1063+
cmd = script_cmd
1064+
if memray_bin:
1065+
cmd = f"python -m memray run --trace-python-allocators -o {memray_bin} -- {cmd}"
1066+
1067+
try:
1068+
proc = subprocess.run( # noqa: S602
1069+
cmd, shell=True, cwd=worktree_dir, timeout=timeout, capture_output=True, text=True, check=False
1070+
)
1071+
if proc.returncode != 0:
1072+
logger.warning(f"Script exited with code {proc.returncode}")
1073+
if proc.stderr:
1074+
logger.debug(f"Script stderr:\n{proc.stderr[:2000]}")
1075+
except subprocess.TimeoutExpired:
1076+
logger.warning(f"Script timed out after {timeout}s")
1077+
return {}
1078+
1079+
output_path = worktree_dir / script_output
1080+
if not output_path.exists():
1081+
logger.warning(f"Script output not found at {output_path}")
1082+
return {}
1083+
1084+
try:
1085+
data = json.loads(output_path.read_text(encoding="utf-8"))
1086+
if not isinstance(data, dict):
1087+
logger.warning("Script output JSON is not a dict")
1088+
return {}
1089+
return {k: float(v) for k, v in data.items() if isinstance(v, (int, float))}
1090+
except (json.JSONDecodeError, ValueError) as e:
1091+
logger.warning(f"Failed to parse script output JSON: {e}")
1092+
return {}
1093+
1094+
1095+
def _parse_memray_bin(bin_path: Path) -> Optional[MemoryStats]:
1096+
if not bin_path.exists():
1097+
return None
1098+
try:
1099+
from memray import FileReader
1100+
1101+
from codeflash.benchmarking.plugin.plugin import MemoryStats
1102+
1103+
reader = FileReader(str(bin_path))
1104+
meta = reader.metadata
1105+
stats = MemoryStats(peak_memory_bytes=meta.peak_memory, total_allocations=meta.total_allocations)
1106+
reader.close()
1107+
return stats
1108+
except ImportError:
1109+
logger.warning("memray not installed — skipping memory results")
1110+
return None
1111+
except OSError as e:
1112+
logger.warning(f"Failed to read memray binary: {e}")
1113+
return None
1114+
1115+
1116+
def render_script_comparison(result: ScriptCompareResult) -> None:
1117+
has_timing = result.base_results or result.head_results
1118+
has_memory = result.base_memory or result.head_memory
1119+
if not has_timing and not has_memory:
1120+
logger.warning("No benchmark results to compare")
1121+
return
1122+
1123+
base_short = result.base_ref[:12]
1124+
head_short = result.head_ref[:12]
1125+
1126+
console.print()
1127+
console.rule(f"[bold]Script Benchmark: {base_short} vs {head_short}[/bold]")
1128+
console.print()
1129+
1130+
if has_timing:
1131+
all_keys = sorted((set(result.base_results) | set(result.head_results)) - {"__total__"})
1132+
has_total = "__total__" in result.base_results or "__total__" in result.head_results
1133+
1134+
t = Table(title="Benchmark Results", border_style="blue", show_lines=True, expand=False)
1135+
t.add_column("Key", style="cyan")
1136+
t.add_column("Base", justify="right", style="yellow")
1137+
t.add_column("Head", justify="right", style="yellow")
1138+
t.add_column("Delta", justify="right")
1139+
t.add_column("Speedup", justify="right")
1140+
1141+
for key in all_keys:
1142+
b = result.base_results.get(key)
1143+
h = result.head_results.get(key)
1144+
t.add_row(key, _fmt_seconds(b), _fmt_seconds(h), _fmt_delta_s(b, h), _speedup_s(b, h))
1145+
1146+
if has_total:
1147+
t.add_section()
1148+
b = result.base_results.get("__total__")
1149+
h = result.head_results.get("__total__")
1150+
t.add_row("[bold]TOTAL[/bold]", _fmt_seconds(b), _fmt_seconds(h), _fmt_delta_s(b, h), _speedup_s(b, h))
1151+
1152+
console.print(t, justify="center")
1153+
1154+
if has_memory:
1155+
console.print()
1156+
t_mem = Table(title="Memory (aggregate)", border_style="magenta", show_lines=True, expand=False)
1157+
t_mem.add_column("Ref", style="bold cyan")
1158+
t_mem.add_column("Peak Memory", justify="right")
1159+
t_mem.add_column("Allocations", justify="right")
1160+
t_mem.add_column("Delta", justify="right")
1161+
1162+
if result.base_memory:
1163+
t_mem.add_row(
1164+
f"{base_short} (base)",
1165+
fmt_bytes(result.base_memory.peak_memory_bytes),
1166+
f"{result.base_memory.total_allocations:,}",
1167+
"",
1168+
)
1169+
if result.head_memory:
1170+
delta = fmt_memory_delta(
1171+
result.base_memory.peak_memory_bytes if result.base_memory else None,
1172+
result.head_memory.peak_memory_bytes,
1173+
)
1174+
t_mem.add_row(
1175+
f"{head_short} (head)",
1176+
fmt_bytes(result.head_memory.peak_memory_bytes),
1177+
f"{result.head_memory.total_allocations:,}",
1178+
delta,
1179+
)
1180+
console.print(t_mem, justify="center")
1181+
1182+
console.print()

codeflash/cli_cmds/cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,13 @@ def _build_parser() -> ArgumentParser:
395395
compare_parser.add_argument(
396396
"--memory", action="store_true", help="Profile peak memory usage per benchmark (requires memray, Linux/macOS)"
397397
)
398+
compare_parser.add_argument("--script", type=str, help="Shell command to run as benchmark in each worktree")
399+
compare_parser.add_argument(
400+
"--script-output",
401+
type=str,
402+
dest="script_output",
403+
help="Relative path to JSON results file produced by --script (required with --script)",
404+
)
398405
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
399406

400407
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")

0 commit comments

Comments
 (0)