Skip to content

Commit accb245

Browse files
authored
Merge pull request #1941 from codeflash-ai/cf-compare-copy-benchmarks
feat: enhance codeflash compare with memory profiling, script mode, and auto-calibration
2 parents 30ea701 + ab728f7 commit accb245

13 files changed

Lines changed: 1733 additions & 379 deletions

codeflash/benchmarking/compare.py

Lines changed: 789 additions & 216 deletions
Large diffs are not rendered by default.

codeflash/benchmarking/plugin/plugin.py

Lines changed: 232 additions & 113 deletions
Large diffs are not rendered by default.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Subprocess entry point for memory profiling benchmarks via pytest-memray.
2+
3+
Runs pytest with --memray --native to profile peak memory per test function.
4+
The codeflash-benchmark plugin is left active (without --codeflash-trace) so it
5+
provides a no-op ``benchmark`` fixture for tests that depend on it.
6+
"""
7+
8+
import sys
9+
from pathlib import Path
10+
11+
benchmarks_root = sys.argv[1]
12+
memray_bin_dir = sys.argv[2]
13+
memray_bin_prefix = sys.argv[3]
14+
15+
if __name__ == "__main__":
16+
import pytest
17+
18+
Path(memray_bin_dir).mkdir(parents=True, exist_ok=True)
19+
20+
exitcode = pytest.main(
21+
[
22+
benchmarks_root,
23+
"--memray",
24+
"--native",
25+
f"--memray-bin-path={memray_bin_dir}",
26+
f"--memray-bin-prefix={memray_bin_prefix}",
27+
"--hide-memray-summary",
28+
"-p",
29+
"no:benchmark",
30+
"-p",
31+
"no:codspeed",
32+
"-p",
33+
"no:cov",
34+
"-p",
35+
"no:profiling",
36+
"-s",
37+
"-o",
38+
"addopts=",
39+
]
40+
)
41+
42+
sys.exit(exitcode)

codeflash/benchmarking/trace_benchmarks.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,39 @@ def trace_benchmarks_pytest(
4646
error_section = combined_output
4747
logger.warning(f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}")
4848
logger.debug(f"Full pytest output:\n{combined_output}")
49+
50+
51+
def memory_benchmarks_pytest(
52+
benchmarks_root: Path, project_root: Path, memray_bin_dir: Path, memray_bin_prefix: str, timeout: int = 300
53+
) -> None:
54+
benchmark_env = make_env_with_project_root(project_root)
55+
run_args = get_cross_platform_subprocess_run_args(
56+
cwd=project_root, env=benchmark_env, timeout=timeout, check=False, text=True, capture_output=True
57+
)
58+
result = subprocess.run( # noqa: PLW1510
59+
[
60+
SAFE_SYS_EXECUTABLE,
61+
Path(__file__).parent / "pytest_new_process_memory_benchmarks.py",
62+
benchmarks_root,
63+
memray_bin_dir,
64+
memray_bin_prefix,
65+
],
66+
**run_args,
67+
)
68+
if result.returncode != 0:
69+
combined_output = result.stdout
70+
if result.stderr:
71+
combined_output = combined_output + "\n" + result.stderr if combined_output else result.stderr
72+
73+
if "ERROR collecting" in combined_output:
74+
error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
75+
match = re.search(error_pattern, combined_output)
76+
error_section = match.group(1) if match else combined_output
77+
elif "FAILURES" in combined_output:
78+
error_pattern = r"={3,}\s*FAILURES\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
79+
match = re.search(error_pattern, combined_output)
80+
error_section = match.group(1) if match else combined_output
81+
else:
82+
error_section = combined_output
83+
logger.warning(f"Error collecting memory benchmarks - Pytest Exit code: {result.returncode}, {error_section}")
84+
logger.debug(f"Full pytest output:\n{combined_output}")

codeflash/benchmarking/utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import logging
34
import shutil
5+
from operator import itemgetter
46
from typing import TYPE_CHECKING, Optional
57

68
from rich.console import Console
@@ -16,27 +18,30 @@
1618

1719

1820
def validate_and_format_benchmark_table(
19-
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int]
21+
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]], total_benchmark_timings: dict[BenchmarkKey, float]
2022
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
2123
function_to_result = {}
22-
# Process each function's benchmark data
24+
scale = 1_000_000.0
2325
for func_path, test_times in function_benchmark_timings.items():
2426
# Sort by percentage (highest first)
2527
sorted_tests = []
2628
for benchmark_key, func_time in test_times.items():
2729
total_time = total_benchmark_timings.get(benchmark_key, 0)
2830
if func_time > total_time:
29-
logger.debug(f"Skipping test {benchmark_key} due to func_time {func_time} > total_time {total_time}")
3031
# If the function time is greater than total time, likely to have multithreading / multiprocessing issues.
3132
# Do not try to project the optimization impact for this function.
33+
if logger.isEnabledFor(logging.DEBUG):
34+
logger.debug(
35+
f"Skipping test {benchmark_key} due to func_time {func_time} > total_time {total_time}"
36+
)
3237
sorted_tests.append((benchmark_key, 0.0, 0.0, 0.0))
3338
elif total_time > 0:
3439
percentage = (func_time / total_time) * 100
3540
# Convert nanoseconds to milliseconds
36-
func_time_ms = func_time / 1_000_000
37-
total_time_ms = total_time / 1_000_000
41+
func_time_ms = func_time / scale
42+
total_time_ms = total_time / scale
3843
sorted_tests.append((benchmark_key, total_time_ms, func_time_ms, percentage))
39-
sorted_tests.sort(key=lambda x: x[3], reverse=True)
44+
sorted_tests.sort(key=itemgetter(3), reverse=True)
4045
function_to_result[func_path] = sorted_tests
4146
return function_to_result
4247

@@ -77,8 +82,8 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey
7782

7883
def process_benchmark_data(
7984
replay_performance_gain: dict[BenchmarkKey, float],
80-
fto_benchmark_timings: dict[BenchmarkKey, int],
81-
total_benchmark_timings: dict[BenchmarkKey, int],
85+
fto_benchmark_timings: dict[BenchmarkKey, float],
86+
total_benchmark_timings: dict[BenchmarkKey, float],
8287
) -> Optional[ProcessedBenchmarkInfo]:
8388
"""Process benchmark data and generate detailed benchmark information.
8489

codeflash/cli_cmds/cli.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,13 +383,26 @@ def _build_parser() -> ArgumentParser:
383383
auth_subparsers.add_parser("status", help="Check authentication status")
384384

385385
compare_parser = subparsers.add_parser("compare", help="Compare benchmark performance between two git refs.")
386-
compare_parser.add_argument("base_ref", help="Base git ref (branch, tag, or commit)")
386+
compare_parser.add_argument(
387+
"base_ref", nargs="?", default=None, help="Base git ref (default: auto-detect from PR or default branch)"
388+
)
387389
compare_parser.add_argument("head_ref", nargs="?", default=None, help="Head git ref (default: current branch)")
388390
compare_parser.add_argument("--pr", type=int, help="Resolve head ref from a PR number (requires gh CLI)")
389391
compare_parser.add_argument(
390392
"--functions", type=str, help="Explicit functions to instrument: 'file.py::func1,func2;other.py::func3'"
391393
)
392394
compare_parser.add_argument("--timeout", type=int, default=600, help="Benchmark timeout in seconds (default: 600)")
395+
compare_parser.add_argument("--output", "-o", type=str, help="Write markdown report to file")
396+
compare_parser.add_argument(
397+
"--memory", action="store_true", help="Profile peak memory usage per benchmark (requires memray, Linux/macOS)"
398+
)
399+
compare_parser.add_argument("--script", type=str, help="Shell command to run as benchmark in each worktree")
400+
compare_parser.add_argument(
401+
"--script-output",
402+
type=str,
403+
dest="script_output",
404+
help="Relative path to JSON results file produced by --script (required with --script)",
405+
)
393406
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
394407

395408
trace_optimize.add_argument(

codeflash/cli_cmds/cmd_compare.py

Lines changed: 126 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,73 @@
1313
from codeflash.models.function_types import FunctionToOptimize
1414

1515
from codeflash.cli_cmds.console import logger
16-
from codeflash.code_utils.config_parser import parse_config_file
1716

1817

1918
def run_compare(args: Namespace) -> None:
2019
"""Entry point for the compare subcommand."""
21-
# Load project config
22-
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
20+
# Resolve head_ref: explicit arg > --pr > current branch
21+
head_ref = args.head_ref
22+
if args.pr:
23+
head_ref = resolve_pr_branch(args.pr)
24+
if not head_ref:
25+
head_ref = get_current_branch()
26+
if not head_ref:
27+
logger.error("Must provide head_ref, --pr, or be on a branch")
28+
sys.exit(1)
29+
logger.info(f"Auto-detected head ref: {head_ref}")
30+
31+
# Resolve base_ref: explicit arg > PR base branch > repo default branch
32+
base_ref = args.base_ref
33+
if not base_ref:
34+
base_ref = detect_base_ref(head_ref)
35+
if not base_ref:
36+
logger.error("Could not auto-detect base ref. Provide it explicitly or ensure gh CLI is available.")
37+
sys.exit(1)
38+
logger.info(f"Auto-detected base ref: {base_ref}")
39+
40+
# Script mode: run an arbitrary benchmark command on each worktree (no codeflash config needed)
41+
script_cmd = getattr(args, "script", None)
42+
if script_cmd:
43+
script_output = getattr(args, "script_output", None)
44+
if not script_output:
45+
logger.error("--script-output is required when using --script")
46+
sys.exit(1)
47+
48+
import git
49+
50+
project_root = Path(git.Repo(Path.cwd(), search_parent_directories=True).working_dir)
51+
52+
from codeflash.benchmarking.compare import compare_with_script
53+
54+
result = compare_with_script(
55+
base_ref=base_ref,
56+
head_ref=head_ref,
57+
project_root=project_root,
58+
script_cmd=script_cmd,
59+
script_output=script_output,
60+
timeout=args.timeout,
61+
memory=getattr(args, "memory", False),
62+
)
63+
64+
if not result.base_results and not result.head_results:
65+
logger.warning("No benchmark data collected. Check that --script-output points to a valid JSON file.")
66+
sys.exit(1)
2367

68+
if args.output:
69+
md = result.format_markdown()
70+
Path(args.output).write_text(md, encoding="utf-8")
71+
logger.info(f"Markdown report written to {args.output}")
72+
return
73+
74+
# Standard trace-benchmark mode: requires codeflash config
75+
from codeflash.code_utils.config_parser import parse_config_file
76+
77+
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
2478
module_root = Path(pyproject_config.get("module_root", ".")).resolve()
79+
80+
from codeflash.cli_cmds.cli import project_root_from_module_root
81+
82+
project_root = project_root_from_module_root(module_root, pyproject_file_path)
2583
tests_root = Path(pyproject_config.get("tests_root", "tests")).resolve()
2684
benchmarks_root_str = pyproject_config.get("benchmarks_root")
2785

@@ -34,42 +92,89 @@ def run_compare(args: Namespace) -> None:
3492
logger.error(f"benchmarks-root {benchmarks_root} is not a valid directory")
3593
sys.exit(1)
3694

37-
from codeflash.cli_cmds.cli import project_root_from_module_root
38-
39-
project_root = project_root_from_module_root(module_root, pyproject_file_path)
40-
41-
# Resolve head_ref
42-
head_ref = args.head_ref
43-
if args.pr:
44-
head_ref = _resolve_pr_branch(args.pr)
45-
if not head_ref:
46-
logger.error("Must provide head_ref or --pr")
47-
sys.exit(1)
48-
4995
# Parse explicit functions if provided
5096
functions = None
5197
if args.functions:
52-
functions = _parse_functions_arg(args.functions, project_root)
98+
functions = parse_functions_arg(args.functions, project_root)
5399

54100
from codeflash.benchmarking.compare import compare_branches
55101

56102
result = compare_branches(
57-
base_ref=args.base_ref,
103+
base_ref=base_ref,
58104
head_ref=head_ref,
59105
project_root=project_root,
60106
benchmarks_root=benchmarks_root,
61107
tests_root=tests_root,
62108
functions=functions,
63109
timeout=args.timeout,
110+
memory=getattr(args, "memory", False),
64111
)
65112

66-
if not result.base_total_ns and not result.head_total_ns:
113+
if not result.base_stats and not result.head_stats:
67114
logger.warning("No benchmark data collected. Check that benchmarks-root is configured and benchmarks exist.")
68115
sys.exit(1)
69116

117+
if args.output:
118+
md = result.format_markdown()
119+
Path(args.output).write_text(md, encoding="utf-8")
120+
logger.info(f"Markdown report written to {args.output}")
121+
122+
123+
def get_current_branch() -> str | None:
124+
try:
125+
result = subprocess.run(
126+
["git", "rev-parse", "--abbrev-ref", "HEAD"], capture_output=True, text=True, check=True
127+
)
128+
branch = result.stdout.strip()
129+
return branch if branch and branch != "HEAD" else None
130+
except (FileNotFoundError, subprocess.CalledProcessError):
131+
return None
132+
133+
134+
def detect_base_ref(head_ref: str) -> str | None:
135+
# Try to find an open PR for this branch and use its base
136+
try:
137+
result = subprocess.run(
138+
["gh", "pr", "view", head_ref, "--json", "baseRefName", "-q", ".baseRefName"],
139+
capture_output=True,
140+
text=True,
141+
check=True,
142+
)
143+
base = result.stdout.strip()
144+
if base:
145+
return base
146+
except (FileNotFoundError, subprocess.CalledProcessError):
147+
pass
148+
149+
# Fall back to repo default branch
150+
try:
151+
result = subprocess.run(
152+
["gh", "repo", "view", "--json", "defaultBranchRef", "-q", ".defaultBranchRef.name"],
153+
capture_output=True,
154+
text=True,
155+
check=True,
156+
)
157+
default = result.stdout.strip()
158+
if default:
159+
return default
160+
except (FileNotFoundError, subprocess.CalledProcessError):
161+
pass
162+
163+
# Last resort: check for common default branch names
164+
try:
165+
for candidate in ("main", "master"):
166+
result = subprocess.run(
167+
["git", "rev-parse", "--verify", candidate], capture_output=True, text=True, check=False
168+
)
169+
if result.returncode == 0:
170+
return candidate
171+
except FileNotFoundError:
172+
pass
173+
174+
return None
175+
70176

71-
def _resolve_pr_branch(pr_number: int) -> str:
72-
"""Resolve a PR number to its head branch name using gh CLI."""
177+
def resolve_pr_branch(pr_number: int) -> str:
73178
try:
74179
result = subprocess.run(
75180
["gh", "pr", "view", str(pr_number), "--json", "headRefName", "-q", ".headRefName"],
@@ -91,7 +196,7 @@ def _resolve_pr_branch(pr_number: int) -> str:
91196
sys.exit(1)
92197

93198

94-
def _parse_functions_arg(functions_str: str, project_root: Path) -> dict[Path, list[FunctionToOptimize]]:
199+
def parse_functions_arg(functions_str: str, project_root: Path) -> dict[Path, list[FunctionToOptimize]]:
95200
"""Parse --functions arg format: 'file.py::func1,func2;other.py::func3'."""
96201
from codeflash.models.function_types import FunctionToOptimize
97202

codeflash/optimization/optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def run_benchmarks(
127127
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(
128128
self.trace_file
129129
)
130-
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(self.trace_file)
130+
total_benchmark_stats = CodeFlashBenchmarkPlugin.get_benchmark_timings(self.trace_file)
131+
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
131132
function_to_results = validate_and_format_benchmark_table(
132133
function_benchmark_timings, total_benchmark_timings
133134
)

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ dependencies = [
5353
"filelock>=3.20.3; python_version >= '3.10'",
5454
"filelock<3.20.3; python_version < '3.10'",
5555
"pytest-asyncio>=0.18.0",
56+
"memray>=1.12; sys_platform != 'win32'",
57+
"pytest-memray>=1.7; sys_platform != 'win32'",
5658
]
5759

5860
[project.urls]
@@ -339,8 +341,8 @@ vcs = "git"
339341

340342
[tool.hatch.build.hooks.version]
341343
path = "codeflash/version.py"
342-
template = """# These version placeholders will be replaced by uv-dynamic-versioning during build.
343-
__version__ = "{version}"
344+
template = """# These version placeholders will be replaced by uv-dynamic-versioning during build.
345+
__version__ = "{version}"
344346
"""
345347

346348

0 commit comments

Comments
 (0)