Skip to content

Commit 1f7d8ac

Browse files
authored
Merge pull request #1865 from asimurka/query_agent_run
LCORE-2310: Query agent run
2 parents 41d55b9 + 70949de commit 1f7d8ac

7 files changed

Lines changed: 502 additions & 313 deletions

File tree

src/app/endpoints/query.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from openai._exceptions import (
1616
APIStatusError as OpenAIAPIStatusError,
1717
)
18+
from typing_extensions import deprecated
1819

1920
from authentication import get_auth_dependency
2021
from authentication.interface import AuthTuple
@@ -42,6 +43,7 @@
4243
from models.common.responses.types import ResponseInput
4344
from models.common.turn_summary import TurnSummary
4445
from models.config import Action
46+
from utils.agents.query import retrieve_agent_response
4547
from utils.conversation_compaction import (
4648
apply_compaction_blocking,
4749
configured_conversation_cache,
@@ -68,7 +70,7 @@
6870
build_turn_summary,
6971
deduplicate_referenced_documents,
7072
extract_vector_store_ids_from_tools,
71-
get_topic_summary,
73+
maybe_get_topic_summary,
7274
prepare_responses_params,
7375
)
7476
from utils.shields import run_shield_moderation, validate_shield_ids_override
@@ -226,12 +228,12 @@ async def query_endpoint_handler(
226228
client = await AsyncLlamaStackClientHolder().update_azure_token()
227229

228230
# Retrieve response using Responses API
229-
turn_summary = await retrieve_response(
231+
turn_summary = await retrieve_agent_response(
230232
client,
231233
responses_params,
232234
moderation_result,
233235
endpoint_path,
234-
original_input=compaction.original_input if compaction.compacted else None,
236+
compaction.original_input if compaction.compacted else None,
235237
)
236238

237239
if moderation_result.decision == "passed":
@@ -249,13 +251,15 @@ async def query_endpoint_handler(
249251
)
250252

251253
# Get topic summary for new conversation
252-
if not user_conversation and query_request.generate_topic_summary:
253-
logger.debug("Generating topic summary for new conversation")
254-
topic_summary = await get_topic_summary(
255-
query_request.query, client, responses_params.model
256-
)
257-
else:
258-
topic_summary = None
254+
should_generate = not user_conversation and bool(
255+
query_request.generate_topic_summary
256+
)
257+
topic_summary = await maybe_get_topic_summary(
258+
generate_topic_summary=should_generate,
259+
input_text=query_request.query,
260+
client=client,
261+
model_id=responses_params.model,
262+
)
259263

260264
logger.info("Consuming tokens")
261265
consume_query_tokens(
@@ -301,6 +305,10 @@ async def query_endpoint_handler(
301305
)
302306

303307

308+
@deprecated(
309+
"Deprecated in favor of utils.agents.query.retrieve_agent_response.",
310+
stacklevel=2,
311+
)
304312
async def retrieve_response(
305313
client: AsyncLlamaStackClient,
306314
responses_params: ResponsesApiParams,

src/utils/agents/query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from enum import Enum
6-
from typing import TypeAlias, cast
6+
from typing import Optional, TypeAlias, cast
77

88
from fastapi import HTTPException
99
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
@@ -33,6 +33,7 @@
3333
from models.common.agents import AgentTurnAccumulator
3434
from models.common.moderation import ShieldModerationResult
3535
from models.common.responses.responses_api_params import ResponsesApiParams
36+
from models.common.responses.types import ResponseInput
3637
from models.common.turn_summary import TurnSummary
3738
from utils.agents.tool_processor import (
3839
process_function_tool_call,
@@ -281,6 +282,7 @@ async def retrieve_agent_response(
281282
responses_params: ResponsesApiParams,
282283
moderation_result: ShieldModerationResult,
283284
endpoint_path: str,
285+
_original_input: Optional[ResponseInput] = None,
284286
) -> TurnSummary:
285287
"""Retrieve a turn summary from a blocking agent run.
286288
@@ -291,6 +293,7 @@ async def retrieve_agent_response(
291293
responses_params: Prepared Responses API parameters.
292294
moderation_result: Shield moderation outcome for the turn.
293295
endpoint_path: Endpoint path used for metric labeling.
296+
_original_input: Original user input before the explicit-input rewrite.
294297
295298
Returns:
296299
Turn summary for the completed agent run.

tests/integration/conftest.py

Lines changed: 209 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,22 @@
88
import pytest
99
from fastapi import Request, Response
1010
from fastapi.testclient import TestClient
11-
from pytest_mock import MockerFixture
11+
from llama_stack_api.openai_responses import OpenAIResponseObject
12+
from llama_stack_client.types import VersionInfo
13+
from pydantic_ai.messages import (
14+
ModelMessage,
15+
ModelRequest,
16+
ModelResponse,
17+
NativeToolCallPart,
18+
NativeToolReturnPart,
19+
TextPart,
20+
ToolCallPart,
21+
ToolReturnPart,
22+
)
23+
from pydantic_ai.native_tools import FileSearchTool, MCPServerTool
24+
from pydantic_ai.run import AgentRunResult
25+
from pydantic_ai.usage import RunUsage
26+
from pytest_mock import AsyncMockType, MockerFixture
1227
from sqlalchemy import create_engine
1328
from sqlalchemy.engine import Engine
1429
from sqlalchemy.orm import Session, sessionmaker
@@ -70,9 +85,6 @@ def create_mock_llm_response( # pylint: disable=too-many-arguments,too-many-pos
7085
Returns:
7186
Mock LLM response object with the specified configuration.
7287
"""
73-
# pylint: disable=import-outside-toplevel
74-
from llama_stack_api.openai_responses import OpenAIResponseObject
75-
7688
mock_response = mocker.MagicMock(spec=OpenAIResponseObject)
7789
mock_response.id = "response-123"
7890

@@ -154,6 +166,187 @@ def create_mock_tool_call(
154166
return mock_tool_call
155167

156168

169+
def create_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments
170+
mocker: MockerFixture,
171+
*,
172+
content: str = "This is a test response about Ansible.",
173+
response_id: str = "response-123",
174+
input_tokens: int = 10,
175+
output_tokens: int = 5,
176+
model_response: ModelResponse | None = None,
177+
new_messages: list[ModelMessage] | None = None,
178+
) -> AgentRunResult[str]:
179+
"""Create a mock AgentRunResult wired for retrieve_agent_response.
180+
181+
Uses real pydantic-ai message types so build_turn_summary_from_agent_run
182+
exercises the same path as production agent runs.
183+
184+
Args:
185+
mocker: pytest-mock fixture.
186+
content: Assistant text content for the run.
187+
response_id: Provider response identifier.
188+
input_tokens: Input token count for the run.
189+
output_tokens: Output token count for the run.
190+
model_response: Optional pre-built ModelResponse.
191+
new_messages: Optional message sequence returned by new_messages().
192+
193+
Returns:
194+
Mock AgentRunResult compatible with build_turn_summary_from_agent_run.
195+
"""
196+
if model_response is None:
197+
parts = [TextPart(content)] if content else []
198+
model_response = ModelResponse(
199+
parts=parts,
200+
finish_reason="stop",
201+
provider_response_id=response_id,
202+
)
203+
204+
messages = new_messages if new_messages is not None else [model_response]
205+
run_result = mocker.MagicMock(spec=AgentRunResult)
206+
run_result.response = model_response
207+
run_result.usage = RunUsage(
208+
input_tokens=input_tokens,
209+
output_tokens=output_tokens,
210+
requests=1,
211+
)
212+
run_result.new_messages.return_value = messages
213+
return run_result
214+
215+
216+
def create_file_search_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments
217+
mocker: MockerFixture,
218+
*,
219+
content: str,
220+
response_id: str = "response-tool-rag",
221+
queries: Optional[list[str]] = None,
222+
results: Optional[list[dict[str, Any]]] = None,
223+
input_tokens: int = 10,
224+
output_tokens: int = 5,
225+
) -> AgentRunResult[str]:
226+
"""Create an AgentRunResult containing a native file_search tool call."""
227+
call = NativeToolCallPart(
228+
tool_name=FileSearchTool.kind,
229+
args={"queries": queries or ["test query"]},
230+
tool_call_id="call-fs-1",
231+
)
232+
return_part = NativeToolReturnPart(
233+
tool_name=FileSearchTool.kind,
234+
tool_call_id="call-fs-1",
235+
content={
236+
"status": "success",
237+
"results": results or [],
238+
},
239+
)
240+
model_response = ModelResponse(
241+
parts=[call, return_part, TextPart(content)],
242+
finish_reason="stop",
243+
provider_response_id=response_id,
244+
)
245+
return create_agent_run_result(
246+
mocker,
247+
content=content,
248+
response_id=response_id,
249+
input_tokens=input_tokens,
250+
output_tokens=output_tokens,
251+
model_response=model_response,
252+
)
253+
254+
255+
def create_mcp_list_tools_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments
256+
mocker: MockerFixture,
257+
*,
258+
content: str,
259+
response_id: str = "response-mcplist",
260+
server_label: str = "kubernetes-server",
261+
tools: Optional[list[dict[str, Any]]] = None,
262+
input_tokens: int = 15,
263+
output_tokens: int = 20,
264+
) -> AgentRunResult[str]:
265+
"""Create an AgentRunResult containing an MCP list-tools native tool call."""
266+
call = NativeToolCallPart(
267+
tool_name=f"{MCPServerTool.kind}:{server_label}",
268+
args={"action": "list_tools"},
269+
tool_call_id="mcplist-101",
270+
)
271+
return_part = NativeToolReturnPart(
272+
tool_name=f"{MCPServerTool.kind}:{server_label}",
273+
tool_call_id="mcplist-101",
274+
content={"tools": tools or []},
275+
)
276+
model_response = ModelResponse(
277+
parts=[call, return_part, TextPart(content)],
278+
finish_reason="stop",
279+
provider_response_id=response_id,
280+
)
281+
return create_agent_run_result(
282+
mocker,
283+
content=content,
284+
response_id=response_id,
285+
input_tokens=input_tokens,
286+
output_tokens=output_tokens,
287+
model_response=model_response,
288+
)
289+
290+
291+
def create_multi_tool_agent_run_result(
292+
mocker: MockerFixture,
293+
*,
294+
content: str = "Based on documentation and calculations...",
295+
response_id: str = "response-multi",
296+
input_tokens: int = 40,
297+
output_tokens: int = 60,
298+
) -> AgentRunResult[str]:
299+
"""Create an AgentRunResult with file_search and function tool calls."""
300+
file_search_call = NativeToolCallPart(
301+
tool_name=FileSearchTool.kind,
302+
args={"queries": ["Kubernetes deployment"]},
303+
tool_call_id="search-1",
304+
)
305+
file_search_return = NativeToolReturnPart(
306+
tool_name=FileSearchTool.kind,
307+
tool_call_id="search-1",
308+
content={"status": "success", "results": []},
309+
)
310+
function_call = ToolCallPart(
311+
tool_name="calculate",
312+
args={"operation": "sum"},
313+
tool_call_id="func-2",
314+
)
315+
function_return = ToolReturnPart(
316+
tool_name="calculate",
317+
content={"result": 2},
318+
tool_call_id="func-2",
319+
)
320+
model_response = ModelResponse(
321+
parts=[
322+
file_search_call,
323+
file_search_return,
324+
function_call,
325+
TextPart(content),
326+
],
327+
finish_reason="stop",
328+
provider_response_id=response_id,
329+
)
330+
return create_agent_run_result(
331+
mocker,
332+
content=content,
333+
response_id=response_id,
334+
input_tokens=input_tokens,
335+
output_tokens=output_tokens,
336+
model_response=model_response,
337+
new_messages=[model_response, ModelRequest(parts=[function_return])],
338+
)
339+
340+
341+
def set_query_agent_run(
342+
mock_query_agent: AsyncMockType,
343+
mocker: MockerFixture,
344+
**kwargs: Any,
345+
) -> None:
346+
"""Configure mock agent.run return value for /query integration tests."""
347+
mock_query_agent.run.return_value = create_agent_run_result(mocker, **kwargs)
348+
349+
157350
# ==========================================
158351
# Fixtures
159352
# ==========================================
@@ -448,10 +641,6 @@ def mock_llama_stack_client_fixture(
448641
Yields:
449642
mock_client: The mocked Llama Stack client instance.
450643
"""
451-
# pylint: disable=import-outside-toplevel
452-
from llama_stack_api.openai_responses import OpenAIResponseObject
453-
from llama_stack_client.types import VersionInfo
454-
455644
# Patch AsyncLlamaStackClientHolder at multiple import locations
456645
# This ensures the mock is active both during app startup (app.main)
457646
# and during endpoint execution (query, conversations_v1, responses, etc.)
@@ -514,3 +703,15 @@ def mock_llama_stack_client_fixture(
514703
mock_holder_instance.get_client.return_value = mock_client
515704

516705
yield mock_client
706+
707+
708+
@pytest.fixture(name="mock_query_agent")
709+
def mock_query_agent_fixture(mocker: MockerFixture) -> Any:
710+
"""Patch build_agent for /query and return the mock agent."""
711+
mock_agent = mocker.AsyncMock()
712+
mock_agent.run = mocker.AsyncMock(return_value=create_agent_run_result(mocker))
713+
mock_agent.build_agent_mock = mocker.patch(
714+
"utils.agents.query.build_agent",
715+
return_value=mock_agent,
716+
)
717+
return mock_agent

0 commit comments

Comments
 (0)