@@ -595,9 +595,10 @@ def get_all_replay_test_functions(
595595 except Exception as e :
596596 logger .warning (f"Error parsing replay test file { replay_test_file } : { e } " )
597597
598- if not trace_file_path :
598+ if trace_file_path is None :
599599 logger .error ("Could not find trace_file_path in replay test files." )
600600 exit_with_message ("Could not find trace_file_path in replay test files." )
601+ raise AssertionError ("Unreachable" ) # exit_with_message never returns
601602
602603 if not trace_file_path .exists ():
603604 logger .error (f"Trace file not found: { trace_file_path } " )
@@ -652,7 +653,7 @@ def get_all_replay_test_functions(
652653 if filtered_list :
653654 filtered_valid_functions [file_path ] = filtered_list
654655
655- return filtered_valid_functions , trace_file_path
656+ return dict ( filtered_valid_functions ) , trace_file_path
656657
657658
658659def is_git_repo (file_path : str ) -> bool :
@@ -664,11 +665,13 @@ def is_git_repo(file_path: str) -> bool:
664665
665666
666667@cache
667- def ignored_submodule_paths (module_root : str ) -> list [str ]:
668+ def ignored_submodule_paths (module_root : str ) -> list [Path ]:
668669 if is_git_repo (module_root ):
669670 git_repo = git .Repo (module_root , search_parent_directories = True )
670671 try :
671- return [Path (git_repo .working_tree_dir , submodule .path ).resolve () for submodule in git_repo .submodules ]
672+ working_dir = git_repo .working_tree_dir
673+ if working_dir is not None :
674+ return [Path (working_dir , submodule .path ).resolve () for submodule in git_repo .submodules ]
672675 except Exception as e :
673676 logger .warning (f"Error getting submodule paths: { e } " )
674677 return []
@@ -682,7 +685,7 @@ def __init__(
682685 self .class_name = class_name
683686 self .function_name = function_or_method_name
684687 self .is_top_level = False
685- self .function_has_args = None
688+ self .function_has_args : bool | None = None
686689 self .line_no = line_no
687690 self .is_staticmethod = False
688691 self .is_classmethod = False
@@ -806,14 +809,14 @@ def was_function_previously_optimized(
806809 if not owner or not repo or pr_number is None or getattr (args , "no_pr" , False ):
807810 return False
808811
809- code_contexts = []
812+ code_contexts : list [ dict [ str , str ]] = []
810813
811814 func_hash = code_context .hashing_code_context_hash
812815 # Use a unique path identifier that includes function info
813816
814817 code_contexts .append (
815818 {
816- "file_path" : function_to_optimize .file_path ,
819+ "file_path" : str ( function_to_optimize .file_path ) ,
817820 "function_name" : function_to_optimize .qualified_name ,
818821 "code_hash" : func_hash ,
819822 }
@@ -839,7 +842,7 @@ def filter_functions(
839842 ignore_paths : list [Path ],
840843 project_root : Path ,
841844 module_root : Path ,
842- previous_checkpoint_functions : dict [Path , dict [str , Any ]] | None = None ,
845+ previous_checkpoint_functions : dict [str , dict [str , Any ]] | None = None ,
843846 * ,
844847 disable_logs : bool = False ,
845848) -> tuple [dict [Path , list [FunctionToOptimize ]], int ]:
@@ -868,10 +871,8 @@ def filter_functions(
868871
869872 # Check if tests_root overlaps with module_root or project_root
870873 # In this case, we need to use file pattern matching instead of directory matching
871- tests_root_overlaps_source = (
872- tests_root_str == module_root_str
873- or tests_root_str == project_root_str
874- or module_root_str .startswith (tests_root_str + os .sep )
874+ tests_root_overlaps_source = tests_root_str in (module_root_str , project_root_str ) or module_root_str .startswith (
875+ tests_root_str + os .sep
875876 )
876877
877878 # Test file patterns for when tests_root overlaps with source
@@ -903,12 +904,12 @@ def is_test_file(file_path_normalized: str) -> bool:
903904 if is_test_file (file_path_normalized ):
904905 test_functions_removed_count += len (_functions )
905906 continue
906- if file_path in ignore_paths or any (
907+ if file_path_path in ignore_paths or any (
907908 file_path_normalized .startswith (os .path .normcase (str (ignore_path )) + os .sep ) for ignore_path in ignore_paths
908909 ):
909910 ignore_paths_removed_count += 1
910911 continue
911- if file_path in submodule_paths or any (
912+ if file_path_path in submodule_paths or any (
912913 file_path_normalized .startswith (os .path .normcase (str (submodule_path )) + os .sep )
913914 for submodule_path in submodule_paths
914915 ):
@@ -1000,7 +1001,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
10001001
10011002def function_has_return_statement (function_node : FunctionDef | AsyncFunctionDef ) -> bool :
10021003 # Custom DFS, return True as soon as a Return node is found
1003- stack = [function_node ]
1004+ stack : list [ ast . AST ] = [function_node ]
10041005 while stack :
10051006 node = stack .pop ()
10061007 if isinstance (node , ast .Return ):
0 commit comments