Skip to content

Commit 950a279

Browse files
Merge branch 'main' into fix/class-method-export-detection
2 parents e392e6b + d557e74 commit 950a279

2 files changed

Lines changed: 54 additions & 5 deletions

File tree

codeflash/languages/treesitter_utils.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,23 +454,46 @@ def find_imports(self, source: str) -> list[ImportInfo]:
454454

455455
return imports
456456

457-
def _walk_tree_for_imports(self, node: Node, source_bytes: bytes, imports: list[ImportInfo]) -> None:
458-
"""Recursively walk the tree to find import statements."""
457+
def _walk_tree_for_imports(
458+
self, node: Node, source_bytes: bytes, imports: list[ImportInfo], in_function: bool = False
459+
) -> None:
460+
"""Recursively walk the tree to find import statements.
461+
462+
Args:
463+
node: Current node to check.
464+
source_bytes: Source code bytes.
465+
imports: List to append found imports to.
466+
in_function: Whether we're currently inside a function/method body.
467+
"""
468+
# Track when we enter function/method bodies
469+
# These node types contain function/method bodies where require() should not be treated as imports
470+
function_body_types = {
471+
"function_declaration",
472+
"method_definition",
473+
"arrow_function",
474+
"function_expression",
475+
"function", # Generic function in some grammars
476+
}
477+
459478
if node.type == "import_statement":
460479
import_info = self._extract_import_info(node, source_bytes)
461480
if import_info:
462481
imports.append(import_info)
463482

464-
# Also handle require() calls for CommonJS
465-
if node.type == "call_expression":
483+
# Also handle require() calls for CommonJS, but only at module level
484+
# require() inside functions is a dynamic import, not a module import
485+
if node.type == "call_expression" and not in_function:
466486
func_node = node.child_by_field_name("function")
467487
if func_node and self.get_node_text(func_node, source_bytes) == "require":
468488
import_info = self._extract_require_info(node, source_bytes)
469489
if import_info:
470490
imports.append(import_info)
471491

492+
# Update in_function flag for children
493+
child_in_function = in_function or node.type in function_body_types
494+
472495
for child in node.children:
473-
self._walk_tree_for_imports(child, source_bytes, imports)
496+
self._walk_tree_for_imports(child, source_bytes, imports, child_in_function)
474497

475498
def _extract_import_info(self, node: Node, source_bytes: bytes) -> ImportInfo | None:
476499
"""Extract import information from an import statement node."""

tests/test_languages/test_treesitter_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,32 @@ def test_find_require(self, js_analyzer):
351351
assert imports[0].module_path == "fs"
352352
assert imports[0].default_import == "fs"
353353

354+
def test_require_inside_function_not_import(self, js_analyzer):
355+
"""Test that require() inside functions is not treated as an import.
356+
357+
This is important because dynamic require() calls inside functions are
358+
not module-level imports and should not be extracted as such.
359+
"""
360+
code = """
361+
const fs = require('fs');
362+
363+
function loadModule() {
364+
const dynamic = require('dynamic-module');
365+
return dynamic;
366+
}
367+
368+
class MyClass {
369+
method() {
370+
const inMethod = require('method-module');
371+
}
372+
}
373+
"""
374+
imports = js_analyzer.find_imports(code)
375+
376+
# Only the module-level require should be found
377+
assert len(imports) == 1
378+
assert imports[0].module_path == "fs"
379+
354380
def test_find_multiple_imports(self, js_analyzer):
355381
"""Test finding multiple imports."""
356382
code = """

0 commit comments

Comments
 (0)