Skip to content

Commit 459129a

Browse files
committed
first draft, need to improve jit detection
1 parent 1047ad5 commit 459129a

2 files changed

Lines changed: 41 additions & 1 deletion

File tree

codeflash/code_utils/line_profile_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import re
56
from collections import defaultdict
67
from pathlib import Path
78
from typing import TYPE_CHECKING, Union
@@ -15,6 +16,28 @@
1516
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1617
from codeflash.models.models import CodeOptimizationContext
1718

19+
# Regex pattern to detect JIT compilation decorators from numba, torch, tensorflow, and jax
20+
JIT_DECORATOR_PATTERN = re.compile(
21+
r"@(?:"
22+
# numba decorators
23+
r"(?:numba\.)?(?:jit|njit|vectorize|guvectorize|stencil|cfunc|generated_jit)"
24+
r"|numba\.cuda\.jit"
25+
r"|cuda\.jit"
26+
# torch decorators
27+
r"|torch\.compile"
28+
r"|torch\.jit\.(?:script|trace)"
29+
# tensorflow decorators
30+
r"|(?:tf|tensorflow)\.function"
31+
# jax decorators
32+
r"|jax\.jit"
33+
r")"
34+
)
35+
36+
37+
def contains_jit_decorator(code: str) -> bool:
38+
"""Check if the code contains JIT compilation decorators from numba, torch, tensorflow, or jax."""
39+
return bool(JIT_DECORATOR_PATTERN.search(code))
40+
1841

1942
class LineProfilerDecoratorAdder(cst.CSTTransformer):
2043
"""Transformer that adds a decorator to a function with a specific qualified name."""

codeflash/optimization/function_optimizer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
6767
from codeflash.code_utils.git_utils import git_root_dir
6868
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
69-
from codeflash.code_utils.line_profile_utils import add_decorator_imports
69+
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
7070
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
7171
from codeflash.code_utils.time_utils import humanize_runtime
7272
from codeflash.context import code_context_extractor
@@ -2389,6 +2389,23 @@ def get_test_env(
23892389
def line_profiler_step(
23902390
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
23912391
) -> dict:
2392+
# Check if candidate code contains JIT decorators - line profiler doesn't work with JIT compiled code
2393+
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
2394+
if contains_jit_decorator(candidate_fto_code):
2395+
logger.info(
2396+
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
2397+
)
2398+
return {"timings": {}, "unit": 0, "str_out": ""}
2399+
2400+
# Check helper code for JIT decorators
2401+
for module_abspath in original_helper_code:
2402+
candidate_helper_code = Path(module_abspath).read_text("utf-8")
2403+
if contains_jit_decorator(candidate_helper_code):
2404+
logger.info(
2405+
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
2406+
)
2407+
return {"timings": {}, "unit": 0, "str_out": ""}
2408+
23922409
try:
23932410
console.rule()
23942411

0 commit comments

Comments
 (0)