Skip to content

Commit 0108f9b

Browse files
committed
fix(rlsapi): use Responses API instead of Agent abstraction
Switch the /v1/infer endpoint from using the AsyncAgent abstraction to the Responses API (client.responses.create()). The AsyncAgent.create_turn() method internally calls self.initialize() which does not exist in llama-stack-client 0.3.5, causing AttributeError at runtime. The Responses API is a better fit for the stateless /infer endpoint: - Consistent with other endpoints (query_v2, streaming_query_v2) - Simpler for single-turn inference (no session management) - Avoids the broken Agent abstraction - Fewer moving parts Signed-off-by: Major Hayden <major@redhat.com>
1 parent 1a12376 commit 0108f9b

3 files changed

Lines changed: 131 additions & 182 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
from typing import Annotated, Any, cast
99

1010
from fastapi import APIRouter, Depends, HTTPException
11-
from llama_stack_client import APIConnectionError # type: ignore
12-
from llama_stack_client.types import UserMessage # type: ignore
13-
from llama_stack_client.types.alpha.agents.turn import Turn
11+
from llama_stack.apis.agents.openai_responses import OpenAIResponseObject
12+
from llama_stack_client import APIConnectionError
1413

1514
import constants
1615
from authentication import get_auth_dependency
@@ -27,9 +26,8 @@
2726
)
2827
from models.rlsapi.requests import RlsapiV1InferRequest
2928
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
30-
from utils.endpoints import get_temp_agent
29+
from utils.responses import extract_text_from_response_output_item
3130
from utils.suid import get_suid
32-
from utils.types import content_to_str
3331

3432
logger = logging.getLogger(__name__)
3533
router = APIRouter(tags=["rlsapi-v1"])
@@ -82,8 +80,8 @@ def _get_default_model_id() -> str:
8280
async def retrieve_simple_response(question: str) -> str:
8381
"""Retrieve a simple response from the LLM for a stateless query.
8482
85-
Creates a temporary agent, sends a single turn with the user's question,
86-
and returns the LLM response text. No conversation persistence or tools.
83+
Uses the Responses API for simple stateless inference, consistent with
84+
other endpoints (query_v2, streaming_query_v2).
8785
8886
Args:
8987
question: The combined user input (question + context).
@@ -100,24 +98,19 @@ async def retrieve_simple_response(question: str) -> str:
10098

10199
logger.debug("Using model %s for rlsapi v1 inference", model_id)
102100

103-
agent, session_id, _ = await get_temp_agent(
104-
client, model_id, constants.DEFAULT_SYSTEM_PROMPT
105-
)
106-
107-
response = await agent.create_turn(
108-
messages=[UserMessage(role="user", content=question).model_dump()],
109-
session_id=session_id,
101+
response = await client.responses.create(
102+
input=question,
103+
model=model_id,
104+
instructions=constants.DEFAULT_SYSTEM_PROMPT,
110105
stream=False,
106+
store=False,
111107
)
112-
response = cast(Turn, response)
113-
114-
if getattr(response, "output_message", None) is None:
115-
return ""
108+
response = cast(OpenAIResponseObject, response)
116109

117-
if getattr(response.output_message, "content", None) is None:
118-
return ""
119-
120-
return content_to_str(response.output_message.content)
110+
return "".join(
111+
extract_text_from_response_output_item(output_item)
112+
for output_item in response.output
113+
)
121114

122115

123116
@router.post("/infer", responses=infer_responses)

tests/integration/endpoints/test_rlsapi_v1_integration.py

Lines changed: 73 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
# pylint: disable=protected-access
1010
# pylint: disable=unused-argument
1111

12-
from typing import Any, NamedTuple
12+
from typing import Any
1313

1414
import pytest
1515
from fastapi import HTTPException, status
1616
from llama_stack_client import APIConnectionError
17-
from llama_stack_client.types.alpha.agents.turn import Turn
1817
from pytest_mock import MockerFixture
1918

2019
import constants
@@ -34,26 +33,14 @@
3433
from utils.suid import check_suid
3534

3635

37-
class MockAgentFixture(NamedTuple):
38-
"""Container for mocked Llama Stack agent components."""
39-
40-
client: Any
41-
agent: Any
42-
holder_class: Any
43-
44-
4536
# ==========================================
4637
# Shared Fixtures
4738
# ==========================================
4839

4940

5041
@pytest.fixture(name="rlsapi_config")
5142
def rlsapi_config_fixture(test_config: AppConfig, mocker: MockerFixture) -> AppConfig:
52-
"""Extend test_config with inference defaults required by rlsapi v1.
53-
54-
NOTE(major): The standard test configuration doesn't include inference
55-
settings (default_model, default_provider) which rlsapi v1 requires.
56-
"""
43+
"""Extend test_config with inference defaults required by rlsapi v1."""
5744
test_config.inference.default_model = "test-model"
5845
test_config.inference.default_provider = "test-provider"
5946
mocker.patch("app.endpoints.rlsapi_v1.configuration", test_config)
@@ -66,60 +53,42 @@ def mock_authorization_fixture(mocker: MockerFixture) -> None:
6653
mock_authorization_resolvers(mocker)
6754

6855

69-
def _create_mock_agent(
56+
def _create_mock_response_output(mocker: MockerFixture, text: str) -> Any:
57+
"""Create a mock Responses API output item with assistant message."""
58+
mock_output_item = mocker.Mock()
59+
mock_output_item.type = "message"
60+
mock_output_item.role = "assistant"
61+
mock_output_item.content = text
62+
return mock_output_item
63+
64+
65+
def _setup_responses_mock(
7066
mocker: MockerFixture,
71-
response_content: str = "Use the `ls` command to list files in a directory.",
72-
output_message: Any = "default",
73-
) -> MockAgentFixture:
74-
"""Create a mocked Llama Stack agent with configurable response.
75-
76-
Args:
77-
mocker: pytest-mock fixture
78-
response_content: Text content for the LLM response
79-
output_message: Custom output_message Mock, or "default" to create one,
80-
or None for no output_message
81-
82-
Returns:
83-
MockAgentFixture with client, agent, and holder_class components
84-
"""
67+
response_text: str = "Use the `ls` command to list files in a directory.",
68+
) -> Any:
69+
"""Set up responses.create mock with the given response text."""
70+
mock_response = mocker.Mock()
71+
mock_response.output = [_create_mock_response_output(mocker, response_text)]
72+
73+
mock_responses = mocker.Mock()
74+
mock_responses.create = mocker.AsyncMock(return_value=mock_response)
75+
76+
mock_client = mocker.Mock()
77+
mock_client.responses = mock_responses
78+
8579
mock_holder_class = mocker.patch(
8680
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
8781
)
88-
mock_client = mocker.AsyncMock()
89-
90-
# Configure output message
91-
if output_message == "default":
92-
mock_output_message = mocker.Mock()
93-
mock_output_message.content = response_content
94-
else:
95-
mock_output_message = output_message
96-
97-
mock_turn = mocker.Mock(spec=Turn)
98-
mock_turn.output_message = mock_output_message
99-
mock_turn.steps = []
100-
101-
mock_agent = mocker.AsyncMock()
102-
mock_agent.create_turn = mocker.AsyncMock(return_value=mock_turn)
103-
mock_agent._agent_id = "test_agent_id"
104-
105-
mocker.patch(
106-
"app.endpoints.rlsapi_v1.get_temp_agent",
107-
return_value=(mock_agent, "test_session_id", None),
108-
)
109-
110-
mock_holder_instance = mock_holder_class.return_value
111-
mock_holder_instance.get_client.return_value = mock_client
82+
mock_holder_class.return_value.get_client.return_value = mock_client
11283

113-
return MockAgentFixture(mock_client, mock_agent, mock_holder_class)
84+
return mock_client
11485

11586

11687
@pytest.fixture(name="mock_llama_stack")
117-
def mock_llama_stack_fixture(
118-
rlsapi_config: AppConfig, mocker: MockerFixture
119-
) -> MockAgentFixture:
88+
def mock_llama_stack_fixture(rlsapi_config: AppConfig, mocker: MockerFixture) -> Any:
12089
"""Mock Llama Stack client with successful response."""
12190
_ = rlsapi_config
122-
return _create_mock_agent(mocker)
91+
return _setup_responses_mock(mocker)
12392

12493

12594
# ==========================================
@@ -129,7 +98,7 @@ def mock_llama_stack_fixture(
12998

13099
@pytest.mark.asyncio
131100
async def test_rlsapi_v1_infer_minimal_request(
132-
mock_llama_stack: MockAgentFixture,
101+
mock_llama_stack: Any,
133102
mock_authorization: None,
134103
test_auth: AuthTuple,
135104
) -> None:
@@ -179,7 +148,7 @@ async def test_rlsapi_v1_infer_minimal_request(
179148
],
180149
)
181150
async def test_rlsapi_v1_infer_with_context(
182-
mock_llama_stack: MockAgentFixture,
151+
mock_llama_stack: Any,
183152
mock_authorization: None,
184153
test_auth: AuthTuple,
185154
context: RlsapiV1Context,
@@ -198,7 +167,7 @@ async def test_rlsapi_v1_infer_with_context(
198167

199168
@pytest.mark.asyncio
200169
async def test_rlsapi_v1_infer_generates_unique_request_ids(
201-
mock_llama_stack: MockAgentFixture,
170+
mock_llama_stack: Any,
202171
mock_authorization: None,
203172
test_auth: AuthTuple,
204173
) -> None:
@@ -229,19 +198,18 @@ async def test_rlsapi_v1_infer_connection_error_returns_503(
229198
"""Test /v1/infer returns 503 when Llama Stack is unavailable."""
230199
_ = rlsapi_config
231200

232-
# Create agent that raises APIConnectionError
233-
mock_holder_class = mocker.patch(
234-
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
235-
)
236-
mock_agent = mocker.AsyncMock()
237-
mock_agent.create_turn = mocker.AsyncMock(
201+
mock_responses = mocker.Mock()
202+
mock_responses.create = mocker.AsyncMock(
238203
side_effect=APIConnectionError(request=mocker.Mock())
239204
)
240-
mocker.patch(
241-
"app.endpoints.rlsapi_v1.get_temp_agent",
242-
return_value=(mock_agent, "test_session_id", None),
205+
206+
mock_client = mocker.Mock()
207+
mock_client.responses = mock_responses
208+
209+
mock_holder_class = mocker.patch(
210+
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
243211
)
244-
mock_holder_class.return_value.get_client.return_value = mocker.AsyncMock()
212+
mock_holder_class.return_value.get_client.return_value = mock_client
245213

246214
with pytest.raises(HTTPException) as exc_info:
247215
await infer_endpoint(
@@ -255,29 +223,28 @@ async def test_rlsapi_v1_infer_connection_error_returns_503(
255223

256224

257225
@pytest.mark.asyncio
258-
@pytest.mark.parametrize(
259-
"output_message",
260-
[
261-
pytest.param(None, id="none_output_message"),
262-
pytest.param("empty", id="empty_content"),
263-
],
264-
)
265-
async def test_rlsapi_v1_infer_fallback_responses(
226+
async def test_rlsapi_v1_infer_fallback_response_empty_output(
266227
rlsapi_config: AppConfig,
267228
mock_authorization: None,
268229
test_auth: AuthTuple,
269230
mocker: MockerFixture,
270-
output_message: Any,
271231
) -> None:
272-
"""Test /v1/infer returns fallback for empty/None responses."""
232+
"""Test /v1/infer returns fallback for empty output list."""
273233
_ = rlsapi_config
274234

275-
if output_message == "empty":
276-
mock_output = mocker.Mock()
277-
mock_output.content = ""
278-
_create_mock_agent(mocker, output_message=mock_output)
279-
else:
280-
_create_mock_agent(mocker, output_message=None)
235+
mock_response = mocker.Mock()
236+
mock_response.output = []
237+
238+
mock_responses = mocker.Mock()
239+
mock_responses.create = mocker.AsyncMock(return_value=mock_response)
240+
241+
mock_client = mocker.Mock()
242+
mock_client.responses = mock_responses
243+
244+
mock_holder_class = mocker.patch(
245+
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
246+
)
247+
mock_holder_class.return_value.get_client.return_value = mock_client
281248

282249
response = await infer_endpoint(
283250
infer_request=RlsapiV1InferRequest(question="Test"),
@@ -301,7 +268,20 @@ async def test_rlsapi_v1_infer_input_source_combination(
301268
) -> None:
302269
"""Test that input sources are properly combined before sending to LLM."""
303270
_ = rlsapi_config
304-
mocks = _create_mock_agent(mocker)
271+
272+
mock_response = mocker.Mock()
273+
mock_response.output = [_create_mock_response_output(mocker, "response text")]
274+
275+
mock_responses = mocker.Mock()
276+
mock_responses.create = mocker.AsyncMock(return_value=mock_response)
277+
278+
mock_client = mocker.Mock()
279+
mock_client.responses = mock_responses
280+
281+
mock_holder_class = mocker.patch(
282+
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
283+
)
284+
mock_holder_class.return_value.get_client.return_value = mock_client
305285

306286
await infer_endpoint(
307287
infer_request=RlsapiV1InferRequest(
@@ -315,12 +295,11 @@ async def test_rlsapi_v1_infer_input_source_combination(
315295
auth=test_auth,
316296
)
317297

318-
# Verify all parts present in message sent to LLM
319-
call_args = mocks.agent.create_turn.call_args
320-
message_content = call_args.kwargs["messages"][0]["content"]
298+
call_args = mock_responses.create.call_args
299+
input_content = call_args.kwargs["input"]
321300

322301
for expected in ["My question", "stdin content", "attachment content", "terminal"]:
323-
assert expected in message_content
302+
assert expected in input_content
324303

325304

326305
# ==========================================
@@ -334,7 +313,7 @@ async def test_rlsapi_v1_infer_input_source_combination(
334313
[pytest.param(False, id="default_false"), pytest.param(True, id="explicit_true")],
335314
)
336315
async def test_rlsapi_v1_infer_skip_rag(
337-
mock_llama_stack: MockAgentFixture,
316+
mock_llama_stack: Any,
338317
mock_authorization: None,
339318
test_auth: AuthTuple,
340319
skip_rag: bool,

0 commit comments

Comments
 (0)