Skip to content

Commit 1b08d46

Browse files
re-use same db connection and app lifespan
1 parent f388ca6 commit 1b08d46

6 files changed

Lines changed: 54 additions & 35 deletions

File tree

mcp_server/db.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
_DEFAULT_DB_PATH = Path.home() / ".codeflash" / "mcp_results.db"
1616

17+
db_conn = None
1718

1819
def get_db_path() -> Path:
1920
path = Path(os.environ.get("CODEFLASH_MCP_DB_PATH", str(_DEFAULT_DB_PATH)))
@@ -22,14 +23,26 @@ def get_db_path() -> Path:
2223

2324

2425
def get_connection() -> sqlite3.Connection:
26+
global db_conn
27+
if db_conn is not None:
28+
return db_conn
2529
db_path = get_db_path()
30+
print(db_path)
2631
conn = sqlite3.connect(str(db_path))
2732
conn.execute("PRAGMA journal_mode=WAL")
2833
conn.execute("PRAGMA foreign_keys=ON")
2934
_create_tables(conn)
35+
db_conn = conn
3036
return conn
3137

3238

39+
def close_conn() -> None:
40+
global db_conn
41+
if db_conn is not None:
42+
db_conn.close()
43+
db_conn = None
44+
45+
3346
def _create_tables(conn: sqlite3.Connection) -> None:
3447
conn.executescript("""
3548
CREATE TABLE IF NOT EXISTS runs (

mcp_server/server.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,24 @@
33
from typing import Any
44

55
from fastmcp import FastMCP
6+
from fastmcp.server.mixins import lifespan
7+
8+
from mcp_server.db import close_conn
9+
10+
11+
@lifespan
12+
async def app_lifespan(server) -> None:
13+
print("codeflash-mcp is up")
14+
yield
15+
print("shutting dowm the server")
16+
close_conn()
17+
18+
619

720
mcp = FastMCP(
821
"codeflash-mcp",
922
instructions="Run behavioral tests, compare results, and benchmark performance for code optimization.",
23+
lifespan=app_lifespan,
1024
)
1125

1226

mcp_server/test_mcp_workflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def get_e2e_test_config_by_language(language: str) -> dict[str, str]:
2121
"module_path": str(CODE_DIR / "bubble_sort.py"),
2222
"function_name": "sorter",
2323
"root": str(CODE_DIR),
24-
"optimized_code": "def sorter(arr):\n return sorted(arr)",
24+
"optimized_code": "def sorter(arr):\n result = sorted(arr)\n print('codeflash stdout: Sorting list')\n print(f'result: {result}')\n return result",
2525
}
2626
if language == "javascript":
2727
return {
@@ -203,6 +203,7 @@ def main() -> int:
203203
print("=" * 60)
204204
print("\n")
205205
langs = ["python", "javascript"]
206+
langs = ["python"]
206207
for lang in langs:
207208
exit_code = doTest(get_e2e_test_config_by_language(lang))
208209
if exit_code != 0:

mcp_server/tools/behavioral.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,16 @@ def run_behavioral_tests(
3737
)
3838

3939
conn = get_connection()
40-
try:
41-
store_run(
42-
conn=conn,
43-
run_id=run_id,
44-
run_type="behavioral",
45-
project_root=project_root,
46-
test_files=test_files,
47-
test_results=test_results,
48-
raw_stdout=run_result.stdout or "",
49-
raw_stderr=run_result.stderr or "",
50-
)
51-
finally:
52-
conn.close()
40+
store_run(
41+
conn=conn,
42+
run_id=run_id,
43+
run_type="behavioral",
44+
project_root=project_root,
45+
test_files=test_files,
46+
test_results=test_results,
47+
raw_stdout=run_result.stdout or "",
48+
raw_stderr=run_result.stderr or "",
49+
)
5350

5451
invocation_results = []
5552
errors = []

mcp_server/tools/benchmarking.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,20 @@ def run_benchmarking_tests(
4646
)
4747

4848
conn = get_connection()
49-
try:
50-
store_run(
51-
conn=conn,
52-
run_id=run_id,
53-
run_type="benchmarking",
54-
project_root=project_root,
55-
test_files=test_files,
56-
test_results=test_results,
57-
raw_stdout=run_result.stdout or "",
58-
raw_stderr=run_result.stderr or "",
59-
)
49+
store_run(
50+
conn=conn,
51+
run_id=run_id,
52+
run_type="benchmarking",
53+
project_root=project_root,
54+
test_files=test_files,
55+
test_results=test_results,
56+
raw_stdout=run_result.stdout or "",
57+
raw_stderr=run_result.stderr or "",
58+
)
6059

61-
speedup_info = None
62-
if baseline_run_id:
63-
speedup_info = _compute_speedup(conn, baseline_run_id, test_results)
64-
finally:
65-
conn.close()
60+
speedup_info = None
61+
if baseline_run_id:
62+
speedup_info = _compute_speedup(conn, baseline_run_id, test_results)
6663

6764
best_summed_runtime_ns = test_results.total_passed_runtime() if test_results else 0
6865
loops_executed = test_results.effective_loop_count() if test_results else 0

mcp_server/tools/compare.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77

88
def compare_test_results(original_run_id: str, candidate_run_id: str, pass_fail_only: bool = False) -> dict[str, Any]:
99
conn = get_connection()
10-
try:
11-
original_results = load_test_results(conn, original_run_id)
12-
candidate_results = load_test_results(conn, candidate_run_id)
13-
finally:
14-
conn.close()
10+
original_results = load_test_results(conn, original_run_id)
11+
candidate_results = load_test_results(conn, candidate_run_id)
1512

1613
if not original_results:
1714
return {"error": f"No results found for run_id: {original_run_id}"}

0 commit comments

Comments
 (0)