Skip to content

Commit bdd73aa

Browse files
committed
RHIDP-14000: update mcp and query interrupt tests
Signed-off-by: Jordan Dubrick <jdubrick@redhat.com>
1 parent 82f6ca4 commit bdd73aa

3 files changed

Lines changed: 157 additions & 5 deletions

File tree

tests/integration/endpoints/test_stream_interrupt_integration.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
from collections.abc import Generator
55

66
import pytest
7-
from fastapi import HTTPException
7+
from fastapi import HTTPException, status
8+
from fastapi.testclient import TestClient
89

910
from app.endpoints.stream_interrupt import stream_interrupt_endpoint_handler
11+
from configuration import AppConfig
1012
from models.api.requests import StreamingInterruptRequest
1113
from utils.stream_interrupts import StreamInterruptRegistry
1214

1315
TEST_REQUEST_ID = "123e4567-e89b-12d3-a456-426614174003"
16+
REQUEST_ID_NOT_IN_REGISTRY = "00000000-0000-0000-0000-000000000000"
1417
OWNER_USER_ID = "00000001-0001-0001-0001-000000000001"
1518

1619

@@ -25,10 +28,12 @@ def registry_fixture() -> Generator[StreamInterruptRegistry, None, None]:
2528

2629
@pytest.mark.asyncio
2730
async def test_stream_interrupt_full_round_trip(
31+
test_config: AppConfig,
2832
registry: StreamInterruptRegistry,
2933
) -> None:
3034
"""Full lifecycle: register, interrupt, then verify deregistration."""
31-
35+
# test_config loads configuration so @authorize on the handler can resolve.
36+
_ = test_config
3237
async def pending_stream() -> None:
3338
await asyncio.sleep(10)
3439

@@ -64,3 +69,15 @@ async def pending_stream() -> None:
6469
registry=registry,
6570
)
6671
assert exc_info.value.status_code == 404
72+
73+
74+
def test_stream_interrupt_nonexistent_request_returns_404(
75+
integration_http_client: TestClient,
76+
) -> None:
77+
"""POST /v1/streaming_query/interrupt for unknown stream returns 404."""
78+
response = integration_http_client.post(
79+
"/v1/streaming_query/interrupt",
80+
json={"request_id": REQUEST_ID_NOT_IN_REGISTRY},
81+
)
82+
83+
assert response.status_code == status.HTTP_404_NOT_FOUND

tests/unit/app/endpoints/test_mcp_servers.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from typing import Any
77

88
import pytest
9-
from fastapi import HTTPException
10-
from llama_stack_client import APIConnectionError
9+
from fastapi import HTTPException, status
10+
from llama_stack_client import APIConnectionError, NotFoundError
1111
from pydantic import AnyHttpUrl, SecretStr
1212
from pytest_mock import MockerFixture
1313

@@ -390,6 +390,64 @@ async def test_delete_mcp_server_llama_stack_failure(
390390
assert exc_info.value.status_code == 503
391391

392392

393+
@pytest.mark.asyncio
394+
async def test_delete_mcp_server_toolgroup_not_found_is_idempotent(
395+
mocker: MockerFixture,
396+
mock_configuration: Configuration,
397+
) -> None:
398+
"""Deleting a dynamic server succeeds when Llama Stack toolgroup is already gone."""
399+
app_config = _make_app_config(mocker, mock_configuration)
400+
client = _mock_client(mocker)
401+
client.toolgroups.register.return_value = None
402+
client.toolgroups.unregister.side_effect = NotFoundError(
403+
message="Toolgroup not found",
404+
response=mocker.Mock(request=None),
405+
body=None,
406+
)
407+
408+
body = MCPServerRegistrationRequest(
409+
name="orphan-server",
410+
url="http://localhost:7777/mcp",
411+
provider_id="MCP provider ID",
412+
)
413+
await mcp_servers.register_mcp_server_handler(
414+
request=mocker.Mock(), body=body, auth=MOCK_AUTH
415+
)
416+
assert app_config.is_dynamic_mcp_server("orphan-server")
417+
418+
result = await mcp_servers.delete_mcp_server_handler(
419+
request=mocker.Mock(), name="orphan-server", auth=MOCK_AUTH
420+
)
421+
422+
assert isinstance(result, MCPServerDeleteResponse)
423+
assert result.name == "orphan-server"
424+
assert result.deleted is True
425+
assert not app_config.is_dynamic_mcp_server("orphan-server")
426+
assert not any(s.name == "orphan-server" for s in app_config.mcp_servers)
427+
428+
429+
@pytest.mark.asyncio
430+
async def test_list_mcp_servers_configuration_not_loaded(
431+
mocker: MockerFixture,
432+
) -> None:
433+
"""Test listing MCP servers returns 500 when configuration is not loaded."""
434+
mock_config = AppConfig()
435+
mock_config._configuration = None # pylint: disable=protected-access
436+
mocker.patch("app.endpoints.mcp_servers.configuration", mock_config)
437+
mocker.patch("app.endpoints.mcp_servers.authorize", lambda _: lambda func: func)
438+
439+
with pytest.raises(HTTPException) as exc_info:
440+
await mcp_servers.list_mcp_servers_handler(
441+
request=mocker.Mock(), auth=MOCK_AUTH
442+
)
443+
444+
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
445+
raw_detail = exc_info.value.detail
446+
assert isinstance(raw_detail, dict)
447+
detail: dict[str, Any] = raw_detail
448+
assert detail["response"] == "Configuration is not loaded"
449+
450+
393451
def test_mcp_server_registration_request_validation() -> None:
394452
"""Test request model validation."""
395453
with pytest.raises(Exception):

tests/unit/app/endpoints/test_stream_interrupt.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Unit tests for streaming query interrupt endpoint."""
22

33
import asyncio
4+
import threading
45
from collections.abc import Generator
56

67
import pytest
@@ -9,13 +10,17 @@
910
from app.endpoints.stream_interrupt import stream_interrupt_endpoint_handler
1011
from models.api.requests import StreamingInterruptRequest
1112
from models.api.responses.successful import StreamingInterruptResponse
12-
from utils.stream_interrupts import StreamInterruptRegistry
13+
from utils.stream_interrupts import CancelStreamResult, StreamInterruptRegistry
1314

1415
REQUEST_ID_SUCCESS = "123e4567-e89b-12d3-a456-426614174000"
1516
REQUEST_ID_NOT_FOUND = "123e4567-e89b-12d3-a456-426614174001"
1617
REQUEST_ID_WRONG_USER = "123e4567-e89b-12d3-a456-426614174002"
1718
REQUEST_ID_ALREADY_COMPLETED = "123e4567-e89b-12d3-a456-426614174004"
1819

20+
# CI-friendly sync timeouts for concurrent registry tests.
21+
_CONCURRENT_BARRIER_TIMEOUT_S = 5.0
22+
_CONCURRENT_THREAD_JOIN_TIMEOUT_S = 5.0
23+
1924
OWNER_USER_ID = "00000001-0001-0001-0001-000000000001"
2025
NON_OWNER_USER_ID = "00000001-0001-0001-0001-000000000999"
2126

@@ -148,3 +153,75 @@ async def completed_stream() -> None:
148153
assert isinstance(response, StreamingInterruptResponse)
149154
assert response.request_id == REQUEST_ID_ALREADY_COMPLETED
150155
assert response.interrupted is False
156+
157+
158+
@pytest.mark.asyncio
159+
async def test_stream_interrupt_registry_concurrent_cancel_and_deregister(
160+
registry: StreamInterruptRegistry,
161+
) -> None:
162+
"""Concurrent cancel and deregister do not raise under the registry lock."""
163+
164+
async def pending_stream() -> None:
165+
await asyncio.sleep(10)
166+
167+
task = asyncio.create_task(pending_stream())
168+
registry.register_stream(REQUEST_ID_SUCCESS, OWNER_USER_ID, task)
169+
170+
barrier = threading.Barrier(2)
171+
errors: list[Exception] = []
172+
173+
def deregister_in_thread() -> None:
174+
try:
175+
barrier.wait(timeout=_CONCURRENT_BARRIER_TIMEOUT_S)
176+
registry.deregister_stream(REQUEST_ID_SUCCESS)
177+
except Exception as exc: # pylint: disable=broad-exception-caught
178+
errors.append(exc)
179+
180+
thread = threading.Thread(target=deregister_in_thread)
181+
thread.start()
182+
barrier.wait(timeout=_CONCURRENT_BARRIER_TIMEOUT_S)
183+
184+
result = registry.cancel_stream(REQUEST_ID_SUCCESS, OWNER_USER_ID)
185+
thread.join(timeout=_CONCURRENT_THREAD_JOIN_TIMEOUT_S)
186+
assert not thread.is_alive(), "Deregister thread did not complete within timeout"
187+
188+
assert not errors
189+
assert result in (
190+
CancelStreamResult.CANCELLED,
191+
CancelStreamResult.NOT_FOUND,
192+
)
193+
194+
if not task.done():
195+
task.cancel()
196+
with pytest.raises(asyncio.CancelledError):
197+
await task
198+
199+
200+
@pytest.mark.asyncio
201+
async def test_stream_interrupt_endpoint_double_interrupt(
202+
registry: StreamInterruptRegistry,
203+
) -> None:
204+
"""Second interrupt on the same stream returns interrupted=False."""
205+
206+
async def pending_stream() -> None:
207+
await asyncio.sleep(10)
208+
209+
task = asyncio.create_task(pending_stream())
210+
registry.register_stream(REQUEST_ID_SUCCESS, OWNER_USER_ID, task)
211+
212+
first_response = await stream_interrupt_endpoint_handler(
213+
interrupt_request=StreamingInterruptRequest(request_id=REQUEST_ID_SUCCESS),
214+
auth=(OWNER_USER_ID, "mock_username", False, "mock_token"),
215+
registry=registry,
216+
)
217+
assert first_response.interrupted is True
218+
219+
with pytest.raises(asyncio.CancelledError):
220+
await task
221+
222+
second_response = await stream_interrupt_endpoint_handler(
223+
interrupt_request=StreamingInterruptRequest(request_id=REQUEST_ID_SUCCESS),
224+
auth=(OWNER_USER_ID, "mock_username", False, "mock_token"),
225+
registry=registry,
226+
)
227+
assert second_response.interrupted is False

0 commit comments

Comments
 (0)