Skip to content

Commit e01a1d6

Browse files
authored
LCORE-177: Implemented streaming_query endpoint (#126)
* Inital streaming_query commit * Added streaming_query endpoint so that it conforms with road-core streaming_query endpoint. * Added Unit Tests for streaming_query * Added unit testing for streaming_query endpoint, it is very closely based off of test_query. * test_streaming_query was generated using vibe coding. * Fixed checks * Still not passing `black` test. Thats because streaming_query is based off of query so they have similar code. * Fixed mypy type errors * Switched to using EventLogger * Switched to EventLogger in order to capture more output. * Added AsyncAgent * Removed use of EventLogger due to new use of AsyncAgent * Added Tool Execution to Stream * Added tool execution events to output stream. * Separated response generator out event more. * Fixed PyDocs * Formatted Code * I need to find a way to automate this!!! * Changed APIRouter tag * Fixed Pyright
1 parent 490ad8c commit e01a1d6

8 files changed

Lines changed: 638 additions & 12 deletions

File tree

src/app/endpoints/query.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def query_endpoint_handler(
6565
llama_stack_config = configuration.llama_stack_configuration
6666
logger.info("LLama stack config: %s", llama_stack_config)
6767
client = get_llama_stack_client(llama_stack_config)
68-
model_id = select_model_id(client, query_request)
68+
model_id = select_model_id(client.models.list(), query_request)
6969
conversation_id = retrieve_conversation_id(query_request)
7070
response = retrieve_response(client, model_id, query_request, auth)
7171

@@ -87,9 +87,8 @@ def query_endpoint_handler(
8787
return QueryResponse(conversation_id=conversation_id, response=response)
8888

8989

90-
def select_model_id(client: LlamaStackClient, query_request: QueryRequest) -> str:
90+
def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str:
9191
"""Select the model ID based on the request or available models."""
92-
models: ModelListResponse = client.models.list()
9392
model_id = query_request.model
9493
provider_id = query_request.provider
9594

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""Handler for REST API call to provide answer to streaming query."""
2+
3+
import json
4+
import logging
5+
from typing import Any, AsyncIterator
6+
7+
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
8+
from llama_stack_client import AsyncLlamaStackClient # type: ignore
9+
from llama_stack_client.types import UserMessage # type: ignore
10+
11+
from fastapi import APIRouter, Request, Depends
12+
from fastapi.responses import StreamingResponse
13+
14+
from client import get_async_llama_stack_client
15+
from configuration import configuration
16+
from models.requests import QueryRequest
17+
import constants
18+
from utils.auth import auth_dependency
19+
from utils.common import retrieve_user_id
20+
21+
22+
from app.endpoints.query import (
23+
is_transcripts_enabled,
24+
retrieve_conversation_id,
25+
store_transcript,
26+
select_model_id,
27+
validate_attachments_metadata,
28+
)
29+
30+
logger = logging.getLogger("app.endpoints.handlers")
31+
router = APIRouter(tags=["streaming_query"])
32+
33+
34+
def format_stream_data(d: dict) -> str:
35+
"""Format outbound data in the Event Stream Format."""
36+
data = json.dumps(d)
37+
return f"data: {data}\n\n"
38+
39+
40+
def stream_start_event(conversation_id: str) -> str:
41+
"""Yield the start of the data stream.
42+
43+
Args:
44+
conversation_id: The conversation ID (UUID).
45+
"""
46+
return format_stream_data(
47+
{
48+
"event": "start",
49+
"data": {
50+
"conversation_id": conversation_id,
51+
},
52+
}
53+
)
54+
55+
56+
def stream_end_event() -> str:
57+
"""Yield the end of the data stream."""
58+
return format_stream_data(
59+
{
60+
"event": "end",
61+
"data": {
62+
"referenced_documents": [], # TODO(jboos): implement referenced documents
63+
"truncated": None, # TODO(jboos): implement truncated
64+
"input_tokens": 0, # TODO(jboos): implement input tokens
65+
"output_tokens": 0, # TODO(jboos): implement output tokens
66+
},
67+
"available_quotas": {}, # TODO(jboos): implement available quotas
68+
}
69+
)
70+
71+
72+
def stream_build_event(chunk: Any, chunk_id: int) -> str | None:
73+
"""Build a streaming event from a chunk response.
74+
75+
This function processes chunks from the LLama Stack streaming response and formats
76+
them into Server-Sent Events (SSE) format for the client. It handles two main
77+
event types:
78+
79+
1. step_progress: Contains text deltas from the model inference process
80+
2. step_complete: Contains information about completed tool execution steps
81+
82+
Args:
83+
chunk: The streaming chunk from LLama Stack containing event data
84+
chunk_id: The current chunk ID counter (gets incremented for each token)
85+
86+
Returns:
87+
str | None: A formatted SSE data string with event information, or None if
88+
the chunk doesn't contain processable event data
89+
"""
90+
if hasattr(chunk.event, "payload"):
91+
if chunk.event.payload.event_type == "step_progress":
92+
if hasattr(chunk.event.payload.delta, "text"):
93+
text = chunk.event.payload.delta.text
94+
return format_stream_data(
95+
{
96+
"event": "token",
97+
"data": {
98+
"id": chunk_id,
99+
"role": chunk.event.payload.step_type,
100+
"token": text,
101+
},
102+
}
103+
)
104+
if chunk.event.payload.event_type == "step_complete":
105+
if chunk.event.payload.step_details.step_type == "tool_execution":
106+
if chunk.event.payload.step_details.tool_calls:
107+
tool_name = str(
108+
chunk.event.payload.step_details.tool_calls[0].tool_name
109+
)
110+
return format_stream_data(
111+
{
112+
"event": "token",
113+
"data": {
114+
"id": chunk_id,
115+
"role": chunk.event.payload.step_type,
116+
"token": tool_name,
117+
},
118+
}
119+
)
120+
return None
121+
122+
123+
@router.post("/streaming_query")
124+
async def streaming_query_endpoint_handler(
125+
_request: Request,
126+
query_request: QueryRequest,
127+
auth: Any = Depends(auth_dependency),
128+
) -> StreamingResponse:
129+
"""Handle request to the /streaming_query endpoint."""
130+
llama_stack_config = configuration.llama_stack_configuration
131+
logger.info("LLama stack config: %s", llama_stack_config)
132+
client = await get_async_llama_stack_client(llama_stack_config)
133+
model_id = select_model_id(await client.models.list(), query_request)
134+
conversation_id = retrieve_conversation_id(query_request)
135+
response = await retrieve_response(client, model_id, query_request)
136+
137+
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
138+
"""Generate SSE formatted streaming response."""
139+
chunk_id = 0
140+
complete_response = ""
141+
142+
# Send start event
143+
yield stream_start_event(conversation_id)
144+
145+
async for chunk in turn_response:
146+
if event := stream_build_event(chunk, chunk_id):
147+
complete_response += json.loads(event.replace("data: ", ""))["data"][
148+
"token"
149+
]
150+
chunk_id += 1
151+
yield event
152+
153+
yield stream_end_event()
154+
155+
if not is_transcripts_enabled():
156+
logger.debug("Transcript collection is disabled in the configuration")
157+
else:
158+
store_transcript(
159+
user_id=retrieve_user_id(auth),
160+
conversation_id=conversation_id,
161+
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
162+
query=query_request.query,
163+
query_request=query_request,
164+
response=complete_response,
165+
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
166+
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
167+
attachments=query_request.attachments or [],
168+
)
169+
170+
return StreamingResponse(response_generator(response))
171+
172+
173+
async def retrieve_response(
174+
client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest
175+
) -> Any:
176+
"""Retrieve response from LLMs and agents."""
177+
available_shields = [shield.identifier for shield in await client.shields.list()]
178+
if not available_shields:
179+
logger.info("No available shields. Disabling safety")
180+
else:
181+
logger.info("Available shields found: %s", available_shields)
182+
183+
# use system prompt from request or default one
184+
system_prompt = (
185+
query_request.system_prompt
186+
if query_request.system_prompt
187+
else constants.DEFAULT_SYSTEM_PROMPT
188+
)
189+
logger.debug("Using system prompt: %s", system_prompt)
190+
191+
# TODO(lucasagomes): redact attachments content before sending to LLM
192+
# if attachments are provided, validate them
193+
if query_request.attachments:
194+
validate_attachments_metadata(query_request.attachments)
195+
196+
agent = AsyncAgent(
197+
client, # type: ignore[arg-type]
198+
model=model_id,
199+
instructions=system_prompt,
200+
input_shields=available_shields if available_shields else [],
201+
tools=[],
202+
)
203+
session_id = await agent.create_session("chat_session")
204+
logger.debug("Session ID: %s", session_id)
205+
response = await agent.create_turn(
206+
messages=[UserMessage(role="user", content=query_request.query)],
207+
session_id=session_id,
208+
documents=query_request.get_documents(),
209+
stream=True,
210+
)
211+
212+
return response

src/app/routers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22

33
from fastapi import FastAPI
44

5-
from app.endpoints import info, models, root, query, health, config, feedback
5+
from app.endpoints import (
6+
info,
7+
models,
8+
root,
9+
query,
10+
health,
11+
config,
12+
feedback,
13+
streaming_query,
14+
)
615

716

817
def include_routers(app: FastAPI) -> None:
@@ -18,3 +27,4 @@ def include_routers(app: FastAPI) -> None:
1827
app.include_router(health.router, prefix="/v1")
1928
app.include_router(config.router, prefix="/v1")
2029
app.include_router(feedback.router, prefix="/v1")
30+
app.include_router(streaming_query.router, prefix="/v1")

src/client.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import logging
44

5-
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient # type: ignore
6-
from llama_stack_client import LlamaStackClient # type: ignore
5+
from llama_stack.distribution.library_client import (
6+
AsyncLlamaStackAsLibraryClient, # type: ignore
7+
LlamaStackAsLibraryClient, # type: ignore
8+
)
9+
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient # type: ignore
710
from models.config import LLamaStackConfiguration
811

912
logger = logging.getLogger(__name__)
@@ -29,3 +32,25 @@ def get_llama_stack_client(
2932
return LlamaStackClient(
3033
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
3134
)
35+
36+
37+
async def get_async_llama_stack_client(
38+
llama_stack_config: LLamaStackConfiguration,
39+
) -> AsyncLlamaStackClient:
40+
"""Retrieve Async Llama stack client according to configuration."""
41+
if llama_stack_config.use_as_library_client is True:
42+
if llama_stack_config.library_client_config_path is not None:
43+
logger.info("Using Llama stack as library client")
44+
client = AsyncLlamaStackAsLibraryClient(
45+
llama_stack_config.library_client_config_path
46+
)
47+
await client.initialize()
48+
return client
49+
msg = "Configuration problem: library_client_config_path option is not set"
50+
logger.error(msg)
51+
# tisnik: use custom exception there - with cause etc.
52+
raise Exception(msg) # pylint: disable=broad-exception-raised
53+
logger.info("Using Llama stack running as a service")
54+
return AsyncLlamaStackClient(
55+
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
56+
)

tests/unit/app/endpoints/test_query.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_select_model_id(mocker):
158158
query="What is OpenStack?", model="model1", provider="provider1"
159159
)
160160

161-
model_id = select_model_id(mock_client, query_request)
161+
model_id = select_model_id(mock_client.models.list(), query_request)
162162

163163
assert model_id == "model1"
164164

@@ -180,7 +180,7 @@ def test_select_model_id_no_model(mocker):
180180

181181
query_request = QueryRequest(query="What is OpenStack?")
182182

183-
model_id = select_model_id(mock_client, query_request)
183+
model_id = select_model_id(mock_client.models.list(), query_request)
184184

185185
# Assert return the first available LLM model
186186
assert model_id == "first_model"
@@ -198,7 +198,7 @@ def test_select_model_id_invalid_model(mocker):
198198
)
199199

200200
with pytest.raises(Exception) as exc_info:
201-
select_model_id(mock_client, query_request)
201+
select_model_id(mock_client.models.list(), query_request)
202202

203203
assert (
204204
"Model invalid_model from provider provider1 not found in available models"
@@ -215,7 +215,7 @@ def test_no_available_models(mocker):
215215
query_request = QueryRequest(query="What is OpenStack?", model=None, provider=None)
216216

217217
with pytest.raises(Exception) as exc_info:
218-
select_model_id(mock_client, query_request)
218+
select_model_id(mock_client.models.list(), query_request)
219219

220220
assert "No LLM model found in available models" in str(exc_info.value)
221221

0 commit comments

Comments
 (0)