1414from pydantic import BaseModel , ConfigDict , Field , PrivateAttr , ValidationError , model_validator
1515from pydantic .dataclasses import dataclass
1616
17+ from codeflash .models .shared_types import OptimizedCandidateSource
1718from codeflash .models .test_type import TestType
1819
1920if TYPE_CHECKING :
@@ -50,6 +51,23 @@ class AIServiceRefinerRequest:
5051 additional_context_files : dict [str , str ] | None = None # {filepath: content} for imported modules
5152
5253
54+ @dataclass (frozen = True )
55+ class AIServiceBatchRefinerCandidate :
56+ optimization_id : str
57+ optimized_source_code : str
58+ optimized_explanation : str
59+ optimized_code_runtime : int
60+ original_code_runtime : int
61+ speedup : str
62+ optimized_line_profiler_results : str
63+
64+
65+ @dataclass (frozen = True )
66+ class AIServiceBatchRefinerRequest :
67+ shared_context : dict [str , Any ]
68+ candidates : list [dict [str , Any ]]
69+
70+
5371# this should be possible to auto serialize
5472@dataclass (frozen = True )
5573class AdaptiveOptimizedCandidate :
@@ -298,11 +316,11 @@ def flat(self) -> str:
298316
299317 """
300318 if self ._cache .get ("flat" ) is not None :
301- return self ._cache ["flat" ]
319+ return cast ( "str" , self ._cache ["flat" ])
302320 self ._cache ["flat" ] = "\n " .join (
303321 get_code_block_splitter (block .file_path ) + "\n " + block .code for block in self .code_strings
304322 )
305- return self ._cache ["flat" ]
323+ return cast ( "str" , self ._cache ["flat" ])
306324
307325 @property
308326 def markdown (self ) -> str :
@@ -332,7 +350,7 @@ def file_to_path(self) -> dict[str, str]:
332350
333351 """
334352 try :
335- return self ._cache ["file_to_path" ]
353+ return cast ( "dict[str, str]" , self ._cache ["file_to_path" ])
336354 except KeyError :
337355 mapping = {str (code_string .file_path ): code_string .code for code_string in self .code_strings }
338356 self ._cache ["file_to_path" ] = mapping
@@ -494,7 +512,7 @@ def _normalize_path_for_comparison(path: Path) -> str:
494512 # Only lowercase on Windows where filesystem is case-insensitive
495513 return resolved .lower () if sys .platform == "win32" else resolved
496514
497- def __iter__ (self ) -> Iterator [TestFile ]:
515+ def __iter__ (self ) -> Iterator [TestFile ]: # type: ignore[override]
498516 return iter (self .test_files )
499517
500518 def __len__ (self ) -> int :
@@ -514,9 +532,9 @@ class CandidateEvaluationContext:
514532 optimized_runtimes : dict [str , float | None ] = Field (default_factory = dict )
515533 is_correct : dict [str , bool ] = Field (default_factory = dict )
516534 optimized_line_profiler_results : dict [str , str ] = Field (default_factory = dict )
517- ast_code_to_id : dict = Field (default_factory = dict )
535+ ast_code_to_id : dict [ str , Any ] = Field (default_factory = dict )
518536 optimizations_post : dict [str , str ] = Field (default_factory = dict )
519- valid_optimizations : list = Field (default_factory = list )
537+ valid_optimizations : list [ Any ] = Field (default_factory = list )
520538
521539 def record_failed_candidate (self , optimization_id : str ) -> None :
522540 """Record results for a failed candidate."""
@@ -543,7 +561,7 @@ def handle_duplicate_candidate(
543561 # Copy results from the previous evaluation (use .get() in case past_opt_id was registered
544562 # but never benchmarked due to an unhandled exception in process_single_candidate)
545563 self .speedup_ratios [candidate .optimization_id ] = self .speedup_ratios .get (past_opt_id )
546- self .is_correct [candidate .optimization_id ] = self .is_correct .get (past_opt_id )
564+ self .is_correct [candidate .optimization_id ] = self .is_correct .get (past_opt_id , False )
547565 self .optimized_runtimes [candidate .optimization_id ] = self .optimized_runtimes .get (past_opt_id )
548566
549567 # Line profiler results only available for successful runs
@@ -592,15 +610,6 @@ class TestsInFile:
592610 test_type : TestType
593611
594612
595- class OptimizedCandidateSource (str , Enum ):
596- OPTIMIZE = "OPTIMIZE"
597- OPTIMIZE_LP = "OPTIMIZE_LP"
598- REFINE = "REFINE"
599- REPAIR = "REPAIR"
600- ADAPTIVE = "ADAPTIVE"
601- JIT_REWRITE = "JIT_REWRITE"
602-
603-
604613@dataclass (frozen = True )
605614class OptimizedCandidate :
606615 source_code : CodeStringsMarkdown
@@ -631,7 +640,7 @@ class OriginalCodeBaseline(BaseModel):
631640 behavior_test_results : TestResults
632641 benchmarking_test_results : TestResults
633642 replay_benchmarking_test_results : Optional [dict [BenchmarkKey , TestResults ]] = None
634- line_profile_results : dict
643+ line_profile_results : dict [ str , Any ]
635644 runtime : int
636645 coverage_results : Optional [CoverageData ]
637646 async_throughput : Optional [int ] = None
@@ -794,6 +803,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
794803 )
795804
796805 if self .test_class_name :
806+ assert self .test_function_name is not None
797807 for stmt in module_node .body :
798808 if isinstance (stmt , cst .ClassDef ) and stmt .name .value == self .test_class_name :
799809 func_node = self .find_func_in_class (stmt , self .test_function_name )
@@ -884,7 +894,7 @@ def group_by_benchmarks(
884894 """Group TestResults by benchmark for calculating improvements for each benchmark."""
885895 from codeflash .code_utils .code_utils import module_name_from_file_path
886896
887- test_results_by_benchmark = defaultdict (TestResults )
897+ test_results_by_benchmark : defaultdict [ BenchmarkKey , TestResults ] = defaultdict (TestResults )
888898 benchmark_module_path = {}
889899 for benchmark_key in benchmark_keys :
890900 benchmark_module_path [benchmark_key ] = module_name_from_file_path (
@@ -1015,7 +1025,7 @@ def effective_loop_count(self) -> int:
10151025 return max (loop_indices ) if loop_indices else 0
10161026
10171027 def file_to_no_of_tests (self , test_functions_to_remove : list [str ]) -> Counter [Path ]:
1018- map_gen_test_file_to_no_of_tests = Counter ()
1028+ map_gen_test_file_to_no_of_tests : Counter [ Path ] = Counter ()
10191029 for gen_test_result in self .test_results :
10201030 if (
10211031 gen_test_result .test_type == TestType .GENERATED_REGRESSION
@@ -1024,7 +1034,7 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
10241034 map_gen_test_file_to_no_of_tests [gen_test_result .file_name ] += 1
10251035 return map_gen_test_file_to_no_of_tests
10261036
1027- def __iter__ (self ) -> Iterator [FunctionTestInvocation ]:
1037+ def __iter__ (self ) -> Iterator [FunctionTestInvocation ]: # type: ignore[override]
10281038 return iter (self .test_results )
10291039
10301040 def __len__ (self ) -> int :
@@ -1051,7 +1061,7 @@ def __eq__(self, other: object) -> bool:
10511061 if len (self ) != len (other ):
10521062 return False
10531063 original_recursion_limit = sys .getrecursionlimit ()
1054- cast ( "TestResults" , other )
1064+ assert isinstance ( other , TestResults )
10551065 for test_result in self :
10561066 other_test_result = other .get_by_unique_invocation_loop_id (test_result .unique_invocation_loop_id )
10571067 if other_test_result is None :
0 commit comments