Skip to content

Commit 564434e

Browse files
committed
feat(rlsapi): implement LLM integration for v1 /infer endpoint
Wire up the /infer endpoint to Llama Stack for actual inference: - Handle empty LLM responses with fallback message - Include request_id in all log statements for tracing Signed-off-by: Major Hayden <major@redhat.com>
1 parent 2a2a30f commit 564434e

2 files changed

Lines changed: 336 additions & 15 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,31 @@
55
"""
66

77
import logging
8-
from typing import Annotated, Any
8+
from typing import Annotated, Any, cast
99

10-
from fastapi import APIRouter, Depends
10+
from fastapi import APIRouter, Depends, HTTPException
11+
from llama_stack_client import APIConnectionError # type: ignore
12+
from llama_stack_client.types import UserMessage # type: ignore
13+
from llama_stack_client.types.alpha.agents.turn import Turn
1114

15+
import constants
1216
from authentication import get_auth_dependency
1317
from authentication.interface import AuthTuple
1418
from authorization.middleware import authorize
19+
from client import AsyncLlamaStackClientHolder
20+
from configuration import configuration
1521
from models.config import Action
1622
from models.responses import (
1723
ForbiddenResponse,
24+
ServiceUnavailableResponse,
1825
UnauthorizedResponse,
1926
UnprocessableEntityResponse,
2027
)
2128
from models.rlsapi.requests import RlsapiV1InferRequest
2229
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
30+
from utils.endpoints import get_temp_agent
2331
from utils.suid import get_suid
32+
from utils.types import content_to_str
2433

2534
logger = logging.getLogger(__name__)
2635
router = APIRouter(tags=["rlsapi-v1"])
@@ -33,9 +42,83 @@
3342
),
3443
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
3544
422: UnprocessableEntityResponse.openapi_response(),
45+
503: ServiceUnavailableResponse.openapi_response(),
3646
}
3747

3848

49+
def _get_default_model_id() -> str:
50+
"""Get the default model ID from configuration.
51+
52+
Returns the model identifier in Llama Stack format (provider/model).
53+
54+
Returns:
55+
The model identifier string.
56+
57+
Raises:
58+
HTTPException: If no model can be determined from configuration.
59+
"""
60+
if configuration.inference is None:
61+
msg = "No inference configuration available"
62+
logger.error(msg)
63+
raise HTTPException(
64+
status_code=503,
65+
detail={"response": "Service configuration error", "cause": msg},
66+
)
67+
68+
model_id = configuration.inference.default_model
69+
provider_id = configuration.inference.default_provider
70+
71+
if model_id and provider_id:
72+
return f"{provider_id}/{model_id}"
73+
74+
msg = "No default model configured for rlsapi v1 inference"
75+
logger.error(msg)
76+
raise HTTPException(
77+
status_code=503,
78+
detail={"response": "Service configuration error", "cause": msg},
79+
)
80+
81+
82+
async def retrieve_simple_response(question: str) -> str:
83+
"""Retrieve a simple response from the LLM for a stateless query.
84+
85+
Creates a temporary agent, sends a single turn with the user's question,
86+
and returns the LLM response text. No conversation persistence or tools.
87+
88+
Args:
89+
question: The combined user input (question + context).
90+
91+
Returns:
92+
The LLM-generated response text.
93+
94+
Raises:
95+
APIConnectionError: If the Llama Stack service is unreachable.
96+
"""
97+
client = AsyncLlamaStackClientHolder().get_client()
98+
model_id = _get_default_model_id()
99+
100+
logger.debug("Using model %s for rlsapi v1 inference", model_id)
101+
102+
agent, session_id, _ = await get_temp_agent(
103+
client, model_id, constants.DEFAULT_SYSTEM_PROMPT
104+
)
105+
106+
response = await agent.create_turn(
107+
messages=[UserMessage(role="user", content=question).model_dump()],
108+
session_id=session_id,
109+
stream=False,
110+
)
111+
response = cast(Turn, response)
112+
113+
if getattr(response, "output_message", None) is None:
114+
return ""
115+
116+
if getattr(response.output_message, "content", None) is None:
117+
return ""
118+
119+
return content_to_str(response.output_message.content)
120+
121+
39122
@router.post("/infer", responses=infer_responses)
40123
@authorize(Action.RLSAPI_V1_INFER)
41124
async def infer_endpoint(
@@ -55,6 +138,9 @@ async def infer_endpoint(
55138
56139
Returns:
57140
RlsapiV1InferResponse containing the generated response text and request ID.
141+
142+
Raises:
143+
HTTPException: 503 if the LLM service is unavailable.
58144
"""
59145
# Authentication enforced by get_auth_dependency(), authorization by @authorize decorator.
60146
_ = auth
@@ -66,14 +152,28 @@ async def infer_endpoint(
66152

67153
# Combine all input sources (question, stdin, attachments, terminal)
68154
input_source = infer_request.get_input_source()
69-
logger.debug("Combined input source length: %d", len(input_source))
70-
71-
# NOTE(major): Placeholder until we wire up the LLM integration.
72-
response_text = (
73-
"Inference endpoint is functional. "
74-
"LLM integration will be added in a subsequent update."
155+
logger.debug(
156+
"Request %s: Combined input source length: %d", request_id, len(input_source)
75157
)
76158

159+
try:
160+
response_text = await retrieve_simple_response(input_source)
161+
except APIConnectionError as e:
162+
logger.error(
163+
"Unable to connect to Llama Stack for request %s: %s", request_id, e
164+
)
165+
response = ServiceUnavailableResponse(
166+
backend_name="Llama Stack",
167+
cause=str(e),
168+
)
169+
raise HTTPException(**response.model_dump()) from e
170+
171+
if not response_text:
172+
logger.warning("Empty response from LLM for request %s", request_id)
173+
response_text = constants.UNABLE_TO_PROCESS_RESPONSE
174+
175+
logger.info("Completed rlsapi v1 /infer request %s", request_id)
176+
77177
return RlsapiV1InferResponse(
78178
data=RlsapiV1InferData(
79179
text=response_text,

0 commit comments

Comments
 (0)