2323 TestResult ,
2424)
2525from codeflash .languages .registry import register_language
26- from codeflash .languages .treesitter_utils import TreeSitterAnalyzer , TreeSitterLanguage , get_analyzer_for_file
26+ from codeflash .languages .treesitter_utils import TreeSitterAnalyzer , TreeSitterLanguage , TypeDefinition , get_analyzer_for_file
2727
2828if TYPE_CHECKING :
2929 from collections .abc import Sequence
@@ -364,11 +364,33 @@ def extract_code_context(self, function: FunctionInfo, project_root: Path, modul
364364 imp_lines = lines [imp .start_line - 1 : imp .end_line ]
365365 import_lines .append ("" .join (imp_lines ).strip ())
366366
367+ # Extract type definitions for function parameters and class fields
368+ type_definitions_context , type_definition_names = self ._extract_type_definitions_context (
369+ function = function ,
370+ source = source ,
371+ analyzer = analyzer ,
372+ imports = imports ,
373+ module_root = module_root ,
374+ )
375+
367376 # Find module-level declarations (global variables/constants) referenced by the function
377+ # Exclude type definitions that are already included above to avoid duplication
368378 read_only_context = self ._find_referenced_globals (
369- target_code = target_code , helpers = helpers , source = source , analyzer = analyzer , imports = imports
379+ target_code = target_code ,
380+ helpers = helpers ,
381+ source = source ,
382+ analyzer = analyzer ,
383+ imports = imports ,
384+ exclude_names = type_definition_names ,
370385 )
371386
387+ # Combine type definitions with other read-only context
388+ if type_definitions_context :
389+ if read_only_context :
390+ read_only_context = type_definitions_context + "\n \n " + read_only_context
391+ else :
392+ read_only_context = type_definitions_context
393+
372394 # Validate that the extracted code is syntactically valid
373395 # If not, raise an error to fail the optimization early
374396 if target_code and not self .validate_syntax (target_code ):
@@ -612,6 +634,7 @@ def _find_referenced_globals(
612634 source : str ,
613635 analyzer : TreeSitterAnalyzer ,
614636 imports : list [Any ],
637+ exclude_names : set [str ] | None = None ,
615638 ) -> str :
616639 """Find module-level declarations referenced by the target function and its helpers.
617640
@@ -621,11 +644,15 @@ def _find_referenced_globals(
621644 source: Full source code of the file.
622645 analyzer: TreeSitterAnalyzer for parsing.
623646 imports: List of ImportInfo objects.
647+ exclude_names: Names to exclude from the result (e.g., type definitions).
624648
625649 Returns:
626650 String containing all referenced global declarations.
627651
628652 """
653+ if exclude_names is None :
654+ exclude_names = set ()
655+
629656 # Find all module-level declarations in the source file
630657 module_declarations = analyzer .find_module_level_declarations (source )
631658
@@ -646,8 +673,8 @@ def _find_referenced_globals(
646673 decl_map : dict [str , Any ] = {}
647674 for decl in module_declarations :
648675 # Skip function declarations (they are handled as helpers)
649- # Also skip if it's an import
650- if decl .name not in imported_names :
676+ # Also skip if it's an import or an excluded name (type definitions)
677+ if decl .name not in imported_names and decl . name not in exclude_names :
651678 decl_map [decl .name ] = decl
652679
653680 if not decl_map :
@@ -687,6 +714,172 @@ def _find_referenced_globals(
687714 global_lines = [decl .source_code for decl in referenced_globals ]
688715 return "\n " .join (global_lines )
689716
717+ def _extract_type_definitions_context (
718+ self ,
719+ function : FunctionInfo ,
720+ source : str ,
721+ analyzer : TreeSitterAnalyzer ,
722+ imports : list [Any ],
723+ module_root : Path ,
724+ ) -> tuple [str , set [str ]]:
725+ """Extract type definitions used by the function for read-only context.
726+
727+ Finds user-defined types referenced in:
728+ 1. Function parameters
729+ 2. Function return type
730+ 3. Class fields (if the function is a class method)
731+
732+ Then looks up these type definitions in:
733+ 1. The same file
734+ 2. Imported files
735+
736+ Args:
737+ function: The target function to analyze.
738+ source: Source code of the file.
739+ analyzer: TreeSitterAnalyzer for parsing.
740+ imports: List of ImportInfo objects.
741+ module_root: Root directory of the module.
742+
743+ Returns:
744+ Tuple of (type definitions string, set of found type names).
745+
746+ """
747+ # Extract type names from function parameters and return type
748+ type_names = analyzer .extract_type_annotations (source , function .name , function .start_line or 1 )
749+
750+ # If this is a class method, also extract types from class fields
751+ if function .is_method and function .parents :
752+ for parent in function .parents :
753+ if parent .type == "ClassDef" :
754+ field_types = analyzer .extract_class_field_types (source , parent .name )
755+ type_names .update (field_types )
756+
757+ if not type_names :
758+ return "" , set ()
759+
760+ # Find type definitions in the same file
761+ same_file_definitions = analyzer .find_type_definitions (source )
762+ found_definitions : list [TypeDefinition ] = []
763+
764+ # Build a map of type name -> definition for same-file types
765+ same_file_type_map = {defn .name : defn for defn in same_file_definitions }
766+
767+ # Track which types we've found (avoid duplicates)
768+ found_type_names : set [str ] = set ()
769+
770+ # First, look for types defined in the same file
771+ for type_name in type_names :
772+ if type_name in same_file_type_map and type_name not in found_type_names :
773+ found_definitions .append (same_file_type_map [type_name ])
774+ found_type_names .add (type_name )
775+
776+ # For types not found in same file, look in imported files
777+ remaining_types = type_names - found_type_names
778+ if remaining_types :
779+ imported_definitions = self ._find_imported_type_definitions (
780+ remaining_types , imports , module_root , function .file_path
781+ )
782+ for defn in imported_definitions :
783+ found_definitions .append (defn )
784+ found_type_names .add (defn .name )
785+
786+ if not found_definitions :
787+ return "" , found_type_names
788+
789+ # Sort by file path and line number for consistent ordering
790+ found_definitions .sort (key = lambda d : (str (d .file_path or "" ), d .start_line ))
791+
792+ # Build the type definitions context string
793+ # Group by file for better organization
794+ type_def_parts : list [str ] = []
795+ current_file : Path | None = None
796+
797+ for defn in found_definitions :
798+ if defn .file_path and defn .file_path != current_file :
799+ current_file = defn .file_path
800+ # Add a comment indicating the source file
801+ type_def_parts .append (f"// From { current_file .name } " )
802+
803+ type_def_parts .append (defn .source_code )
804+
805+ return "\n \n " .join (type_def_parts ), found_type_names
806+
807+ def _find_imported_type_definitions (
808+ self ,
809+ type_names : set [str ],
810+ imports : list [Any ],
811+ module_root : Path ,
812+ source_file_path : Path ,
813+ ) -> list [TypeDefinition ]:
814+ """Find type definitions in imported files.
815+
816+ Args:
817+ type_names: Set of type names to look for.
818+ imports: List of ImportInfo objects from the source file.
819+ module_root: Root directory of the module.
820+ source_file_path: Path to the source file (for resolving relative imports).
821+
822+ Returns:
823+ List of TypeDefinition objects found in imported files.
824+
825+ """
826+ found_definitions : list [TypeDefinition ] = []
827+
828+ # Build a map of type names to their import info and original names
829+ type_import_map : dict [str , tuple [Any , str ]] = {} # local_name -> (ImportInfo, original_name)
830+ for imp in imports :
831+ # Check if any of our type names are imported from this module
832+ for name , alias in imp .named_imports :
833+ # The type could be imported with an alias
834+ local_name = alias if alias else name
835+ if local_name in type_names :
836+ type_import_map [local_name ] = (imp , name ) # (ImportInfo, original_name)
837+
838+ if not type_import_map :
839+ return found_definitions
840+
841+ # Resolve imports and find type definitions
842+ from codeflash .languages .javascript .import_resolver import ImportResolver
843+
844+ try :
845+ import_resolver = ImportResolver (module_root )
846+ except Exception :
847+ logger .debug ("Failed to create ImportResolver for type definition lookup" )
848+ return found_definitions
849+
850+ for local_name , (import_info , original_name ) in type_import_map .items ():
851+ try :
852+ # Resolve the import to an actual file path
853+ resolved_import = import_resolver .resolve_import (import_info , source_file_path )
854+ if not resolved_import or not resolved_import .file_path .exists ():
855+ continue
856+
857+ resolved_path = resolved_import .file_path
858+
859+ # Read the source file and find type definitions
860+ try :
861+ imported_source = resolved_path .read_text (encoding = "utf-8" )
862+ except Exception :
863+ continue
864+
865+ # Get analyzer for the imported file
866+ imported_analyzer = get_analyzer_for_file (resolved_path )
867+ type_defs = imported_analyzer .find_type_definitions (imported_source )
868+
869+ # Find the type we're looking for
870+ for defn in type_defs :
871+ if defn .name == original_name :
872+ # Add file path info to the definition
873+ defn .file_path = resolved_path
874+ found_definitions .append (defn )
875+ break
876+
877+ except Exception as e :
878+ logger .debug (f"Failed to resolve type definition for { local_name } : { e } " )
879+ continue
880+
881+ return found_definitions
882+
690883 def find_helper_functions (self , function : FunctionInfo , project_root : Path ) -> list [HelperFunction ]:
691884 """Find helper functions called by the target function.
692885
0 commit comments