Skip to content

Commit ff6465b

Browse files
authored
Merge pull request #1175 from codeflash-ai/disable-jit-optimizations
add --no-jit-opts to disable JIT optimizations
2 parents 87894e8 + 63cc8e7 commit ff6465b

3 files changed

Lines changed: 9 additions & 4 deletions

File tree

codeflash/api/aiservice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def get_jit_rewritten_code( # noqa: D417
224224
logger.info("!lsp|Rewriting as a JIT function…")
225225
console.rule()
226226
try:
227-
response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60)
227+
response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=self.timeout)
228228
except requests.exceptions.RequestException as e:
229229
logger.exception(f"Error generating jit rewritten candidate: {e}")
230230
ph("cli-jit-rewrite-error-caught", {"error": str(e)})

codeflash/cli_cmds/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def parse_args() -> Namespace:
8484
parser.add_argument(
8585
"--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization."
8686
)
87+
parser.add_argument(
88+
"--no-jit-opts", action="store_true", help="Do not generate JIT-compiled optimizations for numerical code."
89+
)
8790
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
8891
parser.add_argument(
8992
"--verify-setup",

codeflash/optimization/function_optimizer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
610610
):
611611
console.rule()
612612
new_code_context = code_context
613-
if self.is_numerical_code: # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax)
613+
if (
614+
self.is_numerical_code and not self.args.no_jit_opts
615+
): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax)
614616
jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code(
615617
code_context.read_writable_code.markdown, self.function_trace_id
616618
)
@@ -639,7 +641,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
639641
read_writable_code=code_context.read_writable_code,
640642
read_only_context_code=code_context.read_only_context_code,
641643
run_experiment=should_run_experiment,
642-
is_numerical_code=self.is_numerical_code,
644+
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
643645
)
644646

645647
concurrent.futures.wait([future_tests, future_optimizations])
@@ -1158,7 +1160,7 @@ def determine_best_candidate(
11581160
)
11591161
if self.experiment_id
11601162
else None,
1161-
is_numerical_code=self.is_numerical_code,
1163+
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
11621164
)
11631165

11641166
processor = CandidateProcessor(

0 commit comments

Comments
 (0)