Skip to content

Commit dd950ae

Browse files
committed
Added non-streaming agent query utilities
1 parent aa29526 commit dd950ae

2 files changed

Lines changed: 813 additions & 0 deletions

File tree

src/utils/agents/query.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
"""Non-streaming agent helpers and shared turn-summary builders for agent runs."""
2+
3+
from __future__ import annotations
4+
5+
from enum import Enum
6+
from typing import TypeAlias, cast
7+
8+
from fastapi import HTTPException
9+
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
10+
from pydantic_ai.exceptions import (
11+
AgentRunError,
12+
ContentFilterError,
13+
IncompleteToolCall,
14+
ModelAPIError,
15+
ModelHTTPError,
16+
UnexpectedModelBehavior,
17+
UsageLimitExceeded,
18+
)
19+
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart
20+
from pydantic_ai.run import AgentRunResult
21+
from pydantic_ai.usage import RunUsage
22+
23+
from configuration import configuration
24+
from log import get_logger
25+
from metrics import recording
26+
from models.api.responses.error import (
27+
AbstractErrorResponse,
28+
InternalServerErrorResponse,
29+
PromptTooLongResponse,
30+
QuotaExceededResponse,
31+
ServiceUnavailableResponse,
32+
)
33+
from models.common.agents import AgentTurnAccumulator
34+
from models.common.moderation import ShieldModerationResult
35+
from models.common.responses.responses_api_params import ResponsesApiParams
36+
from models.common.turn_summary import TurnSummary
37+
from utils.agents.tool_processor import (
38+
process_function_tool_call,
39+
process_function_tool_result,
40+
process_native_tool_call,
41+
process_native_tool_result,
42+
)
43+
from utils.conversations import append_turn_items_to_conversation
44+
from utils.pydantic_ai import build_agent
45+
from utils.query import (
46+
extract_provider_and_model_from_model_id,
47+
handle_known_apistatus_errors,
48+
is_context_length_error,
49+
)
50+
from utils.responses import extract_vector_store_ids_from_tools
51+
from utils.token_counter import TokenCounter
52+
53+
logger = get_logger(__name__)
54+
55+
AgentInferenceError: TypeAlias = (
56+
AgentRunError | APIStatusError | APIConnectionError | RuntimeError
57+
)
58+
59+
60+
class AgentFinishReason(str, Enum):
61+
"""Finish reason for a completed agent model response."""
62+
63+
CONTENT_FILTER = "content_filter"
64+
CANCELLED = "cancelled"
65+
SUCCESS = "stop"
66+
LENGTH = "length"
67+
ERROR = "error"
68+
69+
70+
def map_agent_inference_error(
71+
exc: AgentInferenceError,
72+
model_id: str,
73+
) -> AbstractErrorResponse:
74+
"""Map agent run failures from pydantic-ai or Llama Stack to an LCS error response.
75+
76+
Args:
77+
exc: Agent, HTTP status, connection, or context-length runtime error.
78+
model_id: Model identifier in provider/model format.
79+
80+
Returns:
81+
Structured error response for HTTP or SSE error events.
82+
83+
Raises:
84+
RuntimeError: Re-raised when ``exc`` is a non-agent ``RuntimeError`` that is
85+
not a recognized context-length failure.
86+
"""
87+
if isinstance(exc, AgentRunError):
88+
return map_pydantic_agent_run_error(exc, model_id)
89+
if isinstance(exc, APIStatusError):
90+
return handle_known_apistatus_errors(exc, model_id)
91+
if isinstance(exc, APIConnectionError):
92+
return ServiceUnavailableResponse(
93+
backend_name="Llama Stack",
94+
cause=str(exc),
95+
)
96+
if isinstance(exc, RuntimeError) and is_context_length_error(str(exc)):
97+
return PromptTooLongResponse(model=model_id)
98+
return InternalServerErrorResponse.generic()
99+
100+
101+
def map_pydantic_agent_run_error(
102+
exc: AgentRunError, model_id: str
103+
) -> AbstractErrorResponse:
104+
"""Map pydantic-ai ``AgentRunError`` subclasses to LCS error responses.
105+
106+
Args:
107+
exc: Agent exception to map.
108+
model_id: Model identifier in provider/model format.
109+
110+
Returns:
111+
Structured error response for HTTP or SSE error events.
112+
"""
113+
if isinstance(exc, ContentFilterError):
114+
return InternalServerErrorResponse.query_failed(str(exc))
115+
if isinstance(exc, IncompleteToolCall):
116+
return PromptTooLongResponse(model=model_id)
117+
if isinstance(exc, UnexpectedModelBehavior):
118+
return PromptTooLongResponse(model=model_id)
119+
if isinstance(exc, UsageLimitExceeded):
120+
return QuotaExceededResponse.model(model_id)
121+
if isinstance(exc, ModelHTTPError):
122+
if is_context_length_error(str(exc)):
123+
return PromptTooLongResponse(model=model_id)
124+
if exc.status_code == 429:
125+
return QuotaExceededResponse.model(model_id)
126+
return InternalServerErrorResponse.generic()
127+
if isinstance(exc, ModelAPIError):
128+
return ServiceUnavailableResponse(
129+
backend_name="Llama Stack",
130+
cause=str(exc),
131+
)
132+
return InternalServerErrorResponse.query_failed(str(exc))
133+
134+
135+
def get_agent_finish_reason(response: ModelResponse) -> AgentFinishReason:
136+
"""Get the finish reason from a completed agent model response.
137+
138+
Args:
139+
response: Last model response from the agent run.
140+
141+
Returns:
142+
Resolved finish reason.
143+
"""
144+
raw_finish_reason = (response.provider_details or {}).get("finish_reason")
145+
if raw_finish_reason == "cancelled":
146+
return AgentFinishReason.CANCELLED
147+
if response.finish_reason is None:
148+
return AgentFinishReason.ERROR
149+
return AgentFinishReason(response.finish_reason)
150+
151+
152+
def get_finish_reason_error(
153+
finish_reason: AgentFinishReason,
154+
model_id: str,
155+
) -> AbstractErrorResponse:
156+
"""Map a non-success agent finish reason to an LCS error response.
157+
158+
Args:
159+
finish_reason: Resolved finish reason from :func:`get_agent_finish_reason`.
160+
model_id: Model identifier in provider/model format.
161+
162+
Returns:
163+
Structured error response for HTTP or SSE error events.
164+
"""
165+
match finish_reason:
166+
case AgentFinishReason.LENGTH:
167+
return PromptTooLongResponse(model=model_id)
168+
case AgentFinishReason.CONTENT_FILTER:
169+
return InternalServerErrorResponse.query_failed(
170+
"The model refused to generate a response due to content policy."
171+
)
172+
case AgentFinishReason.CANCELLED:
173+
return InternalServerErrorResponse.query_failed(
174+
"The response was cancelled before completion."
175+
)
176+
case _:
177+
return InternalServerErrorResponse.query_failed(
178+
"An unexpected error occurred while processing the request."
179+
)
180+
181+
182+
def extract_agent_token_usage(
183+
usage: RunUsage,
184+
model: str,
185+
endpoint_path: str,
186+
) -> TokenCounter:
187+
"""Build token usage for a completed agent run and record related metrics.
188+
189+
Args:
190+
usage: Run usage reported by the agent.
191+
model: Model identifier in provider/model format.
192+
endpoint_path: Endpoint path used for metric labeling.
193+
194+
Returns:
195+
Aggregated token usage counter for the run.
196+
"""
197+
provider_id, model_id = extract_provider_and_model_from_model_id(model)
198+
token_counter = TokenCounter(
199+
input_tokens=usage.input_tokens,
200+
output_tokens=usage.output_tokens,
201+
llm_calls=max(usage.requests, 1),
202+
)
203+
logger.debug(
204+
"Extracted token usage from agent run: input=%d, output=%d, requests=%d",
205+
token_counter.input_tokens,
206+
token_counter.output_tokens,
207+
usage.requests,
208+
)
209+
recording.record_llm_token_usage(
210+
provider_id,
211+
model_id,
212+
token_counter.input_tokens,
213+
token_counter.output_tokens,
214+
endpoint_path,
215+
)
216+
recording.record_llm_call(provider_id, model_id, endpoint_path)
217+
return token_counter
218+
219+
220+
def build_turn_summary_from_agent_run(
221+
run_result: AgentRunResult[str],
222+
*,
223+
model_id: str,
224+
endpoint_path: str,
225+
vector_store_ids: list[str],
226+
rag_id_mapping: dict[str, str],
227+
) -> TurnSummary:
228+
"""Build a turn summary from a completed agent run.
229+
230+
Args:
231+
run_result: Completed agent run result.
232+
model_id: Model identifier in provider/model format.
233+
endpoint_path: Endpoint path used for metric labeling.
234+
vector_store_ids: Vector store IDs used for source mapping.
235+
rag_id_mapping: Mapping from vector store IDs to user-facing source labels.
236+
237+
Returns:
238+
Turn summary with text, tools, RAG metadata, and token usage.
239+
240+
Raises:
241+
HTTPException: When the run failed.
242+
"""
243+
finish_reason = get_agent_finish_reason(run_result.response)
244+
if finish_reason != AgentFinishReason.SUCCESS:
245+
error_response = get_finish_reason_error(finish_reason, model_id)
246+
raise HTTPException(**error_response.model_dump())
247+
248+
state = AgentTurnAccumulator(
249+
vector_store_ids=vector_store_ids,
250+
rag_id_mapping=rag_id_mapping,
251+
turn_summary=TurnSummary(),
252+
)
253+
254+
for message in run_result.new_messages():
255+
if isinstance(message, ModelResponse):
256+
if message.text:
257+
state.turn_summary.llm_response = message.text
258+
for tool_call_part in message.tool_calls:
259+
process_function_tool_call(state, tool_call_part)
260+
for call_part, return_part in message.native_tool_calls:
261+
process_native_tool_call(state, call_part)
262+
process_native_tool_result(state, return_part)
263+
elif isinstance(message, ModelRequest):
264+
for request_part in message.parts:
265+
if isinstance(request_part, ToolReturnPart):
266+
process_function_tool_result(state, request_part)
267+
268+
state.turn_summary.id = run_result.response.provider_response_id or ""
269+
state.turn_summary.token_usage = extract_agent_token_usage(
270+
run_result.usage,
271+
model_id,
272+
endpoint_path,
273+
)
274+
return state.turn_summary
275+
276+
277+
async def retrieve_agent_response(
278+
client: AsyncLlamaStackClient,
279+
responses_params: ResponsesApiParams,
280+
moderation_result: ShieldModerationResult,
281+
endpoint_path: str,
282+
) -> TurnSummary:
283+
"""Retrieve a turn summary from a blocking agent run.
284+
285+
Mirrors :func:`app.endpoints.query.retrieve_response` for the agent path.
286+
287+
Args:
288+
client: Llama Stack client for conversation persistence on moderation block.
289+
responses_params: Prepared Responses API parameters.
290+
moderation_result: Shield moderation outcome for the turn.
291+
endpoint_path: Endpoint path used for metric labeling.
292+
293+
Returns:
294+
Turn summary for the completed agent run.
295+
296+
Raises:
297+
HTTPException: On moderation is not applicable; on agent or provider failure.
298+
"""
299+
if moderation_result.decision == "blocked":
300+
await append_turn_items_to_conversation(
301+
client,
302+
responses_params.conversation,
303+
responses_params.input,
304+
[moderation_result.refusal_response],
305+
)
306+
return TurnSummary(
307+
id=moderation_result.moderation_id,
308+
llm_response=moderation_result.message,
309+
)
310+
try:
311+
agent = build_agent(client, responses_params)
312+
logger.debug("Starting agent non-streaming response processing")
313+
run_result = await agent.run(cast(str, responses_params.input))
314+
except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc:
315+
response = map_agent_inference_error(exc, responses_params.model)
316+
raise HTTPException(**response.model_dump()) from exc
317+
318+
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
319+
rag_id_mapping = configuration.rag_id_mapping
320+
return build_turn_summary_from_agent_run(
321+
run_result,
322+
model_id=responses_params.model,
323+
endpoint_path=endpoint_path,
324+
vector_store_ids=vector_store_ids,
325+
rag_id_mapping=rag_id_mapping,
326+
)

0 commit comments

Comments
 (0)