Skip to content

Commit c05f0a3

Browse files
authored
Merge branch 'main' into kevinturcios/cf-948-display-the-results-of-the-reviewer
2 parents 7afba98 + 5ef4ed4 commit c05f0a3

18 files changed

Lines changed: 621 additions & 48 deletions

codeflash/benchmarking/trace_benchmarks.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from codeflash.cli_cmds.console import logger
99
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
10+
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
1011

1112

1213
def trace_benchmarks_pytest(
@@ -17,20 +18,18 @@ def trace_benchmarks_pytest(
1718
benchmark_env["PYTHONPATH"] = str(project_root)
1819
else:
1920
benchmark_env["PYTHONPATH"] += os.pathsep + str(project_root)
20-
result = subprocess.run(
21+
run_args = get_cross_platform_subprocess_run_args(
22+
cwd=project_root, env=benchmark_env, timeout=timeout, check=False, text=True, capture_output=True
23+
)
24+
result = subprocess.run( # noqa: PLW1510
2125
[
2226
SAFE_SYS_EXECUTABLE,
2327
Path(__file__).parent / "pytest_new_process_trace_benchmarks.py",
2428
benchmarks_root,
2529
tests_root,
2630
trace_file,
2731
],
28-
cwd=project_root,
29-
check=False,
30-
capture_output=True,
31-
text=True,
32-
env=benchmark_env,
33-
timeout=timeout,
32+
**run_args,
3433
)
3534
if result.returncode != 0:
3635
if "ERROR collecting" in result.stdout:

codeflash/cli_cmds/console.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@
5858
)
5959

6060

61+
class DummyTask:
62+
def __init__(self) -> None:
63+
self.id = 0
64+
65+
66+
class DummyProgress:
67+
def __init__(self) -> None:
68+
pass
69+
70+
def advance(self, task_id: TaskID, advance: int = 1) -> None:
71+
pass
72+
73+
6174
def lsp_log(message: LspMessage) -> None:
6275
if not is_LSP_enabled():
6376
return
@@ -120,10 +133,6 @@ def progress_bar(
120133
logger.info(message)
121134

122135
# Create a fake task ID since we still need to yield something
123-
class DummyTask:
124-
def __init__(self) -> None:
125-
self.id = 0
126-
127136
yield DummyTask().id
128137
else:
129138
progress = Progress(
@@ -141,6 +150,13 @@ def __init__(self) -> None:
141150
@contextmanager
142151
def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]:
143152
"""Progress bar for test files."""
153+
if is_LSP_enabled():
154+
lsp_log(LspTextMessage(text=description, takes_time=True))
155+
dummy_progress = DummyProgress()
156+
dummy_task = DummyTask()
157+
yield dummy_progress, dummy_task.id
158+
return
159+
144160
with Progress(
145161
SpinnerColumn(next(spinners)),
146162
TextColumn("[progress.description]{task.description}"),

codeflash/code_utils/code_replacer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def replace_function_definitions_in_module(
447447

448448
new_code: str = replace_functions_and_add_imports(
449449
# adding the global assignments before replacing the code, not after
450-
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
450+
# because of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
451451
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
452452
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
453453
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,

codeflash/code_utils/shell_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import contextlib
44
import os
55
import re
6+
import subprocess
7+
import sys
68
from pathlib import Path
79
from typing import TYPE_CHECKING, Optional
810

@@ -11,8 +13,11 @@
1113
from codeflash.either import Failure, Success
1214

1315
if TYPE_CHECKING:
16+
from collections.abc import Mapping
17+
1418
from codeflash.either import Result
1519

20+
1621
# PowerShell patterns and prefixes
1722
POWERSHELL_RC_EXPORT_PATTERN = re.compile(
1823
r'^\$env:CODEFLASH_API_KEY\s*=\s*(?:"|\')?(cf-[^\s"\']+)(?:"|\')?\s*$', re.MULTILINE
@@ -231,3 +236,24 @@ def save_api_key_to_rc(api_key: str) -> Result[str, str]:
231236
f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}"
232237
f"{LF}{api_key_line}{LF}"
233238
)
239+
240+
241+
def get_cross_platform_subprocess_run_args(
242+
cwd: Path | str | None = None,
243+
env: Mapping[str, str] | None = None,
244+
timeout: Optional[float] = None,
245+
check: bool = False, # noqa: FBT001, FBT002
246+
text: bool = True, # noqa: FBT001, FBT002
247+
capture_output: bool = True, # noqa: FBT001, FBT002 (only for non-Windows)
248+
) -> dict[str, str]:
249+
run_args = {"cwd": cwd, "env": env, "text": text, "timeout": timeout, "check": check}
250+
if sys.platform == "win32":
251+
creationflags = subprocess.CREATE_NEW_PROCESS_GROUP
252+
run_args["creationflags"] = creationflags
253+
run_args["stdout"] = subprocess.PIPE
254+
run_args["stderr"] = subprocess.PIPE
255+
run_args["stdin"] = subprocess.DEVNULL
256+
else:
257+
run_args["capture_output"] = capture_output
258+
259+
return run_args

codeflash/context/code_context_extractor.py

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,48 @@ def get_code_optimization_context(
127127
remove_docstrings=False,
128128
code_context_type=CodeContextType.TESTGEN,
129129
)
130+
131+
# Extract class definitions for imported types from project modules
132+
# This helps the LLM understand class constructors and structure
133+
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
134+
if imported_class_context.code_strings:
135+
# Merge imported class definitions into testgen context
136+
testgen_context = CodeStringsMarkdown(
137+
code_strings=testgen_context.code_strings + imported_class_context.code_strings
138+
)
139+
130140
testgen_markdown_code = testgen_context.markdown
131141
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
132142
if testgen_code_token_length > testgen_token_limit:
143+
# First try removing docstrings
133144
testgen_context = extract_code_markdown_context_from_files(
134145
helpers_of_fto_dict,
135146
helpers_of_helpers_dict,
136147
project_root_path,
137148
remove_docstrings=True,
138149
code_context_type=CodeContextType.TESTGEN,
139150
)
151+
# Re-extract imported classes (they may still fit)
152+
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
153+
if imported_class_context.code_strings:
154+
testgen_context = CodeStringsMarkdown(
155+
code_strings=testgen_context.code_strings + imported_class_context.code_strings
156+
)
140157
testgen_markdown_code = testgen_context.markdown
141158
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
142159
if testgen_code_token_length > testgen_token_limit:
143-
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
160+
# If still over limit, try without imported class definitions
161+
testgen_context = extract_code_markdown_context_from_files(
162+
helpers_of_fto_dict,
163+
helpers_of_helpers_dict,
164+
project_root_path,
165+
remove_docstrings=True,
166+
code_context_type=CodeContextType.TESTGEN,
167+
)
168+
testgen_markdown_code = testgen_context.markdown
169+
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
170+
if testgen_code_token_length > testgen_token_limit:
171+
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
144172
code_hash_context = hashing_code_context.markdown
145173
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
146174

@@ -489,6 +517,147 @@ def get_function_sources_from_jedi(
489517
return file_path_to_function_source, function_source_list
490518

491519

520+
def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
521+
"""Extract class definitions for imported types from project modules.
522+
523+
This function analyzes the imports in the extracted code context and fetches
524+
class definitions for any classes imported from project modules. This helps
525+
the LLM understand the actual class structure (constructors, methods, inheritance)
526+
rather than just seeing import statements.
527+
528+
Args:
529+
code_context: The already extracted code context containing imports
530+
project_root_path: Root path of the project
531+
532+
Returns:
533+
CodeStringsMarkdown containing class definitions from imported project modules
534+
535+
"""
536+
import jedi
537+
538+
# Collect all code from the context
539+
all_code = "\n".join(cs.code for cs in code_context.code_strings)
540+
541+
# Parse to find import statements
542+
try:
543+
tree = ast.parse(all_code)
544+
except SyntaxError:
545+
return CodeStringsMarkdown(code_strings=[])
546+
547+
# Collect imported names and their source modules
548+
imported_names: dict[str, str] = {} # name -> module_path
549+
for node in ast.walk(tree):
550+
if isinstance(node, ast.ImportFrom) and node.module:
551+
for alias in node.names:
552+
if alias.name != "*":
553+
imported_name = alias.asname if alias.asname else alias.name
554+
imported_names[imported_name] = node.module
555+
556+
if not imported_names:
557+
return CodeStringsMarkdown(code_strings=[])
558+
559+
# Track which classes we've already extracted to avoid duplicates
560+
extracted_classes: set[tuple[Path, str]] = set() # (file_path, class_name)
561+
562+
# Also track what's already defined in the context
563+
existing_definitions: set[str] = set()
564+
for node in ast.walk(tree):
565+
if isinstance(node, ast.ClassDef):
566+
existing_definitions.add(node.name)
567+
568+
class_code_strings: list[CodeString] = []
569+
570+
for name, module_name in imported_names.items():
571+
# Skip if already defined in context
572+
if name in existing_definitions:
573+
continue
574+
575+
# Try to find the module file using Jedi
576+
try:
577+
# Create a script that imports the module to resolve it
578+
test_code = f"import {module_name}"
579+
script = jedi.Script(test_code, project=jedi.Project(path=project_root_path))
580+
completions = script.goto(1, len(test_code))
581+
582+
if not completions:
583+
continue
584+
585+
module_path = completions[0].module_path
586+
if not module_path:
587+
continue
588+
589+
# Check if this is a project module (not stdlib/third-party)
590+
if not str(module_path).startswith(str(project_root_path) + os.sep):
591+
continue
592+
if path_belongs_to_site_packages(module_path):
593+
continue
594+
595+
# Skip if we've already extracted this class
596+
if (module_path, name) in extracted_classes:
597+
continue
598+
599+
# Parse the module to find the class definition
600+
module_source = module_path.read_text(encoding="utf-8")
601+
module_tree = ast.parse(module_source)
602+
603+
for node in ast.walk(module_tree):
604+
if isinstance(node, ast.ClassDef) and node.name == name:
605+
# Extract the class source code
606+
lines = module_source.split("\n")
607+
class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno])
608+
609+
# Also extract any necessary imports for the class (base classes, type hints)
610+
class_imports = _extract_imports_for_class(module_tree, node, module_source)
611+
612+
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
613+
614+
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
615+
extracted_classes.add((module_path, name))
616+
break
617+
618+
except Exception:
619+
logger.debug(f"Error extracting class definition for {name} from {module_name}")
620+
continue
621+
622+
return CodeStringsMarkdown(code_strings=class_code_strings)
623+
624+
625+
def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
626+
"""Extract import statements needed for a class definition.
627+
628+
This extracts imports for base classes and commonly used type annotations.
629+
"""
630+
needed_names: set[str] = set()
631+
632+
# Get base class names
633+
for base in class_node.bases:
634+
if isinstance(base, ast.Name):
635+
needed_names.add(base.id)
636+
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
637+
# For things like abc.ABC, we need the module name
638+
needed_names.add(base.value.id)
639+
640+
# Find imports that provide these names
641+
import_lines: list[str] = []
642+
source_lines = module_source.split("\n")
643+
644+
for node in module_tree.body:
645+
if isinstance(node, ast.Import):
646+
for alias in node.names:
647+
name = alias.asname if alias.asname else alias.name.split(".")[0]
648+
if name in needed_names:
649+
import_lines.append(source_lines[node.lineno - 1])
650+
break
651+
elif isinstance(node, ast.ImportFrom):
652+
for alias in node.names:
653+
name = alias.asname if alias.asname else alias.name
654+
if name in needed_names:
655+
import_lines.append(source_lines[node.lineno - 1])
656+
break
657+
658+
return "\n".join(import_lines)
659+
660+
492661
def is_dunder_method(name: str) -> bool:
493662
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
494663

codeflash/discovery/discover_unit_tests.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
if TYPE_CHECKING:
1818
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
19-
2019
from pydantic.dataclasses import dataclass
2120
from rich.panel import Panel
2221
from rich.text import Text
@@ -29,6 +28,7 @@
2928
module_name_from_file_path,
3029
)
3130
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
31+
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
3232
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
3333

3434
if TYPE_CHECKING:
@@ -331,7 +331,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
331331
# Be conservative except when an alias is used (which requires exact method matching)
332332
for target_func in fnames:
333333
if "." in target_func:
334-
class_name, method_name = target_func.split(".", 1)
334+
class_name, _method_name = target_func.split(".", 1)
335335
if aname == class_name and not alias.asname:
336336
self.found_any_target_function = True
337337
self.found_qualified_name = target_func
@@ -585,18 +585,18 @@ def discover_tests_pytest(
585585

586586
tmp_pickle_path = get_run_tmp_file("collected_tests.pkl")
587587
with custom_addopts():
588-
result = subprocess.run(
588+
run_kwargs = get_cross_platform_subprocess_run_args(
589+
cwd=project_root, check=False, text=True, capture_output=True
590+
)
591+
result = subprocess.run( # noqa: PLW1510
589592
[
590593
SAFE_SYS_EXECUTABLE,
591594
Path(__file__).parent / "pytest_new_process_discovery.py",
592595
str(project_root),
593596
str(tests_root),
594597
str(tmp_pickle_path),
595598
],
596-
cwd=project_root,
597-
check=False,
598-
capture_output=True,
599-
text=True,
599+
**run_kwargs,
600600
)
601601
try:
602602
with tmp_pickle_path.open(mode="rb") as f:

codeflash/discovery/functions_to_optimize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
175175
def get_functions_to_optimize(
176176
optimize_all: str | None,
177177
replay_test: list[Path] | None,
178-
file: Path | None,
178+
file: Path | str | None,
179179
only_get_this_function: str | None,
180180
test_cfg: TestConfig,
181181
ignore_paths: list[Path],
@@ -202,6 +202,7 @@ def get_functions_to_optimize(
202202
elif file is not None:
203203
logger.info("!lsp|Finding all functions in the file '%s'…", file)
204204
console.rule()
205+
file = Path(file) if isinstance(file, str) else file
205206
functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file)
206207
if only_get_this_function is not None:
207208
split_function = only_get_this_function.split(".")

0 commit comments

Comments
 (0)