Skip to content

Commit 6f0cb03

Browse files
authored
Merge pull request #1872 from Jdubrick/update-shields-tests
RHIDP-14000: update unit tests for shields
2 parents f81b649 + 123f130 commit 6f0cb03

1 file changed

Lines changed: 122 additions & 8 deletions

File tree

tests/unit/app/endpoints/test_shields.py

Lines changed: 122 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ async def test_shields_endpoint_handler_configuration_not_loaded(
3838

3939
with pytest.raises(HTTPException) as e:
4040
await shields_endpoint_handler(request=request, auth=auth)
41-
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
42-
assert e.value.detail["response"] == "Configuration is not loaded" # type: ignore
41+
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
42+
assert e.value.detail["response"] == "Configuration is not loaded" # type: ignore
4343

4444

4545
@pytest.mark.asyncio
@@ -235,7 +235,8 @@ async def test_shields_endpoint_handler_unable_to_retrieve_shields_list(
235235
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
236236

237237
response = await shields_endpoint_handler(request=request, auth=auth)
238-
assert response is not None
238+
assert isinstance(response, ShieldsResponse)
239+
assert response.shields == []
239240

240241

241242
@pytest.mark.asyncio
@@ -249,8 +250,8 @@ async def test_shields_endpoint_llama_stack_connection_error(
249250
250251
Simulates the Llama Stack client raising an APIConnectionError and asserts
251252
that calling the endpoint raises an HTTPException with status 503, a detail
252-
response of "Service unavailable", and a detail cause that contains "Unable
253-
to connect to Llama Stack".
253+
response of "Unable to connect to Llama Stack", and a detail cause that
254+
contains "Connection error".
254255
"""
255256
mock_authorization_resolvers(mocker)
256257

@@ -290,6 +291,8 @@ async def test_shields_endpoint_llama_stack_connection_error(
290291
cfg = AppConfig()
291292
cfg.init_from_dict(config_dict)
292293

294+
mocker.patch("app.endpoints.shields.configuration", cfg)
295+
293296
request = Request(
294297
scope={
295298
"type": "http",
@@ -302,9 +305,9 @@ async def test_shields_endpoint_llama_stack_connection_error(
302305

303306
with pytest.raises(HTTPException) as e:
304307
await shields_endpoint_handler(request=request, auth=auth)
305-
assert e.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
306-
assert e.value.detail["response"] == "Service unavailable" # type: ignore
307-
assert "Unable to connect to Llama Stack" in e.value.detail["cause"] # type: ignore
308+
assert e.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
309+
assert e.value.detail["response"] == "Unable to connect to Llama Stack" # type: ignore
310+
assert "Connection error" in e.value.detail["cause"] # type: ignore
308311

309312

310313
@pytest.mark.asyncio
@@ -382,3 +385,114 @@ async def test_shields_endpoint_handler_success_with_shields_data(
382385
assert len(response.shields) == 2
383386
assert response.shields[0]["identifier"] == "lightspeed_question_validity-shield"
384387
assert response.shields[1]["identifier"] == "content_filter-shield"
388+
389+
390+
@pytest.mark.asyncio
391+
async def test_shields_endpoint_handler_unexpected_exception(
392+
mocker: MockerFixture,
393+
) -> None:
394+
"""Test the shields endpoint when an unexpected exception is raised."""
395+
mock_authorization_resolvers(mocker)
396+
397+
config_dict: dict[str, Any] = {
398+
"name": "foo",
399+
"service": {
400+
"host": "localhost",
401+
"port": 8080,
402+
"auth_enabled": False,
403+
"workers": 1,
404+
"color_log": True,
405+
"access_log": True,
406+
},
407+
"llama_stack": {
408+
"api_key": "xyzzy",
409+
"url": "http://x.y.com:1234",
410+
"use_as_library_client": False,
411+
},
412+
"user_data_collection": {
413+
"feedback_enabled": False,
414+
},
415+
"customization": None,
416+
"authorization": {"access_rules": []},
417+
"authentication": {"module": "noop"},
418+
}
419+
cfg = AppConfig()
420+
cfg.init_from_dict(config_dict)
421+
422+
mock_client = mocker.AsyncMock()
423+
mock_client.shields.list.side_effect = RuntimeError("unexpected failure")
424+
mock_client_holder = mocker.patch(
425+
"app.endpoints.shields.AsyncLlamaStackClientHolder"
426+
)
427+
mock_client_holder.return_value.get_client.return_value = mock_client
428+
429+
request = Request(
430+
scope={
431+
"type": "http",
432+
"headers": [(b"authorization", b"Bearer invalid-token")],
433+
}
434+
)
435+
436+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
437+
438+
with pytest.raises(RuntimeError, match="unexpected failure"):
439+
await shields_endpoint_handler(request=request, auth=auth)
440+
441+
442+
@pytest.mark.asyncio
443+
async def test_shields_endpoint_handler_malformed_shield_objects(
444+
mocker: MockerFixture,
445+
) -> None:
446+
"""Test the shields endpoint handles shields that may have missing fields."""
447+
mock_authorization_resolvers(mocker)
448+
449+
config_dict: dict[str, Any] = {
450+
"name": "foo",
451+
"service": {
452+
"host": "localhost",
453+
"port": 8080,
454+
"auth_enabled": False,
455+
"workers": 1,
456+
"color_log": True,
457+
"access_log": True,
458+
},
459+
"llama_stack": {
460+
"api_key": "xyzzy",
461+
"url": "http://x.y.com:1234",
462+
"use_as_library_client": False,
463+
},
464+
"user_data_collection": {
465+
"feedback_enabled": False,
466+
},
467+
"customization": None,
468+
"authorization": {"access_rules": []},
469+
"authentication": {"module": "noop"},
470+
}
471+
cfg = AppConfig()
472+
cfg.init_from_dict(config_dict)
473+
474+
mock_shield_minimal = {
475+
"identifier": "minimal-shield",
476+
}
477+
478+
mock_client = mocker.AsyncMock()
479+
mock_client.shields.list.return_value = [mock_shield_minimal]
480+
mock_client_holder = mocker.patch(
481+
"app.endpoints.shields.AsyncLlamaStackClientHolder"
482+
)
483+
mock_client_holder.return_value.get_client.return_value = mock_client
484+
485+
request = Request(
486+
scope={
487+
"type": "http",
488+
"headers": [(b"authorization", b"Bearer invalid-token")],
489+
}
490+
)
491+
492+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
493+
494+
response = await shields_endpoint_handler(request=request, auth=auth)
495+
496+
assert isinstance(response, ShieldsResponse)
497+
assert len(response.shields) == 1
498+
assert response.shields[0]["identifier"] == "minimal-shield"

0 commit comments

Comments
 (0)