|
14 | 14 | """ |
15 | 15 |
|
16 | 16 | import json |
17 | | -from typing import Optional |
| 17 | +from collections.abc import Generator |
| 18 | +from typing import Optional, Union |
18 | 19 |
|
19 | 20 | from pgadmin.llm.client import get_llm_client, is_llm_available |
20 | | -from pgadmin.llm.models import Message, StopReason |
| 21 | +from pgadmin.llm.models import Message, LLMResponse, StopReason |
21 | 22 | from pgadmin.llm.tools import DATABASE_TOOLS, execute_tool, DatabaseToolError |
22 | 23 | from pgadmin.llm.utils import get_max_tool_iterations |
23 | 24 |
|
@@ -153,6 +154,117 @@ def chat_with_database( |
153 | 154 | ) |
154 | 155 |
|
155 | 156 |
|
| 157 | +def chat_with_database_stream( |
| 158 | + user_message: str, |
| 159 | + sid: int, |
| 160 | + did: int, |
| 161 | + conversation_history: Optional[list[Message]] = None, |
| 162 | + system_prompt: Optional[str] = None, |
| 163 | + max_tool_iterations: Optional[int] = None, |
| 164 | + provider: Optional[str] = None, |
| 165 | + model: Optional[str] = None |
| 166 | +) -> Generator[Union[str, tuple[str, list[Message]]], None, None]: |
| 167 | + """ |
| 168 | + Stream an LLM chat conversation with database tool access. |
| 169 | +
|
| 170 | + Like chat_with_database, but yields text chunks as the final |
| 171 | + response streams in. During tool-use iterations, no text is |
| 172 | + yielded (tools are executed silently). |
| 173 | +
|
| 174 | + Yields: |
| 175 | + str: Text content chunks from the final LLM response. |
| 176 | +
|
| 177 | + The last item yielded is a tuple of |
| 178 | + (final_response_text, updated_conversation_history). |
| 179 | +
|
| 180 | + Raises: |
| 181 | + LLMClientError: If the LLM request fails. |
| 182 | + RuntimeError: If LLM is not available or max iterations exceeded. |
| 183 | + """ |
| 184 | + if not is_llm_available(): |
| 185 | + raise RuntimeError("LLM is not configured. Please configure an LLM " |
| 186 | + "provider in Preferences > AI.") |
| 187 | + |
| 188 | + client = get_llm_client(provider=provider, model=model) |
| 189 | + if not client: |
| 190 | + raise RuntimeError("Failed to create LLM client") |
| 191 | + |
| 192 | + messages = list(conversation_history) if conversation_history else [] |
| 193 | + messages.append(Message.user(user_message)) |
| 194 | + |
| 195 | + if system_prompt is None: |
| 196 | + system_prompt = DEFAULT_SYSTEM_PROMPT |
| 197 | + |
| 198 | + if max_tool_iterations is None: |
| 199 | + max_tool_iterations = get_max_tool_iterations() |
| 200 | + |
| 201 | + iteration = 0 |
| 202 | + while iteration < max_tool_iterations: |
| 203 | + iteration += 1 |
| 204 | + |
| 205 | + # Stream the LLM response, yielding text chunks as they arrive |
| 206 | + response = None |
| 207 | + for item in client.chat_stream( |
| 208 | + messages=messages, |
| 209 | + tools=DATABASE_TOOLS, |
| 210 | + system_prompt=system_prompt |
| 211 | + ): |
| 212 | + if isinstance(item, LLMResponse): |
| 213 | + response = item |
| 214 | + elif isinstance(item, str): |
| 215 | + yield item |
| 216 | + |
| 217 | + if response is None: |
| 218 | + raise RuntimeError("No response received from LLM") |
| 219 | + |
| 220 | + messages.append(response.to_message()) |
| 221 | + |
| 222 | + if response.stop_reason != StopReason.TOOL_USE: |
| 223 | + # Final response - yield the completion tuple |
| 224 | + yield (response.content, messages) |
| 225 | + return |
| 226 | + |
| 227 | + # Signal that tools are being executed so the caller can |
| 228 | + # reset streaming state and show a thinking indicator |
| 229 | + yield ('tool_use', [tc.name for tc in response.tool_calls]) |
| 230 | + |
| 231 | + # Execute tool calls |
| 232 | + tool_results = [] |
| 233 | + for tool_call in response.tool_calls: |
| 234 | + try: |
| 235 | + result = execute_tool( |
| 236 | + tool_name=tool_call.name, |
| 237 | + arguments=tool_call.arguments, |
| 238 | + sid=sid, |
| 239 | + did=did |
| 240 | + ) |
| 241 | + tool_results.append(Message.tool_result( |
| 242 | + tool_call_id=tool_call.id, |
| 243 | + content=json.dumps(result, default=str), |
| 244 | + is_error=False |
| 245 | + )) |
| 246 | + except (DatabaseToolError, ValueError) as e: |
| 247 | + tool_results.append(Message.tool_result( |
| 248 | + tool_call_id=tool_call.id, |
| 249 | + content=json.dumps({"error": str(e)}), |
| 250 | + is_error=True |
| 251 | + )) |
| 252 | + except Exception as e: |
| 253 | + tool_results.append(Message.tool_result( |
| 254 | + tool_call_id=tool_call.id, |
| 255 | + content=json.dumps({ |
| 256 | + "error": f"Unexpected error: {str(e)}" |
| 257 | + }), |
| 258 | + is_error=True |
| 259 | + )) |
| 260 | + |
| 261 | + messages.extend(tool_results) |
| 262 | + |
| 263 | + raise RuntimeError( |
| 264 | + f"Exceeded maximum tool iterations ({max_tool_iterations})" |
| 265 | + ) |
| 266 | + |
| 267 | + |
156 | 268 | def single_query( |
157 | 269 | question: str, |
158 | 270 | sid: int, |
|
0 commit comments