Skip to content

Commit 1a0dd23

Browse files
committed
test: update unit tests for shields
Signed-off-by: Jordan Dubrick <jdubrick@redhat.com>
1 parent d464fbc commit 1a0dd23

1 file changed

Lines changed: 118 additions & 6 deletions

File tree

tests/unit/app/endpoints/test_shields.py

Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ async def test_shields_endpoint_handler_configuration_not_loaded(
3838

3939
with pytest.raises(HTTPException) as e:
4040
await shields_endpoint_handler(request=request, auth=auth)
41-
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
42-
assert e.value.detail["response"] == "Configuration is not loaded" # type: ignore
41+
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
42+
assert e.value.detail["response"] == "Configuration is not loaded" # type: ignore
4343

4444

4545
@pytest.mark.asyncio
@@ -235,7 +235,8 @@ async def test_shields_endpoint_handler_unable_to_retrieve_shields_list(
235235
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
236236

237237
response = await shields_endpoint_handler(request=request, auth=auth)
238-
assert response is not None
238+
assert isinstance(response, ShieldsResponse)
239+
assert response.shields == []
239240

240241

241242
@pytest.mark.asyncio
@@ -302,9 +303,9 @@ async def test_shields_endpoint_llama_stack_connection_error(
302303

303304
with pytest.raises(HTTPException) as e:
304305
await shields_endpoint_handler(request=request, auth=auth)
305-
assert e.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
306-
assert e.value.detail["response"] == "Service unavailable" # type: ignore
307-
assert "Unable to connect to Llama Stack" in e.value.detail["cause"] # type: ignore
306+
assert e.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
307+
assert e.value.detail["response"] == "Unable to connect to Llama Stack" # type: ignore
308+
assert "Connection error" in e.value.detail["cause"] # type: ignore
308309

309310

310311
@pytest.mark.asyncio
@@ -382,3 +383,114 @@ async def test_shields_endpoint_handler_success_with_shields_data(
382383
assert len(response.shields) == 2
383384
assert response.shields[0]["identifier"] == "lightspeed_question_validity-shield"
384385
assert response.shields[1]["identifier"] == "content_filter-shield"
386+
387+
388+
@pytest.mark.asyncio
389+
async def test_shields_endpoint_handler_unexpected_exception(
390+
mocker: MockerFixture,
391+
) -> None:
392+
"""Test the shields endpoint when an unexpected exception is raised."""
393+
mock_authorization_resolvers(mocker)
394+
395+
config_dict: dict[str, Any] = {
396+
"name": "foo",
397+
"service": {
398+
"host": "localhost",
399+
"port": 8080,
400+
"auth_enabled": False,
401+
"workers": 1,
402+
"color_log": True,
403+
"access_log": True,
404+
},
405+
"llama_stack": {
406+
"api_key": "xyzzy",
407+
"url": "http://x.y.com:1234",
408+
"use_as_library_client": False,
409+
},
410+
"user_data_collection": {
411+
"feedback_enabled": False,
412+
},
413+
"customization": None,
414+
"authorization": {"access_rules": []},
415+
"authentication": {"module": "noop"},
416+
}
417+
cfg = AppConfig()
418+
cfg.init_from_dict(config_dict)
419+
420+
mock_client = mocker.AsyncMock()
421+
mock_client.shields.list.side_effect = RuntimeError("unexpected failure")
422+
mock_client_holder = mocker.patch(
423+
"app.endpoints.shields.AsyncLlamaStackClientHolder"
424+
)
425+
mock_client_holder.return_value.get_client.return_value = mock_client
426+
427+
request = Request(
428+
scope={
429+
"type": "http",
430+
"headers": [(b"authorization", b"Bearer invalid-token")],
431+
}
432+
)
433+
434+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
435+
436+
with pytest.raises(RuntimeError, match="unexpected failure"):
437+
await shields_endpoint_handler(request=request, auth=auth)
438+
439+
440+
@pytest.mark.asyncio
441+
async def test_shields_endpoint_handler_malformed_shield_objects(
442+
mocker: MockerFixture,
443+
) -> None:
444+
"""Test the shields endpoint handles shields that may have missing fields."""
445+
mock_authorization_resolvers(mocker)
446+
447+
config_dict: dict[str, Any] = {
448+
"name": "foo",
449+
"service": {
450+
"host": "localhost",
451+
"port": 8080,
452+
"auth_enabled": False,
453+
"workers": 1,
454+
"color_log": True,
455+
"access_log": True,
456+
},
457+
"llama_stack": {
458+
"api_key": "xyzzy",
459+
"url": "http://x.y.com:1234",
460+
"use_as_library_client": False,
461+
},
462+
"user_data_collection": {
463+
"feedback_enabled": False,
464+
},
465+
"customization": None,
466+
"authorization": {"access_rules": []},
467+
"authentication": {"module": "noop"},
468+
}
469+
cfg = AppConfig()
470+
cfg.init_from_dict(config_dict)
471+
472+
mock_shield_minimal = {
473+
"identifier": "minimal-shield",
474+
}
475+
476+
mock_client = mocker.AsyncMock()
477+
mock_client.shields.list.return_value = [mock_shield_minimal]
478+
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
479+
mock_lsc.return_value = mock_client
480+
mock_config = mocker.Mock()
481+
mocker.patch("app.endpoints.shields.configuration", mock_config)
482+
483+
request = Request(
484+
scope={
485+
"type": "http",
486+
"headers": [(b"authorization", b"Bearer invalid-token")],
487+
}
488+
)
489+
490+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
491+
492+
response = await shields_endpoint_handler(request=request, auth=auth)
493+
494+
assert isinstance(response, ShieldsResponse)
495+
assert len(response.shields) == 1
496+
assert response.shields[0]["identifier"] == "minimal-shield"

0 commit comments

Comments
 (0)