Skip to content

Commit 3aca8de

Browse files
authored
Merge branch 'main' into disable-lp-jit
2 parents 459129a + 9f4fe4f commit 3aca8de

29 files changed

Lines changed: 726 additions & 163 deletions
Binary file not shown.
Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,39 @@
11
from concurrent.futures import ThreadPoolExecutor
2-
from time import sleep
32

43

54
def funcA(number):
6-
number = number if number < 1000 else 1000
5+
number = number if number < 100 else 100
76
k = 0
8-
for i in range(number * 100):
7+
for i in range(number * 10):
98
k += i
10-
# Simplify the for loop by using sum with a range object
119
j = sum(range(number))
12-
13-
# Use a generator expression directly in join for more efficiency
1410
return " ".join(str(i) for i in range(number))
1511

1612

1713
def test_threadpool() -> None:
18-
pool = ThreadPoolExecutor(max_workers=3)
19-
args = list(range(10, 31, 10))
14+
pool = ThreadPoolExecutor(max_workers=2)
15+
args = [5, 10, 15]
2016
result = pool.map(funcA, args)
2117

2218
for r in result:
2319
print(r)
2420

2521
class AlexNet:
26-
def __init__(self, num_classes=1000):
22+
def __init__(self, num_classes=10):
2723
self.num_classes = num_classes
28-
self.features_size = 256 * 6 * 6
2924

3025
def forward(self, x):
31-
features = self._extract_features(x)
32-
33-
output = self._classify(features)
34-
return output
35-
36-
def _extract_features(self, x):
37-
result = []
38-
for i in range(len(x)):
39-
pass
40-
41-
return result
42-
43-
def _classify(self, features):
44-
total = sum(features)
45-
return [total % self.num_classes for _ in features]
46-
47-
class SimpleModel:
48-
@staticmethod
49-
def predict(data):
50-
result = []
51-
sleep(0.1) # can be optimized away
52-
for i in range(500):
53-
for x in data:
54-
computation = 0
55-
computation += x * i ** 2
56-
result.append(computation)
57-
return result
58-
59-
@classmethod
60-
def create_default(cls):
61-
return cls()
26+
result = 0
27+
for val in x:
28+
result += val * val
29+
return result % self.num_classes
6230

6331

6432
def test_models():
6533
model = AlexNet(num_classes=10)
6634
input_data = [1, 2, 3, 4, 5]
6735
result = model.forward(input_data)
6836

69-
model2 = SimpleModel.create_default()
70-
prediction = model2.predict(input_data)
71-
7237
if __name__ == "__main__":
7338
test_threadpool()
7439
test_models()

codeflash/api/aiservice.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from codeflash.models.models import (
2222
AIServiceRefinerRequest,
2323
CodeStringsMarkdown,
24+
OptimizationReviewResult,
2425
OptimizedCandidate,
2526
OptimizedCandidateSource,
2627
)
@@ -652,7 +653,7 @@ def get_optimization_review(
652653
replay_tests: str,
653654
concolic_tests: str, # noqa: ARG002
654655
calling_fn_details: str,
655-
) -> str:
656+
) -> OptimizationReviewResult:
656657
"""Compute the optimization review of current Pull Request.
657658
658659
Args:
@@ -670,7 +671,7 @@ def get_optimization_review(
670671
671672
Returns:
672673
-------
673-
- 'high', 'medium' or 'low' optimization review
674+
OptimizationReviewResult with review ('high', 'medium', 'low', or '') and explanation
674675
675676
"""
676677
diff_str = "\n".join(
@@ -706,18 +707,21 @@ def get_optimization_review(
706707
except requests.exceptions.RequestException as e:
707708
logger.exception(f"Error generating optimization refinements: {e}")
708709
ph("cli-optimize-error-caught", {"error": str(e)})
709-
return ""
710+
return OptimizationReviewResult(review="", explanation="")
710711

711712
if response.status_code == 200:
712-
return cast("str", response.json()["review"])
713+
data = response.json()
714+
return OptimizationReviewResult(
715+
review=cast("str", data["review"]), explanation=cast("str", data.get("review_explanation", ""))
716+
)
713717
try:
714718
error = cast("str", response.json()["error"])
715719
except Exception:
716720
error = response.text
717721
logger.error(f"Error generating optimization review: {response.status_code} - {error}")
718722
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
719723
console.rule()
720-
return ""
724+
return OptimizationReviewResult(review="", explanation="")
721725

722726
def generate_workflow_steps(
723727
self,

codeflash/benchmarking/trace_benchmarks.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from codeflash.cli_cmds.console import logger
99
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
10+
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
1011

1112

1213
def trace_benchmarks_pytest(
@@ -17,20 +18,18 @@ def trace_benchmarks_pytest(
1718
benchmark_env["PYTHONPATH"] = str(project_root)
1819
else:
1920
benchmark_env["PYTHONPATH"] += os.pathsep + str(project_root)
20-
result = subprocess.run(
21+
run_args = get_cross_platform_subprocess_run_args(
22+
cwd=project_root, env=benchmark_env, timeout=timeout, check=False, text=True, capture_output=True
23+
)
24+
result = subprocess.run( # noqa: PLW1510
2125
[
2226
SAFE_SYS_EXECUTABLE,
2327
Path(__file__).parent / "pytest_new_process_trace_benchmarks.py",
2428
benchmarks_root,
2529
tests_root,
2630
trace_file,
2731
],
28-
cwd=project_root,
29-
check=False,
30-
capture_output=True,
31-
text=True,
32-
env=benchmark_env,
33-
timeout=timeout,
32+
**run_args,
3433
)
3534
if result.returncode != 0:
3635
if "ERROR collecting" in result.stdout:

codeflash/cli_cmds/console.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@
5858
)
5959

6060

61+
class DummyTask:
62+
def __init__(self) -> None:
63+
self.id = 0
64+
65+
66+
class DummyProgress:
67+
def __init__(self) -> None:
68+
pass
69+
70+
def advance(self, task_id: TaskID, advance: int = 1) -> None:
71+
pass
72+
73+
6174
def lsp_log(message: LspMessage) -> None:
6275
if not is_LSP_enabled():
6376
return
@@ -120,10 +133,6 @@ def progress_bar(
120133
logger.info(message)
121134

122135
# Create a fake task ID since we still need to yield something
123-
class DummyTask:
124-
def __init__(self) -> None:
125-
self.id = 0
126-
127136
yield DummyTask().id
128137
else:
129138
progress = Progress(
@@ -141,6 +150,13 @@ def __init__(self) -> None:
141150
@contextmanager
142151
def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]:
143152
"""Progress bar for test files."""
153+
if is_LSP_enabled():
154+
lsp_log(LspTextMessage(text=description, takes_time=True))
155+
dummy_progress = DummyProgress()
156+
dummy_task = DummyTask()
157+
yield dummy_progress, dummy_task.id
158+
return
159+
144160
with Progress(
145161
SpinnerColumn(next(spinners)),
146162
TextColumn("[progress.description]{task.description}"),

codeflash/code_utils/code_replacer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def replace_function_definitions_in_module(
447447

448448
new_code: str = replace_functions_and_add_imports(
449449
# adding the global assignments before replacing the code, not after
450-
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
450+
# because of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
451451
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
452452
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
453453
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,

codeflash/code_utils/git_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_git_diff(
3434
only_this_commit + "^1", only_this_commit, ignore_blank_lines=True, ignore_space_at_eol=True
3535
)
3636
elif uncommitted_changes:
37-
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
37+
uni_diff_text = repository.git.diff("HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
3838
else:
3939
uni_diff_text = repository.git.diff(
4040
commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True

codeflash/code_utils/git_worktree_utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import configparser
4+
import shutil
5+
import stat
46
import subprocess
57
import tempfile
68
import time
79
from pathlib import Path
8-
from typing import Optional
10+
from typing import Any, Callable, Optional
911

1012
import git
1113

@@ -95,10 +97,24 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
9597
return worktree_dir
9698

9799

100+
def _handle_remove_readonly(
101+
func: Callable[[str], None], path: str, exc_info: tuple[type[BaseException], BaseException, Any]
102+
) -> None:
103+
"""Error handler for shutil.rmtree to handle read-only files on Windows."""
104+
if isinstance(exc_info[1], PermissionError):
105+
Path(path).chmod(stat.S_IWUSR | stat.S_IRUSR | stat.S_IXUSR)
106+
func(path)
107+
else:
108+
raise exc_info[1]
109+
110+
98111
def remove_worktree(worktree_dir: Path) -> None:
112+
"""Remove a git worktree directory."""
113+
if not worktree_dir.exists():
114+
return
99115
try:
100-
repository = git.Repo(worktree_dir, search_parent_directories=True)
101-
repository.git.worktree("remove", "--force", worktree_dir)
116+
shutil.rmtree(worktree_dir, onerror=_handle_remove_readonly)
117+
logger.debug(f"Removed worktree: {worktree_dir}")
102118
except Exception:
103119
logger.exception(f"Failed to remove worktree: {worktree_dir}")
104120

codeflash/code_utils/shell_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import contextlib
44
import os
55
import re
6+
import subprocess
7+
import sys
68
from pathlib import Path
79
from typing import TYPE_CHECKING, Optional
810

@@ -11,8 +13,11 @@
1113
from codeflash.either import Failure, Success
1214

1315
if TYPE_CHECKING:
16+
from collections.abc import Mapping
17+
1418
from codeflash.either import Result
1519

20+
1621
# PowerShell patterns and prefixes
1722
POWERSHELL_RC_EXPORT_PATTERN = re.compile(
1823
r'^\$env:CODEFLASH_API_KEY\s*=\s*(?:"|\')?(cf-[^\s"\']+)(?:"|\')?\s*$', re.MULTILINE
@@ -231,3 +236,24 @@ def save_api_key_to_rc(api_key: str) -> Result[str, str]:
231236
f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}"
232237
f"{LF}{api_key_line}{LF}"
233238
)
239+
240+
241+
def get_cross_platform_subprocess_run_args(
242+
cwd: Path | str | None = None,
243+
env: Mapping[str, str] | None = None,
244+
timeout: Optional[float] = None,
245+
check: bool = False, # noqa: FBT001, FBT002
246+
text: bool = True, # noqa: FBT001, FBT002
247+
capture_output: bool = True, # noqa: FBT001, FBT002 (only for non-Windows)
248+
) -> dict[str, str]:
249+
run_args = {"cwd": cwd, "env": env, "text": text, "timeout": timeout, "check": check}
250+
if sys.platform == "win32":
251+
creationflags = subprocess.CREATE_NEW_PROCESS_GROUP
252+
run_args["creationflags"] = creationflags
253+
run_args["stdout"] = subprocess.PIPE
254+
run_args["stderr"] = subprocess.PIPE
255+
run_args["stdin"] = subprocess.DEVNULL
256+
else:
257+
run_args["capture_output"] = capture_output
258+
259+
return run_args

0 commit comments

Comments
 (0)