Skip to content

Commit c9cd7bc

Browse files
committed
extract function argument types and interfaces in read-only context
1 parent bf765c2 commit c9cd7bc

3 files changed

Lines changed: 893 additions & 4 deletions

File tree

codeflash/languages/javascript/support.py

Lines changed: 197 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
TestResult,
2424
)
2525
from codeflash.languages.registry import register_language
26-
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
26+
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, TypeDefinition, get_analyzer_for_file
2727

2828
if TYPE_CHECKING:
2929
from collections.abc import Sequence
@@ -364,11 +364,33 @@ def extract_code_context(self, function: FunctionInfo, project_root: Path, modul
364364
imp_lines = lines[imp.start_line - 1 : imp.end_line]
365365
import_lines.append("".join(imp_lines).strip())
366366

367+
# Extract type definitions for function parameters and class fields
368+
type_definitions_context, type_definition_names = self._extract_type_definitions_context(
369+
function=function,
370+
source=source,
371+
analyzer=analyzer,
372+
imports=imports,
373+
module_root=module_root,
374+
)
375+
367376
# Find module-level declarations (global variables/constants) referenced by the function
377+
# Exclude type definitions that are already included above to avoid duplication
368378
read_only_context = self._find_referenced_globals(
369-
target_code=target_code, helpers=helpers, source=source, analyzer=analyzer, imports=imports
379+
target_code=target_code,
380+
helpers=helpers,
381+
source=source,
382+
analyzer=analyzer,
383+
imports=imports,
384+
exclude_names=type_definition_names,
370385
)
371386

387+
# Combine type definitions with other read-only context
388+
if type_definitions_context:
389+
if read_only_context:
390+
read_only_context = type_definitions_context + "\n\n" + read_only_context
391+
else:
392+
read_only_context = type_definitions_context
393+
372394
# Validate that the extracted code is syntactically valid
373395
# If not, raise an error to fail the optimization early
374396
if target_code and not self.validate_syntax(target_code):
@@ -612,6 +634,7 @@ def _find_referenced_globals(
612634
source: str,
613635
analyzer: TreeSitterAnalyzer,
614636
imports: list[Any],
637+
exclude_names: set[str] | None = None,
615638
) -> str:
616639
"""Find module-level declarations referenced by the target function and its helpers.
617640
@@ -621,11 +644,15 @@ def _find_referenced_globals(
621644
source: Full source code of the file.
622645
analyzer: TreeSitterAnalyzer for parsing.
623646
imports: List of ImportInfo objects.
647+
exclude_names: Names to exclude from the result (e.g., type definitions).
624648
625649
Returns:
626650
String containing all referenced global declarations.
627651
628652
"""
653+
if exclude_names is None:
654+
exclude_names = set()
655+
629656
# Find all module-level declarations in the source file
630657
module_declarations = analyzer.find_module_level_declarations(source)
631658

@@ -646,8 +673,8 @@ def _find_referenced_globals(
646673
decl_map: dict[str, Any] = {}
647674
for decl in module_declarations:
648675
# Skip function declarations (they are handled as helpers)
649-
# Also skip if it's an import
650-
if decl.name not in imported_names:
676+
# Also skip if it's an import or an excluded name (type definitions)
677+
if decl.name not in imported_names and decl.name not in exclude_names:
651678
decl_map[decl.name] = decl
652679

653680
if not decl_map:
@@ -687,6 +714,172 @@ def _find_referenced_globals(
687714
global_lines = [decl.source_code for decl in referenced_globals]
688715
return "\n".join(global_lines)
689716

717+
def _extract_type_definitions_context(
718+
self,
719+
function: FunctionInfo,
720+
source: str,
721+
analyzer: TreeSitterAnalyzer,
722+
imports: list[Any],
723+
module_root: Path,
724+
) -> tuple[str, set[str]]:
725+
"""Extract type definitions used by the function for read-only context.
726+
727+
Finds user-defined types referenced in:
728+
1. Function parameters
729+
2. Function return type
730+
3. Class fields (if the function is a class method)
731+
732+
Then looks up these type definitions in:
733+
1. The same file
734+
2. Imported files
735+
736+
Args:
737+
function: The target function to analyze.
738+
source: Source code of the file.
739+
analyzer: TreeSitterAnalyzer for parsing.
740+
imports: List of ImportInfo objects.
741+
module_root: Root directory of the module.
742+
743+
Returns:
744+
Tuple of (type definitions string, set of found type names).
745+
746+
"""
747+
# Extract type names from function parameters and return type
748+
type_names = analyzer.extract_type_annotations(source, function.name, function.start_line or 1)
749+
750+
# If this is a class method, also extract types from class fields
751+
if function.is_method and function.parents:
752+
for parent in function.parents:
753+
if parent.type == "ClassDef":
754+
field_types = analyzer.extract_class_field_types(source, parent.name)
755+
type_names.update(field_types)
756+
757+
if not type_names:
758+
return "", set()
759+
760+
# Find type definitions in the same file
761+
same_file_definitions = analyzer.find_type_definitions(source)
762+
found_definitions: list[TypeDefinition] = []
763+
764+
# Build a map of type name -> definition for same-file types
765+
same_file_type_map = {defn.name: defn for defn in same_file_definitions}
766+
767+
# Track which types we've found (avoid duplicates)
768+
found_type_names: set[str] = set()
769+
770+
# First, look for types defined in the same file
771+
for type_name in type_names:
772+
if type_name in same_file_type_map and type_name not in found_type_names:
773+
found_definitions.append(same_file_type_map[type_name])
774+
found_type_names.add(type_name)
775+
776+
# For types not found in same file, look in imported files
777+
remaining_types = type_names - found_type_names
778+
if remaining_types:
779+
imported_definitions = self._find_imported_type_definitions(
780+
remaining_types, imports, module_root, function.file_path
781+
)
782+
for defn in imported_definitions:
783+
found_definitions.append(defn)
784+
found_type_names.add(defn.name)
785+
786+
if not found_definitions:
787+
return "", found_type_names
788+
789+
# Sort by file path and line number for consistent ordering
790+
found_definitions.sort(key=lambda d: (str(d.file_path or ""), d.start_line))
791+
792+
# Build the type definitions context string
793+
# Group by file for better organization
794+
type_def_parts: list[str] = []
795+
current_file: Path | None = None
796+
797+
for defn in found_definitions:
798+
if defn.file_path and defn.file_path != current_file:
799+
current_file = defn.file_path
800+
# Add a comment indicating the source file
801+
type_def_parts.append(f"// From {current_file.name}")
802+
803+
type_def_parts.append(defn.source_code)
804+
805+
return "\n\n".join(type_def_parts), found_type_names
806+
807+
def _find_imported_type_definitions(
808+
self,
809+
type_names: set[str],
810+
imports: list[Any],
811+
module_root: Path,
812+
source_file_path: Path,
813+
) -> list[TypeDefinition]:
814+
"""Find type definitions in imported files.
815+
816+
Args:
817+
type_names: Set of type names to look for.
818+
imports: List of ImportInfo objects from the source file.
819+
module_root: Root directory of the module.
820+
source_file_path: Path to the source file (for resolving relative imports).
821+
822+
Returns:
823+
List of TypeDefinition objects found in imported files.
824+
825+
"""
826+
found_definitions: list[TypeDefinition] = []
827+
828+
# Build a map of type names to their import info and original names
829+
type_import_map: dict[str, tuple[Any, str]] = {} # local_name -> (ImportInfo, original_name)
830+
for imp in imports:
831+
# Check if any of our type names are imported from this module
832+
for name, alias in imp.named_imports:
833+
# The type could be imported with an alias
834+
local_name = alias if alias else name
835+
if local_name in type_names:
836+
type_import_map[local_name] = (imp, name) # (ImportInfo, original_name)
837+
838+
if not type_import_map:
839+
return found_definitions
840+
841+
# Resolve imports and find type definitions
842+
from codeflash.languages.javascript.import_resolver import ImportResolver
843+
844+
try:
845+
import_resolver = ImportResolver(module_root)
846+
except Exception:
847+
logger.debug("Failed to create ImportResolver for type definition lookup")
848+
return found_definitions
849+
850+
for local_name, (import_info, original_name) in type_import_map.items():
851+
try:
852+
# Resolve the import to an actual file path
853+
resolved_import = import_resolver.resolve_import(import_info, source_file_path)
854+
if not resolved_import or not resolved_import.file_path.exists():
855+
continue
856+
857+
resolved_path = resolved_import.file_path
858+
859+
# Read the source file and find type definitions
860+
try:
861+
imported_source = resolved_path.read_text(encoding="utf-8")
862+
except Exception:
863+
continue
864+
865+
# Get analyzer for the imported file
866+
imported_analyzer = get_analyzer_for_file(resolved_path)
867+
type_defs = imported_analyzer.find_type_definitions(imported_source)
868+
869+
# Find the type we're looking for
870+
for defn in type_defs:
871+
if defn.name == original_name:
872+
# Add file path info to the definition
873+
defn.file_path = resolved_path
874+
found_definitions.append(defn)
875+
break
876+
877+
except Exception as e:
878+
logger.debug(f"Failed to resolve type definition for {local_name}: {e}")
879+
continue
880+
881+
return found_definitions
882+
690883
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
691884
"""Find helper functions called by the target function.
692885

0 commit comments

Comments
 (0)