Skip to content

Commit 6e3c05a

Browse files
Merge branch 'main' into fix/jest-config-for-typescript
2 parents c66cc4e + 239a5a8 commit 6e3c05a

6 files changed

Lines changed: 149 additions & 15 deletions

File tree

codeflash/languages/javascript/find_references.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def find_references(
101101
source_file: Path,
102102
include_definition: bool = False,
103103
max_files: int = 1000,
104+
class_name: str | None = None,
104105
) -> list[Reference]:
105106
"""Find all references to a function across the project.
106107
@@ -109,6 +110,7 @@ def find_references(
109110
source_file: Path to the file where the function is defined.
110111
include_definition: Whether to include the function definition itself.
111112
max_files: Maximum number of files to search (prevents runaway searches).
113+
class_name: For class methods, the name of the containing class.
112114
113115
Returns:
114116
List of Reference objects describing each call site.
@@ -126,7 +128,7 @@ def find_references(
126128
return references
127129

128130
analyzer = get_analyzer_for_file(source_file)
129-
exported = self._analyze_exports(function_name, source_file, source_code, analyzer)
131+
exported = self._analyze_exports(function_name, source_file, source_code, analyzer, class_name)
130132

131133
if not exported:
132134
logger.debug("Function %s is not exported from %s", function_name, source_file)
@@ -250,21 +252,29 @@ def find_references(
250252
return unique_refs
251253

252254
def _analyze_exports(
253-
self, function_name: str, file_path: Path, source_code: str, analyzer: TreeSitterAnalyzer
255+
self,
256+
function_name: str,
257+
file_path: Path,
258+
source_code: str,
259+
analyzer: TreeSitterAnalyzer,
260+
class_name: str | None = None,
254261
) -> ExportedFunction | None:
255262
"""Analyze how a function is exported from its file.
256263
264+
For class methods, also checks if the containing class is exported.
265+
257266
Args:
258267
function_name: Name of the function to check.
259268
file_path: Path to the source file.
260269
source_code: Source code content.
261270
analyzer: TreeSitterAnalyzer instance.
271+
class_name: For class methods, the name of the containing class.
262272
263273
Returns:
264274
ExportedFunction if the function is exported, None otherwise.
265275
266276
"""
267-
is_exported, export_name = analyzer.is_function_exported(source_code, function_name)
277+
is_exported, export_name = analyzer.is_function_exported(source_code, function_name, class_name)
268278

269279
if not is_exported:
270280
return None
@@ -825,6 +835,7 @@ def find_references(
825835
source_file: Path,
826836
project_root: Path | None = None,
827837
max_files: int = 1000,
838+
class_name: str | None = None,
828839
) -> list[Reference]:
829840
"""Convenience function to find all references to a function.
830841
@@ -835,6 +846,7 @@ def find_references(
835846
source_file: Path to the file where the function is defined.
836847
project_root: Root directory of the project. If None, uses source_file's parent.
837848
max_files: Maximum number of files to search.
849+
class_name: For class methods, the name of the containing class.
838850
839851
Returns:
840852
List of Reference objects describing each call site.
@@ -858,4 +870,4 @@ def find_references(
858870
project_root = source_file.parent
859871

860872
finder = ReferenceFinder(project_root)
861-
return finder.find_references(function_name, source_file, max_files=max_files)
873+
return finder.find_references(function_name, source_file, max_files=max_files, class_name=class_name)

codeflash/languages/javascript/support.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,9 @@ def find_references(
988988

989989
try:
990990
finder = ReferenceFinder(project_root)
991-
refs = finder.find_references(function.name, function.file_path, max_files=max_files)
991+
refs = finder.find_references(
992+
function.name, function.file_path, max_files=max_files, class_name=function.class_name
993+
)
992994

993995
# Convert to ReferenceInfo and filter out tests
994996
result: list[ReferenceInfo] = []

codeflash/languages/treesitter_utils.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,23 +454,46 @@ def find_imports(self, source: str) -> list[ImportInfo]:
454454

455455
return imports
456456

457-
def _walk_tree_for_imports(self, node: Node, source_bytes: bytes, imports: list[ImportInfo]) -> None:
458-
"""Recursively walk the tree to find import statements."""
457+
def _walk_tree_for_imports(
458+
self, node: Node, source_bytes: bytes, imports: list[ImportInfo], in_function: bool = False
459+
) -> None:
460+
"""Recursively walk the tree to find import statements.
461+
462+
Args:
463+
node: Current node to check.
464+
source_bytes: Source code bytes.
465+
imports: List to append found imports to.
466+
in_function: Whether we're currently inside a function/method body.
467+
"""
468+
# Track when we enter function/method bodies
469+
# These node types contain function/method bodies where require() should not be treated as imports
470+
function_body_types = {
471+
"function_declaration",
472+
"method_definition",
473+
"arrow_function",
474+
"function_expression",
475+
"function", # Generic function in some grammars
476+
}
477+
459478
if node.type == "import_statement":
460479
import_info = self._extract_import_info(node, source_bytes)
461480
if import_info:
462481
imports.append(import_info)
463482

464-
# Also handle require() calls for CommonJS
465-
if node.type == "call_expression":
483+
# Also handle require() calls for CommonJS, but only at module level
484+
# require() inside functions is a dynamic import, not a module import
485+
if node.type == "call_expression" and not in_function:
466486
func_node = node.child_by_field_name("function")
467487
if func_node and self.get_node_text(func_node, source_bytes) == "require":
468488
import_info = self._extract_require_info(node, source_bytes)
469489
if import_info:
470490
imports.append(import_info)
471491

492+
# Update in_function flag for children
493+
child_in_function = in_function or node.type in function_body_types
494+
472495
for child in node.children:
473-
self._walk_tree_for_imports(child, source_bytes, imports)
496+
self._walk_tree_for_imports(child, source_bytes, imports, child_in_function)
474497

475498
def _extract_import_info(self, node: Node, source_bytes: bytes) -> ImportInfo | None:
476499
"""Extract import information from an import statement node."""
@@ -841,20 +864,27 @@ def _extract_commonjs_export(self, node: Node, source_bytes: bytes) -> ExportInf
841864
end_line=node.end_point[0] + 1,
842865
)
843866

844-
def is_function_exported(self, source: str, function_name: str) -> tuple[bool, str | None]:
867+
def is_function_exported(
868+
self, source: str, function_name: str, class_name: str | None = None
869+
) -> tuple[bool, str | None]:
845870
"""Check if a function is exported and get its export name.
846871
872+
For class methods, also checks if the containing class is exported.
873+
847874
Args:
848875
source: The source code to analyze.
849876
function_name: The name of the function to check.
877+
class_name: For class methods, the name of the containing class.
850878
851879
Returns:
852880
Tuple of (is_exported, export_name). export_name may differ from
853-
function_name if exported with an alias.
881+
function_name if exported with an alias. For class methods,
882+
returns the class export name.
854883
855884
"""
856885
exports = self.find_exports(source)
857886

887+
# First, check if the function itself is directly exported
858888
for export in exports:
859889
# Check default export
860890
if export.default_export == function_name:
@@ -865,6 +895,18 @@ def is_function_exported(self, source: str, function_name: str) -> tuple[bool, s
865895
if name == function_name:
866896
return (True, alias if alias else name)
867897

898+
# For class methods, check if the containing class is exported
899+
if class_name:
900+
for export in exports:
901+
# Check if class is default export
902+
if export.default_export == class_name:
903+
return (True, class_name)
904+
905+
# Check if class is in named exports
906+
for name, alias in export.exported_names:
907+
if name == class_name:
908+
return (True, alias if alias else name)
909+
868910
return (False, None)
869911

870912
def find_function_calls(self, source: str, within_function: FunctionNode) -> list[str]:

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,10 +2172,10 @@ def process_review(
21722172
else self.function_trace_id,
21732173
"coverage_message": coverage_message,
21742174
"replay_tests": replay_tests,
2175-
"concolic_tests": concolic_tests,
2175+
#"concolic_tests": concolic_tests,
21762176
"language": self.function_to_optimize.language,
2177-
"original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""),
2178-
"optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""),
2177+
#"original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""),
2178+
#"optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""),
21792179
}
21802180

21812181
raise_pr = not self.args.no_pr

tests/test_languages/test_import_resolver.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,13 @@ def js_analyzer(self):
474474

475475
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
476476

477+
@pytest.fixture
478+
def ts_analyzer(self):
479+
"""Create a TypeScript analyzer."""
480+
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
481+
482+
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
483+
477484
def test_module_exports_function(self, js_analyzer):
478485
"""Test module.exports = function() {}."""
479486
code = "module.exports = function helper() { return 1; };"
@@ -584,6 +591,51 @@ def test_is_function_exported_commonjs_property(self, js_analyzer):
584591
assert is_exported is True
585592
assert export_name == "helper"
586593

594+
def test_is_class_method_exported_via_class(self, ts_analyzer):
595+
"""Test is_function_exported returns True for method of exported class."""
596+
code = """
597+
export class BloomFilter {
598+
getHashValues(key: string): number[] {
599+
return [1, 2, 3];
600+
}
601+
}
602+
"""
603+
# Method itself is not directly exported
604+
is_exported, export_name = ts_analyzer.is_function_exported(code, "getHashValues")
605+
assert is_exported is False
606+
assert export_name is None
607+
608+
# But when we pass the class name, it should find the class export
609+
is_exported, export_name = ts_analyzer.is_function_exported(code, "getHashValues", "BloomFilter")
610+
assert is_exported is True
611+
assert export_name == "BloomFilter"
612+
613+
def test_is_class_method_exported_default_class(self, ts_analyzer):
614+
"""Test is_function_exported returns True for method of default exported class."""
615+
code = """
616+
export default class Calculator {
617+
add(a: number, b: number): number {
618+
return a + b;
619+
}
620+
}
621+
"""
622+
# When we pass the class name, it should find the default export
623+
is_exported, export_name = ts_analyzer.is_function_exported(code, "add", "Calculator")
624+
assert is_exported is True
625+
assert export_name == "Calculator"
626+
627+
def test_is_class_method_not_exported_non_exported_class(self, ts_analyzer):
628+
"""Test is_function_exported returns False for method of non-exported class."""
629+
code = """
630+
class InternalClass {
631+
helper(): void {}
632+
}
633+
"""
634+
# Even with class name, non-exported class method should not be exported
635+
is_exported, export_name = ts_analyzer.is_function_exported(code, "helper", "InternalClass")
636+
assert is_exported is False
637+
assert export_name is None
638+
587639

588640
class TestCommonJSImportResolver:
589641
"""Tests for ImportResolver with CommonJS require() imports."""

tests/test_languages/test_treesitter_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,32 @@ def test_find_require(self, js_analyzer):
351351
assert imports[0].module_path == "fs"
352352
assert imports[0].default_import == "fs"
353353

354+
def test_require_inside_function_not_import(self, js_analyzer):
355+
"""Test that require() inside functions is not treated as an import.
356+
357+
This is important because dynamic require() calls inside functions are
358+
not module-level imports and should not be extracted as such.
359+
"""
360+
code = """
361+
const fs = require('fs');
362+
363+
function loadModule() {
364+
const dynamic = require('dynamic-module');
365+
return dynamic;
366+
}
367+
368+
class MyClass {
369+
method() {
370+
const inMethod = require('method-module');
371+
}
372+
}
373+
"""
374+
imports = js_analyzer.find_imports(code)
375+
376+
# Only the module-level require should be found
377+
assert len(imports) == 1
378+
assert imports[0].module_path == "fs"
379+
354380
def test_find_multiple_imports(self, js_analyzer):
355381
"""Test finding multiple imports."""
356382
code = """

0 commit comments

Comments
 (0)