Skip to content

Commit 239a5a8

Browse files
Merge pull request #1246 from codeflash-ai/fix/class-method-export-detection
Fix class method export detection for JavaScript/TypeScript
2 parents b289730 + 8433fa9 commit 239a5a8

4 files changed

Lines changed: 92 additions & 7 deletions

File tree

codeflash/languages/javascript/find_references.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def find_references(
101101
source_file: Path,
102102
include_definition: bool = False,
103103
max_files: int = 1000,
104+
class_name: str | None = None,
104105
) -> list[Reference]:
105106
"""Find all references to a function across the project.
106107
@@ -109,6 +110,7 @@ def find_references(
109110
source_file: Path to the file where the function is defined.
110111
include_definition: Whether to include the function definition itself.
111112
max_files: Maximum number of files to search (prevents runaway searches).
113+
class_name: For class methods, the name of the containing class.
112114
113115
Returns:
114116
List of Reference objects describing each call site.
@@ -126,7 +128,7 @@ def find_references(
126128
return references
127129

128130
analyzer = get_analyzer_for_file(source_file)
129-
exported = self._analyze_exports(function_name, source_file, source_code, analyzer)
131+
exported = self._analyze_exports(function_name, source_file, source_code, analyzer, class_name)
130132

131133
if not exported:
132134
logger.debug("Function %s is not exported from %s", function_name, source_file)
@@ -250,21 +252,29 @@ def find_references(
250252
return unique_refs
251253

252254
def _analyze_exports(
253-
self, function_name: str, file_path: Path, source_code: str, analyzer: TreeSitterAnalyzer
255+
self,
256+
function_name: str,
257+
file_path: Path,
258+
source_code: str,
259+
analyzer: TreeSitterAnalyzer,
260+
class_name: str | None = None,
254261
) -> ExportedFunction | None:
255262
"""Analyze how a function is exported from its file.
256263
264+
For class methods, also checks if the containing class is exported.
265+
257266
Args:
258267
function_name: Name of the function to check.
259268
file_path: Path to the source file.
260269
source_code: Source code content.
261270
analyzer: TreeSitterAnalyzer instance.
271+
class_name: For class methods, the name of the containing class.
262272
263273
Returns:
264274
ExportedFunction if the function is exported, None otherwise.
265275
266276
"""
267-
is_exported, export_name = analyzer.is_function_exported(source_code, function_name)
277+
is_exported, export_name = analyzer.is_function_exported(source_code, function_name, class_name)
268278

269279
if not is_exported:
270280
return None
@@ -825,6 +835,7 @@ def find_references(
825835
source_file: Path,
826836
project_root: Path | None = None,
827837
max_files: int = 1000,
838+
class_name: str | None = None,
828839
) -> list[Reference]:
829840
"""Convenience function to find all references to a function.
830841
@@ -835,6 +846,7 @@ def find_references(
835846
source_file: Path to the file where the function is defined.
836847
project_root: Root directory of the project. If None, uses source_file's parent.
837848
max_files: Maximum number of files to search.
849+
class_name: For class methods, the name of the containing class.
838850
839851
Returns:
840852
List of Reference objects describing each call site.
@@ -858,4 +870,4 @@ def find_references(
858870
project_root = source_file.parent
859871

860872
finder = ReferenceFinder(project_root)
861-
return finder.find_references(function_name, source_file, max_files=max_files)
873+
return finder.find_references(function_name, source_file, max_files=max_files, class_name=class_name)

codeflash/languages/javascript/support.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,9 @@ def find_references(
988988

989989
try:
990990
finder = ReferenceFinder(project_root)
991-
refs = finder.find_references(function.name, function.file_path, max_files=max_files)
991+
refs = finder.find_references(
992+
function.name, function.file_path, max_files=max_files, class_name=function.class_name
993+
)
992994

993995
# Convert to ReferenceInfo and filter out tests
994996
result: list[ReferenceInfo] = []

codeflash/languages/treesitter_utils.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -864,20 +864,27 @@ def _extract_commonjs_export(self, node: Node, source_bytes: bytes) -> ExportInf
864864
end_line=node.end_point[0] + 1,
865865
)
866866

867-
def is_function_exported(self, source: str, function_name: str) -> tuple[bool, str | None]:
867+
def is_function_exported(
868+
self, source: str, function_name: str, class_name: str | None = None
869+
) -> tuple[bool, str | None]:
868870
"""Check if a function is exported and get its export name.
869871
872+
For class methods, also checks if the containing class is exported.
873+
870874
Args:
871875
source: The source code to analyze.
872876
function_name: The name of the function to check.
877+
class_name: For class methods, the name of the containing class.
873878
874879
Returns:
875880
Tuple of (is_exported, export_name). export_name may differ from
876-
function_name if exported with an alias.
881+
function_name if exported with an alias. For class methods,
882+
returns the class export name.
877883
878884
"""
879885
exports = self.find_exports(source)
880886

887+
# First, check if the function itself is directly exported
881888
for export in exports:
882889
# Check default export
883890
if export.default_export == function_name:
@@ -888,6 +895,18 @@ def is_function_exported(self, source: str, function_name: str) -> tuple[bool, s
888895
if name == function_name:
889896
return (True, alias if alias else name)
890897

898+
# For class methods, check if the containing class is exported
899+
if class_name:
900+
for export in exports:
901+
# Check if class is default export
902+
if export.default_export == class_name:
903+
return (True, class_name)
904+
905+
# Check if class is in named exports
906+
for name, alias in export.exported_names:
907+
if name == class_name:
908+
return (True, alias if alias else name)
909+
891910
return (False, None)
892911

893912
def find_function_calls(self, source: str, within_function: FunctionNode) -> list[str]:

tests/test_languages/test_import_resolver.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,13 @@ def js_analyzer(self):
474474

475475
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
476476

477+
@pytest.fixture
478+
def ts_analyzer(self):
479+
"""Create a TypeScript analyzer."""
480+
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
481+
482+
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
483+
477484
def test_module_exports_function(self, js_analyzer):
478485
"""Test module.exports = function() {}."""
479486
code = "module.exports = function helper() { return 1; };"
@@ -584,6 +591,51 @@ def test_is_function_exported_commonjs_property(self, js_analyzer):
584591
assert is_exported is True
585592
assert export_name == "helper"
586593

594+
def test_is_class_method_exported_via_class(self, ts_analyzer):
595+
"""Test is_function_exported returns True for method of exported class."""
596+
code = """
597+
export class BloomFilter {
598+
getHashValues(key: string): number[] {
599+
return [1, 2, 3];
600+
}
601+
}
602+
"""
603+
# Method itself is not directly exported
604+
is_exported, export_name = ts_analyzer.is_function_exported(code, "getHashValues")
605+
assert is_exported is False
606+
assert export_name is None
607+
608+
# But when we pass the class name, it should find the class export
609+
is_exported, export_name = ts_analyzer.is_function_exported(code, "getHashValues", "BloomFilter")
610+
assert is_exported is True
611+
assert export_name == "BloomFilter"
612+
613+
def test_is_class_method_exported_default_class(self, ts_analyzer):
614+
"""Test is_function_exported returns True for method of default exported class."""
615+
code = """
616+
export default class Calculator {
617+
add(a: number, b: number): number {
618+
return a + b;
619+
}
620+
}
621+
"""
622+
# When we pass the class name, it should find the default export
623+
is_exported, export_name = ts_analyzer.is_function_exported(code, "add", "Calculator")
624+
assert is_exported is True
625+
assert export_name == "Calculator"
626+
627+
def test_is_class_method_not_exported_non_exported_class(self, ts_analyzer):
628+
"""Test is_function_exported returns False for method of non-exported class."""
629+
code = """
630+
class InternalClass {
631+
helper(): void {}
632+
}
633+
"""
634+
# Even with class name, non-exported class method should not be exported
635+
is_exported, export_name = ts_analyzer.is_function_exported(code, "helper", "InternalClass")
636+
assert is_exported is False
637+
assert export_name is None
638+
587639

588640
class TestCommonJSImportResolver:
589641
"""Tests for ImportResolver with CommonJS require() imports."""

0 commit comments

Comments
 (0)