|
1 | 1 | from __future__ import annotations |
2 | 2 | import os |
3 | 3 | import asyncio |
4 | | -from agents import Agent, Runner |
| 4 | +from typing import Optional, Any |
| 5 | +from agents import Agent, Runner, ModelSettings, set_tracing_disabled |
5 | 6 | from agents.mcp import MCPServerSse |
| 7 | +from .session_manager import session_manager |
| 8 | +from .ollama_integration import ( |
| 9 | + check_ollama_available, |
| 10 | + create_ollama_model, |
| 11 | + get_ollama_model_name, |
| 12 | +) |
| 13 | + |
| 14 | +# Disable tracing for local models to avoid errors |
| 15 | +set_tracing_disabled(True) |
6 | 16 |
|
7 | 17 | # Gradio 5.29 SSE endpoint |
8 | 18 | MCP_SSE_URL = os.getenv( |
|
16 | 26 | cache_tools_list=True, |
17 | 27 | ) |
18 | 28 |
|
19 | | -# Initialize agent |
| 29 | +# Base agent instructions |
| 30 | +BASE_INSTRUCTIONS = ( |
| 31 | + "You are a data assistant that can analyze tabular data and create PDFs.\n" |
| 32 | + "You can work with SQL databases, CSV files, and generate PDF reports.\n" |
| 33 | + "Common workflows include:\n" |
| 34 | + "- Query data from database then generate PDF report with results\n" |
| 35 | + "- Analyze CSV files and create summary reports\n" |
| 36 | + "- Generate custom reports based on user specifications\n" |
| 37 | + "You should auto-discover available tools via the MCP server connection.\n\n" |
| 38 | + "IMPORTANT: You have access to conversation memory. The system will maintain your\n" |
| 39 | + "conversation history with the user, so you can refer to previous messages and context.\n" |
| 40 | + "Remember what was discussed earlier and maintain continuity in the conversation.\n\n" |
| 41 | + "When working with databases:\n" |
| 42 | + "- First, discover what tables are available in the database\n" |
| 43 | + "- If the user mentions a table that doesn't exist, look for alternatives\n" |
| 44 | + "- Explore the structure of the tables to understand columns\n" |
| 45 | + "- Execute appropriate queries based on what you discovered\n" |
| 46 | + '- To call the SQL tool, use: {"name": "sql", "arguments": {"query": "YOUR SQL QUERY"}}\n\n' |
| 47 | + "When generating PDF reports:\n" |
| 48 | + "- IMPORTANT: When asked to create a PDF report, create it immediately with the information provided\n" |
| 49 | + "- Generate reports even with minimal information - do not ask for clarification\n" |
| 50 | + "- The 'data_json' parameter should be a JSON string with data to include\n" |
| 51 | + "- Always include the generated PDF file path in your response\n" |
| 52 | + '- Example format: {"title": "Report Title", "data": "Your Data"}\n' |
| 53 | + '- To call the PDF tool, use: {"name": "pdf", "arguments": {"data_json": "JSON string here"}}\n\n' |
| 54 | + "When working with CSV files:\n" |
| 55 | + "- If a user has uploaded a CSV file, it will be available in the uploads directory\n" |
| 56 | + "- Use the csv tool to analyze and provide insights about the data\n" |
| 57 | + "- Remember previous analyses of the same file when the user asks follow-up questions\n" |
| 58 | + "- Always consider the context of previous questions about the data\n" |
| 59 | + '- To call the CSV tool, use: {"name": "csv", "arguments": {"file_path": "/path/to/file.csv"}}\n\n' |
| 60 | + "IMPORTANT: Always execute tools by submitting the proper JSON format directly.\n" |
| 61 | + "DO NOT show explanations of what you're going to do - just directly call the tool with the proper JSON format.\n" |
| 62 | + "After receiving tool results, then you can explain and interpret the results to the user.\n" |
| 63 | +) |
| 64 | + |
| 65 | +# Standard model settings for all agents |
| 66 | +# Use the same settings across providers for consistency (following the example) |
| 67 | +model_settings = ModelSettings(temperature=0.7, tool_choice="auto") |
| 68 | + |
| 69 | +# Initialize agent - we'll modify the model and instructions per session |
20 | 70 | agent = Agent( |
21 | 71 | name="NeurArk Data Assistant", |
22 | | - instructions=( |
23 | | - "You are a data assistant that can analyze tabular data and create PDFs.\n" |
24 | | - "You can work with SQL databases, CSV files, and generate PDF reports.\n" |
25 | | - "Common workflows include:\n" |
26 | | - "- Query data from database then generate PDF report with results\n" |
27 | | - "- Analyze CSV files and create summary reports\n" |
28 | | - "- Generate custom reports based on user specifications\n" |
29 | | - "You should auto-discover available tools via the MCP server connection.\n\n" |
30 | | - "When working with databases:\n" |
31 | | - "- First, discover what tables are available in the database\n" |
32 | | - "- If the user mentions a table that doesn't exist, look for alternatives\n" |
33 | | - "- Explore the structure of the tables to understand columns\n" |
34 | | - "- Execute appropriate queries based on what you discovered\n\n" |
35 | | - "When generating PDF reports:\n" |
36 | | - "- The 'data_json' parameter should be a JSON string with data to include\n" |
37 | | - "- Always include the generated PDF file path in your response\n" |
38 | | - "- Example format: {\"title\": \"Report Title\", \"data\": \"Your Data\"}\n" |
39 | | - ), |
40 | | - model="gpt-4.1-mini", |
| 72 | + instructions=BASE_INSTRUCTIONS, |
| 73 | + model="gpt-4.1-mini", # Default model, will be changed based on provider |
| 74 | + model_settings=model_settings, |
41 | 75 | mcp_servers=[mcp_server], |
42 | 76 | ) |
43 | 77 |
|
| 78 | +# Use the function from ollama_integration.py module |
| 79 | +# Just for backward compatibility with existing code |
| 80 | +_check_ollama_available = check_ollama_available |
44 | 81 |
|
45 | | -async def _run_agent(prompt: str) -> str: |
46 | | - """Run the agent asynchronously with proper server connection handling.""" |
47 | | - # Connect to MCP server before running the agent |
48 | | - async with mcp_server: |
49 | | - # Execute the agent with the prompt |
50 | | - result = await Runner.run(starting_agent=agent, input=prompt) |
51 | | - return result.final_output # String with PDF path or response |
52 | 82 |
|
| 83 | +def answer( |
| 84 | + prompt: str, |
| 85 | + provider: str = "openai", |
| 86 | + session_id: Optional[str] = None, |
| 87 | + prev_result: Optional[Any] = None, |
| 88 | +) -> str: |
| 89 | + """ |
| 90 | + Run the agent with the specified provider and session context. |
53 | 91 |
|
54 | | -def answer(prompt: str) -> str: |
55 | | - """Synchronous wrapper for running the agent.""" |
56 | | - if not os.getenv("OPENAI_API_KEY"): |
57 | | - return "⚠️ OPENAI_API_KEY not set." |
| 92 | + Args: |
| 93 | + prompt: The user prompt to send to the agent |
| 94 | + provider: The LLM provider (openai or ollama) |
| 95 | + session_id: Optional session ID for maintaining conversation context |
| 96 | + prev_result: Previous result object from Runner.run, used to maintain conversation history |
58 | 97 |
|
| 98 | + Returns: |
| 99 | + tuple: The agent's response and the result object for future calls |
| 100 | + """ |
59 | 101 | try: |
60 | | - # Run the async function in a synchronous context |
61 | | - return asyncio.run(_run_agent(prompt)) |
| 102 | + # Create a new session if none provided |
| 103 | + if not session_id: |
| 104 | + session_id = session_manager.create_session() |
| 105 | + print(f"Created new session: {session_id}") |
| 106 | + |
| 107 | + # Exit early if Ollama selected but not available |
| 108 | + if provider == "ollama" and not _check_ollama_available(): |
| 109 | + return "⚠️ Ollama not available or not running.", None |
| 110 | + |
| 111 | + # Exit early if OpenAI selected but API key not set |
| 112 | + if provider == "openai" and not os.getenv("OPENAI_API_KEY"): |
| 113 | + return "⚠️ OPENAI_API_KEY not set.", None |
| 114 | + |
| 115 | + try: |
| 116 | + # Update instructions with session context |
| 117 | + if session_id: |
| 118 | + agent.instructions = session_manager.create_system_prompt( |
| 119 | + session_id, BASE_INSTRUCTIONS |
| 120 | + ) |
| 121 | + |
| 122 | + # Configure the model based on provider |
| 123 | + if provider == "ollama": |
| 124 | + # Get the Ollama model |
| 125 | + model_name = get_ollama_model_name() |
| 126 | + print(f"Using Ollama model: {model_name}") |
| 127 | + |
| 128 | + # Set the agent's model to use Ollama |
| 129 | + agent.model = create_ollama_model() |
| 130 | + else: |
| 131 | + # Get the OpenAI model |
| 132 | + model_name = os.getenv("OPENAI_MODEL", "gpt-4.1-mini") |
| 133 | + print(f"Using OpenAI model: {model_name}") |
| 134 | + |
| 135 | + # Set the agent's model to use OpenAI |
| 136 | + agent.model = model_name |
| 137 | + |
| 138 | + except Exception as e: |
| 139 | + print(f"Error setting up provider: {str(e)}") |
| 140 | + return f"⚠️ Error setting up {provider} client: {str(e)}", None |
| 141 | + |
| 142 | + # Prepare input based on whether prev_result exists |
| 143 | + if prev_result: |
| 144 | + # Use the conversation history from the previous result |
| 145 | + print("Using previous result to maintain conversation history") |
| 146 | + # Add the new user message to the previous conversation history |
| 147 | + input_messages = prev_result.to_input_list() + [ |
| 148 | + {"role": "user", "content": prompt} |
| 149 | + ] |
| 150 | + else: |
| 151 | + # First message in conversation |
| 152 | + print("Starting new conversation") |
| 153 | + input_messages = [{"role": "user", "content": prompt}] |
| 154 | + |
| 155 | + # Still store in session for persistence/logging (but won't be used directly) |
| 156 | + session_manager.add_message(session_id, "user", prompt) |
| 157 | + |
| 158 | + print(f"Running agent with prompt: {prompt[:30]}...") |
| 159 | + |
| 160 | + try: |
| 161 | + # Define async function to run the agent |
| 162 | + async def run_agent_async(): |
| 163 | + # Connect to MCP server |
| 164 | + try: |
| 165 | + print("Connecting to MCP server...") |
| 166 | + await mcp_server.connect() |
| 167 | + print("MCP server connected successfully") |
| 168 | + except Exception as e: |
| 169 | + print(f"Warning: MCP server connection issue: {str(e)}") |
| 170 | + |
| 171 | + # Use async context manager for clean connections |
| 172 | + async with mcp_server: |
| 173 | + # Use input_messages from prev_result or new conversation |
| 174 | + print(f"Running with {len(input_messages)} messages in history") |
| 175 | + if len(input_messages) > 0: |
| 176 | + first_role = input_messages[0].get('role', '?') |
| 177 | + last_role = input_messages[-1].get('role', '?') |
| 178 | + print(f"First message: {first_role}, latest: {last_role}") |
| 179 | + |
| 180 | + result = await Runner.run( |
| 181 | + starting_agent=agent, |
| 182 | + input=input_messages, |
| 183 | + max_turns=10, # Prevent infinite loops |
| 184 | + ) |
| 185 | + |
| 186 | + # Ensure we properly close any OpenAI clients if using Ollama |
| 187 | + if provider == "ollama": |
| 188 | + try: |
| 189 | + # Get the OpenAI client from the model and close it |
| 190 | + if hasattr(agent.model, "openai_client"): |
| 191 | + client = agent.model.openai_client |
| 192 | + if hasattr(client, "aclose"): |
| 193 | + await client.aclose() |
| 194 | + except Exception as e: |
| 195 | + print(f"Warning when closing httpx client: {str(e)}") |
| 196 | + |
| 197 | + return result |
| 198 | + |
| 199 | + # Run the agent with better event loop handling |
| 200 | + try: |
| 201 | + try: |
| 202 | + # Vérifier si une boucle est déjà en cours d'exécution |
| 203 | + loop = asyncio.get_running_loop() |
| 204 | + # Si on est déjà dans une boucle asyncio, utiliser create_task |
| 205 | + task = asyncio.run_coroutine_threadsafe(run_agent_async(), loop) |
| 206 | + result = task.result() |
| 207 | + except RuntimeError: |
| 208 | + # Aucune boucle en cours d'exécution, en créer une nouvelle |
| 209 | + result = asyncio.run(run_agent_async()) |
| 210 | + except Exception as e: |
| 211 | + print(f"Error during async execution: {str(e)}") |
| 212 | + # Ensure any pending tasks are cleaned up |
| 213 | + try: |
| 214 | + for task in asyncio.all_tasks(): |
| 215 | + if not task.done(): |
| 216 | + task.cancel() |
| 217 | + except RuntimeError: |
| 218 | + # Handle the case where there's no running event loop |
| 219 | + pass |
| 220 | + raise |
| 221 | + |
| 222 | + # Get the response text |
| 223 | + response = result.final_output |
| 224 | + print( |
| 225 | + f"DEBUG - Raw LLM response from result.final_output: {response[:150]}" |
| 226 | + ) |
| 227 | + |
| 228 | + # Store the assistant response in session history |
| 229 | + session_manager.add_message(session_id, "assistant", response) |
| 230 | + |
| 231 | + # Return both the response and result object |
| 232 | + return response, result |
| 233 | + |
| 234 | + except Exception as e: |
| 235 | + print(f"Error running agent: {str(e)}") |
| 236 | + import traceback |
| 237 | + |
| 238 | + print(traceback.format_exc()) |
| 239 | + |
| 240 | + # Add error message to history |
| 241 | + error_msg = f"Error: {str(e)}" |
| 242 | + session_manager.add_message(session_id, "assistant", error_msg) |
| 243 | + return error_msg, None |
| 244 | + |
62 | 245 | except Exception as e: |
63 | 246 | import traceback |
| 247 | + |
64 | 248 | error_trace = traceback.format_exc() |
65 | 249 | print(f"Agent error: {str(e)}") |
66 | 250 | print(f"Error trace: {error_trace}") |
67 | | - return f"Error: {str(e)}\nTrace: {error_trace}" |
| 251 | + |
| 252 | + # Add error to history if session exists |
| 253 | + if session_id: |
| 254 | + error_response = f"Error: {str(e)}" |
| 255 | + session_manager.add_message(session_id, "assistant", error_response) |
| 256 | + |
| 257 | + return f"Error: {str(e)}\nTrace: {error_trace}", None |
0 commit comments