4141
4242 from codeflash .models .models import CodeOptimizationContext
4343 from codeflash .verification .verification_utils import TestConfig
44+ import contextlib
45+
4446from rich .text import Text
4547
4648_property_id = "property"
@@ -595,9 +597,10 @@ def get_all_replay_test_functions(
595597 except Exception as e :
596598 logger .warning (f"Error parsing replay test file { replay_test_file } : { e } " )
597599
598- if not trace_file_path :
600+ if trace_file_path is None :
599601 logger .error ("Could not find trace_file_path in replay test files." )
600602 exit_with_message ("Could not find trace_file_path in replay test files." )
603+ raise AssertionError ("Unreachable" ) # exit_with_message never returns
601604
602605 if not trace_file_path .exists ():
603606 logger .error (f"Trace file not found: { trace_file_path } " )
@@ -652,7 +655,7 @@ def get_all_replay_test_functions(
652655 if filtered_list :
653656 filtered_valid_functions [file_path ] = filtered_list
654657
655- return filtered_valid_functions , trace_file_path
658+ return dict ( filtered_valid_functions ) , trace_file_path
656659
657660
658661def is_git_repo (file_path : str ) -> bool :
@@ -664,11 +667,13 @@ def is_git_repo(file_path: str) -> bool:
664667
665668
666669@cache
667- def ignored_submodule_paths (module_root : str ) -> list [str ]:
670+ def ignored_submodule_paths (module_root : str ) -> list [Path ]:
668671 if is_git_repo (module_root ):
669672 git_repo = git .Repo (module_root , search_parent_directories = True )
670673 try :
671- return [Path (git_repo .working_tree_dir , submodule .path ).resolve () for submodule in git_repo .submodules ]
674+ working_dir = git_repo .working_tree_dir
675+ if working_dir is not None :
676+ return [Path (working_dir , submodule .path ).resolve () for submodule in git_repo .submodules ]
672677 except Exception as e :
673678 logger .warning (f"Error getting submodule paths: { e } " )
674679 return []
@@ -682,7 +687,7 @@ def __init__(
682687 self .class_name = class_name
683688 self .function_name = function_or_method_name
684689 self .is_top_level = False
685- self .function_has_args = None
690+ self .function_has_args : bool | None = None
686691 self .line_no = line_no
687692 self .is_staticmethod = False
688693 self .is_classmethod = False
@@ -796,31 +801,28 @@ def was_function_previously_optimized(
796801
797802 # Check optimization status if repository info is provided
798803 # already_optimized_count = 0
799- try :
804+
805+ # Check optimization status if repository info is provided
806+ # already_optimized_count = 0
807+ owner = None
808+ repo = None
809+ with contextlib .suppress (git .exc .InvalidGitRepositoryError ):
800810 owner , repo = get_repo_owner_and_name ()
801- except git .exc .InvalidGitRepositoryError :
802- logger .warning ("No git repository found" )
803- owner , repo = None , None
811+
804812 pr_number = get_pr_number ()
805813
806814 if not owner or not repo or pr_number is None or getattr (args , "no_pr" , False ):
807815 return False
808816
809- code_contexts = []
810-
811817 func_hash = code_context .hashing_code_context_hash
812- # Use a unique path identifier that includes function info
813818
814- code_contexts . append (
819+ code_contexts = [
815820 {
816- "file_path" : function_to_optimize .file_path ,
821+ "file_path" : str ( function_to_optimize .file_path ) ,
817822 "function_name" : function_to_optimize .qualified_name ,
818823 "code_hash" : func_hash ,
819824 }
820- )
821-
822- if not code_contexts :
823- return False
825+ ]
824826
825827 try :
826828 result = is_function_being_optimized_again (owner , repo , pr_number , code_contexts )
@@ -839,7 +841,7 @@ def filter_functions(
839841 ignore_paths : list [Path ],
840842 project_root : Path ,
841843 module_root : Path ,
842- previous_checkpoint_functions : dict [Path , dict [str , Any ]] | None = None ,
844+ previous_checkpoint_functions : dict [str , dict [str , Any ]] | None = None ,
843845 * ,
844846 disable_logs : bool = False ,
845847) -> tuple [dict [Path , list [FunctionToOptimize ]], int ]:
@@ -868,24 +870,13 @@ def filter_functions(
868870
869871 # Check if tests_root overlaps with module_root or project_root
870872 # In this case, we need to use file pattern matching instead of directory matching
871- tests_root_overlaps_source = (
872- tests_root_str == module_root_str
873- or tests_root_str == project_root_str
874- or module_root_str .startswith (tests_root_str + os .sep )
873+ tests_root_overlaps_source = tests_root_str in (module_root_str , project_root_str ) or module_root_str .startswith (
874+ tests_root_str + os .sep
875875 )
876876
877877 # Test file patterns for when tests_root overlaps with source
878- test_file_name_patterns = (
879- ".test." ,
880- ".spec." ,
881- "_test." ,
882- "_spec." ,
883- )
884- test_dir_patterns = (
885- os .sep + "test" + os .sep ,
886- os .sep + "tests" + os .sep ,
887- os .sep + "__tests__" + os .sep ,
888- )
878+ test_file_name_patterns = (".test." , ".spec." , "_test." , "_spec." )
879+ test_dir_patterns = (os .sep + "test" + os .sep , os .sep + "tests" + os .sep , os .sep + "__tests__" + os .sep )
889880
890881 def is_test_file (file_path_normalized : str ) -> bool :
891882 """Check if a file is a test file based on patterns."""
@@ -899,11 +890,10 @@ def is_test_file(file_path_normalized: str) -> bool:
899890 # to avoid false positives from parent directories
900891 relative_path = file_lower
901892 if project_root_str and file_lower .startswith (project_root_str .lower ()):
902- relative_path = file_lower [len (project_root_str ):]
893+ relative_path = file_lower [len (project_root_str ) :]
903894 return any (pattern in relative_path for pattern in test_dir_patterns )
904- else :
905- # Use directory-based filtering when tests are in a separate directory
906- return file_path_normalized .startswith (tests_root_str + os .sep )
895+ # Use directory-based filtering when tests are in a separate directory
896+ return file_path_normalized .startswith (tests_root_str + os .sep )
907897
908898 # We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
909899 for file_path_path , functions in modified_functions .items ():
@@ -913,12 +903,12 @@ def is_test_file(file_path_normalized: str) -> bool:
913903 if is_test_file (file_path_normalized ):
914904 test_functions_removed_count += len (_functions )
915905 continue
916- if file_path in ignore_paths or any (
906+ if file_path_path in ignore_paths or any (
917907 file_path_normalized .startswith (os .path .normcase (str (ignore_path )) + os .sep ) for ignore_path in ignore_paths
918908 ):
919909 ignore_paths_removed_count += 1
920910 continue
921- if file_path in submodule_paths or any (
911+ if file_path_path in submodule_paths or any (
922912 file_path_normalized .startswith (os .path .normcase (str (submodule_path )) + os .sep )
923913 for submodule_path in submodule_paths
924914 ):
@@ -1010,7 +1000,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
10101000
10111001def function_has_return_statement (function_node : FunctionDef | AsyncFunctionDef ) -> bool :
10121002 # Custom DFS, return True as soon as a Return node is found
1013- stack = [function_node ]
1003+ stack : list [ ast . AST ] = [function_node ]
10141004 while stack :
10151005 node = stack .pop ()
10161006 if isinstance (node , ast .Return ):
0 commit comments