Skip to content

Commit 736f162

Browse files
committed
Wired query agent.run
1 parent 9105412 commit 736f162

6 files changed

Lines changed: 703 additions & 267 deletions

File tree

src/app/endpoints/query.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from models.common.responses.responses_api_params import ResponsesApiParams
4242
from models.common.turn_summary import TurnSummary
4343
from models.config import Action
44+
from utils.agents.query import retrieve_agent_response
4445
from utils.conversations import append_turn_items_to_conversation
4546
from utils.endpoints import (
4647
check_configuration_loaded,
@@ -206,8 +207,13 @@ async def query_endpoint_handler(
206207
client = await AsyncLlamaStackClientHolder().update_azure_token()
207208

208209
# Retrieve response using Responses API
209-
turn_summary = await retrieve_response(
210-
client, responses_params, moderation_result, endpoint_path
210+
turn_summary = await retrieve_agent_response(
211+
client=client,
212+
responses_params=responses_params,
213+
moderation_result=moderation_result,
214+
endpoint_path=endpoint_path,
215+
vector_store_ids=query_request.vector_store_ids or [],
216+
rag_id_mapping=configuration.rag_id_mapping or {},
211217
)
212218

213219
if moderation_result.decision == "passed":

src/utils/agents/query.py

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
"""Non-streaming agent helpers and shared turn-summary builders for agent runs."""
2+
3+
from __future__ import annotations
4+
5+
from enum import Enum
6+
from typing import TypeAlias
7+
8+
from fastapi import HTTPException
9+
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
10+
from pydantic_ai.exceptions import (
11+
AgentRunError,
12+
ContentFilterError,
13+
IncompleteToolCall,
14+
ModelAPIError,
15+
ModelHTTPError,
16+
UnexpectedModelBehavior,
17+
UsageLimitExceeded,
18+
)
19+
from pydantic_ai.messages import ModelRequest, ModelResponse, ToolReturnPart
20+
from pydantic_ai.run import AgentRunResult
21+
from pydantic_ai.usage import RunUsage
22+
23+
from log import get_logger
24+
from metrics import recording
25+
from models.api.responses.error import (
26+
AbstractErrorResponse,
27+
InternalServerErrorResponse,
28+
PromptTooLongResponse,
29+
QuotaExceededResponse,
30+
ServiceUnavailableResponse,
31+
)
32+
from models.common.agents import AgentTurnAccumulator
33+
from models.common.moderation import ShieldModerationResult
34+
from models.common.responses.responses_api_params import ResponsesApiParams
35+
from models.common.turn_summary import TurnSummary
36+
from utils.agents.tool_processor import (
37+
process_function_tool_call,
38+
process_function_tool_result,
39+
process_native_tool_call,
40+
process_native_tool_result,
41+
)
42+
from utils.conversations import append_turn_items_to_conversation
43+
from utils.pydantic_ai import build_agent
44+
from utils.query import (
45+
extract_provider_and_model_from_model_id,
46+
handle_known_apistatus_errors,
47+
is_context_length_error,
48+
)
49+
from utils.responses import extract_text_from_response_items
50+
from utils.token_counter import TokenCounter
51+
52+
logger = get_logger(__name__)
53+
54+
AgentInferenceError: TypeAlias = (
55+
AgentRunError | APIStatusError | APIConnectionError | RuntimeError
56+
)
57+
58+
59+
class AgentFinishReason(str, Enum):
60+
"""Finish reason for a completed agent model response."""
61+
62+
CONTENT_FILTER = "content_filter"
63+
CANCELLED = "cancelled"
64+
SUCCESS = "stop"
65+
LENGTH = "length"
66+
ERROR = "error"
67+
68+
69+
def map_agent_inference_error(
70+
exc: AgentInferenceError,
71+
model_id: str,
72+
) -> AbstractErrorResponse:
73+
"""Map agent run failures from pydantic-ai or Llama Stack to an LCS error response.
74+
75+
Args:
76+
exc: Agent, HTTP status, connection, or context-length runtime error.
77+
model_id: Model identifier in provider/model format.
78+
79+
Returns:
80+
Structured error response for HTTP or SSE error events.
81+
82+
Raises:
83+
RuntimeError: Re-raised when ``exc`` is a non-agent ``RuntimeError`` that is
84+
not a recognized context-length failure.
85+
"""
86+
if isinstance(exc, AgentRunError):
87+
return map_pydantic_agent_run_error(exc, model_id)
88+
if isinstance(exc, APIStatusError):
89+
return handle_known_apistatus_errors(exc, model_id)
90+
if isinstance(exc, APIConnectionError):
91+
return ServiceUnavailableResponse(
92+
backend_name="Llama Stack",
93+
cause=str(exc),
94+
)
95+
if isinstance(exc, RuntimeError) and is_context_length_error(str(exc)):
96+
return PromptTooLongResponse(model=model_id)
97+
return InternalServerErrorResponse.generic()
98+
99+
100+
def map_pydantic_agent_run_error(
101+
exc: AgentRunError, model_id: str
102+
) -> AbstractErrorResponse:
103+
"""Map pydantic-ai ``AgentRunError`` subclasses to LCS error responses.
104+
105+
Args:
106+
exc: Agent exception to map.
107+
model_id: Model identifier in provider/model format.
108+
109+
Returns:
110+
Structured error response for HTTP or SSE error events.
111+
"""
112+
if isinstance(exc, ContentFilterError):
113+
return InternalServerErrorResponse.query_failed(str(exc))
114+
if isinstance(exc, IncompleteToolCall):
115+
return PromptTooLongResponse(model=model_id)
116+
if isinstance(exc, UnexpectedModelBehavior):
117+
if is_context_length_error(str(exc)):
118+
return PromptTooLongResponse(model=model_id)
119+
return InternalServerErrorResponse.query_failed(str(exc))
120+
if isinstance(exc, UsageLimitExceeded):
121+
return QuotaExceededResponse.model(model_id)
122+
if isinstance(exc, ModelHTTPError):
123+
if is_context_length_error(str(exc)):
124+
return PromptTooLongResponse(model=model_id)
125+
if exc.status_code == 429:
126+
return QuotaExceededResponse.model(model_id)
127+
return InternalServerErrorResponse.generic()
128+
if isinstance(exc, ModelAPIError):
129+
return ServiceUnavailableResponse(
130+
backend_name="Llama Stack",
131+
cause=str(exc),
132+
)
133+
return InternalServerErrorResponse.query_failed(str(exc))
134+
135+
136+
def get_agent_finish_reason(response: ModelResponse) -> AgentFinishReason:
137+
"""Get the finish reason from a completed agent model response.
138+
139+
Args:
140+
response: Last model response from the agent run.
141+
142+
Returns:
143+
Resolved finish reason.
144+
"""
145+
raw_finish_reason = (response.provider_details or {}).get("finish_reason")
146+
if raw_finish_reason == "cancelled":
147+
return AgentFinishReason.CANCELLED
148+
if response.finish_reason is None:
149+
return AgentFinishReason.ERROR
150+
return AgentFinishReason(response.finish_reason)
151+
152+
153+
def get_finish_reason_error(
154+
finish_reason: AgentFinishReason,
155+
model_id: str,
156+
) -> AbstractErrorResponse:
157+
"""Map a non-success agent finish reason to an LCS error response.
158+
159+
Args:
160+
finish_reason: Resolved finish reason from :func:`get_agent_finish_reason`.
161+
model_id: Model identifier in provider/model format.
162+
163+
Returns:
164+
Structured error response for HTTP or SSE error events.
165+
"""
166+
match finish_reason:
167+
case AgentFinishReason.LENGTH:
168+
return PromptTooLongResponse(model=model_id)
169+
case AgentFinishReason.CONTENT_FILTER:
170+
return InternalServerErrorResponse.query_failed(
171+
"The model refused to generate a response due to content policy."
172+
)
173+
case AgentFinishReason.CANCELLED:
174+
return InternalServerErrorResponse.query_failed(
175+
"The response was cancelled before completion."
176+
)
177+
case _:
178+
return InternalServerErrorResponse.query_failed(
179+
"An unexpected error occurred while processing the request."
180+
)
181+
182+
183+
def extract_agent_token_usage(
184+
usage: RunUsage,
185+
model: str,
186+
endpoint_path: str,
187+
) -> TokenCounter:
188+
"""Build token usage for a completed agent run and record related metrics.
189+
190+
Args:
191+
usage: Run usage reported by the agent.
192+
model: Model identifier in provider/model format.
193+
endpoint_path: Endpoint path used for metric labeling.
194+
195+
Returns:
196+
Aggregated token usage counter for the run.
197+
"""
198+
provider_id, model_id = extract_provider_and_model_from_model_id(model)
199+
token_counter = TokenCounter(
200+
input_tokens=usage.input_tokens,
201+
output_tokens=usage.output_tokens,
202+
llm_calls=max(usage.requests, 1),
203+
)
204+
logger.debug(
205+
"Extracted token usage from agent run: input=%d, output=%d, requests=%d",
206+
token_counter.input_tokens,
207+
token_counter.output_tokens,
208+
usage.requests,
209+
)
210+
recording.record_llm_token_usage(
211+
provider_id,
212+
model_id,
213+
token_counter.input_tokens,
214+
token_counter.output_tokens,
215+
endpoint_path,
216+
)
217+
recording.record_llm_call(provider_id, model_id, endpoint_path)
218+
return token_counter
219+
220+
221+
def build_turn_summary_from_agent_run(
222+
run_result: AgentRunResult[str],
223+
*,
224+
model_id: str,
225+
endpoint_path: str,
226+
vector_store_ids: list[str] | None = None,
227+
rag_id_mapping: dict[str, str] | None = None,
228+
) -> TurnSummary:
229+
"""Build a turn summary from a completed agent run.
230+
231+
Args:
232+
run_result: Completed agent run result.
233+
model_id: Model identifier in provider/model format.
234+
endpoint_path: Endpoint path used for metric labeling.
235+
vector_store_ids: Vector store IDs used for source mapping.
236+
rag_id_mapping: Mapping from vector store IDs to user-facing source labels.
237+
238+
Returns:
239+
Turn summary with text, tools, RAG metadata, and token usage.
240+
241+
Raises:
242+
HTTPException: When the run failed.
243+
"""
244+
finish_reason = get_agent_finish_reason(run_result.response)
245+
if finish_reason != AgentFinishReason.SUCCESS:
246+
error_response = get_finish_reason_error(finish_reason, model_id)
247+
raise HTTPException(**error_response.model_dump())
248+
249+
state = AgentTurnAccumulator(
250+
vector_store_ids=vector_store_ids or [],
251+
rag_id_mapping=rag_id_mapping or {},
252+
turn_summary=TurnSummary(),
253+
)
254+
255+
for message in run_result.new_messages():
256+
if isinstance(message, ModelResponse):
257+
if message.text:
258+
state.turn_summary.llm_response = message.text
259+
for tool_call_part in message.tool_calls:
260+
process_function_tool_call(state, tool_call_part)
261+
for call_part, return_part in message.native_tool_calls:
262+
process_native_tool_call(state, call_part)
263+
process_native_tool_result(state, return_part)
264+
elif isinstance(message, ModelRequest):
265+
for request_part in message.parts:
266+
if isinstance(request_part, ToolReturnPart):
267+
process_function_tool_result(state, request_part)
268+
269+
state.turn_summary.id = run_result.response.provider_response_id or ""
270+
state.turn_summary.token_usage = extract_agent_token_usage(
271+
run_result.usage,
272+
model_id,
273+
endpoint_path,
274+
)
275+
return state.turn_summary
276+
277+
278+
async def retrieve_agent_response(
279+
client: AsyncLlamaStackClient,
280+
responses_params: ResponsesApiParams,
281+
moderation_result: ShieldModerationResult,
282+
endpoint_path: str,
283+
vector_store_ids: list[str],
284+
rag_id_mapping: dict[str, str],
285+
) -> TurnSummary:
286+
"""Retrieve a turn summary from a blocking agent run.
287+
288+
Mirrors :func:`app.endpoints.query.retrieve_response` for the agent path.
289+
290+
Args:
291+
client: Llama Stack client for conversation persistence on moderation block.
292+
responses_params: Prepared Responses API parameters.
293+
moderation_result: Shield moderation outcome for the turn.
294+
endpoint_path: Endpoint path used for metric labeling.
295+
vector_store_ids: Vector store IDs used for source mapping.
296+
rag_id_mapping: Mapping from vector store IDs to user-facing source labels.
297+
298+
Returns:
299+
Turn summary for the completed agent run.
300+
301+
Raises:
302+
HTTPException: On moderation is not applicable; on agent or provider failure.
303+
"""
304+
if moderation_result.decision == "blocked":
305+
await append_turn_items_to_conversation(
306+
client,
307+
responses_params.conversation,
308+
responses_params.input,
309+
[moderation_result.refusal_response],
310+
)
311+
return TurnSummary(
312+
id=moderation_result.moderation_id,
313+
llm_response=moderation_result.message,
314+
)
315+
prompt = (
316+
responses_params.input
317+
if isinstance(responses_params.input, str)
318+
else extract_text_from_response_items(responses_params.input)
319+
)
320+
321+
try:
322+
agent = build_agent(client, responses_params)
323+
logger.debug("Starting agent non-streaming response processing")
324+
run_result = await agent.run(prompt)
325+
except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc:
326+
response = map_agent_inference_error(exc, responses_params.model)
327+
raise HTTPException(**response.model_dump()) from exc
328+
329+
return build_turn_summary_from_agent_run(
330+
run_result,
331+
model_id=responses_params.model,
332+
endpoint_path=endpoint_path,
333+
vector_store_ids=vector_store_ids,
334+
rag_id_mapping=rag_id_mapping,
335+
)

0 commit comments

Comments
 (0)