Skip to content

Commit 7b72148

Browse files
committed
some issues
1 parent 328c837 commit 7b72148

1 file changed

Lines changed: 16 additions & 15 deletions

File tree

codeflash/discovery/functions_to_optimize.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,10 @@ def get_all_replay_test_functions(
595595
except Exception as e:
596596
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
597597

598-
if not trace_file_path:
598+
if trace_file_path is None:
599599
logger.error("Could not find trace_file_path in replay test files.")
600600
exit_with_message("Could not find trace_file_path in replay test files.")
601+
raise AssertionError("Unreachable") # exit_with_message never returns
601602

602603
if not trace_file_path.exists():
603604
logger.error(f"Trace file not found: {trace_file_path}")
@@ -652,7 +653,7 @@ def get_all_replay_test_functions(
652653
if filtered_list:
653654
filtered_valid_functions[file_path] = filtered_list
654655

655-
return filtered_valid_functions, trace_file_path
656+
return dict(filtered_valid_functions), trace_file_path
656657

657658

658659
def is_git_repo(file_path: str) -> bool:
@@ -664,11 +665,13 @@ def is_git_repo(file_path: str) -> bool:
664665

665666

666667
@cache
667-
def ignored_submodule_paths(module_root: str) -> list[str]:
668+
def ignored_submodule_paths(module_root: str) -> list[Path]:
668669
if is_git_repo(module_root):
669670
git_repo = git.Repo(module_root, search_parent_directories=True)
670671
try:
671-
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
672+
working_dir = git_repo.working_tree_dir
673+
if working_dir is not None:
674+
return [Path(working_dir, submodule.path).resolve() for submodule in git_repo.submodules]
672675
except Exception as e:
673676
logger.warning(f"Error getting submodule paths: {e}")
674677
return []
@@ -682,7 +685,7 @@ def __init__(
682685
self.class_name = class_name
683686
self.function_name = function_or_method_name
684687
self.is_top_level = False
685-
self.function_has_args = None
688+
self.function_has_args: bool | None = None
686689
self.line_no = line_no
687690
self.is_staticmethod = False
688691
self.is_classmethod = False
@@ -806,14 +809,14 @@ def was_function_previously_optimized(
806809
if not owner or not repo or pr_number is None or getattr(args, "no_pr", False):
807810
return False
808811

809-
code_contexts = []
812+
code_contexts: list[dict[str, str]] = []
810813

811814
func_hash = code_context.hashing_code_context_hash
812815
# Use a unique path identifier that includes function info
813816

814817
code_contexts.append(
815818
{
816-
"file_path": function_to_optimize.file_path,
819+
"file_path": str(function_to_optimize.file_path),
817820
"function_name": function_to_optimize.qualified_name,
818821
"code_hash": func_hash,
819822
}
@@ -839,7 +842,7 @@ def filter_functions(
839842
ignore_paths: list[Path],
840843
project_root: Path,
841844
module_root: Path,
842-
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
845+
previous_checkpoint_functions: dict[str, dict[str, Any]] | None = None,
843846
*,
844847
disable_logs: bool = False,
845848
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
@@ -868,10 +871,8 @@ def filter_functions(
868871

869872
# Check if tests_root overlaps with module_root or project_root
870873
# In this case, we need to use file pattern matching instead of directory matching
871-
tests_root_overlaps_source = (
872-
tests_root_str == module_root_str
873-
or tests_root_str == project_root_str
874-
or module_root_str.startswith(tests_root_str + os.sep)
874+
tests_root_overlaps_source = tests_root_str in (module_root_str, project_root_str) or module_root_str.startswith(
875+
tests_root_str + os.sep
875876
)
876877

877878
# Test file patterns for when tests_root overlaps with source
@@ -903,12 +904,12 @@ def is_test_file(file_path_normalized: str) -> bool:
903904
if is_test_file(file_path_normalized):
904905
test_functions_removed_count += len(_functions)
905906
continue
906-
if file_path in ignore_paths or any(
907+
if file_path_path in ignore_paths or any(
907908
file_path_normalized.startswith(os.path.normcase(str(ignore_path)) + os.sep) for ignore_path in ignore_paths
908909
):
909910
ignore_paths_removed_count += 1
910911
continue
911-
if file_path in submodule_paths or any(
912+
if file_path_path in submodule_paths or any(
912913
file_path_normalized.startswith(os.path.normcase(str(submodule_path)) + os.sep)
913914
for submodule_path in submodule_paths
914915
):
@@ -1000,7 +1001,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
10001001

10011002
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
10021003
# Custom DFS, return True as soon as a Return node is found
1003-
stack = [function_node]
1004+
stack: list[ast.AST] = [function_node]
10041005
while stack:
10051006
node = stack.pop()
10061007
if isinstance(node, ast.Return):

0 commit comments

Comments
 (0)