Skip to content

Commit dbe4221

Browse files
Merge pull request #1226 from codeflash-ai/feat/find-references-javascript
feat: Add find references functionality for JavaScript/TypeScript
2 parents afc4941 + 88999f0 commit dbe4221

6 files changed

Lines changed: 2394 additions & 10 deletions

File tree

codeflash/code_utils/code_extractor.py

Lines changed: 215 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,23 +1563,228 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
15631563
def get_opt_review_metrics(
15641564
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
15651565
) -> str:
1566-
if language != Language.PYTHON:
1567-
# TODO: {Claude} handle function refrences for other languages
1568-
return ""
1566+
"""Get function reference metrics for optimization review.
1567+
1568+
Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript.
1569+
1570+
Args:
1571+
source_code: Source code of the file containing the function.
1572+
file_path: Path to the file.
1573+
qualified_name: Qualified name of the function (e.g., "module.ClassName.method").
1574+
project_root: Root of the project.
1575+
tests_root: Root of the tests directory.
1576+
language: The programming language.
1577+
1578+
Returns:
1579+
Markdown-formatted string with code blocks showing calling functions.
1580+
"""
1581+
from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo
1582+
from codeflash.languages.registry import get_language_support
1583+
15691584
start_time = time.perf_counter()
1585+
15701586
try:
1587+
# Get the language support
1588+
lang_support = get_language_support(language)
1589+
if lang_support is None:
1590+
return ""
1591+
1592+
# Parse qualified name to get function name and class name
15711593
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
15721594
if len(qualified_name_split) == 1:
1573-
target_function, target_class = qualified_name_split[0], None
1595+
function_name, class_name = qualified_name_split[0], None
15741596
else:
1575-
target_function, target_class = qualified_name_split[1], qualified_name_split[0]
1576-
matches = get_fn_references_jedi(
1577-
source_code, file_path, project_root, target_function, target_class
1578-
) # jedi is not perfect, it doesn't capture aliased references
1579-
calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
1597+
function_name, class_name = qualified_name_split[1], qualified_name_split[0]
1598+
1599+
# Create a FunctionInfo for the function
1600+
# We don't have full line info here, so we'll use defaults
1601+
parents = ()
1602+
if class_name:
1603+
parents = (ParentInfo(name=class_name, type="ClassDef"),)
1604+
1605+
func_info = FunctionInfo(
1606+
name=function_name,
1607+
file_path=file_path,
1608+
start_line=1,
1609+
end_line=1,
1610+
parents=parents,
1611+
language=language,
1612+
)
1613+
1614+
# Find references using language support
1615+
references = lang_support.find_references(func_info, project_root, tests_root, max_files=500)
1616+
1617+
if not references:
1618+
return ""
1619+
1620+
# Format references as markdown code blocks
1621+
calling_fns_details = _format_references_as_markdown(
1622+
references, file_path, project_root, language
1623+
)
1624+
15801625
except Exception as e:
1626+
logger.debug(f"Error getting function references: {e}")
15811627
calling_fns_details = ""
1582-
logger.debug(f"Investigate {e}")
1628+
15831629
end_time = time.perf_counter()
15841630
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
15851631
return calling_fns_details
1632+
1633+
1634+
def _format_references_as_markdown(
1635+
references: list, file_path: Path, project_root: Path, language: Language
1636+
) -> str:
1637+
"""Format references as markdown code blocks with calling function code.
1638+
1639+
Args:
1640+
references: List of ReferenceInfo objects.
1641+
file_path: Path to the source file (to exclude).
1642+
project_root: Root of the project.
1643+
language: The programming language.
1644+
1645+
Returns:
1646+
Markdown-formatted string.
1647+
"""
1648+
# Group references by file
1649+
refs_by_file: dict[Path, list] = {}
1650+
for ref in references:
1651+
# Exclude the source file's definition/import references
1652+
if ref.file_path == file_path and ref.reference_type in ("import", "reexport"):
1653+
continue
1654+
1655+
if ref.file_path not in refs_by_file:
1656+
refs_by_file[ref.file_path] = []
1657+
refs_by_file[ref.file_path].append(ref)
1658+
1659+
fn_call_context = ""
1660+
context_len = 0
1661+
1662+
for ref_file, file_refs in refs_by_file.items():
1663+
if context_len > MAX_CONTEXT_LEN_REVIEW:
1664+
break
1665+
1666+
try:
1667+
path_relative = ref_file.relative_to(project_root)
1668+
except ValueError:
1669+
continue
1670+
1671+
# Get syntax highlighting language
1672+
ext = ref_file.suffix.lstrip(".")
1673+
if language == Language.PYTHON:
1674+
lang_hint = "python"
1675+
elif ext in ("ts", "tsx"):
1676+
lang_hint = "typescript"
1677+
else:
1678+
lang_hint = "javascript"
1679+
1680+
# Read the file to extract calling function context
1681+
try:
1682+
file_content = ref_file.read_text(encoding="utf-8")
1683+
lines = file_content.splitlines()
1684+
except Exception:
1685+
continue
1686+
1687+
# Get unique caller functions from this file
1688+
callers_seen: set[str] = set()
1689+
caller_contexts: list[str] = []
1690+
1691+
for ref in file_refs:
1692+
caller = ref.caller_function or "<module>"
1693+
if caller in callers_seen:
1694+
continue
1695+
callers_seen.add(caller)
1696+
1697+
# Extract context around the reference
1698+
if ref.caller_function:
1699+
# Try to extract the full calling function
1700+
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
1701+
if func_code:
1702+
caller_contexts.append(func_code)
1703+
context_len += len(func_code)
1704+
else:
1705+
# Module-level call - show a few lines of context
1706+
start_line = max(0, ref.line - 3)
1707+
end_line = min(len(lines), ref.line + 2)
1708+
context_code = "\n".join(lines[start_line:end_line])
1709+
caller_contexts.append(context_code)
1710+
context_len += len(context_code)
1711+
1712+
if caller_contexts:
1713+
fn_call_context += f"```{lang_hint}:{path_relative}\n"
1714+
fn_call_context += "\n".join(caller_contexts)
1715+
fn_call_context += "\n```\n"
1716+
1717+
return fn_call_context
1718+
1719+
1720+
def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
1721+
"""Extract the source code of a calling function.
1722+
1723+
Args:
1724+
source_code: Full source code of the file.
1725+
function_name: Name of the function to extract.
1726+
ref_line: Line number where the reference is.
1727+
language: The programming language.
1728+
1729+
Returns:
1730+
Source code of the function, or None if not found.
1731+
"""
1732+
if language == Language.PYTHON:
1733+
return _extract_calling_function_python(source_code, function_name, ref_line)
1734+
else:
1735+
return _extract_calling_function_js(source_code, function_name, ref_line)
1736+
1737+
1738+
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
1739+
"""Extract the source code of a calling function in Python."""
1740+
try:
1741+
import ast
1742+
1743+
tree = ast.parse(source_code)
1744+
lines = source_code.splitlines()
1745+
1746+
for node in ast.walk(tree):
1747+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
1748+
if node.name == function_name:
1749+
# Check if the reference line is within this function
1750+
start_line = node.lineno
1751+
end_line = node.end_lineno or start_line
1752+
if start_line <= ref_line <= end_line:
1753+
return "\n".join(lines[start_line - 1 : end_line])
1754+
return None
1755+
except Exception:
1756+
return None
1757+
1758+
1759+
def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:
1760+
"""Extract the source code of a calling function in JavaScript/TypeScript.
1761+
1762+
Args:
1763+
source_code: Full source code of the file.
1764+
function_name: Name of the function to extract.
1765+
ref_line: Line number where the reference is (helps identify the right function).
1766+
1767+
Returns:
1768+
Source code of the function, or None if not found.
1769+
"""
1770+
try:
1771+
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
1772+
1773+
# Try TypeScript first, fall back to JavaScript
1774+
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
1775+
try:
1776+
analyzer = TreeSitterAnalyzer(lang)
1777+
functions = analyzer.find_functions(source_code, include_methods=True)
1778+
1779+
for func in functions:
1780+
if func.name == function_name:
1781+
# Check if the reference line is within this function
1782+
if func.start_line <= ref_line <= func.end_line:
1783+
return func.source_text
1784+
break
1785+
except Exception:
1786+
continue
1787+
1788+
return None
1789+
except Exception:
1790+
return None

codeflash/languages/base.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,37 @@ class FunctionFilterCriteria:
236236
max_lines: int | None = None
237237

238238

239+
@dataclass
240+
class ReferenceInfo:
241+
"""Information about a reference (call site) to a function.
242+
243+
This class captures information about where a function is called
244+
from, including the file, line number, context, and caller function.
245+
246+
Attributes:
247+
file_path: Path to the file containing the reference.
248+
line: Line number (1-indexed).
249+
column: Column number (0-indexed).
250+
end_line: End line number (1-indexed).
251+
end_column: End column number (0-indexed).
252+
context: The line of code containing the reference.
253+
reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport").
254+
import_name: Name used to import the function (may differ from original).
255+
caller_function: Name of the function containing this reference (or None for module-level).
256+
257+
"""
258+
259+
file_path: Path
260+
line: int
261+
column: int
262+
end_line: int
263+
end_column: int
264+
context: str
265+
reference_type: str
266+
import_name: str | None
267+
caller_function: str | None = None
268+
269+
239270
@runtime_checkable
240271
class LanguageSupport(Protocol):
241272
"""Protocol defining what a language implementation must provide.
@@ -357,6 +388,29 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l
357388
"""
358389
...
359390

391+
def find_references(
392+
self, function: FunctionInfo, project_root: Path, tests_root: Path | None = None, max_files: int = 500
393+
) -> list[ReferenceInfo]:
394+
"""Find all references (call sites) to a function across the codebase.
395+
396+
This method finds all places where a function is called, including:
397+
- Direct calls
398+
- Callbacks (passed to other functions)
399+
- Memoized versions
400+
- Re-exports
401+
402+
Args:
403+
function: The function to find references for.
404+
project_root: Root of the project to search.
405+
tests_root: Root of tests directory (references in tests are excluded).
406+
max_files: Maximum number of files to search.
407+
408+
Returns:
409+
List of ReferenceInfo objects describing each reference location.
410+
411+
"""
412+
...
413+
360414
# === Code Transformation ===
361415

362416
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:

0 commit comments

Comments
 (0)