|
18 | 18 | extract_init_stub_from_class, |
19 | 19 | extract_parameter_type_constructors, |
20 | 20 | get_code_optimization_context, |
| 21 | + resolve_instance_class_name, |
21 | 22 | ) |
22 | 23 | from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent |
23 | 24 | from codeflash.optimization.optimizer import Optimizer |
@@ -4339,3 +4340,127 @@ def process(c: Config) -> str: |
4339 | 4340 | ) |
4340 | 4341 | result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) |
4341 | 4342 | assert len(result.code_strings) == 0 |
| 4343 | + |
| 4344 | + |
| 4345 | +# --- Tests for resolve_instance_class_name --- |
| 4346 | + |
| 4347 | + |
| 4348 | +def test_resolve_instance_class_name_direct_call() -> None: |
| 4349 | + source = "config = MyConfig(debug=True)" |
| 4350 | + tree = ast.parse(source) |
| 4351 | + assert resolve_instance_class_name("config", tree) == "MyConfig" |
| 4352 | + |
| 4353 | + |
| 4354 | +def test_resolve_instance_class_name_annotated() -> None: |
| 4355 | + source = "config: MyConfig = load()" |
| 4356 | + tree = ast.parse(source) |
| 4357 | + assert resolve_instance_class_name("config", tree) == "MyConfig" |
| 4358 | + |
| 4359 | + |
| 4360 | +def test_resolve_instance_class_name_factory_method() -> None: |
| 4361 | + source = "config = MyConfig.from_env()" |
| 4362 | + tree = ast.parse(source) |
| 4363 | + assert resolve_instance_class_name("config", tree) == "MyConfig" |
| 4364 | + |
| 4365 | + |
| 4366 | +def test_resolve_instance_class_name_no_match() -> None: |
| 4367 | + source = "x = 42" |
| 4368 | + tree = ast.parse(source) |
| 4369 | + assert resolve_instance_class_name("x", tree) is None |
| 4370 | + |
| 4371 | + |
| 4372 | +def test_resolve_instance_class_name_missing_variable() -> None: |
| 4373 | + source = "config = MyConfig()" |
| 4374 | + tree = ast.parse(source) |
| 4375 | + assert resolve_instance_class_name("other", tree) is None |
| 4376 | + |
| 4377 | + |
| 4378 | +# --- Tests for enhanced extract_init_stub_from_class --- |
| 4379 | + |
| 4380 | + |
| 4381 | +def test_extract_init_stub_includes_post_init() -> None: |
| 4382 | + source = """\ |
| 4383 | +class MyDataclass: |
| 4384 | + def __init__(self, x: int): |
| 4385 | + self.x = x |
| 4386 | + def __post_init__(self): |
| 4387 | + self.y = self.x * 2 |
| 4388 | +""" |
| 4389 | + tree = ast.parse(source) |
| 4390 | + stub = extract_init_stub_from_class("MyDataclass", source, tree) |
| 4391 | + assert stub is not None |
| 4392 | + assert "class MyDataclass:" in stub |
| 4393 | + assert "def __init__" in stub |
| 4394 | + assert "def __post_init__" in stub |
| 4395 | + assert "self.y = self.x * 2" in stub |
| 4396 | + |
| 4397 | + |
| 4398 | +def test_extract_init_stub_includes_properties() -> None: |
| 4399 | + source = """\ |
| 4400 | +class MyClass: |
| 4401 | + def __init__(self, name: str): |
| 4402 | + self._name = name |
| 4403 | + @property |
| 4404 | + def name(self) -> str: |
| 4405 | + return self._name |
| 4406 | +""" |
| 4407 | + tree = ast.parse(source) |
| 4408 | + stub = extract_init_stub_from_class("MyClass", source, tree) |
| 4409 | + assert stub is not None |
| 4410 | + assert "def __init__" in stub |
| 4411 | + assert "@property" in stub |
| 4412 | + assert "def name" in stub |
| 4413 | + |
| 4414 | + |
| 4415 | +def test_extract_init_stub_property_only_class() -> None: |
| 4416 | + source = """\ |
| 4417 | +class ReadOnly: |
| 4418 | + @property |
| 4419 | + def value(self) -> int: |
| 4420 | + return 42 |
| 4421 | +""" |
| 4422 | + tree = ast.parse(source) |
| 4423 | + stub = extract_init_stub_from_class("ReadOnly", source, tree) |
| 4424 | + assert stub is not None |
| 4425 | + assert "class ReadOnly:" in stub |
| 4426 | + assert "@property" in stub |
| 4427 | + assert "def value" in stub |
| 4428 | + |
| 4429 | + |
| 4430 | +# --- Tests for enrich_testgen_context resolving instances --- |
| 4431 | + |
| 4432 | + |
| 4433 | +def test_enrich_testgen_context_resolves_instance_to_class(tmp_path: Path) -> None: |
| 4434 | + package_dir = tmp_path / "mypkg" |
| 4435 | + package_dir.mkdir() |
| 4436 | + (package_dir / "__init__.py").write_text("", encoding="utf-8") |
| 4437 | + |
| 4438 | + config_module = """\ |
| 4439 | +class AppConfig: |
| 4440 | + def __init__(self, debug: bool = False): |
| 4441 | + self.debug = debug |
| 4442 | +
|
| 4443 | + @property |
| 4444 | + def log_level(self) -> str: |
| 4445 | + return "DEBUG" if self.debug else "INFO" |
| 4446 | +
|
| 4447 | +app_config = AppConfig(debug=True) |
| 4448 | +""" |
| 4449 | + (package_dir / "config.py").write_text(config_module, encoding="utf-8") |
| 4450 | + |
| 4451 | + consumer_code = """\ |
| 4452 | +from mypkg.config import app_config |
| 4453 | +
|
| 4454 | +def get_log_level() -> str: |
| 4455 | + return app_config.log_level |
| 4456 | +""" |
| 4457 | + consumer_path = package_dir / "consumer.py" |
| 4458 | + consumer_path.write_text(consumer_code, encoding="utf-8") |
| 4459 | + |
| 4460 | + context = CodeStringsMarkdown(code_strings=[CodeString(code=consumer_code, file_path=consumer_path)]) |
| 4461 | + result = enrich_testgen_context(context, tmp_path) |
| 4462 | + |
| 4463 | + assert len(result.code_strings) >= 1 |
| 4464 | + combined = "\n".join(cs.code for cs in result.code_strings) |
| 4465 | + assert "class AppConfig:" in combined |
| 4466 | + assert "@property" in combined |
0 commit comments