-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmcp_sql_server.py
More file actions
1020 lines (843 loc) · 37.1 KB
/
mcp_sql_server.py
File metadata and controls
1020 lines (843 loc) · 37.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
"""
MCP Server for SQL Safety Checker
Provides safe, read-only SQL query execution through MCP protocol.
Following FastMCP best practices for tool design and context usage.
Supported Databases:
- MySQL (default): Full INFORMATION_SCHEMA support
- SQLite: Uses sqlite_master and PRAGMA for metadata
Configuration:
- Set DB_TYPE environment variable to 'mysql' or 'sqlite'
- MySQL: Configure DB_USER, DB_PASSWORD, DB_HOST, DB_NAME
- SQLite: Configure SQLITE_DATABASE_PATH
Backward Compatibility:
- Default DB_TYPE=mysql maintains existing behavior
- All existing environment variables continue to work
"""
import os
import re
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator
from mcp.types import ToolAnnotations
from fastmcp import FastMCP, Context
from sql_safety_checker import is_sql_safe, execute_sql
from db_adapter import get_adapter, DB_TYPE
logger = logging.getLogger(__name__)
# =============================================================================
# Configuration
# =============================================================================
SCHEMA_TOOLS_ENABLED = os.getenv("ENABLE_SCHEMA_TOOLS", "1") == "1"
# Enable get_table_summary tool (with exact COUNT(*) option)
# Default: disabled - describe_table already provides row_count (estimated)
# Enable when exact counts are needed for specific workflows
TABLE_SUMMARY_ENABLED = os.getenv("ENABLE_TABLE_SUMMARY", "0") == "1"
# Large table threshold for is_large flag and query recommendations
# Reference: MySQL InnoDB full table scan cost considerations
LARGE_TABLE_THRESHOLD = int(os.getenv("LARGE_TABLE_THRESHOLD", "1000"))
# Token optimization: Limit result size to prevent context overflow
# Reference: Google Gemini best practices - "Token limits: function descriptions and parameters count toward input token limits"
# Reference: MCP Security best practices - "Sanitize tool outputs"
# Set to 0 to disable truncation (for data export scenarios)
MAX_RESULT_ROWS = int(os.getenv("MAX_RESULT_ROWS", "100")) # Max rows (0=unlimited)
MAX_RESULT_CHARS = int(os.getenv("MAX_RESULT_CHARS", "16000")) # Max chars (0=unlimited)
MAX_SCHEMA_TABLES = int(os.getenv("MAX_SCHEMA_TABLES", "50")) # Max tables in get_full_schema
MAX_OVERVIEW_TABLES = int(os.getenv("MAX_OVERVIEW_TABLES", "100")) # Max tables in list_tables
# UNION Query Policy (P2 Security: Configurable UNION handling)
# Reference: OWASP Defense-in-Depth - block UNION by default for safety
# Reference: OpenAI "minimize tool calls" - allow UNION for efficiency when needed
# When disabled (default): LLM uses multiple queries (safer, more calls)
# When enabled: UNION allowed but requires table allowlist for validation
ALLOW_UNION = os.getenv("ALLOW_UNION", "0") == "1"
# =============================================================================
# Table Allowlist Configuration (P1 Security: Restrict table access)
# Reference: Microsoft "Least Privilege Principle" - only allow access to necessary tables
# Reference: Anthropic MCP Security - "Implement proper access controls"
# =============================================================================
def _parse_table_allowlist() -> set[str] | None:
"""
Parse ALLOWED_TABLES environment variable into a set.
Format: Comma-separated table names (case-insensitive)
Special value: "*" means allow all tables (explicit opt-in for UNION)
Example: ALLOWED_TABLES=products,orders,customers
Returns:
Set of allowed table names (lowercase), None if not configured,
or {"*"} if explicitly set to allow all
"""
allowed_tables_env = os.getenv("ALLOWED_TABLES", "").strip()
if not allowed_tables_env:
return None # No allowlist configured - allow all tables
# Special case: "*" means explicitly allow all tables
if allowed_tables_env == "*":
logger.info("Table allowlist set to '*' - all tables allowed (explicit)")
return {"*"} # Special marker for "allow all"
# Parse comma-separated list, normalize to lowercase
tables = {t.strip().lower() for t in allowed_tables_env.split(",") if t.strip()}
if tables:
logger.info(f"Table allowlist enabled: {len(tables)} tables allowed")
return tables if tables else None
# Load allowlist at startup
ALLOWED_TABLES: set[str] | None = _parse_table_allowlist()
# Log security configuration at module load
if ALLOW_UNION:
if ALLOWED_TABLES:
logger.info(f"UNION queries enabled with table allowlist: {sorted(ALLOWED_TABLES)}")
else:
logger.warning("ALLOW_UNION=1 but no ALLOWED_TABLES configured - UNION will be blocked")
else:
logger.info("UNION queries disabled (default safe mode)")
def _is_table_allowed(table_name: str) -> bool:
"""
Check if a table is in the allowlist.
Args:
table_name: Name of the table to check
Returns:
True if table is allowed (or no allowlist configured), False otherwise
"""
if ALLOWED_TABLES is None:
return True # No allowlist - allow all
if "*" in ALLOWED_TABLES:
return True # Explicit "allow all" via ALLOWED_TABLES=*
return table_name.lower() in ALLOWED_TABLES
def _extract_tables_from_sql(sql: str) -> list[str]:
"""
Extract table names from a SQL query.
Handles common patterns:
- FROM table_name
- JOIN table_name
- UPDATE table_name (blocked by safety check, but included for completeness)
- INTO table_name
Args:
sql: SQL query string
Returns:
List of table names found in the query
"""
tables = []
# Pattern for FROM/JOIN clauses
# Handles: FROM table, FROM `table`, FROM schema.table, FROM `schema`.`table`
# Reference: MySQL identifier syntax - captures only the table name (after optional schema.)
from_join_pattern = r'(?:FROM|JOIN)\s+(?:`?[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*`?\s*\.\s*)?`?([a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*)`?'
matches = re.findall(from_join_pattern, sql, re.IGNORECASE)
tables.extend(matches)
# Pattern for table in DESCRIBE/EXPLAIN (also handles schema.table)
describe_pattern = r'(?:DESCRIBE|DESC|EXPLAIN)\s+(?:`?[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*`?\s*\.\s*)?`?([a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*)`?'
matches = re.findall(describe_pattern, sql, re.IGNORECASE)
tables.extend(matches)
return list(set(tables)) # Remove duplicates
def _check_table_allowlist(sql: str) -> tuple[bool, str | None]:
"""
Check if all tables in a SQL query are in the allowlist.
Args:
sql: SQL query to check
Returns:
(is_allowed, error_message) - True if all tables allowed, False with error otherwise
"""
if ALLOWED_TABLES is None:
return True, None # No allowlist configured
tables = _extract_tables_from_sql(sql)
blocked_tables = [t for t in tables if not _is_table_allowed(t)]
if blocked_tables:
return False, f"Access denied to table(s): {', '.join(blocked_tables)}. Only allowed tables: {', '.join(sorted(ALLOWED_TABLES))}"
return True, None
# =============================================================================
# Lifespan Management (Best Practice: manage resources properly)
# =============================================================================
@asynccontextmanager
async def lifespan(mcp_server: FastMCP) -> AsyncIterator[dict[str, Any]]:
"""
Server lifespan context manager.
Initialize resources on startup, cleanup on shutdown.
"""
logger.info("SQL Safety Checker MCP Server starting...")
# Future: Initialize database connection pool here
yield {"initialized": True}
logger.info("SQL Safety Checker MCP Server shutting down...")
# Future: Cleanup database connections here
# Create MCP server with lifespan
# Reference: Google/Anthropic best practices - server instructions should describe capabilities,
# not prescribe workflow (let LLM decide based on task context)
mcp = FastMCP(
name="sql-safety-executor",
instructions="""Database query assistant with READ-ONLY access.
Use query() for all data requests. Use describe_table() or get_full_schema() first if structure unknown.
For single-table queries, if schema/columns unknown, call describe_table(table_name) before query().""",
lifespan=lifespan,
)
# =============================================================================
# Helper Functions (Internal)
# =============================================================================
def _serialize_result(data: Any) -> Any:
"""Convert SQLAlchemy Row objects to JSON-serializable format."""
if data is None:
return None
if isinstance(data, list):
return [_serialize_result(item) for item in data]
if hasattr(data, '_mapping'):
return dict(data._mapping)
if hasattr(data, '__dict__'):
return {k: v for k, v in data.__dict__.items() if not k.startswith('_')}
return data
def _is_valid_identifier(name: str) -> bool:
"""
Validate table/column name to prevent SQL injection.
Security measures:
- Only allow ASCII alphanumeric and underscore (stricter than before)
- Block quotes and special characters
- Length limit to prevent buffer issues
- Reject MySQL reserved words that could be exploited
"""
if not name or len(name) > 64: # MySQL identifier max length
return False
# Strict ASCII pattern: letters, digits, underscore only
# Chinese characters are allowed but validated separately
if not re.match(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$', name):
return False
# Block dangerous patterns
dangerous_patterns = [
r'[\'\"`;\\]', # Quotes and escape chars
r'--', # SQL comment
r'/\*', # Block comment start
r'\*/', # Block comment end
r'\x00', # Null byte
]
for pattern in dangerous_patterns:
if re.search(pattern, name):
return False
return True
# Allowlist for SHOW commands (restrict information disclosure)
ALLOWED_SHOW_COMMANDS = {
'SHOW TABLES',
'SHOW COLUMNS',
'SHOW INDEX',
'SHOW CREATE TABLE',
'SHOW TABLE STATUS',
'SHOW DATABASES',
}
# Blocked SHOW commands that leak sensitive info
BLOCKED_SHOW_PATTERNS = [
r'SHOW\s+VARIABLES',
r'SHOW\s+GRANTS',
r'SHOW\s+PROCESSLIST',
r'SHOW\s+MASTER',
r'SHOW\s+SLAVE',
r'SHOW\s+BINARY',
r'SHOW\s+ENGINE',
r'SHOW\s+PLUGINS',
r'SHOW\s+PRIVILEGES',
r'SHOW\s+STATUS', # Can leak sensitive metrics
]
def _is_query_safe_extended(sql: str) -> tuple[bool, str | None]:
"""
Extended safety check beyond basic statement type validation.
Returns:
(is_safe, error_message)
"""
sql_upper = sql.upper().strip()
# Check blocked SHOW commands
for pattern in BLOCKED_SHOW_PATTERNS:
if re.match(pattern, sql_upper, re.IGNORECASE):
return False, f"SHOW command not allowed for security: {pattern}"
# Block access to mysql/information_schema system tables in SELECT
system_table_pattern = r'\b(mysql|performance_schema)\s*\.'
if re.search(system_table_pattern, sql, re.IGNORECASE):
return False, "Access to system databases not allowed"
# UNION handling: Configurable based on ALLOW_UNION setting
# Reference: OWASP - UNION is common SQL injection vector
# Reference: OpenAI - minimize tool calls for efficiency
if re.search(r'\bUNION\b', sql, re.IGNORECASE):
if not ALLOW_UNION:
# Default: Block UNION, guide LLM to use multiple queries
return False, (
"UNION queries disabled for security. "
"Execute separate queries for each table and combine results in your response."
)
# UNION enabled: Require table allowlist for validation
if ALLOWED_TABLES is None:
return False, (
"UNION requires ALLOWED_TABLES. "
"Set ALLOWED_TABLES=table1,table2 or ALLOWED_TABLES=* to enable."
)
# UNION will be validated by _check_table_allowlist() which extracts all tables
logger.info("UNION query allowed - validating tables")
# Block subqueries in FROM clause (potential info disclosure)
# Allow subqueries in WHERE for legitimate use
if re.search(r'FROM\s*\(', sql, re.IGNORECASE):
return False, "Subqueries in FROM clause not allowed"
return True, None
def _truncate_result(data: list, total_rows: int) -> dict[str, Any]:
"""
Truncate query results to prevent token explosion.
Set MAX_RESULT_ROWS=0 or MAX_RESULT_CHARS=0 to disable respective limits.
Returns:
Dict with truncated data and metadata
"""
import json
truncated = False
truncation_reason = None
returned_rows = len(data)
# Step 1: Limit by row count (0 = disabled)
if MAX_RESULT_ROWS > 0 and len(data) > MAX_RESULT_ROWS:
data = data[:MAX_RESULT_ROWS]
truncated = True
truncation_reason = f"row_limit ({MAX_RESULT_ROWS})"
returned_rows = MAX_RESULT_ROWS
# Step 2: Limit by character count (0 = disabled)
if MAX_RESULT_CHARS > 0:
try:
json_str = json.dumps(data, ensure_ascii=False, default=str)
if len(json_str) > MAX_RESULT_CHARS:
# Binary search for optimal row count within char limit
low, high = 1, len(data)
while low < high:
mid = (low + high + 1) // 2
test_str = json.dumps(data[:mid], ensure_ascii=False, default=str)
if len(test_str) <= MAX_RESULT_CHARS:
low = mid
else:
high = mid - 1
data = data[:low]
truncated = True
truncation_reason = f"char_limit ({MAX_RESULT_CHARS})"
returned_rows = low
except (TypeError, ValueError):
pass # If JSON encoding fails, skip char limit check
return {
"data": data,
"returned_rows": returned_rows,
"total_rows": total_rows,
"truncated": truncated,
"truncation_reason": truncation_reason,
}
# =============================================================================
# MCP Tools (Following FastMCP Best Practices)
# =============================================================================
@mcp.tool(
annotations=ToolAnnotations(
title="Execute SQL Query",
readOnlyHint=True,
destructiveHint=False,
idempotentHint=True,
openWorldHint=False,
)
)
async def query(sql: str, ctx: Context) -> dict[str, Any]:
"""
Execute a SQL SELECT query on the database.
This is the PRIMARY tool for all database queries.
Safety validation is automatic - only read-only statements are allowed.
Supported: SELECT, SHOW, DESCRIBE, EXPLAIN.
Args:
sql: A SQL SELECT query to execute
Returns:
Query results with data rows, or error message if query fails
Examples:
query("SELECT * FROM users LIMIT 10")
query("SELECT name, email FROM users WHERE active = 1")
query("SELECT COUNT(*) as total FROM orders")
query("SHOW TABLES")
query("DESCRIBE users")
query("EXPLAIN SELECT * FROM products WHERE id = 1")
"""
await ctx.info(f"Executing query: {sql}")
# Validate safety - basic check
if not is_sql_safe(sql):
await ctx.warning(f"Rejected unsafe query: {sql}")
return {
"success": False,
"error": "Only read-only queries allowed (SELECT, SHOW, DESCRIBE, EXPLAIN)",
"query": sql
}
# Extended safety check - block dangerous patterns
is_safe, error_msg = _is_query_safe_extended(sql)
if not is_safe:
await ctx.warning(f"Rejected query (extended check): {error_msg}")
return {
"success": False,
"error": error_msg,
"query": sql
}
# Table allowlist check (P1 Security)
# Reference: Microsoft "Least Privilege Principle"
is_allowed, allowlist_error = _check_table_allowlist(sql)
if not is_allowed:
await ctx.warning(f"Table access denied: {allowlist_error}")
return {
"success": False,
"error": allowlist_error,
"query": sql
}
# Execute query
result = execute_sql(sql)
# Handle error
if isinstance(result, str) and result.startswith("Error:"):
await ctx.error(f"Query failed: {result}")
return {
"success": False,
"error": result,
"query": sql
}
# Success - Apply token optimization with truncation
# Best practice: Limit response size to prevent context overflow
# Reference: Google Gemini - "Token limits: function descriptions and parameters count toward input token limits"
data = _serialize_result(result)
total_rows = len(result) if isinstance(result, list) else 0
# Apply truncation to prevent token explosion (root cause of 454K token issue)
truncation_result = _truncate_result(data, total_rows)
if truncation_result["truncated"]:
await ctx.warning(
f"Truncated: {truncation_result['returned_rows']}/{total_rows} rows. "
f"Add LIMIT to your query for precise control."
)
else:
await ctx.info(f"Query returned {total_rows} rows")
return {
"success": True,
"data": truncation_result["data"],
"row_count": truncation_result["returned_rows"],
"total_rows": total_rows,
"truncated": truncation_result["truncated"],
"truncation_note": (
f"Showing {truncation_result['returned_rows']}/{total_rows} rows. "
f"Use LIMIT clause for full control."
) if truncation_result["truncated"] else None,
"query": sql
}
@mcp.tool(
annotations=ToolAnnotations(
title="Check Database Connection",
readOnlyHint=True,
destructiveHint=False,
idempotentHint=True,
)
)
async def check_connection(ctx: Context) -> dict[str, Any]:
"""
Check if the database connection is working.
Use this to verify database connectivity before running queries.
Returns:
Connection status, database type, and configuration check
"""
await ctx.info("Checking database connection...")
adapter = get_adapter()
success, message = adapter.check_connection()
if not success:
await ctx.error(f"Connection failed: {message}")
# Return appropriate config hints based on database type
if DB_TYPE == "sqlite":
return {
"connected": False,
"error": message,
"db_type": "sqlite",
"config": {
"SQLITE_DATABASE_PATH": "set" if os.getenv("SQLITE_DATABASE_PATH") else "missing (using :memory:)",
}
}
else: # mysql
return {
"connected": False,
"error": message,
"db_type": "mysql",
"config": {
"DB_USER": "set" if os.getenv("DB_USER") else "missing",
"DB_PASSWORD": "set" if os.getenv("DB_PASSWORD") else "missing",
"DB_HOST": "set" if os.getenv("DB_HOST") else "missing",
"DB_NAME": "set" if os.getenv("DB_NAME") else "missing",
}
}
await ctx.info(f"Database connection successful ({DB_TYPE})")
return {
"connected": True,
"message": message,
"db_type": DB_TYPE,
"database_name": adapter.get_database_name()
}
@mcp.tool(
annotations=ToolAnnotations(
title="List Database Tables",
readOnlyHint=True,
destructiveHint=False,
idempotentHint=True,
)
)
async def list_tables(ctx: Context) -> dict[str, Any]:
"""
Database overview: list all tables with names and approximate row counts.
Lightweight initial discovery tool. Row counts are estimates:
- MySQL: from INFORMATION_SCHEMA (InnoDB may vary ±40%)
- SQLite: from sqlite_stat1 or sampling
For column details, use describe_table(name).
Returns:
Database name, table count, and list of tables with approximate row counts
"""
await ctx.info("Listing database tables")
adapter = get_adapter()
database_name = adapter.get_database_name()
# Use adapter method for cross-database compatibility
tables = adapter.get_tables()
if not tables:
await ctx.info(f"No tables found in {database_name}")
return {
"success": True,
"database_name": database_name,
"db_type": DB_TYPE,
"returned_table_count": 0,
"total_tables": 0,
"tables": [],
"row_count_approximate": True,
"truncated": False,
"truncation_note": None,
"hint": "No tables found in database."
}
# Filter by allowlist if configured (P1 Security)
# Skip filtering if ALLOWED_TABLES=* (explicit allow all)
if ALLOWED_TABLES is not None and "*" not in ALLOWED_TABLES:
original_count = len(tables)
tables = [t for t in tables if t["table_name"].lower() in ALLOWED_TABLES]
if len(tables) < original_count:
await ctx.info(f"Filtered {original_count - len(tables)} tables by allowlist")
# Apply truncation to prevent token overflow (consistent with get_full_schema)
total_tables = len(tables)
truncated = False
if MAX_OVERVIEW_TABLES > 0 and total_tables > MAX_OVERVIEW_TABLES:
tables = tables[:MAX_OVERVIEW_TABLES]
truncated = True
await ctx.warning(f"Overview truncated: showing {MAX_OVERVIEW_TABLES}/{total_tables} tables")
await ctx.info(f"Found {len(tables)} tables in {database_name}")
# Use consistent field names: returned_table_count vs total_tables (visible before truncation)
# Note: total_tables is after allowlist filtering, before truncation
# "returned_" prefix avoids confusion with "total tables in database"
return {
"success": True,
"database_name": database_name,
"db_type": DB_TYPE,
"returned_table_count": len(tables),
"total_tables": total_tables, # Visible tables (after allowlist, before truncation)
"tables": tables,
"row_count_approximate": True,
"truncated": truncated,
"truncation_note": (
f"Showing {len(tables)}/{total_tables} tables. Use describe_table(name) for specific tables."
) if truncated else None,
"hint": f"Row counts are estimates. total_tables = visible after allowlist. DB type: {DB_TYPE}"
}
@mcp.tool(
annotations=ToolAnnotations(
title="Describe Table Structure",
readOnlyHint=True,
destructiveHint=False,
idempotentHint=True,
)
)
async def describe_table(table_name: str, ctx: Context) -> dict[str, Any]:
"""
Get table structure: columns, row count estimate, and query hints.
Returns column details plus approximate row count:
- MySQL: from INFORMATION_SCHEMA (InnoDB ±40% variance)
- SQLite: from sqlite_stat1 or sampling
Includes is_large flag and recommendations for large tables.
Args:
table_name: Name of the table to describe
Returns:
Table structure with columns, row count, and query recommendations
"""
# Validate table name to prevent SQL injection
if not _is_valid_identifier(table_name):
await ctx.warning(f"Invalid table name rejected: {table_name}")
return {
"success": False,
"error": f"Invalid table name: {table_name}"
}
# Check table allowlist (P1 Security)
if not _is_table_allowed(table_name):
await ctx.warning(f"Table access denied by allowlist: {table_name}")
return {
"success": False,
"error": f"Access denied to table: {table_name}"
}
await ctx.info(f"Describing table: {table_name}")
# Use adapter methods for cross-database compatibility
adapter = get_adapter()
columns_data = adapter.get_columns(table_name)
row_count = adapter.get_row_estimate(table_name)
if not columns_data:
await ctx.error(f"Table not found: {table_name}")
return {
"success": False,
"error": f"Table '{table_name}' not found"
}
# Determine if table is large (needs LIMIT)
is_large = row_count > LARGE_TABLE_THRESHOLD
await ctx.info(f"Table {table_name}: ~{row_count} rows, {len(columns_data)} columns")
# Build response with query recommendations
result = {
"success": True,
"table_name": table_name,
"db_type": DB_TYPE,
"row_count": row_count,
"row_count_approximate": True,
"column_count": len(columns_data),
"columns": columns_data,
"is_large": is_large,
}
# Add recommendation only for large tables (reduce token overhead)
if is_large:
result["recommendation"] = (
f"Large table (~{row_count} rows). Use LIMIT or aggregation (COUNT/GROUP BY)."
)
return result
# =============================================================================
# Schema Caching Tools (Microsoft Best Practices: Reduce repeated tool calls)
# Reference: "Too many tools in the same agent can have negative effect on agent quality"
# =============================================================================
@mcp.tool(
annotations=ToolAnnotations(
title="Get Full Database Schema",
readOnlyHint=True,
destructiveHint=False,
idempotentHint=True,
)
)
async def get_full_schema(ctx: Context) -> dict[str, Any]:
"""
Get complete database schema (all tables and columns) in one call.
Best for: exploring unknown databases, multi-table queries, or complex JOINs.
Works with both MySQL and SQLite databases.
Returns:
Complete schema with all tables and their column definitions
"""
await ctx.info("Fetching complete database schema...")
adapter = get_adapter()
# Step 1: Get all tables with row counts using adapter
tables_data = adapter.get_tables()
if not tables_data:
await ctx.info("No tables found in database")
return {
"success": True,
"schema": {},
"db_type": DB_TYPE,
"returned_table_count": 0,
"total_tables": 0,
"total_columns": 0,
"row_count_approximate": True,
"truncated": False,
"truncation_note": None,
"hint": "No tables found in database."
}
# Filter tables by allowlist if configured (P1 Security)
# Skip filtering if ALLOWED_TABLES=* (explicit allow all)
if ALLOWED_TABLES is not None and "*" not in ALLOWED_TABLES:
tables_data = [t for t in tables_data if t["table_name"].lower() in ALLOWED_TABLES]
await ctx.info(f"Allowlist active: showing {len(tables_data)} allowed tables")
# Step 2: Apply truncation to prevent token overflow (P0 security/performance)
# Reference: Google Gemini best practices - token limits
total_tables = len(tables_data)
truncated = False
if MAX_SCHEMA_TABLES > 0 and total_tables > MAX_SCHEMA_TABLES:
tables_data = tables_data[:MAX_SCHEMA_TABLES]
truncated = True
await ctx.warning(f"Schema truncated: showing {MAX_SCHEMA_TABLES}/{total_tables} tables")
# Step 3: Get columns for each table and organize into structured schema
schema = {}
for table in tables_data:
table_name = table["table_name"]
columns_data = adapter.get_columns(table_name)
schema[table_name] = {
"row_count": table["row_count"],
"columns": [
{
"name": col["column_name"],
"type": col["data_type"],
"nullable": col["nullable"],
"key": col["key_type"]
}
for col in columns_data
]
}
total_columns_shown = sum(len(t["columns"]) for t in schema.values())
await ctx.info(f"Schema loaded: {len(schema)} tables, {total_columns_shown} columns")
# Use consistent field names: returned_table_count vs total_tables (visible before truncation)
# Note: total_tables is after allowlist filtering, before truncation
# "returned_" prefix avoids confusion with "total tables in database"
return {
"success": True,
"schema": schema,
"db_type": DB_TYPE,
"returned_table_count": len(schema),
"total_tables": total_tables, # Visible tables (after allowlist, before truncation)
"total_columns": total_columns_shown,
"row_count_approximate": True,
"truncated": truncated,
"truncation_note": (
f"Showing {len(schema)}/{total_tables} tables. Use describe_table(name) for specific tables."
) if truncated else None,
"hint": f"Row counts are estimates. Use LIMIT for large tables (row_count > {LARGE_TABLE_THRESHOLD}). total_tables = visible after allowlist. DB type: {DB_TYPE}"
}
# Optional tool: get_table_summary with exact COUNT(*) option
# Default: disabled - describe_table already provides estimated row_count
# Enable via ENABLE_TABLE_SUMMARY=1 when exact counts are needed
if TABLE_SUMMARY_ENABLED:
@mcp.tool(
annotations=ToolAnnotations(
title="Get Table Summary with Exact Count",
readOnlyHint=True,
destructiveHint=False,
idempotentHint=True,
)
)
async def get_table_summary(
table_name: str,
ctx: Context,
exact_count: bool = False
) -> dict[str, Any]:
"""
Table statistics with optional exact row count.
WARNING: exact_count=True runs COUNT(*) which may be slow on large InnoDB tables
(full table scan, potential MDL contention). Use only when precision is required.
Default: Uses INFORMATION_SCHEMA estimate (fast, ~40% variance for InnoDB).
Args:
table_name: Name of the table
exact_count: If True, run COUNT(*) for precise count (slow on large tables)
Returns:
Table statistics with row count, columns, and query hints
"""
# Validate table name
if not _is_valid_identifier(table_name):
await ctx.warning(f"Invalid table name rejected: {table_name}")
return {"success": False, "error": f"Invalid table name: {table_name}"}
# Check table allowlist (P1 Security)
if not _is_table_allowed(table_name):
await ctx.warning(f"Table access denied by allowlist: {table_name}")
return {"success": False, "error": f"Access denied to table: {table_name}"}
await ctx.info(f"Getting summary for table: {table_name} (exact_count={exact_count})")
# Get row count - choose method based on exact_count flag
row_count_approximate = True
if exact_count:
# WARNING: COUNT(*) can be slow on large InnoDB tables
count_sql = f"SELECT COUNT(*) as total_rows FROM `{table_name}`"
row_count_approximate = False
await ctx.warning(f"Running COUNT(*) on {table_name} - may be slow on large tables")
else:
# Fast estimate from INFORMATION_SCHEMA (no table scan)
count_sql = f"""
SELECT TABLE_ROWS as total_rows
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = '{table_name}'
"""
count_result = execute_sql(count_sql)
if isinstance(count_result, str) and count_result.startswith("Error:"):
await ctx.error(f"Failed to count rows: {count_result}")
return {"success": False, "error": count_result}
count_data = _serialize_result(count_result)
total_rows = count_data[0]["total_rows"] if count_data else 0
total_rows = total_rows or 0 # Handle None
# Get column info
columns_sql = f"""
SELECT
COLUMN_NAME as name,
DATA_TYPE as type,
IS_NULLABLE as nullable
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = '{table_name}'
ORDER BY ORDINAL_POSITION
"""
columns_result = execute_sql(columns_sql)
columns_data = _serialize_result(columns_result) if not isinstance(columns_result, str) else []
# Determine if table is large
is_large = total_rows > LARGE_TABLE_THRESHOLD
await ctx.info(f"Table {table_name}: {'~' if row_count_approximate else ''}{total_rows} rows, {len(columns_data)} columns")
result = {
"success": True,
"table_name": table_name,
"row_count": total_rows,
"row_count_approximate": row_count_approximate,
"column_count": len(columns_data),
"columns": columns_data,
"is_large": is_large,
}
if is_large:
result["recommendation"] = (
f"Large table ({'~' if row_count_approximate else ''}{total_rows} rows). "
"Use LIMIT or aggregation (COUNT/GROUP BY)."
)
return result
# Optional tool: Only register if enabled
if SCHEMA_TOOLS_ENABLED:
@mcp.tool(
annotations=ToolAnnotations(
title="Sample Table Data",
readOnlyHint=True,
destructiveHint=False,
idempotentHint=True,
)
)
async def sample(table_name: str, ctx: Context, limit: int = 5) -> dict[str, Any]:
"""
Get sample rows from a table to preview its data.
Args:
table_name: Name of the table to sample
limit: Number of rows to return (max 20)
Returns:
Sample rows from the table
"""
# Validate inputs
if not _is_valid_identifier(table_name):
await ctx.warning(f"Invalid table name rejected: {table_name}")
return {
"success": False,
"error": f"Invalid table name: {table_name}"
}
# Check table allowlist (P1 Security)
if not _is_table_allowed(table_name):
await ctx.warning(f"Table access denied by allowlist: {table_name}")
return {
"success": False,
"error": f"Access denied to table: {table_name}"
}
limit = min(max(1, limit), 20) # Clamp between 1-20
await ctx.info(f"Sampling {limit} rows from: {table_name}")
sql = f"SELECT * FROM `{table_name}` LIMIT {limit}"
result = execute_sql(sql)
if isinstance(result, str) and result.startswith("Error:"):
await ctx.error(f"Failed to sample table: {result}")
return {"success": False, "error": result}
data = _serialize_result(result)
return {
"success": True,
"table_name": table_name,
"data": data,
"row_count": len(data),
"query": sql
}
# =============================================================================
# MCP Prompts
# =============================================================================
@mcp.prompt(name="sql_assistant")
def sql_assistant() -> str:
"""System prompt for SQL query assistance."""
# Build dynamic prompt based on configuration
# Reference: Microsoft prompt engineering - clear, structured, avoid unnecessary steps
# Reference: MCP spec - model-driven tool selection, provide decision rules not fixed paths
if ALLOW_UNION and ALLOWED_TABLES:
if "*" in ALLOWED_TABLES:
cross_table = "UNION supported for combining results."
else:
tables_desc = ', '.join(sorted(ALLOWED_TABLES))
cross_table = f"UNION allowed for: {tables_desc}."
else:
cross_table = "Query tables separately."