Skip to content

Commit b32cac2

Browse files
committed
fix: address all code review findings for MCP server
1 parent 6565804 commit b32cac2

11 files changed

Lines changed: 368 additions & 74 deletions

TablePro/Core/MCP/MCPAuthGuard.swift

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ actor MCPAuthGuard {
113113
// MARK: - User Approval (askEachTime)
114114

115115
private func promptUserApproval(connectionName: String, databaseType: String) async throws -> Bool {
116+
// Use a task group so the actor suspends (freeing it for other requests)
117+
// while the approval dialog is shown on the main thread.
118+
// Race the dialog against a 30-second timeout.
116119
let approvalTask = Task { @MainActor in
117120
NSApp.requestUserAttention(.criticalRequest)
118121
NSApp.activate(ignoringOtherApps: true)
@@ -129,8 +132,7 @@ actor MCPAuthGuard {
129132
)
130133
}
131134

132-
// Race against a 30-second timeout
133-
return try await withThrowingTaskGroup(of: Bool.self) { group in
135+
let approved = try await withThrowingTaskGroup(of: Bool.self) { group in
134136
group.addTask {
135137
await approvalTask.value
136138
}
@@ -146,6 +148,13 @@ actor MCPAuthGuard {
146148
group.cancelAll()
147149
return result
148150
}
151+
152+
if approved {
153+
return true
154+
}
155+
throw MCPError.forbidden(
156+
String(localized: "User denied MCP access to this connection")
157+
)
149158
}
150159

151160
// MARK: - Session Cleanup

TablePro/Core/MCP/MCPConnectionBridge.swift

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,32 @@ actor MCPConnectionBridge {
4545
func connect(connectionId: UUID) async throws -> JSONValue {
4646
let connection = try await resolveConnection(connectionId)
4747

48+
// Check if session already exists and is connected -- reuse without switching UI
49+
let existingSession = await MainActor.run {
50+
DatabaseManager.shared.activeSessions[connectionId]
51+
}
52+
53+
if let existing = existingSession, existing.driver != nil {
54+
// Already connected, return current state without switching the UI's active session
55+
let serverVersion = existing.driver?.serverVersion
56+
let currentDatabase = existing.activeDatabase
57+
let currentSchema = existing.currentSchema
58+
59+
var result: [String: JSONValue] = [
60+
"status": "connected",
61+
"current_database": .string(currentDatabase)
62+
]
63+
if let version = serverVersion {
64+
result["server_version"] = .string(version)
65+
}
66+
if let schema = currentSchema {
67+
result["current_schema"] = .string(schema)
68+
}
69+
return .object(result)
70+
}
71+
72+
// Not connected yet -- create a new session via DatabaseManager.
73+
// connectToSession is @MainActor; Swift hops automatically for async calls.
4874
try await DatabaseManager.shared.connectToSession(connection)
4975

5076
let (serverVersion, currentDatabase, currentSchema) = await MainActor.run {
@@ -81,43 +107,56 @@ actor MCPConnectionBridge {
81107
}
82108

83109
func getConnectionStatus(connectionId: UUID) async throws -> JSONValue {
84-
let sessionInfo = await MainActor.run {
85-
() -> (status: ConnectionStatus, database: String, schema: String?, version: String?, connectedAt: Date)? in
110+
let core = await MainActor.run {
111+
() -> (status: ConnectionStatus, database: String, schema: String?)? in
86112
guard let session = DatabaseManager.shared.activeSessions[connectionId] else {
87113
return nil
88114
}
89-
return (
90-
status: session.status,
91-
database: session.activeDatabase,
92-
schema: session.currentSchema,
93-
version: session.driver?.serverVersion,
94-
connectedAt: session.connectedAt
95-
)
115+
return (session.status, session.activeDatabase, session.currentSchema)
96116
}
97117

98-
guard let info = sessionInfo else {
118+
guard let core else {
99119
throw MCPError.notConnected(connectionId)
100120
}
101121

122+
let meta = await MainActor.run {
123+
() -> (version: String?, connectedAt: Date, lastActiveAt: Date) in
124+
let session = DatabaseManager.shared.activeSessions[connectionId]
125+
return (
126+
session?.driver?.serverVersion,
127+
session?.connectedAt ?? Date(),
128+
session?.lastActiveAt ?? Date()
129+
)
130+
}
131+
102132
let statusString: String
103-
switch info.status {
133+
var errorDetail: JSONValue?
134+
switch core.status {
104135
case .connected: statusString = "connected"
105136
case .connecting: statusString = "connecting"
106137
case .disconnected: statusString = "disconnected"
107-
case .error(let msg): statusString = "error: \(msg)"
138+
case .error(let msg):
139+
statusString = "error"
140+
errorDetail = .object([
141+
"message": .string(msg)
142+
])
108143
}
109144

110145
var result: [String: JSONValue] = [
111146
"status": .string(statusString),
112-
"current_database": .string(info.database),
113-
"connected_at": .string(ISO8601DateFormatter().string(from: info.connectedAt))
147+
"current_database": .string(core.database),
148+
"connected_at": .string(ISO8601DateFormatter().string(from: meta.connectedAt)),
149+
"last_active_at": .string(ISO8601DateFormatter().string(from: meta.lastActiveAt))
114150
]
115-
if let schema = info.schema {
151+
if let schema = core.schema {
116152
result["current_schema"] = .string(schema)
117153
}
118-
if let version = info.version {
154+
if let version = meta.version {
119155
result["server_version"] = .string(version)
120156
}
157+
if let errorDetail {
158+
result["error"] = errorDetail
159+
}
121160

122161
return .object(result)
123162
}
@@ -146,6 +185,8 @@ actor MCPConnectionBridge {
146185
}
147186
group.addTask {
148187
try await Task.sleep(for: .seconds(timeoutSeconds))
188+
// Cancel the driver query before throwing
189+
try? driver.cancelQuery()
149190
throw MCPError.timeout("Query timed out after \(timeoutSeconds) seconds")
150191
}
151192
guard let first = try await group.next() else {
@@ -156,7 +197,7 @@ actor MCPConnectionBridge {
156197
}
157198
}
158199

159-
let executionTimeMs = (CFAbsoluteTimeGetCurrent() - startTime) * 1000
200+
let executionTimeMs = (CFAbsoluteTimeGetCurrent() - startTime) * 1_000
160201
let isTruncated = result.rows.count > maxRows
161202
let rows = isTruncated ? Array(result.rows.prefix(maxRows)) : result.rows
162203

@@ -316,6 +357,7 @@ actor MCPConnectionBridge {
316357
// MARK: - Database/Schema Switching
317358

318359
func switchDatabase(connectionId: UUID, database: String) async throws -> JSONValue {
360+
// switchDatabase is @MainActor; Swift hops automatically for async calls.
319361
try await DatabaseManager.shared.switchDatabase(to: database, for: connectionId)
320362
return .object([
321363
"status": "switched",
@@ -324,6 +366,7 @@ actor MCPConnectionBridge {
324366
}
325367

326368
func switchSchema(connectionId: UUID, schema: String) async throws -> JSONValue {
369+
// switchSchema is @MainActor; Swift hops automatically for async calls.
327370
try await DatabaseManager.shared.switchSchema(to: schema, for: connectionId)
328371
return .object([
329372
"status": "switched",
@@ -334,14 +377,34 @@ actor MCPConnectionBridge {
334377
// MARK: - Schema Resource (for resources/read)
335378

336379
func fetchSchemaResource(connectionId: UUID) async throws -> JSONValue {
337-
let (driver, _) = try await resolveDriver(connectionId)
380+
// Check SchemaProviderRegistry cache first
381+
let provider = await MainActor.run {
382+
SchemaProviderRegistry.shared.provider(for: connectionId)
383+
}
384+
var cachedTables: [TableInfo] = []
385+
if let provider {
386+
let cached = await provider.getTables()
387+
if !cached.isEmpty {
388+
cachedTables = cached
389+
}
390+
}
338391

339-
let tables = try await DatabaseManager.shared.trackOperation(sessionId: connectionId) {
340-
try await driver.fetchTables()
392+
let tables: [TableInfo]
393+
if !cachedTables.isEmpty {
394+
tables = cachedTables
395+
} else {
396+
let (driver, _) = try await resolveDriver(connectionId)
397+
tables = try await DatabaseManager.shared.trackOperation(sessionId: connectionId) {
398+
try await driver.fetchTables()
399+
}
341400
}
342401

402+
// Limit to first 100 tables to prevent excessive round-trips
403+
let limitedTables = Array(tables.prefix(100))
404+
405+
let (driver, _) = try await resolveDriver(connectionId)
343406
var tableSchemas: [JSONValue] = []
344-
for table in tables {
407+
for table in limitedTables {
345408
let columns = try await DatabaseManager.shared.trackOperation(sessionId: connectionId) {
346409
try await driver.fetchColumns(table: table.name)
347410
}
@@ -362,7 +425,13 @@ actor MCPConnectionBridge {
362425
]))
363426
}
364427

365-
return .object(["tables": .array(tableSchemas)])
428+
var result: [String: JSONValue] = ["tables": .array(tableSchemas)]
429+
if tables.count > 100 {
430+
result["truncated"] = .bool(true)
431+
result["total_tables"] = .int(tables.count)
432+
}
433+
434+
return .object(result)
366435
}
367436

368437
// MARK: - History Resource
@@ -394,7 +463,7 @@ actor MCPConnectionBridge {
394463
"query": .string(entry.query),
395464
"database_name": .string(entry.databaseName),
396465
"executed_at": .string(ISO8601DateFormatter().string(from: entry.executedAt)),
397-
"execution_time_ms": .double(entry.executionTime * 1000),
466+
"execution_time_ms": .double(entry.executionTime * 1_000),
398467
"row_count": .int(entry.rowCount),
399468
"was_successful": .bool(entry.wasSuccessful)
400469
]

TablePro/Core/MCP/MCPHTTPParser.swift

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ enum HTTPParseError: Error, Sendable {
2121
case malformedHeaders
2222
case unsupportedMethod(String)
2323
case bodyTooLarge
24+
case malformedChunkedEncoding
2425
}
2526

2627
enum MCPHTTPParser {
@@ -90,7 +91,16 @@ enum MCPHTTPParser {
9091
if let transferEncoding = headers["transfer-encoding"],
9192
transferEncoding.lowercased().contains("chunked")
9293
{
93-
logger.warning("Chunked transfer encoding not supported, treating body as empty")
94+
let bodyData = data[bodyStartIndex...]
95+
switch decodeChunkedBody(bodyData) {
96+
case .success(let decoded):
97+
if decoded.count > maxBodySize {
98+
return .failure(.bodyTooLarge)
99+
}
100+
body = decoded
101+
case .failure(let error):
102+
return .failure(error)
103+
}
94104
} else if let contentLengthStr = headers["content-length"],
95105
let contentLength = Int(contentLengthStr)
96106
{
@@ -115,6 +125,70 @@ enum MCPHTTPParser {
115125
))
116126
}
117127

128+
// MARK: - Chunked Transfer Encoding
129+
130+
/// Decode chunked transfer encoding body.
131+
/// Format: <chunk-size-hex>\r\n<chunk-data>\r\n ... 0\r\n\r\n
132+
private static func decodeChunkedBody(_ data: Data) -> Result<Data, HTTPParseError> {
133+
var result = Data()
134+
var offset = data.startIndex
135+
136+
while offset < data.endIndex {
137+
// Find the end of the chunk size line
138+
guard let lineEnd = findCRLF(in: data, from: offset) else {
139+
return .failure(.incomplete)
140+
}
141+
142+
let sizeData = data[offset..<lineEnd]
143+
guard let sizeString = String(data: sizeData, encoding: .ascii)?.trimmingCharacters(in: .whitespaces),
144+
let chunkSize = UInt(sizeString, radix: 16)
145+
else {
146+
return .failure(.malformedChunkedEncoding)
147+
}
148+
149+
// Move past the \r\n after chunk size
150+
let chunkDataStart = lineEnd + 2
151+
152+
// Terminal chunk
153+
if chunkSize == 0 {
154+
return .success(result)
155+
}
156+
157+
let chunkDataEnd = chunkDataStart + Int(chunkSize)
158+
159+
// Check we have enough data for the chunk + trailing \r\n
160+
guard chunkDataEnd + 2 <= data.endIndex else {
161+
return .failure(.incomplete)
162+
}
163+
164+
// Check accumulated size
165+
if result.count + Int(chunkSize) > maxBodySize {
166+
return .failure(.bodyTooLarge)
167+
}
168+
169+
result.append(data[chunkDataStart..<chunkDataEnd])
170+
171+
// Skip past the trailing \r\n after chunk data
172+
offset = chunkDataEnd + 2
173+
}
174+
175+
return .failure(.incomplete)
176+
}
177+
178+
/// Find \r\n in data starting from given offset
179+
private static func findCRLF(in data: Data, from start: Data.Index) -> Data.Index? {
180+
var i = start
181+
while i < data.endIndex - 1 {
182+
if data[i] == 0x0D, data[i + 1] == 0x0A {
183+
return i
184+
}
185+
i += 1
186+
}
187+
return nil
188+
}
189+
190+
// MARK: - Response Building
191+
118192
static func buildResponse(
119193
status: Int,
120194
statusText: String,
@@ -161,6 +235,7 @@ enum MCPHTTPParser {
161235
case 400: return "Bad Request"
162236
case 404: return "Not Found"
163237
case 405: return "Method Not Allowed"
238+
case 413: return "Content Too Large"
164239
case 500: return "Internal Server Error"
165240
default: return "Unknown"
166241
}

0 commit comments

Comments
 (0)