@@ -52,6 +52,7 @@ def build_testgen_context(
5252 * ,
5353 remove_docstrings : bool = False ,
5454 include_enrichment : bool = True ,
55+ function_to_optimize : FunctionToOptimize | None = None ,
5556) -> CodeStringsMarkdown :
5657 testgen_context = extract_code_markdown_context_from_files (
5758 helpers_of_fto_dict ,
@@ -66,6 +67,17 @@ def build_testgen_context(
6667 if enrichment .code_strings :
6768 testgen_context = CodeStringsMarkdown (code_strings = testgen_context .code_strings + enrichment .code_strings )
6869
70+ if function_to_optimize is not None :
71+ result = _parse_and_collect_imports (testgen_context )
72+ existing_classes = collect_existing_class_names (result [0 ]) if result else set ()
73+ constructor_stubs = extract_parameter_type_constructors (
74+ function_to_optimize , project_root_path , existing_classes
75+ )
76+ if constructor_stubs .code_strings :
77+ testgen_context = CodeStringsMarkdown (
78+ code_strings = testgen_context .code_strings + constructor_stubs .code_strings
79+ )
80+
6981 return testgen_context
7082
7183
@@ -156,12 +168,18 @@ def get_code_optimization_context(
156168 read_only_context_code = ""
157169
158170 # Progressive fallback for testgen context token limits
159- testgen_context = build_testgen_context (helpers_of_fto_dict , helpers_of_helpers_dict , project_root_path )
171+ testgen_context = build_testgen_context (
172+ helpers_of_fto_dict , helpers_of_helpers_dict , project_root_path , function_to_optimize = function_to_optimize
173+ )
160174
161175 if encoded_tokens_len (testgen_context .markdown ) > testgen_token_limit :
162176 logger .debug ("Testgen context exceeded token limit, removing docstrings" )
163177 testgen_context = build_testgen_context (
164- helpers_of_fto_dict , helpers_of_helpers_dict , project_root_path , remove_docstrings = True
178+ helpers_of_fto_dict ,
179+ helpers_of_helpers_dict ,
180+ project_root_path ,
181+ remove_docstrings = True ,
182+ function_to_optimize = function_to_optimize ,
165183 )
166184
167185 if encoded_tokens_len (testgen_context .markdown ) > testgen_token_limit :
@@ -627,6 +645,205 @@ def collect_existing_class_names(tree: ast.Module) -> set[str]:
627645 return class_names
628646
629647
648+ BUILTIN_AND_TYPING_NAMES = frozenset (
649+ {
650+ "int" ,
651+ "str" ,
652+ "float" ,
653+ "bool" ,
654+ "bytes" ,
655+ "bytearray" ,
656+ "complex" ,
657+ "list" ,
658+ "dict" ,
659+ "set" ,
660+ "frozenset" ,
661+ "tuple" ,
662+ "type" ,
663+ "object" ,
664+ "None" ,
665+ "NoneType" ,
666+ "Ellipsis" ,
667+ "NotImplemented" ,
668+ "memoryview" ,
669+ "range" ,
670+ "slice" ,
671+ "property" ,
672+ "classmethod" ,
673+ "staticmethod" ,
674+ "super" ,
675+ "Optional" ,
676+ "Union" ,
677+ "Any" ,
678+ "List" ,
679+ "Dict" ,
680+ "Set" ,
681+ "FrozenSet" ,
682+ "Tuple" ,
683+ "Type" ,
684+ "Callable" ,
685+ "Iterator" ,
686+ "Generator" ,
687+ "Coroutine" ,
688+ "AsyncGenerator" ,
689+ "AsyncIterator" ,
690+ "Iterable" ,
691+ "AsyncIterable" ,
692+ "Sequence" ,
693+ "MutableSequence" ,
694+ "Mapping" ,
695+ "MutableMapping" ,
696+ "Collection" ,
697+ "Awaitable" ,
698+ "Literal" ,
699+ "Final" ,
700+ "ClassVar" ,
701+ "TypeVar" ,
702+ "TypeAlias" ,
703+ "ParamSpec" ,
704+ "Concatenate" ,
705+ "Annotated" ,
706+ "TypeGuard" ,
707+ "Self" ,
708+ "Unpack" ,
709+ "TypeVarTuple" ,
710+ "Never" ,
711+ "NoReturn" ,
712+ "SupportsInt" ,
713+ "SupportsFloat" ,
714+ "SupportsComplex" ,
715+ "SupportsBytes" ,
716+ "SupportsAbs" ,
717+ "SupportsRound" ,
718+ "IO" ,
719+ "TextIO" ,
720+ "BinaryIO" ,
721+ "Pattern" ,
722+ "Match" ,
723+ }
724+ )
725+
726+
727+ def collect_type_names_from_annotation (node : ast .expr | None ) -> set [str ]:
728+ if node is None :
729+ return set ()
730+ if isinstance (node , ast .Name ):
731+ return {node .id }
732+ if isinstance (node , ast .Subscript ):
733+ names = collect_type_names_from_annotation (node .value )
734+ names |= collect_type_names_from_annotation (node .slice )
735+ return names
736+ if isinstance (node , ast .BinOp ) and isinstance (node .op , ast .BitOr ):
737+ return collect_type_names_from_annotation (node .left ) | collect_type_names_from_annotation (node .right )
738+ if isinstance (node , ast .Tuple ):
739+ names : set [str ] = set ()
740+ for elt in node .elts :
741+ names |= collect_type_names_from_annotation (elt )
742+ return names
743+ return set ()
744+
745+
746+ def extract_init_stub_from_class (class_name : str , module_source : str , module_tree : ast .Module ) -> str | None :
747+ class_node = None
748+ for node in ast .walk (module_tree ):
749+ if isinstance (node , ast .ClassDef ) and node .name == class_name :
750+ class_node = node
751+ break
752+ if class_node is None :
753+ return None
754+
755+ init_node = None
756+ for item in class_node .body :
757+ if isinstance (item , (ast .FunctionDef , ast .AsyncFunctionDef )) and item .name == "__init__" :
758+ init_node = item
759+ break
760+ if init_node is None :
761+ return None
762+
763+ lines = module_source .splitlines ()
764+ init_source = "\n " .join (lines [init_node .lineno - 1 : init_node .end_lineno ])
765+ return f"class { class_name } :\n { init_source } "
766+
767+
768+ def extract_parameter_type_constructors (
769+ function_to_optimize : FunctionToOptimize , project_root_path : Path , existing_class_names : set [str ]
770+ ) -> CodeStringsMarkdown :
771+ import jedi
772+
773+ try :
774+ source = function_to_optimize .file_path .read_text (encoding = "utf-8" )
775+ tree = ast .parse (source )
776+ except Exception :
777+ return CodeStringsMarkdown (code_strings = [])
778+
779+ func_node = None
780+ for node in ast .walk (tree ):
781+ if (
782+ isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef ))
783+ and node .name == function_to_optimize .function_name
784+ ):
785+ if function_to_optimize .starting_line is not None and node .lineno != function_to_optimize .starting_line :
786+ continue
787+ func_node = node
788+ break
789+ if func_node is None :
790+ return CodeStringsMarkdown (code_strings = [])
791+
792+ type_names : set [str ] = set ()
793+ for arg in func_node .args .args + func_node .args .posonlyargs + func_node .args .kwonlyargs :
794+ type_names |= collect_type_names_from_annotation (arg .annotation )
795+ if func_node .args .vararg :
796+ type_names |= collect_type_names_from_annotation (func_node .args .vararg .annotation )
797+ if func_node .args .kwarg :
798+ type_names |= collect_type_names_from_annotation (func_node .args .kwarg .annotation )
799+
800+ type_names -= BUILTIN_AND_TYPING_NAMES
801+ type_names -= existing_class_names
802+ if not type_names :
803+ return CodeStringsMarkdown (code_strings = [])
804+
805+ import_map : dict [str , str ] = {}
806+ for node in ast .walk (tree ):
807+ if isinstance (node , ast .ImportFrom ) and node .module :
808+ for alias in node .names :
809+ name = alias .asname if alias .asname else alias .name
810+ import_map [name ] = node .module
811+
812+ code_strings : list [CodeString ] = []
813+ module_cache : dict [Path , tuple [str , ast .Module ]] = {}
814+
815+ for type_name in sorted (type_names ):
816+ module_name = import_map .get (type_name )
817+ if not module_name :
818+ continue
819+ try :
820+ script_code = f"from { module_name } import { type_name } "
821+ script = jedi .Script (script_code , project = jedi .Project (path = project_root_path ))
822+ definitions = script .goto (1 , len (f"from { module_name } import " ) + len (type_name ), follow_imports = True )
823+ if not definitions :
824+ continue
825+
826+ module_path = definitions [0 ].module_path
827+ if not module_path :
828+ continue
829+
830+ if module_path in module_cache :
831+ mod_source , mod_tree = module_cache [module_path ]
832+ else :
833+ mod_source = module_path .read_text (encoding = "utf-8" )
834+ mod_tree = ast .parse (mod_source )
835+ module_cache [module_path ] = (mod_source , mod_tree )
836+
837+ stub = extract_init_stub_from_class (type_name , mod_source , mod_tree )
838+ if stub :
839+ code_strings .append (CodeString (code = stub , file_path = module_path ))
840+ except Exception :
841+ logger .debug (f"Error extracting constructor stub for { type_name } from { module_name } " )
842+ continue
843+
844+ return CodeStringsMarkdown (code_strings = code_strings )
845+
846+
630847def enrich_testgen_context (code_context : CodeStringsMarkdown , project_root_path : Path ) -> CodeStringsMarkdown :
631848 import jedi
632849
@@ -852,7 +1069,12 @@ def prune_cst(
8521069 return node , False
8531070
8541071 # Handle dunder methods for READ_ONLY/TESTGEN modes
855- if include_dunder_methods and len (node .name .value ) > 4 and node .name .value .startswith ("__" ) and node .name .value .endswith ("__" ):
1072+ if (
1073+ include_dunder_methods
1074+ and len (node .name .value ) > 4
1075+ and node .name .value .startswith ("__" )
1076+ and node .name .value .endswith ("__" )
1077+ ):
8561078 if not include_init_dunder and node .name .value == "__init__" :
8571079 return None , False
8581080 if remove_docstrings and isinstance (node .body , cst .IndentedBlock ):
0 commit comments