Skip to content

Commit 808b76b

Browse files
authored
Merge pull request #2133 from codeflash-ai/fix/test-files-silent-dedup
fix: silently skip duplicate TestFiles.add() instead of raising
2 parents 342e902 + 9459c5a commit 808b76b

3 files changed

Lines changed: 88 additions & 40 deletions

File tree

codeflash/languages/python/parse_xml.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import os
1111
import re
12-
from typing import TYPE_CHECKING
12+
from typing import TYPE_CHECKING, Any
1313

1414
from junitparser.xunit2 import JUnitXml
1515

@@ -48,7 +48,7 @@
4848
)
4949

5050

51-
def _parse_func(file_path: Path):
51+
def _parse_func(file_path: Path) -> Any:
5252
from lxml.etree import XMLParser, parse
5353

5454
xml_parser = XMLParser(huge_tree=True)
@@ -59,13 +59,22 @@ def parse_python_test_xml(
5959
test_xml_file_path: Path,
6060
test_files: TestFiles,
6161
test_config: TestConfig,
62-
run_result: subprocess.CompletedProcess | None = None,
62+
run_result: subprocess.CompletedProcess[str] | None = None,
6363
) -> TestResults:
6464
from codeflash.verification.parse_test_output import resolve_test_file_from_class_path
6565

6666
test_results = TestResults()
6767
if not test_xml_file_path.exists():
68-
logger.warning(f"No test results for {test_xml_file_path} found.")
68+
if run_result is not None and run_result.returncode != 0:
69+
stderr_snippet = (run_result.stderr or "")[:500]
70+
stdout_snippet = (run_result.stdout or "")[:500]
71+
logger.warning(
72+
f"No test results for {test_xml_file_path} found. "
73+
f"Subprocess exited with code {run_result.returncode}.\n"
74+
f"stdout: {stdout_snippet}\nstderr: {stderr_snippet}"
75+
)
76+
else:
77+
logger.warning(f"No test results for {test_xml_file_path} found.")
6978
console.rule()
7079
return test_results
7180
try:
@@ -87,12 +96,7 @@ def parse_python_test_xml(
8796
):
8897
logger.info("Test failed to load, skipping it.")
8998
if run_result is not None:
90-
if isinstance(run_result.stdout, str) and isinstance(run_result.stderr, str):
91-
logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
92-
else:
93-
logger.info(
94-
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
95-
)
99+
logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
96100
return test_results
97101

98102
test_class_path = testcase.classname
@@ -159,7 +163,7 @@ def parse_python_test_xml(
159163
sys_stdout = testcase.system_out or ""
160164

161165
begin_matches = list(matches_re_start.finditer(sys_stdout))
162-
end_matches: dict[tuple, re.Match] = {}
166+
end_matches: dict[tuple[str, ...], re.Match[str]] = {}
163167
for match in matches_re_end.finditer(sys_stdout):
164168
groups = match.groups()
165169
if len(groups[5].split(":")) > 1:
@@ -234,11 +238,5 @@ def parse_python_test_xml(
234238
f"Tests '{[test_file.original_file_path for test_file in test_files.test_files]}' failed to run, skipping"
235239
)
236240
if run_result is not None:
237-
stdout, stderr = "", ""
238-
try:
239-
stdout = run_result.stdout.decode()
240-
stderr = run_result.stderr.decode()
241-
except AttributeError:
242-
stdout = run_result.stderr
243-
logger.debug(f"Test log - STDOUT : {stdout} \n STDERR : {stderr}")
241+
logger.debug(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
244242
return test_results

codeflash/models/models.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from functools import lru_cache
1010
from pathlib import Path
1111
from re import Pattern
12-
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast
12+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
1313

1414
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
1515
from pydantic.dataclasses import dataclass
1616

1717
from codeflash.models.test_type import TestType
1818

1919
if TYPE_CHECKING:
20-
from collections.abc import Iterator
20+
from collections.abc import Generator
2121

2222
import libcst as cst
2323
from rich.tree import Tree
@@ -298,11 +298,13 @@ def flat(self) -> str:
298298
299299
"""
300300
if self._cache.get("flat") is not None:
301-
return self._cache["flat"]
302-
self._cache["flat"] = "\n".join(
301+
result: str = self._cache["flat"]
302+
return result
303+
flat: str = "\n".join(
303304
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
304305
)
305-
return self._cache["flat"]
306+
self._cache["flat"] = flat
307+
return flat
306308

307309
@property
308310
def markdown(self) -> str:
@@ -332,7 +334,8 @@ def file_to_path(self) -> dict[str, str]:
332334
333335
"""
334336
try:
335-
return self._cache["file_to_path"]
337+
cached: dict[str, str] = self._cache["file_to_path"]
338+
return cached
336339
except KeyError:
337340
mapping = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
338341
self._cache["file_to_path"] = mapping
@@ -426,13 +429,16 @@ class TestFile(BaseModel):
426429

427430
class TestFiles(BaseModel):
428431
test_files: list[TestFile]
432+
_seen_paths: set[Path] = PrivateAttr(default_factory=set)
433+
434+
def model_post_init(self, __context: Any, /) -> None:
435+
self._seen_paths = {tf.instrumented_behavior_file_path for tf in self.test_files}
429436

430437
def add(self, test_file: TestFile) -> None:
431-
if test_file not in self.test_files:
438+
key = test_file.instrumented_behavior_file_path
439+
if key not in self._seen_paths:
440+
self._seen_paths.add(key)
432441
self.test_files.append(test_file)
433-
else:
434-
msg = "Test file already exists in the list"
435-
raise ValueError(msg)
436442

437443
def get_by_original_file_path(self, file_path: Path) -> TestFile | None:
438444
normalized = self._normalize_path_for_comparison(file_path)
@@ -494,8 +500,8 @@ def _normalize_path_for_comparison(path: Path) -> str:
494500
# Only lowercase on Windows where filesystem is case-insensitive
495501
return resolved.lower() if sys.platform == "win32" else resolved
496502

497-
def __iter__(self) -> Iterator[TestFile]:
498-
return iter(self.test_files)
503+
def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058
504+
yield from self.test_files
499505

500506
def __len__(self) -> int:
501507
return len(self.test_files)
@@ -514,9 +520,9 @@ class CandidateEvaluationContext:
514520
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
515521
is_correct: dict[str, bool] = Field(default_factory=dict)
516522
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
517-
ast_code_to_id: dict = Field(default_factory=dict)
523+
ast_code_to_id: dict[str, Any] = Field(default_factory=dict)
518524
optimizations_post: dict[str, str] = Field(default_factory=dict)
519-
valid_optimizations: list = Field(default_factory=list)
525+
valid_optimizations: list[Any] = Field(default_factory=list)
520526

521527
def record_failed_candidate(self, optimization_id: str) -> None:
522528
"""Record results for a failed candidate."""
@@ -543,7 +549,7 @@ def handle_duplicate_candidate(
543549
# Copy results from the previous evaluation (use .get() in case past_opt_id was registered
544550
# but never benchmarked due to an unhandled exception in process_single_candidate)
545551
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios.get(past_opt_id)
546-
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id)
552+
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id, False)
547553
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id)
548554

549555
# Line profiler results only available for successful runs
@@ -631,7 +637,7 @@ class OriginalCodeBaseline(BaseModel):
631637
behavior_test_results: TestResults
632638
benchmarking_test_results: TestResults
633639
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
634-
line_profile_results: dict
640+
line_profile_results: dict[str, Any]
635641
runtime: int
636642
coverage_results: Optional[CoverageData]
637643
async_throughput: Optional[int] = None
@@ -793,7 +799,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
793799
f"// Testing function: {self.function_getting_tested}"
794800
)
795801

796-
if self.test_class_name:
802+
if self.test_class_name and self.test_function_name:
797803
for stmt in module_node.body:
798804
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
799805
func_node = self.find_func_in_class(stmt, self.test_function_name)
@@ -884,7 +890,7 @@ def group_by_benchmarks(
884890
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
885891
from codeflash.code_utils.code_utils import module_name_from_file_path
886892

887-
test_results_by_benchmark = defaultdict(TestResults)
893+
test_results_by_benchmark: defaultdict[BenchmarkKey, TestResults] = defaultdict(TestResults)
888894
benchmark_module_path = {}
889895
for benchmark_key in benchmark_keys:
890896
benchmark_module_path[benchmark_key] = module_name_from_file_path(
@@ -1015,7 +1021,7 @@ def effective_loop_count(self) -> int:
10151021
return max(loop_indices) if loop_indices else 0
10161022

10171023
def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
1018-
map_gen_test_file_to_no_of_tests = Counter()
1024+
map_gen_test_file_to_no_of_tests: Counter[Path] = Counter()
10191025
for gen_test_result in self.test_results:
10201026
if (
10211027
gen_test_result.test_type == TestType.GENERATED_REGRESSION
@@ -1024,8 +1030,8 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
10241030
map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1
10251031
return map_gen_test_file_to_no_of_tests
10261032

1027-
def __iter__(self) -> Iterator[FunctionTestInvocation]:
1028-
return iter(self.test_results)
1033+
def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058
1034+
yield from self.test_results
10291035

10301036
def __len__(self) -> int:
10311037
return len(self.test_results)
@@ -1051,7 +1057,7 @@ def __eq__(self, other: object) -> bool:
10511057
if len(self) != len(other):
10521058
return False
10531059
original_recursion_limit = sys.getrecursionlimit()
1054-
cast("TestResults", other)
1060+
assert isinstance(other, TestResults)
10551061
for test_result in self:
10561062
other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id)
10571063
if other_test_result is None:

tests/test_test_files_add.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from pathlib import Path
2+
3+
from codeflash.models.models import TestFile, TestFiles
4+
from codeflash.models.test_type import TestType
5+
6+
7+
class TestTestFilesAdd:
8+
def test_add_unique_test_file(self) -> None:
9+
tf = TestFiles(test_files=[])
10+
test_file = TestFile(
11+
instrumented_behavior_file_path=Path("/tmp/test_behavior.py"),
12+
benchmarking_file_path=Path("/tmp/test_perf.py"),
13+
test_type=TestType.GENERATED_REGRESSION,
14+
)
15+
tf.add(test_file)
16+
assert len(tf.test_files) == 1
17+
assert tf.test_files[0] is test_file
18+
19+
def test_add_duplicate_is_noop(self) -> None:
20+
tf = TestFiles(test_files=[])
21+
test_file = TestFile(
22+
instrumented_behavior_file_path=Path("/tmp/test_behavior.py"),
23+
benchmarking_file_path=Path("/tmp/test_perf.py"),
24+
test_type=TestType.GENERATED_REGRESSION,
25+
)
26+
tf.add(test_file)
27+
tf.add(test_file) # silent skip — first write wins
28+
assert len(tf.test_files) == 1
29+
30+
def test_add_many_files_performance(self) -> None:
31+
tf = TestFiles(test_files=[])
32+
for i in range(100):
33+
test_file = TestFile(
34+
instrumented_behavior_file_path=Path(f"/tmp/test_behavior_{i}.py"),
35+
benchmarking_file_path=Path(f"/tmp/test_perf_{i}.py"),
36+
test_type=TestType.GENERATED_REGRESSION,
37+
)
38+
tf.add(test_file)
39+
40+
assert len(tf.test_files) == 100
41+
assert len(tf._seen_paths) == 100
42+
# Verify all paths are unique in the set
43+
expected_paths = {Path(f"/tmp/test_behavior_{i}.py") for i in range(100)}
44+
assert tf._seen_paths == expected_paths

0 commit comments

Comments
 (0)