66from typing import Any
77
88import pytest
9- from fastapi import HTTPException , Request
9+ from fastapi import HTTPException , Request , status
1010from fastapi .responses import StreamingResponse
1111from llama_stack_api .openai_responses import (
1212 OpenAIResponseObject ,
@@ -257,20 +257,14 @@ class TestOLSCompatibilityIntegration:
257257
258258 def test_media_type_validation (self ) -> None :
259259 """Test that media type validation works correctly."""
260- valid_request = QueryRequest (
261- query = "test" , media_type = "application/json"
262- ) # pyright: ignore[reportCallIssue]
260+ valid_request = QueryRequest (query = "test" , media_type = "application/json" ) # pyright: ignore[reportCallIssue]
263261 assert valid_request .media_type == "application/json"
264262
265- valid_request = QueryRequest (
266- query = "test" , media_type = "text/plain"
267- ) # pyright: ignore[reportCallIssue]
263+ valid_request = QueryRequest (query = "test" , media_type = "text/plain" ) # pyright: ignore[reportCallIssue]
268264 assert valid_request .media_type == "text/plain"
269265
270266 with pytest .raises (ValueError , match = "media_type must be either" ):
271- QueryRequest (
272- query = "test" , media_type = "invalid/type"
273- ) # pyright: ignore[reportCallIssue]
267+ QueryRequest (query = "test" , media_type = "invalid/type" ) # pyright: ignore[reportCallIssue]
274268
275269 def test_ols_end_event_structure (self ) -> None :
276270 """Test that end event follows OLS structure."""
@@ -322,9 +316,7 @@ async def test_successful_streaming_query(
322316 mocker : MockerFixture ,
323317 ) -> None :
324318 """Test successful streaming query."""
325- query_request = QueryRequest (
326- query = "What is Kubernetes?"
327- ) # pyright: ignore[reportCallIssue]
319+ query_request = QueryRequest (query = "What is Kubernetes?" ) # pyright: ignore[reportCallIssue]
328320
329321 mocker .patch ("app.endpoints.streaming_query.configuration" , setup_configuration )
330322 mocker .patch ("app.endpoints.streaming_query.check_configuration_loaded" )
@@ -574,9 +566,7 @@ async def test_streaming_query_azure_token_refresh(
574566 mocker : MockerFixture ,
575567 ) -> None :
576568 """Test streaming query refreshes Azure token when needed."""
577- query_request = QueryRequest (
578- query = "What is Kubernetes?"
579- ) # pyright: ignore[reportCallIssue]
569+ query_request = QueryRequest (query = "What is Kubernetes?" ) # pyright: ignore[reportCallIssue]
580570
581571 mocker .patch ("app.endpoints.streaming_query.configuration" , setup_configuration )
582572 mocker .patch ("app.endpoints.streaming_query.check_configuration_loaded" )
@@ -679,9 +669,7 @@ async def test_retrieve_response_generator_success(
679669
680670 mock_context = mocker .Mock (spec = ResponseGeneratorContext )
681671 mock_context .client = mock_client
682- mock_context .query_request = QueryRequest (
683- query = "test"
684- ) # pyright: ignore[reportCallIssue]
672+ mock_context .query_request = QueryRequest (query = "test" ) # pyright: ignore[reportCallIssue]
685673
686674 async def mock_response_gen () -> AsyncIterator [str ]:
687675 yield "test"
@@ -769,9 +757,7 @@ async def test_retrieve_response_generator_connection_error(
769757
770758 mock_context = mocker .Mock (spec = ResponseGeneratorContext )
771759 mock_context .client = mock_client
772- mock_context .query_request = QueryRequest (
773- query = "test"
774- ) # pyright: ignore[reportCallIssue]
760+ mock_context .query_request = QueryRequest (query = "test" ) # pyright: ignore[reportCallIssue]
775761
776762 mocker .patch (
777763 "app.endpoints.streaming_query.run_shield_moderation" ,
@@ -822,9 +808,7 @@ async def test_retrieve_response_generator_api_status_error(
822808
823809 mock_context = mocker .Mock (spec = ResponseGeneratorContext )
824810 mock_context .client = mock_client
825- mock_context .query_request = QueryRequest (
826- query = "test"
827- ) # pyright: ignore[reportCallIssue]
811+ mock_context .query_request = QueryRequest (query = "test" ) # pyright: ignore[reportCallIssue]
828812
829813 mocker .patch (
830814 "app.endpoints.streaming_query.run_shield_moderation" ,
@@ -872,9 +856,7 @@ async def test_retrieve_response_generator_runtime_error_context_length(
872856
873857 mock_context = mocker .Mock (spec = ResponseGeneratorContext )
874858 mock_context .client = mock_client
875- mock_context .query_request = QueryRequest (
876- query = "test"
877- ) # pyright: ignore[reportCallIssue]
859+ mock_context .query_request = QueryRequest (query = "test" ) # pyright: ignore[reportCallIssue]
878860
879861 mocker .patch (
880862 "app.endpoints.streaming_query.run_shield_moderation" ,
@@ -919,9 +901,7 @@ async def test_retrieve_response_generator_runtime_error_other(
919901
920902 mock_context = mocker .Mock (spec = ResponseGeneratorContext )
921903 mock_context .client = mock_client
922- mock_context .query_request = QueryRequest (
923- query = "test"
924- ) # pyright: ignore[reportCallIssue]
904+ mock_context .query_request = QueryRequest (query = "test" ) # pyright: ignore[reportCallIssue]
925905
926906 mocker .patch (
927907 "app.endpoints.streaming_query.run_shield_moderation" ,
@@ -932,8 +912,9 @@ async def test_retrieve_response_generator_runtime_error_other(
932912 side_effect = RuntimeError ("Some other error" )
933913 )
934914
935- with pytest .raises (RuntimeError ) :
915+ with pytest .raises (HTTPException ) as exc_info :
936916 await retrieve_response_generator (mock_responses_params , mock_context )
917+ assert exc_info .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
937918
938919
939920class TestGenerateResponse :
@@ -950,9 +931,7 @@ async def mock_generator() -> AsyncIterator[str]:
950931 mock_context = mocker .Mock (spec = ResponseGeneratorContext )
951932 mock_context .conversation_id = "conv_123"
952933 mock_context .user_id = "user_123"
953- mock_context .query_request = QueryRequest (
954- query = "test"
955- ) # pyright: ignore[reportCallIssue]
934+ mock_context .query_request = QueryRequest (query = "test" ) # pyright: ignore[reportCallIssue]
956935 mock_context .started_at = "2024-01-01T00:00:00Z"
957936 mock_context .skip_userid_check = False
958937
@@ -1047,9 +1026,7 @@ async def mock_generator() -> AsyncIterator[str]:
10471026
10481027 mock_context = mocker .Mock (spec = ResponseGeneratorContext )
10491028 mock_context .conversation_id = "conv_123"
1050- mock_context .query_request = QueryRequest (
1051- query = "test"
1052- ) # pyright: ignore[reportCallIssue]
1029+ mock_context .query_request = QueryRequest (query = "test" ) # pyright: ignore[reportCallIssue]
10531030 mock_context .started_at = "2024-01-01T00:00:00Z"
10541031 mock_context .skip_userid_check = False
10551032
@@ -1082,9 +1059,7 @@ async def mock_generator() -> AsyncIterator[str]:
10821059
10831060 mock_context = mocker .Mock (spec = ResponseGeneratorContext )
10841061 mock_context .conversation_id = "conv_123"
1085- mock_context .query_request = QueryRequest (
1086- query = "test"
1087- ) # pyright: ignore[reportCallIssue]
1062+ mock_context .query_request = QueryRequest (query = "test" ) # pyright: ignore[reportCallIssue]
10881063 mock_context .started_at = "2024-01-01T00:00:00Z"
10891064 mock_context .skip_userid_check = False
10901065
0 commit comments