Skip to content

Commit 7251715

Browse files
committed
feat: send concurrency metrics to LLM for async function optimization
For async functions, run the concurrency benchmark before submitting the optimization request so the LLM receives runtime proof that the function blocks (concurrency_ratio ≈ 1.0). This steers the model toward correct async optimizations (e.g. time.sleep → asyncio.sleep). Sync functions keep the existing parallel test-gen + optimization flow.
1 parent cafcd7f commit 7251715

2 files changed

Lines changed: 59 additions & 12 deletions

File tree

codeflash/api/aiservice.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def optimize_code(
157157
n_candidates: int = 5,
158158
is_numerical_code: bool | None = None,
159159
rerun_trace_id: str | None = None,
160+
concurrency_metrics: dict[str, float] | None = None,
160161
) -> list[OptimizedCandidate]:
161162
"""Optimize the given code for performance by making a request to the Django endpoint.
162163
@@ -200,6 +201,9 @@ def optimize_code(
200201
"rerun_trace_id": rerun_trace_id,
201202
}
202203

204+
if concurrency_metrics is not None:
205+
payload["concurrency_metrics"] = concurrency_metrics
206+
203207
self.add_language_metadata(payload, language_version, module_system)
204208

205209
# DEBUG: Print payload language field

codeflash/languages/function_optimizer.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
AIServiceRefinerRequest,
8282
BestOptimization,
8383
CandidateEvaluationContext,
84+
ConcurrencyMetrics,
8485
GeneratedTests,
8586
GeneratedTestsList,
8687
OptimizationReviewResult,
@@ -502,6 +503,7 @@ def __init__(
502503
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
503504
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
504505
self.test_files = TestFiles(test_files=[])
506+
self.cached_concurrency_metrics: ConcurrencyMetrics | None = None
505507

506508
default_effort = getattr(args, "effort", EffortLevel.MEDIUM.value) if args else EffortLevel.MEDIUM.value
507509
self.effort = effort_override or default_effort
@@ -788,20 +790,53 @@ def optimize_function(self) -> Result[BestOptimization, str]:
788790
):
789791
console.rule()
790792
new_code_context = code_context
791-
# Generate tests and optimizations in parallel
792-
future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context)
793-
future_optimizations = self.executor.submit(
794-
self.generate_optimizations,
795-
read_writable_code=code_context.read_writable_code,
796-
read_only_context_code=code_context.read_only_context_code,
797-
run_experiment=should_run_experiment,
798-
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
799-
)
800793

801-
concurrent.futures.wait([future_tests, future_optimizations])
794+
if self.function_to_optimize.is_async:
795+
future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context)
796+
concurrent.futures.wait([future_tests])
797+
test_setup_result = future_tests.result()
798+
799+
pre_optimization_concurrency_metrics: dict[str, float] | None = None
800+
if is_successful(test_setup_result) and self.test_files.test_files:
801+
test_env = self.get_test_env(
802+
codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1
803+
)
804+
metrics = self.run_concurrency_benchmark(
805+
code_context=code_context, original_helper_code=original_helper_code, test_env=test_env
806+
)
807+
if metrics is not None:
808+
self.cached_concurrency_metrics = metrics
809+
pre_optimization_concurrency_metrics = {
810+
"concurrency_ratio": metrics.concurrency_ratio,
811+
"sequential_time_ns": float(metrics.sequential_time_ns),
812+
"concurrent_time_ns": float(metrics.concurrent_time_ns),
813+
}
814+
815+
future_optimizations = self.executor.submit(
816+
self.generate_optimizations,
817+
read_writable_code=code_context.read_writable_code,
818+
read_only_context_code=code_context.read_only_context_code,
819+
run_experiment=should_run_experiment,
820+
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
821+
concurrency_metrics=pre_optimization_concurrency_metrics,
822+
)
823+
concurrent.futures.wait([future_optimizations])
824+
optimization_result = future_optimizations.result()
825+
else:
826+
future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context)
827+
future_optimizations = self.executor.submit(
828+
self.generate_optimizations,
829+
read_writable_code=code_context.read_writable_code,
830+
read_only_context_code=code_context.read_only_context_code,
831+
run_experiment=should_run_experiment,
832+
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
833+
)
834+
835+
concurrent.futures.wait([future_tests, future_optimizations])
836+
837+
test_setup_result = future_tests.result()
838+
optimization_result = future_optimizations.result()
802839

803-
test_setup_result = future_tests.result()
804-
optimization_result = future_optimizations.result()
805840
console.rule()
806841

807842
if not is_successful(test_setup_result):
@@ -1861,6 +1896,7 @@ def generate_optimizations(
18611896
read_only_context_code: str,
18621897
run_experiment: bool = False,
18631898
is_numerical_code: bool | None = None,
1899+
concurrency_metrics: dict[str, float] | None = None,
18641900
) -> Result[tuple[OptimizationSet, str], str]:
18651901
"""Generate optimization candidates for the function. Backend handles multi-model diversity."""
18661902
n_candidates = get_effort_value(EffortKeys.N_OPTIMIZER_CANDIDATES, self.effort)
@@ -1876,6 +1912,7 @@ def generate_optimizations(
18761912
n_candidates=n_candidates,
18771913
is_numerical_code=is_numerical_code,
18781914
rerun_trace_id=self.rerun_trace_id,
1915+
concurrency_metrics=concurrency_metrics,
18791916
)
18801917

18811918
future_references = self.executor.submit(
@@ -1902,6 +1939,7 @@ def generate_optimizations(
19021939
language_version=self.language_support.language_version,
19031940
is_async=self.function_to_optimize.is_async,
19041941
n_candidates=n_candidates,
1942+
concurrency_metrics=concurrency_metrics,
19051943
)
19061944
futures.append(future_candidates_exp)
19071945

@@ -3291,6 +3329,11 @@ def run_concurrency_benchmark(
32913329
if not self.function_to_optimize.is_async:
32923330
return None
32933331

3332+
if self.cached_concurrency_metrics is not None:
3333+
cached = self.cached_concurrency_metrics
3334+
self.cached_concurrency_metrics = None
3335+
return cached
3336+
32943337
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
32953338

32963339
try:

0 commit comments

Comments
 (0)