@@ -1563,138 +1563,197 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
15631563def get_opt_review_metrics (
15641564 source_code : str , file_path : Path , qualified_name : str , project_root : Path , tests_root : Path , language : Language
15651565) -> str :
1566- start_time = time . perf_counter ()
1566+ """Get function reference metrics for optimization review.
15671567
1568- if language == Language .PYTHON :
1569- calling_fns_details = _get_python_references (source_code , file_path , qualified_name , project_root , tests_root )
1570- elif language in (Language .JAVASCRIPT , Language .TYPESCRIPT ):
1571- calling_fns_details = _get_javascript_references (file_path , qualified_name , project_root , tests_root )
1572- else :
1573- calling_fns_details = ""
1568+ Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript.
15741569
1575- end_time = time .perf_counter ()
1576- logger .debug (f"Got function references in { end_time - start_time :.2f} seconds" )
1577- return calling_fns_details
1570+ Args:
1571+ source_code: Source code of the file containing the function.
1572+ file_path: Path to the file.
1573+ qualified_name: Qualified name of the function (e.g., "module.ClassName.method").
1574+ project_root: Root of the project.
1575+ tests_root: Root of the tests directory.
1576+ language: The programming language.
1577+
1578+ Returns:
1579+ Markdown-formatted string with code blocks showing calling functions.
1580+ """
1581+ from codeflash .languages .base import FunctionInfo , ParentInfo , ReferenceInfo
1582+ from codeflash .languages .registry import get_language_support
15781583
1584+ start_time = time .perf_counter ()
15791585
1580- def _get_python_references (
1581- source_code : str , file_path : Path , qualified_name : str , project_root : Path , tests_root : Path
1582- ) -> str :
1583- """Get function references for Python code using jedi."""
15841586 try :
1587+ # Get the language support
1588+ lang_support = get_language_support (language )
1589+ if lang_support is None :
1590+ return ""
1591+
1592+ # Parse qualified name to get function name and class name
15851593 qualified_name_split = qualified_name .rsplit ("." , maxsplit = 1 )
15861594 if len (qualified_name_split ) == 1 :
1587- target_function , target_class = qualified_name_split [0 ], None
1595+ function_name , class_name = qualified_name_split [0 ], None
15881596 else :
1589- target_function , target_class = qualified_name_split [1 ], qualified_name_split [0 ]
1590- matches = get_fn_references_jedi (
1591- source_code , file_path , project_root , target_function , target_class
1592- ) # jedi is not perfect, it doesn't capture aliased references
1593- return find_occurances (qualified_name , str (file_path ), matches , project_root , tests_root )
1597+ function_name , class_name = qualified_name_split [1 ], qualified_name_split [0 ]
1598+
1599+ # Create a FunctionInfo for the function
1600+ # We don't have full line info here, so we'll use defaults
1601+ parents = ()
1602+ if class_name :
1603+ parents = (ParentInfo (name = class_name , type = "ClassDef" ),)
1604+
1605+ func_info = FunctionInfo (
1606+ name = function_name ,
1607+ file_path = file_path ,
1608+ start_line = 1 ,
1609+ end_line = 1 ,
1610+ parents = parents ,
1611+ language = language ,
1612+ )
1613+
1614+ # Find references using language support
1615+ references = lang_support .find_references (func_info , project_root , tests_root , max_files = 500 )
1616+
1617+ if not references :
1618+ return ""
1619+
1620+ # Format references as markdown code blocks
1621+ calling_fns_details = _format_references_as_markdown (
1622+ references , file_path , project_root , language
1623+ )
1624+
15941625 except Exception as e :
1595- logger .debug (f"Error getting Python references: { e } " )
1596- return ""
1626+ logger .debug (f"Error getting function references: { e } " )
1627+ calling_fns_details = ""
15971628
1629+ end_time = time .perf_counter ()
1630+ logger .debug (f"Got function references in { end_time - start_time :.2f} seconds" )
1631+ return calling_fns_details
15981632
1599- def _get_javascript_references (
1600- file_path : Path , qualified_name : str , project_root : Path , tests_root : Path
1633+
1634+ def _format_references_as_markdown (
1635+ references : list , file_path : Path , project_root : Path , language : Language
16011636) -> str :
1602- """Get function references for JavaScript/TypeScript code using tree-sitter .
1637+ """Format references as markdown code blocks with calling function code .
16031638
1604- This function finds all call sites of a JavaScript/TypeScript function
1605- across the codebase and formats them for the optimizer's context.
1606- """
1607- try :
1608- from codeflash .languages .javascript .find_references import ReferenceFinder
1609- from codeflash .languages .treesitter_utils import get_analyzer_for_file
1639+ Args:
1640+ references: List of ReferenceInfo objects.
1641+ file_path: Path to the source file (to exclude).
1642+ project_root: Root of the project.
1643+ language: The programming language.
16101644
1611- # Extract function name from qualified name
1612- # Qualified name could be "functionName" or "ClassName.methodName"
1613- function_name = qualified_name .rsplit ("." , maxsplit = 1 )[- 1 ]
1645+ Returns:
1646+ Markdown-formatted string.
1647+ """
1648+ # Group references by file
1649+ refs_by_file : dict [Path , list ] = {}
1650+ for ref in references :
1651+ # Exclude the source file's definition/import references
1652+ if ref .file_path == file_path and ref .reference_type in ("import" , "reexport" ):
1653+ continue
16141654
1615- finder = ReferenceFinder (project_root )
1616- references = finder .find_references (function_name , file_path , max_files = 500 )
1655+ if ref .file_path not in refs_by_file :
1656+ refs_by_file [ref .file_path ] = []
1657+ refs_by_file [ref .file_path ].append (ref )
16171658
1618- if not references :
1619- return ""
1659+ fn_call_context = ""
1660+ context_len = 0
16201661
1621- # Format references similar to Python format
1622- fn_call_context = ""
1623- context_len = 0
1662+ for ref_file , file_refs in refs_by_file . items ():
1663+ if context_len > MAX_CONTEXT_LEN_REVIEW :
1664+ break
16241665
1625- # Group references by file
1626- refs_by_file : dict [Path , list ] = {}
1627- for ref in references :
1628- # Exclude test files
1629- try :
1630- if ref .file_path .relative_to (tests_root ):
1631- continue
1632- except ValueError :
1633- pass
1666+ try :
1667+ path_relative = ref_file .relative_to (project_root )
1668+ except ValueError :
1669+ continue
16341670
1635- # Exclude the source file's definition
1636- if ref .file_path == file_path and ref .reference_type == "import" :
1637- continue
1671+ # Get syntax highlighting language
1672+ ext = ref_file .suffix .lstrip ("." )
1673+ if language == Language .PYTHON :
1674+ lang_hint = "python"
1675+ elif ext in ("ts" , "tsx" ):
1676+ lang_hint = "typescript"
1677+ else :
1678+ lang_hint = "javascript"
16381679
1639- if ref .file_path not in refs_by_file :
1640- refs_by_file [ref .file_path ] = []
1641- refs_by_file [ref .file_path ].append (ref )
1680+ # Read the file to extract calling function context
1681+ try :
1682+ file_content = ref_file .read_text (encoding = "utf-8" )
1683+ lines = file_content .splitlines ()
1684+ except Exception :
1685+ continue
16421686
1643- for ref_file , file_refs in refs_by_file . items ():
1644- if context_len > MAX_CONTEXT_LEN_REVIEW :
1645- break
1687+ # Get unique caller functions from this file
1688+ callers_seen : set [ str ] = set ()
1689+ caller_contexts : list [ str ] = []
16461690
1647- try :
1648- path_relative = ref_file . relative_to ( project_root )
1649- except ValueError :
1691+ for ref in file_refs :
1692+ caller = ref . caller_function or "<module>"
1693+ if caller in callers_seen :
16501694 continue
1695+ callers_seen .add (caller )
1696+
1697+ # Extract context around the reference
1698+ if ref .caller_function :
1699+ # Try to extract the full calling function
1700+ func_code = _extract_calling_function (file_content , ref .caller_function , ref .line , language )
1701+ if func_code :
1702+ caller_contexts .append (func_code )
1703+ context_len += len (func_code )
1704+ else :
1705+ # Module-level call - show a few lines of context
1706+ start_line = max (0 , ref .line - 3 )
1707+ end_line = min (len (lines ), ref .line + 2 )
1708+ context_code = "\n " .join (lines [start_line :end_line ])
1709+ caller_contexts .append (context_code )
1710+ context_len += len (context_code )
1711+
1712+ if caller_contexts :
1713+ fn_call_context += f"```{ lang_hint } :{ path_relative } \n "
1714+ fn_call_context += "\n " .join (caller_contexts )
1715+ fn_call_context += "\n ```\n "
16511716
1652- # Get the file extension for syntax highlighting
1653- ext = ref_file .suffix .lstrip ("." )
1654- lang = "typescript" if ext in ("ts" , "tsx" ) else "javascript"
1717+ return fn_call_context
16551718
1656- # Read the file to extract calling function context
1657- try :
1658- file_content = ref_file .read_text (encoding = "utf-8" )
1659- lines = file_content .splitlines ()
1660- except Exception :
1661- continue
16621719
1663- # Get unique caller functions from this file
1664- callers_seen = set ()
1665- caller_contexts = []
1720+ def _extract_calling_function (source_code : str , function_name : str , ref_line : int , language : Language ) -> str | None :
1721+ """Extract the source code of a calling function.
16661722
1667- for ref in file_refs :
1668- caller = ref .caller_function or "<module>"
1669- if caller in callers_seen :
1670- continue
1671- callers_seen .add (caller )
1672-
1673- # Extract context around the reference (the calling function or surrounding lines)
1674- if ref .caller_function :
1675- # Try to extract the full calling function
1676- func_code = _extract_calling_function_js (file_content , ref .caller_function , ref .line )
1677- if func_code :
1678- caller_contexts .append (func_code )
1679- context_len += len (func_code )
1680- else :
1681- # Module-level call - just show a few lines of context
1682- start_line = max (0 , ref .line - 3 )
1683- end_line = min (len (lines ), ref .line + 2 )
1684- context_code = "\n " .join (lines [start_line :end_line ])
1685- caller_contexts .append (context_code )
1686- context_len += len (context_code )
1723+ Args:
1724+ source_code: Full source code of the file.
1725+ function_name: Name of the function to extract.
1726+ ref_line: Line number where the reference is.
1727+ language: The programming language.
16871728
1688- if caller_contexts :
1689- fn_call_context += f"```{ lang } :{ path_relative } \n "
1690- fn_call_context += "\n " .join (caller_contexts )
1691- fn_call_context += "\n ```\n "
1729+ Returns:
1730+ Source code of the function, or None if not found.
1731+ """
1732+ if language == Language .PYTHON :
1733+ return _extract_calling_function_python (source_code , function_name , ref_line )
1734+ else :
1735+ return _extract_calling_function_js (source_code , function_name , ref_line )
16921736
1693- return fn_call_context
16941737
1695- except Exception as e :
1696- logger .debug (f"Error getting JavaScript references: { e } " )
1697- return ""
1738+ def _extract_calling_function_python (source_code : str , function_name : str , ref_line : int ) -> str | None :
1739+ """Extract the source code of a calling function in Python."""
1740+ try :
1741+ import ast
1742+
1743+ tree = ast .parse (source_code )
1744+ lines = source_code .splitlines ()
1745+
1746+ for node in ast .walk (tree ):
1747+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )):
1748+ if node .name == function_name :
1749+ # Check if the reference line is within this function
1750+ start_line = node .lineno
1751+ end_line = node .end_lineno or start_line
1752+ if start_line <= ref_line <= end_line :
1753+ return "\n " .join (lines [start_line - 1 : end_line ])
1754+ return None
1755+ except Exception :
1756+ return None
16981757
16991758
17001759def _extract_calling_function_js (source_code : str , function_name : str , ref_line : int ) -> str | None :
0 commit comments