Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 28 additions & 31 deletions codeflash/languages/javascript/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,53 +1196,50 @@ def _find_and_extract_body(self, source: str, function_name: str, analyzer: Tree
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)

def find_function_node(node, target_name: str):
"""Recursively find a function/method with the given name."""
mv = memoryview(source_bytes)
target_name_bytes = function_name.encode("utf8")

func_node = None
func_types = ("function_declaration", "function", "generator_function_declaration", "generator_function")
decl_types = ("lexical_declaration", "variable_declaration")
value_fn_types = ("arrow_function", "function_expression", "generator_function")

stack = [tree.root_node]
while stack:
node = stack.pop()

# Check method definitions
if node.type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return node
if mv[name_node.start_byte : name_node.end_byte] == target_name_bytes:
func_node = node
break

# Check function declarations
if node.type in (
"function_declaration",
"function",
"generator_function_declaration",
"generator_function",
):
if node.type in func_types:
name_node = node.child_by_field_name("name")
if name_node:
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return node
if mv[name_node.start_byte : name_node.end_byte] == target_name_bytes:
func_node = node
break

# Check arrow functions and function expressions assigned to variables
if node.type in ("lexical_declaration", "variable_declaration"):
if node.type in decl_types:
for child in node.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
value_node = child.child_by_field_name("value")
if (
name_node
and value_node
and value_node.type in ("arrow_function", "function_expression", "generator_function")
):
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return value_node

# Recurse into children
for child in node.children:
result = find_function_node(child, target_name)
if result:
return result
if name_node and value_node and value_node.type in value_fn_types:
if mv[name_node.start_byte : name_node.end_byte] == target_name_bytes:
func_node = value_node
break
if func_node:
break

return None
if node.children:
stack.extend(node.children)

func_node = find_function_node(tree.root_node, function_name)
if not func_node:
logger.debug("Could not find function '%s' in optimized code for body extraction", function_name)
return None
Expand Down
Loading