-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathinstrument_codeflash_trace.py
More file actions
124 lines (98 loc) · 5.15 KB
/
instrument_codeflash_trace.py
File metadata and controls
124 lines (98 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union
import libcst as cst
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.code_utils.formatter import sort_imports
if TYPE_CHECKING:
from pathlib import Path
from libcst import BaseStatement, ClassDef, FlattenSentinel, FunctionDef, RemovalSentinel
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
class AddDecoratorTransformer(cst.CSTTransformer):
def __init__(self, target_functions: set[tuple[str, str]]) -> None:
super().__init__()
self.target_functions = target_functions
self.added_codeflash_trace = False
self.class_name = ""
self.function_name = ""
self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace"))
def leave_ClassDef(
self, original_node: ClassDef, updated_node: ClassDef
) -> Union[BaseStatement, FlattenSentinel[BaseStatement], RemovalSentinel]:
if self.class_name == original_node.name.value:
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
return updated_node
def visit_ClassDef(self, node: ClassDef) -> Optional[bool]:
if self.class_name: # Don't go into nested class
return False
self.class_name = node.name.value
return None
def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
if self.function_name: # Don't go into nested function
return False
self.function_name = node.name.value
return None
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
if self.function_name == original_node.name.value:
self.function_name = ""
if (self.class_name, original_node.name.value) in self.target_functions:
# Add the new decorator after any existing decorators, so it gets executed first
updated_decorators = [*list(updated_node.decorators), self.decorator]
self.added_codeflash_trace = True
return updated_node.with_changes(decorators=updated_decorators)
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Create import statement for codeflash_trace
if not self.added_codeflash_trace:
return updated_node
import_stmt = cst.SimpleStatementLine(
body=[
cst.ImportFrom(
module=cst.Attribute(
value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")),
attr=cst.Name(value="codeflash_trace"),
),
names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))],
)
]
)
# Insert at the beginning of the file. We'll use isort later to sort the imports.
new_body = [import_stmt, *list(updated_node.body)]
return updated_node.with_changes(body=new_body)
def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
"""Add codeflash_trace to a function.
Args:
----
code: The source code as a string
functions_to_optimize: List of FunctionToOptimize instances containing function details
Returns:
-------
The modified source code as a string
"""
target_functions = set()
for function_to_optimize in functions_to_optimize:
class_name = ""
if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef":
class_name = function_to_optimize.parents[0].name
target_functions.add((class_name, function_to_optimize.function_name))
transformer = AddDecoratorTransformer(target_functions=target_functions)
module = cst.parse_module(code)
modified_module = module.visit(transformer)
return modified_module.code
def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None:
"""Instrument codeflash_trace decorator to functions to optimize."""
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
# Skip codeflash's own benchmarking and picklepatch modules to avoid circular imports
# (codeflash_trace.py imports from picklepatch, and instrumenting these would cause
# them to import codeflash_trace back, creating a circular import)
# Use rpartition to find the last "codeflash" in path (handles nested paths)
_, sep, after = file_path.as_posix().rpartition("/codeflash/")
if sep:
submodule = after.partition("/")[0]
if submodule in ("benchmarking", "picklepatch"):
continue
original_code = file_path.read_text(encoding="utf-8")
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
# Modify the code
modified_code = sort_imports(code=new_code, float_to_top=True)
# Write the modified code back to the file
file_path.write_text(modified_code, encoding="utf-8")