|
13 | 13 | def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]: |
14 | 14 | """Extract the single dependent function from the code context excluding the main function.""" |
15 | 15 | dependent_functions = set() |
16 | | - for code_string in code_context.testgen_context.code_strings: |
17 | | - ast_tree = ast.parse(code_string.code) |
18 | | - dependent_functions.update( |
19 | | - {node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))} |
20 | | - ) |
21 | 16 |
|
22 | 17 | # Compare using bare name since AST extracts bare function names |
23 | 18 | bare_main = main_function.rsplit(".", 1)[-1] if "." in main_function else main_function |
24 | | - if bare_main in dependent_functions: |
25 | | - dependent_functions.discard(bare_main) |
| 19 | + |
| 20 | + for code_string in code_context.testgen_context.code_strings: |
| 21 | + # Quick heuristic: skip parsing entirely if there is no 'def' token, |
| 22 | + # since no function definitions can be present without it. |
| 23 | + if "def" not in code_string.code: |
| 24 | + continue |
| 25 | + |
| 26 | + ast_tree = ast.parse(code_string.code) |
| 27 | + # Add function names directly, skipping the bare main name. |
| 28 | + for node in ast_tree.body: |
| 29 | + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): |
| 30 | + name = node.name |
| 31 | + if name == bare_main: |
| 32 | + continue |
| 33 | + dependent_functions.add(name) |
| 34 | + # If more than one dependent function (other than the main) is found, |
| 35 | + # we can return False early since the final result cannot be a single name. |
| 36 | + if len(dependent_functions) > 1: |
| 37 | + return False |
| 38 | + |
26 | 39 |
|
27 | 40 | if not dependent_functions: |
28 | 41 | return False |
|
0 commit comments