-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathstreamlit_client_stdio.py
More file actions
184 lines (155 loc) · 8.38 KB
/
streamlit_client_stdio.py
File metadata and controls
184 lines (155 loc) · 8.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python
"""
streamlit_client_stdio.py
This file implements a LangChain MCP client that:
- Loads configuration from a JSON file specified by the THEAILANGUAGE_CONFIG environment variable.
- Connects to one or more MCP servers defined in the config.
- Loads available MCP tools from each connected server.
- Uses the Google Gemini API (via LangChain) to create a React agent with access to all tools.
- Runs an interactive chat loop where user queries are processed by the agent.
Detailed explanations:
- Retries (max_retries=2): If an API call fails due to transient issues (e.g., timeouts), it will retry up to 2 times.
- Temperature (set to 0): A value of 0 means fully deterministic output; increase this for more creative responses.
- Environment Variable: THEAILANGUAGE_CONFIG should point to a config JSON that defines all MCP servers.
"""
import asyncio # For asynchronous operations
import os # To access environment variables and file paths
import sys # For system-specific parameters and error handling
import json # For reading and writing JSON data
from contextlib import AsyncExitStack # For managing multiple asynchronous context managers
# ---------------------------
# MCP Client Imports
# ---------------------------
from mcp import ClientSession, StdioServerParameters # For managing MCP client sessions and server parameters
from mcp.client.stdio import stdio_client # For establishing a stdio connection to an MCP server
# ---------------------------
# Agent and LLM Imports
# ---------------------------
from langchain_mcp_adapters.tools import load_mcp_tools # Adapter to convert MCP tools to LangChain compatible tools
from langgraph.prebuilt import create_react_agent # Function to create a prebuilt React agent using LangGraph
from langchain_google_genai import ChatGoogleGenerativeAI # Wrapper for the Google Gemini API via LangChain
# ---------------------------
# Environment Setup
# ---------------------------
from dotenv import load_dotenv
load_dotenv() # Load environment variables from a .env file (e.g., GOOGLE_API_KEY)
# ---------------------------
# Custom JSON Encoder for LangChain objects
# ---------------------------
class CustomEncoder(json.JSONEncoder):
"""
Custom JSON encoder to handle non-serializable objects returned by LangChain.
If the object has a 'content' attribute (such as HumanMessage or ToolMessage), serialize it accordingly.
"""
def default(self, o):
# Check if the object has a 'content' attribute
if hasattr(o, "content"):
# Return a dictionary containing the type and content of the object
return {"type": o.__class__.__name__, "content": o.content}
# Otherwise, use the default serialization
return super().default(o)
# ---------------------------
# Function: read_config_json
# ---------------------------
def read_config_json():
"""
Reads the MCP server configuration JSON.
Priority:
1. Try to read the path from the THEAILANGUAGE_CONFIG environment variable.
2. If not set, fallback to a default file 'theailanguage_config.json' in the same directory.
Returns:
dict: Parsed JSON content with MCP server definitions.
"""
# Attempt to get the config file path from the environment variable
config_path = os.getenv("THEAILANGUAGE_CONFIG")
if not config_path:
# If environment variable is not set, use a default config file in the same directory as this script
script_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_dir, "theailanguage_config.json")
print(f"⚠️ THEAILANGUAGE_CONFIG not set. Falling back to: {config_path}")
try:
# Open and read the JSON config file
with open(config_path, "r") as f:
return json.load(f)
except Exception as e:
# If reading fails, print an error and exit the program
print(f"❌ Failed to read config file at '{config_path}': {e}")
sys.exit(1)
# ---------------------------
# Main Function: run_agent
# ---------------------------
async def run_agent(query: str) -> str:
"""
Connects to all MCP servers defined in the configuration, loads their tools, creates a unified React agent,
and starts an interactive loop to query the agent.
"""
# ---------------------------
# Google Gemini LLM Instantiation
# ---------------------------
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash", # Specify the Google Gemini model variant to use
temperature=0, # Set temperature to 0 for deterministic responses
max_retries=2, # Set maximum retries for API calls to 2 in case of transient errors
google_api_key=os.getenv("GOOGLE_API_KEY") # Retrieve the Google API key from environment variables
)
config = read_config_json() # Load MCP server configuration from the JSON file
mcp_servers = config.get("mcpServers", {}) # Retrieve the MCP server definitions from the config
if not mcp_servers:
print("❌ No MCP servers found in the configuration.")
return
tools = [] # Initialize an empty list to hold all the tools from the connected servers
# Use AsyncExitStack to manage and cleanly close multiple asynchronous resources
async with AsyncExitStack() as stack:
# Iterate over each MCP server defined in the configuration
for server_name, server_info in mcp_servers.items():
print(f"\n🔗 Connecting to MCP Server: {server_name}...")
# Create StdioServerParameters using the command and arguments specified for the server
server_params = StdioServerParameters(
command=server_info["command"],
args=server_info["args"]
)
try:
# Establish a stdio connection to the server using the server parameters
read, write = await stack.enter_async_context(stdio_client(server_params))
# Create a client session using the read and write streams from the connection
session = await stack.enter_async_context(ClientSession(read, write))
# Initialize the session (e.g., perform handshake or setup operations)
await session.initialize()
# Load the MCP tools from the connected server using the adapter function
server_tools = await load_mcp_tools(session)
# Iterate over each tool and add it to the aggregated tools list
for tool in server_tools:
print(f"\n🔧 Loaded tool: {tool.name}")
tools.append(tool)
print(f"\n✅ {len(server_tools)} tools loaded from {server_name}.")
except Exception as e:
# Handle any errors that occur during connection or tool loading for the server
print(f"❌ Failed to connect to server {server_name}: {e}")
# If no tools were loaded from any server, exit the function
if not tools:
print("❌ No tools loaded from any server. Exiting.")
return
# Create a React agent using the Google Gemini LLM and the list of aggregated tools
agent = create_react_agent(llm, tools)
# Invoke the agent asynchronously with the query as the input message
response = await agent.ainvoke({"messages": query})
# Format and print the agent's response as nicely formatted JSON
print("\nResponse:")
try:
formatted = json.dumps(response, indent=2, cls=CustomEncoder)
print(formatted)
except Exception:
# If JSON formatting fails, simply print the raw response
print(str(response))
# If the response is a dict with "messages" list, extract the last AIMessage's content
if isinstance(response, dict) and "messages" in response:
for message in reversed(response["messages"]):
if hasattr(message, "content"):
return message.content
return "⚠️ No AI response found in messages"
# ---------------------------
# Entry Point
# ---------------------------
if __name__ == "__main__":
# Run the asynchronous run_agent function using asyncio's event loop
asyncio.run(run_agent())