Skip to content

Commit 684f2c3

Browse files
committed
Created v1/responses endpoint
1 parent 6f7aab4 commit 684f2c3

8 files changed

Lines changed: 1893 additions & 48 deletions

File tree

docs/responses.md

Lines changed: 447 additions & 0 deletions
Large diffs are not rendered by default.

src/app/endpoints/responses.py

Lines changed: 393 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,393 @@
1+
# pylint: disable=too-many-locals,too-many-branches,too-many-nested-blocks
2+
3+
"""Handler for REST API call to provide answer using Responses API (LCORE specification)."""
4+
5+
import json
6+
import logging
7+
from datetime import UTC, datetime
8+
import sqlite3
9+
from typing import Annotated, Any, AsyncIterator, Optional, Union, cast
10+
11+
from fastapi import APIRouter, Depends, HTTPException, Request
12+
from fastapi.responses import StreamingResponse
13+
from llama_stack_api.openai_responses import (
14+
OpenAIResponseObject,
15+
OpenAIResponseObjectStream,
16+
)
17+
from llama_stack_client import (
18+
APIConnectionError,
19+
APIStatusError as LLSApiStatusError,
20+
)
21+
from openai._exceptions import (
22+
APIStatusError as OpenAIAPIStatusError,
23+
)
24+
import psycopg2
25+
from sqlalchemy.exc import SQLAlchemyError
26+
27+
from authentication import get_auth_dependency
28+
from authentication.interface import AuthTuple
29+
from authorization.azure_token_manager import AzureEntraIDManager
30+
from authorization.middleware import authorize
31+
from cache.cache_error import CacheError
32+
from client import AsyncLlamaStackClientHolder
33+
from configuration import configuration
34+
from models.cache_entry import CacheEntry
35+
from models.config import Action
36+
from models.requests import ResponsesRequest
37+
from models.responses import (
38+
ForbiddenResponse,
39+
InternalServerErrorResponse,
40+
NotFoundResponse,
41+
PromptTooLongResponse,
42+
ResponsesResponse,
43+
QuotaExceededResponse,
44+
ServiceUnavailableResponse,
45+
UnauthorizedResponse,
46+
UnprocessableEntityResponse,
47+
)
48+
from utils.endpoints import (
49+
check_configuration_loaded,
50+
validate_and_retrieve_conversation,
51+
)
52+
from utils.mcp_headers import mcp_headers_dependency
53+
from utils.query import (
54+
consume_query_tokens,
55+
handle_known_apistatus_errors,
56+
persist_user_conversation_details,
57+
store_conversation_into_cache,
58+
store_query_results,
59+
update_azure_token,
60+
)
61+
from utils.quota import check_tokens_available, get_available_quotas
62+
from utils.responses import (
63+
extract_text_from_input,
64+
extract_token_usage,
65+
get_topic_summary,
66+
select_model_for_responses,
67+
validate_model_override_permissions,
68+
)
69+
from utils.shields import (
70+
append_turn_to_conversation,
71+
run_shield_moderation,
72+
)
73+
from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id
74+
75+
logger = logging.getLogger("app.endpoints.handlers")
76+
router = APIRouter(tags=["responses"])
77+
78+
responses_response: dict[int | str, dict[str, Any]] = {
79+
200: ResponsesResponse.openapi_response(),
80+
401: UnauthorizedResponse.openapi_response(
81+
examples=["missing header", "missing token"]
82+
),
83+
403: ForbiddenResponse.openapi_response(
84+
examples=["endpoint", "conversation read", "model override"]
85+
),
86+
404: NotFoundResponse.openapi_response(
87+
examples=["model", "conversation", "provider"]
88+
),
89+
413: PromptTooLongResponse.openapi_response(),
90+
422: UnprocessableEntityResponse.openapi_response(),
91+
429: QuotaExceededResponse.openapi_response(),
92+
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
93+
503: ServiceUnavailableResponse.openapi_response(),
94+
}
95+
96+
97+
@router.post(
98+
"/responses",
99+
responses=responses_response,
100+
summary="Responses Endpoint Handler",
101+
)
102+
@authorize(Action.QUERY)
103+
async def responses_endpoint_handler(
104+
request: Request,
105+
responses_request: ResponsesRequest,
106+
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
107+
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
108+
) -> Union[ResponsesResponse, StreamingResponse]:
109+
"""
110+
Handle request to the /responses endpoint using Responses API (LCORE specification).
111+
112+
Processes a POST request to the responses endpoint, forwarding the
113+
user's request to a selected Llama Stack LLM and returning the generated response
114+
following the LCORE OpenAPI specification.
115+
116+
Returns:
117+
ResponsesResponse: Contains the response following LCORE specification (non-streaming).
118+
StreamingResponse: SSE-formatted streaming response with enriched events (streaming).
119+
- response.created event includes conversation attribute
120+
- response.completed event includes available_quotas attribute
121+
122+
Raises:
123+
HTTPException:
124+
- 401: Unauthorized - Missing or invalid credentials
125+
- 403: Forbidden - Insufficient permissions or model override not allowed
126+
- 404: Not Found - Conversation, model, or provider not found
127+
- 413: Prompt too long - Prompt exceeded model's context window size
128+
- 422: Unprocessable Entity - Request validation failed
129+
- 429: Quota limit exceeded - The token quota for model or user has been exceeded
130+
- 500: Internal Server Error - Configuration not loaded or other server errors
131+
- 503: Service Unavailable - Unable to connect to Llama Stack backend
132+
"""
133+
check_configuration_loaded(configuration)
134+
135+
started_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
136+
user_id, _, _skip_userid_check, token = auth
137+
138+
# Check token availability
139+
check_tokens_available(configuration.quota_limiters, user_id)
140+
141+
# Enforce RBAC: optionally disallow overriding model/provider in requests
142+
if responses_request.model:
143+
validate_model_override_permissions(
144+
responses_request.model,
145+
request.state.authorized_actions,
146+
)
147+
148+
user_conversation = None
149+
if responses_request.conversation:
150+
logger.debug(
151+
"Conversation ID specified in request: %s", responses_request.conversation
152+
)
153+
user_conversation = validate_and_retrieve_conversation(
154+
normalized_conv_id=normalize_conversation_id(responses_request.conversation),
155+
user_id=user_id,
156+
others_allowed=Action.READ_OTHERS_CONVERSATIONS
157+
in request.state.authorized_actions,
158+
)
159+
# Convert to llama-stack format if needed
160+
responses_request.conversation = to_llama_stack_conversation_id(user_conversation.id)
161+
162+
client = AsyncLlamaStackClientHolder().get_client()
163+
164+
# LCORE-specific: Automatically select model if not provided in request
165+
# This extends the base LLS API which requires model to be specified.
166+
if not responses_request.model:
167+
responses_request.model = await select_model_for_responses(
168+
client, user_conversation
169+
)
170+
171+
# Prepare API request parameters
172+
api_params = responses_request.model_dump(
173+
exclude_none=True, exclude={"generate_topic_summary"}
174+
)
175+
176+
# Handle Azure token refresh if needed
177+
if (
178+
api_params["model"].startswith("azure")
179+
and AzureEntraIDManager().is_entra_id_configured
180+
and AzureEntraIDManager().is_token_expired
181+
and AzureEntraIDManager().refresh_token()
182+
):
183+
client = await update_azure_token(client)
184+
185+
# Retrieve response using Responses API
186+
try:
187+
# Extract text from input for shield moderation (input can be string or complex object)
188+
input_text_for_moderation = extract_text_from_input(responses_request.input)
189+
moderation_result = await run_shield_moderation(
190+
client, input_text_for_moderation
191+
)
192+
if moderation_result.blocked:
193+
violation_message = moderation_result.message or ""
194+
if responses_request.conversation:
195+
await append_turn_to_conversation(
196+
client,
197+
responses_request.conversation,
198+
input_text_for_moderation,
199+
violation_message,
200+
)
201+
return ResponsesResponse.model_construct(
202+
status="blocked",
203+
text=violation_message,
204+
error={"message": violation_message},
205+
conversation=responses_request.conversation,
206+
)
207+
208+
response = await client.responses.create(**api_params)
209+
210+
# Handle streaming response
211+
if responses_request.stream:
212+
stream_iterator = cast(AsyncIterator[OpenAIResponseObjectStream], response)
213+
return StreamingResponse(
214+
_stream_responses(
215+
stream_iterator,
216+
responses_request.conversation,
217+
user_id,
218+
api_params.get("model", ""),
219+
),
220+
media_type="text/event-stream",
221+
)
222+
223+
response = cast(OpenAIResponseObject, response)
224+
225+
except RuntimeError as e: # library mode wraps 413 into runtime error
226+
if "context_length" in str(e).lower():
227+
error_response = PromptTooLongResponse(model=api_params.get("model", ""))
228+
raise HTTPException(**error_response.model_dump()) from e
229+
raise e
230+
except APIConnectionError as e:
231+
error_response = ServiceUnavailableResponse(
232+
backend_name="Llama Stack",
233+
cause=str(e),
234+
)
235+
raise HTTPException(**error_response.model_dump()) from e
236+
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
237+
error_response = handle_known_apistatus_errors(e, api_params.get("model", ""))
238+
raise HTTPException(**error_response.model_dump()) from e
239+
240+
# Extract token usage
241+
token_usage = extract_token_usage(response, api_params["model"])
242+
243+
# Consume tokens
244+
logger.info("Consuming tokens")
245+
consume_query_tokens(
246+
user_id=user_id,
247+
model_id=api_params["model"],
248+
token_usage=token_usage,
249+
configuration=configuration,
250+
)
251+
252+
# Get available quotas
253+
logger.info("Getting available quotas")
254+
available_quotas = get_available_quotas(
255+
quota_limiters=configuration.quota_limiters, user_id=user_id
256+
)
257+
258+
# Get topic summary for new conversation
259+
if not user_conversation and responses_request.generate_topic_summary:
260+
logger.debug("Generating topic summary for new conversation")
261+
topic_summary = await get_topic_summary(
262+
extract_text_from_input(responses_request.input), client, api_params["model"]
263+
)
264+
else:
265+
topic_summary = None
266+
267+
try:
268+
logger.info("Persisting conversation details")
269+
# Extract provider_id from model_id (format: "provider/model")
270+
persist_user_conversation_details(
271+
user_id=user_id,
272+
conversation_id=responses_request.conversation or "",
273+
model_id=api_params["model"],
274+
provider_id=api_params["model"].split("/")[0], # type: ignore
275+
topic_summary=topic_summary,
276+
)
277+
except SQLAlchemyError as e:
278+
logger.exception("Error persisting conversation details.")
279+
response = InternalServerErrorResponse.database_error()
280+
raise HTTPException(**response.model_dump()) from e
281+
282+
# Store conversation in cache
283+
try:
284+
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
285+
cache_entry = CacheEntry(
286+
query=extract_text_from_input(responses_request.input),
287+
response="",
288+
provider=api_params["model"].split("/")[0], # type: ignore
289+
model=api_params["model"],
290+
started_at=started_at,
291+
completed_at=completed_at,
292+
referenced_documents=None,
293+
tool_calls=None,
294+
tool_results=None,
295+
)
296+
297+
logger.info("Storing conversation in cache")
298+
store_conversation_into_cache(
299+
config=configuration,
300+
user_id=user_id,
301+
conversation_id=responses_request.conversation or "",
302+
cache_entry=cache_entry,
303+
_skip_userid_check=_skip_userid_check,
304+
topic_summary=topic_summary,
305+
)
306+
except (CacheError, ValueError, psycopg2.Error, sqlite3.Error) as e:
307+
logger.exception("Error storing conversation in cache: %s", e)
308+
response = InternalServerErrorResponse.database_error()
309+
raise HTTPException(**response.model_dump()) from e
310+
311+
# Extract response fields using model_dump, excluding output/text which are handled separately
312+
response_dict = cast(OpenAIResponseObject, response).model_dump()
313+
314+
logger.info("Building final response")
315+
return ResponsesResponse(
316+
**response_dict,
317+
conversation=responses_request.conversation,
318+
available_quotas=available_quotas,
319+
)
320+
321+
322+
async def _stream_responses(
323+
stream: AsyncIterator[OpenAIResponseObjectStream],
324+
conversation_id: Optional[str],
325+
user_id: str,
326+
model_id: str,
327+
) -> AsyncIterator[str]:
328+
"""Generate SSE-formatted streaming response with LCORE-enriched events.
329+
330+
Processes streaming chunks from Llama Stack and converts them to
331+
Server-Sent Events (SSE) format, enriching response.created with conversation
332+
and response.completed with available_quotas. All other events are forwarded
333+
exactly as received from the stream.
334+
335+
Args:
336+
stream: The streaming response from Llama Stack
337+
conversation_id: The conversation ID to include in response.created
338+
user_id: User ID for quota retrieval
339+
model_id: Model ID for token usage extraction
340+
341+
Yields:
342+
SSE-formatted strings for streaming events.
343+
"""
344+
normalized_conv_id = normalize_conversation_id(conversation_id) if conversation_id else None
345+
latest_response_object: Optional[OpenAIResponseObject] = None
346+
347+
async for chunk in stream:
348+
event_type = getattr(chunk, "type", None)
349+
logger.debug("Processing streaming chunk, type: %s", event_type)
350+
351+
# Get the original chunk data as dict (exact same structure as original)
352+
chunk_dict = chunk.model_dump() if hasattr(chunk, "model_dump") else {}
353+
354+
# Enrich response.created event with conversation attribute
355+
if event_type == "response.created":
356+
response_obj = getattr(chunk, "response", None)
357+
if response_obj:
358+
latest_response_object = cast(OpenAIResponseObject, response_obj)
359+
360+
# Add conversation attribute to the original chunk data
361+
if normalized_conv_id:
362+
chunk_dict["conversation"] = normalized_conv_id
363+
364+
# Enrich response.completed event with available_quotas attribute
365+
elif event_type == "response.completed":
366+
response_obj = getattr(chunk, "response", None)
367+
if response_obj:
368+
latest_response_object = cast(OpenAIResponseObject, response_obj)
369+
370+
# Extract token usage
371+
token_usage_obj = None
372+
if latest_response_object:
373+
token_usage_obj = extract_token_usage(latest_response_object, model_id)
374+
375+
# Get available quotas
376+
available_quotas = get_available_quotas(
377+
quota_limiters=configuration.quota_limiters, user_id=user_id
378+
)
379+
380+
# Consume tokens
381+
if token_usage_obj and latest_response_object:
382+
consume_query_tokens(
383+
user_id=user_id,
384+
model_id=model_id,
385+
token_usage=token_usage_obj,
386+
configuration=configuration,
387+
)
388+
389+
# Add available_quotas attribute to the original chunk data
390+
if available_quotas:
391+
chunk_dict["available_quotas"] = available_quotas
392+
393+
yield json.dumps(chunk_dict)

0 commit comments

Comments
 (0)