|
2 | 2 |
|
3 | 3 | import ast |
4 | 4 | import json |
5 | | -import re |
6 | 5 | import logging |
| 6 | +import re |
7 | 7 | from typing import Annotated, Any, AsyncIterator, Iterator, cast |
8 | 8 |
|
9 | | -from llama_stack_client import APIConnectionError |
10 | | -from llama_stack_client import AsyncLlamaStackClient # type: ignore |
11 | | -from llama_stack_client.types import UserMessage # type: ignore |
12 | | - |
| 9 | +from fastapi import APIRouter, Depends, HTTPException, Request, status |
| 10 | +from fastapi.responses import StreamingResponse |
| 11 | +from llama_stack_client import ( |
| 12 | + APIConnectionError, |
| 13 | + AsyncLlamaStackClient, # type: ignore |
| 14 | +) |
13 | 15 | from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str |
| 16 | +from llama_stack_client.types import UserMessage # type: ignore |
14 | 17 | from llama_stack_client.types.agents.agent_turn_response_stream_chunk import ( |
15 | 18 | AgentTurnResponseStreamChunk, |
16 | 19 | ) |
17 | 20 | from llama_stack_client.types.shared import ToolCall |
18 | 21 | from llama_stack_client.types.shared.interleaved_content_item import TextContentItem |
19 | 22 |
|
20 | | -from fastapi import APIRouter, HTTPException, Request, Depends, status |
21 | | -from fastapi.responses import StreamingResponse |
22 | | - |
23 | | -from authentication import get_auth_dependency |
24 | | -from authentication.interface import AuthTuple |
25 | | -from authorization.middleware import authorize |
26 | | -from client import AsyncLlamaStackClientHolder |
27 | | -from configuration import configuration |
28 | 23 | import metrics |
29 | | -from metrics.utils import update_llm_token_count_from_turn |
30 | | -from models.config import Action |
31 | | -from models.requests import QueryRequest |
32 | | -from models.responses import UnauthorizedResponse, ForbiddenResponse |
33 | | -from models.database.conversations import UserConversation |
34 | | -from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt |
35 | | -from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups |
36 | | -from utils.transcripts import store_transcript |
37 | | -from utils.types import TurnSummary |
38 | | -from utils.endpoints import validate_model_provider_override |
39 | | - |
40 | 24 | from app.endpoints.query import ( |
| 25 | + evaluate_model_hints, |
41 | 26 | get_rag_toolgroups, |
42 | 27 | is_input_shield, |
43 | 28 | is_output_shield, |
44 | 29 | is_transcripts_enabled, |
| 30 | + persist_user_conversation_details, |
45 | 31 | select_model_and_provider_id, |
46 | 32 | validate_attachments_metadata, |
47 | 33 | validate_conversation_ownership, |
48 | | - persist_user_conversation_details, |
49 | | - evaluate_model_hints, |
50 | 34 | ) |
| 35 | +from authentication import get_auth_dependency |
| 36 | +from authentication.interface import AuthTuple |
| 37 | +from authorization.middleware import authorize |
| 38 | +from client import AsyncLlamaStackClientHolder |
| 39 | +from configuration import configuration |
| 40 | +from constants import DEFAULT_RAG_TOOL |
| 41 | +from metrics.utils import update_llm_token_count_from_turn |
| 42 | +from models.config import Action |
| 43 | +from models.database.conversations import UserConversation |
| 44 | +from models.requests import QueryRequest |
| 45 | +from models.responses import ForbiddenResponse, UnauthorizedResponse |
| 46 | +from utils.endpoints import ( |
| 47 | + check_configuration_loaded, |
| 48 | + get_agent, |
| 49 | + get_system_prompt, |
| 50 | + validate_model_provider_override, |
| 51 | +) |
| 52 | +from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency |
| 53 | +from utils.transcripts import store_transcript |
| 54 | +from utils.types import TurnSummary |
51 | 55 |
|
52 | 56 | logger = logging.getLogger("app.endpoints.handlers") |
53 | 57 | router = APIRouter(tags=["streaming_query"]) |
@@ -482,7 +486,7 @@ def _handle_tool_execution_event( |
482 | 486 | } |
483 | 487 | ) |
484 | 488 |
|
485 | | - elif r.tool_name == "knowledge_search" and r.content: |
| 489 | + elif r.tool_name == DEFAULT_RAG_TOOL and r.content: |
486 | 490 | summary = "" |
487 | 491 | for i, text_content_item in enumerate(r.content): |
488 | 492 | if isinstance(text_content_item, TextContentItem): |
|
0 commit comments