Skip to content

Commit 32c6fbf

Browse files
authored
Merge pull request lightspeed-core#583 from are-ces/fix-constant-rag
Replacing RAG tool name with constant
2 parents 0de9deb + d8f617e commit 32c6fbf

2 files changed

Lines changed: 33 additions & 30 deletions

File tree

src/app/endpoints/query.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Annotated, Any, Optional, cast
99

1010
from fastapi import APIRouter, Depends, HTTPException, Request, status
11-
from pydantic import AnyUrl
1211
from llama_stack_client import (
1312
APIConnectionError,
1413
AsyncLlamaStackClient, # type: ignore
@@ -23,6 +22,7 @@
2322
from llama_stack_client.types.model_list_response import ModelListResponse
2423
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
2524
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
25+
from pydantic import AnyUrl
2626

2727
import constants
2828
import metrics
@@ -513,8 +513,7 @@ def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]:
513513
if not isinstance(step, ToolExecutionStep):
514514
continue
515515
for tool_response in step.tool_responses:
516-
# TODO(are-ces): use constant instead
517-
if tool_response.tool_name != "knowledge_search":
516+
if tool_response.tool_name != constants.DEFAULT_RAG_TOOL:
518517
continue
519518
for text_item in tool_response.content:
520519
if not isinstance(text_item, TextContentItem):

src/app/endpoints/streaming_query.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,56 @@
22

33
import ast
44
import json
5-
import re
65
import logging
6+
import re
77
from typing import Annotated, Any, AsyncIterator, Iterator, cast
88

9-
from llama_stack_client import APIConnectionError
10-
from llama_stack_client import AsyncLlamaStackClient # type: ignore
11-
from llama_stack_client.types import UserMessage # type: ignore
12-
9+
from fastapi import APIRouter, Depends, HTTPException, Request, status
10+
from fastapi.responses import StreamingResponse
11+
from llama_stack_client import (
12+
APIConnectionError,
13+
AsyncLlamaStackClient, # type: ignore
14+
)
1315
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
16+
from llama_stack_client.types import UserMessage # type: ignore
1417
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import (
1518
AgentTurnResponseStreamChunk,
1619
)
1720
from llama_stack_client.types.shared import ToolCall
1821
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
1922

20-
from fastapi import APIRouter, HTTPException, Request, Depends, status
21-
from fastapi.responses import StreamingResponse
22-
23-
from authentication import get_auth_dependency
24-
from authentication.interface import AuthTuple
25-
from authorization.middleware import authorize
26-
from client import AsyncLlamaStackClientHolder
27-
from configuration import configuration
2823
import metrics
29-
from metrics.utils import update_llm_token_count_from_turn
30-
from models.config import Action
31-
from models.requests import QueryRequest
32-
from models.responses import UnauthorizedResponse, ForbiddenResponse
33-
from models.database.conversations import UserConversation
34-
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
35-
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
36-
from utils.transcripts import store_transcript
37-
from utils.types import TurnSummary
38-
from utils.endpoints import validate_model_provider_override
39-
4024
from app.endpoints.query import (
25+
evaluate_model_hints,
4126
get_rag_toolgroups,
4227
is_input_shield,
4328
is_output_shield,
4429
is_transcripts_enabled,
30+
persist_user_conversation_details,
4531
select_model_and_provider_id,
4632
validate_attachments_metadata,
4733
validate_conversation_ownership,
48-
persist_user_conversation_details,
49-
evaluate_model_hints,
5034
)
35+
from authentication import get_auth_dependency
36+
from authentication.interface import AuthTuple
37+
from authorization.middleware import authorize
38+
from client import AsyncLlamaStackClientHolder
39+
from configuration import configuration
40+
from constants import DEFAULT_RAG_TOOL
41+
from metrics.utils import update_llm_token_count_from_turn
42+
from models.config import Action
43+
from models.database.conversations import UserConversation
44+
from models.requests import QueryRequest
45+
from models.responses import ForbiddenResponse, UnauthorizedResponse
46+
from utils.endpoints import (
47+
check_configuration_loaded,
48+
get_agent,
49+
get_system_prompt,
50+
validate_model_provider_override,
51+
)
52+
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
53+
from utils.transcripts import store_transcript
54+
from utils.types import TurnSummary
5155

5256
logger = logging.getLogger("app.endpoints.handlers")
5357
router = APIRouter(tags=["streaming_query"])
@@ -482,7 +486,7 @@ def _handle_tool_execution_event(
482486
}
483487
)
484488

485-
elif r.tool_name == "knowledge_search" and r.content:
489+
elif r.tool_name == DEFAULT_RAG_TOOL and r.content:
486490
summary = ""
487491
for i, text_content_item in enumerate(r.content):
488492
if isinstance(text_content_item, TextContentItem):

0 commit comments

Comments
 (0)