Skip to content

Commit a23322f

Browse files
authored
Merge branch 'main' into inquirer
2 parents f0fe4d8 + d2d57fe commit a23322f

20 files changed

Lines changed: 303 additions & 96 deletions

codeflash/api/aiservice.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import platform
66
import time
7+
from itertools import count
78
from typing import TYPE_CHECKING, Any, cast
89

910
import requests
@@ -12,7 +13,6 @@
1213
from codeflash.cli_cmds.console import console, logger
1314
from codeflash.code_utils.code_replacer import is_zero_diff
1415
from codeflash.code_utils.code_utils import unified_diff_strings
15-
from codeflash.code_utils.config_consts import N_CANDIDATES_EFFECTIVE, N_CANDIDATES_LP_EFFECTIVE
1616
from codeflash.code_utils.env_utils import get_codeflash_api_key
1717
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1818
from codeflash.code_utils.time_utils import humanize_runtime
@@ -40,6 +40,11 @@ class AiServiceClient:
4040
def __init__(self) -> None:
4141
self.base_url = self.get_aiservice_base_url()
4242
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
43+
self.llm_call_counter = count(1)
44+
45+
def get_next_sequence(self) -> int:
46+
"""Get the next LLM call sequence number."""
47+
return next(self.llm_call_counter)
4348

4449
def get_aiservice_base_url(self) -> str:
4550
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
@@ -106,6 +111,7 @@ def _get_valid_candidates(
106111
optimization_id=opt["optimization_id"],
107112
source=source,
108113
parent_id=opt.get("parent_id", None),
114+
model=opt.get("model"),
109115
)
110116
)
111117
return candidates
@@ -115,7 +121,6 @@ def optimize_python_code( # noqa: D417
115121
source_code: str,
116122
dependency_code: str,
117123
trace_id: str,
118-
num_candidates: int = 10,
119124
experiment_metadata: ExperimentMetadata | None = None,
120125
*,
121126
is_async: bool = False,
@@ -127,46 +132,49 @@ def optimize_python_code( # noqa: D417
127132
- source_code (str): The python code to optimize.
128133
- dependency_code (str): The dependency code used as read-only context for the optimization
129134
- trace_id (str): Trace id of optimization run
130-
- num_candidates (int): Number of optimization variants to generate. Default is 10.
131135
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
136+
- is_async (bool): Whether the function being optimized is async
132137
133138
Returns
134139
-------
135140
- List[OptimizationCandidate]: A list of Optimization Candidates.
136141
137142
"""
143+
logger.info("Generating optimized candidates…")
144+
console.rule()
138145
start_time = time.perf_counter()
139146
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
140147

141148
payload = {
142149
"source_code": source_code,
143150
"dependency_code": dependency_code,
144-
"num_variants": num_candidates,
145151
"trace_id": trace_id,
146152
"python_version": platform.python_version(),
147153
"experiment_metadata": experiment_metadata,
148154
"codeflash_version": codeflash_version,
149155
"current_username": get_last_commit_author_if_pr_exists(None),
150156
"repo_owner": git_repo_owner,
151157
"repo_name": git_repo_name,
152-
"n_candidates": N_CANDIDATES_EFFECTIVE,
153158
"is_async": is_async,
159+
"lsp_mode": is_LSP_enabled(),
160+
"call_sequence": self.get_next_sequence(),
154161
}
162+
logger.debug(f"Sending optimize request: trace_id={trace_id}, lsp_mode={payload['lsp_mode']}")
155163

156-
logger.info("!lsp|Generating optimized candidates…")
157-
console.rule()
158164
try:
159165
response = self.make_ai_service_request("/optimize", payload=payload, timeout=60)
160166
except requests.exceptions.RequestException as e:
161167
logger.exception(f"Error generating optimized candidates: {e}")
162168
ph("cli-optimize-error-caught", {"error": str(e)})
169+
console.rule()
163170
return []
164171

165172
if response.status_code == 200:
166173
optimizations_json = response.json()["optimizations"]
167-
console.rule()
168174
end_time = time.perf_counter()
169175
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")
176+
logger.info(f"!lsp|Received {len(optimizations_json)} optimization candidates.")
177+
console.rule()
170178
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE)
171179
try:
172180
error = response.json()["error"]
@@ -183,54 +191,53 @@ def optimize_python_code_line_profiler( # noqa: D417
183191
dependency_code: str,
184192
trace_id: str,
185193
line_profiler_results: str,
186-
num_candidates: int = 10,
187194
experiment_metadata: ExperimentMetadata | None = None,
188195
) -> list[OptimizedCandidate]:
189-
"""Optimize the given python code for performance by making a request to the Django endpoint.
196+
"""Optimize the given python code for performance using line profiler results.
190197
191198
Parameters
192199
----------
193200
- source_code (str): The python code to optimize.
194201
- dependency_code (str): The dependency code used as read-only context for the optimization
195202
- trace_id (str): Trace id of optimization run
196-
- num_candidates (int): Number of optimization variants to generate. Default is 10.
203+
- line_profiler_results (str): Line profiler output to guide optimization
197204
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
198205
199206
Returns
200207
-------
201208
- List[OptimizationCandidate]: A list of Optimization Candidates.
202209
203210
"""
211+
if line_profiler_results == "":
212+
logger.info("No LineProfiler results were provided, Skipping optimization.")
213+
return []
214+
215+
logger.info("Generating optimized candidates with line profiler…")
216+
console.rule()
217+
204218
payload = {
205219
"source_code": source_code,
206220
"dependency_code": dependency_code,
207-
"num_variants": num_candidates,
208221
"line_profiler_results": line_profiler_results,
209222
"trace_id": trace_id,
210223
"python_version": platform.python_version(),
211224
"experiment_metadata": experiment_metadata,
212225
"codeflash_version": codeflash_version,
213226
"lsp_mode": is_LSP_enabled(),
214-
"n_candidates_lp": N_CANDIDATES_LP_EFFECTIVE,
227+
"call_sequence": self.get_next_sequence(),
215228
}
216229

217-
console.rule()
218-
if line_profiler_results == "":
219-
logger.info("No LineProfiler results were provided, Skipping optimization.")
220-
console.rule()
221-
return []
222230
try:
223231
response = self.make_ai_service_request("/optimize-line-profiler", payload=payload, timeout=60)
224232
except requests.exceptions.RequestException as e:
225233
logger.exception(f"Error generating optimized candidates: {e}")
226234
ph("cli-optimize-error-caught", {"error": str(e)})
235+
console.rule()
227236
return []
228237

229238
if response.status_code == 200:
230239
optimizations_json = response.json()["optimizations"]
231-
logger.info(
232-
f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information."
233-
)
240+
logger.info(f"!lsp|Received {len(optimizations_json)} line profiler optimization candidates.")
234241
console.rule()
235242
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP)
236243
try:
@@ -268,6 +275,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
268275
"trace_id": opt.trace_id,
269276
"function_references": opt.function_references,
270277
"python_version": platform.python_version(),
278+
"call_sequence": self.get_next_sequence(),
271279
}
272280
for opt in request
273281
]
@@ -402,6 +410,7 @@ def get_new_explanation( # noqa: D417
402410
"throughput_improvement": throughput_improvement,
403411
"function_references": function_references,
404412
"codeflash_version": codeflash_version,
413+
"call_sequence": self.get_next_sequence(),
405414
}
406415
logger.info("loading|Generating explanation")
407416
console.rule()
@@ -564,6 +573,7 @@ def generate_regression_tests( # noqa: D417
564573
"python_version": platform.python_version(),
565574
"codeflash_version": codeflash_version,
566575
"is_async": function_to_optimize.is_async,
576+
"call_sequence": self.get_next_sequence(),
567577
}
568578
try:
569579
response = self.make_ai_service_request("/testgen", payload=payload, timeout=90)
@@ -650,6 +660,7 @@ def get_optimization_review(
650660
"codeflash_version": codeflash_version,
651661
"calling_fn_details": calling_fn_details,
652662
"python_version": platform.python_version(),
663+
"call_sequence": self.get_next_sequence(),
653664
}
654665
console.rule()
655666
try:

codeflash/cli_cmds/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def parse_args() -> Namespace:
7979
parser.add_argument(
8080
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
8181
)
82+
parser.add_argument(
83+
"--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization."
84+
)
8285
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
8386
parser.add_argument(
8487
"--verify-setup",

codeflash/code_utils/config_consts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
1515
N_CANDIDATES_LP = 6
1616

17+
# pytest loop stability
18+
# For now, we use strict thresholds (large windows and low tolerances), since this is still experimental.
19+
STABILITY_WINDOW_SIZE = 0.35 # 35% of total window
20+
STABILITY_CENTER_TOLERANCE = 0.0025 # ±0.25% around median
21+
STABILITY_SPREAD_TOLERANCE = 0.0025 # 0.25% window spread
22+
1723
# Refinement
1824
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
1925
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2

codeflash/code_utils/env_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
2020
if not formatter_cmds or formatter_cmds[0] == "disabled":
2121
return True
22-
2322
first_cmd = formatter_cmds[0]
2423
cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd]
2524

codeflash/code_utils/formatter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,13 @@ def apply_formatter_cmds(
4646
print_status: bool, # noqa
4747
exit_on_failure: bool = True, # noqa
4848
) -> tuple[Path, str, bool]:
49-
should_make_copy = False
50-
file_path = path
51-
52-
if test_dir_str:
53-
should_make_copy = True
54-
file_path = Path(test_dir_str) / "temp.py"
55-
5649
if not path.exists():
5750
msg = f"File {path} does not exist. Cannot apply formatter commands."
5851
raise FileNotFoundError(msg)
5952

60-
if should_make_copy:
53+
file_path = path
54+
if test_dir_str:
55+
file_path = Path(test_dir_str) / "temp.py"
6156
shutil.copy2(path, file_path)
6257

6358
file_token = "$file" # noqa: S105

codeflash/discovery/discover_unit_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ def process_test_files(
751751

752752
tests_cache = TestsCache(project_root_path)
753753
logger.info("!lsp|Discovering tests and processing unit tests")
754+
console.rule()
754755
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
755756
progress,
756757
task_id,

codeflash/models/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class AIServiceRefinerRequest:
4646
original_line_profiler_results: str
4747
optimized_line_profiler_results: str
4848
function_references: str | None = None
49+
call_sequence: int | None = None
4950

5051

5152
class TestDiffScope(str, Enum):
@@ -464,6 +465,7 @@ class OptimizedCandidate:
464465
optimization_id: str
465466
source: OptimizedCandidateSource
466467
parent_id: str | None = None
468+
model: str | None = None # Which LLM model generated this candidate
467469

468470

469471
@dataclass(frozen=True)

0 commit comments

Comments
 (0)