diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index f1a570740..cf5d389f1 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -1196,53 +1196,50 @@ def _find_and_extract_body(self, source: str, function_name: str, analyzer: Tree source_bytes = source.encode("utf8") tree = analyzer.parse(source_bytes) - def find_function_node(node, target_name: str): - """Recursively find a function/method with the given name.""" + mv = memoryview(source_bytes) + target_name_bytes = function_name.encode("utf8") + + func_node = None + func_types = ("function_declaration", "function", "generator_function_declaration", "generator_function") + decl_types = ("lexical_declaration", "variable_declaration") + value_fn_types = ("arrow_function", "function_expression", "generator_function") + + stack = [tree.root_node] + while stack: + node = stack.pop() + # Check method definitions if node.type == "method_definition": name_node = node.child_by_field_name("name") if name_node: - name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - if name == target_name: - return node + if mv[name_node.start_byte : name_node.end_byte] == target_name_bytes: + func_node = node + break # Check function declarations - if node.type in ( - "function_declaration", - "function", - "generator_function_declaration", - "generator_function", - ): + if node.type in func_types: name_node = node.child_by_field_name("name") if name_node: - name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - if name == target_name: - return node + if mv[name_node.start_byte : name_node.end_byte] == target_name_bytes: + func_node = node + break # Check arrow functions and function expressions assigned to variables - if node.type in ("lexical_declaration", "variable_declaration"): + if node.type in decl_types: for child in node.children: if child.type == "variable_declarator": name_node = child.child_by_field_name("name") value_node = child.child_by_field_name("value") - if ( - name_node - and value_node - and value_node.type in ("arrow_function", "function_expression", "generator_function") - ): - name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - if name == target_name: - return value_node - - # Recurse into children - for child in node.children: - result = find_function_node(child, target_name) - if result: - return result + if name_node and value_node and value_node.type in value_fn_types: + if mv[name_node.start_byte : name_node.end_byte] == target_name_bytes: + func_node = value_node + break + if func_node: + break - return None + if node.children: + stack.extend(node.children) - func_node = find_function_node(tree.root_node, function_name) if not func_node: logger.debug("Could not find function '%s' in optimized code for body extraction", function_name) return None