@@ -127,20 +127,48 @@ def get_code_optimization_context(
127127 remove_docstrings = False ,
128128 code_context_type = CodeContextType .TESTGEN ,
129129 )
130+
131+ # Extract class definitions for imported types from project modules
132+ # This helps the LLM understand class constructors and structure
133+ imported_class_context = get_imported_class_definitions (testgen_context , project_root_path )
134+ if imported_class_context .code_strings :
135+ # Merge imported class definitions into testgen context
136+ testgen_context = CodeStringsMarkdown (
137+ code_strings = testgen_context .code_strings + imported_class_context .code_strings
138+ )
139+
130140 testgen_markdown_code = testgen_context .markdown
131141 testgen_code_token_length = encoded_tokens_len (testgen_markdown_code )
132142 if testgen_code_token_length > testgen_token_limit :
143+ # First try removing docstrings
133144 testgen_context = extract_code_markdown_context_from_files (
134145 helpers_of_fto_dict ,
135146 helpers_of_helpers_dict ,
136147 project_root_path ,
137148 remove_docstrings = True ,
138149 code_context_type = CodeContextType .TESTGEN ,
139150 )
151+ # Re-extract imported classes (they may still fit)
152+ imported_class_context = get_imported_class_definitions (testgen_context , project_root_path )
153+ if imported_class_context .code_strings :
154+ testgen_context = CodeStringsMarkdown (
155+ code_strings = testgen_context .code_strings + imported_class_context .code_strings
156+ )
140157 testgen_markdown_code = testgen_context .markdown
141158 testgen_code_token_length = encoded_tokens_len (testgen_markdown_code )
142159 if testgen_code_token_length > testgen_token_limit :
143- raise ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
160+ # If still over limit, try without imported class definitions
161+ testgen_context = extract_code_markdown_context_from_files (
162+ helpers_of_fto_dict ,
163+ helpers_of_helpers_dict ,
164+ project_root_path ,
165+ remove_docstrings = True ,
166+ code_context_type = CodeContextType .TESTGEN ,
167+ )
168+ testgen_markdown_code = testgen_context .markdown
169+ testgen_code_token_length = encoded_tokens_len (testgen_markdown_code )
170+ if testgen_code_token_length > testgen_token_limit :
171+ raise ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
144172 code_hash_context = hashing_code_context .markdown
145173 code_hash = hashlib .sha256 (code_hash_context .encode ("utf-8" )).hexdigest ()
146174
@@ -489,6 +517,147 @@ def get_function_sources_from_jedi(
489517 return file_path_to_function_source , function_source_list
490518
491519
520+ def get_imported_class_definitions (code_context : CodeStringsMarkdown , project_root_path : Path ) -> CodeStringsMarkdown :
521+ """Extract class definitions for imported types from project modules.
522+
523+ This function analyzes the imports in the extracted code context and fetches
524+ class definitions for any classes imported from project modules. This helps
525+ the LLM understand the actual class structure (constructors, methods, inheritance)
526+ rather than just seeing import statements.
527+
528+ Args:
529+ code_context: The already extracted code context containing imports
530+ project_root_path: Root path of the project
531+
532+ Returns:
533+ CodeStringsMarkdown containing class definitions from imported project modules
534+
535+ """
536+ import jedi
537+
538+ # Collect all code from the context
539+ all_code = "\n " .join (cs .code for cs in code_context .code_strings )
540+
541+ # Parse to find import statements
542+ try :
543+ tree = ast .parse (all_code )
544+ except SyntaxError :
545+ return CodeStringsMarkdown (code_strings = [])
546+
547+ # Collect imported names and their source modules
548+ imported_names : dict [str , str ] = {} # name -> module_path
549+ for node in ast .walk (tree ):
550+ if isinstance (node , ast .ImportFrom ) and node .module :
551+ for alias in node .names :
552+ if alias .name != "*" :
553+ imported_name = alias .asname if alias .asname else alias .name
554+ imported_names [imported_name ] = node .module
555+
556+ if not imported_names :
557+ return CodeStringsMarkdown (code_strings = [])
558+
559+ # Track which classes we've already extracted to avoid duplicates
560+ extracted_classes : set [tuple [Path , str ]] = set () # (file_path, class_name)
561+
562+ # Also track what's already defined in the context
563+ existing_definitions : set [str ] = set ()
564+ for node in ast .walk (tree ):
565+ if isinstance (node , ast .ClassDef ):
566+ existing_definitions .add (node .name )
567+
568+ class_code_strings : list [CodeString ] = []
569+
570+ for name , module_name in imported_names .items ():
571+ # Skip if already defined in context
572+ if name in existing_definitions :
573+ continue
574+
575+ # Try to find the module file using Jedi
576+ try :
577+ # Create a script that imports the module to resolve it
578+ test_code = f"import { module_name } "
579+ script = jedi .Script (test_code , project = jedi .Project (path = project_root_path ))
580+ completions = script .goto (1 , len (test_code ))
581+
582+ if not completions :
583+ continue
584+
585+ module_path = completions [0 ].module_path
586+ if not module_path :
587+ continue
588+
589+ # Check if this is a project module (not stdlib/third-party)
590+ if not str (module_path ).startswith (str (project_root_path ) + os .sep ):
591+ continue
592+ if path_belongs_to_site_packages (module_path ):
593+ continue
594+
595+ # Skip if we've already extracted this class
596+ if (module_path , name ) in extracted_classes :
597+ continue
598+
599+ # Parse the module to find the class definition
600+ module_source = module_path .read_text (encoding = "utf-8" )
601+ module_tree = ast .parse (module_source )
602+
603+ for node in ast .walk (module_tree ):
604+ if isinstance (node , ast .ClassDef ) and node .name == name :
605+ # Extract the class source code
606+ lines = module_source .split ("\n " )
607+ class_source = "\n " .join (lines [node .lineno - 1 : node .end_lineno ])
608+
609+ # Also extract any necessary imports for the class (base classes, type hints)
610+ class_imports = _extract_imports_for_class (module_tree , node , module_source )
611+
612+ full_source = class_imports + "\n \n " + class_source if class_imports else class_source
613+
614+ class_code_strings .append (CodeString (code = full_source , file_path = module_path ))
615+ extracted_classes .add ((module_path , name ))
616+ break
617+
618+ except Exception :
619+ logger .debug (f"Error extracting class definition for { name } from { module_name } " )
620+ continue
621+
622+ return CodeStringsMarkdown (code_strings = class_code_strings )
623+
624+
625+ def _extract_imports_for_class (module_tree : ast .Module , class_node : ast .ClassDef , module_source : str ) -> str :
626+ """Extract import statements needed for a class definition.
627+
628+ This extracts imports for base classes and commonly used type annotations.
629+ """
630+ needed_names : set [str ] = set ()
631+
632+ # Get base class names
633+ for base in class_node .bases :
634+ if isinstance (base , ast .Name ):
635+ needed_names .add (base .id )
636+ elif isinstance (base , ast .Attribute ) and isinstance (base .value , ast .Name ):
637+ # For things like abc.ABC, we need the module name
638+ needed_names .add (base .value .id )
639+
640+ # Find imports that provide these names
641+ import_lines : list [str ] = []
642+ source_lines = module_source .split ("\n " )
643+
644+ for node in module_tree .body :
645+ if isinstance (node , ast .Import ):
646+ for alias in node .names :
647+ name = alias .asname if alias .asname else alias .name .split ("." )[0 ]
648+ if name in needed_names :
649+ import_lines .append (source_lines [node .lineno - 1 ])
650+ break
651+ elif isinstance (node , ast .ImportFrom ):
652+ for alias in node .names :
653+ name = alias .asname if alias .asname else alias .name
654+ if name in needed_names :
655+ import_lines .append (source_lines [node .lineno - 1 ])
656+ break
657+
658+ return "\n " .join (import_lines )
659+
660+
492661def is_dunder_method (name : str ) -> bool :
493662 return len (name ) > 4 and name .isascii () and name .startswith ("__" ) and name .endswith ("__" )
494663
0 commit comments