99from functools import lru_cache
1010from pathlib import Path
1111from re import Pattern
12- from typing import TYPE_CHECKING , Any , NamedTuple , Optional , cast
12+ from typing import TYPE_CHECKING , Any , NamedTuple , Optional
1313
1414from pydantic import BaseModel , ConfigDict , Field , PrivateAttr , ValidationError , model_validator
1515from pydantic .dataclasses import dataclass
1616
1717from codeflash .models .test_type import TestType
1818
1919if TYPE_CHECKING :
20- from collections .abc import Iterator
20+ from collections .abc import Generator
2121
2222 import libcst as cst
2323 from rich .tree import Tree
@@ -298,11 +298,13 @@ def flat(self) -> str:
298298
299299 """
300300 if self ._cache .get ("flat" ) is not None :
301- return self ._cache ["flat" ]
302- self ._cache ["flat" ] = "\n " .join (
301+ result : str = self ._cache ["flat" ]
302+ return result
303+ flat : str = "\n " .join (
303304 get_code_block_splitter (block .file_path ) + "\n " + block .code for block in self .code_strings
304305 )
305- return self ._cache ["flat" ]
306+ self ._cache ["flat" ] = flat
307+ return flat
306308
307309 @property
308310 def markdown (self ) -> str :
@@ -332,7 +334,8 @@ def file_to_path(self) -> dict[str, str]:
332334
333335 """
334336 try :
335- return self ._cache ["file_to_path" ]
337+ cached : dict [str , str ] = self ._cache ["file_to_path" ]
338+ return cached
336339 except KeyError :
337340 mapping = {str (code_string .file_path ): code_string .code for code_string in self .code_strings }
338341 self ._cache ["file_to_path" ] = mapping
@@ -426,13 +429,16 @@ class TestFile(BaseModel):
426429
427430class TestFiles (BaseModel ):
428431 test_files : list [TestFile ]
432+ _seen_paths : set [Path ] = PrivateAttr (default_factory = set )
433+
434+ def model_post_init (self , __context : Any , / ) -> None :
435+ self ._seen_paths = {tf .instrumented_behavior_file_path for tf in self .test_files }
429436
430437 def add (self , test_file : TestFile ) -> None :
431- if test_file not in self .test_files :
438+ key = test_file .instrumented_behavior_file_path
439+ if key not in self ._seen_paths :
440+ self ._seen_paths .add (key )
432441 self .test_files .append (test_file )
433- else :
434- msg = "Test file already exists in the list"
435- raise ValueError (msg )
436442
437443 def get_by_original_file_path (self , file_path : Path ) -> TestFile | None :
438444 normalized = self ._normalize_path_for_comparison (file_path )
@@ -494,8 +500,8 @@ def _normalize_path_for_comparison(path: Path) -> str:
494500 # Only lowercase on Windows where filesystem is case-insensitive
495501 return resolved .lower () if sys .platform == "win32" else resolved
496502
497- def __iter__ (self ) -> Iterator [ TestFile ]:
498- return iter ( self .test_files )
503+ def __iter__ (self ) -> Generator [ Any , None , None ]: # noqa: PYI058
504+ yield from self .test_files
499505
500506 def __len__ (self ) -> int :
501507 return len (self .test_files )
@@ -514,9 +520,9 @@ class CandidateEvaluationContext:
514520 optimized_runtimes : dict [str , float | None ] = Field (default_factory = dict )
515521 is_correct : dict [str , bool ] = Field (default_factory = dict )
516522 optimized_line_profiler_results : dict [str , str ] = Field (default_factory = dict )
517- ast_code_to_id : dict = Field (default_factory = dict )
523+ ast_code_to_id : dict [ str , Any ] = Field (default_factory = dict )
518524 optimizations_post : dict [str , str ] = Field (default_factory = dict )
519- valid_optimizations : list = Field (default_factory = list )
525+ valid_optimizations : list [ Any ] = Field (default_factory = list )
520526
521527 def record_failed_candidate (self , optimization_id : str ) -> None :
522528 """Record results for a failed candidate."""
@@ -543,7 +549,7 @@ def handle_duplicate_candidate(
543549 # Copy results from the previous evaluation (use .get() in case past_opt_id was registered
544550 # but never benchmarked due to an unhandled exception in process_single_candidate)
545551 self .speedup_ratios [candidate .optimization_id ] = self .speedup_ratios .get (past_opt_id )
546- self .is_correct [candidate .optimization_id ] = self .is_correct .get (past_opt_id )
552+ self .is_correct [candidate .optimization_id ] = self .is_correct .get (past_opt_id , False )
547553 self .optimized_runtimes [candidate .optimization_id ] = self .optimized_runtimes .get (past_opt_id )
548554
549555 # Line profiler results only available for successful runs
@@ -631,7 +637,7 @@ class OriginalCodeBaseline(BaseModel):
631637 behavior_test_results : TestResults
632638 benchmarking_test_results : TestResults
633639 replay_benchmarking_test_results : Optional [dict [BenchmarkKey , TestResults ]] = None
634- line_profile_results : dict
640+ line_profile_results : dict [ str , Any ]
635641 runtime : int
636642 coverage_results : Optional [CoverageData ]
637643 async_throughput : Optional [int ] = None
@@ -793,7 +799,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
793799 f"// Testing function: { self .function_getting_tested } "
794800 )
795801
796- if self .test_class_name :
802+ if self .test_class_name and self . test_function_name :
797803 for stmt in module_node .body :
798804 if isinstance (stmt , cst .ClassDef ) and stmt .name .value == self .test_class_name :
799805 func_node = self .find_func_in_class (stmt , self .test_function_name )
@@ -884,7 +890,7 @@ def group_by_benchmarks(
884890 """Group TestResults by benchmark for calculating improvements for each benchmark."""
885891 from codeflash .code_utils .code_utils import module_name_from_file_path
886892
887- test_results_by_benchmark = defaultdict (TestResults )
893+ test_results_by_benchmark : defaultdict [ BenchmarkKey , TestResults ] = defaultdict (TestResults )
888894 benchmark_module_path = {}
889895 for benchmark_key in benchmark_keys :
890896 benchmark_module_path [benchmark_key ] = module_name_from_file_path (
@@ -1015,7 +1021,7 @@ def effective_loop_count(self) -> int:
10151021 return max (loop_indices ) if loop_indices else 0
10161022
10171023 def file_to_no_of_tests (self , test_functions_to_remove : list [str ]) -> Counter [Path ]:
1018- map_gen_test_file_to_no_of_tests = Counter ()
1024+ map_gen_test_file_to_no_of_tests : Counter [ Path ] = Counter ()
10191025 for gen_test_result in self .test_results :
10201026 if (
10211027 gen_test_result .test_type == TestType .GENERATED_REGRESSION
@@ -1024,8 +1030,8 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
10241030 map_gen_test_file_to_no_of_tests [gen_test_result .file_name ] += 1
10251031 return map_gen_test_file_to_no_of_tests
10261032
1027- def __iter__ (self ) -> Iterator [ FunctionTestInvocation ]:
1028- return iter ( self .test_results )
1033+ def __iter__ (self ) -> Generator [ Any , None , None ]: # noqa: PYI058
1034+ yield from self .test_results
10291035
10301036 def __len__ (self ) -> int :
10311037 return len (self .test_results )
@@ -1051,7 +1057,7 @@ def __eq__(self, other: object) -> bool:
10511057 if len (self ) != len (other ):
10521058 return False
10531059 original_recursion_limit = sys .getrecursionlimit ()
1054- cast ( "TestResults" , other )
1060+ assert isinstance ( other , TestResults )
10551061 for test_result in self :
10561062 other_test_result = other .get_by_unique_invocation_loop_id (test_result .unique_invocation_loop_id )
10571063 if other_test_result is None :
0 commit comments