99from functools import lru_cache
1010from pathlib import Path
1111from re import Pattern
12- from typing import TYPE_CHECKING , Any , NamedTuple , Optional , cast
12+ from typing import TYPE_CHECKING , Any , NamedTuple , Optional
1313
1414from pydantic import BaseModel , ConfigDict , Field , PrivateAttr , ValidationError , model_validator
1515from pydantic .dataclasses import dataclass
1616
1717from codeflash .models .test_type import TestType
1818
1919if TYPE_CHECKING :
20- from collections .abc import Iterator
20+ from collections .abc import Generator
2121
2222 import libcst as cst
2323 from rich .tree import Tree
@@ -298,11 +298,13 @@ def flat(self) -> str:
298298
299299 """
300300 if self ._cache .get ("flat" ) is not None :
301- return self ._cache ["flat" ]
302- self ._cache ["flat" ] = "\n " .join (
301+ result : str = self ._cache ["flat" ]
302+ return result
303+ flat : str = "\n " .join (
303304 get_code_block_splitter (block .file_path ) + "\n " + block .code for block in self .code_strings
304305 )
305- return self ._cache ["flat" ]
306+ self ._cache ["flat" ] = flat
307+ return flat
306308
307309 @property
308310 def markdown (self ) -> str :
@@ -332,7 +334,8 @@ def file_to_path(self) -> dict[str, str]:
332334
333335 """
334336 try :
335- return self ._cache ["file_to_path" ]
337+ cached : dict [str , str ] = self ._cache ["file_to_path" ]
338+ return cached
336339 except KeyError :
337340 mapping = {str (code_string .file_path ): code_string .code for code_string in self .code_strings }
338341 self ._cache ["file_to_path" ] = mapping
@@ -494,8 +497,8 @@ def _normalize_path_for_comparison(path: Path) -> str:
494497 # Only lowercase on Windows where filesystem is case-insensitive
495498 return resolved .lower () if sys .platform == "win32" else resolved
496499
497- def __iter__ (self ) -> Iterator [ TestFile ]:
498- return iter ( self .test_files )
500+ def __iter__ (self ) -> Generator [ Any , None , None ]: # noqa: PYI058
501+ yield from self .test_files
499502
500503 def __len__ (self ) -> int :
501504 return len (self .test_files )
@@ -514,9 +517,9 @@ class CandidateEvaluationContext:
514517 optimized_runtimes : dict [str , float | None ] = Field (default_factory = dict )
515518 is_correct : dict [str , bool ] = Field (default_factory = dict )
516519 optimized_line_profiler_results : dict [str , str ] = Field (default_factory = dict )
517- ast_code_to_id : dict = Field (default_factory = dict )
520+ ast_code_to_id : dict [ str , Any ] = Field (default_factory = dict )
518521 optimizations_post : dict [str , str ] = Field (default_factory = dict )
519- valid_optimizations : list = Field (default_factory = list )
522+ valid_optimizations : list [ Any ] = Field (default_factory = list )
520523
521524 def record_failed_candidate (self , optimization_id : str ) -> None :
522525 """Record results for a failed candidate."""
@@ -543,7 +546,7 @@ def handle_duplicate_candidate(
543546 # Copy results from the previous evaluation (use .get() in case past_opt_id was registered
544547 # but never benchmarked due to an unhandled exception in process_single_candidate)
545548 self .speedup_ratios [candidate .optimization_id ] = self .speedup_ratios .get (past_opt_id )
546- self .is_correct [candidate .optimization_id ] = self .is_correct .get (past_opt_id )
549+ self .is_correct [candidate .optimization_id ] = self .is_correct .get (past_opt_id , False )
547550 self .optimized_runtimes [candidate .optimization_id ] = self .optimized_runtimes .get (past_opt_id )
548551
549552 # Line profiler results only available for successful runs
@@ -631,7 +634,7 @@ class OriginalCodeBaseline(BaseModel):
631634 behavior_test_results : TestResults
632635 benchmarking_test_results : TestResults
633636 replay_benchmarking_test_results : Optional [dict [BenchmarkKey , TestResults ]] = None
634- line_profile_results : dict
637+ line_profile_results : dict [ str , Any ]
635638 runtime : int
636639 coverage_results : Optional [CoverageData ]
637640 async_throughput : Optional [int ] = None
@@ -793,7 +796,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
793796 f"// Testing function: { self .function_getting_tested } "
794797 )
795798
796- if self .test_class_name :
799+ if self .test_class_name and self . test_function_name :
797800 for stmt in module_node .body :
798801 if isinstance (stmt , cst .ClassDef ) and stmt .name .value == self .test_class_name :
799802 func_node = self .find_func_in_class (stmt , self .test_function_name )
@@ -884,7 +887,7 @@ def group_by_benchmarks(
884887 """Group TestResults by benchmark for calculating improvements for each benchmark."""
885888 from codeflash .code_utils .code_utils import module_name_from_file_path
886889
887- test_results_by_benchmark = defaultdict (TestResults )
890+ test_results_by_benchmark : defaultdict [ BenchmarkKey , TestResults ] = defaultdict (TestResults )
888891 benchmark_module_path = {}
889892 for benchmark_key in benchmark_keys :
890893 benchmark_module_path [benchmark_key ] = module_name_from_file_path (
@@ -1015,7 +1018,7 @@ def effective_loop_count(self) -> int:
10151018 return max (loop_indices ) if loop_indices else 0
10161019
10171020 def file_to_no_of_tests (self , test_functions_to_remove : list [str ]) -> Counter [Path ]:
1018- map_gen_test_file_to_no_of_tests = Counter ()
1021+ map_gen_test_file_to_no_of_tests : Counter [ Path ] = Counter ()
10191022 for gen_test_result in self .test_results :
10201023 if (
10211024 gen_test_result .test_type == TestType .GENERATED_REGRESSION
@@ -1024,8 +1027,8 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
10241027 map_gen_test_file_to_no_of_tests [gen_test_result .file_name ] += 1
10251028 return map_gen_test_file_to_no_of_tests
10261029
1027- def __iter__ (self ) -> Iterator [ FunctionTestInvocation ]:
1028- return iter ( self .test_results )
1030+ def __iter__ (self ) -> Generator [ Any , None , None ]: # noqa: PYI058
1031+ yield from self .test_results
10291032
10301033 def __len__ (self ) -> int :
10311034 return len (self .test_results )
@@ -1051,7 +1054,7 @@ def __eq__(self, other: object) -> bool:
10511054 if len (self ) != len (other ):
10521055 return False
10531056 original_recursion_limit = sys .getrecursionlimit ()
1054- cast ( "TestResults" , other )
1057+ assert isinstance ( other , TestResults )
10551058 for test_result in self :
10561059 other_test_result = other .get_by_unique_invocation_loop_id (test_result .unique_invocation_loop_id )
10571060 if other_test_result is None :
0 commit comments