Skip to content

Commit 0386bad

Browse files
committed
Added non-streaming agent query utilities
1 parent aa29526 commit 0386bad

2 files changed

Lines changed: 815 additions & 0 deletions

File tree

src/utils/agents/query.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
"""Non-streaming agent helpers and shared turn-summary builders for agent runs."""
2+
3+
from __future__ import annotations
4+
5+
from enum import Enum
6+
from typing import TypeAlias, cast
7+
8+
from fastapi import HTTPException
9+
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
10+
from pydantic_ai.exceptions import (
11+
AgentRunError,
12+
ContentFilterError,
13+
IncompleteToolCall,
14+
ModelAPIError,
15+
ModelHTTPError,
16+
UnexpectedModelBehavior,
17+
UsageLimitExceeded,
18+
)
19+
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart
20+
from pydantic_ai.run import AgentRunResult
21+
from pydantic_ai.usage import RunUsage
22+
23+
from configuration import configuration
24+
from log import get_logger
25+
from metrics import recording
26+
from models.api.responses.error import (
27+
AbstractErrorResponse,
28+
InternalServerErrorResponse,
29+
PromptTooLongResponse,
30+
QuotaExceededResponse,
31+
ServiceUnavailableResponse,
32+
)
33+
from models.common.agents import AgentTurnAccumulator
34+
from models.common.moderation import ShieldModerationResult
35+
from models.common.responses.responses_api_params import ResponsesApiParams
36+
from models.common.turn_summary import TurnSummary
37+
from utils.agents.tool_processor import (
38+
process_function_tool_call,
39+
process_function_tool_result,
40+
process_native_tool_call,
41+
process_native_tool_result,
42+
)
43+
from utils.conversations import append_turn_items_to_conversation
44+
from utils.pydantic_ai import build_agent
45+
from utils.query import (
46+
extract_provider_and_model_from_model_id,
47+
handle_known_apistatus_errors,
48+
is_context_length_error,
49+
)
50+
from utils.responses import extract_vector_store_ids_from_tools
51+
from utils.token_counter import TokenCounter
52+
53+
logger = get_logger(__name__)
54+
55+
AgentInferenceError: TypeAlias = (
56+
AgentRunError | APIStatusError | APIConnectionError | RuntimeError
57+
)
58+
59+
60+
class AgentFinishReason(str, Enum):
61+
"""Finish reason for a completed agent model response."""
62+
63+
CONTENT_FILTER = "content_filter"
64+
CANCELLED = "cancelled"
65+
SUCCESS = "stop"
66+
LENGTH = "length"
67+
ERROR = "error"
68+
69+
70+
def map_agent_inference_error(
71+
exc: AgentInferenceError,
72+
model_id: str,
73+
) -> AbstractErrorResponse:
74+
"""Map agent run failures from pydantic-ai or Llama Stack to an LCS error response.
75+
76+
Args:
77+
exc: Agent, HTTP status, connection, or context-length runtime error.
78+
model_id: Model identifier in provider/model format.
79+
80+
Returns:
81+
Structured error response for HTTP or SSE error events.
82+
83+
Raises:
84+
RuntimeError: Re-raised when ``exc`` is a non-agent ``RuntimeError`` that is
85+
not a recognized context-length failure.
86+
"""
87+
if isinstance(exc, AgentRunError):
88+
return map_pydantic_agent_run_error(exc, model_id)
89+
if isinstance(exc, APIStatusError):
90+
return handle_known_apistatus_errors(exc, model_id)
91+
if isinstance(exc, APIConnectionError):
92+
return ServiceUnavailableResponse(
93+
backend_name="Llama Stack",
94+
cause=str(exc),
95+
)
96+
if isinstance(exc, RuntimeError) and is_context_length_error(str(exc)):
97+
return PromptTooLongResponse(model=model_id)
98+
return InternalServerErrorResponse.generic()
99+
100+
101+
def map_pydantic_agent_run_error(
102+
exc: AgentRunError, model_id: str
103+
) -> AbstractErrorResponse:
104+
"""Map pydantic-ai ``AgentRunError`` subclasses to LCS error responses.
105+
106+
Args:
107+
exc: Agent exception to map.
108+
model_id: Model identifier in provider/model format.
109+
110+
Returns:
111+
Structured error response for HTTP or SSE error events.
112+
"""
113+
if isinstance(exc, ContentFilterError):
114+
return InternalServerErrorResponse.query_failed(str(exc))
115+
if isinstance(exc, IncompleteToolCall):
116+
return PromptTooLongResponse(model=model_id)
117+
if isinstance(exc, UnexpectedModelBehavior):
118+
if is_context_length_error(str(exc)):
119+
return PromptTooLongResponse(model=model_id)
120+
return InternalServerErrorResponse.query_failed(str(exc))
121+
if isinstance(exc, UsageLimitExceeded):
122+
return QuotaExceededResponse.model(model_id)
123+
if isinstance(exc, ModelHTTPError):
124+
if is_context_length_error(str(exc)):
125+
return PromptTooLongResponse(model=model_id)
126+
if exc.status_code == 429:
127+
return QuotaExceededResponse.model(model_id)
128+
return InternalServerErrorResponse.generic()
129+
if isinstance(exc, ModelAPIError):
130+
return ServiceUnavailableResponse(
131+
backend_name="Llama Stack",
132+
cause=str(exc),
133+
)
134+
return InternalServerErrorResponse.query_failed(str(exc))
135+
136+
137+
def get_agent_finish_reason(response: ModelResponse) -> AgentFinishReason:
138+
"""Get the finish reason from a completed agent model response.
139+
140+
Args:
141+
response: Last model response from the agent run.
142+
143+
Returns:
144+
Resolved finish reason.
145+
"""
146+
raw_finish_reason = (response.provider_details or {}).get("finish_reason")
147+
if raw_finish_reason == "cancelled":
148+
return AgentFinishReason.CANCELLED
149+
if response.finish_reason is None:
150+
return AgentFinishReason.ERROR
151+
return AgentFinishReason(response.finish_reason)
152+
153+
154+
def get_finish_reason_error(
155+
finish_reason: AgentFinishReason,
156+
model_id: str,
157+
) -> AbstractErrorResponse:
158+
"""Map a non-success agent finish reason to an LCS error response.
159+
160+
Args:
161+
finish_reason: Resolved finish reason from :func:`get_agent_finish_reason`.
162+
model_id: Model identifier in provider/model format.
163+
164+
Returns:
165+
Structured error response for HTTP or SSE error events.
166+
"""
167+
match finish_reason:
168+
case AgentFinishReason.LENGTH:
169+
return PromptTooLongResponse(model=model_id)
170+
case AgentFinishReason.CONTENT_FILTER:
171+
return InternalServerErrorResponse.query_failed(
172+
"The model refused to generate a response due to content policy."
173+
)
174+
case AgentFinishReason.CANCELLED:
175+
return InternalServerErrorResponse.query_failed(
176+
"The response was cancelled before completion."
177+
)
178+
case _:
179+
return InternalServerErrorResponse.query_failed(
180+
"An unexpected error occurred while processing the request."
181+
)
182+
183+
184+
def extract_agent_token_usage(
185+
usage: RunUsage,
186+
model: str,
187+
endpoint_path: str,
188+
) -> TokenCounter:
189+
"""Build token usage for a completed agent run and record related metrics.
190+
191+
Args:
192+
usage: Run usage reported by the agent.
193+
model: Model identifier in provider/model format.
194+
endpoint_path: Endpoint path used for metric labeling.
195+
196+
Returns:
197+
Aggregated token usage counter for the run.
198+
"""
199+
provider_id, model_id = extract_provider_and_model_from_model_id(model)
200+
token_counter = TokenCounter(
201+
input_tokens=usage.input_tokens,
202+
output_tokens=usage.output_tokens,
203+
llm_calls=max(usage.requests, 1),
204+
)
205+
logger.debug(
206+
"Extracted token usage from agent run: input=%d, output=%d, requests=%d",
207+
token_counter.input_tokens,
208+
token_counter.output_tokens,
209+
usage.requests,
210+
)
211+
recording.record_llm_token_usage(
212+
provider_id,
213+
model_id,
214+
token_counter.input_tokens,
215+
token_counter.output_tokens,
216+
endpoint_path,
217+
)
218+
recording.record_llm_call(provider_id, model_id, endpoint_path)
219+
return token_counter
220+
221+
222+
def build_turn_summary_from_agent_run(
223+
run_result: AgentRunResult[str],
224+
*,
225+
model_id: str,
226+
endpoint_path: str,
227+
vector_store_ids: list[str],
228+
rag_id_mapping: dict[str, str],
229+
) -> TurnSummary:
230+
"""Build a turn summary from a completed agent run.
231+
232+
Args:
233+
run_result: Completed agent run result.
234+
model_id: Model identifier in provider/model format.
235+
endpoint_path: Endpoint path used for metric labeling.
236+
vector_store_ids: Vector store IDs used for source mapping.
237+
rag_id_mapping: Mapping from vector store IDs to user-facing source labels.
238+
239+
Returns:
240+
Turn summary with text, tools, RAG metadata, and token usage.
241+
242+
Raises:
243+
HTTPException: When the run failed.
244+
"""
245+
finish_reason = get_agent_finish_reason(run_result.response)
246+
if finish_reason != AgentFinishReason.SUCCESS:
247+
error_response = get_finish_reason_error(finish_reason, model_id)
248+
raise HTTPException(**error_response.model_dump())
249+
250+
state = AgentTurnAccumulator(
251+
vector_store_ids=vector_store_ids,
252+
rag_id_mapping=rag_id_mapping,
253+
turn_summary=TurnSummary(),
254+
)
255+
256+
for message in run_result.new_messages():
257+
if isinstance(message, ModelResponse):
258+
if message.text:
259+
state.turn_summary.llm_response = message.text
260+
for tool_call_part in message.tool_calls:
261+
process_function_tool_call(state, tool_call_part)
262+
for call_part, return_part in message.native_tool_calls:
263+
process_native_tool_call(state, call_part)
264+
process_native_tool_result(state, return_part)
265+
elif isinstance(message, ModelRequest):
266+
for request_part in message.parts:
267+
if isinstance(request_part, ToolReturnPart):
268+
process_function_tool_result(state, request_part)
269+
270+
state.turn_summary.id = run_result.response.provider_response_id or ""
271+
state.turn_summary.token_usage = extract_agent_token_usage(
272+
run_result.usage,
273+
model_id,
274+
endpoint_path,
275+
)
276+
return state.turn_summary
277+
278+
279+
async def retrieve_agent_response(
280+
client: AsyncLlamaStackClient,
281+
responses_params: ResponsesApiParams,
282+
moderation_result: ShieldModerationResult,
283+
endpoint_path: str,
284+
) -> TurnSummary:
285+
"""Retrieve a turn summary from a blocking agent run.
286+
287+
Mirrors :func:`app.endpoints.query.retrieve_response` for the agent path.
288+
289+
Args:
290+
client: Llama Stack client for conversation persistence on moderation block.
291+
responses_params: Prepared Responses API parameters.
292+
moderation_result: Shield moderation outcome for the turn.
293+
endpoint_path: Endpoint path used for metric labeling.
294+
295+
Returns:
296+
Turn summary for the completed agent run.
297+
298+
Raises:
299+
HTTPException: On moderation is not applicable; on agent or provider failure.
300+
"""
301+
if moderation_result.decision == "blocked":
302+
await append_turn_items_to_conversation(
303+
client,
304+
responses_params.conversation,
305+
responses_params.input,
306+
[moderation_result.refusal_response],
307+
)
308+
return TurnSummary(
309+
id=moderation_result.moderation_id,
310+
llm_response=moderation_result.message,
311+
)
312+
try:
313+
agent = build_agent(client, responses_params)
314+
logger.debug("Starting agent non-streaming response processing")
315+
run_result = await agent.run(cast(str, responses_params.input))
316+
except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc:
317+
response = map_agent_inference_error(exc, responses_params.model)
318+
raise HTTPException(**response.model_dump()) from exc
319+
320+
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
321+
rag_id_mapping = configuration.rag_id_mapping
322+
return build_turn_summary_from_agent_run(
323+
run_result,
324+
model_id=responses_params.model,
325+
endpoint_path=endpoint_path,
326+
vector_store_ids=vector_store_ids,
327+
rag_id_mapping=rag_id_mapping,
328+
)

0 commit comments

Comments
 (0)