|
| 1 | +"""Handler for REST API call to provide answer to streaming query.""" |
| 2 | + |
| 3 | +import json |
| 4 | +import logging |
| 5 | +from typing import Any, AsyncIterator |
| 6 | + |
| 7 | +from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore |
| 8 | +from llama_stack_client import AsyncLlamaStackClient # type: ignore |
| 9 | +from llama_stack_client.types import UserMessage # type: ignore |
| 10 | + |
| 11 | +from fastapi import APIRouter, Request, Depends |
| 12 | +from fastapi.responses import StreamingResponse |
| 13 | + |
| 14 | +from client import get_async_llama_stack_client |
| 15 | +from configuration import configuration |
| 16 | +from models.requests import QueryRequest |
| 17 | +import constants |
| 18 | +from utils.auth import auth_dependency |
| 19 | +from utils.common import retrieve_user_id |
| 20 | + |
| 21 | + |
| 22 | +from app.endpoints.query import ( |
| 23 | + is_transcripts_enabled, |
| 24 | + retrieve_conversation_id, |
| 25 | + store_transcript, |
| 26 | + select_model_id, |
| 27 | + validate_attachments_metadata, |
| 28 | +) |
| 29 | + |
| 30 | +logger = logging.getLogger("app.endpoints.handlers") |
| 31 | +router = APIRouter(tags=["streaming_query"]) |
| 32 | + |
| 33 | + |
| 34 | +def format_stream_data(d: dict) -> str: |
| 35 | + """Format outbound data in the Event Stream Format.""" |
| 36 | + data = json.dumps(d) |
| 37 | + return f"data: {data}\n\n" |
| 38 | + |
| 39 | + |
| 40 | +def stream_start_event(conversation_id: str) -> str: |
| 41 | + """Yield the start of the data stream. |
| 42 | +
|
| 43 | + Args: |
| 44 | + conversation_id: The conversation ID (UUID). |
| 45 | + """ |
| 46 | + return format_stream_data( |
| 47 | + { |
| 48 | + "event": "start", |
| 49 | + "data": { |
| 50 | + "conversation_id": conversation_id, |
| 51 | + }, |
| 52 | + } |
| 53 | + ) |
| 54 | + |
| 55 | + |
| 56 | +def stream_end_event() -> str: |
| 57 | + """Yield the end of the data stream.""" |
| 58 | + return format_stream_data( |
| 59 | + { |
| 60 | + "event": "end", |
| 61 | + "data": { |
| 62 | + "referenced_documents": [], # TODO(jboos): implement referenced documents |
| 63 | + "truncated": None, # TODO(jboos): implement truncated |
| 64 | + "input_tokens": 0, # TODO(jboos): implement input tokens |
| 65 | + "output_tokens": 0, # TODO(jboos): implement output tokens |
| 66 | + }, |
| 67 | + "available_quotas": {}, # TODO(jboos): implement available quotas |
| 68 | + } |
| 69 | + ) |
| 70 | + |
| 71 | + |
| 72 | +def stream_build_event(chunk: Any, chunk_id: int) -> str | None: |
| 73 | + """Build a streaming event from a chunk response. |
| 74 | +
|
| 75 | + This function processes chunks from the LLama Stack streaming response and formats |
| 76 | + them into Server-Sent Events (SSE) format for the client. It handles two main |
| 77 | + event types: |
| 78 | +
|
| 79 | + 1. step_progress: Contains text deltas from the model inference process |
| 80 | + 2. step_complete: Contains information about completed tool execution steps |
| 81 | +
|
| 82 | + Args: |
| 83 | + chunk: The streaming chunk from LLama Stack containing event data |
| 84 | + chunk_id: The current chunk ID counter (gets incremented for each token) |
| 85 | +
|
| 86 | + Returns: |
| 87 | + str | None: A formatted SSE data string with event information, or None if |
| 88 | + the chunk doesn't contain processable event data |
| 89 | + """ |
| 90 | + if hasattr(chunk.event, "payload"): |
| 91 | + if chunk.event.payload.event_type == "step_progress": |
| 92 | + if hasattr(chunk.event.payload.delta, "text"): |
| 93 | + text = chunk.event.payload.delta.text |
| 94 | + return format_stream_data( |
| 95 | + { |
| 96 | + "event": "token", |
| 97 | + "data": { |
| 98 | + "id": chunk_id, |
| 99 | + "role": chunk.event.payload.step_type, |
| 100 | + "token": text, |
| 101 | + }, |
| 102 | + } |
| 103 | + ) |
| 104 | + if chunk.event.payload.event_type == "step_complete": |
| 105 | + if chunk.event.payload.step_details.step_type == "tool_execution": |
| 106 | + if chunk.event.payload.step_details.tool_calls: |
| 107 | + tool_name = str( |
| 108 | + chunk.event.payload.step_details.tool_calls[0].tool_name |
| 109 | + ) |
| 110 | + return format_stream_data( |
| 111 | + { |
| 112 | + "event": "token", |
| 113 | + "data": { |
| 114 | + "id": chunk_id, |
| 115 | + "role": chunk.event.payload.step_type, |
| 116 | + "token": tool_name, |
| 117 | + }, |
| 118 | + } |
| 119 | + ) |
| 120 | + return None |
| 121 | + |
| 122 | + |
| 123 | +@router.post("/streaming_query") |
| 124 | +async def streaming_query_endpoint_handler( |
| 125 | + _request: Request, |
| 126 | + query_request: QueryRequest, |
| 127 | + auth: Any = Depends(auth_dependency), |
| 128 | +) -> StreamingResponse: |
| 129 | + """Handle request to the /streaming_query endpoint.""" |
| 130 | + llama_stack_config = configuration.llama_stack_configuration |
| 131 | + logger.info("LLama stack config: %s", llama_stack_config) |
| 132 | + client = await get_async_llama_stack_client(llama_stack_config) |
| 133 | + model_id = select_model_id(await client.models.list(), query_request) |
| 134 | + conversation_id = retrieve_conversation_id(query_request) |
| 135 | + response = await retrieve_response(client, model_id, query_request) |
| 136 | + |
| 137 | + async def response_generator(turn_response: Any) -> AsyncIterator[str]: |
| 138 | + """Generate SSE formatted streaming response.""" |
| 139 | + chunk_id = 0 |
| 140 | + complete_response = "" |
| 141 | + |
| 142 | + # Send start event |
| 143 | + yield stream_start_event(conversation_id) |
| 144 | + |
| 145 | + async for chunk in turn_response: |
| 146 | + if event := stream_build_event(chunk, chunk_id): |
| 147 | + complete_response += json.loads(event.replace("data: ", ""))["data"][ |
| 148 | + "token" |
| 149 | + ] |
| 150 | + chunk_id += 1 |
| 151 | + yield event |
| 152 | + |
| 153 | + yield stream_end_event() |
| 154 | + |
| 155 | + if not is_transcripts_enabled(): |
| 156 | + logger.debug("Transcript collection is disabled in the configuration") |
| 157 | + else: |
| 158 | + store_transcript( |
| 159 | + user_id=retrieve_user_id(auth), |
| 160 | + conversation_id=conversation_id, |
| 161 | + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation |
| 162 | + query=query_request.query, |
| 163 | + query_request=query_request, |
| 164 | + response=complete_response, |
| 165 | + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks |
| 166 | + truncated=False, # TODO(lucasagomes): implement truncation as part of quota work |
| 167 | + attachments=query_request.attachments or [], |
| 168 | + ) |
| 169 | + |
| 170 | + return StreamingResponse(response_generator(response)) |
| 171 | + |
| 172 | + |
| 173 | +async def retrieve_response( |
| 174 | + client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest |
| 175 | +) -> Any: |
| 176 | + """Retrieve response from LLMs and agents.""" |
| 177 | + available_shields = [shield.identifier for shield in await client.shields.list()] |
| 178 | + if not available_shields: |
| 179 | + logger.info("No available shields. Disabling safety") |
| 180 | + else: |
| 181 | + logger.info("Available shields found: %s", available_shields) |
| 182 | + |
| 183 | + # use system prompt from request or default one |
| 184 | + system_prompt = ( |
| 185 | + query_request.system_prompt |
| 186 | + if query_request.system_prompt |
| 187 | + else constants.DEFAULT_SYSTEM_PROMPT |
| 188 | + ) |
| 189 | + logger.debug("Using system prompt: %s", system_prompt) |
| 190 | + |
| 191 | + # TODO(lucasagomes): redact attachments content before sending to LLM |
| 192 | + # if attachments are provided, validate them |
| 193 | + if query_request.attachments: |
| 194 | + validate_attachments_metadata(query_request.attachments) |
| 195 | + |
| 196 | + agent = AsyncAgent( |
| 197 | + client, # type: ignore[arg-type] |
| 198 | + model=model_id, |
| 199 | + instructions=system_prompt, |
| 200 | + input_shields=available_shields if available_shields else [], |
| 201 | + tools=[], |
| 202 | + ) |
| 203 | + session_id = await agent.create_session("chat_session") |
| 204 | + logger.debug("Session ID: %s", session_id) |
| 205 | + response = await agent.create_turn( |
| 206 | + messages=[UserMessage(role="user", content=query_request.query)], |
| 207 | + session_id=session_id, |
| 208 | + documents=query_request.get_documents(), |
| 209 | + stream=True, |
| 210 | + ) |
| 211 | + |
| 212 | + return response |
0 commit comments