-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmcp_client_patch.py
More file actions
128 lines (109 loc) · 4.77 KB
/
mcp_client_patch.py
File metadata and controls
128 lines (109 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Patch for MCP client to handle malformed JSON-RPC responses from the server.
This addresses the issue where the server returns error responses that don't
conform to the JSON-RPC specification.
"""
import json
import logging
from typing import Any, Dict
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def patch_mcp_json_parsing():
"""
Monkey patch the MCP client to handle malformed JSON-RPC responses.
"""
try:
# Try to import and patch the streamable HTTP client
from mcp.client import streamable_http
# Store the original parse function if it exists
if hasattr(streamable_http, '_parse_json_rpc_response'):
original_parse = streamable_http._parse_json_rpc_response
else:
original_parse = None
def patched_parse_response(response_text: str) -> Dict[str, Any]:
"""
Parse JSON-RPC response with better error handling for malformed responses.
"""
try:
response = json.loads(response_text)
# If this is an error response, ensure it has the required fields
if "error" in response and isinstance(response["error"], dict):
error = response["error"]
# Add missing 'code' field if not present
if "code" not in error:
error["code"] = -32603 # Internal error code
logger.warning("Fixed malformed JSON-RPC error response: added missing 'code' field")
# Ensure message field exists
if "message" not in error:
error["message"] = "Unknown server error"
logger.warning("Fixed malformed JSON-RPC error response: added missing 'message' field")
return response
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON-RPC response: {e}")
# Return a properly formatted error response
return {
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32700, # Parse error
"message": f"Parse error: {str(e)}"
}
}
# Apply the patch if we found the function
if original_parse:
streamable_http._parse_json_rpc_response = patched_parse_response
logger.info("Successfully patched MCP JSON-RPC response parser")
return True
else:
logger.warning("Could not find MCP JSON-RPC parser function to patch")
return False
except ImportError as e:
logger.error(f"Could not import MCP streamable_http module: {e}")
return False
except Exception as e:
logger.error(f"Failed to apply MCP JSON-RPC patch: {e}")
return False
def patch_mcp_error_handling():
"""
Alternative approach: patch the error handling in the MCP client session.
"""
try:
from mcp.client.session import ClientSession
# Store original method
if hasattr(ClientSession, '_handle_response'):
original_handle_response = ClientSession._handle_response
else:
return False
def patched_handle_response(self, response_data):
"""Handle response with better error handling."""
try:
return original_handle_response(self, response_data)
except Exception as e:
error_msg = str(e)
if "validation errors for JSONRPCMessage" in error_msg:
# Create a proper error response
return {
"jsonrpc": "2.0",
"id": getattr(response_data, 'id', None),
"error": {
"code": -32603,
"message": "Server returned malformed response"
}
}
else:
raise e
ClientSession._handle_response = patched_handle_response
logger.info("Successfully patched MCP ClientSession error handling")
return True
except Exception as e:
logger.error(f"Failed to patch MCP ClientSession: {e}")
return False
if __name__ == "__main__":
# Try both patching approaches
success1 = patch_mcp_json_parsing()
success2 = patch_mcp_error_handling()
if success1 or success2:
print("✅ MCP client patching completed")
else:
print("❌ MCP client patching failed")