99from functools import lru_cache
1010from pathlib import Path
1111from re import Pattern
12- from typing import TYPE_CHECKING , Any , NamedTuple , Optional , cast
12+ from typing import TYPE_CHECKING , Any , NamedTuple , Optional
1313
1414from pydantic import BaseModel , ConfigDict , Field , PrivateAttr , ValidationError , model_validator
1515from pydantic .dataclasses import dataclass
1616
1717from codeflash .models .test_type import TestType
1818
1919if TYPE_CHECKING :
20- from collections .abc import Iterator
20+ from collections .abc import Generator
2121
2222 import libcst as cst
2323 from rich .tree import Tree
@@ -298,11 +298,13 @@ def flat(self) -> str:
298298
299299 """
300300 if self ._cache .get ("flat" ) is not None :
301- return self ._cache ["flat" ]
302- self ._cache ["flat" ] = "\n " .join (
301+ result : str = self ._cache ["flat" ]
302+ return result
303+ flat : str = "\n " .join (
303304 get_code_block_splitter (block .file_path ) + "\n " + block .code for block in self .code_strings
304305 )
305- return self ._cache ["flat" ]
306+ self ._cache ["flat" ] = flat
307+ return flat
306308
307309 @property
308310 def markdown (self ) -> str :
@@ -332,7 +334,8 @@ def file_to_path(self) -> dict[str, str]:
332334
333335 """
334336 try :
335- return self ._cache ["file_to_path" ]
337+ cached : dict [str , str ] = self ._cache ["file_to_path" ]
338+ return cached
336339 except KeyError :
337340 mapping = {str (code_string .file_path ): code_string .code for code_string in self .code_strings }
338341 self ._cache ["file_to_path" ] = mapping
@@ -497,8 +500,8 @@ def _normalize_path_for_comparison(path: Path) -> str:
497500 # Only lowercase on Windows where filesystem is case-insensitive
498501 return resolved .lower () if sys .platform == "win32" else resolved
499502
500- def __iter__ (self ) -> Iterator [ TestFile ]:
501- return iter ( self .test_files )
503+ def __iter__ (self ) -> Generator [ Any , None , None ]: # noqa: PYI058
504+ yield from self .test_files
502505
503506 def __len__ (self ) -> int :
504507 return len (self .test_files )
@@ -517,9 +520,9 @@ class CandidateEvaluationContext:
517520 optimized_runtimes : dict [str , float | None ] = Field (default_factory = dict )
518521 is_correct : dict [str , bool ] = Field (default_factory = dict )
519522 optimized_line_profiler_results : dict [str , str ] = Field (default_factory = dict )
520- ast_code_to_id : dict = Field (default_factory = dict )
523+ ast_code_to_id : dict [ str , Any ] = Field (default_factory = dict )
521524 optimizations_post : dict [str , str ] = Field (default_factory = dict )
522- valid_optimizations : list = Field (default_factory = list )
525+ valid_optimizations : list [ Any ] = Field (default_factory = list )
523526
524527 def record_failed_candidate (self , optimization_id : str ) -> None :
525528 """Record results for a failed candidate."""
@@ -546,7 +549,7 @@ def handle_duplicate_candidate(
546549 # Copy results from the previous evaluation (use .get() in case past_opt_id was registered
547550 # but never benchmarked due to an unhandled exception in process_single_candidate)
548551 self .speedup_ratios [candidate .optimization_id ] = self .speedup_ratios .get (past_opt_id )
549- self .is_correct [candidate .optimization_id ] = self .is_correct .get (past_opt_id )
552+ self .is_correct [candidate .optimization_id ] = self .is_correct .get (past_opt_id , False )
550553 self .optimized_runtimes [candidate .optimization_id ] = self .optimized_runtimes .get (past_opt_id )
551554
552555 # Line profiler results only available for successful runs
@@ -634,7 +637,7 @@ class OriginalCodeBaseline(BaseModel):
634637 behavior_test_results : TestResults
635638 benchmarking_test_results : TestResults
636639 replay_benchmarking_test_results : Optional [dict [BenchmarkKey , TestResults ]] = None
637- line_profile_results : dict
640+ line_profile_results : dict [ str , Any ]
638641 runtime : int
639642 coverage_results : Optional [CoverageData ]
640643 async_throughput : Optional [int ] = None
@@ -796,7 +799,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
796799 f"// Testing function: { self .function_getting_tested } "
797800 )
798801
799- if self .test_class_name :
802+ if self .test_class_name and self . test_function_name :
800803 for stmt in module_node .body :
801804 if isinstance (stmt , cst .ClassDef ) and stmt .name .value == self .test_class_name :
802805 func_node = self .find_func_in_class (stmt , self .test_function_name )
@@ -887,7 +890,7 @@ def group_by_benchmarks(
887890 """Group TestResults by benchmark for calculating improvements for each benchmark."""
888891 from codeflash .code_utils .code_utils import module_name_from_file_path
889892
890- test_results_by_benchmark = defaultdict (TestResults )
893+ test_results_by_benchmark : defaultdict [ BenchmarkKey , TestResults ] = defaultdict (TestResults )
891894 benchmark_module_path = {}
892895 for benchmark_key in benchmark_keys :
893896 benchmark_module_path [benchmark_key ] = module_name_from_file_path (
@@ -1018,7 +1021,7 @@ def effective_loop_count(self) -> int:
10181021 return max (loop_indices ) if loop_indices else 0
10191022
10201023 def file_to_no_of_tests (self , test_functions_to_remove : list [str ]) -> Counter [Path ]:
1021- map_gen_test_file_to_no_of_tests = Counter ()
1024+ map_gen_test_file_to_no_of_tests : Counter [ Path ] = Counter ()
10221025 for gen_test_result in self .test_results :
10231026 if (
10241027 gen_test_result .test_type == TestType .GENERATED_REGRESSION
@@ -1027,8 +1030,8 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
10271030 map_gen_test_file_to_no_of_tests [gen_test_result .file_name ] += 1
10281031 return map_gen_test_file_to_no_of_tests
10291032
1030- def __iter__ (self ) -> Iterator [ FunctionTestInvocation ]:
1031- return iter ( self .test_results )
1033+ def __iter__ (self ) -> Generator [ Any , None , None ]: # noqa: PYI058
1034+ yield from self .test_results
10321035
10331036 def __len__ (self ) -> int :
10341037 return len (self .test_results )
@@ -1054,7 +1057,7 @@ def __eq__(self, other: object) -> bool:
10541057 if len (self ) != len (other ):
10551058 return False
10561059 original_recursion_limit = sys .getrecursionlimit ()
1057- cast ( "TestResults" , other )
1060+ assert isinstance ( other , TestResults )
10581061 for test_result in self :
10591062 other_test_result = other .get_by_unique_invocation_loop_id (test_result .unique_invocation_loop_id )
10601063 if other_test_result is None :
0 commit comments