|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import json |
5 | 6 | from enum import Enum |
6 | | -from typing import Optional, TypeAlias, cast |
| 7 | +from typing import Any, Optional, TypeAlias, cast |
7 | 8 |
|
8 | 9 | from fastapi import HTTPException |
9 | 10 | from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient |
|
16 | 17 | UnexpectedModelBehavior, |
17 | 18 | UsageLimitExceeded, |
18 | 19 | ) |
19 | | -from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart |
| 20 | +from pydantic_ai.messages import ( |
| 21 | + ModelRequest, |
| 22 | + ModelResponse, |
| 23 | + ToolCallPart, |
| 24 | + ToolReturnPart, |
| 25 | +) |
20 | 26 | from pydantic_ai.run import AgentRunResult |
21 | 27 | from pydantic_ai.usage import RunUsage |
22 | 28 |
|
@@ -277,6 +283,83 @@ def build_turn_summary_from_agent_run( |
277 | 283 | return state.turn_summary |
278 | 284 |
|
279 | 285 |
|
| 286 | +async def persist_agent_run_to_conversation( |
| 287 | + client: AsyncLlamaStackClient, |
| 288 | + conversation_id: str, |
| 289 | + user_input: ResponseInput, |
| 290 | + run_result: AgentRunResult[str], |
| 291 | +) -> None: |
| 292 | + """Persist a completed pydantic AI agent run to a Llama Stack conversation. |
| 293 | +
|
| 294 | + Since the pydantic AI agent path does not pass ``conversation`` to Llama Stack |
| 295 | + (to avoid duplicate history loading on tool-call continuations), the turn must |
| 296 | + be explicitly stored after the run completes. |
| 297 | +
|
| 298 | + Builds conversation items from the run's message history in correct order: |
| 299 | + user input, function calls, function call outputs, and final assistant message. |
| 300 | +
|
| 301 | + Args: |
| 302 | + client: Llama Stack client for conversation persistence. |
| 303 | + conversation_id: Llama Stack conversation ID to store items in. |
| 304 | + user_input: Original user input (string or structured items). |
| 305 | + run_result: Completed pydantic AI agent run result. |
| 306 | + """ |
| 307 | + items: list[dict[str, Any]] = [] |
| 308 | + |
| 309 | + if isinstance(user_input, str): |
| 310 | + items.append({"type": "message", "role": "user", "content": user_input}) |
| 311 | + else: |
| 312 | + items.extend(item.model_dump() for item in user_input) |
| 313 | + |
| 314 | + for message in run_result.new_messages(): |
| 315 | + if isinstance(message, ModelResponse): |
| 316 | + for part in message.parts: |
| 317 | + if isinstance(part, ToolCallPart): |
| 318 | + args = part.args_as_json_str() |
| 319 | + items.append( |
| 320 | + { |
| 321 | + "type": "function_call", |
| 322 | + "call_id": part.tool_call_id or "", |
| 323 | + "name": part.tool_name, |
| 324 | + "arguments": args, |
| 325 | + "status": "completed", |
| 326 | + } |
| 327 | + ) |
| 328 | + if message.text: |
| 329 | + items.append( |
| 330 | + { |
| 331 | + "type": "message", |
| 332 | + "role": "assistant", |
| 333 | + "content": message.text, |
| 334 | + } |
| 335 | + ) |
| 336 | + elif isinstance(message, ModelRequest): |
| 337 | + for part in message.parts: |
| 338 | + if isinstance(part, ToolReturnPart): |
| 339 | + content = part.content |
| 340 | + if not isinstance(content, str): |
| 341 | + content = json.dumps(content) |
| 342 | + items.append( |
| 343 | + { |
| 344 | + "type": "function_call_output", |
| 345 | + "call_id": part.tool_call_id or "", |
| 346 | + "output": content, |
| 347 | + } |
| 348 | + ) |
| 349 | + |
| 350 | + if not items: |
| 351 | + return |
| 352 | + |
| 353 | + try: |
| 354 | + await client.conversations.items.create(conversation_id, items=items) # type: ignore[arg-type] |
| 355 | + except (APIConnectionError, APIStatusError) as exc: |
| 356 | + logger.warning( |
| 357 | + "Failed to persist agent turn to conversation %s: %s", |
| 358 | + conversation_id, |
| 359 | + exc, |
| 360 | + ) |
| 361 | + |
| 362 | + |
280 | 363 | async def retrieve_agent_response( |
281 | 364 | client: AsyncLlamaStackClient, |
282 | 365 | responses_params: ResponsesApiParams, |
@@ -320,6 +403,14 @@ async def retrieve_agent_response( |
320 | 403 | response = map_agent_inference_error(exc, responses_params.model) |
321 | 404 | raise HTTPException(**response.model_dump()) from exc |
322 | 405 |
|
| 406 | + if not responses_params.omit_conversation: |
| 407 | + await persist_agent_run_to_conversation( |
| 408 | + client, |
| 409 | + responses_params.conversation, |
| 410 | + _original_input or responses_params.input, |
| 411 | + run_result, |
| 412 | + ) |
| 413 | + |
323 | 414 | vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools) |
324 | 415 | rag_id_mapping = configuration.rag_id_mapping |
325 | 416 | return build_turn_summary_from_agent_run( |
|
0 commit comments