diff --git a/src/utils/agents/query.py b/src/utils/agents/query.py new file mode 100644 index 000000000..0322c43ad --- /dev/null +++ b/src/utils/agents/query.py @@ -0,0 +1,328 @@ +"""Non-streaming agent helpers and shared turn-summary builders for agent runs.""" + +from __future__ import annotations + +from enum import Enum +from typing import TypeAlias, cast + +from fastapi import HTTPException +from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient +from pydantic_ai.exceptions import ( + AgentRunError, + ContentFilterError, + IncompleteToolCall, + ModelAPIError, + ModelHTTPError, + UnexpectedModelBehavior, + UsageLimitExceeded, +) +from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart +from pydantic_ai.run import AgentRunResult +from pydantic_ai.usage import RunUsage + +from configuration import configuration +from log import get_logger +from metrics import recording +from models.api.responses.error import ( + AbstractErrorResponse, + InternalServerErrorResponse, + PromptTooLongResponse, + QuotaExceededResponse, + ServiceUnavailableResponse, +) +from models.common.agents import AgentTurnAccumulator +from models.common.moderation import ShieldModerationResult +from models.common.responses.responses_api_params import ResponsesApiParams +from models.common.turn_summary import TurnSummary +from utils.agents.tool_processor import ( + process_function_tool_call, + process_function_tool_result, + process_native_tool_call, + process_native_tool_result, +) +from utils.conversations import append_turn_items_to_conversation +from utils.pydantic_ai import build_agent +from utils.query import ( + extract_provider_and_model_from_model_id, + handle_known_apistatus_errors, + is_context_length_error, +) +from utils.responses import extract_vector_store_ids_from_tools +from utils.token_counter import TokenCounter + +logger = get_logger(__name__) + +AgentInferenceError: TypeAlias = ( + AgentRunError | APIStatusError | APIConnectionError | RuntimeError +) + + +class AgentFinishReason(str, Enum): + """Finish reason for a completed agent model response.""" + + CONTENT_FILTER = "content_filter" + CANCELLED = "cancelled" + SUCCESS = "stop" + LENGTH = "length" + ERROR = "error" + + +def map_agent_inference_error( + exc: AgentInferenceError, + model_id: str, +) -> AbstractErrorResponse: + """Map agent run failures from pydantic-ai or Llama Stack to an LCS error response. + + Args: + exc: Agent, HTTP status, connection, or context-length runtime error. + model_id: Model identifier in provider/model format. + + Returns: + Structured error response for HTTP or SSE error events. + + Raises: + RuntimeError: Re-raised when ``exc`` is a non-agent ``RuntimeError`` that is + not a recognized context-length failure. + """ + match exc: + case AgentRunError() as agent_exc: + return map_pydantic_agent_run_error(agent_exc, model_id) + case APIStatusError() as status_exc: + return handle_known_apistatus_errors(status_exc, model_id) + case APIConnectionError() as connection_exc: + return ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(connection_exc), + ) + case RuntimeError() as runtime_exc if is_context_length_error(str(runtime_exc)): + return PromptTooLongResponse(model=model_id) + case _: + return InternalServerErrorResponse.generic() + + +def map_pydantic_agent_run_error( + exc: AgentRunError, model_id: str +) -> AbstractErrorResponse: + """Map pydantic-ai ``AgentRunError`` subclasses to LCS error responses. + + Args: + exc: Agent exception to map. + model_id: Model identifier in provider/model format. + + Returns: + Structured error response for HTTP or SSE error events. + """ + match exc: + case ContentFilterError() as filter_exc: + return InternalServerErrorResponse.query_failed(str(filter_exc)) + case IncompleteToolCall() | UnexpectedModelBehavior(): + return PromptTooLongResponse(model=model_id) + case UsageLimitExceeded(): + return QuotaExceededResponse.model(model_id) + case ModelHTTPError() as http_exc if is_context_length_error(str(http_exc)): + return PromptTooLongResponse(model=model_id) + case ModelHTTPError(status_code=429): + return QuotaExceededResponse.model(model_id) + case ModelHTTPError(): + return InternalServerErrorResponse.generic() + case ModelAPIError() as api_exc: + return ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(api_exc), + ) + case _: + return InternalServerErrorResponse.query_failed(str(exc)) + + +def get_agent_finish_reason(response: ModelResponse) -> AgentFinishReason: + """Get the finish reason from a completed agent model response. + + Args: + response: Last model response from the agent run. + + Returns: + Resolved finish reason. + """ + raw_finish_reason = (response.provider_details or {}).get("finish_reason") + if raw_finish_reason == "cancelled": + return AgentFinishReason.CANCELLED + if response.finish_reason is None: + return AgentFinishReason.ERROR + return AgentFinishReason(response.finish_reason) + + +def get_finish_reason_error( + finish_reason: AgentFinishReason, + model_id: str, +) -> AbstractErrorResponse: + """Map a non-success agent finish reason to an LCS error response. + + Args: + finish_reason: Resolved finish reason from :func:`get_agent_finish_reason`. + model_id: Model identifier in provider/model format. + + Returns: + Structured error response for HTTP or SSE error events. + """ + match finish_reason: + case AgentFinishReason.LENGTH: + return PromptTooLongResponse(model=model_id) + case AgentFinishReason.CONTENT_FILTER: + return InternalServerErrorResponse.query_failed( + "The model refused to generate a response due to content policy." + ) + case AgentFinishReason.CANCELLED: + return InternalServerErrorResponse.query_failed( + "The response was cancelled before completion." + ) + case _: + return InternalServerErrorResponse.query_failed( + "An unexpected error occurred while processing the request." + ) + + +def extract_agent_token_usage( + usage: RunUsage, + model: str, + endpoint_path: str, +) -> TokenCounter: + """Build token usage for a completed agent run and record related metrics. + + Args: + usage: Run usage reported by the agent. + model: Model identifier in provider/model format. + endpoint_path: Endpoint path used for metric labeling. + + Returns: + Aggregated token usage counter for the run. + """ + provider_id, model_id = extract_provider_and_model_from_model_id(model) + token_counter = TokenCounter( + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + llm_calls=max(usage.requests, 1), + ) + logger.debug( + "Extracted token usage from agent run: input=%d, output=%d, requests=%d", + token_counter.input_tokens, + token_counter.output_tokens, + usage.requests, + ) + recording.record_llm_token_usage( + provider_id, + model_id, + token_counter.input_tokens, + token_counter.output_tokens, + endpoint_path, + ) + recording.record_llm_call(provider_id, model_id, endpoint_path) + return token_counter + + +def build_turn_summary_from_agent_run( + run_result: AgentRunResult[str], + *, + model_id: str, + endpoint_path: str, + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> TurnSummary: + """Build a turn summary from a completed agent run. + + Args: + run_result: Completed agent run result. + model_id: Model identifier in provider/model format. + endpoint_path: Endpoint path used for metric labeling. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Turn summary with text, tools, RAG metadata, and token usage. + + Raises: + HTTPException: When the run failed. + """ + finish_reason = get_agent_finish_reason(run_result.response) + if finish_reason != AgentFinishReason.SUCCESS: + error_response = get_finish_reason_error(finish_reason, model_id) + raise HTTPException(**error_response.model_dump()) + + state = AgentTurnAccumulator( + vector_store_ids=vector_store_ids, + rag_id_mapping=rag_id_mapping, + turn_summary=TurnSummary(), + ) + + for message in run_result.new_messages(): + if isinstance(message, ModelResponse): + if message.text: + state.turn_summary.llm_response = message.text + for tool_call_part in message.tool_calls: + process_function_tool_call(state, tool_call_part) + for call_part, return_part in message.native_tool_calls: + process_native_tool_call(state, call_part) + process_native_tool_result(state, return_part) + elif isinstance(message, ModelRequest): + for request_part in message.parts: + if isinstance(request_part, ToolReturnPart): + process_function_tool_result(state, request_part) + + state.turn_summary.id = run_result.response.provider_response_id or "" + state.turn_summary.token_usage = extract_agent_token_usage( + run_result.usage, + model_id, + endpoint_path, + ) + return state.turn_summary + + +async def retrieve_agent_response( + client: AsyncLlamaStackClient, + responses_params: ResponsesApiParams, + moderation_result: ShieldModerationResult, + endpoint_path: str, +) -> TurnSummary: + """Retrieve a turn summary from a blocking agent run. + + Mirrors :func:`app.endpoints.query.retrieve_response` for the agent path. + + Args: + client: Llama Stack client for conversation persistence on moderation block. + responses_params: Prepared Responses API parameters. + moderation_result: Shield moderation outcome for the turn. + endpoint_path: Endpoint path used for metric labeling. + + Returns: + Turn summary for the completed agent run. + + Raises: + HTTPException: On moderation is not applicable; on agent or provider failure. + """ + if moderation_result.decision == "blocked": + await append_turn_items_to_conversation( + client, + responses_params.conversation, + responses_params.input, + [moderation_result.refusal_response], + ) + return TurnSummary( + id=moderation_result.moderation_id, + llm_response=moderation_result.message, + ) + try: + agent = build_agent(client, responses_params) + logger.debug("Starting agent non-streaming response processing") + run_result = await agent.run(cast(str, responses_params.input)) + except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc: + response = map_agent_inference_error(exc, responses_params.model) + raise HTTPException(**response.model_dump()) from exc + + vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools) + rag_id_mapping = configuration.rag_id_mapping + return build_turn_summary_from_agent_run( + run_result, + model_id=responses_params.model, + endpoint_path=endpoint_path, + vector_store_ids=vector_store_ids, + rag_id_mapping=rag_id_mapping, + ) diff --git a/tests/unit/utils/agents/test_query.py b/tests/unit/utils/agents/test_query.py new file mode 100644 index 000000000..6baa1380f --- /dev/null +++ b/tests/unit/utils/agents/test_query.py @@ -0,0 +1,487 @@ +"""Unit tests for utils.agents.query module.""" + +from collections.abc import Callable +from typing import Any + +import pytest +from fastapi import HTTPException +from llama_stack_api.openai_responses import ( + OpenAIResponseMessage as ResponseMessage, +) +from llama_stack_client import APIConnectionError, APIStatusError +from pydantic_ai.messages import ( + FinishReason, + ModelRequest, + ModelResponse, + NativeToolCallPart, + NativeToolReturnPart, + TextPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.native_tools import FileSearchTool +from pydantic_ai.usage import RunUsage +from pytest_mock import MockerFixture + +from constants import ENDPOINT_PATH_QUERY +from models.common.moderation import ShieldModerationBlocked, ShieldModerationPassed +from models.common.responses.responses_api_params import ResponsesApiParams +from models.common.responses.types import ResponseInput +from models.common.turn_summary import TurnSummary +from utils.agents.query import ( + AgentFinishReason, + build_turn_summary_from_agent_run, + extract_agent_token_usage, + get_agent_finish_reason, + retrieve_agent_response, +) +from utils.token_counter import TokenCounter + + +@pytest.fixture(name="make_agent_run_result") +def make_agent_run_result_fixture(mocker: MockerFixture) -> Callable[..., Any]: + """Return a factory that builds mock AgentRunResult objects.""" + + def _make( + *, + content: str = "Hello from the agent.", + response_id: str = "response-123", + input_tokens: int = 10, + output_tokens: int = 5, + finish_reason: FinishReason | None = "stop", + model_response: ModelResponse | None = None, + new_messages: list[Any] | None = None, + ) -> Any: + if model_response is None: + model_response = ModelResponse( + parts=[TextPart(content=content)], + finish_reason=finish_reason, + provider_response_id=response_id, + ) + messages = new_messages if new_messages is not None else [model_response] + run_result = mocker.MagicMock() + run_result.response = model_response + run_result.usage = RunUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + requests=1, + ) + run_result.new_messages.return_value = messages + return run_result + + return _make + + +@pytest.fixture(name="make_responses_params") +def make_responses_params_fixture() -> Callable[..., ResponsesApiParams]: + """Return a factory that builds ResponsesApiParams for agent query tests.""" + + def _make( + *, + model: str = "provider1/model1", + input_text: ResponseInput = "What is OpenShift?", + conversation: str | None = "conv_123", + ) -> ResponsesApiParams: + return ResponsesApiParams.model_validate( + { + "model": model, + "input": input_text, + "conversation": conversation, + "stream": False, + "store": True, + } + ) + + return _make + + +@pytest.fixture(name="responses_params") +def responses_params_fixture( + make_responses_params: Callable[..., ResponsesApiParams], +) -> ResponsesApiParams: + """Default ResponsesApiParams for agent query tests.""" + return make_responses_params() + + +@pytest.fixture(name="blocked_moderation") +def blocked_moderation_fixture() -> ShieldModerationBlocked: + """Blocked shield moderation result for tests.""" + return ShieldModerationBlocked( + message="Content blocked by shield.", + moderation_id="modr-test-456", + refusal_response=ResponseMessage( + role="assistant", + content="Content blocked by shield.", + ), + ) + + +@pytest.fixture(name="patch_recording_metrics") +def patch_recording_metrics_fixture(mocker: MockerFixture) -> None: + """Patch LLM recording helpers so token usage tests stay isolated.""" + mock_config = mocker.MagicMock() + mock_config.rag_id_mapping = {} + mocker.patch("utils.agents.query.configuration", mock_config) + mocker.patch( + "utils.agents.query.extract_vector_store_ids_from_tools", + return_value=[], + ) + mocker.patch("utils.agents.query.recording.record_llm_token_usage") + mocker.patch("utils.agents.query.recording.record_llm_call") + + +class TestGetAgentFinishReason: + """Tests for get_agent_finish_reason.""" + + def test_returns_success_for_stop(self) -> None: + """Test a normal stop finish reason maps to SUCCESS.""" + response = ModelResponse( + parts=[TextPart(content="done")], + finish_reason="stop", + ) + + assert get_agent_finish_reason(response) == AgentFinishReason.SUCCESS + + def test_returns_length_finish_reason(self) -> None: + """Test length finish reason is preserved.""" + response = ModelResponse( + parts=[TextPart(content="truncated")], + finish_reason="length", + ) + + assert get_agent_finish_reason(response) == AgentFinishReason.LENGTH + + def test_returns_content_filter_finish_reason(self) -> None: + """Test content_filter finish reason is preserved.""" + response = ModelResponse( + parts=[TextPart(content="")], + finish_reason="content_filter", + ) + + assert get_agent_finish_reason(response) == AgentFinishReason.CONTENT_FILTER + + def test_returns_error_when_finish_reason_missing(self) -> None: + """Test missing finish reason maps to ERROR.""" + response = ModelResponse( + parts=[TextPart(content="partial")], + finish_reason=None, + ) + + assert get_agent_finish_reason(response) == AgentFinishReason.ERROR + + def test_returns_cancelled_from_provider_details(self) -> None: + """Test provider_details cancelled overrides the model finish reason.""" + response = ModelResponse( + parts=[TextPart(content="partial")], + finish_reason="stop", + provider_details={"finish_reason": "cancelled"}, + ) + + assert get_agent_finish_reason(response) == AgentFinishReason.CANCELLED + + +class TestExtractAgentTokenUsage: + """Tests for extract_agent_token_usage.""" + + def test_builds_token_counter_from_usage( + self, patch_recording_metrics: None + ) -> None: + """Test token usage is mapped from RunUsage to TokenCounter.""" + usage = RunUsage(input_tokens=42, output_tokens=17, requests=2) + + counter = extract_agent_token_usage( + usage, + model="provider1/model1", + endpoint_path=ENDPOINT_PATH_QUERY, + ) + + assert counter == TokenCounter( + input_tokens=42, + output_tokens=17, + llm_calls=2, + ) + + def test_llm_calls_minimum_one_when_requests_zero( + self, patch_recording_metrics: None + ) -> None: + """Test llm_calls is at least 1 when the agent reports zero requests.""" + usage = RunUsage(input_tokens=1, output_tokens=1, requests=0) + + counter = extract_agent_token_usage( + usage, + model="provider1/model1", + endpoint_path=ENDPOINT_PATH_QUERY, + ) + + assert counter.llm_calls == 1 + + def test_records_metrics(self, mocker: MockerFixture) -> None: + """Test LLM token usage and call metrics are recorded.""" + mock_record_usage = mocker.patch( + "utils.agents.query.recording.record_llm_token_usage" + ) + mock_record_call = mocker.patch("utils.agents.query.recording.record_llm_call") + usage = RunUsage(input_tokens=8, output_tokens=3, requests=1) + + extract_agent_token_usage( + usage, + model="my-provider/my-model", + endpoint_path=ENDPOINT_PATH_QUERY, + ) + + mock_record_usage.assert_called_once_with( + "my-provider", + "my-model", + 8, + 3, + ENDPOINT_PATH_QUERY, + ) + mock_record_call.assert_called_once_with( + "my-provider", + "my-model", + ENDPOINT_PATH_QUERY, + ) + + +class TestBuildTurnSummaryFromAgentRun: + """Tests for build_turn_summary_from_agent_run.""" + + def test_builds_summary_from_text_response( + self, + make_agent_run_result: Callable[..., Any], + patch_recording_metrics: None, + ) -> None: + """Test a successful agent run produces text, tools, id, and token usage.""" + llm_text = "OpenShift is a Kubernetes distribution." + file_search_call = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": ["OpenShift"]}, + tool_call_id="fs-1", + ) + file_search_return = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="fs-1", + content={"status": "success", "results": []}, + ) + function_call = ToolCallPart( + tool_name="calculate", + args={"x": 1}, + tool_call_id="fn-1", + ) + function_return = ToolReturnPart( + tool_name="calculate", + content={"result": 2}, + tool_call_id="fn-1", + ) + model_response = ModelResponse( + parts=[ + file_search_call, + file_search_return, + function_call, + TextPart(llm_text), + ], + finish_reason="stop", + provider_response_id="resp-agent-1", + ) + run_result = make_agent_run_result( + input_tokens=20, + output_tokens=12, + model_response=model_response, + new_messages=[model_response, ModelRequest(parts=[function_return])], + ) + + summary = build_turn_summary_from_agent_run( + run_result, + model_id="provider1/model1", + endpoint_path=ENDPOINT_PATH_QUERY, + vector_store_ids=[], + rag_id_mapping={}, + ) + + assert summary.llm_response == llm_text + assert summary.id == "resp-agent-1" + assert summary.token_usage.input_tokens == 20 + assert summary.token_usage.output_tokens == 12 + assert summary.token_usage.llm_calls == 1 + assert len(summary.tool_calls) == 2 + assert {call.name for call in summary.tool_calls} == { + "calculate", + "file_search", + } + assert len(summary.tool_results) == 2 + assert {result.type for result in summary.tool_results} == { + "file_search_call", + "function_call_output", + } + + def test_raises_http_exception_on_length_finish_reason( + self, + make_agent_run_result: Callable[..., Any], + ) -> None: + """Test non-success finish reason raises HTTPException.""" + run_result = make_agent_run_result(finish_reason="length") + + with pytest.raises(HTTPException) as exc_info: + build_turn_summary_from_agent_run( + run_result, + model_id="provider1/model1", + endpoint_path=ENDPOINT_PATH_QUERY, + vector_store_ids=[], + rag_id_mapping={}, + ) + + assert exc_info.value.status_code == 413 + + def test_raises_http_exception_on_missing_finish_reason( + self, + make_agent_run_result: Callable[..., Any], + ) -> None: + """Test missing finish reason is treated as an error.""" + run_result = make_agent_run_result( + content="partial", + response_id="resp-error", + finish_reason=None, + ) + + with pytest.raises(HTTPException) as exc_info: + build_turn_summary_from_agent_run( + run_result, + model_id="provider1/model1", + endpoint_path=ENDPOINT_PATH_QUERY, + vector_store_ids=[], + rag_id_mapping={}, + ) + + assert exc_info.value.status_code == 500 + + +class TestRetrieveAgentResponse: + """Tests for retrieve_agent_response.""" + + @pytest.mark.asyncio + async def test_blocked_moderation_returns_refusal_summary( + self, + mocker: MockerFixture, + responses_params: ResponsesApiParams, + blocked_moderation: ShieldModerationBlocked, + ) -> None: + """Test blocked moderation persists refusal and returns a turn summary.""" + mock_client = mocker.AsyncMock() + mock_append = mocker.patch( + "utils.agents.query.append_turn_items_to_conversation", + new=mocker.AsyncMock(), + ) + + summary = await retrieve_agent_response( + client=mock_client, + responses_params=responses_params, + moderation_result=blocked_moderation, + endpoint_path=ENDPOINT_PATH_QUERY, + ) + + mock_append.assert_awaited_once_with( + mock_client, + responses_params.conversation, + responses_params.input, + [blocked_moderation.refusal_response], + ) + assert summary == TurnSummary( + id="modr-test-456", + llm_response="Content blocked by shield.", + ) + + @pytest.mark.asyncio + async def test_success_returns_turn_summary( + self, + mocker: MockerFixture, + make_agent_run_result: Callable[..., Any], + make_responses_params: Callable[..., ResponsesApiParams], + patch_recording_metrics: None, + ) -> None: + """Test a successful agent run returns a populated turn summary.""" + run_result = make_agent_run_result( + content="Hello!", + response_id="resp-success", + ) + mock_agent = mocker.AsyncMock() + mock_agent.run = mocker.AsyncMock(return_value=run_result) + mocker.patch( + "utils.agents.query.build_agent", + return_value=mock_agent, + ) + + summary = await retrieve_agent_response( + client=mocker.AsyncMock(), + responses_params=make_responses_params(input_text="Say hello"), + moderation_result=ShieldModerationPassed(), + endpoint_path=ENDPOINT_PATH_QUERY, + ) + + mock_agent.run.assert_awaited_once_with("Say hello") + assert summary.llm_response == "Hello!" + assert summary.id == "resp-success" + + @pytest.mark.asyncio + async def test_agent_connection_error_raises_http_exception( + self, + mocker: MockerFixture, + responses_params: ResponsesApiParams, + ) -> None: + """Test Llama Stack connection errors are mapped to HTTPException.""" + mock_agent = mocker.AsyncMock() + mock_agent.run = mocker.AsyncMock( + side_effect=APIConnectionError(request=mocker.Mock()) + ) + mocker.patch( + "utils.agents.query.build_agent", + return_value=mock_agent, + ) + + with pytest.raises(HTTPException) as exc_info: + await retrieve_agent_response( + client=mocker.AsyncMock(), + responses_params=responses_params, + moderation_result=ShieldModerationPassed(), + endpoint_path=ENDPOINT_PATH_QUERY, + ) + + assert exc_info.value.status_code == 503 + + @pytest.mark.asyncio + async def test_api_status_error_raises_http_exception( + self, + mocker: MockerFixture, + responses_params: ResponsesApiParams, + ) -> None: + """Test API status errors from the agent run are mapped to HTTPException.""" + mock_agent = mocker.AsyncMock() + mock_agent.run = mocker.AsyncMock( + side_effect=APIStatusError( + message="quota exceeded", + response=mocker.Mock(), + body=None, + ) + ) + mocker.patch( + "utils.agents.query.build_agent", + return_value=mock_agent, + ) + mock_error = mocker.Mock() + mock_error.model_dump.return_value = { + "status_code": 429, + "detail": {"response": "Quota exceeded", "cause": "quota exceeded"}, + } + mocker.patch( + "utils.agents.query.handle_known_apistatus_errors", + return_value=mock_error, + ) + + with pytest.raises(HTTPException) as exc_info: + await retrieve_agent_response( + client=mocker.AsyncMock(), + responses_params=responses_params, + moderation_result=ShieldModerationPassed(), + endpoint_path=ENDPOINT_PATH_QUERY, + ) + + assert exc_info.value.status_code == 429