@@ -68,9 +68,9 @@ def visit_Return(self, node: cst.Return) -> None:
6868class FunctionVisitor (cst .CSTVisitor ):
6969 METADATA_DEPENDENCIES = (cst .metadata .PositionProvider , cst .metadata .ParentNodeProvider )
7070
71- def __init__ (self , file_path : str ) -> None :
71+ def __init__ (self , file_path : Path ) -> None :
7272 super ().__init__ ()
73- self .file_path : str = file_path
73+ self .file_path : Path = file_path
7474 self .functions : list [FunctionToOptimize ] = []
7575
7676 @staticmethod
@@ -264,7 +264,7 @@ def get_functions_to_optimize(
264264 assert sum ([bool (optimize_all ), bool (replay_test ), bool (file )]) <= 1 , (
265265 "Only one of optimize_all, replay_test, or file should be provided"
266266 )
267- functions : dict [str , list [FunctionToOptimize ]]
267+ functions : dict [Path , list [FunctionToOptimize ]]
268268 trace_file_path : Path | None = None
269269 is_lsp = is_LSP_enabled ()
270270 with warnings .catch_warnings ():
@@ -281,7 +281,7 @@ def get_functions_to_optimize(
281281 logger .info ("!lsp|Finding all functions in the file '%s'…" , file )
282282 console .rule ()
283283 file = Path (file ) if isinstance (file , str ) else file
284- functions : dict [ Path , list [ FunctionToOptimize ]] = find_all_functions_in_file (file )
284+ functions = find_all_functions_in_file (file )
285285 if only_get_this_function is not None :
286286 split_function = only_get_this_function .split ("." )
287287 if len (split_function ) > 2 :
@@ -316,6 +316,7 @@ def get_functions_to_optimize(
316316 f"Function { only_get_this_function } not found in file { file } \n or the function does not have a 'return' statement or is a property"
317317 )
318318
319+ assert found_function is not None
319320 # For JavaScript/TypeScript, verify that the function (or its parent class) is exported
320321 # Non-exported functions cannot be imported by tests
321322 if found_function .language in ("javascript" , "typescript" ):
@@ -359,7 +360,7 @@ def get_functions_to_optimize(
359360 return filtered_modified_functions , functions_count , trace_file_path
360361
361362
362- def get_functions_within_git_diff (uncommitted_changes : bool ) -> dict [str , list [FunctionToOptimize ]]:
363+ def get_functions_within_git_diff (uncommitted_changes : bool ) -> dict [Path , list [FunctionToOptimize ]]:
363364 modified_lines : dict [str , list [int ]] = get_git_diff (uncommitted_changes = uncommitted_changes )
364365 return get_functions_within_lines (modified_lines )
365366
@@ -400,7 +401,7 @@ def closest_matching_file_function_name(
400401 closest_match = function
401402 closest_file = file_path
402403
403- if closest_match is not None :
404+ if closest_match is not None and closest_file is not None :
404405 return closest_file , closest_match
405406 return None
406407
@@ -434,13 +435,13 @@ def levenshtein_distance(s1: str, s2: str) -> int:
434435 return previous [len1 ]
435436
436437
437- def get_functions_inside_a_commit (commit_hash : str ) -> dict [str , list [FunctionToOptimize ]]:
438+ def get_functions_inside_a_commit (commit_hash : str ) -> dict [Path , list [FunctionToOptimize ]]:
438439 modified_lines : dict [str , list [int ]] = get_git_diff (only_this_commit = commit_hash )
439440 return get_functions_within_lines (modified_lines )
440441
441442
442- def get_functions_within_lines (modified_lines : dict [str , list [int ]]) -> dict [str , list [FunctionToOptimize ]]:
443- functions : dict [str , list [FunctionToOptimize ]] = {}
443+ def get_functions_within_lines (modified_lines : dict [str , list [int ]]) -> dict [Path , list [FunctionToOptimize ]]:
444+ functions : dict [Path , list [FunctionToOptimize ]] = {}
444445 for path_str , lines_in_file in modified_lines .items ():
445446 path = Path (path_str )
446447 if not path .exists ():
@@ -452,9 +453,9 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
452453 except Exception as e :
453454 logger .exception (e )
454455 continue
455- function_lines = FunctionVisitor (file_path = str ( path ) )
456+ function_lines = FunctionVisitor (file_path = path )
456457 wrapper .visit (function_lines )
457- functions [str ( path ) ] = [
458+ functions [path ] = [
458459 function_to_optimize
459460 for function_to_optimize in function_lines .functions
460461 if (start_line := function_to_optimize .starting_line ) is not None
@@ -466,7 +467,7 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
466467
467468def get_all_files_and_functions (
468469 module_root_path : Path , ignore_paths : list [Path ], language : Language | None = None
469- ) -> dict [str , list [FunctionToOptimize ]]:
470+ ) -> dict [Path , list [FunctionToOptimize ]]:
470471 """Get all optimizable functions from files in the module root.
471472
472473 Args:
@@ -478,9 +479,8 @@ def get_all_files_and_functions(
478479 Dictionary mapping file paths to lists of FunctionToOptimize.
479480
480481 """
481- functions : dict [str , list [FunctionToOptimize ]] = {}
482+ functions : dict [Path , list [FunctionToOptimize ]] = {}
482483 for file_path in get_files_for_language (module_root_path , ignore_paths , language ):
483- # Find all the functions in the file
484484 functions .update (find_all_functions_in_file (file_path ).items ())
485485 # Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
486486 # Helpful if an optimize-all run is stuck and we restart it.
@@ -785,7 +785,7 @@ def filter_functions(
785785 disable_logs : bool = False ,
786786) -> tuple [dict [Path , list [FunctionToOptimize ]], int ]:
787787 resolved_project_root = project_root .resolve ()
788- filtered_modified_functions : dict [str , list [FunctionToOptimize ]] = {}
788+ filtered_modified_functions : dict [Path , list [FunctionToOptimize ]] = {}
789789 blocklist_funcs = get_blocklisted_functions ()
790790 logger .debug (f"Blocklisted functions: { blocklist_funcs } " )
791791 # Remove any function that we don't want to optimize
@@ -892,7 +892,7 @@ def is_test_file(file_path_normalized: str) -> bool:
892892 functions_tmp .append (function )
893893 _functions = functions_tmp
894894
895- filtered_modified_functions [file_path ] = _functions
895+ filtered_modified_functions [file_path_path ] = _functions
896896 functions_count += len (_functions )
897897
898898 if not disable_logs :
@@ -913,7 +913,7 @@ def is_test_file(file_path_normalized: str) -> bool:
913913 if len (tree .children ) > 0 :
914914 console .print (tree )
915915 console .rule ()
916- return {Path ( k ) : v for k , v in filtered_modified_functions .items () if v }, functions_count
916+ return {k : v for k , v in filtered_modified_functions .items () if v }, functions_count
917917
918918
919919def filter_files_optimized (file_path : Path , tests_root : Path , ignore_paths : list [Path ], module_root : Path ) -> bool :
0 commit comments