Skip to content

Commit 17e46d1

Browse files
committed
fix: resolve test file paths in discover_tests_pytest to fix path comparison
1 parent 6346c74 commit 17e46d1

29 files changed

Lines changed: 1213 additions & 1556 deletions

codeflash/cli_cmds/cli.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -352,32 +352,58 @@ def _handle_show_config() -> None:
352352
from codeflash.setup.detector import detect_project, has_existing_config
353353

354354
project_root = Path.cwd()
355-
detected = detect_project(project_root)
356-
357-
# Check if config exists or is auto-detected
358355
config_exists, config_file = has_existing_config(project_root)
359-
status = "Saved config" if config_exists else "Auto-detected (not saved)"
360356

361-
console.print()
362-
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
363-
if config_exists and config_file:
364-
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
365-
console.print()
357+
if config_exists:
358+
from codeflash.code_utils.config_parser import parse_config_file
366359

367-
table = Table(show_header=True, header_style="bold cyan")
368-
table.add_column("Setting", style="dim")
369-
table.add_column("Value")
370-
371-
table.add_row("Language", detected.language)
372-
table.add_row("Project root", str(detected.project_root))
373-
table.add_row("Module root", str(detected.module_root))
374-
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
375-
table.add_row("Test runner", detected.test_runner or "(not detected)")
376-
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
377-
table.add_row(
378-
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
379-
)
380-
table.add_row("Confidence", f"{detected.confidence:.0%}")
360+
config, config_file_path = parse_config_file()
361+
status = "Saved config"
362+
363+
console.print()
364+
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
365+
console.print(f"[dim]Config file: {config_file_path}[/dim]")
366+
console.print()
367+
368+
table = Table(show_header=True, header_style="bold cyan")
369+
table.add_column("Setting", style="dim")
370+
table.add_column("Value")
371+
372+
table.add_row("Project root", str(project_root))
373+
table.add_row("Module root", config.get("module_root", "(not set)"))
374+
table.add_row("Tests root", config.get("tests_root", "(not set)"))
375+
table.add_row("Test runner", config.get("test_framework", config.get("pytest_cmd", "(not set)")))
376+
table.add_row(
377+
"Formatter",
378+
", ".join(config["formatter_cmds"]) if config.get("formatter_cmds") else "(not set)",
379+
)
380+
ignore_paths = config.get("ignore_paths", [])
381+
table.add_row("Ignore paths", ", ".join(str(p) for p in ignore_paths) if ignore_paths else "(none)")
382+
else:
383+
detected = detect_project(project_root)
384+
status = "Auto-detected (not saved)"
385+
386+
console.print()
387+
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
388+
console.print()
389+
390+
table = Table(show_header=True, header_style="bold cyan")
391+
table.add_column("Setting", style="dim")
392+
table.add_column("Value")
393+
394+
table.add_row("Language", detected.language)
395+
table.add_row("Project root", str(detected.project_root))
396+
table.add_row("Module root", str(detected.module_root))
397+
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
398+
table.add_row("Test runner", detected.test_runner or "(not detected)")
399+
table.add_row(
400+
"Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)"
401+
)
402+
table.add_row(
403+
"Ignore paths",
404+
", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)",
405+
)
406+
table.add_row("Confidence", f"{detected.confidence:.0%}")
381407

382408
console.print(table)
383409
console.print()

codeflash/code_utils/time_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from codeflash.result.critic import performance_gain
4+
35

46
def humanize_runtime(time_in_ns: int) -> str:
57
runtime_human: str = str(time_in_ns)
@@ -89,3 +91,11 @@ def format_perf(percentage: float) -> str:
8991
if abs_perc >= 1:
9092
return f"{percentage:.2f}"
9193
return f"{percentage:.3f}"
94+
95+
96+
def format_runtime_comment(original_time_ns: int, optimized_time_ns: int, comment_prefix: str = "#") -> str:
97+
perf_gain = format_perf(
98+
abs(performance_gain(original_runtime_ns=original_time_ns, optimized_runtime_ns=optimized_time_ns) * 100)
99+
)
100+
status = "slower" if optimized_time_ns > original_time_ns else "faster"
101+
return f"{comment_prefix} {format_time(original_time_ns)} -> {format_time(optimized_time_ns)} ({perf_gain}% {status})"

codeflash/discovery/discover_unit_tests.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,10 @@ def discover_tests_pytest(
728728
logger.debug(f"Pytest collection exit code: {exitcode}")
729729
if pytest_rootdir is not None:
730730
cfg.tests_project_rootdir = Path(pytest_rootdir)
731+
if discover_only_these_tests:
732+
resolved_discover_only = {p.resolve() for p in discover_only_these_tests}
733+
else:
734+
resolved_discover_only = None
731735
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
732736
for test in tests:
733737
if "__replay_test" in test["test_file"]:
@@ -737,13 +741,14 @@ def discover_tests_pytest(
737741
else:
738742
test_type = TestType.EXISTING_UNIT_TEST
739743

744+
test_file_path = Path(test["test_file"]).resolve()
740745
test_obj = TestsInFile(
741-
test_file=Path(test["test_file"]),
746+
test_file=test_file_path,
742747
test_class=test["test_class"],
743748
test_function=test["test_function"],
744749
test_type=test_type,
745750
)
746-
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
751+
if resolved_discover_only and test_obj.test_file not in resolved_discover_only:
747752
continue
748753
file_to_test_map[test_obj.test_file].append(test_obj)
749754
# Within these test files, find the project functions they are referring to and return their names/locations

codeflash/languages/javascript/edit_tests.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,8 @@
1111
from pathlib import Path
1212

1313
from codeflash.cli_cmds.console import logger
14-
from codeflash.code_utils.time_utils import format_perf, format_time
14+
from codeflash.code_utils.time_utils import format_runtime_comment
1515
from codeflash.models.models import GeneratedTests, GeneratedTestsList
16-
from codeflash.result.critic import performance_gain
17-
18-
19-
def format_runtime_comment(original_time: int, optimized_time: int) -> str:
20-
"""Format a runtime comparison comment for JavaScript.
21-
22-
Args:
23-
original_time: Original runtime in nanoseconds.
24-
optimized_time: Optimized runtime in nanoseconds.
25-
26-
Returns:
27-
Formatted comment string with // prefix.
28-
29-
"""
30-
perf_gain = format_perf(
31-
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
32-
)
33-
status = "slower" if optimized_time > original_time else "faster"
34-
return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
3516

3617

3718
def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str:
@@ -120,7 +101,7 @@ def find_matching_test(test_description: str) -> str | None:
120101
# Only add comment if line has a function call and doesn't already have a comment
121102
if func_call_pattern.search(line) and "//" not in line and "expect(" in line:
122103
orig_time, opt_time = timing_by_full_name[current_matched_full_name]
123-
comment = format_runtime_comment(orig_time, opt_time)
104+
comment = format_runtime_comment(orig_time, opt_time, comment_prefix="//")
124105
logger.debug(f"[js-annotations] Adding comment to test '{current_test_name}': {comment}")
125106
# Add comment at end of line
126107
line = f"{line.rstrip()} {comment}"

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
recurse_sections,
2424
remove_unused_definitions_by_function_names,
2525
)
26-
from codeflash.languages.python.static_analysis.code_extractor import (
27-
add_needed_imports_from_module,
28-
find_preexisting_objects,
29-
)
26+
from codeflash.languages.python.static_analysis.code_extractor import add_needed_imports_from_module, find_preexisting_objects
3027
from codeflash.models.models import (
3128
CodeContextType,
3229
CodeOptimizationContext,
@@ -550,6 +547,35 @@ def get_function_sources_from_jedi(
550547
file_path_to_function_source[definition_path].add(function_source)
551548
function_source_list.append(function_source)
552549

550+
if definition.type == "statement":
551+
try:
552+
for inferred in name.infer():
553+
if (
554+
inferred.type == "class"
555+
and inferred.full_name
556+
and inferred.module_path
557+
and is_project_path(inferred.module_path, project_root_path)
558+
and inferred.full_name.startswith(inferred.module_name)
559+
):
560+
class_fqn = f"{inferred.full_name}.__init__"
561+
class_qname = get_qualified_name(inferred.module_name, class_fqn)
562+
if len(class_qname.split(".")) <= 2:
563+
class_path = inferred.module_path
564+
rel = safe_relative_to(class_path, project_root_path)
565+
if not rel.is_absolute():
566+
class_path = project_root_path / rel
567+
class_source = FunctionSource(
568+
file_path=class_path,
569+
qualified_name=class_qname,
570+
fully_qualified_name=class_fqn,
571+
only_function_name="__init__",
572+
source_code=inferred.get_line_code(),
573+
)
574+
file_path_to_function_source[class_path].add(class_source)
575+
function_source_list.append(class_source)
576+
except Exception:
577+
logger.debug(f"Error inferring type for statement {definition.full_name}")
578+
553579
return file_path_to_function_source, function_source_list
554580

555581

@@ -750,6 +776,16 @@ def extract_class_and_bases(
750776

751777
extract_class_and_bases(name, module_path, module_source, module_tree)
752778

779+
if (module_path, name) not in extracted_classes:
780+
for ast_node in module_tree.body:
781+
if isinstance(ast_node, ast.Assign):
782+
for target in ast_node.targets:
783+
if isinstance(target, ast.Name) and target.id == name:
784+
if isinstance(ast_node.value, ast.Call) and isinstance(ast_node.value.func, ast.Name):
785+
class_name = ast_node.value.func.id
786+
if class_name not in existing_classes:
787+
extract_class_and_bases(class_name, module_path, module_source, module_tree)
788+
753789
except Exception:
754790
logger.debug(f"Error extracting class definition for {name} from {module_name}")
755791
continue
@@ -759,7 +795,7 @@ def extract_class_and_bases(
759795
for cls, name in resolve_classes_from_modules(external_base_classes):
760796
if name in emitted_class_names:
761797
continue
762-
stub = extract_init_stub(cls, name, require_site_packages=False)
798+
stub = extract_class_stub(cls, name, require_site_packages=False)
763799
if stub is not None:
764800
code_strings.append(stub)
765801
emitted_class_names.add(name)
@@ -778,7 +814,7 @@ def extract_class_and_bases(
778814
continue
779815
processed_classes.add(cls)
780816

781-
stub = extract_init_stub(cls, class_name)
817+
stub = extract_class_stub(cls, class_name)
782818
if stub is None:
783819
continue
784820

@@ -888,22 +924,24 @@ def resolve_transitive_type_deps(cls: type) -> list[type]:
888924
return deps
889925

890926

891-
def extract_init_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None:
892-
"""Extract a stub containing the class definition with only its __init__ method.
927+
def extract_class_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None:
928+
"""Extract the full class source, falling back to an __init__-only stub.
929+
930+
Attempts ``inspect.getsource(cls)`` first so the LLM sees every method and
931+
attribute. Falls back to extracting just ``__init__`` when the full source
932+
is unavailable (C extensions, dynamically generated classes). Classes whose
933+
``__init__`` is inherited from ``object`` are still included when the full
934+
source can be retrieved.
893935
894936
Args:
895-
cls: The class object to extract __init__ from
937+
cls: The class object to extract from
896938
class_name: Name to use for the class in the stub
897939
require_site_packages: If True, only extract from site-packages. If False, include stdlib too.
898940
899941
"""
900942
import inspect
901943
import textwrap
902944

903-
init_method = getattr(cls, "__init__", None)
904-
if init_method is None or init_method is object.__init__:
905-
return None
906-
907945
try:
908946
class_file = Path(inspect.getfile(cls))
909947
except (OSError, TypeError):
@@ -912,17 +950,30 @@ def extract_init_stub(cls: type, class_name: str, require_site_packages: bool =
912950
if require_site_packages and not path_belongs_to_site_packages(class_file):
913951
return None
914952

953+
parts = class_file.parts
954+
if "site-packages" in parts:
955+
idx = parts.index("site-packages")
956+
class_file = Path(*parts[idx + 1 :])
957+
958+
# Try full class source first
959+
try:
960+
class_source = inspect.getsource(cls)
961+
class_source = textwrap.dedent(class_source)
962+
return CodeString(code=class_source, file_path=class_file)
963+
except (OSError, TypeError):
964+
pass
965+
966+
# Fallback: __init__-only stub
967+
init_method = getattr(cls, "__init__", None)
968+
if init_method is None or init_method is object.__init__:
969+
return None
970+
915971
try:
916972
init_source = inspect.getsource(init_method)
917973
init_source = textwrap.dedent(init_source)
918974
except (OSError, TypeError):
919975
return None
920976

921-
parts = class_file.parts
922-
if "site-packages" in parts:
923-
idx = parts.index("site-packages")
924-
class_file = Path(*parts[idx + 1 :])
925-
926977
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
927978
return CodeString(code=class_source, file_path=class_file)
928979

@@ -1080,6 +1131,7 @@ def parse_code_and_prune_cst(
10801131
filtered_node, found_target = prune_cst(
10811132
module,
10821133
target_functions,
1134+
defs_with_usages=defs_with_usages,
10831135
helpers=helpers_of_helper_functions,
10841136
remove_docstrings=remove_docstrings,
10851137
include_dunder_methods=True,
@@ -1219,7 +1271,7 @@ def prune_cst(
12191271
stmt,
12201272
target_functions,
12211273
class_prefix,
1222-
defs_with_usages=defs_with_usages,
1274+
defs_with_usages=None,
12231275
helpers=helpers,
12241276
remove_docstrings=remove_docstrings,
12251277
include_target_in_output=include_target_in_output,

0 commit comments

Comments
 (0)