4141
4242 from codeflash .models .models import CodeOptimizationContext
4343 from codeflash .verification .verification_utils import TestConfig
44+ import contextlib
45+
4446from rich .text import Text
4547
4648_property_id = "property"
@@ -595,9 +597,10 @@ def get_all_replay_test_functions(
595597 except Exception as e :
596598 logger .warning (f"Error parsing replay test file { replay_test_file } : { e } " )
597599
598- if not trace_file_path :
600+ if trace_file_path is None :
599601 logger .error ("Could not find trace_file_path in replay test files." )
600602 exit_with_message ("Could not find trace_file_path in replay test files." )
603+ raise AssertionError ("Unreachable" ) # exit_with_message never returns
601604
602605 if not trace_file_path .exists ():
603606 logger .error (f"Trace file not found: { trace_file_path } " )
@@ -652,7 +655,7 @@ def get_all_replay_test_functions(
652655 if filtered_list :
653656 filtered_valid_functions [file_path ] = filtered_list
654657
655- return filtered_valid_functions , trace_file_path
658+ return dict ( filtered_valid_functions ) , trace_file_path
656659
657660
658661def is_git_repo (file_path : str ) -> bool :
@@ -664,11 +667,13 @@ def is_git_repo(file_path: str) -> bool:
664667
665668
666669@cache
667- def ignored_submodule_paths (module_root : str ) -> list [str ]:
670+ def ignored_submodule_paths (module_root : str ) -> list [Path ]:
668671 if is_git_repo (module_root ):
669672 git_repo = git .Repo (module_root , search_parent_directories = True )
670673 try :
671- return [Path (git_repo .working_tree_dir , submodule .path ).resolve () for submodule in git_repo .submodules ]
674+ working_dir = git_repo .working_tree_dir
675+ if working_dir is not None :
676+ return [Path (working_dir , submodule .path ).resolve () for submodule in git_repo .submodules ]
672677 except Exception as e :
673678 logger .warning (f"Error getting submodule paths: { e } " )
674679 return []
@@ -682,7 +687,7 @@ def __init__(
682687 self .class_name = class_name
683688 self .function_name = function_or_method_name
684689 self .is_top_level = False
685- self .function_has_args = None
690+ self .function_has_args : bool | None = None
686691 self .line_no = line_no
687692 self .is_staticmethod = False
688693 self .is_classmethod = False
@@ -796,31 +801,28 @@ def was_function_previously_optimized(
796801
797802 # Check optimization status if repository info is provided
798803 # already_optimized_count = 0
799- try :
804+
805+ # Check optimization status if repository info is provided
806+ # already_optimized_count = 0
807+ owner = None
808+ repo = None
809+ with contextlib .suppress (git .exc .InvalidGitRepositoryError ):
800810 owner , repo = get_repo_owner_and_name ()
801- except git .exc .InvalidGitRepositoryError :
802- logger .warning ("No git repository found" )
803- owner , repo = None , None
811+
804812 pr_number = get_pr_number ()
805813
806814 if not owner or not repo or pr_number is None or getattr (args , "no_pr" , False ):
807815 return False
808816
809- code_contexts = []
810-
811817 func_hash = code_context .hashing_code_context_hash
812- # Use a unique path identifier that includes function info
813818
814- code_contexts . append (
819+ code_contexts = [
815820 {
816- "file_path" : function_to_optimize .file_path ,
821+ "file_path" : str ( function_to_optimize .file_path ) ,
817822 "function_name" : function_to_optimize .qualified_name ,
818823 "code_hash" : func_hash ,
819824 }
820- )
821-
822- if not code_contexts :
823- return False
825+ ]
824826
825827 try :
826828 result = is_function_being_optimized_again (owner , repo , pr_number , code_contexts )
@@ -839,7 +841,7 @@ def filter_functions(
839841 ignore_paths : list [Path ],
840842 project_root : Path ,
841843 module_root : Path ,
842- previous_checkpoint_functions : dict [Path , dict [str , Any ]] | None = None ,
844+ previous_checkpoint_functions : dict [str , dict [str , Any ]] | None = None ,
843845 * ,
844846 disable_logs : bool = False ,
845847) -> tuple [dict [Path , list [FunctionToOptimize ]], int ]:
@@ -864,21 +866,49 @@ def filter_functions(
864866 # Normalize paths for case-insensitive comparison on Windows
865867 tests_root_str = os .path .normcase (str (tests_root ))
866868 module_root_str = os .path .normcase (str (module_root ))
869+ project_root_str = os .path .normcase (str (project_root ))
870+
871+ # Check if tests_root overlaps with module_root or project_root
872+ # In this case, we need to use file pattern matching instead of directory matching
873+ tests_root_overlaps_source = tests_root_str in (module_root_str , project_root_str ) or module_root_str .startswith (
874+ tests_root_str + os .sep
875+ )
876+
877+ # Test file patterns for when tests_root overlaps with source
878+ test_file_name_patterns = (".test." , ".spec." , "_test." , "_spec." )
879+ test_dir_patterns = (os .sep + "test" + os .sep , os .sep + "tests" + os .sep , os .sep + "__tests__" + os .sep )
880+
881+ def is_test_file (file_path_normalized : str ) -> bool :
882+ """Check if a file is a test file based on patterns."""
883+ if tests_root_overlaps_source :
884+ # Use file pattern matching when tests_root overlaps with source
885+ file_lower = file_path_normalized .lower ()
886+ # Check filename patterns (e.g., .test.ts, .spec.ts)
887+ if any (pattern in file_lower for pattern in test_file_name_patterns ):
888+ return True
889+ # Check directory patterns, but only within the project root
890+ # to avoid false positives from parent directories
891+ relative_path = file_lower
892+ if project_root_str and file_lower .startswith (project_root_str .lower ()):
893+ relative_path = file_lower [len (project_root_str ) :]
894+ return any (pattern in relative_path for pattern in test_dir_patterns )
895+ # Use directory-based filtering when tests are in a separate directory
896+ return file_path_normalized .startswith (tests_root_str + os .sep )
867897
868898 # We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
869899 for file_path_path , functions in modified_functions .items ():
870900 _functions = functions
871901 file_path = str (file_path_path )
872902 file_path_normalized = os .path .normcase (file_path )
873- if file_path_normalized . startswith ( tests_root_str + os . sep ):
903+ if is_test_file ( file_path_normalized ):
874904 test_functions_removed_count += len (_functions )
875905 continue
876- if file_path in ignore_paths or any (
906+ if file_path_path in ignore_paths or any (
877907 file_path_normalized .startswith (os .path .normcase (str (ignore_path )) + os .sep ) for ignore_path in ignore_paths
878908 ):
879909 ignore_paths_removed_count += 1
880910 continue
881- if file_path in submodule_paths or any (
911+ if file_path_path in submodule_paths or any (
882912 file_path_normalized .startswith (os .path .normcase (str (submodule_path )) + os .sep )
883913 for submodule_path in submodule_paths
884914 ):
@@ -970,7 +1000,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
9701000
9711001def function_has_return_statement (function_node : FunctionDef | AsyncFunctionDef ) -> bool :
9721002 # Custom DFS, return True as soon as a Return node is found
973- stack = [function_node ]
1003+ stack : list [ ast . AST ] = [function_node ]
9741004 while stack :
9751005 node = stack .pop ()
9761006 if isinstance (node , ast .Return ):
0 commit comments