Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tests/integration/endpoints/test_stream_interrupt_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from collections.abc import Generator

import pytest
from fastapi import HTTPException
from fastapi import HTTPException, status
from fastapi.testclient import TestClient

from app.endpoints.stream_interrupt import stream_interrupt_endpoint_handler
from configuration import AppConfig
from models.api.requests import StreamingInterruptRequest
from utils.stream_interrupts import StreamInterruptRegistry

TEST_REQUEST_ID = "123e4567-e89b-12d3-a456-426614174003"
REQUEST_ID_NOT_IN_REGISTRY = "00000000-0000-0000-0000-000000000000"
OWNER_USER_ID = "00000001-0001-0001-0001-000000000001"


Expand All @@ -25,9 +28,12 @@ def registry_fixture() -> Generator[StreamInterruptRegistry, None, None]:

@pytest.mark.asyncio
async def test_stream_interrupt_full_round_trip(
test_config: AppConfig,
registry: StreamInterruptRegistry,
) -> None:
"""Full lifecycle: register, interrupt, then verify deregistration."""
# test_config loads configuration so @authorize on the handler can resolve.
_ = test_config

async def pending_stream() -> None:
await asyncio.sleep(10)
Expand Down Expand Up @@ -64,3 +70,15 @@ async def pending_stream() -> None:
registry=registry,
)
assert exc_info.value.status_code == 404


def test_stream_interrupt_nonexistent_request_returns_404(
integration_http_client: TestClient,
) -> None:
"""POST /v1/streaming_query/interrupt for unknown stream returns 404."""
response = integration_http_client.post(
"/v1/streaming_query/interrupt",
json={"request_id": REQUEST_ID_NOT_IN_REGISTRY},
)

assert response.status_code == status.HTTP_404_NOT_FOUND
62 changes: 60 additions & 2 deletions tests/unit/app/endpoints/test_mcp_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import Any

import pytest
from fastapi import HTTPException
from llama_stack_client import APIConnectionError
from fastapi import HTTPException, status
from llama_stack_client import APIConnectionError, NotFoundError
from pydantic import AnyHttpUrl, SecretStr
from pytest_mock import MockerFixture

Expand Down Expand Up @@ -390,6 +390,64 @@ async def test_delete_mcp_server_llama_stack_failure(
assert exc_info.value.status_code == 503


@pytest.mark.asyncio
async def test_delete_mcp_server_toolgroup_not_found_is_idempotent(
mocker: MockerFixture,
mock_configuration: Configuration,
) -> None:
"""Deleting a dynamic server succeeds when Llama Stack toolgroup is already gone."""
app_config = _make_app_config(mocker, mock_configuration)
client = _mock_client(mocker)
client.toolgroups.register.return_value = None
client.toolgroups.unregister.side_effect = NotFoundError(
message="Toolgroup not found",
response=mocker.Mock(request=None),
body=None,
)

body = MCPServerRegistrationRequest(
name="orphan-server",
url="http://localhost:7777/mcp",
provider_id="MCP provider ID",
)
await mcp_servers.register_mcp_server_handler(
request=mocker.Mock(), body=body, auth=MOCK_AUTH
)
assert app_config.is_dynamic_mcp_server("orphan-server")

result = await mcp_servers.delete_mcp_server_handler(
request=mocker.Mock(), name="orphan-server", auth=MOCK_AUTH
)

assert isinstance(result, MCPServerDeleteResponse)
assert result.name == "orphan-server"
assert result.deleted is True
assert not app_config.is_dynamic_mcp_server("orphan-server")
assert not any(s.name == "orphan-server" for s in app_config.mcp_servers)


@pytest.mark.asyncio
async def test_list_mcp_servers_configuration_not_loaded(
mocker: MockerFixture,
) -> None:
"""Test listing MCP servers returns 500 when configuration is not loaded."""
mock_config = AppConfig()
mock_config._configuration = None # pylint: disable=protected-access
mocker.patch("app.endpoints.mcp_servers.configuration", mock_config)
mocker.patch("app.endpoints.mcp_servers.authorize", lambda _: lambda func: func)

with pytest.raises(HTTPException) as exc_info:
await mcp_servers.list_mcp_servers_handler(
request=mocker.Mock(), auth=MOCK_AUTH
)

assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
raw_detail = exc_info.value.detail
assert isinstance(raw_detail, dict)
detail: dict[str, Any] = raw_detail
assert detail["response"] == "Configuration is not loaded"


def test_mcp_server_registration_request_validation() -> None:
"""Test request model validation."""
with pytest.raises(Exception):
Expand Down
79 changes: 78 additions & 1 deletion tests/unit/app/endpoints/test_stream_interrupt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unit tests for streaming query interrupt endpoint."""

import asyncio
import threading
from collections.abc import Generator

import pytest
Expand All @@ -9,13 +10,17 @@
from app.endpoints.stream_interrupt import stream_interrupt_endpoint_handler
from models.api.requests import StreamingInterruptRequest
from models.api.responses.successful import StreamingInterruptResponse
from utils.stream_interrupts import StreamInterruptRegistry
from utils.stream_interrupts import CancelStreamResult, StreamInterruptRegistry

REQUEST_ID_SUCCESS = "123e4567-e89b-12d3-a456-426614174000"
REQUEST_ID_NOT_FOUND = "123e4567-e89b-12d3-a456-426614174001"
REQUEST_ID_WRONG_USER = "123e4567-e89b-12d3-a456-426614174002"
REQUEST_ID_ALREADY_COMPLETED = "123e4567-e89b-12d3-a456-426614174004"

# CI-friendly sync timeouts for concurrent registry tests.
_CONCURRENT_BARRIER_TIMEOUT_S = 5.0
_CONCURRENT_THREAD_JOIN_TIMEOUT_S = 5.0

OWNER_USER_ID = "00000001-0001-0001-0001-000000000001"
NON_OWNER_USER_ID = "00000001-0001-0001-0001-000000000999"

Expand Down Expand Up @@ -148,3 +153,75 @@ async def completed_stream() -> None:
assert isinstance(response, StreamingInterruptResponse)
assert response.request_id == REQUEST_ID_ALREADY_COMPLETED
assert response.interrupted is False


@pytest.mark.asyncio
async def test_stream_interrupt_registry_concurrent_cancel_and_deregister(
registry: StreamInterruptRegistry,
) -> None:
"""Concurrent cancel and deregister do not raise under the registry lock."""

async def pending_stream() -> None:
await asyncio.sleep(10)

task = asyncio.create_task(pending_stream())
registry.register_stream(REQUEST_ID_SUCCESS, OWNER_USER_ID, task)

barrier = threading.Barrier(2)
errors: list[Exception] = []

def deregister_in_thread() -> None:
try:
barrier.wait(timeout=_CONCURRENT_BARRIER_TIMEOUT_S)
registry.deregister_stream(REQUEST_ID_SUCCESS)
except Exception as exc: # pylint: disable=broad-exception-caught
errors.append(exc)

thread = threading.Thread(target=deregister_in_thread)
thread.start()
barrier.wait(timeout=_CONCURRENT_BARRIER_TIMEOUT_S)

result = registry.cancel_stream(REQUEST_ID_SUCCESS, OWNER_USER_ID)
thread.join(timeout=_CONCURRENT_THREAD_JOIN_TIMEOUT_S)
assert not thread.is_alive(), "Deregister thread did not complete within timeout"

assert not errors
assert result in (
CancelStreamResult.CANCELLED,
CancelStreamResult.NOT_FOUND,
)

if not task.done():
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task


@pytest.mark.asyncio
async def test_stream_interrupt_endpoint_double_interrupt(
registry: StreamInterruptRegistry,
) -> None:
"""Second interrupt on the same stream returns interrupted=False."""

async def pending_stream() -> None:
await asyncio.sleep(10)

task = asyncio.create_task(pending_stream())
registry.register_stream(REQUEST_ID_SUCCESS, OWNER_USER_ID, task)

first_response = await stream_interrupt_endpoint_handler(
interrupt_request=StreamingInterruptRequest(request_id=REQUEST_ID_SUCCESS),
auth=(OWNER_USER_ID, "mock_username", False, "mock_token"),
registry=registry,
)
assert first_response.interrupted is True

with pytest.raises(asyncio.CancelledError):
await task

second_response = await stream_interrupt_endpoint_handler(
interrupt_request=StreamingInterruptRequest(request_id=REQUEST_ID_SUCCESS),
auth=(OWNER_USER_ID, "mock_username", False, "mock_token"),
registry=registry,
)
assert second_response.interrupted is False
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# pylint: disable=protected-access

import json
from collections.abc import AsyncGenerator
from typing import Any

import httpx
Expand Down Expand Up @@ -45,7 +46,7 @@ class TestAsyncByteStream:
async def test_iterates_chunks(self) -> None:
"""Test that _AsyncByteStream yields all chunks from the wrapped generator."""

async def gen():
async def gen() -> AsyncGenerator[bytes, None]:
yield b"chunk1"
yield b"chunk2"
yield b"chunk3"
Expand All @@ -59,7 +60,7 @@ async def gen():
async def test_empty_generator(self) -> None:
"""Test that _AsyncByteStream handles an empty generator gracefully."""

async def gen():
async def gen() -> AsyncGenerator[bytes, None]:
return
yield # pragma: no cover

Expand Down Expand Up @@ -134,7 +135,7 @@ async def test_streaming_request(
content=json.dumps(body).encode("utf-8"),
)

async def mock_stream_result():
async def mock_stream_result() -> AsyncGenerator[dict[str, int], None]:
yield {"chunk": 1}
yield {"chunk": 2}

Expand Down Expand Up @@ -317,7 +318,7 @@ async def test_produces_sse_format(
) -> None:
"""Test that streaming responses produce SSE-formatted byte chunks."""

async def mock_stream():
async def mock_stream() -> AsyncGenerator[dict[str, str], None]:
yield {"delta": "hello"}
yield {"delta": "world"}

Expand Down
Loading