Skip to content

Commit 949a8de

Browse files
committed
potential fix
1 parent 9ff72ff commit 949a8de

4 files changed

Lines changed: 111 additions & 7 deletions

File tree

src/utils/agents/query.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from __future__ import annotations
44

5+
import json
56
from enum import Enum
6-
from typing import Optional, TypeAlias, cast
7+
from typing import Any, Optional, TypeAlias, cast
78

89
from fastapi import HTTPException
910
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
@@ -16,7 +17,12 @@
1617
UnexpectedModelBehavior,
1718
UsageLimitExceeded,
1819
)
19-
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart
20+
from pydantic_ai.messages import (
21+
ModelRequest,
22+
ModelResponse,
23+
ToolCallPart,
24+
ToolReturnPart,
25+
)
2026
from pydantic_ai.run import AgentRunResult
2127
from pydantic_ai.usage import RunUsage
2228

@@ -277,6 +283,83 @@ def build_turn_summary_from_agent_run(
277283
return state.turn_summary
278284

279285

286+
async def persist_agent_run_to_conversation(
287+
client: AsyncLlamaStackClient,
288+
conversation_id: str,
289+
user_input: ResponseInput,
290+
run_result: AgentRunResult[str],
291+
) -> None:
292+
"""Persist a completed pydantic AI agent run to a Llama Stack conversation.
293+
294+
Since the pydantic AI agent path does not pass ``conversation`` to Llama Stack
295+
(to avoid duplicate history loading on tool-call continuations), the turn must
296+
be explicitly stored after the run completes.
297+
298+
Builds conversation items from the run's message history in correct order:
299+
user input, function calls, function call outputs, and final assistant message.
300+
301+
Args:
302+
client: Llama Stack client for conversation persistence.
303+
conversation_id: Llama Stack conversation ID to store items in.
304+
user_input: Original user input (string or structured items).
305+
run_result: Completed pydantic AI agent run result.
306+
"""
307+
items: list[dict[str, Any]] = []
308+
309+
if isinstance(user_input, str):
310+
items.append({"type": "message", "role": "user", "content": user_input})
311+
else:
312+
items.extend(item.model_dump() for item in user_input)
313+
314+
for message in run_result.new_messages():
315+
if isinstance(message, ModelResponse):
316+
for part in message.parts:
317+
if isinstance(part, ToolCallPart):
318+
args = part.args_as_json_str()
319+
items.append(
320+
{
321+
"type": "function_call",
322+
"call_id": part.tool_call_id or "",
323+
"name": part.tool_name,
324+
"arguments": args,
325+
"status": "completed",
326+
}
327+
)
328+
if message.text:
329+
items.append(
330+
{
331+
"type": "message",
332+
"role": "assistant",
333+
"content": message.text,
334+
}
335+
)
336+
elif isinstance(message, ModelRequest):
337+
for part in message.parts:
338+
if isinstance(part, ToolReturnPart):
339+
content = part.content
340+
if not isinstance(content, str):
341+
content = json.dumps(content)
342+
items.append(
343+
{
344+
"type": "function_call_output",
345+
"call_id": part.tool_call_id or "",
346+
"output": content,
347+
}
348+
)
349+
350+
if not items:
351+
return
352+
353+
try:
354+
await client.conversations.items.create(conversation_id, items=items) # type: ignore[arg-type]
355+
except (APIConnectionError, APIStatusError) as exc:
356+
logger.warning(
357+
"Failed to persist agent turn to conversation %s: %s",
358+
conversation_id,
359+
exc,
360+
)
361+
362+
280363
async def retrieve_agent_response(
281364
client: AsyncLlamaStackClient,
282365
responses_params: ResponsesApiParams,
@@ -320,6 +403,14 @@ async def retrieve_agent_response(
320403
response = map_agent_inference_error(exc, responses_params.model)
321404
raise HTTPException(**response.model_dump()) from exc
322405

406+
if not responses_params.omit_conversation:
407+
await persist_agent_run_to_conversation(
408+
client,
409+
responses_params.conversation,
410+
_original_input or responses_params.input,
411+
run_result,
412+
)
413+
323414
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
324415
rag_id_mapping = configuration.rag_id_mapping
325416
return build_turn_summary_from_agent_run(

src/utils/agents/streaming.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
get_agent_finish_reason,
5050
get_finish_reason_error,
5151
map_agent_inference_error,
52+
persist_agent_run_to_conversation,
5253
)
5354
from utils.agents.tool_processor import (
5455
process_function_tool_call,
@@ -300,6 +301,15 @@ async def agent_response_generator(
300301
return
301302

302303
run_result = dispatch_state.run_result
304+
305+
if not responses_params.omit_conversation:
306+
await persist_agent_run_to_conversation(
307+
context.client,
308+
responses_params.conversation,
309+
responses_params.input,
310+
run_result,
311+
)
312+
303313
turn_summary.token_usage = extract_agent_token_usage(
304314
run_result.usage,
305315
responses_params.model,

src/utils/pydantic_ai.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
2121
{
22-
"conversation",
2322
"max_infer_iters",
2423
"tools",
2524
"tool_choice",
@@ -68,10 +67,9 @@ def _model_settings_from_responses_params(
6867
if responses_params.extra_headers:
6968
settings_dict["extra_headers"] = dict(responses_params.extra_headers)
7069
settings_dict["openai_store"] = responses_params.store
71-
if responses_params.previous_response_id is not None:
72-
settings_dict["openai_previous_response_id"] = (
73-
responses_params.previous_response_id
74-
)
70+
settings_dict["openai_previous_response_id"] = (
71+
responses_params.previous_response_id or "auto"
72+
)
7573
return cast(OpenAIResponsesModelSettings, settings_dict)
7674

7775

src/utils/responses.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
420420
# Normalize Vertex AI model IDs to work around llama-stack 0.6.x bug
421421
normalized_model = normalize_vertex_ai_model_id(model)
422422

423+
previous_response_id = (
424+
user_conversation.last_response_id if user_conversation else None
425+
)
426+
423427
return ResponsesApiParams(
424428
input=input_text,
425429
model=normalized_model,
@@ -429,6 +433,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
429433
stream=stream,
430434
store=store,
431435
extra_headers=extra_headers,
436+
previous_response_id=previous_response_id,
432437
)
433438

434439

0 commit comments

Comments
 (0)