diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index f7fd5f632..5234a56eb 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -41,6 +41,7 @@ from models.common.responses.responses_api_params import ResponsesApiParams from models.common.turn_summary import TurnSummary from models.config import Action +from utils.agents.query import retrieve_agent_response from utils.conversations import append_turn_items_to_conversation from utils.endpoints import ( check_configuration_loaded, @@ -206,8 +207,13 @@ async def query_endpoint_handler( client = await AsyncLlamaStackClientHolder().update_azure_token() # Retrieve response using Responses API - turn_summary = await retrieve_response( - client, responses_params, moderation_result, endpoint_path + turn_summary = await retrieve_agent_response( + client=client, + responses_params=responses_params, + moderation_result=moderation_result, + endpoint_path=endpoint_path, + vector_store_ids=query_request.vector_store_ids or [], + rag_id_mapping=configuration.rag_id_mapping or {}, ) if moderation_result.decision == "passed": diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index c88fb03dd..dbc802f93 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -4,7 +4,6 @@ import asyncio import datetime -import json from collections.abc import AsyncIterator from typing import Annotated, Any, Optional, cast @@ -78,6 +77,10 @@ from models.common.responses.responses_api_params import ResponsesApiParams from models.common.turn_summary import ReferencedDocument, TurnSummary from models.config import Action +from utils.agents.streaming import ( + generate_agent_response, + retrieve_agent_response_generator, +) from utils.conversations import append_turn_items_to_conversation from utils.endpoints import ( check_configuration_loaded, @@ -115,6 +118,11 @@ validate_shield_ids_override, ) from utils.stream_interrupts import get_stream_interrupt_registry +from utils.streaming_sse import ( + format_stream_data, + shield_violation_generator, + stream_event, +) from utils.suid import get_suid, normalize_conversation_id from utils.token_counter import TokenCounter from utils.vector_search import build_rag_context @@ -287,7 +295,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals ) recording.record_llm_call(provider_id, model_id, endpoint_path) - generator, turn_summary = await retrieve_response_generator( + generator, turn_summary = await retrieve_agent_response_generator( responses_params=responses_params, context=context, endpoint_path=endpoint_path, @@ -306,7 +314,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals ) return StreamingResponse( - generate_response( + generate_agent_response( generator=generator, context=context, responses_params=responses_params, @@ -738,6 +746,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat # Store MCP call item info for later lookup when arguments.done event occurs elif event_type == "response.output_item.added": item_added_chunk = cast(OutputItemAddedChunk, chunk) + if item_added_chunk.item.type == "mcp_call": mcp_call_item = cast(MCPCall, item_added_chunk.item) mcp_calls[item_added_chunk.output_index] = ( @@ -923,22 +932,6 @@ def stream_http_error_event( ) -def format_stream_data(d: dict) -> str: - """ - Create a response generator function for Responses API streaming. - - Parameters: - ---------- - d (dict): The data to be formatted as an SSE event. - - Returns: - ------- - str: The formatted SSE data string. - """ - data = json.dumps(d) - return f"data: {data}\n\n" - - def stream_start_event(conversation_id: str, request_id: str) -> str: """Format an SSE start event for a streaming response. @@ -1038,61 +1031,3 @@ def stream_end_event( "available_quotas": available_quotas, } ) - - -def stream_event(data: dict, event_type: str, media_type: str) -> str: - """Build an item to yield based on media type. - - Args: - data: Dictionary containing the event data - event_type: Type of event (token, tool call, etc.) - media_type: The media type for the response format - - Returns: - SSE-formatted string representing the event - """ - if media_type == MEDIA_TYPE_TEXT: - if event_type == LLM_TOKEN_EVENT: - return data.get("token", "") - if event_type == LLM_TOOL_CALL_EVENT: - return f"[Tool Call: {data.get('function_name', 'unknown')}]\n" - if event_type == LLM_TOOL_RESULT_EVENT: - return "[Tool Result]\n" - if event_type == LLM_TURN_COMPLETE_EVENT: - return "" - return "" - - return format_stream_data( - { - "event": event_type, - "data": data, - } - ) - - -async def shield_violation_generator( - violation_message: str, - media_type: str = MEDIA_TYPE_TEXT, -) -> AsyncIterator[str]: - """ - Create an SSE stream for shield violation responses. - - Yields start, token, and end events immediately for shield violations. - This function creates a minimal streaming response without going through - the Llama Stack response format. - - Args: - violation_message: The violation message to display. - media_type: The media type for the response format. - - Yields: - str: SSE-formatted strings for start, token, and end events. - """ - yield stream_event( - { - "id": 0, - "token": violation_message, - }, - LLM_TOKEN_EVENT, - media_type, - ) diff --git a/src/pydantic_ai_lightspeed/llamastack/__init__.py b/src/pydantic_ai_lightspeed/llamastack/__init__.py index 47eda1e7d..fac9ee826 100644 --- a/src/pydantic_ai_lightspeed/llamastack/__init__.py +++ b/src/pydantic_ai_lightspeed/llamastack/__init__.py @@ -1,5 +1,6 @@ """Pydantic AI provider for Llama Stack.""" +from pydantic_ai_lightspeed.llamastack._model import LlamaStackResponsesModel from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider -__all__ = ["LlamaStackProvider"] +__all__ = ["LlamaStackProvider", "LlamaStackResponsesModel"] diff --git a/src/pydantic_ai_lightspeed/llamastack/_model.py b/src/pydantic_ai_lightspeed/llamastack/_model.py new file mode 100644 index 000000000..26b90be20 --- /dev/null +++ b/src/pydantic_ai_lightspeed/llamastack/_model.py @@ -0,0 +1,571 @@ +"""Custom OpenAI Responses model that works around Llama Stack streaming quirks. + +Llama Stack's Responses API emits MCP tool argument events *before* +``ResponseOutputItemAddedEvent`` and uses event type names that differ from the +OpenAI SDK (``response.mcp_call.arguments.*`` vs ``response.mcp_call_arguments.*``). +pydantic_ai expects: + +* ``output_item.added`` first so it can register the MCP call part and seed args + up to ``"tool_args":`` +* ``response.mcp_call_arguments.delta`` fragments for the tool-args JSON body +* ``response.mcp_call_arguments.done`` (pydantic_ai appends only ``}``) + +This module buffers pre-announcement events per ``item_id`` and replays them in +that order once the ``mcp_call`` output item is announced. Post-announcement +``function_call_arguments.*`` events for ``mcp_call`` items are converted to the +MCP argument form pydantic_ai handles. +""" + +from __future__ import annotations as _annotations + +from collections import defaultdict +from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any, cast + +from openai import AsyncStream +from openai.types import responses +from pydantic_ai import UnexpectedModelBehavior +from pydantic_ai._run_context import RunContext +from pydantic_ai._utils import PeekableAsyncStream, Unset, number_to_datetime +from pydantic_ai.messages import ModelMessage +from pydantic_ai.models import ( + ModelRequestParameters, + StreamedResponse, + check_allow_model_requests, +) +from pydantic_ai.models.openai import ( + OpenAIResponsesModel, + OpenAIResponsesModelSettings, + OpenAIResponsesStreamedResponse, + _map_api_errors, +) +from pydantic_ai.settings import ModelSettings + +from log import get_logger + +logger = get_logger(__name__) + +_LLS_MCP_ARGUMENTS_DELTA_TYPE = "response.mcp_call.arguments.delta" +_LLS_MCP_ARGUMENTS_DONE_TYPE = "response.mcp_call.arguments.done" +_SDK_MCP_ARGUMENTS_DELTA_TYPE = "response.mcp_call_arguments.delta" +_SDK_MCP_ARGUMENTS_DONE_TYPE = "response.mcp_call_arguments.done" + +_MCP_ARGUMENTS_DONE_TYPES = frozenset( + { + _LLS_MCP_ARGUMENTS_DONE_TYPE, + _SDK_MCP_ARGUMENTS_DONE_TYPE, + } +) + + +@dataclass +class _PreAnnouncementArguments: + """Argument fragments buffered before ``output_item.added`` (always MCP for LLS).""" + + delta_fragments: list[str] = field(default_factory=list) + arguments_done: bool = False + done_arguments: str = "{}" + pending_output_done: responses.ResponseOutputItemDoneEvent | None = None + + def has_content(self) -> bool: + """Return whether any argument events are buffered.""" + return bool(self.delta_fragments) or self.arguments_done + + +@dataclass +class _BufferedMcpListToolsEvents: + """Buffered MCP list-tools lifecycle events keyed by item id.""" + + in_progress: responses.ResponseMcpListToolsInProgressEvent | None = None + completed: responses.ResponseMcpListToolsCompletedEvent | None = None + output_done: responses.ResponseOutputItemDoneEvent | None = None + + def has_content(self) -> bool: + """Return whether any list-tools lifecycle events are buffered.""" + return ( + self.in_progress is not None + or self.completed is not None + or self.output_done is not None + ) + + +class _FilteredResponseStream: + """Wraps an OpenAI AsyncStream to reorder and normalize Llama Stack events.""" + + def __init__(self, source: AsyncStream[responses.ResponseStreamEvent]) -> None: + """Wrap an existing stream with reordering logic. + + Args: + source: The raw OpenAI AsyncStream to reorder. + """ + self._source = source + self._released_item_ids: set[str] = set() + self._item_types: dict[str, str] = {} + self._mcp_args_complete: set[str] = set() + self._pre_args_buffers: dict[str, _PreAnnouncementArguments] = defaultdict( + _PreAnnouncementArguments + ) + self._list_tools_buffers: dict[str, _BufferedMcpListToolsEvents] = defaultdict( + _BufferedMcpListToolsEvents + ) + + async def close(self) -> None: + """Close the underlying stream.""" + await self._source.close() + + def __aiter__(self) -> AsyncIterator[responses.ResponseStreamEvent]: + """Return async iterator that reorders events.""" + return self._filtered_iter() + + async def _filtered_iter( + self, + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Yield events, buffering and normalizing Llama Stack streaming quirks.""" + async for event in self._source: + if isinstance(event, responses.ResponseOutputItemAddedEvent): + async for reordered in self._handle_output_item_added(event): + yield reordered + continue + + if isinstance(event, responses.ResponseOutputItemDoneEvent): + async for reordered in self._handle_output_item_done(event): + yield reordered + continue + + if isinstance(event, responses.ResponseFunctionCallArgumentsDeltaEvent): + async for reordered in self._handle_argument_delta(event): + yield reordered + continue + + if isinstance(event, responses.ResponseFunctionCallArgumentsDoneEvent): + async for reordered in self._handle_argument_done(event): + yield reordered + continue + + if isinstance(event, responses.ResponseMcpCallArgumentsDeltaEvent): + async for reordered in self._handle_mcp_argument_delta(event): + yield reordered + continue + + if getattr(event, "type", None) in _MCP_ARGUMENTS_DONE_TYPES: + async for reordered in self._handle_mcp_argument_done(event): + yield reordered + continue + + event_type = getattr(event, "type", None) + if event_type == _LLS_MCP_ARGUMENTS_DELTA_TYPE: + async for reordered in self._handle_lls_mcp_argument_delta(event): + yield reordered + continue + + if isinstance(event, responses.ResponseMcpListToolsInProgressEvent): + if event.item_id in self._released_item_ids: + yield event + else: + self._list_tools_buffers[event.item_id].in_progress = event + continue + + if isinstance(event, responses.ResponseMcpListToolsCompletedEvent): + if event.item_id in self._released_item_ids: + yield event + else: + self._list_tools_buffers[event.item_id].completed = event + continue + + yield event + + for item_id, buffer in list(self._pre_args_buffers.items()): + if not buffer.has_content(): + continue + logger.warning( + "Flushing buffered argument event(s) without output_item.added " + "for item_id=%s", + item_id, + ) + for flushed in self._replay_mcp_argument_events(item_id, buffer): + yield flushed + if buffer.pending_output_done is not None: + yield buffer.pending_output_done + self._pre_args_buffers.pop(item_id, None) + + for item_id, buffer in list(self._list_tools_buffers.items()): + if not buffer.has_content(): + continue + logger.warning( + "Flushing buffered mcp_list_tools event(s) without output_item.added " + "for item_id=%s", + item_id, + ) + for replayed in self._replay_list_tools_events(buffer): + yield replayed + self._list_tools_buffers.pop(item_id, None) + + async def _handle_output_item_added( + self, event: responses.ResponseOutputItemAddedEvent + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Yield ``output_item.added`` then buffered follow-up events for the item.""" + item_id = getattr(event.item, "id", None) + item_type = getattr(event.item, "type", None) + if item_id is not None: + self._released_item_ids.add(item_id) + if item_type is not None: + self._item_types[item_id] = item_type + + yield event + + if item_id is None: + return + + if item_type == "mcp_list_tools": + list_tools_buffer = self._list_tools_buffers.pop( + item_id, _BufferedMcpListToolsEvents() + ) + for replayed in self._replay_list_tools_events(list_tools_buffer): + yield replayed + return + + pre_args = self._pre_args_buffers.pop(item_id, _PreAnnouncementArguments()) + if not pre_args.has_content(): + if pre_args.pending_output_done is not None: + yield pre_args.pending_output_done + return + + if item_type == "mcp_call": + logger.debug( + "Replaying buffered MCP argument events after output_item.added " + "for item_id=%s", + item_id, + ) + for replayed in self._replay_mcp_argument_events(item_id, pre_args): + yield replayed + else: + logger.debug( + "Replaying buffered function argument events after output_item.added " + "for item_id=%s", + item_id, + ) + for replayed in self._replay_function_argument_events(item_id, pre_args): + yield replayed + + if pre_args.pending_output_done is not None: + yield pre_args.pending_output_done + + async def _handle_output_item_done( + self, event: responses.ResponseOutputItemDoneEvent + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Hold MCP call ``output_item.done`` until argument streaming has finished.""" + item_id = getattr(event.item, "id", None) + item_type = getattr(event.item, "type", None) + + if ( + item_type == "mcp_list_tools" + and item_id is not None + and item_id not in self._released_item_ids + ): + self._list_tools_buffers[item_id].output_done = event + return + + if ( + item_type == "mcp_call" + and item_id is not None + and item_id not in self._mcp_args_complete + ): + self._pre_args_buffers[item_id].pending_output_done = event + logger.debug( + "Buffering mcp_call output_item.done until arguments.done for item_id=%s", + item_id, + ) + return + + yield event + + async def _handle_argument_delta( + self, event: responses.ResponseFunctionCallArgumentsDeltaEvent + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Buffer or convert function argument deltas.""" + if event.item_id not in self._released_item_ids: + self._pre_args_buffers[event.item_id].delta_fragments.append(event.delta) + logger.debug( + "Buffering pre-announcement argument delta for item_id=%s", + event.item_id, + ) + return + + if self._item_types.get(event.item_id) == "mcp_call": + yield self._to_mcp_arguments_delta(event) + return + + yield event + + async def _handle_argument_done( + self, event: responses.ResponseFunctionCallArgumentsDoneEvent + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Buffer or drop function argument done events for MCP calls.""" + if event.item_id not in self._released_item_ids: + self._pre_args_buffers[event.item_id].arguments_done = True + self._pre_args_buffers[event.item_id].done_arguments = event.arguments + logger.debug( + "Buffering pre-announcement arguments.done for item_id=%s", + event.item_id, + ) + return + + if self._item_types.get(event.item_id) == "mcp_call": + if event.item_id not in self._mcp_args_complete: + yield self._to_mcp_arguments_done(event) + self._mcp_args_complete.add(event.item_id) + pending = self._pre_args_buffers[event.item_id].pending_output_done + if pending is not None: + self._pre_args_buffers[event.item_id].pending_output_done = None + yield pending + return + + yield event + + async def _handle_mcp_argument_delta( + self, event: responses.ResponseMcpCallArgumentsDeltaEvent + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Buffer or pass through OpenAI-form MCP argument deltas.""" + if event.item_id not in self._released_item_ids: + self._pre_args_buffers[event.item_id].delta_fragments.append(event.delta) + return + yield event + + async def _handle_lls_mcp_argument_delta( + self, event: responses.ResponseStreamEvent + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Normalize Llama Stack dot-form MCP argument deltas.""" + item_id = cast(str, getattr(event, "item_id", None)) + delta = cast(str, getattr(event, "delta", "")) + if item_id not in self._released_item_ids: + self._pre_args_buffers[item_id].delta_fragments.append(delta) + return + yield self._build_mcp_arguments_delta( + item_id=item_id, + delta=delta, + output_index=getattr(event, "output_index", 0), + sequence_number=getattr(event, "sequence_number", 0), + ) + + async def _handle_mcp_argument_done( + self, event: responses.ResponseStreamEvent + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Buffer or normalize MCP arguments.done events.""" + mcp_done = self._normalize_mcp_arguments_done(event) + if mcp_done.item_id not in self._released_item_ids: + self._pre_args_buffers[mcp_done.item_id].arguments_done = True + self._pre_args_buffers[mcp_done.item_id].done_arguments = mcp_done.arguments + return + + if mcp_done.item_id not in self._mcp_args_complete: + yield mcp_done + self._mcp_args_complete.add(mcp_done.item_id) + pending = self._pre_args_buffers[mcp_done.item_id].pending_output_done + if pending is not None: + self._pre_args_buffers[mcp_done.item_id].pending_output_done = None + yield pending + + def _replay_list_tools_events( + self, buffer: _BufferedMcpListToolsEvents + ) -> Iterator[responses.ResponseStreamEvent]: + """Replay buffered MCP list-tools lifecycle events in pydantic_ai order.""" + if buffer.in_progress is not None: + yield buffer.in_progress + if buffer.completed is not None: + yield buffer.completed + if buffer.output_done is not None: + yield buffer.output_done + + def _replay_function_argument_events( + self, + item_id: str, + buffer: _PreAnnouncementArguments, + ) -> Iterator[responses.ResponseStreamEvent]: + """Replay buffered argument fragments as function-call events.""" + output_index = 0 + sequence_number = 0 + for fragment in buffer.delta_fragments: + yield responses.ResponseFunctionCallArgumentsDeltaEvent.model_validate( + { + "type": "response.function_call_arguments.delta", + "item_id": item_id, + "output_index": output_index, + "sequence_number": sequence_number, + "delta": fragment, + } + ) + sequence_number += 1 + + if buffer.arguments_done: + yield responses.ResponseFunctionCallArgumentsDoneEvent.model_validate( + { + "type": "response.function_call_arguments.done", + "item_id": item_id, + "output_index": output_index, + "sequence_number": sequence_number, + "arguments": buffer.done_arguments, + "name": "", + } + ) + + def _replay_mcp_argument_events( + self, + item_id: str, + buffer: _PreAnnouncementArguments, + ) -> Iterator[responses.ResponseStreamEvent]: + """Replay buffered MCP argument fragments after ``output_item.added``. + + pydantic_ai seeds args up to ``"tool_args":`` when the item is announced, + then appends delta fragments and closes the JSON object on arguments.done. + """ + if not buffer.delta_fragments and not buffer.arguments_done: + return + + output_index = 0 + sequence_number = 0 + for fragment in buffer.delta_fragments: + yield self._build_mcp_arguments_delta( + item_id=item_id, + delta=fragment, + output_index=output_index, + sequence_number=sequence_number, + ) + sequence_number += 1 + + if buffer.arguments_done: + yield self._build_mcp_arguments_done( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + self._mcp_args_complete.add(item_id) + + def _to_mcp_arguments_delta( + self, event: responses.ResponseFunctionCallArgumentsDeltaEvent + ) -> responses.ResponseMcpCallArgumentsDeltaEvent: + """Convert a function argument delta into the MCP form pydantic_ai expects.""" + return self._build_mcp_arguments_delta( + item_id=event.item_id, + delta=event.delta, + output_index=event.output_index, + sequence_number=event.sequence_number, + ) + + def _to_mcp_arguments_done( + self, event: responses.ResponseFunctionCallArgumentsDoneEvent + ) -> responses.ResponseMcpCallArgumentsDoneEvent: + """Convert a misclassified function arguments.done into MCP form.""" + return self._build_mcp_arguments_done( + item_id=event.item_id, + output_index=event.output_index, + sequence_number=event.sequence_number, + arguments=event.arguments, + ) + + @staticmethod + def _build_mcp_arguments_delta( + *, + item_id: str, + delta: str, + output_index: int, + sequence_number: int, + ) -> responses.ResponseMcpCallArgumentsDeltaEvent: + """Build an OpenAI SDK MCP arguments delta event.""" + return responses.ResponseMcpCallArgumentsDeltaEvent.model_validate( + { + "type": _SDK_MCP_ARGUMENTS_DELTA_TYPE, + "item_id": item_id, + "output_index": output_index, + "sequence_number": sequence_number, + "delta": delta, + } + ) + + @staticmethod + def _build_mcp_arguments_done( + *, + item_id: str, + output_index: int, + sequence_number: int, + arguments: str = "{}", + ) -> responses.ResponseMcpCallArgumentsDoneEvent: + """Build an OpenAI SDK MCP arguments done event.""" + return responses.ResponseMcpCallArgumentsDoneEvent.model_validate( + { + "type": _SDK_MCP_ARGUMENTS_DONE_TYPE, + "item_id": item_id, + "output_index": output_index, + "sequence_number": sequence_number, + "arguments": arguments, + } + ) + + @staticmethod + def _normalize_mcp_arguments_done( + event: responses.ResponseStreamEvent, + ) -> responses.ResponseMcpCallArgumentsDoneEvent: + """Normalize Llama Stack MCP done event types to the OpenAI SDK form.""" + if isinstance(event, responses.ResponseMcpCallArgumentsDoneEvent): + return event + return responses.ResponseMcpCallArgumentsDoneEvent.model_validate( + { + **event.model_dump(exclude={"type"}), + "type": _SDK_MCP_ARGUMENTS_DONE_TYPE, + } + ) + + +class LlamaStackResponsesModel(OpenAIResponsesModel): + """OpenAI Responses model with Llama Stack streaming compatibility fixes.""" + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, + ) -> AsyncIterator[StreamedResponse]: + """Request a streaming response with Llama Stack event normalization.""" + check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) + model_settings_cast = cast(OpenAIResponsesModelSettings, model_settings or {}) + response = await self._responses_create( + messages, True, model_settings_cast, model_request_parameters + ) + + filtered_stream = _FilteredResponseStream(response) + + async with response: + peekable: PeekableAsyncStream[ + responses.ResponseStreamEvent, _FilteredResponseStream + ] = PeekableAsyncStream(filtered_stream) + + with _map_api_errors(self.model_name): + first_chunk = await peekable.peek() + + if isinstance(first_chunk, Unset): + raise UnexpectedModelBehavior( + "Streamed response ended without content or tool calls" + ) + + assert isinstance(first_chunk, responses.ResponseCreatedEvent) + + yield OpenAIResponsesStreamedResponse( + model_request_parameters=model_request_parameters, + _model_name=first_chunk.response.model, + _model_settings=model_settings_cast, + _response=peekable, # type: ignore[arg-type] + _provider_name=self._provider.name, + _provider_url=self._provider.base_url, + _provider_timestamp=( + number_to_datetime(first_chunk.response.created_at) + if first_chunk.response.created_at + else None + ), + ) diff --git a/src/pydantic_ai_lightspeed/llamastack/_transport.py b/src/pydantic_ai_lightspeed/llamastack/_transport.py index 1d63bd60f..e5401bd68 100644 --- a/src/pydantic_ai_lightspeed/llamastack/_transport.py +++ b/src/pydantic_ai_lightspeed/llamastack/_transport.py @@ -17,6 +17,7 @@ ) from llama_stack.core.server.routes import find_matching_route from llama_stack.core.utils.context import preserve_contexts_async_generator +from starlette.responses import StreamingResponse class _AsyncByteStream(httpx.AsyncByteStream): @@ -183,9 +184,16 @@ async def _handle_streaming( result = await func(**merged_body) async def gen() -> AsyncGenerator[bytes, None]: - async for chunk in result: - data = json.dumps(convert_pydantic_to_json_value(chunk)) - yield f"data: {data}\n\n".encode("utf-8") + if isinstance(result, StreamingResponse): + async for chunk in result.body_iterator: + if isinstance(chunk, str): + yield chunk.encode("utf-8") + else: + yield bytes(chunk) + else: + async for chunk in result: + data = json.dumps(convert_pydantic_to_json_value(chunk)) + yield f"data: {data}\n\n".encode("utf-8") wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR]) diff --git a/src/utils/agents/query.py b/src/utils/agents/query.py new file mode 100644 index 000000000..3eefb4b74 --- /dev/null +++ b/src/utils/agents/query.py @@ -0,0 +1,335 @@ +"""Non-streaming agent helpers and shared turn-summary builders for agent runs.""" + +from __future__ import annotations + +from enum import Enum +from typing import TypeAlias + +from fastapi import HTTPException +from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient +from pydantic_ai.exceptions import ( + AgentRunError, + ContentFilterError, + IncompleteToolCall, + ModelAPIError, + ModelHTTPError, + UnexpectedModelBehavior, + UsageLimitExceeded, +) +from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart +from pydantic_ai.run import AgentRunResult +from pydantic_ai.usage import RunUsage + +from log import get_logger +from metrics import recording +from models.api.responses.error import ( + AbstractErrorResponse, + InternalServerErrorResponse, + PromptTooLongResponse, + QuotaExceededResponse, + ServiceUnavailableResponse, +) +from models.common.agents import AgentTurnAccumulator +from models.common.moderation import ShieldModerationResult +from models.common.responses.responses_api_params import ResponsesApiParams +from models.common.turn_summary import TurnSummary +from utils.agents.tool_processor import ( + process_function_tool_call, + process_function_tool_result, + process_native_tool_call, + process_native_tool_result, +) +from utils.conversations import append_turn_items_to_conversation +from utils.pydantic_ai import build_agent +from utils.query import ( + extract_provider_and_model_from_model_id, + handle_known_apistatus_errors, + is_context_length_error, +) +from utils.responses import extract_text_from_response_items +from utils.token_counter import TokenCounter + +logger = get_logger(__name__) + +AgentInferenceError: TypeAlias = ( + AgentRunError | APIStatusError | APIConnectionError | RuntimeError +) + + +class AgentFinishReason(str, Enum): + """Finish reason for a completed agent model response.""" + + CONTENT_FILTER = "content_filter" + CANCELLED = "cancelled" + SUCCESS = "stop" + LENGTH = "length" + ERROR = "error" + + +def map_agent_inference_error( + exc: AgentInferenceError, + model_id: str, +) -> AbstractErrorResponse: + """Map agent run failures from pydantic-ai or Llama Stack to an LCS error response. + + Args: + exc: Agent, HTTP status, connection, or context-length runtime error. + model_id: Model identifier in provider/model format. + + Returns: + Structured error response for HTTP or SSE error events. + + Raises: + RuntimeError: Re-raised when ``exc`` is a non-agent ``RuntimeError`` that is + not a recognized context-length failure. + """ + if isinstance(exc, AgentRunError): + return map_pydantic_agent_run_error(exc, model_id) + if isinstance(exc, APIStatusError): + return handle_known_apistatus_errors(exc, model_id) + if isinstance(exc, APIConnectionError): + return ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(exc), + ) + if isinstance(exc, RuntimeError) and is_context_length_error(str(exc)): + return PromptTooLongResponse(model=model_id) + return InternalServerErrorResponse.generic() + + +def map_pydantic_agent_run_error( + exc: AgentRunError, model_id: str +) -> AbstractErrorResponse: + """Map pydantic-ai ``AgentRunError`` subclasses to LCS error responses. + + Args: + exc: Agent exception to map. + model_id: Model identifier in provider/model format. + + Returns: + Structured error response for HTTP or SSE error events. + """ + if isinstance(exc, ContentFilterError): + return InternalServerErrorResponse.query_failed(str(exc)) + if isinstance(exc, IncompleteToolCall): + return PromptTooLongResponse(model=model_id) + if isinstance(exc, UnexpectedModelBehavior): + return PromptTooLongResponse(model=model_id) + if isinstance(exc, UsageLimitExceeded): + return QuotaExceededResponse.model(model_id) + if isinstance(exc, ModelHTTPError): + if is_context_length_error(str(exc)): + return PromptTooLongResponse(model=model_id) + if exc.status_code == 429: + return QuotaExceededResponse.model(model_id) + return InternalServerErrorResponse.generic() + if isinstance(exc, ModelAPIError): + return ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(exc), + ) + return InternalServerErrorResponse.query_failed(str(exc)) + + +def get_agent_finish_reason(response: ModelResponse) -> AgentFinishReason: + """Get the finish reason from a completed agent model response. + + Args: + response: Last model response from the agent run. + + Returns: + Resolved finish reason. + """ + raw_finish_reason = (response.provider_details or {}).get("finish_reason") + if raw_finish_reason == "cancelled": + return AgentFinishReason.CANCELLED + if response.finish_reason is None: + return AgentFinishReason.ERROR + return AgentFinishReason(response.finish_reason) + + +def get_finish_reason_error( + finish_reason: AgentFinishReason, + model_id: str, +) -> AbstractErrorResponse: + """Map a non-success agent finish reason to an LCS error response. + + Args: + finish_reason: Resolved finish reason from :func:`get_agent_finish_reason`. + model_id: Model identifier in provider/model format. + + Returns: + Structured error response for HTTP or SSE error events. + """ + match finish_reason: + case AgentFinishReason.LENGTH: + return PromptTooLongResponse(model=model_id) + case AgentFinishReason.CONTENT_FILTER: + return InternalServerErrorResponse.query_failed( + "The model refused to generate a response due to content policy." + ) + case AgentFinishReason.CANCELLED: + return InternalServerErrorResponse.query_failed( + "The response was cancelled before completion." + ) + case _: + return InternalServerErrorResponse.query_failed( + "An unexpected error occurred while processing the request." + ) + + +def extract_agent_token_usage( + usage: RunUsage, + model: str, + endpoint_path: str, +) -> TokenCounter: + """Build token usage for a completed agent run and record related metrics. + + Args: + usage: Run usage reported by the agent. + model: Model identifier in provider/model format. + endpoint_path: Endpoint path used for metric labeling. + + Returns: + Aggregated token usage counter for the run. + """ + provider_id, model_id = extract_provider_and_model_from_model_id(model) + token_counter = TokenCounter( + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + llm_calls=max(usage.requests, 1), + ) + logger.debug( + "Extracted token usage from agent run: input=%d, output=%d, requests=%d", + token_counter.input_tokens, + token_counter.output_tokens, + usage.requests, + ) + recording.record_llm_token_usage( + provider_id, + model_id, + token_counter.input_tokens, + token_counter.output_tokens, + endpoint_path, + ) + recording.record_llm_call(provider_id, model_id, endpoint_path) + return token_counter + + +def build_turn_summary_from_agent_run( + run_result: AgentRunResult[str], + *, + model_id: str, + endpoint_path: str, + vector_store_ids: list[str] | None = None, + rag_id_mapping: dict[str, str] | None = None, +) -> TurnSummary: + """Build a turn summary from a completed agent run. + + Args: + run_result: Completed agent run result. + model_id: Model identifier in provider/model format. + endpoint_path: Endpoint path used for metric labeling. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Turn summary with text, tools, RAG metadata, and token usage. + + Raises: + HTTPException: When the run failed. + """ + finish_reason = get_agent_finish_reason(run_result.response) + if finish_reason != AgentFinishReason.SUCCESS: + error_response = get_finish_reason_error(finish_reason, model_id) + raise HTTPException(**error_response.model_dump()) + + state = AgentTurnAccumulator( + vector_store_ids=vector_store_ids or [], + rag_id_mapping=rag_id_mapping or {}, + turn_summary=TurnSummary(), + ) + + for message in run_result.new_messages(): + if isinstance(message, ModelResponse): + if message.text: + state.turn_summary.llm_response = message.text + for tool_call_part in message.tool_calls: + process_function_tool_call(state, tool_call_part) + for call_part, return_part in message.native_tool_calls: + logger.error(f"Native tool call: {repr(call_part)}") + logger.error(f"Native tool return: {repr(return_part)}") + process_native_tool_call(state, call_part) + process_native_tool_result(state, return_part) + elif isinstance(message, ModelRequest): + for request_part in message.parts: + if isinstance(request_part, ToolReturnPart): + process_function_tool_result(state, request_part) + + state.turn_summary.id = run_result.response.provider_response_id or "" + state.turn_summary.token_usage = extract_agent_token_usage( + run_result.usage, + model_id, + endpoint_path, + ) + return state.turn_summary + + +async def retrieve_agent_response( + client: AsyncLlamaStackClient, + responses_params: ResponsesApiParams, + moderation_result: ShieldModerationResult, + endpoint_path: str, + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> TurnSummary: + """Retrieve a turn summary from a blocking agent run. + + Mirrors :func:`app.endpoints.query.retrieve_response` for the agent path. + + Args: + client: Llama Stack client for conversation persistence on moderation block. + responses_params: Prepared Responses API parameters. + moderation_result: Shield moderation outcome for the turn. + endpoint_path: Endpoint path used for metric labeling. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Turn summary for the completed agent run. + + Raises: + HTTPException: On moderation is not applicable; on agent or provider failure. + """ + if moderation_result.decision == "blocked": + await append_turn_items_to_conversation( + client, + responses_params.conversation, + responses_params.input, + [moderation_result.refusal_response], + ) + return TurnSummary( + id=moderation_result.moderation_id, + llm_response=moderation_result.message, + ) + prompt = ( + responses_params.input + if isinstance(responses_params.input, str) + else extract_text_from_response_items(responses_params.input) + ) + + try: + agent = build_agent(client, responses_params) + logger.debug("Starting agent non-streaming response processing") + run_result = await agent.run(prompt) + except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc: + response = map_agent_inference_error(exc, responses_params.model) + raise HTTPException(**response.model_dump()) from exc + + return build_turn_summary_from_agent_run( + run_result, + model_id=responses_params.model, + endpoint_path=endpoint_path, + vector_store_ids=vector_store_ids, + rag_id_mapping=rag_id_mapping, + ) diff --git a/src/utils/agents/streaming.py b/src/utils/agents/streaming.py new file mode 100644 index 000000000..5b411db1f --- /dev/null +++ b/src/utils/agents/streaming.py @@ -0,0 +1,482 @@ +"""Agent streaming helpers for the streaming_query flow.""" + +from __future__ import annotations + +import datetime +from collections.abc import AsyncIterator +from functools import singledispatch +from typing import Any, Final, Optional, cast + +from fastapi import HTTPException +from llama_stack_client import APIConnectionError, APIStatusError +from pydantic_ai import Agent, AgentRunError, AgentRunResultEvent, ToolReturnPart +from pydantic_ai.messages import ( + AgentStreamEvent, + FunctionToolCallEvent, + FunctionToolResultEvent, + NativeToolCallPart, + NativeToolReturnPart, + PartDeltaEvent, + PartEndEvent, + PartStartEvent, + TextPart, + TextPartDelta, +) + +from configuration import configuration +from constants import MEDIA_TYPE_JSON +from log import get_logger +from models.common.agents import ( + AgentTurnAccumulator, + EndStreamPayload, + ErrorStreamPayload, + StartStreamPayload, + StreamEventPayload, + TokenStreamPayload, + ToolCallStreamPayload, + ToolResultStreamPayload, + TurnCompleteStreamPayload, +) +from models.common.responses.contexts import ResponseGeneratorContext +from models.common.responses.responses_api_params import ResponsesApiParams +from models.common.turn_summary import TurnSummary +from utils.agents.query import ( + AgentFinishReason, + extract_agent_token_usage, + get_agent_finish_reason, + get_finish_reason_error, + map_agent_inference_error, +) +from utils.agents.tool_processor import ( + process_function_tool_call, + process_function_tool_result, + process_native_tool_call, + process_native_tool_result, +) +from utils.conversations import append_turn_items_to_conversation +from utils.pydantic_ai import build_agent +from utils.query import consume_query_tokens, store_query_results +from utils.quota import get_available_quotas +from utils.responses import ( + deduplicate_referenced_documents, + maybe_get_topic_summary, +) +from utils.streaming_sse import shield_violation_generator + +logger = get_logger(__name__) + +DEFAULT_REFUSAL_RESPONSE: Final[str] = ( + "I cannot process this request due to policy restrictions." +) + + +async def retrieve_agent_response_generator( + responses_params: ResponsesApiParams, + context: ResponseGeneratorContext, + endpoint_path: str, +) -> tuple[AsyncIterator[str], TurnSummary]: + """Return the SSE generator and mutable turn summary for an agent run. + + Args: + responses_params: Prepared Responses API parameters. + context: Streaming request context and moderation result. + endpoint_path: Endpoint path used for metric labeling. + + Returns: + Tuple of SSE async iterator and mutable turn summary. + """ + turn_summary = TurnSummary() + try: + if context.moderation_result.decision == "blocked": + turn_summary.llm_response = context.moderation_result.message + turn_summary.id = context.moderation_result.moderation_id + await append_turn_items_to_conversation( + context.client, + responses_params.conversation, + responses_params.input, + [context.moderation_result.refusal_response], + ) + media_type = context.query_request.media_type or MEDIA_TYPE_JSON + return ( + shield_violation_generator( + context.moderation_result.message, + media_type, + ), + turn_summary, + ) + + agent = build_agent(context.client, responses_params) + + return ( + agent_response_generator( + agent, + responses_params, + context, + turn_summary, + endpoint_path, + ), + turn_summary, + ) + except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc: + response = map_agent_inference_error(exc, responses_params.model) + raise HTTPException(**response.model_dump()) from exc + + +async def generate_agent_response( + generator: AsyncIterator[str], + context: ResponseGeneratorContext, + responses_params: ResponsesApiParams, + turn_summary: TurnSummary, +) -> AsyncIterator[str]: + """Wrap an agent SSE generator with cleanup logic. + + Re-yields events from the generator, handles errors, and ensures + persistence and token consumption after completion. + + Args: + generator: The base agent SSE generator to wrap. + context: The response generator context. + responses_params: The Responses API parameters. + turn_summary: TurnSummary populated during streaming. + + Yields: + SSE-formatted strings from the wrapped generator. + """ + media_type = context.query_request.media_type or MEDIA_TYPE_JSON + prompt = cast(str, responses_params.input) + stream_completed = False + yield serialize_event( + StartStreamPayload.create( + conversation_id=context.conversation_id, + request_id=context.request_id, + ), + media_type, + ) + try: + async for event in generator: + yield event + + stream_completed = True + + except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc: + error_response = map_agent_inference_error(exc, responses_params.model) + yield serialize_event( + ErrorStreamPayload.from_error_response(error_response), + media_type, + ) + + if not stream_completed: + return + + should_generate_topic_summary = ( + context.query_request.conversation_id is None + and bool(context.query_request.generate_topic_summary) + ) + topic_summary = await maybe_get_topic_summary( + generate_topic_summary=should_generate_topic_summary, + input_text=prompt, + client=context.client, + model_id=responses_params.model, + ) + logger.info("Consuming tokens") + consume_query_tokens( + user_id=context.user_id, + model_id=responses_params.model, + token_usage=turn_summary.token_usage, + ) + logger.info("Getting available quotas") + available_quotas = get_available_quotas( + quota_limiters=configuration.quota_limiters, + user_id=context.user_id, + ) + end_payload = EndStreamPayload.create( + referenced_documents=turn_summary.referenced_documents, + input_tokens=turn_summary.token_usage.input_tokens, + output_tokens=turn_summary.token_usage.output_tokens, + available_quotas=available_quotas, + ) + yield serialize_event(end_payload, media_type) + + completed_at = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + logger.info("Storing query results") + store_query_results( + user_id=context.user_id, + conversation_id=context.conversation_id, + model=responses_params.model, + completed_at=completed_at, + started_at=context.started_at, + summary=turn_summary, + query=prompt, + skip_userid_check=context.skip_userid_check, + topic_summary=topic_summary, + ) + logger.info("Agent streaming complete") + + +async def agent_response_generator( + agent: Agent[Any, str], + responses_params: ResponsesApiParams, + context: ResponseGeneratorContext, + turn_summary: TurnSummary, + endpoint_path: str, +) -> AsyncIterator[str]: + """Stream SSE events from an agent run and update the turn summary. + + Args: + agent: Agent to execute. + responses_params: Prepared Responses API parameters. + context: Streaming request context. + turn_summary: Mutable summary to fill while streaming. + endpoint_path: Endpoint path used for metric labeling. + + Yields: + Serialized SSE event strings. + """ + media_type = context.query_request.media_type or MEDIA_TYPE_JSON + dispatch_state = AgentTurnAccumulator( + vector_store_ids=context.vector_store_ids, + rag_id_mapping=context.rag_id_mapping, + turn_summary=turn_summary, + ) + prompt = cast(str, responses_params.input) # query is always a string + + logger.debug("Starting agent streaming response processing") + async with agent.run_stream_events(prompt) as stream: + async for event in stream: + logger.error(f"event: {event.event_kind} {repr(event)}") + if payload := dispatch_stream_event(event, dispatch_state): + # print(f"payload: {payload.serialize_json()}") + yield serialize_event(payload, media_type) + + if dispatch_state.run_result is None: + logger.error("No final result received from agent run") + return + + run_result = dispatch_state.run_result + turn_summary.token_usage = extract_agent_token_usage( + run_result.usage, + responses_params.model, + endpoint_path, + ) + + finish_reason = get_agent_finish_reason(run_result.response) + if finish_reason != AgentFinishReason.SUCCESS: + error_response = get_finish_reason_error(finish_reason, responses_params.model) + yield serialize_event( + ErrorStreamPayload.from_error_response(error_response), + media_type, + ) + + turn_summary.referenced_documents = deduplicate_referenced_documents( + context.inline_rag_context.referenced_documents + + turn_summary.referenced_documents + ) + turn_summary.rag_chunks = ( + context.inline_rag_context.rag_chunks + turn_summary.rag_chunks + ) + + +def serialize_event( + payload: StreamEventPayload, + media_type: str = MEDIA_TYPE_JSON, +) -> str: + """Serialize an LLM stream payload (token, tool, turn complete) for the client.""" + if media_type == MEDIA_TYPE_JSON: + return payload.serialize_json() + return payload.serialize_text() + + +def _process_token( + state: AgentTurnAccumulator, + text: str, +) -> StreamEventPayload: + """Append text to state and build a token stream payload. + + Args: + state: Mutable dispatch reducer state. + text: Token text to append and emit. + + Returns: + Token stream payload containing the emitted token chunk. + """ + state.text_parts.append(text) + payload = TokenStreamPayload.create( + chunk_id=state.chunk_id, + token=text, + ) + state.chunk_id += 1 + return payload + + +@singledispatch +def dispatch_stream_event( + event: AgentStreamEvent | AgentRunResultEvent, + _state: AgentTurnAccumulator, +) -> Optional[StreamEventPayload]: + """Map a pydantic-ai stream event to an SSE payload. + + Args: + event: Agent stream event emitted by the runtime. + _state: Mutable accumulator for stream processing. + + Returns: + None when the event does not map to an SSE payload. + """ + logger.debug("Ignoring event kind=%s", event.event_kind) + return None + + +@dispatch_stream_event.register +def _( + event: AgentRunResultEvent, + state: AgentTurnAccumulator, +) -> Optional[StreamEventPayload]: + """Handle final run result event and emit completion payload. + + Args: + event: Final run result event. + state: Mutable accumulator for stream processing. + + Returns: + Completion stream payload. + """ + state.run_result = event.result + state.turn_summary.id = state.run_result.response.provider_response_id or "" + if state.run_result.response.finish_reason == "content_filter": + provider_details = state.run_result.response.provider_details or {} + final_text = ( + provider_details.get("refusal_response") or DEFAULT_REFUSAL_RESPONSE + ) + else: + final_text = state.run_result.response.text or "".join(state.text_parts) + + payload = TurnCompleteStreamPayload.create( + chunk_id=state.chunk_id, + token=final_text, + ) + state.chunk_id += 1 + return payload + + +@dispatch_stream_event.register +def _( + event: PartStartEvent, + state: AgentTurnAccumulator, +) -> Optional[StreamEventPayload]: + """Handle start of a model response part. + + Args: + event: Part start event. + state: Mutable accumulator for stream processing. + + Returns: + Optional stream payload emitted at part start. + """ + part = event.part + if isinstance(part, TextPart): + state.increment_round_if_pending() + return _process_token(state, part.content) + + if isinstance(part, NativeToolReturnPart): + if tool_result := process_native_tool_result(state, part): + # print(f"Tool result summarized: {tool_result.model_dump_json()}") + return ToolResultStreamPayload(data=tool_result) + return None + + logger.debug("Ignoring part start kind=%s", part.part_kind) + return None + + +@dispatch_stream_event.register +def _( + event: PartDeltaEvent, + state: AgentTurnAccumulator, +) -> Optional[StreamEventPayload]: + """Handle incremental updates to a model response part. + + Args: + event: Part delta event. + state: Mutable accumulator for stream processing. + + Returns: + Optional stream payload emitted for the delta. + """ + delta = event.delta + if isinstance(delta, TextPartDelta): + return _process_token(state, delta.content_delta) + + logger.debug("Ignoring part delta kind=%s", delta.part_delta_kind) + return None + + +@dispatch_stream_event.register +def _( + event: PartEndEvent, + state: AgentTurnAccumulator, +) -> Optional[StreamEventPayload]: + """Handle completion of a model response part. + + Args: + event: Part end event. + state: Mutable accumulator for stream processing. + + Returns: + Optional stream payload emitted at part end. + """ + part = event.part + if isinstance(part, TextPart): + state.turn_summary.llm_response += ( + part.content or "".join(state.text_parts) + "\n\n" + ) + state.text_parts.clear() + return None + + if isinstance(part, NativeToolCallPart): + if summary := process_native_tool_call(state, part): + return ToolCallStreamPayload(data=summary) + return None + + logger.debug("Ignoring part end kind=%s", part.part_kind) + return None + + +@dispatch_stream_event.register +def _( + event: FunctionToolCallEvent, + state: AgentTurnAccumulator, +) -> Optional[StreamEventPayload]: + """Handle function tool call event. + + Args: + event: Function tool call event. + state: Mutable accumulator for stream processing. + + Returns: + Tool call stream payload or None. + """ + if summary := process_function_tool_call(state, event.part): + return ToolCallStreamPayload(data=summary) + return None + + +@dispatch_stream_event.register +def _( + event: FunctionToolResultEvent, + state: AgentTurnAccumulator, +) -> Optional[StreamEventPayload]: + """Handle function tool result event. + + Args: + event: Function tool result event. + state: Mutable accumulator for stream processing. + + Returns: + Tool result stream payload or None. + """ + part = event.part + if not isinstance(part, ToolReturnPart): + return None + + if result := process_function_tool_result(state, part): + return ToolResultStreamPayload(data=result) + return None diff --git a/src/utils/agents/tool_processor.py b/src/utils/agents/tool_processor.py new file mode 100644 index 000000000..4f6d283c8 --- /dev/null +++ b/src/utils/agents/tool_processor.py @@ -0,0 +1,530 @@ +"""Process and record pydantic-ai tool parts during agent stream dispatch.""" + +from __future__ import annotations + +import json +from typing import Any, Optional, cast + +from openai.types.responses.response_file_search_tool_call import ( + Result as OpenAIFileSearchResult, +) +from pydantic import AnyUrl +from pydantic_ai.messages import ( + NativeToolCallPart, + NativeToolReturnPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.native_tools import FileSearchTool, MCPServerTool, WebSearchTool + +from constants import DEFAULT_RAG_TOOL +from log import get_logger +from models.common.agents import AgentTurnAccumulator +from models.common.turn_summary import ( + MCPListToolsSummary, + RAGChunk, + ReferencedDocument, + ToolCallSummary, + ToolInfoSummary, + ToolResultSummary, +) +from utils.responses import resolve_source_for_result + +logger = get_logger(__name__) + +_FILE_SEARCH_URL_KEYS = ("doc_url", "docs_url", "url", "link", "reference_url") +_MCP_SERVER_TOOL_PREFIX = f"{MCPServerTool.kind}:" + + +def summarize_function_tool_call(part: ToolCallPart) -> ToolCallSummary: + """Build a tool-call summary for a client function tool call. + + Args: + part: Function tool call part emitted by the agent. + + Returns: + Tool call summary in LCS turn-summary format. + """ + return ToolCallSummary( + id=part.tool_call_id, + name=part.tool_name, + args=part.args_as_dict(), + type="function_call", + ) + + +def summarize_native_tool_call( + part: NativeToolCallPart, +) -> Optional[ToolCallSummary]: + """Build a tool-call summary for a native agent tool call. + + Args: + part: Native tool call part emitted by the model. + + Returns: + Tool call summary in LCS turn-summary format. + """ + call_id = part.tool_call_id + args = part.args_as_dict() + match part.tool_name: + case WebSearchTool.kind: + return ToolCallSummary( + id=call_id, + name=part.tool_name, + args=args, + type="web_search_call", + ) + case FileSearchTool.kind: + return ToolCallSummary( + id=call_id, + name=DEFAULT_RAG_TOOL, + args=args, + type="file_search_call", + ) + case tool_name if tool_name.startswith(_MCP_SERVER_TOOL_PREFIX): + label = tool_name.removeprefix(_MCP_SERVER_TOOL_PREFIX) + action = args.get("action") + # MCP list tools + if action == "list_tools": + return ToolCallSummary( + id=call_id, + name="mcp_list_tools", + args={"server_label": label}, + type="mcp_list_tools", + ) + + # MCP call + return ToolCallSummary( + id=call_id, + name=args.get("tool_name") or "", + args=args.get("tool_args", {}), + type="mcp_call", + ) + case _: + logger.warning(f"Unknown tool name: {part.tool_name}") + return None + + +def process_function_tool_call( + state: AgentTurnAccumulator, + part: ToolCallPart, +) -> Optional[ToolCallSummary]: + """Record a client function tool call on dispatch state. + + Args: + state: Mutable dispatch reducer state. + part: Function tool call part from the agent. + + Returns: + Tool call summary when recorded, otherwise None if already emitted. + """ + if part.tool_call_id in state.emitted_tool_call_ids: + return None + summary = summarize_function_tool_call(part) + state.increment_round_if_pending() + state.emitted_tool_call_ids.add(summary.id) + state.turn_summary.tool_calls.append(summary) + return summary + + +def process_native_tool_call( + state: AgentTurnAccumulator, + part: NativeToolCallPart, +) -> Optional[ToolCallSummary]: + """Record a native tool call on dispatch state. + + Args: + state: Mutable dispatch reducer state. + part: Native tool call part from the model. + + Returns: + Tool call summary when recorded, otherwise None if already emitted. + """ + if part.tool_call_id in state.emitted_tool_call_ids: + return None + if summary := summarize_native_tool_call(part): + state.increment_round_if_pending() + state.emitted_tool_call_ids.add(summary.id) + state.turn_summary.tool_calls.append(summary) + return summary + logger.warning("Tool call not summarized: %s", part.tool_call_id) + return None + + +def process_native_tool_result( + state: AgentTurnAccumulator, + part: NativeToolReturnPart, +) -> Optional[ToolResultSummary]: + """Record a native tool return on dispatch state. + + Args: + state: Mutable dispatch reducer state. + part: Native tool return part from the model. + + Returns: + Tool result summary when recorded, otherwise None if already emitted. + """ + if part.tool_call_id in state.emitted_tool_result_ids: + return None + + match part.tool_name: + case FileSearchTool.kind: + tool_result, rag_chunks, referenced_documents = ( + summarize_file_search_result( + part, + state.tool_round, + state.seen_docs, + state.vector_store_ids, + state.rag_id_mapping, + ) + ) + state.turn_summary.rag_chunks.extend(rag_chunks) + state.turn_summary.referenced_documents.extend(referenced_documents) + case WebSearchTool.kind: + tool_result = summarize_web_search_result(part, state.tool_round) + case tool_name if tool_name.startswith(_MCP_SERVER_TOOL_PREFIX): + tool_result = summarize_mcp_tool_result(part, state.tool_round) + case _: + logger.warning(f"Unknown tool name: {part.tool_name}") + return None + + state.emitted_tool_result_ids.add(tool_result.id) + state.turn_summary.tool_results.append(tool_result) + state.round_increment_pending = True + return tool_result + + +def process_function_tool_result( + state: AgentTurnAccumulator, + part: ToolReturnPart, +) -> Optional[ToolResultSummary]: + """Record a client function tool return on dispatch state. + + Args: + state: Mutable dispatch reducer state. + part: Function tool return part from the agent. + + Returns: + Tool result summary when recorded, otherwise None if already emitted. + """ + if part.tool_call_id in state.emitted_tool_result_ids: + return None + tool_result = summarize_function_tool_result(part, state.tool_round) + state.emitted_tool_result_ids.add(tool_result.id) + state.turn_summary.tool_results.append(tool_result) + state.round_increment_pending = True + return tool_result + + +def summarize_function_tool_result( + part: ToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary for a client function tool return. + + Args: + part: Function tool return part emitted by the agent. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + return ToolResultSummary( + id=part.tool_call_id, + status="success", + content=part.model_response_str(), + type="function_call_output", + round=tool_round, + ) + + +def referenced_documents_from_file_search_results( + results: list[OpenAIFileSearchResult], + seen_docs: set[tuple[str, str]], + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> list[ReferencedDocument]: + """Parse referenced documents from OpenAI file-search result rows. + + Args: + results: Validated file-search result rows. + seen_docs: Dedupe keys already emitted; updated in place. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Newly discovered referenced documents from these result rows. + """ + documents: list[ReferencedDocument] = [] + for result in results: + doc = build_referenced_document(result, vector_store_ids, rag_id_mapping) + if doc is None: + continue + + dedup_key = (str(doc.doc_url or ""), doc.doc_title or "") + if dedup_key in seen_docs: + continue + + seen_docs.add(dedup_key) + documents.append(doc) + + return documents + + +def build_referenced_document( + result: OpenAIFileSearchResult, + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> Optional[ReferencedDocument]: + """Build one referenced document from a single file-search result row. + + Args: + result: OpenAI file-search result row. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Referenced document when metadata is present, otherwise None. + """ + attributes = result.attributes or {} + + doc_url = _file_search_attribute_url(attributes) + doc_title = _file_search_attribute_str(attributes, "title") + if not (doc_title or doc_url): + return None + + doc_id = _file_search_attribute_str( + attributes, "document_id" + ) or _file_search_attribute_str(attributes, "doc_id") + return ReferencedDocument( + doc_url=AnyUrl(doc_url) if doc_url else None, + doc_title=doc_title, + source=resolve_source_for_result(attributes, vector_store_ids, rag_id_mapping), + document_id=doc_id, + ) + + +def _file_search_attribute_str( + attributes: dict[str, str | float | bool], + key: str, +) -> Optional[str]: + """Read a non-empty string metadata field from file-search attributes. + + Args: + attributes: File-search result metadata attributes. + key: Metadata key to read. + + Returns: + Non-empty string value for the key, or None. + """ + return str(value) if (value := attributes.get(key)) else None + + +def _file_search_attribute_url( + attributes: dict[str, str | float | bool], +) -> Optional[str]: + """Extract the first available document URL from file-search attributes. + + Args: + attributes: File-search result metadata attributes. + + Returns: + First matching URL value as a string, or None. + """ + for key in _FILE_SEARCH_URL_KEYS: + if url := _file_search_attribute_str(attributes, key): + return url + return None + + +def rag_chunks_from_file_search_results( + results: list[OpenAIFileSearchResult], + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> list[RAGChunk]: + """Extract RAG chunks from OpenAI file-search result rows. + + Args: + results: Validated file-search result rows. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + RAG chunks extracted from these result rows. + """ + return [ + RAGChunk( + content=result.text, + source=resolve_source_for_result( + result.attributes or {}, vector_store_ids, rag_id_mapping + ), + score=result.score, + attributes=result.attributes or None, + ) + for result in results + if result.text + ] + + +def summarize_web_search_result( + part: NativeToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary from a native web-search return. + + Args: + part: Native web-search tool return part from the model stream. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + content = cast(dict[str, Any], part.content) + status = str(content.pop("status")) + return ToolResultSummary( + id=part.tool_call_id, + status=status, + content=json.dumps(content) if content else "", + type="web_search_call", + round=tool_round, + ) + + +def summarize_mcp_list_tools_result( + part: NativeToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary from a native MCP list-tools return. + + Args: + part: Native MCP list-tools return part from the model stream. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + content = cast(dict[str, Any], part.content) + call_id = part.tool_call_id + label = part.tool_name.removeprefix(f"{MCPServerTool.kind}:") + + if error := content.get("error"): + return ToolResultSummary( + id=call_id, + status="failure", + content=str(error), + type="mcp_list_tools", + round=tool_round, + ) + + list_summary = MCPListToolsSummary( + server_label=label, + tools=[ToolInfoSummary.model_validate(tool) for tool in content["tools"]], + ) + return ToolResultSummary( + id=call_id, + status="success", + content=json.dumps(list_summary.model_dump()), + type="mcp_list_tools", + round=tool_round, + ) + + +def summarize_mcp_call_result( + part: NativeToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary from a native MCP tool call return. + + Args: + part: Native MCP call return part from the model stream. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + content = cast(dict[str, Any], part.content) + call_id = part.tool_call_id + + if error := content.get("error"): + return ToolResultSummary( + id=call_id, + status="failure", + content=str(error), + type="mcp_call", + round=tool_round, + ) + + output = content.get("output", "") + return ToolResultSummary( + id=call_id, + status="success", + content=str(output), + type="mcp_call", + round=tool_round, + ) + + +def summarize_mcp_tool_result( + part: NativeToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary from a native MCP server tool return. + + Dispatches to list-tools or call processors based on return shape. + + Args: + part: Native MCP tool return part from the model stream. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + content = cast(dict[str, Any], part.content) + if "tools" in content and "error" in content: + return summarize_mcp_list_tools_result(part, tool_round) + return summarize_mcp_call_result(part, tool_round) + + +def summarize_file_search_result( + part: NativeToolReturnPart, + tool_round: int, + seen_docs: set[tuple[str, str]], + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> tuple[ToolResultSummary, list[RAGChunk], list[ReferencedDocument]]: + """Build tool result, RAG chunks, and referenced docs from a file-search return. + + Args: + part: Native file-search tool return part from the model stream. + tool_round: Tool execution round number for this result. + seen_docs: Dedupe keys for referenced documents; updated in place. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Tool result summary, RAG chunks, and referenced documents for this return. + """ + content = cast(dict[str, Any], part.content) + tool_result = ToolResultSummary( + id=part.tool_call_id, + status=str(content.pop("status")), + content=json.dumps(content), + type="file_search_call", + round=tool_round, + ) + results = [ + OpenAIFileSearchResult.model_validate(result) + for result in content.get("results", []) + ] + rag_chunks = rag_chunks_from_file_search_results( + results, + vector_store_ids=vector_store_ids, + rag_id_mapping=rag_id_mapping, + ) + referenced_documents = referenced_documents_from_file_search_results( + results, + seen_docs, + vector_store_ids=vector_store_ids, + rag_id_mapping=rag_id_mapping, + ) + return tool_result, rag_chunks, referenced_documents diff --git a/src/utils/pydantic_ai.py b/src/utils/pydantic_ai.py index 5df570dc9..2574c0ca3 100644 --- a/src/utils/pydantic_ai.py +++ b/src/utils/pydantic_ai.py @@ -7,10 +7,13 @@ from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient from llama_stack_client import AsyncLlamaStackClient from pydantic_ai import Agent -from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings +from pydantic_ai.models.openai import OpenAIResponsesModelSettings from models.common.responses.responses_api_params import ResponsesApiParams -from pydantic_ai_lightspeed.llamastack import LlamaStackProvider +from pydantic_ai_lightspeed.llamastack import ( + LlamaStackProvider, + LlamaStackResponsesModel, +) _LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset( { @@ -92,7 +95,7 @@ def build_agent( provider = _llama_stack_provider_from_client(client) settings = _model_settings_from_responses_params(responses_params) - model = OpenAIResponsesModel( + model = LlamaStackResponsesModel( responses_params.model, provider=provider, settings=settings, diff --git a/src/utils/streaming_sse.py b/src/utils/streaming_sse.py new file mode 100644 index 000000000..6b412d35b --- /dev/null +++ b/src/utils/streaming_sse.py @@ -0,0 +1,81 @@ +"""Shared SSE formatting helpers for streaming endpoints.""" + +import json +from collections.abc import AsyncIterator + +from constants import ( + LLM_TOKEN_EVENT, + LLM_TOOL_CALL_EVENT, + LLM_TOOL_RESULT_EVENT, + LLM_TURN_COMPLETE_EVENT, + MEDIA_TYPE_TEXT, +) + + +def format_stream_data(d: dict) -> str: + """Format a dictionary as an SSE ``data:`` line. + + Args: + d: Event payload to serialize. + + Returns: + SSE-formatted data string. + """ + data = json.dumps(d) + return f"data: {data}\n\n" + + +def stream_event(data: dict, event_type: str, media_type: str) -> str: + """Build a streaming event string for JSON or plain-text clients. + + Args: + data: Dictionary containing the event data. + event_type: Type of event (token, tool call, etc.). + media_type: The media type for the response format. + + Returns: + SSE-formatted string representing the event. + """ + if media_type == MEDIA_TYPE_TEXT: + if event_type == LLM_TOKEN_EVENT: + return data.get("token", "") + if event_type == LLM_TOOL_CALL_EVENT: + return f"[Tool Call: {data.get('function_name', 'unknown')}]\n" + if event_type == LLM_TOOL_RESULT_EVENT: + return "[Tool Result]\n" + if event_type == LLM_TURN_COMPLETE_EVENT: + return "" + return "" + + return format_stream_data( + { + "event": event_type, + "data": data, + } + ) + + +async def shield_violation_generator( + violation_message: str, + media_type: str = MEDIA_TYPE_TEXT, +) -> AsyncIterator[str]: + """Create an SSE stream for shield violation responses. + + Yields a token event immediately for shield violations without going + through the Llama Stack response format. + + Args: + violation_message: The violation message to display. + media_type: The media type for the response format. + + Yields: + SSE-formatted strings for the violation token event. + """ + yield stream_event( + { + "id": 0, + "token": violation_message, + }, + LLM_TOKEN_EVENT, + media_type, + ) diff --git a/tests/e2e/features/steps/llm_query_response.py b/tests/e2e/features/steps/llm_query_response.py index 18c76a4cf..5ae926b5b 100644 --- a/tests/e2e/features/steps/llm_query_response.py +++ b/tests/e2e/features/steps/llm_query_response.py @@ -93,23 +93,9 @@ def wait_for_complete_response(context: Context) -> None: """Wait for the response to be complete.""" context.response_data = _parse_streaming_response(context.response.text) context.response.raise_for_status() - assert context.response_data["finished"] is True - - -@step('I use "{endpoint}" to ask question') -def ask_question(context: Context, endpoint: str) -> None: - """Call the service REST API endpoint with question.""" - base = f"http://{context.hostname}:{context.port}" - path = f"{context.api_prefix}/{endpoint}".replace("//", "/") - url = base + path - - # Replace {MODEL} and {PROVIDER} placeholders with actual values - json_str = replace_placeholders(context, context.text or "{}") - - data = json.loads(json_str) - context.response = request_with_transient_retry( - method="POST", url=url, json=data, timeout=DEFAULT_LLM_TIMEOUT - ) + assert ( + context.response_data["finished"] is True + ), f"Response is not finished: {context.response_data}" def _read_streamed_response(response: requests.Response) -> str: @@ -124,41 +110,72 @@ def _read_streamed_response(response: requests.Response) -> str: return "".join(chunks) -@step('I use "{endpoint}" to ask question with authorization header') -def ask_question_authorized(context: Context, endpoint: str) -> None: - """Call the service REST API endpoint with question.""" +def _uses_sse(endpoint: str, data: dict[str, Any]) -> bool: + """Return whether the endpoint delivers an SSE stream for the given payload.""" + return endpoint == "streaming_query" or ( + endpoint == "responses" and bool(data.get("stream")) + ) + + +def _post_question( + context: Context, + endpoint: str, + headers: dict[str, str] | None = None, + extra_data: dict[str, Any] | None = None, +) -> requests.Response: + """POST a question to the service REST API endpoint. + + Parameters: + context: Behave context with hostname, port, and request body text. + endpoint: API endpoint name (e.g. ``query``, ``streaming_query``). + headers: Optional HTTP headers (e.g. authorization). + extra_data: Optional fields merged into the JSON request body. + + Returns: + The HTTP response, with streamed bodies fully consumed when applicable. + """ base = f"http://{context.hostname}:{context.port}" path = f"{context.api_prefix}/{endpoint}".replace("//", "/") url = base + path - # Replace {MODEL} and {PROVIDER} placeholders with actual values json_str = replace_placeholders(context, context.text or "{}") - data = json.loads(json_str) - use_sse = endpoint == "streaming_query" or ( - endpoint == "responses" and bool(data.get("stream")) - ) - if use_sse: + if extra_data: + data.update(extra_data) + + if _uses_sse(endpoint, data): resp = request_with_transient_retry( method="POST", url=url, json=data, - headers=context.auth_headers, + headers=headers, timeout=DEFAULT_LLM_TIMEOUT, stream=True, ) # Consume stream so server close after error event does not raise body = _read_streamed_response(resp) resp._content = body.encode(resp.encoding or "utf-8") - context.response = resp - else: - context.response = request_with_transient_retry( - method="POST", - url=url, - json=data, - headers=context.auth_headers, - timeout=DEFAULT_LLM_TIMEOUT, - ) + return resp + + return request_with_transient_retry( + method="POST", + url=url, + json=data, + headers=headers, + timeout=DEFAULT_LLM_TIMEOUT, + ) + + +@step('I use "{endpoint}" to ask question') +def ask_question(context: Context, endpoint: str) -> None: + """Call the service REST API endpoint with question.""" + context.response = _post_question(context, endpoint) + + +@step('I use "{endpoint}" to ask question with authorization header') +def ask_question_authorized(context: Context, endpoint: str) -> None: + """Call the service REST API endpoint with question.""" + context.response = _post_question(context, endpoint, headers=context.auth_headers) # Query length chosen to exceed typical model context windows (e.g. 128k tokens) @@ -188,19 +205,12 @@ def store_conversation_details(context: Context) -> None: @step('I use "{endpoint}" to ask question with same conversation_id') def ask_question_in_same_conversation(context: Context, endpoint: str) -> None: """Call the service REST API endpoint with question, but use the existing conversation id.""" - base = f"http://{context.hostname}:{context.port}" - path = f"{context.api_prefix}/{endpoint}".replace("//", "/") - url = base + path - - # Replace {MODEL} and {PROVIDER} placeholders with actual values - json_str = replace_placeholders(context, context.text or "{}") - - data = json.loads(json_str) - headers = context.auth_headers if hasattr(context, "auth_headers") else {} - data["conversation_id"] = context.response_data["conversation_id"] - - context.response = request_with_transient_retry( - method="POST", url=url, json=data, headers=headers, timeout=DEFAULT_LLM_TIMEOUT + headers = context.auth_headers if hasattr(context, "auth_headers") else None + context.response = _post_question( + context, + endpoint, + headers=headers, + extra_data={"conversation_id": context.response_data["conversation_id"]}, ) @@ -366,12 +376,12 @@ def _parse_streaming_response(response_text: str) -> dict: full_response = "" full_response_split = [] finished = False - first_token = True stream_error = ( None # {"status_code": int, "response": str, "cause": str} if event "error" ) for line in lines: + print(f"line: {line}") if line.startswith("data: "): try: data = json.loads(line[6:]) # Remove 'data: ' prefix @@ -380,10 +390,6 @@ def _parse_streaming_response(response_text: str) -> dict: if event == "start": conversation_id = data["data"]["conversation_id"] elif event == "token": - # Skip the first token (shield status message) - if first_token: - first_token = False - continue full_response_split.append(data["data"]["token"]) elif event == "turn_complete": full_response = data["data"]["token"] diff --git a/tests/e2e/test_list.txt b/tests/e2e/test_list.txt index f47aa1d97..39804acb7 100644 --- a/tests/e2e/test_list.txt +++ b/tests/e2e/test_list.txt @@ -1,30 +1,3 @@ -features/authorized_noop.feature -features/health.feature -features/info.feature -features/models.feature -features/rest_api.feature -features/smoketests.feature -features/authorized_noop_token.feature -features/conversation_cache_v2.feature -features/conversations.feature -features/prompts.feature -features/faiss.feature features/inline_rag.feature -features/feedback.feature -features/query.feature -features/responses.feature -features/responses_streaming.feature -features/rlsapi_v1.feature features/streaming_query.feature -features/http_401_unauthorized.feature -features/authorized_rh_identity.feature -features/rbac.feature -features/rlsapi_v1_errors.feature -features/llama_stack_disrupted.feature features/mcp.feature -features/mcp_servers_api.feature -features/mcp_servers_api_auth.feature -features/mcp_servers_api_no_config.feature -features/proxy.feature -features/tls.feature -features/opentelemetry.feature diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7fc2edfa0..ee36161ae 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,11 +3,24 @@ import os from collections.abc import Generator from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, cast import pytest from fastapi import Request, Response from fastapi.testclient import TestClient +from llama_stack_api.openai_responses import OpenAIResponseObject +from llama_stack_client.types import VersionInfo +from pydantic_ai import AgentRunResultEvent +from pydantic_ai.messages import ( + ModelResponse, + NativeToolCallPart, + NativeToolReturnPart, + PartEndEvent, + PartStartEvent, + TextPart, +) +from pydantic_ai.native_tools import FileSearchTool, MCPServerTool +from pydantic_ai.usage import RunUsage from pytest_mock import MockerFixture from sqlalchemy import create_engine from sqlalchemy.engine import Engine @@ -70,9 +83,6 @@ def create_mock_llm_response( # pylint: disable=too-many-arguments,too-many-pos Returns: Mock LLM response object with the specified configuration. """ - # pylint: disable=import-outside-toplevel - from llama_stack_api.openai_responses import OpenAIResponseObject - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) mock_response.id = "response-123" @@ -154,6 +164,264 @@ def create_mock_tool_call( return mock_tool_call +def create_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str = "This is a test response about Ansible.", + response_id: str = "response-123", + input_tokens: int = 10, + output_tokens: int = 5, + model_response: Any = None, + new_messages: Optional[list[Any]] = None, +) -> Any: + """Create a mock AgentRunResult wired for retrieve_agent_response. + + Uses real pydantic-ai message types so build_turn_summary_from_agent_run + exercises the same path as production agent runs. + + Args: + mocker: pytest-mock fixture. + content: Assistant text content for the run. + response_id: Provider response identifier. + input_tokens: Input token count for the run. + output_tokens: Output token count for the run. + model_response: Optional pre-built ModelResponse. + new_messages: Optional message sequence returned by new_messages(). + + Returns: + Mock AgentRunResult compatible with build_turn_summary_from_agent_run. + """ + if model_response is None: + parts = [TextPart(content)] if content else [] + model_response = ModelResponse( + parts=parts, + finish_reason="stop", + provider_response_id=response_id, + ) + + messages = new_messages if new_messages is not None else [model_response] + run_result = mocker.MagicMock() + run_result.response = model_response + run_result.usage = RunUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + requests=1, + ) + run_result.new_messages.return_value = messages + return run_result + + +def create_file_search_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-tool-rag", + queries: Optional[list[str]] = None, + results: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 10, + output_tokens: int = 5, +) -> Any: + """Create an AgentRunResult containing a native file_search tool call.""" + call = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": queries or ["test query"]}, + tool_call_id="call-fs-1", + ) + return_part = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="call-fs-1", + content={ + "status": "success", + "results": results or [], + }, + ) + model_response = ModelResponse( + parts=[call, return_part, TextPart(content)], + finish_reason="stop", + provider_response_id=response_id, + ) + return create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + model_response=model_response, + ) + + +def create_mcp_list_tools_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-mcplist", + server_label: str = "kubernetes-server", + tools: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 15, + output_tokens: int = 20, +) -> Any: + """Create an AgentRunResult containing an MCP list-tools native tool call.""" + call = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:{server_label}", + args={"action": "list_tools"}, + tool_call_id="mcplist-101", + ) + return_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:{server_label}", + tool_call_id="mcplist-101", + content={"tools": tools or []}, + ) + model_response = ModelResponse( + parts=[call, return_part, TextPart(content)], + finish_reason="stop", + provider_response_id=response_id, + ) + return create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + model_response=model_response, + ) + + +def configure_query_agent_mock( + mocker: MockerFixture, + *, + run_result: Any = None, + run_side_effect: Any = None, +) -> Any: + """Patch build_agent for /query integration tests and return the mock agent. + + Args: + mocker: pytest-mock fixture. + run_result: AgentRunResult returned by agent.run(). + run_side_effect: Optional exception side effect for agent.run(). + + Returns: + Mock agent exposing AsyncMock run(). + """ + if run_result is None: + run_result = create_agent_run_result(mocker) + + mock_agent = mocker.AsyncMock() + if run_side_effect is not None: + mock_agent.run = mocker.AsyncMock(side_effect=run_side_effect) + else: + mock_agent.run = mocker.AsyncMock(return_value=run_result) + + build_agent_mock = mocker.patch( + "utils.agents.query.build_agent", + return_value=mock_agent, + ) + mock_agent.build_agent_mock = build_agent_mock + return mock_agent + + +def create_text_agent_stream_events( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str = "Based on the documentation, OpenShift is a Kubernetes distribution.", + response_id: str = "response-inline-stream", + input_tokens: int = 50, + output_tokens: int = 20, +) -> list[Any]: + """Build pydantic-ai stream events for a simple text agent response.""" + run_result = create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + text_part = TextPart(content=content) + return [ + PartStartEvent(index=0, part=text_part), + PartEndEvent(index=0, part=text_part), + AgentRunResultEvent(result=run_result), + ] + + +def create_file_search_agent_stream_events( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-tool-stream", + queries: Optional[list[str]] = None, + results: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 60, + output_tokens: int = 25, +) -> list[Any]: + """Build pydantic-ai stream events for a file_search tool agent response.""" + run_result = create_file_search_agent_run_result( + mocker, + content=content, + response_id=response_id, + queries=queries, + results=results, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + call_part, return_part, text_part = run_result.response.parts + return [ + PartEndEvent(index=0, part=call_part), + PartStartEvent(index=1, part=return_part), + PartStartEvent(index=2, part=text_part), + PartEndEvent(index=2, part=text_part), + AgentRunResultEvent(result=run_result), + ] + + +def configure_streaming_agent_mock( + mocker: MockerFixture, + *, + stream_events: Optional[list[Any]] = None, +) -> Any: + """Patch build_agent for /streaming_query integration tests. + + Args: + mocker: pytest-mock fixture. + stream_events: Optional pydantic-ai events yielded by run_stream_events. + + Returns: + Mock agent exposing run_stream_events(). + """ + events = stream_events or create_text_agent_stream_events(mocker) + + def _run_stream_events_side_effect(_prompt: str) -> Any: + async def _event_stream() -> Any: + for event in events: + yield event + + ctx = mocker.MagicMock() + ctx.__aenter__ = mocker.AsyncMock(return_value=_event_stream()) + ctx.__aexit__ = mocker.AsyncMock(return_value=None) + return ctx + + mock_agent = mocker.MagicMock() + mock_agent.run_stream_events = mocker.MagicMock( + side_effect=_run_stream_events_side_effect + ) + + build_agent_mock = mocker.patch( + "utils.agents.streaming.build_agent", + return_value=mock_agent, + ) + mock_agent.build_agent_mock = build_agent_mock + return mock_agent + + +def get_agent_responses_params(mock_client: Any) -> Any: + """Return ResponsesApiParams passed to the patched streaming build_agent.""" + return mock_client.build_agent_mock.call_args[0][1] + + +def get_agent_input_text(mock_client: Any) -> str: + """Return the agent prompt text from the patched streaming build_agent call.""" + return cast(str, get_agent_responses_params(mock_client).input) + + # ========================================== # Fixtures # ========================================== @@ -448,10 +716,6 @@ def mock_llama_stack_client_fixture( Yields: mock_client: The mocked Llama Stack client instance. """ - # pylint: disable=import-outside-toplevel - from llama_stack_api.openai_responses import OpenAIResponseObject - from llama_stack_client.types import VersionInfo - # Patch AsyncLlamaStackClientHolder at multiple import locations # This ensures the mock is active both during app startup (app.main) # and during endpoint execution (query, conversations_v1, responses, etc.) @@ -484,6 +748,10 @@ def mock_llama_stack_client_fixture( mock_client.responses.create.return_value = mock_response + mock_agent = configure_query_agent_mock(mocker) + mock_client.query_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock + # Mock models.list mock_model = mocker.MagicMock() mock_model.id = "test-provider/test-model" diff --git a/tests/integration/endpoints/test_query_byok_integration.py b/tests/integration/endpoints/test_query_byok_integration.py index b2a659f19..02e107388 100644 --- a/tests/integration/endpoints/test_query_byok_integration.py +++ b/tests/integration/endpoints/test_query_byok_integration.py @@ -17,6 +17,10 @@ from configuration import AppConfig from models.api.requests import QueryRequest from models.api.responses.successful import QueryResponse +from tests.integration.conftest import ( + configure_query_agent_mock, + create_file_search_agent_run_result, +) # --------------------------------------------------------------------------- # Helpers @@ -87,8 +91,8 @@ def _make_vector_io_response( def _build_base_mock_client(mocker: MockerFixture) -> Any: """Build a base mock Llama Stack client with common stubs. - Configures models, shields, conversations, version, and a default - responses.create return value. + Configures models, shields, conversations, version, and a default agent.run + return value. responses.create remains available for topic summary generation. """ mock_client = mocker.AsyncMock() @@ -131,6 +135,10 @@ def _build_base_mock_client(mocker: MockerFixture) -> Any: mock_response.usage = mock_usage mock_client.responses.create.return_value = mock_response + mock_agent = configure_query_agent_mock(mocker) + mock_client.query_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock + return mock_client @@ -171,8 +179,8 @@ def mock_byok_tool_rag_client_fixture( ) -> Generator[Any, None, None]: """Mock Llama Stack client with BYOK tool RAG (file_search) configured. - Configures vector_stores.list with a BYOK store and responses.create - to return a file_search_call output item alongside the assistant message. + Configures vector_stores.list with a BYOK store and agent.run to return + a file_search tool result alongside the assistant message. """ mock_holder_class = mocker.patch("app.endpoints.query.AsyncLlamaStackClientHolder") mock_client = _build_base_mock_client(mocker) @@ -190,54 +198,26 @@ def mock_byok_tool_rag_client_fixture( mock_list_result.data = [mock_vector_store] mock_client.vector_stores.list.return_value = mock_list_result - # Response with file_search tool call - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_response.id = "response-tool-rag" - - mock_tool_output = mocker.MagicMock() - mock_tool_output.type = "file_search_call" - mock_tool_output.id = "call-fs-1" - mock_tool_output.queries = ["What is OpenShift?"] - mock_tool_output.status = "completed" - - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-ocp-1" - mock_result.filename = "openshift-docs.txt" - mock_result.score = 0.92 - mock_result.text = "OpenShift is a Kubernetes distribution by Red Hat." - mock_result.attributes = { - "doc_url": "https://docs.redhat.com/ocp/overview", - "link": "https://docs.redhat.com/ocp/overview", - } - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-ocp-1", - "filename": "openshift-docs.txt", - "score": 0.92, - "text": "OpenShift is a Kubernetes distribution by Red Hat.", - "attributes": { - "doc_url": "https://docs.redhat.com/ocp/overview", - }, - } - ) - mock_tool_output.results = [mock_result] - - mock_message = mocker.MagicMock() - mock_message.type = "message" - mock_message.role = "assistant" - mock_message.content = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." + tool_run_result = create_file_search_agent_run_result( + mocker, + content=("Based on the documentation, OpenShift is a Kubernetes distribution."), + response_id="response-tool-rag", + queries=["What is OpenShift?"], + results=[ + { + "text": "OpenShift is a Kubernetes distribution by Red Hat.", + "score": 0.92, + "attributes": { + "doc_url": "https://docs.redhat.com/ocp/overview", + "title": "openshift-docs.txt", + "document_id": "doc-ocp-1", + }, + } + ], + input_tokens=60, + output_tokens=25, ) - mock_message.refusal = None - - mock_response.output = [mock_tool_output, mock_message] - mock_response.stop_reason = "end_turn" - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 60 - mock_usage.output_tokens = 25 - mock_response.usage = mock_usage - mock_client.responses.create.return_value = mock_response + mock_client.query_agent.run.return_value = tool_run_result mock_holder_class.return_value.get_client.return_value = mock_client yield mock_client @@ -317,7 +297,7 @@ async def test_query_byok_inline_rag_injects_context( Verifies: - vector_io.query is called for BYOK inline RAG - - RAG context is injected into the responses.create input + - RAG context is injected into the agent prompt - Response includes RAG chunks from inline sources """ _ = byok_config @@ -353,13 +333,10 @@ async def test_query_byok_inline_rag_injects_context( call_kwargs = mock_byok_client.vector_io.query.call_args.kwargs assert call_kwargs["query"] == "What is OpenShift?" - # Verify RAG context was injected into responses.create input - # Use call_args_list[0] — the first call is the main query; - # a second call may follow for topic summary generation. - create_kwargs = mock_byok_client.responses.create.call_args_list[0].kwargs - input_text = create_kwargs["input"] - assert "file_search found" in input_text - assert "OpenShift is a Kubernetes distribution" in input_text + # Verify RAG context was injected into the agent prompt + prompt = mock_byok_client.query_agent.run.call_args.args[0] + assert "file_search found" in prompt + assert "OpenShift is a Kubernetes distribution" in prompt # Verify RAG chunks are included in the response assert response.rag_chunks is not None @@ -802,47 +779,27 @@ async def test_query_byok_combined_inline_and_tool_rag( # pylint: disable=too-m mock_list_result.data = [mock_vector_store] mock_client.vector_stores.list.return_value = mock_list_result - # Response includes file_search_call (tool RAG result) - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_response.id = "response-combined" - - mock_tool_output = mocker.MagicMock() - mock_tool_output.type = "file_search_call" - mock_tool_output.id = "call-fs-combined" - mock_tool_output.queries = ["What is OpenShift?"] - mock_tool_output.status = "completed" - - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-tool-1" - mock_result.filename = "tool-doc.txt" - mock_result.score = 0.90 - mock_result.text = "Tool-based RAG result about OpenShift." - mock_result.attributes = {"doc_url": "https://example.com/tool-doc"} - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-tool-1", - "filename": "tool-doc.txt", - "score": 0.90, - "text": "Tool-based RAG result about OpenShift.", - "attributes": {"doc_url": "https://example.com/tool-doc"}, - } + # Agent run includes file_search tool RAG result + combined_run_result = create_file_search_agent_run_result( + mocker, + content="Combined answer from inline and tool RAG.", + response_id="response-combined", + queries=["What is OpenShift?"], + results=[ + { + "text": "Tool-based RAG result about OpenShift.", + "score": 0.90, + "attributes": { + "doc_url": "https://example.com/tool-doc", + "title": "tool-doc.txt", + "document_id": "doc-tool-1", + }, + } + ], + input_tokens=80, + output_tokens=30, ) - mock_tool_output.results = [mock_result] - - mock_message = mocker.MagicMock() - mock_message.type = "message" - mock_message.role = "assistant" - mock_message.content = "Combined answer from inline and tool RAG." - mock_message.refusal = None - - mock_response.output = [mock_tool_output, mock_message] - mock_response.stop_reason = "end_turn" - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 80 - mock_usage.output_tokens = 30 - mock_response.usage = mock_usage - mock_client.responses.create.return_value = mock_response + mock_client.query_agent.run.return_value = combined_run_result mock_holder_class.return_value.get_client.return_value = mock_client diff --git a/tests/integration/endpoints/test_query_integration.py b/tests/integration/endpoints/test_query_integration.py index 725f01d25..a2d29e182 100644 --- a/tests/integration/endpoints/test_query_integration.py +++ b/tests/integration/endpoints/test_query_integration.py @@ -7,8 +7,17 @@ import pytest from fastapi import HTTPException, Request, status -from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client import APIConnectionError +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + NativeToolCallPart, + NativeToolReturnPart, + TextPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.native_tools import FileSearchTool from pytest_mock import AsyncMockType, MockerFixture from sqlalchemy.orm import Session @@ -25,7 +34,9 @@ from tests.integration.conftest import ( TEST_CONVERSATION_ID, TEST_NON_EXISTENT_ID, - create_mock_llm_response, + create_agent_run_result, + create_file_search_agent_run_result, + create_mcp_list_tools_agent_run_result, ) # File-specific test constants @@ -115,7 +126,7 @@ async def test_query_v2_endpoint_handles_connection_error( """ _ = test_config - mock_llama_stack_client.responses.create.side_effect = APIConnectionError( + mock_llama_stack_client.query_agent.run.side_effect = APIConnectionError( request=mocker.Mock() ) @@ -453,50 +464,24 @@ async def test_query_v2_endpoint_with_tool_calls( """ _ = test_config - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_response.id = "response-789" - - mock_tool_output = mocker.MagicMock() - mock_tool_output.type = "file_search_call" - mock_tool_output.id = "call-1" - mock_tool_output.queries = ["What is Ansible"] - mock_tool_output.status = "completed" - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-1" - mock_result.filename = "ansible-docs.txt" - mock_result.score = 0.95 - mock_result.text = "Ansible is an open-source automation tool..." - mock_result.attributes = { - "doc_url": "https://example.com/ansible-docs.txt", - "link": "https://example.com/ansible-docs.txt", - } - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-1", - "filename": "ansible-docs.txt", - "score": 0.95, - "text": "Ansible is an open-source automation tool...", - "attributes": { - "doc_url": "https://example.com/ansible-docs.txt", - "link": "https://example.com/ansible-docs.txt", - }, - } + tool_run_result = create_file_search_agent_run_result( + mocker, + content="Based on the documentation, Ansible is...", + response_id="response-789", + queries=["What is Ansible"], + results=[ + { + "text": "Ansible is an open-source automation tool...", + "score": 0.95, + "attributes": { + "doc_url": "https://example.com/ansible-docs.txt", + "title": "ansible-docs.txt", + "document_id": "doc-1", + }, + } + ], ) - mock_tool_output.results = [mock_result] - - mock_message_output = mocker.MagicMock() - mock_message_output.type = "message" - mock_message_output.role = "assistant" - mock_message_output.content = "Based on the documentation, Ansible is..." - - mock_response.output = [mock_tool_output, mock_message_output] - mock_response.stop_reason = "end_turn" - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 10 - mock_usage.output_tokens = 5 - mock_response.usage = mock_usage - - mock_llama_stack_client.responses.create.return_value = mock_response + mock_llama_stack_client.query_agent.run.return_value = tool_run_result query_request = QueryRequest(query="What is Ansible?") @@ -537,38 +522,23 @@ async def test_query_v2_endpoint_with_mcp_list_tools( """ _ = test_config - mock_response = mocker.MagicMock() - mock_response.id = "response-mcplist" - - mock_tool1 = mocker.MagicMock() - mock_tool1.name = "list_pods" - mock_tool1.description = "List Kubernetes pods" - mock_tool1.input_schema = {"type": "object", "properties": {}} - - mock_tool2 = mocker.MagicMock() - mock_tool2.name = "get_deployment" - mock_tool2.description = "Get Kubernetes deployment" - mock_tool2.input_schema = {"type": "object", "properties": {}} - - mock_mcp_list = mocker.MagicMock() - mock_mcp_list.type = "mcp_list_tools" - mock_mcp_list.id = "mcplist-101" - mock_mcp_list.server_label = "kubernetes-server" - mock_mcp_list.tools = [mock_tool1, mock_tool2] - - mock_message = mocker.MagicMock() - mock_message.type = "message" - mock_message.role = "assistant" - mock_message.content = "Available tools: list_pods, get_deployment" - - mock_response.output = [mock_mcp_list, mock_message] - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 15 - mock_usage.output_tokens = 20 - mock_response.usage = mock_usage - - mock_llama_stack_client.responses.create.return_value = mock_response + mcp_run_result = create_mcp_list_tools_agent_run_result( + mocker, + content="Available tools: list_pods, get_deployment", + tools=[ + { + "name": "list_pods", + "description": "List Kubernetes pods", + "input_schema": {"type": "object", "properties": {}}, + }, + { + "name": "get_deployment", + "description": "Get Kubernetes deployment", + "input_schema": {"type": "object", "properties": {}}, + }, + ], + ) + mock_llama_stack_client.query_agent.run.return_value = mcp_run_result query_request = QueryRequest(query="What tools are available?") @@ -609,37 +579,46 @@ async def test_query_v2_endpoint_with_multiple_tool_types( """ _ = test_config - mock_response = mocker.MagicMock() - mock_response.id = "response-multi" - - mock_file_search = mocker.MagicMock() - mock_file_search.type = "file_search_call" - mock_file_search.id = "search-1" - mock_file_search.queries = ["Kubernetes deployment"] - mock_file_search.status = "completed" - mock_file_search.results = [] - - mock_function = mocker.MagicMock() - mock_function.type = "function_call" - mock_function.id = "func-2" - mock_function.call_id = "func-2" - mock_function.name = "calculate" - mock_function.arguments = '{"operation": "sum"}' - mock_function.status = "completed" - - mock_message = mocker.MagicMock() - mock_message.type = "message" - mock_message.role = "assistant" - mock_message.content = "Based on documentation and calculations..." - - mock_response.output = [mock_file_search, mock_function, mock_message] - mock_response.tool_calls = [] - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 40 - mock_usage.output_tokens = 60 - mock_response.usage = mock_usage - - mock_llama_stack_client.responses.create.return_value = mock_response + file_search_call = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": ["Kubernetes deployment"]}, + tool_call_id="search-1", + ) + file_search_return = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="search-1", + content={"status": "success", "results": []}, + ) + function_call = ToolCallPart( + tool_name="calculate", + args={"operation": "sum"}, + tool_call_id="func-2", + ) + function_return = ToolReturnPart( + tool_name="calculate", + content={"result": 2}, + tool_call_id="func-2", + ) + model_response = ModelResponse( + parts=[ + file_search_call, + file_search_return, + function_call, + TextPart("Based on documentation and calculations..."), + ], + finish_reason="stop", + provider_response_id="response-multi", + ) + multi_tool_run = create_agent_run_result( + mocker, + content="Based on documentation and calculations...", + response_id="response-multi", + input_tokens=40, + output_tokens=60, + model_response=model_response, + new_messages=[model_response, ModelRequest(parts=[function_return])], + ) + mock_llama_stack_client.query_agent.run.return_value = multi_tool_run query_request = QueryRequest(query="Search docs and calculate deployment replicas") @@ -711,9 +690,9 @@ async def test_query_v2_endpoint_bypasses_tools_when_no_tools_true( assert response.conversation_id is not None assert response.response is not None - # Verify NO tools were passed to Llama Stack (despite vector stores being available) - call_kwargs = mock_llama_stack_client.responses.create.call_args.kwargs - assert call_kwargs.get("tools") is None + # Verify NO tools were passed to the agent (despite vector stores being available) + responses_params = mock_llama_stack_client.build_agent_mock.call_args[0][1] + assert responses_params.tools is None @pytest.mark.asyncio @@ -770,11 +749,11 @@ async def test_query_v2_endpoint_uses_tools_when_available( assert response.conversation_id is not None assert response.response is not None - # Verify tools were passed to Llama Stack (real tool preparation logic ran) - call_kwargs = mock_llama_stack_client.responses.create.call_args_list[0].kwargs - assert call_kwargs.get("tools") is not None - assert len(call_kwargs["tools"]) > 0 - assert any(tool.get("type") == "file_search" for tool in call_kwargs["tools"]) + # Verify tools were passed to the agent (real tool preparation logic ran) + responses_params = mock_llama_stack_client.build_agent_mock.call_args[0][1] + assert responses_params.tools is not None + assert len(responses_params.tools) > 0 + assert any(tool.type == "file_search" for tool in responses_params.tools) # ========================================== @@ -876,16 +855,15 @@ async def test_query_v2_endpoint_updates_existing_conversation( original_topic = existing_conversation.topic_summary original_count = existing_conversation.message_count - # Create a proper mock response with all required attributes - mock_response = create_mock_llm_response( + # Create a proper agent run result with empty assistant text + empty_run_result = create_agent_run_result( mocker, content="", + response_id=EXISTING_CONV_ID, input_tokens=10, output_tokens=5, ) - mock_response.id = EXISTING_CONV_ID - mock_response.output = [] # Override to empty for this test - mock_llama_stack_client.responses.create.return_value = mock_response + mock_llama_stack_client.query_agent.run.return_value = empty_run_result query_request = QueryRequest(query="Tell me more", conversation_id=EXISTING_CONV_ID) @@ -1110,17 +1088,15 @@ async def test_query_v2_endpoint_with_shield_violation( """ _ = test_config - # Configure Llama Stack mock to return response with violation - mock_response = create_mock_llm_response( + # Configure agent mock to return advisory shield-violation-style content + violation_run_result = create_agent_run_result( mocker, content="I cannot respond to this request", - refusal="Content violates safety policy", + response_id="response-violation", input_tokens=10, output_tokens=5, ) - mock_response.id = "response-violation" - - mock_llama_stack_client.responses.create.return_value = mock_response + mock_llama_stack_client.query_agent.run.return_value = violation_run_result query_request = QueryRequest(query="Inappropriate query") @@ -1182,13 +1158,10 @@ async def test_query_v2_endpoint_without_shields( assert response.conversation_id is not None assert response.response is not None - # Verify extra_body was not included (or guardrails is empty) - call_kwargs = mock_llama_stack_client.responses.create.call_args.kwargs - if "extra_body" in call_kwargs: - assert ( - "guardrails" not in call_kwargs["extra_body"] - or not call_kwargs["extra_body"]["guardrails"] - ) + # Verify responses params passed to the agent do not include guardrails + responses_params = mock_llama_stack_client.build_agent_mock.call_args[0][1] + dumped_params = responses_params.model_dump(exclude_none=True) + assert "guardrails" not in dumped_params @pytest.mark.asyncio @@ -1217,17 +1190,14 @@ async def test_query_v2_endpoint_handles_empty_llm_response( """ _ = test_config - # Create a response with truly empty output array (no assistant messages) - mock_response = create_mock_llm_response( + empty_run_result = create_agent_run_result( mocker, content="", + response_id="response-empty", input_tokens=10, output_tokens=0, ) - mock_response.id = "response-empty" - mock_response.output = [] # Override to test truly empty response - - mock_llama_stack_client.responses.create.return_value = mock_response + mock_llama_stack_client.query_agent.run.return_value = empty_run_result query_request = QueryRequest(query="What is Ansible?") @@ -1276,16 +1246,14 @@ async def test_query_v2_endpoint_quota_integration( _ = test_config _ = patch_db_session - mock_response = create_mock_llm_response( + quota_run_result = create_agent_run_result( mocker, content="", + response_id="response-quota", input_tokens=100, output_tokens=50, ) - mock_response.id = "response-quota" - mock_response.output = [] # Override to empty for this test - - mock_llama_stack_client.responses.create.return_value = mock_response + mock_llama_stack_client.query_agent.run.return_value = quota_run_result mock_consume = mocker.spy(app.endpoints.query, "consume_query_tokens") _ = mocker.spy(app.endpoints.query, "get_available_quotas") @@ -1513,15 +1481,14 @@ async def test_query_v2_endpoint_uses_conversation_history_model( patch_db_session.add(existing_conv) patch_db_session.commit() - mock_response = create_mock_llm_response( + history_run_result = create_agent_run_result( mocker, content="", + response_id=EXISTING_CONV_ID, input_tokens=10, output_tokens=5, ) - mock_response.id = EXISTING_CONV_ID - mock_response.output = [] # Override to empty for this test - mock_llama_stack_client.responses.create.return_value = mock_response + mock_llama_stack_client.query_agent.run.return_value = history_run_result query_request = QueryRequest(query="Tell me more", conversation_id=EXISTING_CONV_ID) diff --git a/tests/integration/endpoints/test_streaming_query_byok_integration.py b/tests/integration/endpoints/test_streaming_query_byok_integration.py index c539d4294..33bf90085 100644 --- a/tests/integration/endpoints/test_streaming_query_byok_integration.py +++ b/tests/integration/endpoints/test_streaming_query_byok_integration.py @@ -3,13 +3,12 @@ # pylint: disable=too-many-lines import json -from collections.abc import AsyncIterator, Generator +from collections.abc import Generator from typing import Any import pytest from fastapi import Request, status from fastapi.responses import StreamingResponse -from llama_stack_api.openai_responses import OpenAIResponseObject from pytest_mock import AsyncMockType, MockerFixture import constants @@ -17,6 +16,12 @@ from authentication.interface import AuthTuple from configuration import AppConfig from models.api.requests import QueryRequest +from tests.integration.conftest import ( + configure_streaming_agent_mock, + create_file_search_agent_stream_events, + get_agent_input_text, + get_agent_responses_params, +) from tests.integration.endpoints.test_query_byok_integration import ( _build_base_mock_client, _make_byok_vector_io_response, @@ -50,39 +55,18 @@ async def _collect_sse_events(response: StreamingResponse) -> list[dict[str, Any def _build_base_streaming_mock_client(mocker: MockerFixture) -> Any: """Build a base mock Llama Stack client configured for streaming responses. - Extends the base query mock client with streaming-specific stubs: - conversations.items.create and a streaming responses.create. + Extends the base query mock client with a patched pydantic-ai streaming + agent and topic-summary responses.create stub. """ mock_client = _build_base_mock_client(mocker) - # Streaming additions + mock_agent = configure_streaming_agent_mock(mocker) + mock_client.streaming_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock + mock_client.conversations.items.create = mocker.AsyncMock() - async def _mock_stream() -> AsyncIterator[Any]: - chunk = mocker.MagicMock() - chunk.type = "response.output_text.done" - chunk.text = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." - ) - yield chunk - - # Emit response.completed so referenced_documents propagate to end event - completed_chunk = mocker.MagicMock() - completed_chunk.type = "response.completed" - mock_final = mocker.MagicMock(spec=OpenAIResponseObject) - mock_final.id = "response-inline-stream" - mock_final.error = None - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 50 - mock_usage.output_tokens = 20 - mock_final.usage = mock_usage - mock_final.output = [] - completed_chunk.response = mock_final - yield completed_chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_stream() + async def _responses_create(**_kwargs: Any) -> Any: mock_resp = mocker.MagicMock() mock_resp.output = [mocker.MagicMock(content="topic summary")] return mock_resp @@ -152,78 +136,32 @@ def mock_streaming_byok_tool_client_fixture( # pylint: disable=too-many-stateme mock_list_result.data = [mock_vector_store] mock_client.vector_stores.list.return_value = mock_list_result - # Build a streaming response with file_search and completion events - async def _mock_tool_stream() -> AsyncIterator[Any]: - # file_search output item done - item_done_chunk = mocker.MagicMock() - item_done_chunk.type = "response.output_item.done" - item_done_chunk.output_index = 0 - - mock_item = mocker.MagicMock() - mock_item.type = "file_search_call" - mock_item.id = "call-fs-stream-1" - mock_item.queries = ["What is OpenShift?"] - mock_item.status = "completed" - - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-ocp-1" - mock_result.filename = "openshift-docs.txt" - mock_result.score = 0.92 - mock_result.text = "OpenShift is a Kubernetes distribution by Red Hat." - mock_result.attributes = { - "doc_url": "https://docs.redhat.com/ocp/overview", - } - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-ocp-1", - "filename": "openshift-docs.txt", - "score": 0.92, - "text": "OpenShift is a Kubernetes distribution.", - "attributes": {"doc_url": "https://docs.redhat.com/ocp/overview"}, - } - ) - mock_item.results = [mock_result] - item_done_chunk.item = mock_item - yield item_done_chunk - - # Text done - text_done_chunk = mocker.MagicMock() - text_done_chunk.type = "response.output_text.done" - text_done_chunk.text = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." - ) - yield text_done_chunk - - # Response completed - completed_chunk = mocker.MagicMock() - completed_chunk.type = "response.completed" - mock_final_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_final_response.id = "response-tool-stream" - mock_final_response.error = None - - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 60 - mock_usage.output_tokens = 25 - mock_final_response.usage = mock_usage - - # file_search results in the final response output - mock_fs_output = mocker.MagicMock() - mock_fs_output.type = "file_search_call" - mock_fs_output.id = "call-fs-stream-1" - mock_fs_output.results = [mock_result] - mock_final_response.output = [mock_fs_output] - - completed_chunk.response = mock_final_response - yield completed_chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_tool_stream() - mock_resp = mocker.MagicMock() - mock_resp.output = [mocker.MagicMock(content="topic summary")] - return mock_resp - - mock_client.responses.create = mocker.AsyncMock(side_effect=_responses_create) + mock_agent = configure_streaming_agent_mock( + mocker, + stream_events=create_file_search_agent_stream_events( + mocker, + content=( + "Based on the documentation, OpenShift is a Kubernetes distribution." + ), + response_id="response-tool-stream", + queries=["What is OpenShift?"], + results=[ + { + "text": "OpenShift is a Kubernetes distribution by Red Hat.", + "score": 0.92, + "attributes": { + "doc_url": "https://docs.redhat.com/ocp/overview", + "title": "openshift-docs.txt", + "document_id": "doc-ocp-1", + }, + } + ], + input_tokens=60, + output_tokens=25, + ), + ) + mock_client.streaming_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock mock_holder_class.return_value.get_client.return_value = mock_client yield mock_client @@ -309,12 +247,8 @@ async def test_streaming_query_byok_inline_rag_injects_context( assert isinstance(response, StreamingResponse) - # Verify RAG context was injected into responses.create input - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - call_kwargs = create_call.kwargs - input_text = call_kwargs["input"] + # Verify RAG context was injected into the agent prompt + input_text = get_agent_input_text(mock_streaming_byok_client) assert "file_search found" in input_text assert "OpenShift is a Kubernetes distribution" in input_text @@ -448,11 +382,8 @@ async def test_streaming_query_byok_request_vector_store_ids_filters_configured_ call_kwargs = mock_client.vector_io.query.call_args.kwargs assert call_kwargs["vector_store_id"] == "vs-source-a" - # Verify source-a context was injected into the LLM input - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + # Verify source-a context was injected into the agent prompt + input_text = get_agent_input_text(mock_client) assert "file_search found" in input_text @@ -484,10 +415,7 @@ async def test_streaming_query_byok_inline_rag_empty_vector_store_ids_no_context assert isinstance(response, StreamingResponse) mock_streaming_byok_client.vector_io.query.assert_not_called() - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_streaming_byok_client) assert "file_search found" not in input_text @@ -525,11 +453,7 @@ async def test_streaming_query_byok_inline_rag_error_handled_gracefully( assert isinstance(response, StreamingResponse) # No inline RAG context should be injected when the search fails. - # "file_search found" is the header added by _format_rag_context when chunks are present. - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_streaming_byok_client) assert "file_search found" not in input_text @@ -725,17 +649,14 @@ async def test_streaming_query_byok_combined_inline_and_tool_rag( assert isinstance(response, StreamingResponse) assert response.status_code == status.HTTP_200_OK - # Verify inline RAG context was injected - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - call_kwargs = create_call.kwargs - input_text = call_kwargs["input"] + # Verify inline RAG context was injected into the agent prompt + input_text = get_agent_input_text(mock_client) assert "file_search found" in input_text - # Verify tool RAG file_search was passed - assert call_kwargs.get("tools") is not None - assert any(tool.get("type") == "file_search" for tool in call_kwargs["tools"]) + # Verify tool RAG file_search was passed to the agent + responses_params = get_agent_responses_params(mock_client) + assert responses_params.tools is not None + assert any(tool.type == "file_search" for tool in responses_params.tools) # ============================================================================== @@ -812,10 +733,7 @@ async def test_streaming_query_byok_only_configured_rag_id_is_queried( ] assert "vs-source-b" not in queried_stores - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) assert "file_search found" in input_text @@ -897,10 +815,7 @@ async def _side_effect(**kwargs: Any) -> Any: assert isinstance(response, StreamingResponse) # Verify Doc B (weighted 2.0) appears before Doc A (weighted 0.9) in context - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) pos_b = input_text.find("Doc B low similarity boosted") pos_a = input_text.find("Doc A high similarity") assert pos_b != -1 and pos_a != -1 @@ -969,10 +884,7 @@ async def test_streaming_query_rag_content_limit_caps_context( # pylint: disabl assert isinstance(response, StreamingResponse) # Verify the context header reports the capped count - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) expected_header = f"file_search found {constants.INLINE_RAG_MAX_CHUNKS} chunks:" assert expected_header in input_text @@ -1058,10 +970,7 @@ async def _side_effect(**kwargs: Any) -> Any: assert isinstance(response, StreamingResponse) - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) expected_header = f"file_search found {constants.INLINE_RAG_MAX_CHUNKS} chunks:" assert expected_header in input_text @@ -1132,8 +1041,7 @@ async def test_streaming_query_rag_content_limit_caps_inline_rag( # pylint: dis assert isinstance(response, StreamingResponse) - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) expected_header = "file_search found 3 chunks:" assert expected_header in input_text diff --git a/tests/integration/endpoints/test_streaming_query_integration.py b/tests/integration/endpoints/test_streaming_query_integration.py index 5a7e51620..10ecb5a6f 100644 --- a/tests/integration/endpoints/test_streaming_query_integration.py +++ b/tests/integration/endpoints/test_streaming_query_integration.py @@ -1,12 +1,13 @@ """Integration tests for the /streaming_query endpoint (using Responses API).""" -from collections.abc import AsyncIterator, Generator +from collections.abc import Generator from typing import Any import pytest from fastapi import HTTPException, Request, status from fastapi.responses import StreamingResponse from fastapi.testclient import TestClient +from llama_stack_client.types import VersionInfo from pytest_mock import AsyncMockType, MockerFixture from app.endpoints.streaming_query import streaming_query_endpoint_handler @@ -14,6 +15,13 @@ from configuration import AppConfig from models.api.requests import QueryRequest from models.common.query import Attachment +from tests.integration.conftest import ( + configure_streaming_agent_mock, + create_text_agent_stream_events, +) +from tests.integration.endpoints.test_query_byok_integration import ( + _build_base_mock_client, +) @pytest.fixture(name="mock_streaming_llama_stack_client") @@ -22,32 +30,26 @@ def mock_llama_stack_streaming_fixture( ) -> Generator[Any, None, None]: """Mock only the Llama Stack client (holder + client). - Configures the client so the real handler runs: models, vector_stores, - conversations, shields, vector_io, and responses.create returning a minimal - stream. No other code paths are patched. + Configures the client so the real handler runs with a patched pydantic-ai + streaming agent. No other code paths are patched. """ mock_holder_class = mocker.patch( "app.endpoints.streaming_query.AsyncLlamaStackClientHolder" ) - mock_client = mocker.AsyncMock() - - mock_model = mocker.MagicMock() - mock_model.id = "test-provider/test-model" - mock_model.custom_metadata = { - "provider_id": "test-provider", - "model_type": "llm", - } - mock_client.models.list.return_value = [mock_model] - - mock_vector_stores_response = mocker.MagicMock() - mock_vector_stores_response.data = [] - mock_client.vector_stores.list.return_value = mock_vector_stores_response - - mock_conversation = mocker.MagicMock() - mock_conversation.id = "conv_" + "a" * 48 - mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation) - - mock_client.shields.list.return_value = [] + mock_client = _build_base_mock_client(mocker) + + mock_agent = configure_streaming_agent_mock( + mocker, + stream_events=create_text_agent_stream_events( + mocker, + content="test", + response_id="response-stream-test", + input_tokens=10, + output_tokens=5, + ), + ) + mock_client.streaming_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock mock_client.conversations.items.create = mocker.AsyncMock() @@ -56,20 +58,13 @@ def mock_llama_stack_streaming_fixture( mock_vector_io_response.scores = [] mock_client.vector_io.query = mocker.AsyncMock(return_value=mock_vector_io_response) - async def _mock_stream() -> AsyncIterator[Any]: - chunk = mocker.MagicMock() - chunk.type = "response.output_text.done" - chunk.text = "test" - yield chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_stream() + async def _responses_create(**_kwargs: Any) -> Any: mock_resp = mocker.MagicMock() mock_resp.output = [mocker.MagicMock(content="topic summary")] return mock_resp mock_client.responses.create = mocker.AsyncMock(side_effect=_responses_create) + mock_client.inspect.version.return_value = VersionInfo(version="0.2.22") mock_holder_class.return_value.get_client.return_value = mock_client diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 5a6b43684..231e2486d 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -153,11 +153,14 @@ async def test_successful_query_no_conversation( "Kubernetes is a container orchestration platform" ) - async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: + async def mock_retrieve_agent_response( + *_args: Any, **_kwargs: Any + ) -> TurnSummary: return mock_turn_summary mocker.patch( - "app.endpoints.query.retrieve_response", side_effect=mock_retrieve_response + "app.endpoints.query.retrieve_agent_response", + side_effect=mock_retrieve_agent_response, ) mocker.patch( @@ -245,7 +248,7 @@ async def test_query_merges_inline_and_tool_rag_chunks_and_documents( mock_turn_summary.referenced_documents = [tool_doc] mocker.patch( - "app.endpoints.query.retrieve_response", + "app.endpoints.query.retrieve_agent_response", new=mocker.AsyncMock(return_value=mock_turn_summary), ) mocker.patch("app.endpoints.query.store_query_results") @@ -317,7 +320,7 @@ async def test_successful_query_with_conversation( new=mocker.AsyncMock(return_value=ShieldModerationPassed()), ) mocker.patch( - "app.endpoints.query.retrieve_response", + "app.endpoints.query.retrieve_agent_response", new=mocker.AsyncMock(return_value=TurnSummary()), ) mocker.patch("app.endpoints.query.store_query_results") @@ -394,11 +397,14 @@ async def test_query_with_attachments( new=mocker.AsyncMock(return_value=mock_responses_params), ) - async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: + async def mock_retrieve_agent_response( + *_args: Any, **_kwargs: Any + ) -> TurnSummary: return TurnSummary() mocker.patch( - "app.endpoints.query.retrieve_response", side_effect=mock_retrieve_response + "app.endpoints.query.retrieve_agent_response", + side_effect=mock_retrieve_agent_response, ) mocker.patch( "app.endpoints.query.normalize_conversation_id", return_value="123" @@ -459,7 +465,7 @@ async def test_query_with_topic_summary( ) mocker.patch( - "app.endpoints.query.retrieve_response", + "app.endpoints.query.retrieve_agent_response", new=mocker.AsyncMock(return_value=TurnSummary()), ) mock_get_topic_summary = mocker.patch( @@ -545,11 +551,14 @@ async def test_query_azure_token_refresh( return_value=mock_updated_client ) - async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: + async def mock_retrieve_agent_response( + *_args: Any, **_kwargs: Any + ) -> TurnSummary: return TurnSummary() mocker.patch( - "app.endpoints.query.retrieve_response", side_effect=mock_retrieve_response + "app.endpoints.query.retrieve_agent_response", + side_effect=mock_retrieve_agent_response, ) mocker.patch( "app.endpoints.query.normalize_conversation_id", return_value="123" diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 19176d57f..d7ed215d4 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -397,19 +397,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -484,19 +484,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -582,19 +582,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -678,19 +678,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -776,19 +776,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", diff --git a/tests/unit/pydantic_ai_lightspeed/llamastack/test_model.py b/tests/unit/pydantic_ai_lightspeed/llamastack/test_model.py new file mode 100644 index 000000000..f9f2a7aeb --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/llamastack/test_model.py @@ -0,0 +1,470 @@ +"""Unit tests for pydantic_ai_lightspeed.llamastack._model module.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from dataclasses import dataclass +from typing import Any + +import pytest +from openai.types import responses + +from pydantic_ai_lightspeed.llamastack._model import _FilteredResponseStream + + +class _FakeAsyncStream: + """Minimal AsyncStream stand-in for _FilteredResponseStream tests.""" + + def __init__(self, events: list[responses.ResponseStreamEvent]) -> None: + """Store events to replay from the fake stream. + + Args: + events: Ordered upstream events before reordering. + """ + self._events = events + + def __aiter__(self) -> AsyncIterator[responses.ResponseStreamEvent]: + """Return an async iterator over the configured events.""" + return self._iter_events() + + async def _iter_events(self) -> AsyncIterator[responses.ResponseStreamEvent]: + """Yield each configured event in order.""" + for event in self._events: + yield event + + async def close(self) -> None: + """No-op close for test double compatibility.""" + + +async def _collect_events( + stream: _FilteredResponseStream, +) -> list[responses.ResponseStreamEvent]: + """Drain a filtered stream into a list. + + Args: + stream: Filtered response stream under test. + + Returns: + All events emitted by the filtered stream. + """ + return [event async for event in stream] + + +def _function_delta( + *, + item_id: str, + delta: str, + sequence_number: int, +) -> responses.ResponseFunctionCallArgumentsDeltaEvent: + """Build a function-call arguments delta test event. + + Args: + item_id: Tool-call item identifier. + delta: Argument fragment string. + sequence_number: Event sequence number. + + Returns: + Function-call arguments delta event. + """ + return responses.ResponseFunctionCallArgumentsDeltaEvent.model_validate( + { + "type": "response.function_call_arguments.delta", + "item_id": item_id, + "output_index": 1, + "sequence_number": sequence_number, + "delta": delta, + } + ) + + +def _function_done( + *, + item_id: str, + arguments: str, + sequence_number: int, +) -> responses.ResponseFunctionCallArgumentsDoneEvent: + """Build a function-call arguments done test event. + + Args: + item_id: Tool-call item identifier. + arguments: Final JSON arguments string. + sequence_number: Event sequence number. + + Returns: + Function-call arguments done event. + """ + return responses.ResponseFunctionCallArgumentsDoneEvent.model_validate( + { + "type": "response.function_call_arguments.done", + "item_id": item_id, + "output_index": 1, + "sequence_number": sequence_number, + "arguments": arguments, + "name": "client_tool", + } + ) + + +def _mcp_added( + *, + item_id: str, + sequence_number: int, +) -> responses.ResponseOutputItemAddedEvent: + """Build an MCP output item added test event. + + Args: + item_id: MCP tool-call item identifier. + sequence_number: Event sequence number. + + Returns: + Output item added event for an MCP call. + """ + return responses.ResponseOutputItemAddedEvent.model_validate( + { + "type": "response.output_item.added", + "output_index": 1, + "sequence_number": sequence_number, + "item": { + "type": "mcp_call", + "id": item_id, + "name": "unit_convert", + "arguments": "", + "server_label": "datautils", + }, + } + ) + + +@dataclass +class _LlsMcpArgumentsDone: + """Llama Stack MCP arguments.done event shape before OpenAI SDK normalization.""" + + item_id: str + output_index: int + sequence_number: int + arguments: str + type: str = "response.mcp_call.arguments.done" + + def model_dump(self, exclude: set[str] | None = None) -> dict[str, Any]: + """Return a dict compatible with MCP done normalization. + + Args: + exclude: Optional field names to omit from the dump. + + Returns: + Serialized event fields. + """ + data = { + "type": self.type, + "item_id": self.item_id, + "output_index": self.output_index, + "sequence_number": self.sequence_number, + "arguments": self.arguments, + } + if exclude: + for key in exclude: + data.pop(key, None) + return data + + +def _list_tools_added( + *, + item_id: str, + sequence_number: int, +) -> responses.ResponseOutputItemAddedEvent: + """Build an MCP list-tools output item added test event. + + Args: + item_id: MCP list-tools item identifier. + sequence_number: Event sequence number. + + Returns: + Output item added event for an MCP list-tools call. + """ + return responses.ResponseOutputItemAddedEvent.model_validate( + { + "type": "response.output_item.added", + "output_index": 0, + "sequence_number": sequence_number, + "item": { + "type": "mcp_list_tools", + "id": item_id, + "server_label": "datautils", + "tools": [], + }, + } + ) + + +def _list_tools_done( + *, + item_id: str, + sequence_number: int, +) -> responses.ResponseOutputItemDoneEvent: + """Build an MCP list-tools output item done test event. + + Args: + item_id: MCP list-tools item identifier. + sequence_number: Event sequence number. + + Returns: + Output item done event for an MCP list-tools call. + """ + return responses.ResponseOutputItemDoneEvent.model_validate( + { + "type": "response.output_item.done", + "output_index": 0, + "sequence_number": sequence_number, + "item": { + "type": "mcp_list_tools", + "id": item_id, + "server_label": "datautils", + "tools": [{"name": "tool_a", "input_schema": {}}], + }, + } + ) + + +def _function_added( + *, + item_id: str, + sequence_number: int, +) -> responses.ResponseOutputItemAddedEvent: + """Build a function output item added test event. + + Args: + item_id: Function tool-call item identifier. + sequence_number: Event sequence number. + + Returns: + Output item added event for a function tool call. + """ + return responses.ResponseOutputItemAddedEvent.model_validate( + { + "type": "response.output_item.added", + "output_index": 1, + "sequence_number": sequence_number, + "item": { + "type": "function_call", + "id": item_id, + "call_id": "call_123", + "name": "client_tool", + "arguments": "", + "status": "in_progress", + }, + } + ) + + +class TestFilteredResponseStream: + """Tests for _FilteredResponseStream event reordering.""" + + @pytest.mark.asyncio + async def test_reorders_mcp_events_after_output_item_added(self) -> None: + """Test MCP deltas and done are replayed after output_item.added.""" + item_id = "fc_mcp" + upstream = [ + _function_delta(item_id=item_id, delta='{"', sequence_number=1), + _function_delta(item_id=item_id, delta="value", sequence_number=2), + _function_delta(item_id=item_id, delta='":100}', sequence_number=3), + _LlsMcpArgumentsDone( + item_id=item_id, + output_index=1, + sequence_number=4, + arguments='{"value":100}', + ), + _mcp_added(item_id=item_id, sequence_number=5), + ] + events = await _collect_events( + _FilteredResponseStream(_FakeAsyncStream(upstream)) # type: ignore[arg-type] + ) + types = [event.type for event in events] + + assert types[0] == "response.output_item.added" + assert types[1:4] == ["response.mcp_call_arguments.delta"] * 3 + assert types[4] == "response.mcp_call_arguments.done" + mcp_deltas = [ + event + for event in events[1:4] + if isinstance(event, responses.ResponseMcpCallArgumentsDeltaEvent) + ] + assert [delta.delta for delta in mcp_deltas] == ['{"', "value", '":100}'] + + @pytest.mark.asyncio + async def test_reorders_function_events_after_output_item_added(self) -> None: + """Test function deltas and done are replayed after output_item.added.""" + item_id = "fc_fn" + delta = _function_delta(item_id=item_id, delta='{"x":1}', sequence_number=1) + done = _function_done( + item_id=item_id, + arguments='{"x":1}', + sequence_number=2, + ) + added = _function_added(item_id=item_id, sequence_number=3) + events = await _collect_events( + _FilteredResponseStream(_FakeAsyncStream([delta, done, added])) # type: ignore[arg-type] + ) + types = [event.type for event in events] + + assert types == [ + "response.output_item.added", + "response.function_call_arguments.delta", + "response.function_call_arguments.done", + ] + + @pytest.mark.asyncio + async def test_passes_through_events_after_output_item_added(self) -> None: + """Test post-announcement deltas are not buffered.""" + item_id = "fc_live" + added = _function_added(item_id=item_id, sequence_number=1) + delta = _function_delta(item_id=item_id, delta='{"x":1}', sequence_number=2) + events = await _collect_events( + _FilteredResponseStream(_FakeAsyncStream([added, delta])) # type: ignore[arg-type] + ) + types = [event.type for event in events] + + assert types == [ + "response.output_item.added", + "response.function_call_arguments.delta", + ] + + @pytest.mark.asyncio + async def test_reorders_mcp_list_tools_events_after_output_item_added(self) -> None: + """Test list-tools lifecycle events replay after output_item.added.""" + item_id = "mcp_list_test" + upstream = [ + responses.ResponseMcpListToolsInProgressEvent.model_validate( + { + "type": "response.mcp_list_tools.in_progress", + "item_id": item_id, + "output_index": 0, + "sequence_number": 1, + } + ), + _list_tools_added(item_id=item_id, sequence_number=2), + responses.ResponseMcpListToolsCompletedEvent.model_validate( + { + "type": "response.mcp_list_tools.completed", + "item_id": item_id, + "output_index": 0, + "sequence_number": 3, + } + ), + _list_tools_done(item_id=item_id, sequence_number=4), + ] + events = await _collect_events( + _FilteredResponseStream(_FakeAsyncStream(upstream)) # type: ignore[arg-type] + ) + types = [event.type for event in events] + + assert types == [ + "response.output_item.added", + "response.mcp_list_tools.in_progress", + "response.mcp_list_tools.completed", + "response.output_item.done", + ] + + @pytest.mark.asyncio + async def test_reorders_all_mcp_list_tools_events_before_added(self) -> None: + """Test list-tools events buffered when they all arrive before added.""" + item_id = "mcp_list_early" + upstream = [ + responses.ResponseMcpListToolsInProgressEvent.model_validate( + { + "type": "response.mcp_list_tools.in_progress", + "item_id": item_id, + "output_index": 0, + "sequence_number": 1, + } + ), + responses.ResponseMcpListToolsCompletedEvent.model_validate( + { + "type": "response.mcp_list_tools.completed", + "item_id": item_id, + "output_index": 0, + "sequence_number": 2, + } + ), + _list_tools_done(item_id=item_id, sequence_number=3), + _list_tools_added(item_id=item_id, sequence_number=4), + ] + events = await _collect_events( + _FilteredResponseStream(_FakeAsyncStream(upstream)) # type: ignore[arg-type] + ) + types = [event.type for event in events] + + assert types == [ + "response.output_item.added", + "response.mcp_list_tools.in_progress", + "response.mcp_list_tools.completed", + "response.output_item.done", + ] + + @pytest.mark.asyncio + async def test_flushes_buffered_events_when_added_never_arrives(self) -> None: + """Test buffered events are flushed if output_item.added never arrives.""" + item_id = "fc_orphan" + delta = _function_delta(item_id=item_id, delta="{}", sequence_number=1) + events = await _collect_events( + _FilteredResponseStream(_FakeAsyncStream([delta])) # type: ignore[arg-type] + ) + + assert len(events) == 1 + assert events[0].type == "response.mcp_call_arguments.delta" + + @pytest.mark.asyncio + async def test_converts_post_added_function_deltas_for_mcp_call(self) -> None: + """Test function argument deltas after added are rewritten for MCP calls.""" + item_id = "fc_live_mcp" + added = _mcp_added(item_id=item_id, sequence_number=1) + delta = _function_delta(item_id=item_id, delta='{"value":1}', sequence_number=2) + events = await _collect_events( + _FilteredResponseStream(_FakeAsyncStream([added, delta])) # type: ignore[arg-type] + ) + types = [event.type for event in events] + + assert types == [ + "response.output_item.added", + "response.mcp_call_arguments.delta", + ] + + @pytest.mark.asyncio + async def test_buffers_mcp_output_done_until_arguments_done(self) -> None: + """Test mcp_call output_item.done is held until arguments.done is emitted.""" + item_id = "fc_mcp_done_order" + added = _mcp_added(item_id=item_id, sequence_number=1) + output_done = responses.ResponseOutputItemDoneEvent.model_validate( + { + "type": "response.output_item.done", + "output_index": 1, + "sequence_number": 2, + "item": { + "type": "mcp_call", + "id": item_id, + "name": "unit_convert", + "arguments": '{"action":"call_tool","tool_name":"unit_convert","tool_args":{}}', + "server_label": "datautils", + }, + } + ) + mcp_done = responses.ResponseMcpCallArgumentsDoneEvent.model_validate( + { + "type": "response.mcp_call_arguments.done", + "item_id": item_id, + "output_index": 1, + "sequence_number": 3, + "arguments": "{}", + } + ) + events = await _collect_events( + _FilteredResponseStream( # type: ignore[arg-type] + _FakeAsyncStream([added, output_done, mcp_done]) + ) + ) + types = [event.type for event in events] + + assert types == [ + "response.output_item.added", + "response.mcp_call_arguments.done", + "response.output_item.done", + ] diff --git a/tests/unit/utils/agents/test_tool_processor.py b/tests/unit/utils/agents/test_tool_processor.py new file mode 100644 index 000000000..4acf9b9d1 --- /dev/null +++ b/tests/unit/utils/agents/test_tool_processor.py @@ -0,0 +1,674 @@ +"""Unit tests for utils.agents.tool_processor module.""" + +import json + +import pytest +from openai.types.responses.response_file_search_tool_call import ( + Result as OpenAIFileSearchResult, +) +from pydantic import AnyUrl +from pydantic_ai.messages import ( + NativeToolCallPart, + NativeToolReturnPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.native_tools import FileSearchTool, MCPServerTool, WebSearchTool +from pytest_mock import MockerFixture + +from constants import DEFAULT_RAG_TOOL +from models.common.agents import AgentTurnAccumulator +from models.common.turn_summary import TurnSummary +from utils.agents.tool_processor import ( + build_referenced_document, + process_function_tool_call, + process_function_tool_result, + process_native_tool_call, + process_native_tool_result, + rag_chunks_from_file_search_results, + referenced_documents_from_file_search_results, + summarize_file_search_result, + summarize_function_tool_call, + summarize_function_tool_result, + summarize_mcp_call_result, + summarize_mcp_list_tools_result, + summarize_mcp_tool_result, + summarize_native_tool_call, + summarize_web_search_result, +) + + +@pytest.fixture(name="turn_state") +def turn_state_fixture() -> AgentTurnAccumulator: + """Create a fresh agent turn accumulator for dispatch tests.""" + return AgentTurnAccumulator( + vector_store_ids=["vs-001"], + rag_id_mapping={"vs-001": "ocp-docs"}, + turn_summary=TurnSummary(), + ) + + +def _file_search_result(**kwargs: object) -> OpenAIFileSearchResult: + """Build a validated OpenAI file-search result row.""" + return OpenAIFileSearchResult.model_validate(kwargs) + + +class TestSummarizeFunctionToolCall: + """Tests for summarize_function_tool_call.""" + + def test_builds_function_call_summary(self) -> None: + """Test function tool call is mapped to ToolCallSummary.""" + part = ToolCallPart( + tool_name="my_fn", + args={"key": "value"}, + tool_call_id="call-fn-1", + ) + + summary = summarize_function_tool_call(part) + + assert summary.id == "call-fn-1" + assert summary.name == "my_fn" + assert summary.args == {"key": "value"} + assert summary.type == "function_call" + + +class TestSummarizeNativeToolCall: + """Tests for summarize_native_tool_call.""" + + def test_web_search_call(self) -> None: + """Test web search native tool call summary.""" + part = NativeToolCallPart( + tool_name=WebSearchTool.kind, + args={"query": "OpenShift"}, + tool_call_id="ws-1", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.type == "web_search_call" + assert summary.name == WebSearchTool.kind + + def test_file_search_call(self) -> None: + """Test file search native tool call uses DEFAULT_RAG_TOOL name.""" + part = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": ["docs"]}, + tool_call_id="fs-1", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.name == DEFAULT_RAG_TOOL + assert summary.type == "file_search_call" + + def test_mcp_list_tools_call(self) -> None: + """Test MCP list-tools action summary.""" + part = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:srv", + args={"action": "list_tools"}, + tool_call_id="mcp-list-1", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.name == "mcp_list_tools" + assert summary.args == {"server_label": "srv"} + assert summary.type == "mcp_list_tools" + + def test_mcp_list_tools_call_with_label(self) -> None: + """Test labeled MCP list-tools action uses the server label suffix.""" + part = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:myserver", + args={"action": "list_tools"}, + tool_call_id="mcp-list-labeled", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.args == {"server_label": "myserver"} + + def test_mcp_call(self) -> None: + """Test MCP tool call summary.""" + part = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:srv", + args={ + "action": "call", + "tool_name": "remote_tool", + "tool_args": {"arg": 1}, + }, + tool_call_id="mcp-call-1", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.name == "remote_tool" + assert summary.args == {"arg": 1} + assert summary.type == "mcp_call" + + def test_unknown_tool_returns_none(self, mocker: MockerFixture) -> None: + """Test unknown native tool logs warning and returns None.""" + mock_warning = mocker.patch("utils.agents.tool_processor.logger.warning") + part = NativeToolCallPart( + tool_name="unknown_tool", + args={}, + tool_call_id="unk-1", + ) + + assert summarize_native_tool_call(part) is None + mock_warning.assert_called_once() + + +class TestProcessFunctionToolCall: + """Tests for process_function_tool_call.""" + + def test_records_tool_call_on_state(self, turn_state: AgentTurnAccumulator) -> None: + """Test first function tool call is recorded on turn state.""" + part = ToolCallPart( + tool_name="fn", + args={"x": 1}, + tool_call_id="call-1", + ) + + summary = process_function_tool_call(turn_state, part) + + assert summary is not None + assert turn_state.turn_summary.tool_calls == [summary] + assert "call-1" in turn_state.emitted_tool_call_ids + + def test_skips_duplicate_tool_call(self, turn_state: AgentTurnAccumulator) -> None: + """Test duplicate function tool call id is not recorded twice.""" + part = ToolCallPart(tool_name="fn", args={}, tool_call_id="call-dup") + process_function_tool_call(turn_state, part) + + assert process_function_tool_call(turn_state, part) is None + assert len(turn_state.turn_summary.tool_calls) == 1 + + def test_increments_round_when_pending( + self, turn_state: AgentTurnAccumulator + ) -> None: + """Test pending round increment runs before recording tool call.""" + turn_state.round_increment_pending = True + turn_state.tool_round = 2 + part = ToolCallPart(tool_name="fn", args={}, tool_call_id="call-round") + + process_function_tool_call(turn_state, part) + + assert turn_state.tool_round == 3 + assert not turn_state.round_increment_pending + + +class TestProcessNativeToolCall: + """Tests for process_native_tool_call.""" + + def test_records_native_tool_call(self, turn_state: AgentTurnAccumulator) -> None: + """Test native tool call is recorded on turn state.""" + part = NativeToolCallPart( + tool_name=WebSearchTool.kind, + args={"query": "q"}, + tool_call_id="ws-record", + ) + + summary = process_native_tool_call(turn_state, part) + + assert summary is not None + assert turn_state.turn_summary.tool_calls == [summary] + + def test_skips_duplicate_and_unknown( + self, turn_state: AgentTurnAccumulator, mocker: MockerFixture + ) -> None: + """Test duplicate ids and unknown tools are not recorded.""" + mocker.patch("utils.agents.tool_processor.logger.warning") + part = NativeToolCallPart( + tool_name="unknown", + args={}, + tool_call_id="unk-record", + ) + + assert process_native_tool_call(turn_state, part) is None + assert not turn_state.turn_summary.tool_calls + + known = NativeToolCallPart( + tool_name=WebSearchTool.kind, + args={}, + tool_call_id="ws-dup", + ) + process_native_tool_call(turn_state, known) + assert process_native_tool_call(turn_state, known) is None + + def test_defers_incomplete_mcp_call_until_tool_name_present( + self, turn_state: AgentTurnAccumulator + ) -> None: + """Test MCP call PartEnd is skipped until streamed args include tool_name.""" + incomplete = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:datautils", + tool_call_id="fc-mcp-incomplete", + args=None, + provider_name="llama-stack", + ) + complete = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:datautils", + tool_call_id="fc-mcp-incomplete", + args={ + "action": "call_tool", + "tool_name": "unit_convert", + "tool_args": {"value": 100, "from_unit": "mi", "to_unit": "km"}, + }, + provider_name="llama-stack", + ) + + assert process_native_tool_call(turn_state, incomplete) is None + assert turn_state.emitted_tool_call_ids == set() + + summary = process_native_tool_call(turn_state, complete) + + assert summary is not None + assert summary.name == "unit_convert" + assert summary.type == "mcp_call" + assert summary.args == { + "value": 100, + "from_unit": "mi", + "to_unit": "km", + } + + def test_emits_mcp_list_tools_without_deferral( + self, turn_state: AgentTurnAccumulator + ) -> None: + """Test MCP list-tools calls emit on first PartEnd with full args.""" + part = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:datautils", + tool_call_id="mcp-list-1", + args={"action": "list_tools"}, + provider_name="llama-stack", + ) + + summary = process_native_tool_call(turn_state, part) + + assert summary is not None + assert summary.type == "mcp_list_tools" + + +class TestSummarizeFunctionToolResult: + """Tests for summarize_function_tool_result.""" + + def test_builds_function_tool_result(self) -> None: + """Test function tool return maps to ToolResultSummary.""" + part = ToolReturnPart( + tool_name="fn", + content={"answer": 42}, + tool_call_id="result-1", + ) + + result = summarize_function_tool_result(part, tool_round=3) + + assert result.id == "result-1" + assert result.status == "success" + assert result.type == "function_call_output" + assert result.round == 3 + assert json.loads(result.content) == {"answer": 42} + + +class TestProcessFunctionToolResult: + """Tests for process_function_tool_result.""" + + def test_records_function_tool_result( + self, turn_state: AgentTurnAccumulator + ) -> None: + """Test function tool result is recorded and marks round pending.""" + part = ToolReturnPart( + tool_name="fn", + content="ok", + tool_call_id="result-record", + ) + + result = process_function_tool_result(turn_state, part) + + assert result is not None + assert turn_state.turn_summary.tool_results == [result] + assert turn_state.round_increment_pending + assert "result-record" in turn_state.emitted_tool_result_ids + + def test_skips_duplicate_result(self, turn_state: AgentTurnAccumulator) -> None: + """Test duplicate function tool result id is ignored.""" + part = ToolReturnPart(tool_name="fn", content="ok", tool_call_id="result-dup") + process_function_tool_result(turn_state, part) + + assert process_function_tool_result(turn_state, part) is None + assert len(turn_state.turn_summary.tool_results) == 1 + + +class TestBuildReferencedDocument: + """Tests for build_referenced_document.""" + + def test_returns_none_without_title_or_url(self) -> None: + """Test result without title or URL metadata is skipped.""" + result = _file_search_result(attributes={"document_id": "only-id"}) + + assert build_referenced_document(result, ["vs-001"], {}) is None + + def test_builds_from_url_and_title_with_source_mapping(self) -> None: + """Test referenced document resolves source from vector store mapping.""" + result = _file_search_result( + attributes={ + "link": "https://example.com/doc", + "title": "Example Doc", + "document_id": "doc-1", + } + ) + + doc = build_referenced_document(result, ["vs-001"], {"vs-001": "mapped-source"}) + + assert doc is not None + assert doc.doc_url == AnyUrl("https://example.com/doc") + assert doc.doc_title == "Example Doc" + assert doc.document_id == "doc-1" + assert doc.source == "mapped-source" + + def test_supports_alternate_url_and_id_keys(self) -> None: + """Test doc_url and doc_id attribute key fallbacks.""" + result = _file_search_result( + attributes={ + "docs_url": "https://example.com/alt", + "title": "Alt Doc", + "doc_id": "alt-id", + } + ) + + doc = build_referenced_document(result, [], {}) + + assert doc is not None + assert doc.doc_url == AnyUrl("https://example.com/alt") + assert doc.document_id == "alt-id" + + def test_title_only_document(self) -> None: + """Test referenced document can be built with title only.""" + result = _file_search_result(attributes={"title": "Title Only"}) + + doc = build_referenced_document(result, [], {}) + + assert doc is not None + assert doc.doc_url is None + assert doc.doc_title == "Title Only" + + +class TestReferencedDocumentsFromFileSearchResults: + """Tests for referenced_documents_from_file_search_results.""" + + def test_deduplicates_documents(self) -> None: + """Test seen_docs prevents duplicate referenced documents.""" + results = [ + _file_search_result(attributes={"url": "https://dup.com", "title": "Same"}), + _file_search_result( + attributes={"link": "https://dup.com", "title": "Same"} + ), + _file_search_result( + attributes={"url": "https://other.com", "title": "Other"} + ), + _file_search_result(attributes={"document_id": "no-metadata"}), + ] + seen_docs: set[tuple[str, str]] = set() + + documents = referenced_documents_from_file_search_results( + results, seen_docs, ["vs-001"], {"vs-001": "source"} + ) + + assert len(documents) == 2 + assert len(seen_docs) == 2 + + +class TestRagChunksFromFileSearchResults: + """Tests for rag_chunks_from_file_search_results.""" + + def test_skips_empty_text_and_maps_source(self) -> None: + """Test chunks without text are skipped and source is resolved.""" + results = [ + _file_search_result(text="chunk one", score=0.8, attributes={}), + _file_search_result(text="", score=0.5, attributes={}), + ] + + chunks = rag_chunks_from_file_search_results( + results, ["vs-001"], {"vs-001": "mapped"} + ) + + assert len(chunks) == 1 + assert chunks[0].content == "chunk one" + assert chunks[0].source == "mapped" + assert chunks[0].score == 0.8 + + +class TestSummarizeWebSearchResult: + """Tests for summarize_web_search_result.""" + + def test_serializes_remaining_content(self) -> None: + """Test web search result keeps non-status fields as JSON content.""" + part = NativeToolReturnPart( + tool_name=WebSearchTool.kind, + tool_call_id="ws-result", + content={"status": "success", "results": [{"title": "hit"}]}, + ) + + result = summarize_web_search_result(part, tool_round=1) + + assert result.status == "success" + assert result.type == "web_search_call" + assert json.loads(result.content) == {"results": [{"title": "hit"}]} + + def test_empty_content_when_only_status(self) -> None: + """Test web search result content is empty when only status remains.""" + part = NativeToolReturnPart( + tool_name=WebSearchTool.kind, + tool_call_id="ws-empty", + content={"status": "success"}, + ) + + result = summarize_web_search_result(part, tool_round=2) + + assert not result.content + + +class TestSummarizeMcpResults: + """Tests for MCP tool result summarizers.""" + + def test_list_tools_success(self) -> None: + """Test MCP list-tools success payload is serialized.""" + part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-list", + content={ + "tools": [ + {"name": "tool_a", "description": "does things"}, + ] + }, + ) + + result = summarize_mcp_list_tools_result(part, tool_round=1) + + assert result.status == "success" + assert result.type == "mcp_list_tools" + payload = json.loads(result.content) + assert payload["server_label"] == "srv" + assert payload["tools"][0]["name"] == "tool_a" + + def test_list_tools_error(self) -> None: + """Test MCP list-tools error returns failure summary.""" + part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-list-err", + content={"error": "unavailable"}, + ) + + result = summarize_mcp_list_tools_result(part, tool_round=1) + + assert result.status == "failure" + assert result.content == "unavailable" + + def test_mcp_call_success_and_error(self) -> None: + """Test MCP call success and error summaries.""" + success_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-call-ok", + content={"output": "done"}, + ) + error_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-call-err", + content={"error": "failed"}, + ) + + success = summarize_mcp_call_result(success_part, tool_round=2) + error = summarize_mcp_call_result(error_part, tool_round=2) + + assert success.status == "success" + assert success.content == "done" + assert error.status == "failure" + assert error.content == "failed" + + def test_mcp_tool_result_dispatches_by_shape(self) -> None: + """Test summarize_mcp_tool_result routes list-tools vs call payloads.""" + list_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="dispatch-list", + content={"tools": [], "error": None}, + ) + call_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="dispatch-call", + content={"output": "ok", "error": None}, + ) + + list_result = summarize_mcp_tool_result(list_part, tool_round=1) + call_result = summarize_mcp_tool_result(call_part, tool_round=1) + + assert list_result.type == "mcp_list_tools" + assert call_result.type == "mcp_call" + + +class TestSummarizeFileSearchResult: + """Tests for summarize_file_search_result.""" + + def test_builds_tool_result_rag_chunks_and_referenced_docs(self) -> None: + """Test file-search return produces result, chunks, and referenced docs.""" + part = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="fs-result", + content={ + "status": "success", + "results": [ + { + "text": "chunk text", + "score": 0.95, + "attributes": { + "title": "Doc", + "url": "https://example.com", + }, + }, + {"text": "", "attributes": {}}, + ], + }, + ) + seen_docs: set[tuple[str, str]] = set() + + tool_result, rag_chunks, referenced_docs = summarize_file_search_result( + part, + tool_round=4, + seen_docs=seen_docs, + vector_store_ids=["vs-001"], + rag_id_mapping={"vs-001": "mapped"}, + ) + + assert tool_result.status == "success" + assert tool_result.type == "file_search_call" + assert tool_result.round == 4 + assert len(rag_chunks) == 1 + assert rag_chunks[0].content == "chunk text" + assert len(referenced_docs) == 1 + assert referenced_docs[0].doc_title == "Doc" + assert len(seen_docs) == 1 + + +class TestProcessNativeToolResult: + """Tests for process_native_tool_result.""" + + def test_records_file_search_result(self, turn_state: AgentTurnAccumulator) -> None: + """Test file-search result updates tool results, RAG chunks, and docs.""" + part = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="fs-process", + content={ + "status": "success", + "results": [ + { + "text": "rag", + "attributes": {"title": "RAG Doc", "url": "https://rag"}, + } + ], + }, + ) + + result = process_native_tool_result(turn_state, part) + + assert result is not None + assert turn_state.turn_summary.tool_results == [result] + assert len(turn_state.turn_summary.rag_chunks) == 1 + assert len(turn_state.turn_summary.referenced_documents) == 1 + assert turn_state.round_increment_pending + + def test_records_labeled_mcp_result(self, turn_state: AgentTurnAccumulator) -> None: + """Test labeled MCP tool return is processed like unlabeled MCP returns.""" + part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-labeled", + content={"output": "labeled-output"}, + ) + + result = process_native_tool_result(turn_state, part) + + assert result is not None + assert result.content == "labeled-output" + + def test_records_web_search_and_mcp_results( + self, turn_state: AgentTurnAccumulator + ) -> None: + """Test web search and MCP results are recorded on turn state.""" + web_part = NativeToolReturnPart( + tool_name=WebSearchTool.kind, + tool_call_id="ws-process", + content={"status": "success"}, + ) + mcp_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-process", + content={"output": "mcp-output"}, + ) + + web_result = process_native_tool_result(turn_state, web_part) + mcp_result = process_native_tool_result(turn_state, mcp_part) + + assert web_result is not None + assert mcp_result is not None + assert len(turn_state.turn_summary.tool_results) == 2 + + def test_skips_duplicate_and_unknown( + self, turn_state: AgentTurnAccumulator, mocker: MockerFixture + ) -> None: + """Test duplicate ids and unknown tool returns are ignored.""" + mocker.patch("utils.agents.tool_processor.logger.warning") + part = NativeToolReturnPart( + tool_name="unknown", + tool_call_id="unk-result", + content={"status": "success"}, + ) + + assert process_native_tool_result(turn_state, part) is None + + known = NativeToolReturnPart( + tool_name=WebSearchTool.kind, + tool_call_id="ws-dup-result", + content={"status": "success"}, + ) + process_native_tool_result(turn_state, known) + assert process_native_tool_result(turn_state, known) is None