Skip to content

Commit 2db9549

Browse files
authored
Merge pull request #1020 from codeflash-ai/disable-lp-jit
Disable Line Profiler Execution when dealing with a jit compiled function
2 parents 9f4fe4f + 3aca8de commit 2db9549

2 files changed

Lines changed: 41 additions & 1 deletion

File tree

codeflash/code_utils/line_profile_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import re
56
from collections import defaultdict
67
from pathlib import Path
78
from typing import TYPE_CHECKING, Union
@@ -15,6 +16,28 @@
1516
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1617
from codeflash.models.models import CodeOptimizationContext
1718

19+
# Regex pattern to detect JIT compilation decorators from numba, torch, tensorflow, and jax
20+
JIT_DECORATOR_PATTERN = re.compile(
21+
r"@(?:"
22+
# numba decorators
23+
r"(?:numba\.)?(?:jit|njit|vectorize|guvectorize|stencil|cfunc|generated_jit)"
24+
r"|numba\.cuda\.jit"
25+
r"|cuda\.jit"
26+
# torch decorators
27+
r"|torch\.compile"
28+
r"|torch\.jit\.(?:script|trace)"
29+
# tensorflow decorators
30+
r"|(?:tf|tensorflow)\.function"
31+
# jax decorators
32+
r"|jax\.jit"
33+
r")"
34+
)
35+
36+
37+
def contains_jit_decorator(code: str) -> bool:
38+
"""Check if the code contains JIT compilation decorators from numba, torch, tensorflow, or jax."""
39+
return bool(JIT_DECORATOR_PATTERN.search(code))
40+
1841

1942
class LineProfilerDecoratorAdder(cst.CSTTransformer):
2043
"""Transformer that adds a decorator to a function with a specific qualified name."""

codeflash/optimization/function_optimizer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
6767
from codeflash.code_utils.git_utils import git_root_dir
6868
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
69-
from codeflash.code_utils.line_profile_utils import add_decorator_imports
69+
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
7070
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
7171
from codeflash.code_utils.time_utils import humanize_runtime
7272
from codeflash.context import code_context_extractor
@@ -2412,6 +2412,23 @@ def get_test_env(
24122412
def line_profiler_step(
24132413
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
24142414
) -> dict:
2415+
# Check if candidate code contains JIT decorators - line profiler doesn't work with JIT compiled code
2416+
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
2417+
if contains_jit_decorator(candidate_fto_code):
2418+
logger.info(
2419+
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
2420+
)
2421+
return {"timings": {}, "unit": 0, "str_out": ""}
2422+
2423+
# Check helper code for JIT decorators
2424+
for module_abspath in original_helper_code:
2425+
candidate_helper_code = Path(module_abspath).read_text("utf-8")
2426+
if contains_jit_decorator(candidate_helper_code):
2427+
logger.info(
2428+
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
2429+
)
2430+
return {"timings": {}, "unit": 0, "str_out": ""}
2431+
24152432
try:
24162433
console.rule()
24172434

0 commit comments

Comments
 (0)