Skip to content

Commit 2367b4c

Browse files
committed
feat: extract parameter type constructor signatures into testgen context
Add enrichment step that parses FTO parameter type annotations, resolves types via jedi (following re-exports), and extracts full __init__ source to give the LLM constructor context for typed parameters.
1 parent 68c148c commit 2367b4c

2 files changed

Lines changed: 432 additions & 3 deletions

File tree

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 225 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def build_testgen_context(
5252
*,
5353
remove_docstrings: bool = False,
5454
include_enrichment: bool = True,
55+
function_to_optimize: FunctionToOptimize | None = None,
5556
) -> CodeStringsMarkdown:
5657
testgen_context = extract_code_markdown_context_from_files(
5758
helpers_of_fto_dict,
@@ -66,6 +67,17 @@ def build_testgen_context(
6667
if enrichment.code_strings:
6768
testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings)
6869

70+
if function_to_optimize is not None:
71+
result = _parse_and_collect_imports(testgen_context)
72+
existing_classes = collect_existing_class_names(result[0]) if result else set()
73+
constructor_stubs = extract_parameter_type_constructors(
74+
function_to_optimize, project_root_path, existing_classes
75+
)
76+
if constructor_stubs.code_strings:
77+
testgen_context = CodeStringsMarkdown(
78+
code_strings=testgen_context.code_strings + constructor_stubs.code_strings
79+
)
80+
6981
return testgen_context
7082

7183

@@ -156,12 +168,18 @@ def get_code_optimization_context(
156168
read_only_context_code = ""
157169

158170
# Progressive fallback for testgen context token limits
159-
testgen_context = build_testgen_context(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path)
171+
testgen_context = build_testgen_context(
172+
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, function_to_optimize=function_to_optimize
173+
)
160174

161175
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
162176
logger.debug("Testgen context exceeded token limit, removing docstrings")
163177
testgen_context = build_testgen_context(
164-
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True
178+
helpers_of_fto_dict,
179+
helpers_of_helpers_dict,
180+
project_root_path,
181+
remove_docstrings=True,
182+
function_to_optimize=function_to_optimize,
165183
)
166184

167185
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
@@ -627,6 +645,205 @@ def collect_existing_class_names(tree: ast.Module) -> set[str]:
627645
return class_names
628646

629647

648+
BUILTIN_AND_TYPING_NAMES = frozenset(
649+
{
650+
"int",
651+
"str",
652+
"float",
653+
"bool",
654+
"bytes",
655+
"bytearray",
656+
"complex",
657+
"list",
658+
"dict",
659+
"set",
660+
"frozenset",
661+
"tuple",
662+
"type",
663+
"object",
664+
"None",
665+
"NoneType",
666+
"Ellipsis",
667+
"NotImplemented",
668+
"memoryview",
669+
"range",
670+
"slice",
671+
"property",
672+
"classmethod",
673+
"staticmethod",
674+
"super",
675+
"Optional",
676+
"Union",
677+
"Any",
678+
"List",
679+
"Dict",
680+
"Set",
681+
"FrozenSet",
682+
"Tuple",
683+
"Type",
684+
"Callable",
685+
"Iterator",
686+
"Generator",
687+
"Coroutine",
688+
"AsyncGenerator",
689+
"AsyncIterator",
690+
"Iterable",
691+
"AsyncIterable",
692+
"Sequence",
693+
"MutableSequence",
694+
"Mapping",
695+
"MutableMapping",
696+
"Collection",
697+
"Awaitable",
698+
"Literal",
699+
"Final",
700+
"ClassVar",
701+
"TypeVar",
702+
"TypeAlias",
703+
"ParamSpec",
704+
"Concatenate",
705+
"Annotated",
706+
"TypeGuard",
707+
"Self",
708+
"Unpack",
709+
"TypeVarTuple",
710+
"Never",
711+
"NoReturn",
712+
"SupportsInt",
713+
"SupportsFloat",
714+
"SupportsComplex",
715+
"SupportsBytes",
716+
"SupportsAbs",
717+
"SupportsRound",
718+
"IO",
719+
"TextIO",
720+
"BinaryIO",
721+
"Pattern",
722+
"Match",
723+
}
724+
)
725+
726+
727+
def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:
728+
if node is None:
729+
return set()
730+
if isinstance(node, ast.Name):
731+
return {node.id}
732+
if isinstance(node, ast.Subscript):
733+
names = collect_type_names_from_annotation(node.value)
734+
names |= collect_type_names_from_annotation(node.slice)
735+
return names
736+
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
737+
return collect_type_names_from_annotation(node.left) | collect_type_names_from_annotation(node.right)
738+
if isinstance(node, ast.Tuple):
739+
names: set[str] = set()
740+
for elt in node.elts:
741+
names |= collect_type_names_from_annotation(elt)
742+
return names
743+
return set()
744+
745+
746+
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
747+
class_node = None
748+
for node in ast.walk(module_tree):
749+
if isinstance(node, ast.ClassDef) and node.name == class_name:
750+
class_node = node
751+
break
752+
if class_node is None:
753+
return None
754+
755+
init_node = None
756+
for item in class_node.body:
757+
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__":
758+
init_node = item
759+
break
760+
if init_node is None:
761+
return None
762+
763+
lines = module_source.splitlines()
764+
init_source = "\n".join(lines[init_node.lineno - 1 : init_node.end_lineno])
765+
return f"class {class_name}:\n{init_source}"
766+
767+
768+
def extract_parameter_type_constructors(
769+
function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str]
770+
) -> CodeStringsMarkdown:
771+
import jedi
772+
773+
try:
774+
source = function_to_optimize.file_path.read_text(encoding="utf-8")
775+
tree = ast.parse(source)
776+
except Exception:
777+
return CodeStringsMarkdown(code_strings=[])
778+
779+
func_node = None
780+
for node in ast.walk(tree):
781+
if (
782+
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
783+
and node.name == function_to_optimize.function_name
784+
):
785+
if function_to_optimize.starting_line is not None and node.lineno != function_to_optimize.starting_line:
786+
continue
787+
func_node = node
788+
break
789+
if func_node is None:
790+
return CodeStringsMarkdown(code_strings=[])
791+
792+
type_names: set[str] = set()
793+
for arg in func_node.args.args + func_node.args.posonlyargs + func_node.args.kwonlyargs:
794+
type_names |= collect_type_names_from_annotation(arg.annotation)
795+
if func_node.args.vararg:
796+
type_names |= collect_type_names_from_annotation(func_node.args.vararg.annotation)
797+
if func_node.args.kwarg:
798+
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
799+
800+
type_names -= BUILTIN_AND_TYPING_NAMES
801+
type_names -= existing_class_names
802+
if not type_names:
803+
return CodeStringsMarkdown(code_strings=[])
804+
805+
import_map: dict[str, str] = {}
806+
for node in ast.walk(tree):
807+
if isinstance(node, ast.ImportFrom) and node.module:
808+
for alias in node.names:
809+
name = alias.asname if alias.asname else alias.name
810+
import_map[name] = node.module
811+
812+
code_strings: list[CodeString] = []
813+
module_cache: dict[Path, tuple[str, ast.Module]] = {}
814+
815+
for type_name in sorted(type_names):
816+
module_name = import_map.get(type_name)
817+
if not module_name:
818+
continue
819+
try:
820+
script_code = f"from {module_name} import {type_name}"
821+
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
822+
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
823+
if not definitions:
824+
continue
825+
826+
module_path = definitions[0].module_path
827+
if not module_path:
828+
continue
829+
830+
if module_path in module_cache:
831+
mod_source, mod_tree = module_cache[module_path]
832+
else:
833+
mod_source = module_path.read_text(encoding="utf-8")
834+
mod_tree = ast.parse(mod_source)
835+
module_cache[module_path] = (mod_source, mod_tree)
836+
837+
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
838+
if stub:
839+
code_strings.append(CodeString(code=stub, file_path=module_path))
840+
except Exception:
841+
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
842+
continue
843+
844+
return CodeStringsMarkdown(code_strings=code_strings)
845+
846+
630847
def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
631848
import jedi
632849

@@ -852,7 +1069,12 @@ def prune_cst(
8521069
return node, False
8531070

8541071
# Handle dunder methods for READ_ONLY/TESTGEN modes
855-
if include_dunder_methods and len(node.name.value) > 4 and node.name.value.startswith("__") and node.name.value.endswith("__"):
1072+
if (
1073+
include_dunder_methods
1074+
and len(node.name.value) > 4
1075+
and node.name.value.startswith("__")
1076+
and node.name.value.endswith("__")
1077+
):
8561078
if not include_init_dunder and node.name.value == "__init__":
8571079
return None, False
8581080
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):

0 commit comments

Comments
 (0)