Skip to content

Commit 77f1eea

Browse files
misrasaurabh1claude
andcommitted
refactor: Move find_references into LanguageSupport abstraction
- Add ReferenceInfo dataclass to base.py for language-agnostic reference info - Add find_references method to LanguageSupport protocol - Implement find_references in JavaScriptSupport using tree-sitter - Implement find_references in PythonSupport using jedi - Refactor get_opt_review_metrics to use LanguageSupport abstraction - Both Python and JavaScript/TypeScript now use the same abstraction This provides a clean, unified API for finding function references across languages: ```python lang_support = get_language_support(Language.TYPESCRIPT) refs = lang_support.find_references(func_info, project_root, tests_root) ``` Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 4a54486 commit 77f1eea

4 files changed

Lines changed: 391 additions & 102 deletions

File tree

codeflash/code_utils/code_extractor.py

Lines changed: 161 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,138 +1563,197 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
15631563
def get_opt_review_metrics(
15641564
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
15651565
) -> str:
1566-
start_time = time.perf_counter()
1566+
"""Get function reference metrics for optimization review.
15671567
1568-
if language == Language.PYTHON:
1569-
calling_fns_details = _get_python_references(source_code, file_path, qualified_name, project_root, tests_root)
1570-
elif language in (Language.JAVASCRIPT, Language.TYPESCRIPT):
1571-
calling_fns_details = _get_javascript_references(file_path, qualified_name, project_root, tests_root)
1572-
else:
1573-
calling_fns_details = ""
1568+
Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript.
15741569
1575-
end_time = time.perf_counter()
1576-
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
1577-
return calling_fns_details
1570+
Args:
1571+
source_code: Source code of the file containing the function.
1572+
file_path: Path to the file.
1573+
qualified_name: Qualified name of the function (e.g., "module.ClassName.method").
1574+
project_root: Root of the project.
1575+
tests_root: Root of the tests directory.
1576+
language: The programming language.
1577+
1578+
Returns:
1579+
Markdown-formatted string with code blocks showing calling functions.
1580+
"""
1581+
from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo
1582+
from codeflash.languages.registry import get_language_support
15781583

1584+
start_time = time.perf_counter()
15791585

1580-
def _get_python_references(
1581-
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path
1582-
) -> str:
1583-
"""Get function references for Python code using jedi."""
15841586
try:
1587+
# Get the language support
1588+
lang_support = get_language_support(language)
1589+
if lang_support is None:
1590+
return ""
1591+
1592+
# Parse qualified name to get function name and class name
15851593
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
15861594
if len(qualified_name_split) == 1:
1587-
target_function, target_class = qualified_name_split[0], None
1595+
function_name, class_name = qualified_name_split[0], None
15881596
else:
1589-
target_function, target_class = qualified_name_split[1], qualified_name_split[0]
1590-
matches = get_fn_references_jedi(
1591-
source_code, file_path, project_root, target_function, target_class
1592-
) # jedi is not perfect, it doesn't capture aliased references
1593-
return find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
1597+
function_name, class_name = qualified_name_split[1], qualified_name_split[0]
1598+
1599+
# Create a FunctionInfo for the function
1600+
# We don't have full line info here, so we'll use defaults
1601+
parents = ()
1602+
if class_name:
1603+
parents = (ParentInfo(name=class_name, type="ClassDef"),)
1604+
1605+
func_info = FunctionInfo(
1606+
name=function_name,
1607+
file_path=file_path,
1608+
start_line=1,
1609+
end_line=1,
1610+
parents=parents,
1611+
language=language,
1612+
)
1613+
1614+
# Find references using language support
1615+
references = lang_support.find_references(func_info, project_root, tests_root, max_files=500)
1616+
1617+
if not references:
1618+
return ""
1619+
1620+
# Format references as markdown code blocks
1621+
calling_fns_details = _format_references_as_markdown(
1622+
references, file_path, project_root, language
1623+
)
1624+
15941625
except Exception as e:
1595-
logger.debug(f"Error getting Python references: {e}")
1596-
return ""
1626+
logger.debug(f"Error getting function references: {e}")
1627+
calling_fns_details = ""
15971628

1629+
end_time = time.perf_counter()
1630+
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
1631+
return calling_fns_details
15981632

1599-
def _get_javascript_references(
1600-
file_path: Path, qualified_name: str, project_root: Path, tests_root: Path
1633+
1634+
def _format_references_as_markdown(
1635+
references: list, file_path: Path, project_root: Path, language: Language
16011636
) -> str:
1602-
"""Get function references for JavaScript/TypeScript code using tree-sitter.
1637+
"""Format references as markdown code blocks with calling function code.
16031638
1604-
This function finds all call sites of a JavaScript/TypeScript function
1605-
across the codebase and formats them for the optimizer's context.
1606-
"""
1607-
try:
1608-
from codeflash.languages.javascript.find_references import ReferenceFinder
1609-
from codeflash.languages.treesitter_utils import get_analyzer_for_file
1639+
Args:
1640+
references: List of ReferenceInfo objects.
1641+
file_path: Path to the source file (to exclude).
1642+
project_root: Root of the project.
1643+
language: The programming language.
16101644
1611-
# Extract function name from qualified name
1612-
# Qualified name could be "functionName" or "ClassName.methodName"
1613-
function_name = qualified_name.rsplit(".", maxsplit=1)[-1]
1645+
Returns:
1646+
Markdown-formatted string.
1647+
"""
1648+
# Group references by file
1649+
refs_by_file: dict[Path, list] = {}
1650+
for ref in references:
1651+
# Exclude the source file's definition/import references
1652+
if ref.file_path == file_path and ref.reference_type in ("import", "reexport"):
1653+
continue
16141654

1615-
finder = ReferenceFinder(project_root)
1616-
references = finder.find_references(function_name, file_path, max_files=500)
1655+
if ref.file_path not in refs_by_file:
1656+
refs_by_file[ref.file_path] = []
1657+
refs_by_file[ref.file_path].append(ref)
16171658

1618-
if not references:
1619-
return ""
1659+
fn_call_context = ""
1660+
context_len = 0
16201661

1621-
# Format references similar to Python format
1622-
fn_call_context = ""
1623-
context_len = 0
1662+
for ref_file, file_refs in refs_by_file.items():
1663+
if context_len > MAX_CONTEXT_LEN_REVIEW:
1664+
break
16241665

1625-
# Group references by file
1626-
refs_by_file: dict[Path, list] = {}
1627-
for ref in references:
1628-
# Exclude test files
1629-
try:
1630-
if ref.file_path.relative_to(tests_root):
1631-
continue
1632-
except ValueError:
1633-
pass
1666+
try:
1667+
path_relative = ref_file.relative_to(project_root)
1668+
except ValueError:
1669+
continue
16341670

1635-
# Exclude the source file's definition
1636-
if ref.file_path == file_path and ref.reference_type == "import":
1637-
continue
1671+
# Get syntax highlighting language
1672+
ext = ref_file.suffix.lstrip(".")
1673+
if language == Language.PYTHON:
1674+
lang_hint = "python"
1675+
elif ext in ("ts", "tsx"):
1676+
lang_hint = "typescript"
1677+
else:
1678+
lang_hint = "javascript"
16381679

1639-
if ref.file_path not in refs_by_file:
1640-
refs_by_file[ref.file_path] = []
1641-
refs_by_file[ref.file_path].append(ref)
1680+
# Read the file to extract calling function context
1681+
try:
1682+
file_content = ref_file.read_text(encoding="utf-8")
1683+
lines = file_content.splitlines()
1684+
except Exception:
1685+
continue
16421686

1643-
for ref_file, file_refs in refs_by_file.items():
1644-
if context_len > MAX_CONTEXT_LEN_REVIEW:
1645-
break
1687+
# Get unique caller functions from this file
1688+
callers_seen: set[str] = set()
1689+
caller_contexts: list[str] = []
16461690

1647-
try:
1648-
path_relative = ref_file.relative_to(project_root)
1649-
except ValueError:
1691+
for ref in file_refs:
1692+
caller = ref.caller_function or "<module>"
1693+
if caller in callers_seen:
16501694
continue
1695+
callers_seen.add(caller)
1696+
1697+
# Extract context around the reference
1698+
if ref.caller_function:
1699+
# Try to extract the full calling function
1700+
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
1701+
if func_code:
1702+
caller_contexts.append(func_code)
1703+
context_len += len(func_code)
1704+
else:
1705+
# Module-level call - show a few lines of context
1706+
start_line = max(0, ref.line - 3)
1707+
end_line = min(len(lines), ref.line + 2)
1708+
context_code = "\n".join(lines[start_line:end_line])
1709+
caller_contexts.append(context_code)
1710+
context_len += len(context_code)
1711+
1712+
if caller_contexts:
1713+
fn_call_context += f"```{lang_hint}:{path_relative}\n"
1714+
fn_call_context += "\n".join(caller_contexts)
1715+
fn_call_context += "\n```\n"
16511716

1652-
# Get the file extension for syntax highlighting
1653-
ext = ref_file.suffix.lstrip(".")
1654-
lang = "typescript" if ext in ("ts", "tsx") else "javascript"
1717+
return fn_call_context
16551718

1656-
# Read the file to extract calling function context
1657-
try:
1658-
file_content = ref_file.read_text(encoding="utf-8")
1659-
lines = file_content.splitlines()
1660-
except Exception:
1661-
continue
16621719

1663-
# Get unique caller functions from this file
1664-
callers_seen = set()
1665-
caller_contexts = []
1720+
def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
1721+
"""Extract the source code of a calling function.
16661722
1667-
for ref in file_refs:
1668-
caller = ref.caller_function or "<module>"
1669-
if caller in callers_seen:
1670-
continue
1671-
callers_seen.add(caller)
1672-
1673-
# Extract context around the reference (the calling function or surrounding lines)
1674-
if ref.caller_function:
1675-
# Try to extract the full calling function
1676-
func_code = _extract_calling_function_js(file_content, ref.caller_function, ref.line)
1677-
if func_code:
1678-
caller_contexts.append(func_code)
1679-
context_len += len(func_code)
1680-
else:
1681-
# Module-level call - just show a few lines of context
1682-
start_line = max(0, ref.line - 3)
1683-
end_line = min(len(lines), ref.line + 2)
1684-
context_code = "\n".join(lines[start_line:end_line])
1685-
caller_contexts.append(context_code)
1686-
context_len += len(context_code)
1723+
Args:
1724+
source_code: Full source code of the file.
1725+
function_name: Name of the function to extract.
1726+
ref_line: Line number where the reference is.
1727+
language: The programming language.
16871728
1688-
if caller_contexts:
1689-
fn_call_context += f"```{lang}:{path_relative}\n"
1690-
fn_call_context += "\n".join(caller_contexts)
1691-
fn_call_context += "\n```\n"
1729+
Returns:
1730+
Source code of the function, or None if not found.
1731+
"""
1732+
if language == Language.PYTHON:
1733+
return _extract_calling_function_python(source_code, function_name, ref_line)
1734+
else:
1735+
return _extract_calling_function_js(source_code, function_name, ref_line)
16921736

1693-
return fn_call_context
16941737

1695-
except Exception as e:
1696-
logger.debug(f"Error getting JavaScript references: {e}")
1697-
return ""
1738+
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
1739+
"""Extract the source code of a calling function in Python."""
1740+
try:
1741+
import ast
1742+
1743+
tree = ast.parse(source_code)
1744+
lines = source_code.splitlines()
1745+
1746+
for node in ast.walk(tree):
1747+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
1748+
if node.name == function_name:
1749+
# Check if the reference line is within this function
1750+
start_line = node.lineno
1751+
end_line = node.end_lineno or start_line
1752+
if start_line <= ref_line <= end_line:
1753+
return "\n".join(lines[start_line - 1 : end_line])
1754+
return None
1755+
except Exception:
1756+
return None
16981757

16991758

17001759
def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:

codeflash/languages/base.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,37 @@ class FunctionFilterCriteria:
236236
max_lines: int | None = None
237237

238238

239+
@dataclass
240+
class ReferenceInfo:
241+
"""Information about a reference (call site) to a function.
242+
243+
This class captures information about where a function is called
244+
from, including the file, line number, context, and caller function.
245+
246+
Attributes:
247+
file_path: Path to the file containing the reference.
248+
line: Line number (1-indexed).
249+
column: Column number (0-indexed).
250+
end_line: End line number (1-indexed).
251+
end_column: End column number (0-indexed).
252+
context: The line of code containing the reference.
253+
reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport").
254+
import_name: Name used to import the function (may differ from original).
255+
caller_function: Name of the function containing this reference (or None for module-level).
256+
257+
"""
258+
259+
file_path: Path
260+
line: int
261+
column: int
262+
end_line: int
263+
end_column: int
264+
context: str
265+
reference_type: str
266+
import_name: str | None
267+
caller_function: str | None = None
268+
269+
239270
@runtime_checkable
240271
class LanguageSupport(Protocol):
241272
"""Protocol defining what a language implementation must provide.
@@ -352,6 +383,29 @@ def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> l
352383
"""
353384
...
354385

386+
def find_references(
387+
self, function: FunctionInfo, project_root: Path, tests_root: Path | None = None, max_files: int = 500
388+
) -> list[ReferenceInfo]:
389+
"""Find all references (call sites) to a function across the codebase.
390+
391+
This method finds all places where a function is called, including:
392+
- Direct calls
393+
- Callbacks (passed to other functions)
394+
- Memoized versions
395+
- Re-exports
396+
397+
Args:
398+
function: The function to find references for.
399+
project_root: Root of the project to search.
400+
tests_root: Root of tests directory (references in tests are excluded).
401+
max_files: Maximum number of files to search.
402+
403+
Returns:
404+
List of ReferenceInfo objects describing each reference location.
405+
406+
"""
407+
...
408+
355409
# === Code Transformation ===
356410

357411
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:

0 commit comments

Comments
 (0)