-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathclient.py
More file actions
244 lines (188 loc) · 10.5 KB
/
client.py
File metadata and controls
244 lines (188 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Import necessary libraries
import asyncio # For handling asynchronous operations
import os # For environment variable access
import sys # For system-specific parameters and functions
import json # For handling JSON data (used when printing function declarations)
# Import MCP client components
from typing import Optional # For type hinting optional values
from contextlib import AsyncExitStack # For managing multiple async tasks
from mcp import ClientSession, StdioServerParameters # MCP session management
from mcp.client.stdio import stdio_client # MCP client for standard I/O communication
# Import Google's Gen AI SDK
from google import genai
from google.genai import types
from google.genai.types import Tool, FunctionDeclaration
from google.genai.types import GenerateContentConfig
from dotenv import load_dotenv # For loading API keys from a .env file
# Load environment variables from .env file
load_dotenv()
class MCPClient:
def __init__(self):
"""Initialize the MCP client and configure the Gemini API."""
self.session: Optional[ClientSession] = None # MCP session for communication
self.exit_stack = AsyncExitStack() # Manages async resource cleanup
# Retrieve the Gemini API key from environment variables
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
raise ValueError("GEMINI_API_KEY not found. Please add it to your .env file.")
# Configure the Gemini AI client
self.genai_client = genai.Client(api_key=gemini_api_key)
async def connect_to_server(self, server_script_path: str):
"""Connect to the MCP server and list available tools."""
# Determine whether the server script is written in Python or JavaScript
# This allows us to execute the correct command to start the MCP server
command = "python" if server_script_path.endswith('.py') else "node"
# Define the parameters for connecting to the MCP server
server_params = StdioServerParameters(command=command, args=[server_script_path])
# Establish communication with the MCP server using standard input/output (stdio)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
# Extract the read/write streams from the transport object
self.stdio, self.write = stdio_transport
# Initialize the MCP client session, which allows interaction with the server
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
# Send an initialization request to the MCP server
await self.session.initialize()
# Request the list of available tools from the MCP server
response = await self.session.list_tools()
tools = response.tools # Extract the tool list from the response
# Print a message showing the names of the tools available on the server
print("\nConnected to server with tools:", [tool.name for tool in tools])
# Convert MCP tools to Gemini format
self.function_declarations = convert_mcp_tools_to_gemini(tools)
async def process_query(self, query: str) -> str:
"""
Process a user query using the Gemini API and execute tool calls if needed.
Args:
query (str): The user's input query.
Returns:
str: The response generated by the Gemini model.
"""
# Format user input as a structured Content object for Gemini
user_prompt_content = types.Content(
role='user', # Indicates that this is a user message
parts=[types.Part.from_text(text=query)] # Convert the text query into a Gemini-compatible format
)
# Send user input to Gemini AI and include available tools for function calling
response = self.genai_client.models.generate_content(
model='gemini-2.0-flash-001', # Specifies which Gemini model to use
contents=[user_prompt_content], # Send user input to Gemini
config=types.GenerateContentConfig(
tools=self.function_declarations, # Pass the list of available MCP tools for Gemini to use
),
)
# Initialize variables to store final response text and assistant messages
final_text = [] # Stores the final formatted response
assistant_message_content = [] # Stores assistant responses
# Process the response received from Gemini
for candidate in response.candidates:
if candidate.content.parts: # Ensure response has content
for part in candidate.content.parts:
if isinstance(part, types.Part): # Check if part is a valid Gemini response unit
if part.function_call: # If Gemini suggests a function call, process it
# Extract function call details
function_call_part = part # Store the function call response
tool_name = function_call_part.function_call.name # Name of the MCP tool Gemini wants to call
tool_args = function_call_part.function_call.args # Arguments required for the tool execution
# Print debug info: Which tool is being called and with what arguments
print(f"\n[Gemini requested tool call: {tool_name} with args {tool_args}]")
# Execute the tool using the MCP server
try:
result = await self.session.call_tool(tool_name, tool_args) # Call MCP tool with arguments
function_response = {"result": result.content} # Store the tool's output
except Exception as e:
function_response = {"error": str(e)} # Handle errors if tool execution fails
# Format the tool response for Gemini in a way it understands
function_response_part = types.Part.from_function_response(
name=tool_name, # Name of the function/tool executed
response=function_response # The result of the function execution
)
# Structure the tool response as a Content object for Gemini
function_response_content = types.Content(
role='tool', # Specifies that this response comes from a tool
parts=[function_response_part] # Attach the formatted response part
)
# Send tool execution results back to Gemini for processing
response = self.genai_client.models.generate_content(
model='gemini-2.0-flash-001', # Use the same model
contents=[
user_prompt_content, # Include original user query
function_call_part, # Include Gemini's function call request
function_response_content, # Include tool execution result
],
config=types.GenerateContentConfig(
tools=self.function_declarations, # Provide the available tools for continued use
),
)
# Extract final response text from Gemini after processing the tool call
final_text.append(response.candidates[0].content.parts[0].text)
else:
# If no function call was requested, simply add Gemini's text response
final_text.append(part.text)
# Return the combined response as a single formatted string
return "\n".join(final_text)
async def chat_loop(self):
"""Run an interactive chat session with the user."""
print("\nMCP Client Started! Type 'quit' to exit.")
while True:
query = input("\nQuery: ").strip()
if query.lower() == 'quit':
break
# Process the user's query and display the response
response = await self.process_query(query)
print("\n" + response)
async def cleanup(self):
"""Clean up resources before exiting."""
await self.exit_stack.aclose()
def clean_schema(schema):
"""
Recursively removes 'title' fields from the JSON schema.
Args:
schema (dict): The schema dictionary.
Returns:
dict: Cleaned schema without 'title' fields.
"""
if isinstance(schema, dict):
schema.pop("title", None) # Remove title if present
# Recursively clean nested properties
if "properties" in schema and isinstance(schema["properties"], dict):
for key in schema["properties"]:
schema["properties"][key] = clean_schema(schema["properties"][key])
return schema
def convert_mcp_tools_to_gemini(mcp_tools):
"""
Converts MCP tool definitions to the correct format for Gemini API function calling.
Args:
mcp_tools (list): List of MCP tool objects with 'name', 'description', and 'inputSchema'.
Returns:
list: List of Gemini Tool objects with properly formatted function declarations.
"""
gemini_tools = []
for tool in mcp_tools:
# Ensure inputSchema is a valid JSON schema and clean it
parameters = clean_schema(tool.inputSchema)
# Construct the function declaration
function_declaration = FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters=parameters # Now correctly formatted
)
# Wrap in a Tool object
gemini_tool = Tool(function_declarations=[function_declaration])
gemini_tools.append(gemini_tool)
return gemini_tools
async def main():
"""Main function to start the MCP client."""
if len(sys.argv) < 2:
print("Usage: python client.py <path_to_server_script>")
sys.exit(1)
client = MCPClient()
try:
# Connect to the MCP server and start the chat loop
await client.connect_to_server(sys.argv[1])
await client.chat_loop()
finally:
# Ensure resources are cleaned up
await client.cleanup()
if __name__ == "__main__":
# Run the main function within the asyncio event loop
asyncio.run(main())