Skip to content

Commit 2b2caf4

Browse files
committed
fix: resolve mypy errors in discovery and support files
- Change FunctionVisitor.file_path from str to Path - Unify dict keys to Path across discovery functions (get_all_files_and_functions, get_functions_within_lines, get_functions_within_git_diff, etc.) - Remove redundant isinstance check in discover_functions - Add assert for found_function narrowing after exit_with_message - Fix closest_matching_file_function_name return type narrowing
1 parent 91cf6ea commit 2b2caf4

3 files changed

Lines changed: 19 additions & 22 deletions

File tree

codeflash/discovery/functions_to_optimize.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def visit_Return(self, node: cst.Return) -> None:
6868
class FunctionVisitor(cst.CSTVisitor):
6969
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
7070

71-
def __init__(self, file_path: str) -> None:
71+
def __init__(self, file_path: Path) -> None:
7272
super().__init__()
73-
self.file_path: str = file_path
73+
self.file_path: Path = file_path
7474
self.functions: list[FunctionToOptimize] = []
7575

7676
@staticmethod
@@ -264,7 +264,7 @@ def get_functions_to_optimize(
264264
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
265265
"Only one of optimize_all, replay_test, or file should be provided"
266266
)
267-
functions: dict[str, list[FunctionToOptimize]]
267+
functions: dict[Path, list[FunctionToOptimize]]
268268
trace_file_path: Path | None = None
269269
is_lsp = is_LSP_enabled()
270270
with warnings.catch_warnings():
@@ -281,7 +281,7 @@ def get_functions_to_optimize(
281281
logger.info("!lsp|Finding all functions in the file '%s'…", file)
282282
console.rule()
283283
file = Path(file) if isinstance(file, str) else file
284-
functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file)
284+
functions = find_all_functions_in_file(file)
285285
if only_get_this_function is not None:
286286
split_function = only_get_this_function.split(".")
287287
if len(split_function) > 2:
@@ -316,6 +316,7 @@ def get_functions_to_optimize(
316316
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property"
317317
)
318318

319+
assert found_function is not None
319320
# For JavaScript/TypeScript, verify that the function (or its parent class) is exported
320321
# Non-exported functions cannot be imported by tests
321322
if found_function.language in ("javascript", "typescript"):
@@ -359,7 +360,7 @@ def get_functions_to_optimize(
359360
return filtered_modified_functions, functions_count, trace_file_path
360361

361362

362-
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]:
363+
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[Path, list[FunctionToOptimize]]:
363364
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
364365
return get_functions_within_lines(modified_lines)
365366

@@ -400,7 +401,7 @@ def closest_matching_file_function_name(
400401
closest_match = function
401402
closest_file = file_path
402403

403-
if closest_match is not None:
404+
if closest_match is not None and closest_file is not None:
404405
return closest_file, closest_match
405406
return None
406407

@@ -434,13 +435,13 @@ def levenshtein_distance(s1: str, s2: str) -> int:
434435
return previous[len1]
435436

436437

437-
def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]:
438+
def get_functions_inside_a_commit(commit_hash: str) -> dict[Path, list[FunctionToOptimize]]:
438439
modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash)
439440
return get_functions_within_lines(modified_lines)
440441

441442

442-
def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str, list[FunctionToOptimize]]:
443-
functions: dict[str, list[FunctionToOptimize]] = {}
443+
def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Path, list[FunctionToOptimize]]:
444+
functions: dict[Path, list[FunctionToOptimize]] = {}
444445
for path_str, lines_in_file in modified_lines.items():
445446
path = Path(path_str)
446447
if not path.exists():
@@ -452,9 +453,9 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
452453
except Exception as e:
453454
logger.exception(e)
454455
continue
455-
function_lines = FunctionVisitor(file_path=str(path))
456+
function_lines = FunctionVisitor(file_path=path)
456457
wrapper.visit(function_lines)
457-
functions[str(path)] = [
458+
functions[path] = [
458459
function_to_optimize
459460
for function_to_optimize in function_lines.functions
460461
if (start_line := function_to_optimize.starting_line) is not None
@@ -466,7 +467,7 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
466467

467468
def get_all_files_and_functions(
468469
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
469-
) -> dict[str, list[FunctionToOptimize]]:
470+
) -> dict[Path, list[FunctionToOptimize]]:
470471
"""Get all optimizable functions from files in the module root.
471472
472473
Args:
@@ -478,9 +479,8 @@ def get_all_files_and_functions(
478479
Dictionary mapping file paths to lists of FunctionToOptimize.
479480
480481
"""
481-
functions: dict[str, list[FunctionToOptimize]] = {}
482+
functions: dict[Path, list[FunctionToOptimize]] = {}
482483
for file_path in get_files_for_language(module_root_path, ignore_paths, language):
483-
# Find all the functions in the file
484484
functions.update(find_all_functions_in_file(file_path).items())
485485
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
486486
# Helpful if an optimize-all run is stuck and we restart it.
@@ -785,7 +785,7 @@ def filter_functions(
785785
disable_logs: bool = False,
786786
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
787787
resolved_project_root = project_root.resolve()
788-
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
788+
filtered_modified_functions: dict[Path, list[FunctionToOptimize]] = {}
789789
blocklist_funcs = get_blocklisted_functions()
790790
logger.debug(f"Blocklisted functions: {blocklist_funcs}")
791791
# Remove any function that we don't want to optimize
@@ -892,7 +892,7 @@ def is_test_file(file_path_normalized: str) -> bool:
892892
functions_tmp.append(function)
893893
_functions = functions_tmp
894894

895-
filtered_modified_functions[file_path] = _functions
895+
filtered_modified_functions[file_path_path] = _functions
896896
functions_count += len(_functions)
897897

898898
if not disable_logs:
@@ -913,7 +913,7 @@ def is_test_file(file_path_normalized: str) -> bool:
913913
if len(tree.children) > 0:
914914
console.print(tree)
915915
console.rule()
916-
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
916+
return {k: v for k, v in filtered_modified_functions.items() if v}, functions_count
917917

918918

919919
def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool:

codeflash/languages/python/support.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,11 @@ def discover_functions(
134134
tree = cst.parse_module(source)
135135

136136
wrapper = cst.metadata.MetadataWrapper(tree)
137-
function_visitor = FunctionVisitor(file_path=str(file_path))
137+
function_visitor = FunctionVisitor(file_path=file_path)
138138
wrapper.visit(function_visitor)
139139

140140
functions: list[FunctionToOptimize] = []
141141
for func in function_visitor.functions:
142-
if not isinstance(func, FunctionToOptimize):
143-
continue
144-
145142
if not criteria.include_async and func.is_async:
146143
continue
147144

codeflash/lsp/beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_functions_in_commit(params: OptimizableFunctionsInCommitParams) -> dict[
114114
return {"functions": file_to_qualified_names, "status": "success"}
115115

116116

117-
def _group_functions_by_file(functions: dict[str, list[FunctionToOptimize]]) -> dict[str, list[str]]:
117+
def _group_functions_by_file(functions: dict[Path, list[FunctionToOptimize]]) -> dict[str, list[str]]:
118118
file_to_funcs_to_optimize, _ = filter_functions(
119119
modified_functions=functions,
120120
tests_root=server.optimizer.test_cfg.tests_root,

0 commit comments

Comments
 (0)