Skip to content

Commit d480061

Browse files
committed
temp
1 parent 26989b2 commit d480061

2 files changed

Lines changed: 172 additions & 8 deletions

File tree

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -752,17 +752,30 @@ def extract_init_stub_from_class(class_name: str, module_source: str, module_tre
752752
if class_node is None:
753753
return None
754754

755-
init_node = None
755+
lines = module_source.splitlines()
756+
relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
756757
for item in class_node.body:
757-
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__":
758-
init_node = item
759-
break
760-
if init_node is None:
758+
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
759+
if item.name in ("__init__", "__post_init__"):
760+
relevant_nodes.append(item)
761+
elif any(
762+
isinstance(d, ast.Name) and d.id == "property"
763+
or isinstance(d, ast.Attribute) and d.attr == "property"
764+
for d in item.decorator_list
765+
):
766+
relevant_nodes.append(item)
767+
768+
if not relevant_nodes:
761769
return None
762770

763-
lines = module_source.splitlines()
764-
init_source = "\n".join(lines[init_node.lineno - 1 : init_node.end_lineno])
765-
return f"class {class_name}:\n{init_source}"
771+
snippets: list[str] = []
772+
for node in relevant_nodes:
773+
start = node.lineno
774+
if node.decorator_list:
775+
start = min(d.lineno for d in node.decorator_list)
776+
snippets.append("\n".join(lines[start - 1 : node.end_lineno]))
777+
778+
return f"class {class_name}:\n" + "\n".join(snippets)
766779

767780

768781
def extract_parameter_type_constructors(
@@ -844,6 +857,27 @@ def extract_parameter_type_constructors(
844857
return CodeStringsMarkdown(code_strings=code_strings)
845858

846859

860+
def resolve_instance_class_name(name: str, module_tree: ast.Module) -> str | None:
861+
for node in module_tree.body:
862+
if isinstance(node, ast.Assign):
863+
for target in node.targets:
864+
if isinstance(target, ast.Name) and target.id == name:
865+
value = node.value
866+
if isinstance(value, ast.Call):
867+
func = value.func
868+
if isinstance(func, ast.Name):
869+
return func.id
870+
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
871+
return func.value.id
872+
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == name:
873+
ann = node.annotation
874+
if isinstance(ann, ast.Name):
875+
return ann.id
876+
if isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name):
877+
return ann.value.id
878+
return None
879+
880+
847881
def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
848882
import jedi
849883

@@ -938,6 +972,11 @@ def extract_class_and_bases(
938972

939973
extract_class_and_bases(name, module_path, module_source, module_tree)
940974

975+
if (module_path, name) not in extracted_classes:
976+
resolved_class = resolve_instance_class_name(name, module_tree)
977+
if resolved_class and resolved_class not in existing_classes:
978+
extract_class_and_bases(resolved_class, module_path, module_source, module_tree)
979+
941980
except Exception:
942981
logger.debug(f"Error extracting class definition for {name} from {module_name}")
943982
continue

tests/test_code_context_extractor.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
extract_init_stub_from_class,
1919
extract_parameter_type_constructors,
2020
get_code_optimization_context,
21+
resolve_instance_class_name,
2122
)
2223
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
2324
from codeflash.optimization.optimizer import Optimizer
@@ -4339,3 +4340,127 @@ def process(c: Config) -> str:
43394340
)
43404341
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
43414342
assert len(result.code_strings) == 0
4343+
4344+
4345+
# --- Tests for resolve_instance_class_name ---
4346+
4347+
4348+
def test_resolve_instance_class_name_direct_call() -> None:
4349+
source = "config = MyConfig(debug=True)"
4350+
tree = ast.parse(source)
4351+
assert resolve_instance_class_name("config", tree) == "MyConfig"
4352+
4353+
4354+
def test_resolve_instance_class_name_annotated() -> None:
4355+
source = "config: MyConfig = load()"
4356+
tree = ast.parse(source)
4357+
assert resolve_instance_class_name("config", tree) == "MyConfig"
4358+
4359+
4360+
def test_resolve_instance_class_name_factory_method() -> None:
4361+
source = "config = MyConfig.from_env()"
4362+
tree = ast.parse(source)
4363+
assert resolve_instance_class_name("config", tree) == "MyConfig"
4364+
4365+
4366+
def test_resolve_instance_class_name_no_match() -> None:
4367+
source = "x = 42"
4368+
tree = ast.parse(source)
4369+
assert resolve_instance_class_name("x", tree) is None
4370+
4371+
4372+
def test_resolve_instance_class_name_missing_variable() -> None:
4373+
source = "config = MyConfig()"
4374+
tree = ast.parse(source)
4375+
assert resolve_instance_class_name("other", tree) is None
4376+
4377+
4378+
# --- Tests for enhanced extract_init_stub_from_class ---
4379+
4380+
4381+
def test_extract_init_stub_includes_post_init() -> None:
4382+
source = """\
4383+
class MyDataclass:
4384+
def __init__(self, x: int):
4385+
self.x = x
4386+
def __post_init__(self):
4387+
self.y = self.x * 2
4388+
"""
4389+
tree = ast.parse(source)
4390+
stub = extract_init_stub_from_class("MyDataclass", source, tree)
4391+
assert stub is not None
4392+
assert "class MyDataclass:" in stub
4393+
assert "def __init__" in stub
4394+
assert "def __post_init__" in stub
4395+
assert "self.y = self.x * 2" in stub
4396+
4397+
4398+
def test_extract_init_stub_includes_properties() -> None:
4399+
source = """\
4400+
class MyClass:
4401+
def __init__(self, name: str):
4402+
self._name = name
4403+
@property
4404+
def name(self) -> str:
4405+
return self._name
4406+
"""
4407+
tree = ast.parse(source)
4408+
stub = extract_init_stub_from_class("MyClass", source, tree)
4409+
assert stub is not None
4410+
assert "def __init__" in stub
4411+
assert "@property" in stub
4412+
assert "def name" in stub
4413+
4414+
4415+
def test_extract_init_stub_property_only_class() -> None:
4416+
source = """\
4417+
class ReadOnly:
4418+
@property
4419+
def value(self) -> int:
4420+
return 42
4421+
"""
4422+
tree = ast.parse(source)
4423+
stub = extract_init_stub_from_class("ReadOnly", source, tree)
4424+
assert stub is not None
4425+
assert "class ReadOnly:" in stub
4426+
assert "@property" in stub
4427+
assert "def value" in stub
4428+
4429+
4430+
# --- Tests for enrich_testgen_context resolving instances ---
4431+
4432+
4433+
def test_enrich_testgen_context_resolves_instance_to_class(tmp_path: Path) -> None:
4434+
package_dir = tmp_path / "mypkg"
4435+
package_dir.mkdir()
4436+
(package_dir / "__init__.py").write_text("", encoding="utf-8")
4437+
4438+
config_module = """\
4439+
class AppConfig:
4440+
def __init__(self, debug: bool = False):
4441+
self.debug = debug
4442+
4443+
@property
4444+
def log_level(self) -> str:
4445+
return "DEBUG" if self.debug else "INFO"
4446+
4447+
app_config = AppConfig(debug=True)
4448+
"""
4449+
(package_dir / "config.py").write_text(config_module, encoding="utf-8")
4450+
4451+
consumer_code = """\
4452+
from mypkg.config import app_config
4453+
4454+
def get_log_level() -> str:
4455+
return app_config.log_level
4456+
"""
4457+
consumer_path = package_dir / "consumer.py"
4458+
consumer_path.write_text(consumer_code, encoding="utf-8")
4459+
4460+
context = CodeStringsMarkdown(code_strings=[CodeString(code=consumer_code, file_path=consumer_path)])
4461+
result = enrich_testgen_context(context, tmp_path)
4462+
4463+
assert len(result.code_strings) >= 1
4464+
combined = "\n".join(cs.code for cs in result.code_strings)
4465+
assert "class AppConfig:" in combined
4466+
assert "@property" in combined

0 commit comments

Comments
 (0)