Skip to content

Commit ee86330

Browse files
committed
feat(js): add Python-side JavaScript tracer integration
Add Python modules for JavaScript tracing orchestration: - replay_test.py: Generate Jest/Vitest replay tests from traces - tracer_runner.py: Run trace-runner.js and detect test frameworks - tracer.py: Refactored to use Babel-only approach, removed legacy source transformation code
1 parent bd63042 commit ee86330

3 files changed

Lines changed: 919 additions & 337 deletions

File tree

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
"""JavaScript replay test generation.
2+
3+
This module provides functionality to generate replay tests from traced JavaScript
4+
function calls. Replay tests allow verifying that optimized code produces the same
5+
results as the original code.
6+
7+
The generated tests can be run with Jest or Vitest, depending on the project's
8+
test framework configuration.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import json
14+
import sqlite3
15+
import textwrap
16+
from dataclasses import dataclass
17+
from pathlib import Path
18+
from typing import TYPE_CHECKING, Any, Optional
19+
20+
if TYPE_CHECKING:
21+
from collections.abc import Generator
22+
23+
24+
@dataclass
25+
class JavaScriptFunctionModule:
26+
"""Information about a traced JavaScript function for replay test generation."""
27+
28+
function_name: str
29+
file_name: Path
30+
module_name: str
31+
class_name: Optional[str] = None
32+
line_no: Optional[int] = None
33+
34+
35+
def get_next_arg_and_return(
36+
trace_file: str, function_name: str, file_name: str, class_name: Optional[str] = None, num_to_get: int = 25
37+
) -> Generator[Any]:
38+
"""Get traced function arguments from the database.
39+
40+
This mirrors the Python version in codeflash/tracing/replay_test.py.
41+
42+
Args:
43+
trace_file: Path to the trace SQLite database.
44+
function_name: Name of the function.
45+
file_name: Path to the source file.
46+
class_name: Optional class name for methods.
47+
num_to_get: Maximum number of traces to retrieve.
48+
49+
Yields:
50+
Serialized argument data for each traced call.
51+
52+
"""
53+
db = sqlite3.connect(trace_file)
54+
cur = db.cursor()
55+
56+
# Try the new schema first (function_calls table)
57+
try:
58+
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
59+
tables = {row[0] for row in cur.fetchall()}
60+
61+
if "function_calls" in tables:
62+
if class_name:
63+
cursor = cur.execute(
64+
"SELECT args FROM function_calls WHERE function = ? AND filename = ? AND classname = ? AND type = 'call' ORDER BY time_ns ASC LIMIT ?",
65+
(function_name, file_name, class_name, num_to_get),
66+
)
67+
else:
68+
cursor = cur.execute(
69+
"SELECT args FROM function_calls WHERE function = ? AND filename = ? AND type = 'call' ORDER BY time_ns ASC LIMIT ?",
70+
(function_name, file_name, num_to_get),
71+
)
72+
73+
while (val := cursor.fetchone()) is not None:
74+
# args is stored as JSON or binary blob
75+
args_data = val[0]
76+
if isinstance(args_data, bytes):
77+
yield args_data
78+
else:
79+
yield args_data
80+
81+
elif "traces" in tables:
82+
# Legacy schema
83+
if class_name:
84+
cursor = cur.execute(
85+
"SELECT args FROM traces WHERE function = ? AND file = ? ORDER BY id ASC LIMIT ?",
86+
(function_name, file_name, num_to_get),
87+
)
88+
else:
89+
cursor = cur.execute(
90+
"SELECT args FROM traces WHERE function = ? AND file = ? ORDER BY id ASC LIMIT ?",
91+
(function_name, file_name, num_to_get),
92+
)
93+
94+
while (val := cursor.fetchone()) is not None:
95+
yield val[0]
96+
97+
finally:
98+
db.close()
99+
100+
101+
def get_function_alias(module: str, function_name: str, class_name: Optional[str] = None) -> str:
102+
"""Generate a unique alias for a function import.
103+
104+
Args:
105+
module: Module path.
106+
function_name: Function name.
107+
class_name: Optional class name.
108+
109+
Returns:
110+
A valid JavaScript identifier for the function.
111+
112+
"""
113+
import re
114+
115+
# Normalize module path to valid identifier
116+
module_alias = re.sub(r"[^a-zA-Z0-9]", "_", module).strip("_")
117+
118+
if class_name:
119+
return f"{module_alias}_{class_name}_{function_name}"
120+
return f"{module_alias}_{function_name}"
121+
122+
123+
def create_javascript_replay_test(
124+
trace_file: str,
125+
functions: list[JavaScriptFunctionModule],
126+
max_run_count: int = 100,
127+
framework: str = "jest",
128+
project_root: Optional[Path] = None,
129+
) -> str:
130+
"""Generate a JavaScript replay test file from traced function calls.
131+
132+
This mirrors the Python version in codeflash/tracing/replay_test.py but
133+
generates JavaScript test code for Jest or Vitest.
134+
135+
Args:
136+
trace_file: Path to the trace SQLite database.
137+
functions: List of functions to generate tests for.
138+
max_run_count: Maximum number of test cases per function.
139+
framework: Test framework ('jest' or 'vitest').
140+
project_root: Project root for calculating relative imports.
141+
142+
Returns:
143+
Generated test file content as a string.
144+
145+
"""
146+
is_vitest = framework.lower() == "vitest"
147+
148+
# Build imports section
149+
imports = []
150+
151+
if is_vitest:
152+
imports.append("import { describe, test } from 'vitest';")
153+
154+
imports.append("const { getNextArg } = require('codeflash/replay');")
155+
imports.append("")
156+
157+
# Build function imports
158+
for func in functions:
159+
if func.function_name in ("__init__", "constructor"):
160+
# Skip constructors
161+
continue
162+
163+
alias = get_function_alias(func.module_name, func.function_name, func.class_name)
164+
165+
if func.class_name:
166+
imports.append(f"const {{ {func.class_name}: {alias}_class }} = require('./{func.module_name}');")
167+
else:
168+
imports.append(f"const {{ {func.function_name}: {alias} }} = require('./{func.module_name}');")
169+
170+
imports.append("")
171+
172+
# Metadata
173+
functions_to_test = [f.function_name for f in functions if f.function_name not in ("__init__", "constructor")]
174+
metadata = f"""const traceFilePath = '{trace_file}';
175+
const functions = {json.dumps(functions_to_test)};
176+
"""
177+
178+
# Build test cases
179+
test_cases = []
180+
181+
for func in functions:
182+
if func.function_name in ("__init__", "constructor"):
183+
continue
184+
185+
alias = get_function_alias(func.module_name, func.function_name, func.class_name)
186+
test_name = f"{func.class_name}.{func.function_name}" if func.class_name else func.function_name
187+
188+
if func.class_name:
189+
# Method test - need to instantiate the class
190+
class_arg = f"'{func.class_name}'"
191+
test_body = textwrap.dedent(f"""
192+
describe('Replay: {test_name}', () => {{
193+
const traces = getNextArg(traceFilePath, '{func.function_name}', '{func.file_name.as_posix()}', {max_run_count}, {class_arg});
194+
195+
test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => {{
196+
// For instance methods, we need to create an instance
197+
// The traced args may include 'this' context as first argument
198+
const instance = new {alias}_class();
199+
instance.{func.function_name}(...args);
200+
}});
201+
}});
202+
""")
203+
else:
204+
# Regular function test
205+
test_body = textwrap.dedent(f"""
206+
describe('Replay: {test_name}', () => {{
207+
const traces = getNextArg(traceFilePath, '{func.function_name}', '{func.file_name.as_posix()}', {max_run_count});
208+
209+
test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => {{
210+
{alias}(...args);
211+
}});
212+
}});
213+
""")
214+
215+
test_cases.append(test_body)
216+
217+
# Combine all parts
218+
return "\n".join(
219+
[
220+
"// Auto-generated replay test by Codeflash",
221+
"// Do not edit this file directly",
222+
"",
223+
*imports,
224+
metadata,
225+
*test_cases,
226+
]
227+
)
228+
229+
230+
def get_traced_functions_from_db(trace_file: Path) -> list[JavaScriptFunctionModule]:
231+
"""Get list of functions that were traced from the database.
232+
233+
Args:
234+
trace_file: Path to trace database.
235+
236+
Returns:
237+
List of traced function information.
238+
239+
"""
240+
if not trace_file.exists():
241+
return []
242+
243+
try:
244+
conn = sqlite3.connect(trace_file)
245+
cursor = conn.cursor()
246+
247+
# Check schema
248+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
249+
tables = {row[0] for row in cursor.fetchall()}
250+
251+
functions = []
252+
253+
if "function_calls" in tables:
254+
cursor.execute(
255+
"SELECT DISTINCT function, filename, classname, line_number FROM function_calls WHERE type = 'call'"
256+
)
257+
for row in cursor.fetchall():
258+
func_name = row[0]
259+
file_name = row[1]
260+
class_name = row[2]
261+
line_number = row[3]
262+
263+
# Calculate module path from filename
264+
module_path = file_name.replace("\\", "/").replace(".js", "").replace(".ts", "")
265+
module_path = module_path.removeprefix("./")
266+
267+
functions.append(
268+
JavaScriptFunctionModule(
269+
function_name=func_name,
270+
file_name=Path(file_name),
271+
module_name=module_path,
272+
class_name=class_name,
273+
line_no=line_number,
274+
)
275+
)
276+
277+
elif "traces" in tables:
278+
# Legacy schema
279+
cursor.execute("SELECT DISTINCT function, file FROM traces")
280+
for row in cursor.fetchall():
281+
func_name = row[0]
282+
file_name = row[1]
283+
284+
module_path = file_name.replace("\\", "/").replace(".js", "").replace(".ts", "")
285+
module_path = module_path.removeprefix("./")
286+
287+
functions.append(
288+
JavaScriptFunctionModule(
289+
function_name=func_name, file_name=Path(file_name), module_name=module_path
290+
)
291+
)
292+
293+
conn.close()
294+
return functions
295+
296+
except Exception:
297+
return []
298+
299+
300+
def create_replay_test_file(
301+
trace_file: Path,
302+
output_path: Path,
303+
framework: str = "jest",
304+
max_run_count: int = 100,
305+
project_root: Optional[Path] = None,
306+
) -> Optional[Path]:
307+
"""Generate a replay test file from a trace database.
308+
309+
This is the main entry point for creating JavaScript replay tests.
310+
311+
Args:
312+
trace_file: Path to the trace SQLite database.
313+
output_path: Path to write the test file.
314+
framework: Test framework ('jest' or 'vitest').
315+
max_run_count: Maximum number of test cases per function.
316+
project_root: Project root for calculating relative imports.
317+
318+
Returns:
319+
Path to generated test file, or None if generation failed.
320+
321+
"""
322+
functions = get_traced_functions_from_db(trace_file)
323+
324+
if not functions:
325+
return None
326+
327+
content = create_javascript_replay_test(
328+
trace_file=str(trace_file),
329+
functions=functions,
330+
max_run_count=max_run_count,
331+
framework=framework,
332+
project_root=project_root,
333+
)
334+
335+
try:
336+
output_path.parent.mkdir(parents=True, exist_ok=True)
337+
output_path.write_text(content)
338+
return output_path
339+
except Exception:
340+
return None

0 commit comments

Comments
 (0)