Skip to content

Commit c884d99

Browse files
add mcp server for behavioural correctness and benchmarking
1 parent cafcd7f commit c884d99

12 files changed

Lines changed: 2140 additions & 65 deletions

File tree

mcp_server/__init__.py

Whitespace-only changes.

mcp_server/db.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from __future__ import annotations
2+
3+
import contextlib
4+
import json
5+
import os
6+
import pickle
7+
import sqlite3
8+
from datetime import datetime, timezone
9+
from pathlib import Path
10+
from typing import Any
11+
12+
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
13+
from codeflash.models.test_type import TestType
14+
15+
_DEFAULT_DB_PATH = Path.home() / ".codeflash" / "mcp_results.db"
16+
17+
18+
def get_db_path() -> Path:
19+
path = Path(os.environ.get("CODEFLASH_MCP_DB_PATH", str(_DEFAULT_DB_PATH)))
20+
path.parent.mkdir(parents=True, exist_ok=True)
21+
return path
22+
23+
24+
def get_connection() -> sqlite3.Connection:
25+
db_path = get_db_path()
26+
conn = sqlite3.connect(str(db_path))
27+
conn.execute("PRAGMA journal_mode=WAL")
28+
conn.execute("PRAGMA foreign_keys=ON")
29+
_create_tables(conn)
30+
return conn
31+
32+
33+
def _create_tables(conn: sqlite3.Connection) -> None:
34+
conn.executescript("""
35+
CREATE TABLE IF NOT EXISTS runs (
36+
run_id TEXT PRIMARY KEY,
37+
run_type TEXT NOT NULL,
38+
created_at TEXT NOT NULL,
39+
project_root TEXT NOT NULL,
40+
test_files TEXT NOT NULL,
41+
total_runtime_ns INTEGER,
42+
total_tests INTEGER,
43+
passed INTEGER,
44+
failed INTEGER,
45+
loops_executed INTEGER,
46+
raw_stdout TEXT,
47+
raw_stderr TEXT
48+
);
49+
50+
CREATE TABLE IF NOT EXISTS test_invocations (
51+
id INTEGER PRIMARY KEY AUTOINCREMENT,
52+
run_id TEXT NOT NULL REFERENCES runs(run_id),
53+
test_module_path TEXT,
54+
test_class_name TEXT,
55+
test_function_name TEXT,
56+
function_getting_tested TEXT,
57+
loop_index INTEGER,
58+
iteration_id TEXT,
59+
runtime_ns INTEGER,
60+
return_value BLOB,
61+
verification_type TEXT,
62+
did_pass INTEGER NOT NULL,
63+
timed_out INTEGER DEFAULT 0,
64+
error_message TEXT,
65+
stdout TEXT,
66+
test_type TEXT
67+
);
68+
""")
69+
conn.commit()
70+
71+
72+
def store_run(
73+
conn: sqlite3.Connection,
74+
run_id: str,
75+
run_type: str,
76+
project_root: str,
77+
test_files: list[str],
78+
test_results: TestResults,
79+
raw_stdout: str = "",
80+
raw_stderr: str = "",
81+
) -> None:
82+
total_runtime_ns = test_results.total_passed_runtime() if test_results else 0
83+
total_tests = len(test_results)
84+
passed = sum(1 for r in test_results if r.did_pass)
85+
failed = total_tests - passed
86+
loops_executed = test_results.effective_loop_count() if test_results else 0
87+
88+
conn.execute(
89+
"INSERT INTO runs (run_id, run_type, created_at, project_root, test_files, "
90+
"total_runtime_ns, total_tests, passed, failed, loops_executed, raw_stdout, raw_stderr) "
91+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
92+
(
93+
run_id,
94+
run_type,
95+
datetime.now(timezone.utc).isoformat(),
96+
project_root,
97+
json.dumps(test_files),
98+
total_runtime_ns,
99+
total_tests,
100+
passed,
101+
failed,
102+
loops_executed,
103+
raw_stdout,
104+
raw_stderr,
105+
),
106+
)
107+
108+
for invocation in test_results:
109+
return_value_blob = None
110+
if invocation.return_value is not None:
111+
with contextlib.suppress(Exception):
112+
return_value_blob = pickle.dumps(invocation.return_value)
113+
114+
conn.execute(
115+
"INSERT INTO test_invocations (run_id, test_module_path, test_class_name, "
116+
"test_function_name, function_getting_tested, loop_index, iteration_id, "
117+
"runtime_ns, return_value, verification_type, did_pass, timed_out, error_message, stdout, test_type) "
118+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
119+
(
120+
run_id,
121+
invocation.id.test_module_path,
122+
invocation.id.test_class_name,
123+
invocation.id.test_function_name,
124+
invocation.id.function_getting_tested,
125+
invocation.loop_index,
126+
invocation.id.iteration_id,
127+
invocation.runtime,
128+
return_value_blob,
129+
invocation.verification_type,
130+
int(invocation.did_pass),
131+
int(invocation.timed_out or False),
132+
None,
133+
invocation.stdout,
134+
invocation.test_type.value if invocation.test_type else None,
135+
),
136+
)
137+
conn.commit()
138+
139+
140+
def load_test_results(conn: sqlite3.Connection, run_id: str) -> TestResults:
141+
rows = conn.execute(
142+
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, "
143+
"loop_index, iteration_id, runtime_ns, return_value, verification_type, did_pass, "
144+
"timed_out, error_message, stdout, test_type FROM test_invocations WHERE run_id = ?",
145+
(run_id,),
146+
).fetchall()
147+
148+
test_results = TestResults()
149+
for row in rows:
150+
(
151+
test_module_path,
152+
test_class_name,
153+
test_function_name,
154+
function_getting_tested,
155+
loop_index,
156+
iteration_id,
157+
runtime_ns,
158+
return_value_blob,
159+
verification_type,
160+
did_pass,
161+
timed_out,
162+
_,
163+
stdout,
164+
test_type_str,
165+
) = row
166+
167+
return_value = None
168+
if return_value_blob is not None:
169+
with contextlib.suppress(Exception):
170+
return_value = pickle.loads(return_value_blob)
171+
172+
try:
173+
test_type = TestType(int(test_type_str)) if test_type_str else TestType.EXISTING_UNIT_TEST
174+
except (ValueError, TypeError):
175+
test_type = TestType.EXISTING_UNIT_TEST
176+
177+
test_results.add(
178+
FunctionTestInvocation(
179+
loop_index=loop_index or 1,
180+
id=InvocationId(
181+
test_module_path=test_module_path or "",
182+
test_class_name=test_class_name,
183+
test_function_name=test_function_name or "",
184+
function_getting_tested=function_getting_tested or "",
185+
iteration_id=iteration_id,
186+
),
187+
file_name=Path("unknown"),
188+
did_pass=bool(did_pass),
189+
runtime=runtime_ns,
190+
test_framework="pytest",
191+
test_type=test_type,
192+
return_value=return_value,
193+
timed_out=bool(timed_out),
194+
verification_type=verification_type,
195+
stdout=stdout,
196+
)
197+
)
198+
return test_results
199+
200+
201+
def load_run_metadata(conn: sqlite3.Connection, run_id: str) -> dict[str, Any] | None:
202+
row = conn.execute(
203+
"SELECT run_type, created_at, project_root, test_files, total_runtime_ns, "
204+
"total_tests, passed, failed, loops_executed FROM runs WHERE run_id = ?",
205+
(run_id,),
206+
).fetchone()
207+
if row is None:
208+
return None
209+
return {
210+
"run_type": row[0],
211+
"created_at": row[1],
212+
"project_root": row[2],
213+
"test_files": json.loads(row[3]),
214+
"total_runtime_ns": row[4],
215+
"total_tests": row[5],
216+
"passed": row[6],
217+
"failed": row[7],
218+
"loops_executed": row[8],
219+
}

mcp_server/models.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
5+
6+
@dataclass
7+
class TestInvocationResult:
8+
test_module_path: str
9+
test_class_name: str | None
10+
test_function_name: str
11+
function_getting_tested: str | None
12+
loop_index: int
13+
iteration_id: str | None
14+
runtime_ns: int | None
15+
did_pass: bool
16+
timed_out: bool = False
17+
error_message: str | None = None
18+
19+
20+
@dataclass
21+
class BehavioralRunResult:
22+
run_id: str
23+
total_tests: int
24+
passed: int
25+
failed: int
26+
total_runtime_ns: int
27+
test_results: list[TestInvocationResult]
28+
errors: list[str] = field(default_factory=list)
29+
30+
31+
@dataclass
32+
class CompareResult:
33+
equivalent: bool
34+
total_compared: int
35+
diffs: list[DiffEntry] = field(default_factory=list)
36+
37+
38+
@dataclass
39+
class DiffEntry:
40+
scope: str
41+
test_name: str
42+
original_value: str
43+
candidate_value: str
44+
original_passed: bool
45+
candidate_passed: bool
46+
47+
48+
@dataclass
49+
class SpeedupInfo:
50+
baseline_run_id: str
51+
baseline_runtime_ns: int
52+
candidate_runtime_ns: int
53+
performance_gain: float
54+
speedup_x: str
55+
speedup_pct: str
56+
57+
58+
@dataclass
59+
class BenchmarkRunResult:
60+
run_id: str
61+
total_runtime_ns: int
62+
loops_executed: int
63+
test_results: list[TestInvocationResult]
64+
speedup: SpeedupInfo | None = None

0 commit comments

Comments
 (0)