1+ import contextlib
12import logging
23import os
34import pathlib
89from dataclasses import dataclass , field
910from typing import Optional
1011
12+ try :
13+ import tomllib
14+ except ImportError :
15+ import tomli as tomllib
16+
1117
1218@dataclass
1319class CoverageExpectation :
@@ -21,7 +27,10 @@ class TestConfig:
2127 # Make file_path optional when trace_mode is True
2228 file_path : Optional [pathlib .Path ] = None
2329 function_name : Optional [str ] = None
24- expected_unit_tests : Optional [int ] = None
30+ # Global count: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path"
31+ expected_unit_tests_count : Optional [int ] = None
32+ # Per-function count: "Discovered X existing unit test files, Y replay test files, and Z concolic..."
33+ expected_unit_test_files : Optional [int ] = None
2534 min_improvement_x : float = 0.1
2635 trace_mode : bool = False
2736 coverage_expectations : list [CoverageExpectation ] = field (default_factory = list )
@@ -129,7 +138,20 @@ def build_command(
129138
130139 if config .function_name :
131140 base_command .extend (["--function" , config .function_name ])
132- base_command .extend (["--tests-root" , str (test_root ), "--module-root" , str (cwd )])
141+
142+ # Check if pyproject.toml exists with codeflash config - if so, don't override it
143+ pyproject_path = cwd / "pyproject.toml"
144+ has_codeflash_config = False
145+ if pyproject_path .exists ():
146+ with contextlib .suppress (Exception ):
147+ with open (pyproject_path , "rb" ) as f :
148+ pyproject_data = tomllib .load (f )
149+ has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data ["tool" ]
150+
151+ # Only pass --tests-root and --module-root if they're not configured in pyproject.toml
152+ if not has_codeflash_config :
153+ base_command .extend (["--tests-root" , str (test_root ), "--module-root" , str (cwd )])
154+
133155 if benchmarks_root :
134156 base_command .extend (["--benchmark" , "--benchmarks-root" , str (benchmarks_root )])
135157 if config .use_worktree :
@@ -163,15 +185,30 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
163185 logging .error (f"Performance improvement rate { improvement_x } x not above { config .min_improvement_x } x" )
164186 return False
165187
166- if config .expected_unit_tests is not None :
167- unit_test_match = re .search (r"Discovered (\d+) existing unit test file" , stdout )
188+ if config .expected_unit_tests_count is not None :
189+ # Match the global test discovery message from optimizer.py which counts test invocations
190+ # Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests"
191+ unit_test_match = re .search (r"Discovered (\d+) existing unit tests? and \d+ replay tests? in [\d.]+s at" , stdout )
168192 if not unit_test_match :
169- logging .error ("Could not find unit test count" )
193+ logging .error ("Could not find global unit test count" )
170194 return False
171195
172196 num_tests = int (unit_test_match .group (1 ))
173- if num_tests != config .expected_unit_tests :
174- logging .error (f"Expected { config .expected_unit_tests } unit tests, found { num_tests } " )
197+ if num_tests != config .expected_unit_tests_count :
198+ logging .error (f"Expected { config .expected_unit_tests_count } global unit tests, found { num_tests } " )
199+ return False
200+
201+ if config .expected_unit_test_files is not None :
202+ # Match the per-function test discovery message from function_optimizer.py
203+ # Format: "Discovered X existing unit test files, Y replay test files, and Z concolic..."
204+ unit_test_files_match = re .search (r"Discovered (\d+) existing unit test files?" , stdout )
205+ if not unit_test_files_match :
206+ logging .error ("Could not find per-function unit test file count" )
207+ return False
208+
209+ num_test_files = int (unit_test_files_match .group (1 ))
210+ if num_test_files != config .expected_unit_test_files :
211+ logging .error (f"Expected { config .expected_unit_test_files } unit test files, found { num_test_files } " )
175212 return False
176213
177214 if config .coverage_expectations :
0 commit comments