Skip to content

Commit dcc5bea

Browse files
committed
perf: cache file reads and AST parses during discovery pass
Introduces a discovery-scoped cache (`discovery_cache()` context manager) that ensures each file is read from disk at most once and parsed into an AST at most once within a single discovery run. Key changes: - `read_file_cached()`: returns cached file content when `discovery_cache()` is active, falls back to normal read otherwise - `parse_ast_cached()`: returns cached `ast.Module` when active - `get_functions_to_optimize()`: wraps its body with `discovery_cache()` - `find_all_functions_in_file()`: uses `read_file_cached` - `inspect_top_level_functions_or_methods()`: uses `parse_ast_cached` - `get_all_replay_test_functions()`: uses `parse_ast_cached` - JS/TS export helpers: use `read_file_cached` - Removed dead code: `_find_all_functions_via_language_support()` (never called, had a type error passing Path as source str) Signature changes: None. All public function signatures are unchanged. The cache is transparent -- when not inside `discovery_cache()`, all functions behave identically to before (direct reads/parses).
1 parent 5c1cc1e commit dcc5bea

2 files changed

Lines changed: 204 additions & 46 deletions

File tree

codeflash/discovery/functions_to_optimize.py

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
if TYPE_CHECKING:
3939
from argparse import Namespace
40+
from collections.abc import Generator
4041

4142
from codeflash.models.models import CodeOptimizationContext
4243
from codeflash.verification.verification_utils import TestConfig
@@ -51,6 +52,46 @@ class FunctionProperties:
5152
staticmethod_class_name: Optional[str]
5253

5354

55+
# =============================================================================
56+
# Discovery-scoped file/AST cache
57+
# =============================================================================
58+
59+
_active_discovery_cache: dict[Path, str] | None = None
60+
_active_ast_cache: dict[Path, ast.Module] | None = None
61+
62+
63+
@contextlib.contextmanager
64+
def discovery_cache() -> Generator[None, None, None]:
65+
global _active_discovery_cache, _active_ast_cache
66+
_active_discovery_cache = {}
67+
_active_ast_cache = {}
68+
try:
69+
yield
70+
finally:
71+
_active_discovery_cache = None
72+
_active_ast_cache = None
73+
74+
75+
def read_file_cached(file_path: Path) -> str:
76+
if _active_discovery_cache is not None:
77+
if file_path not in _active_discovery_cache:
78+
_active_discovery_cache[file_path] = file_path.read_text(encoding="utf-8")
79+
return _active_discovery_cache[file_path]
80+
return file_path.read_text(encoding="utf-8")
81+
82+
83+
def parse_ast_cached(file_path: Path, source: str | None = None) -> ast.Module:
84+
if _active_ast_cache is not None:
85+
if file_path not in _active_ast_cache:
86+
if source is None:
87+
source = read_file_cached(file_path)
88+
_active_ast_cache[file_path] = ast.parse(source)
89+
return _active_ast_cache[file_path]
90+
if source is None:
91+
source = read_file_cached(file_path)
92+
return ast.parse(source)
93+
94+
5495
# =============================================================================
5596
# Multi-language support helpers
5697
# =============================================================================
@@ -152,7 +193,7 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
152193
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
153194

154195
try:
155-
source = file_path.read_text(encoding="utf-8")
196+
source = read_file_cached(file_path)
156197
analyzer = get_analyzer_for_file(file_path)
157198
return analyzer.is_function_exported(source, function_name)
158199
except Exception as e:
@@ -170,7 +211,7 @@ def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: s
170211
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
171212

172213
try:
173-
source = file_path.read_text(encoding="utf-8")
214+
source = read_file_cached(file_path)
174215
analyzer = get_analyzer_for_file(file_path)
175216
all_funcs = analyzer.find_functions(
176217
source, include_methods=True, include_arrow_functions=True, require_name=True
@@ -183,28 +224,6 @@ def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: s
183224
return False
184225

185226

186-
def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
187-
"""Find all optimizable functions using the language support abstraction.
188-
189-
This function uses the registered language support for the file's language
190-
to discover functions, then converts them to FunctionToOptimize instances.
191-
"""
192-
from codeflash.languages.base import FunctionFilterCriteria
193-
194-
functions: dict[Path, list[FunctionToOptimize]] = {}
195-
196-
try:
197-
lang_support = get_language_support(file_path)
198-
require_return = lang_support.language != Language.JAVA
199-
criteria = FunctionFilterCriteria(require_return=require_return)
200-
source = file_path.read_text(encoding="utf-8")
201-
functions[file_path] = lang_support.discover_functions(source, file_path, criteria)
202-
except Exception as e:
203-
logger.debug(f"Failed to discover functions in {file_path}: {e}")
204-
205-
return functions
206-
207-
208227
def get_functions_to_optimize(
209228
optimize_all: str | None,
210229
replay_test: list[Path] | None,
@@ -222,7 +241,7 @@ def get_functions_to_optimize(
222241
functions: dict[Path, list[FunctionToOptimize]]
223242
trace_file_path: Path | None = None
224243
is_lsp = is_LSP_enabled()
225-
with warnings.catch_warnings():
244+
with discovery_cache(), warnings.catch_warnings():
226245
warnings.simplefilter(action="ignore", category=SyntaxWarning)
227246
if optimize_all:
228247
logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all)
@@ -489,7 +508,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
489508
lang_support = get_language_support(file_path)
490509
require_return = lang_support.language != Language.JAVA
491510
criteria = FunctionFilterCriteria(require_return=require_return)
492-
source = file_path.read_text(encoding="utf-8")
511+
source = read_file_cached(file_path)
493512
return {file_path: lang_support.discover_functions(source, file_path, criteria)}
494513
except Exception as e:
495514
logger.debug(f"Failed to discover functions in {file_path}: {e}")
@@ -506,21 +525,20 @@ def get_all_replay_test_functions(
506525
trace_file_path: Path | None = None
507526
for replay_test_file in replay_test:
508527
try:
509-
with replay_test_file.open("r", encoding="utf8") as f:
510-
tree = ast.parse(f.read())
511-
for node in ast.walk(tree):
512-
if isinstance(node, ast.Assign):
513-
for target in node.targets:
514-
if (
515-
isinstance(target, ast.Name)
516-
and target.id == "trace_file_path"
517-
and isinstance(node.value, ast.Constant)
518-
and isinstance(node.value.value, str)
519-
):
520-
trace_file_path = Path(node.value.value)
521-
break
522-
if trace_file_path:
528+
tree = parse_ast_cached(replay_test_file)
529+
for node in ast.walk(tree):
530+
if isinstance(node, ast.Assign):
531+
for target in node.targets:
532+
if (
533+
isinstance(target, ast.Name)
534+
and target.id == "trace_file_path"
535+
and isinstance(node.value, ast.Constant)
536+
and isinstance(node.value.value, str)
537+
):
538+
trace_file_path = Path(node.value.value)
523539
break
540+
if trace_file_path:
541+
break
524542
if trace_file_path:
525543
break
526544
except Exception as e:
@@ -634,7 +652,7 @@ def _get_java_replay_test_functions(
634652
from codeflash.languages.registry import get_language_support
635653

636654
lang_support = get_language_support(source_file)
637-
source_code = source_file.read_text(encoding="utf-8")
655+
source_code = read_file_cached(source_file)
638656
all_functions = lang_support.discover_functions(source_code, source_file)
639657

640658
for func in all_functions:
@@ -762,11 +780,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
762780
def inspect_top_level_functions_or_methods(
763781
file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None
764782
) -> FunctionProperties | None:
765-
with file_name.open(encoding="utf8") as file:
766-
try:
767-
ast_module = ast.parse(file.read())
768-
except Exception:
769-
return None
783+
try:
784+
ast_module = parse_ast_cached(file_name)
785+
except Exception:
786+
return None
770787
visitor = TopLevelFunctionOrMethodVisitor(
771788
file_name=file_name, function_or_method_name=function_or_method_name, class_name=class_name, line_no=line_no
772789
)

tests/test_discovery_cache.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import tempfile
5+
from pathlib import Path
6+
from unittest.mock import patch
7+
8+
from codeflash.discovery.functions_to_optimize import (
9+
discovery_cache,
10+
find_all_functions_in_file,
11+
inspect_top_level_functions_or_methods,
12+
parse_ast_cached,
13+
read_file_cached,
14+
)
15+
16+
17+
def test_read_file_cached_without_context_manager(tmp_path: Path) -> None:
18+
f = tmp_path / "sample.py"
19+
f.write_text("x = 1\n", encoding="utf-8")
20+
assert read_file_cached(f) == "x = 1\n"
21+
22+
23+
def test_read_file_cached_returns_same_object_within_context(tmp_path: Path) -> None:
24+
f = tmp_path / "sample.py"
25+
f.write_text("x = 1\n", encoding="utf-8")
26+
with discovery_cache():
27+
result1 = read_file_cached(f)
28+
result2 = read_file_cached(f)
29+
assert result1 is result2
30+
31+
32+
def test_read_file_cached_does_not_persist_across_contexts(tmp_path: Path) -> None:
33+
f = tmp_path / "sample.py"
34+
f.write_text("x = 1\n", encoding="utf-8")
35+
with discovery_cache():
36+
result1 = read_file_cached(f)
37+
f.write_text("x = 2\n", encoding="utf-8")
38+
with discovery_cache():
39+
result2 = read_file_cached(f)
40+
assert result1 != result2
41+
42+
43+
def test_parse_ast_cached_returns_same_object_within_context(tmp_path: Path) -> None:
44+
f = tmp_path / "sample.py"
45+
f.write_text("def foo():\n return 1\n", encoding="utf-8")
46+
with discovery_cache():
47+
tree1 = parse_ast_cached(f)
48+
tree2 = parse_ast_cached(f)
49+
assert tree1 is tree2
50+
assert isinstance(tree1, ast.Module)
51+
52+
53+
def test_parse_ast_cached_uses_provided_source(tmp_path: Path) -> None:
54+
f = tmp_path / "sample.py"
55+
f.write_text("x = 1\n", encoding="utf-8")
56+
source = "y = 2\n"
57+
with discovery_cache():
58+
tree = parse_ast_cached(f, source=source)
59+
assert any(
60+
isinstance(n, ast.Assign)
61+
and isinstance(n.targets[0], ast.Name)
62+
and n.targets[0].id == "y"
63+
for n in ast.walk(tree)
64+
)
65+
66+
67+
def test_discovery_cache_avoids_redundant_reads(tmp_path: Path) -> None:
68+
f = tmp_path / "module.py"
69+
f.write_text("def bar():\n return 42\n", encoding="utf-8")
70+
with discovery_cache():
71+
with patch.object(Path, "read_text", wraps=f.read_text) as mock_read:
72+
read_file_cached(f)
73+
read_file_cached(f)
74+
read_file_cached(f)
75+
assert mock_read.call_count == 1
76+
77+
78+
def test_find_all_functions_in_file_uses_cache(tmp_path: Path) -> None:
79+
f = tmp_path / "module.py"
80+
f.write_text("def compute(x):\n return x * 2\n", encoding="utf-8")
81+
with discovery_cache():
82+
result = find_all_functions_in_file(f)
83+
assert f in result
84+
assert result[f][0].function_name == "compute"
85+
86+
87+
def test_inspect_top_level_functions_uses_cache(tmp_path: Path) -> None:
88+
f = tmp_path / "module.py"
89+
f.write_text("def top_func(a, b):\n return a + b\n", encoding="utf-8")
90+
with discovery_cache():
91+
props = inspect_top_level_functions_or_methods(f, "top_func")
92+
assert props is not None
93+
assert props.is_top_level
94+
assert props.has_args
95+
96+
97+
def test_find_and_inspect_share_cached_content(tmp_path: Path) -> None:
98+
f = tmp_path / "module.py"
99+
f.write_text(
100+
"class MyClass:\n def method(self):\n return 1\n\ndef standalone():\n return 2\n",
101+
encoding="utf-8",
102+
)
103+
with discovery_cache():
104+
with patch.object(Path, "read_text", wraps=f.read_text) as mock_read:
105+
find_all_functions_in_file(f)
106+
props = inspect_top_level_functions_or_methods(f, "method", class_name="MyClass")
107+
assert mock_read.call_count == 1
108+
assert props is not None
109+
assert props.is_top_level
110+
111+
112+
def test_discovery_results_correct_with_multiple_files(tmp_path: Path) -> None:
113+
f1 = tmp_path / "a.py"
114+
f1.write_text("def alpha():\n return 'a'\n", encoding="utf-8")
115+
f2 = tmp_path / "b.py"
116+
f2.write_text("def beta(x):\n return x + 1\n", encoding="utf-8")
117+
118+
with discovery_cache():
119+
r1 = find_all_functions_in_file(f1)
120+
r2 = find_all_functions_in_file(f2)
121+
122+
assert r1[f1][0].function_name == "alpha"
123+
assert r2[f2][0].function_name == "beta"
124+
125+
126+
def test_cache_handles_invalid_syntax_gracefully(tmp_path: Path) -> None:
127+
f = tmp_path / "broken.py"
128+
f.write_text("def incomplete(:\n", encoding="utf-8")
129+
with discovery_cache():
130+
result = find_all_functions_in_file(f)
131+
assert result == {}
132+
133+
134+
def test_cache_handles_nonexistent_file_in_parse_ast(tmp_path: Path) -> None:
135+
f = tmp_path / "nonexistent.py"
136+
with discovery_cache():
137+
try:
138+
parse_ast_cached(f)
139+
assert False, "Should have raised"
140+
except (FileNotFoundError, OSError):
141+
pass

0 commit comments

Comments
 (0)