Skip to content

Commit dc5090e

Browse files
authored
Merge pull request #2105 from codeflash-ai/feat/js-tracer-rebase-v2
feat(js): add JavaScript function tracer with Babel instrumentation
2 parents d2ec01a + 892bff4 commit dc5090e

13 files changed

Lines changed: 3227 additions & 400 deletions

File tree

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import re
5+
import sqlite3
6+
import textwrap
7+
from dataclasses import dataclass
8+
from pathlib import Path
9+
from typing import TYPE_CHECKING, Any, Optional
10+
11+
if TYPE_CHECKING:
12+
from collections.abc import Generator
13+
14+
15+
@dataclass
16+
class JavaScriptFunctionModule:
17+
function_name: str
18+
file_name: Path
19+
module_name: str
20+
class_name: Optional[str] = None
21+
line_no: Optional[int] = None
22+
23+
24+
def get_next_arg_and_return(
25+
trace_file: str, function_name: str, file_name: str, class_name: Optional[str] = None, num_to_get: int = 25
26+
) -> Generator[Any]:
27+
db = sqlite3.connect(trace_file)
28+
cur = db.cursor()
29+
30+
try:
31+
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
32+
tables = {row[0] for row in cur.fetchall()}
33+
34+
if "function_calls" in tables:
35+
if class_name:
36+
cursor = cur.execute(
37+
"SELECT args FROM function_calls WHERE function = ? AND filename = ? AND classname = ? AND type = 'call' ORDER BY time_ns ASC LIMIT ?",
38+
(function_name, file_name, class_name, num_to_get),
39+
)
40+
else:
41+
cursor = cur.execute(
42+
"SELECT args FROM function_calls WHERE function = ? AND filename = ? AND type = 'call' ORDER BY time_ns ASC LIMIT ?",
43+
(function_name, file_name, num_to_get),
44+
)
45+
46+
while (val := cursor.fetchone()) is not None:
47+
args_data = val[0]
48+
if isinstance(args_data, bytes):
49+
yield args_data
50+
else:
51+
yield args_data
52+
53+
elif "traces" in tables:
54+
if class_name:
55+
cursor = cur.execute(
56+
"SELECT args FROM traces WHERE function = ? AND file = ? ORDER BY id ASC LIMIT ?",
57+
(function_name, file_name, num_to_get),
58+
)
59+
else:
60+
cursor = cur.execute(
61+
"SELECT args FROM traces WHERE function = ? AND file = ? ORDER BY id ASC LIMIT ?",
62+
(function_name, file_name, num_to_get),
63+
)
64+
65+
while (val := cursor.fetchone()) is not None:
66+
yield val[0]
67+
68+
finally:
69+
db.close()
70+
71+
72+
def get_function_alias(module: str, function_name: str, class_name: Optional[str] = None) -> str:
73+
module_alias = re.sub(r"[^a-zA-Z0-9]", "_", module).strip("_")
74+
75+
if class_name:
76+
return f"{module_alias}_{class_name}_{function_name}"
77+
return f"{module_alias}_{function_name}"
78+
79+
80+
def create_javascript_replay_test(
81+
trace_file: str,
82+
functions: list[JavaScriptFunctionModule],
83+
max_run_count: int = 100,
84+
framework: str = "jest",
85+
project_root: Optional[Path] = None,
86+
) -> str:
87+
is_vitest = framework.lower() == "vitest"
88+
89+
imports = []
90+
91+
if is_vitest:
92+
imports.append("import { describe, test } from 'vitest';")
93+
94+
imports.append("const { getNextArg } = require('codeflash/replay');")
95+
imports.append("")
96+
97+
for func in functions:
98+
if func.function_name in ("__init__", "constructor"):
99+
continue
100+
101+
alias = get_function_alias(func.module_name, func.function_name, func.class_name)
102+
103+
if func.class_name:
104+
imports.append(f"const {{ {func.class_name}: {alias}_class }} = require('./{func.module_name}');")
105+
else:
106+
imports.append(f"const {{ {func.function_name}: {alias} }} = require('./{func.module_name}');")
107+
108+
imports.append("")
109+
110+
functions_to_test = [f.function_name for f in functions if f.function_name not in ("__init__", "constructor")]
111+
metadata = f"""const traceFilePath = '{trace_file}';
112+
const functions = {json.dumps(functions_to_test)};
113+
"""
114+
115+
test_cases = []
116+
117+
for func in functions:
118+
if func.function_name in ("__init__", "constructor"):
119+
continue
120+
121+
alias = get_function_alias(func.module_name, func.function_name, func.class_name)
122+
test_name = f"{func.class_name}.{func.function_name}" if func.class_name else func.function_name
123+
124+
if func.class_name:
125+
class_arg = f"'{func.class_name}'"
126+
test_body = textwrap.dedent(f"""
127+
describe('Replay: {test_name}', () => {{
128+
const traces = getNextArg(traceFilePath, '{func.function_name}', '{func.file_name.as_posix()}', {max_run_count}, {class_arg});
129+
130+
test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => {{
131+
const instance = new {alias}_class();
132+
instance.{func.function_name}(...args);
133+
}});
134+
}});
135+
""")
136+
else:
137+
test_body = textwrap.dedent(f"""
138+
describe('Replay: {test_name}', () => {{
139+
const traces = getNextArg(traceFilePath, '{func.function_name}', '{func.file_name.as_posix()}', {max_run_count});
140+
141+
test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => {{
142+
{alias}(...args);
143+
}});
144+
}});
145+
""")
146+
147+
test_cases.append(test_body)
148+
149+
return "\n".join(
150+
[
151+
"// Auto-generated replay test by Codeflash",
152+
"// Do not edit this file directly",
153+
"",
154+
*imports,
155+
metadata,
156+
*test_cases,
157+
]
158+
)
159+
160+
161+
def get_traced_functions_from_db(trace_file: Path) -> list[JavaScriptFunctionModule]:
162+
if not trace_file.exists():
163+
return []
164+
165+
try:
166+
conn = sqlite3.connect(trace_file)
167+
cursor = conn.cursor()
168+
169+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
170+
tables = {row[0] for row in cursor.fetchall()}
171+
172+
functions = []
173+
174+
if "function_calls" in tables:
175+
cursor.execute(
176+
"SELECT DISTINCT function, filename, classname, line_number FROM function_calls WHERE type = 'call'"
177+
)
178+
for row in cursor.fetchall():
179+
func_name = row[0]
180+
file_name = row[1]
181+
class_name = row[2]
182+
line_number = row[3]
183+
184+
module_path = file_name.replace("\\", "/").replace(".js", "").replace(".ts", "")
185+
module_path = module_path.removeprefix("./")
186+
187+
functions.append(
188+
JavaScriptFunctionModule(
189+
function_name=func_name,
190+
file_name=Path(file_name),
191+
module_name=module_path,
192+
class_name=class_name,
193+
line_no=line_number,
194+
)
195+
)
196+
197+
elif "traces" in tables:
198+
cursor.execute("SELECT DISTINCT function, file FROM traces")
199+
for row in cursor.fetchall():
200+
func_name = row[0]
201+
file_name = row[1]
202+
203+
module_path = file_name.replace("\\", "/").replace(".js", "").replace(".ts", "")
204+
module_path = module_path.removeprefix("./")
205+
206+
functions.append(
207+
JavaScriptFunctionModule(
208+
function_name=func_name, file_name=Path(file_name), module_name=module_path
209+
)
210+
)
211+
212+
conn.close()
213+
return functions
214+
215+
except Exception:
216+
return []
217+
218+
219+
def create_replay_test_file(
220+
trace_file: Path,
221+
output_path: Path,
222+
framework: str = "jest",
223+
max_run_count: int = 100,
224+
project_root: Optional[Path] = None,
225+
) -> Optional[Path]:
226+
functions = get_traced_functions_from_db(trace_file)
227+
228+
if not functions:
229+
return None
230+
231+
content = create_javascript_replay_test(
232+
trace_file=str(trace_file),
233+
functions=functions,
234+
max_run_count=max_run_count,
235+
framework=framework,
236+
project_root=project_root,
237+
)
238+
239+
try:
240+
output_path.parent.mkdir(parents=True, exist_ok=True)
241+
output_path.write_text(content, encoding="utf-8")
242+
return output_path
243+
except Exception:
244+
return None

codeflash/languages/javascript/support.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,30 +1702,21 @@ def instrument_for_behavior(
17021702
) -> str:
17031703
"""Add behavior instrumentation to capture inputs/outputs.
17041704
1705-
For JavaScript, this wraps functions to capture their arguments
1706-
and return values.
1705+
For JavaScript, instrumentation is handled at runtime by the Babel tracer plugin
1706+
(babel-tracer-plugin.js) via trace-runner.js. This method returns the source
1707+
unchanged since no source-level transformation is needed.
17071708
17081709
Args:
17091710
source: Source code to instrument.
17101711
functions: Functions to add tracing to.
17111712
output_file: Optional output file for traces.
17121713
17131714
Returns:
1714-
Instrumented source code.
1715+
Source code unchanged (Babel handles instrumentation at runtime).
17151716
17161717
"""
1717-
if not functions:
1718-
return source
1719-
1720-
from codeflash.languages.javascript.tracer import JavaScriptTracer
1721-
1722-
# Use first function's file path if output_file not specified
1723-
if output_file is None:
1724-
file_path = functions[0].file_path
1725-
output_file = file_path.parent / ".codeflash" / "traces.db"
1726-
1727-
tracer = JavaScriptTracer(output_file)
1728-
return tracer.instrument_source(source, functions[0].file_path, list(functions))
1718+
# JavaScript tracing is done at runtime via Babel plugin, not source transformation
1719+
return source
17291720

17301721
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
17311722
"""Add timing instrumentation to test code.

0 commit comments

Comments
 (0)