@@ -70,6 +70,12 @@ def build_testgen_context(
7070 code_strings = testgen_context .code_strings + external_base_inits .code_strings
7171 )
7272
73+ external_class_inits = get_external_class_inits (testgen_context , project_root_path )
74+ if external_class_inits .code_strings :
75+ testgen_context = CodeStringsMarkdown (
76+ code_strings = testgen_context .code_strings + external_class_inits .code_strings
77+ )
78+
7379 return testgen_context
7480
7581
@@ -821,6 +827,210 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
821827 return CodeStringsMarkdown (code_strings = code_strings )
822828
823829
830+ MAX_TRANSITIVE_DEPTH = 2
831+
832+
833+ def extract_classes_from_type_hint (hint : object ) -> list [type ]:
834+ """Recursively extract concrete class objects from a type annotation.
835+
836+ Unwraps Optional, Union, List, Dict, Callable, Annotated, etc.
837+ Filters out builtins and typing module types.
838+ """
839+ import typing
840+
841+ classes : list [type ] = []
842+ origin = getattr (hint , "__origin__" , None )
843+ args = getattr (hint , "__args__" , None )
844+
845+ if origin is not None and args :
846+ for arg in args :
847+ classes .extend (extract_classes_from_type_hint (arg ))
848+ elif isinstance (hint , type ):
849+ module = getattr (hint , "__module__" , "" )
850+ if module not in ("builtins" , "typing" , "typing_extensions" , "types" ):
851+ classes .append (hint )
852+ # Handle typing.Annotated on older Pythons where __origin__ may not be set
853+ if hasattr (typing , "get_args" ) and origin is None and args is None :
854+ try :
855+ inner_args = typing .get_args (hint )
856+ if inner_args :
857+ for arg in inner_args :
858+ classes .extend (extract_classes_from_type_hint (arg ))
859+ except Exception :
860+ pass
861+
862+ return classes
863+
864+
865+ def resolve_transitive_type_deps (cls : type ) -> list [type ]:
866+ """Find external classes referenced in cls.__init__ type annotations.
867+
868+ Returns classes from site-packages that have a custom __init__.
869+ """
870+ import inspect
871+ import typing
872+
873+ try :
874+ init_method = getattr (cls , "__init__" )
875+ hints = typing .get_type_hints (init_method )
876+ except Exception :
877+ return []
878+
879+ deps : list [type ] = []
880+ for param_name , hint in hints .items ():
881+ if param_name == "return" :
882+ continue
883+ for dep_cls in extract_classes_from_type_hint (hint ):
884+ if dep_cls is cls :
885+ continue
886+ init_method = getattr (dep_cls , "__init__" , None )
887+ if init_method is None or init_method is object .__init__ :
888+ continue
889+ try :
890+ class_file = Path (inspect .getfile (dep_cls ))
891+ except (OSError , TypeError ):
892+ continue
893+ if not path_belongs_to_site_packages (class_file ):
894+ continue
895+ deps .append (dep_cls )
896+
897+ return deps
898+
899+
900+ def extract_init_stub_for_class (cls : type , class_name : str ) -> CodeString | None :
901+ """Extract a stub containing the class definition with only its __init__ method."""
902+ import inspect
903+ import textwrap
904+
905+ init_method = getattr (cls , "__init__" , None )
906+ if init_method is None or init_method is object .__init__ :
907+ return None
908+
909+ try :
910+ class_file = Path (inspect .getfile (cls ))
911+ except (OSError , TypeError ):
912+ return None
913+
914+ if not path_belongs_to_site_packages (class_file ):
915+ return None
916+
917+ try :
918+ init_source = inspect .getsource (init_method )
919+ init_source = textwrap .dedent (init_source )
920+ except (OSError , TypeError ):
921+ return None
922+
923+ parts = class_file .parts
924+ if "site-packages" in parts :
925+ idx = parts .index ("site-packages" )
926+ class_file = Path (* parts [idx + 1 :])
927+
928+ class_source = f"class { class_name } :\n " + textwrap .indent (init_source , " " )
929+ return CodeString (code = class_source , file_path = class_file )
930+
931+
932+ def get_external_class_inits (code_context : CodeStringsMarkdown , project_root_path : Path ) -> CodeStringsMarkdown :
933+ """Extract __init__ methods from directly imported external library classes.
934+
935+ Scans the code context for classes imported from external packages (site-packages) and extracts
936+ their __init__ methods, including transitive type dependencies found in __init__ annotations.
937+ This helps the LLM understand constructor signatures for instantiation in generated tests.
938+ """
939+ import importlib
940+ import inspect
941+
942+ all_code = "\n " .join (cs .code for cs in code_context .code_strings )
943+
944+ try :
945+ tree = ast .parse (all_code )
946+ except SyntaxError :
947+ return CodeStringsMarkdown (code_strings = [])
948+
949+ # Collect all from X import Y statements
950+ imported_names : dict [str , str ] = {}
951+ is_project_cache : dict [str , bool ] = {}
952+
953+ # Track classes already defined in the context to avoid duplicates
954+ existing_classes : set [str ] = set ()
955+
956+ for node in ast .walk (tree ):
957+ if isinstance (node , ast .ImportFrom ) and node .module :
958+ for alias in node .names :
959+ if alias .name != "*" :
960+ imported_name = alias .asname if alias .asname else alias .name
961+ imported_names [imported_name ] = node .module
962+ elif isinstance (node , ast .ClassDef ):
963+ existing_classes .add (node .name )
964+
965+ if not imported_names :
966+ return CodeStringsMarkdown (code_strings = [])
967+
968+ # Filter to external-only imports
969+ external_imports : set [tuple [str , str ]] = set ()
970+ for name , module_name in imported_names .items ():
971+ if name in existing_classes :
972+ continue
973+ cached = is_project_cache .get (module_name )
974+ if cached is None :
975+ is_project = _is_project_module (module_name , project_root_path )
976+ is_project_cache [module_name ] = is_project
977+ else :
978+ is_project = cached
979+ if not is_project :
980+ external_imports .add ((name , module_name ))
981+
982+ if not external_imports :
983+ return CodeStringsMarkdown (code_strings = [])
984+
985+ code_strings : list [CodeString ] = []
986+ imported_module_cache : dict [str , object ] = {}
987+ processed_classes : set [type ] = set ()
988+ emitted_names : set [str ] = set ()
989+
990+ # BFS worklist: (class_object, class_name, depth)
991+ worklist : list [tuple [type , str , int ]] = []
992+
993+ # Seed the worklist with directly imported classes
994+ for class_name , module_name in external_imports :
995+ try :
996+ module = imported_module_cache .get (module_name )
997+ if module is None :
998+ module = importlib .import_module (module_name )
999+ imported_module_cache [module_name ] = module
1000+
1001+ cls = getattr (module , class_name , None )
1002+ if cls is None or not inspect .isclass (cls ):
1003+ continue
1004+
1005+ worklist .append ((cls , class_name , 0 ))
1006+ except (ImportError , ModuleNotFoundError , AttributeError ):
1007+ logger .debug (f"Failed to import { module_name } .{ class_name } " )
1008+ continue
1009+
1010+ while worklist :
1011+ cls , class_name , depth = worklist .pop (0 )
1012+
1013+ if cls in processed_classes :
1014+ continue
1015+ processed_classes .add (cls )
1016+
1017+ stub = extract_init_stub_for_class (cls , class_name )
1018+ if stub is None :
1019+ continue
1020+
1021+ if class_name not in emitted_names :
1022+ code_strings .append (stub )
1023+ emitted_names .add (class_name )
1024+
1025+ # Resolve transitive type dependencies up to MAX_TRANSITIVE_DEPTH
1026+ if depth < MAX_TRANSITIVE_DEPTH :
1027+ for dep_cls in resolve_transitive_type_deps (cls ):
1028+ if dep_cls not in processed_classes :
1029+ worklist .append ((dep_cls , dep_cls .__name__ , depth + 1 ))
1030+
1031+ return CodeStringsMarkdown (code_strings = code_strings )
1032+
1033+
8241034def _is_project_module (module_name : str , project_root_path : Path ) -> bool :
8251035 """Check if a module is part of the project (not external/stdlib)."""
8261036 import importlib .util
0 commit comments