1313 from codeflash .models .function_types import FunctionToOptimize
1414
1515from codeflash .cli_cmds .console import logger
16- from codeflash .code_utils .config_parser import parse_config_file
1716
1817
1918def run_compare (args : Namespace ) -> None :
2019 """Entry point for the compare subcommand."""
21- # Load project config
22- pyproject_config , pyproject_file_path = parse_config_file (args .config_file )
20+ # Resolve head_ref: explicit arg > --pr > current branch
21+ head_ref = args .head_ref
22+ if args .pr :
23+ head_ref = resolve_pr_branch (args .pr )
24+ if not head_ref :
25+ head_ref = get_current_branch ()
26+ if not head_ref :
27+ logger .error ("Must provide head_ref, --pr, or be on a branch" )
28+ sys .exit (1 )
29+ logger .info (f"Auto-detected head ref: { head_ref } " )
30+
31+ # Resolve base_ref: explicit arg > PR base branch > repo default branch
32+ base_ref = args .base_ref
33+ if not base_ref :
34+ base_ref = detect_base_ref (head_ref )
35+ if not base_ref :
36+ logger .error ("Could not auto-detect base ref. Provide it explicitly or ensure gh CLI is available." )
37+ sys .exit (1 )
38+ logger .info (f"Auto-detected base ref: { base_ref } " )
39+
40+ # Script mode: run an arbitrary benchmark command on each worktree (no codeflash config needed)
41+ script_cmd = getattr (args , "script" , None )
42+ if script_cmd :
43+ script_output = getattr (args , "script_output" , None )
44+ if not script_output :
45+ logger .error ("--script-output is required when using --script" )
46+ sys .exit (1 )
47+
48+ import git
49+
50+ project_root = Path (git .Repo (Path .cwd (), search_parent_directories = True ).working_dir )
51+
52+ from codeflash .benchmarking .compare import compare_with_script
53+
54+ result = compare_with_script (
55+ base_ref = base_ref ,
56+ head_ref = head_ref ,
57+ project_root = project_root ,
58+ script_cmd = script_cmd ,
59+ script_output = script_output ,
60+ timeout = args .timeout ,
61+ memory = getattr (args , "memory" , False ),
62+ )
63+
64+ if not result .base_results and not result .head_results :
65+ logger .warning ("No benchmark data collected. Check that --script-output points to a valid JSON file." )
66+ sys .exit (1 )
2367
68+ if args .output :
69+ md = result .format_markdown ()
70+ Path (args .output ).write_text (md , encoding = "utf-8" )
71+ logger .info (f"Markdown report written to { args .output } " )
72+ return
73+
74+ # Standard trace-benchmark mode: requires codeflash config
75+ from codeflash .code_utils .config_parser import parse_config_file
76+
77+ pyproject_config , pyproject_file_path = parse_config_file (args .config_file )
2478 module_root = Path (pyproject_config .get ("module_root" , "." )).resolve ()
79+
80+ from codeflash .cli_cmds .cli import project_root_from_module_root
81+
82+ project_root = project_root_from_module_root (module_root , pyproject_file_path )
2583 tests_root = Path (pyproject_config .get ("tests_root" , "tests" )).resolve ()
2684 benchmarks_root_str = pyproject_config .get ("benchmarks_root" )
2785
@@ -34,42 +92,89 @@ def run_compare(args: Namespace) -> None:
3492 logger .error (f"benchmarks-root { benchmarks_root } is not a valid directory" )
3593 sys .exit (1 )
3694
37- from codeflash .cli_cmds .cli import project_root_from_module_root
38-
39- project_root = project_root_from_module_root (module_root , pyproject_file_path )
40-
41- # Resolve head_ref
42- head_ref = args .head_ref
43- if args .pr :
44- head_ref = _resolve_pr_branch (args .pr )
45- if not head_ref :
46- logger .error ("Must provide head_ref or --pr" )
47- sys .exit (1 )
48-
4995 # Parse explicit functions if provided
5096 functions = None
5197 if args .functions :
52- functions = _parse_functions_arg (args .functions , project_root )
98+ functions = parse_functions_arg (args .functions , project_root )
5399
54100 from codeflash .benchmarking .compare import compare_branches
55101
56102 result = compare_branches (
57- base_ref = args . base_ref ,
103+ base_ref = base_ref ,
58104 head_ref = head_ref ,
59105 project_root = project_root ,
60106 benchmarks_root = benchmarks_root ,
61107 tests_root = tests_root ,
62108 functions = functions ,
63109 timeout = args .timeout ,
110+ memory = getattr (args , "memory" , False ),
64111 )
65112
66- if not result .base_total_ns and not result .head_total_ns :
113+ if not result .base_stats and not result .head_stats :
67114 logger .warning ("No benchmark data collected. Check that benchmarks-root is configured and benchmarks exist." )
68115 sys .exit (1 )
69116
117+ if args .output :
118+ md = result .format_markdown ()
119+ Path (args .output ).write_text (md , encoding = "utf-8" )
120+ logger .info (f"Markdown report written to { args .output } " )
121+
122+
123+ def get_current_branch () -> str | None :
124+ try :
125+ result = subprocess .run (
126+ ["git" , "rev-parse" , "--abbrev-ref" , "HEAD" ], capture_output = True , text = True , check = True
127+ )
128+ branch = result .stdout .strip ()
129+ return branch if branch and branch != "HEAD" else None
130+ except (FileNotFoundError , subprocess .CalledProcessError ):
131+ return None
132+
133+
134+ def detect_base_ref (head_ref : str ) -> str | None :
135+ # Try to find an open PR for this branch and use its base
136+ try :
137+ result = subprocess .run (
138+ ["gh" , "pr" , "view" , head_ref , "--json" , "baseRefName" , "-q" , ".baseRefName" ],
139+ capture_output = True ,
140+ text = True ,
141+ check = True ,
142+ )
143+ base = result .stdout .strip ()
144+ if base :
145+ return base
146+ except (FileNotFoundError , subprocess .CalledProcessError ):
147+ pass
148+
149+ # Fall back to repo default branch
150+ try :
151+ result = subprocess .run (
152+ ["gh" , "repo" , "view" , "--json" , "defaultBranchRef" , "-q" , ".defaultBranchRef.name" ],
153+ capture_output = True ,
154+ text = True ,
155+ check = True ,
156+ )
157+ default = result .stdout .strip ()
158+ if default :
159+ return default
160+ except (FileNotFoundError , subprocess .CalledProcessError ):
161+ pass
162+
163+ # Last resort: check for common default branch names
164+ try :
165+ for candidate in ("main" , "master" ):
166+ result = subprocess .run (
167+ ["git" , "rev-parse" , "--verify" , candidate ], capture_output = True , text = True , check = False
168+ )
169+ if result .returncode == 0 :
170+ return candidate
171+ except FileNotFoundError :
172+ pass
173+
174+ return None
175+
70176
71- def _resolve_pr_branch (pr_number : int ) -> str :
72- """Resolve a PR number to its head branch name using gh CLI."""
177+ def resolve_pr_branch (pr_number : int ) -> str :
73178 try :
74179 result = subprocess .run (
75180 ["gh" , "pr" , "view" , str (pr_number ), "--json" , "headRefName" , "-q" , ".headRefName" ],
@@ -91,7 +196,7 @@ def _resolve_pr_branch(pr_number: int) -> str:
91196 sys .exit (1 )
92197
93198
94- def _parse_functions_arg (functions_str : str , project_root : Path ) -> dict [Path , list [FunctionToOptimize ]]:
199+ def parse_functions_arg (functions_str : str , project_root : Path ) -> dict [Path , list [FunctionToOptimize ]]:
95200 """Parse --functions arg format: 'file.py::func1,func2;other.py::func3'."""
96201 from codeflash .models .function_types import FunctionToOptimize
97202
0 commit comments